瀏覽代碼

Add language modeling task

Sengxian 3 年之前
父節點
當前提交
113f5f1364

+ 6 - 2
evaluation/__init__.py

@@ -1,7 +1,11 @@
 from .configs import *
 from .model import ModelForEvaluation
-from .tasks import BaseTask, GenerationTask, MultiChoiceTask
+from .tasks import BaseTask, GenerationTask, MultiChoiceTask, LanguageModelTask
 from .metrics import qa_evaluate
 from .utils import print_rank_0
 
-DEFAULT_CLASS = {TaskType.GENERATION: GenerationTask, TaskType.MULTICHOICE: MultiChoiceTask}
+DEFAULT_CLASS = {
+    TaskType.GENERATION: GenerationTask,
+    TaskType.MULTICHOICE: MultiChoiceTask,
+    TaskType.LANGUAGE_MODEL: LanguageModelTask,
+}

+ 9 - 0
evaluation/configs.py

@@ -8,6 +8,7 @@ from typing import Optional, List, Dict
 class TaskType(Enum):
     MULTICHOICE = "mul"
     GENERATION = "gen"
+    LANGUAGE_MODEL = "lm"
     OTHER = "other"
 
 
@@ -51,3 +52,11 @@ class GenerationTaskConfig(BaseConfig):
 
     def __post_init__(self):
         assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"
+
+
+@dataclass
+class LanguageModelTaskConfig(BaseConfig):
+    module = "evaluation.LanguageModelTask"
+    metrics: List[str] = field(default_factory=lambda: ["PPL"])
+
+    generation_length: int = 256  # Generated length in each window

+ 72 - 10
evaluation/dataset.py

@@ -1,15 +1,19 @@
 import os
+import math
 import json
 
 import numpy as np
 import torch
 
+from typing import List, Union
 from abc import ABC, abstractmethod
 from scipy.linalg import block_diag
+from itertools import accumulate
+from bisect import bisect_right
 
 from SwissArmyTransformer import get_tokenizer
 
-from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig
+from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
 from .utils import get_tokenized_input
 
 
@@ -35,21 +39,19 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
     If [MASK] not in context, will append [MASK] after text
     """
 
-    def __init__(self, path, config: BaseConfig):
-        self.path = path
+    def __init__(self, path: Union[str, List[str]], config: BaseConfig):
+        self.path = path if isinstance(path, list) else [path]
         self.config = config
         self.max_seq_length = self.config.max_seq_length
         self.dtype = np.int64
 
-        tokenizer = get_tokenizer(tokenizer_type="icetk-glm-130B")
-        self.mask_id = tokenizer.get_command("[MASK]")
-        self.gmask_id = tokenizer.get_command("[gMASK]")
+        self.tokenizer = get_tokenizer()
+        self.mask_id = self.tokenizer.get_command("[MASK]")
+        self.gmask_id = self.tokenizer.get_command("[gMASK]")
 
         self.data = []
-        with open(os.path.join(path), "r", encoding="utf-8") as file:
-            for line in file:
-                item = json.loads(line)
-                self.data.append(self.process_single_item(item))
+        for p in self.path:
+            self.process_single_file(p)
 
     @property
     def has_collate_fn(self) -> bool:
@@ -58,6 +60,12 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
     def collate_fn(self, samples):
         return None
 
+    def process_single_file(self, path):
+        with open(os.path.join(path), "r", encoding="utf-8") as file:
+            for line in file:
+                item = json.loads(line)
+                self.data.append(self.process_single_item(item))
+
     @abstractmethod
     def process_single_item(self, item) -> dict:
         pass
@@ -257,3 +265,57 @@ class MultiChoiceTaskDataset(EvaluationDataset):
         )
         sample["label"] = item["label"]
         return sample
+
+
+class LanguageModelTaskDataset(EvaluationDataset):
+    config: LanguageModelTaskConfig
+
+    def process_single_file(self, path):
+        with open(os.path.join(path), "r", encoding="utf-8") as file:
+            raw_text = file.read()
+            tokens = self.tokenizer.tokenize(raw_text)
+            self.data.append(
+                {
+                    "raw_text": tokens,
+                    "num_original_tokens": len(raw_text.strip().split(" ")),
+                    "num_sequences": max(
+                        math.ceil(
+                            max(len(tokens) - (self.config.max_seq_length - 1), 0) / self.config.generation_length
+                        )
+                        + 1,
+                        1,
+                    ),
+                }
+            )
+
+    def process_single_item(self, item):
+        pass
+
+    def __len__(self):
+        return self.data[0]["num_sequences"]
+
+    def __getitem__(self, idx):
+        start_idx = idx * self.config.generation_length
+        end_idx = start_idx + self.config.max_seq_length - 1  # for additional [gMASK]
+        tokens = self.data[0]["raw_text"][start_idx:end_idx]
+
+        mask_id = self.gmask_id if self.config.use_task_mask else self.mask_id
+        sop_id = self.tokenizer.get_command("sop")
+
+        if idx == 0 or self.config.unidirectional:
+            prompt, text = tokens[:1], tokens[1:]
+        else:
+            prompt_length = self.config.max_seq_length - 1 - self.config.generation_length
+            prompt, text = tokens[:prompt_length], tokens[prompt_length:]
+
+        seq_length = len(prompt) + len(text) + 1
+        attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
+        attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
+
+        return {
+            "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
+            "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
+            "position_ids": np.arange(0, seq_length, dtype=np.int64),
+            "attention_mask": attention_mask < 0.5,
+            "loss_masks": np.array([0] * (len(prompt) + 1) + [1] * len(text), dtype=np.int64),
+        }

+ 11 - 2
evaluation/metrics.py

@@ -1,7 +1,11 @@
-import string
 import re
+import math
+import string
 import functools
 
+import numpy as np
+
+from typing import Tuple, List
 from collections import Counter
 
 from SwissArmyTransformer import get_tokenizer
@@ -79,4 +83,9 @@ def qa_evaluate(predictions, examples, metric):
 qa_exact_match = functools.partial(qa_evaluate, metric=exact_match_score)
 qa_f1 = functools.partial(qa_evaluate, metric=f1_score)
 
-DEFAULT_METRICS = {"EM": qa_exact_match, "F1": qa_f1, "Accuracy": accuracy_metric}
+
+def calculate_perplexity(loss: List[float], data):
+    return math.exp(min(20, np.sum(loss) / data[0]["num_original_tokens"]))
+
+
+DEFAULT_METRICS = {"EM": qa_exact_match, "F1": qa_f1, "Accuracy": accuracy_metric, "PPL": calculate_perplexity}

+ 22 - 0
evaluation/model.py

@@ -3,6 +3,7 @@ import torch
 from typing import List, Union
 
 from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence
+from SwissArmyTransformer.mpu import vocab_parallel_cross_entropy
 
 
 class ModelForEvaluation(torch.nn.Module):
@@ -86,3 +87,24 @@ class ModelForEvaluation(torch.nn.Module):
             output_targets.append(line)
 
         return output_targets if return_all_beams else output_targets[0]
+
+    def calculate_loss(self, batch) -> List[float]:
+        tokens, position_ids, attention_mask = self.process_data(batch)
+        targets, loss_masks = (
+            batch["targets"].to(device=torch.cuda.current_device()).long(),
+            batch["loss_masks"].to(device=torch.cuda.current_device()).long(),
+        )
+
+        original_parallel_output = self.model.transformer.parallel_output
+        self.model.transformer.parallel_output = True
+        self.model.eval()
+
+        with torch.no_grad():
+            logits, *output_per_layers = self.model(tokens, position_ids, attention_mask, log_attention_weights=None)
+            losses = vocab_parallel_cross_entropy(logits.contiguous().float(), targets)
+            loss = torch.sum(losses * loss_masks, dim=-1)
+
+        self.model.transformer.parallel_output = original_parallel_output
+
+        # return list(zip(loss.tolist(), loss_masks.sum(dim=-1).tolist()))
+        return loss.tolist()

+ 16 - 2
evaluation/tasks.py

@@ -13,9 +13,9 @@ from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
 from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
 
 from generation import BeamSearchStrategy
-from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig
+from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
 from .model import ModelForEvaluation
-from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset
+from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
 from .utils import build_data_loader, gather_result, print_rank_0
 from .metrics import DEFAULT_METRICS
 
@@ -205,3 +205,17 @@ class MultiChoiceTask(BaseTask, ABC):
     def predict_single_batch(self, batch) -> List[int]:
         log_probs = self.model.cond_log_prob(batch)
         return [np.argmax(log_probs_single).item() for log_probs_single in log_probs]
+
+
+class LanguageModelTask(BaseTask, ABC):
+    config: LanguageModelTaskConfig
+
+    @classmethod
+    def config_class(cls):
+        return LanguageModelTaskConfig
+
+    def build_dataset(self, relative_path):
+        return LanguageModelTaskDataset(join(self.config.path, relative_path), self.config)
+
+    def predict_single_batch(self, batch) -> List[float]:
+        return self.model.calculate_loss(batch)

+ 8 - 0
tasks/language-modeling/ptb.yaml

@@ -0,0 +1,8 @@
+name: "Penn Treebank"
+type: "lm"
+path: "ptbdataset"
+file-pattern:
+  test: "**/ptb.test.txt"
+
+generation-length: 256
+use_task_mask: true

+ 8 - 0
tasks/language-modeling/wikitext-103.yaml

@@ -0,0 +1,8 @@
+name: "WikiText-103"
+type: "lm"
+path: "wikitext-103"
+file-pattern:
+  test: "**/wiki.test.tokens"
+
+generation-length: 256
+use_task_mask: true

+ 8 - 0
tasks/language-modeling/wikitext-2.yaml

@@ -0,0 +1,8 @@
+name: "WikiText-2"
+type: "lm"
+path: "wikitext-2"
+file-pattern:
+  test: "**/wiki.test.tokens"
+
+generation-length: 256
+use_task_mask: true