|
@@ -1,15 +1,19 @@
|
|
|
import os
|
|
|
+import math
|
|
|
import json
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
|
|
|
+from typing import List, Union
|
|
|
from abc import ABC, abstractmethod
|
|
|
from scipy.linalg import block_diag
|
|
|
+from itertools import accumulate
|
|
|
+from bisect import bisect_right
|
|
|
|
|
|
from SwissArmyTransformer import get_tokenizer
|
|
|
|
|
|
-from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig
|
|
|
+from .configs import BaseConfig, MultiChoiceTaskConfig, GenerationTaskConfig, LanguageModelTaskConfig
|
|
|
from .utils import get_tokenized_input
|
|
|
|
|
|
|
|
@@ -35,21 +39,19 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
|
|
|
If [MASK] not in context, will append [MASK] after text
|
|
|
"""
|
|
|
|
|
|
- def __init__(self, path, config: BaseConfig):
|
|
|
- self.path = path
|
|
|
+ def __init__(self, path: Union[str, List[str]], config: BaseConfig):
|
|
|
+ self.path = path if isinstance(path, list) else [path]
|
|
|
self.config = config
|
|
|
self.max_seq_length = self.config.max_seq_length
|
|
|
self.dtype = np.int64
|
|
|
|
|
|
- tokenizer = get_tokenizer(tokenizer_type="icetk-glm-130B")
|
|
|
- self.mask_id = tokenizer.get_command("[MASK]")
|
|
|
- self.gmask_id = tokenizer.get_command("[gMASK]")
|
|
|
+ self.tokenizer = get_tokenizer()
|
|
|
+ self.mask_id = self.tokenizer.get_command("[MASK]")
|
|
|
+ self.gmask_id = self.tokenizer.get_command("[gMASK]")
|
|
|
|
|
|
self.data = []
|
|
|
- with open(os.path.join(path), "r", encoding="utf-8") as file:
|
|
|
- for line in file:
|
|
|
- item = json.loads(line)
|
|
|
- self.data.append(self.process_single_item(item))
|
|
|
+ for p in self.path:
|
|
|
+ self.process_single_file(p)
|
|
|
|
|
|
@property
|
|
|
def has_collate_fn(self) -> bool:
|
|
@@ -58,6 +60,12 @@ class EvaluationDataset(torch.utils.data.Dataset, ABC):
|
|
|
def collate_fn(self, samples):
|
|
|
return None
|
|
|
|
|
|
+ def process_single_file(self, path):
|
|
|
+ with open(os.path.join(path), "r", encoding="utf-8") as file:
|
|
|
+ for line in file:
|
|
|
+ item = json.loads(line)
|
|
|
+ self.data.append(self.process_single_item(item))
|
|
|
+
|
|
|
@abstractmethod
|
|
|
def process_single_item(self, item) -> dict:
|
|
|
pass
|
|
@@ -257,3 +265,57 @@ class MultiChoiceTaskDataset(EvaluationDataset):
|
|
|
)
|
|
|
sample["label"] = item["label"]
|
|
|
return sample
|
|
|
+
|
|
|
+
|
|
|
+class LanguageModelTaskDataset(EvaluationDataset):
|
|
|
+ config: LanguageModelTaskConfig
|
|
|
+
|
|
|
+ def process_single_file(self, path):
|
|
|
+ with open(os.path.join(path), "r", encoding="utf-8") as file:
|
|
|
+ raw_text = file.read()
|
|
|
+ tokens = self.tokenizer.tokenize(raw_text)
|
|
|
+ self.data.append(
|
|
|
+ {
|
|
|
+ "raw_text": tokens,
|
|
|
+ "num_original_tokens": len(raw_text.strip().split(" ")),
|
|
|
+ "num_sequences": max(
|
|
|
+ math.ceil(
|
|
|
+ max(len(tokens) - (self.config.max_seq_length - 1), 0) / self.config.generation_length
|
|
|
+ )
|
|
|
+ + 1,
|
|
|
+ 1,
|
|
|
+ ),
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ def process_single_item(self, item):
|
|
|
+ pass
|
|
|
+
|
|
|
+ def __len__(self):
|
|
|
+ return self.data[0]["num_sequences"]
|
|
|
+
|
|
|
+ def __getitem__(self, idx):
|
|
|
+ start_idx = idx * self.config.generation_length
|
|
|
+ end_idx = start_idx + self.config.max_seq_length - 1 # for additional [gMASK]
|
|
|
+ tokens = self.data[0]["raw_text"][start_idx:end_idx]
|
|
|
+
|
|
|
+ mask_id = self.gmask_id if self.config.use_task_mask else self.mask_id
|
|
|
+ sop_id = self.tokenizer.get_command("sop")
|
|
|
+
|
|
|
+ if idx == 0 or self.config.unidirectional:
|
|
|
+ prompt, text = tokens[:1], tokens[1:]
|
|
|
+ else:
|
|
|
+ prompt_length = self.config.max_seq_length - 1 - self.config.generation_length
|
|
|
+ prompt, text = tokens[:prompt_length], tokens[prompt_length:]
|
|
|
+
|
|
|
+ seq_length = len(prompt) + len(text) + 1
|
|
|
+ attention_mask = np.tril(np.ones((seq_length, seq_length), dtype=np.int64))
|
|
|
+ attention_mask[: len(prompt) + 1, : len(prompt) + 1] = 1
|
|
|
+
|
|
|
+ return {
|
|
|
+ "tokens": np.array(prompt + [mask_id, sop_id] + text[:-1], dtype=np.int64),
|
|
|
+ "targets": np.array(prompt + [mask_id] + text, dtype=np.int64),
|
|
|
+ "position_ids": np.arange(0, seq_length, dtype=np.int64),
|
|
|
+ "attention_mask": attention_mask < 0.5,
|
|
|
+ "loss_masks": np.array([0] * (len(prompt) + 1) + [1] * len(text), dtype=np.int64),
|
|
|
+ }
|