|
@@ -0,0 +1,83 @@
|
|
|
|
+import os
|
|
|
|
+import math
|
|
|
|
+import json
|
|
|
|
+
|
|
|
|
+from typing import *
|
|
|
|
+from os.path import join
|
|
|
|
+from bisect import bisect_right
|
|
|
|
+from itertools import accumulate
|
|
|
|
+from collections import defaultdict
|
|
|
|
+
|
|
|
|
+from evaluation import LanguageModelTask, LanguageModelTaskDataset, print_rank_0
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def calculate_bpb_score(loss: List[float], data: List[Dict]):
|
|
|
|
+ loss_per_category = defaultdict(lambda: 0.0)
|
|
|
|
+ utf8_length_per_category = defaultdict(lambda: 0.0)
|
|
|
|
+ weights = []
|
|
|
|
+ for item in data:
|
|
|
|
+ weights.append(item["num_sequences"])
|
|
|
|
+ utf8_length_per_category[item["meta"]["pile_set_name"]] += item["utf8_length"]
|
|
|
|
+ weights = list(accumulate(weights))
|
|
|
|
+ for idx in range(len(loss)):
|
|
|
|
+ document_idx = bisect_right(weights, idx)
|
|
|
|
+ loss_per_category[data[document_idx]["meta"]["pile_set_name"]] += loss[idx]
|
|
|
|
+ return {
|
|
|
|
+ name: (loss_per_category[name] / utf8_length_per_category[name] / math.log(2)) for name in loss_per_category
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class Pile(LanguageModelTask):
|
|
|
|
+ @property
|
|
|
|
+ def metrics(self) -> Dict[str, Callable]:
|
|
|
|
+ return {"BPB": calculate_bpb_score}
|
|
|
|
+
|
|
|
|
+ def build_dataset(self, relative_path):
|
|
|
|
+ return PileDataset(join(self.config.path, relative_path), self.config)
|
|
|
|
+
|
|
|
|
+ def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ def report_group_metrics(
|
|
|
|
+ self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, Dict[str, float]], int]], level=1
|
|
|
|
+ ):
|
|
|
|
+ output_str = f" Finish group {group_name}:\n"
|
|
|
|
+ result = list(result_dict_group.values())[0][0]["BPB"]
|
|
|
|
+ for key, value in result.items():
|
|
|
|
+ output_str += f" {key} = {value:.3f}\n"
|
|
|
|
+ print_rank_0(output_str)
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ def report_overall_metrics(self, result_dict_all: Dict[str, Tuple[Dict[str, float], int]]):
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class PileDataset(LanguageModelTaskDataset):
|
|
|
|
+ def __len__(self):
|
|
|
|
+ return self.weights[-1]
|
|
|
|
+
|
|
|
|
+ def process_single_file(self, path):
|
|
|
|
+ num_sequences = []
|
|
|
|
+ with open(os.path.join(path), "r", encoding="utf-8") as file:
|
|
|
|
+ for line in file:
|
|
|
|
+ item = json.loads(line)
|
|
|
|
+ if len(item["text"]) == 0:
|
|
|
|
+ continue
|
|
|
|
+ self.data.append(
|
|
|
|
+ {
|
|
|
|
+ "raw_text": item["text"],
|
|
|
|
+ "utf8_length": len(item["text_pretokenized"].encode("utf-8")),
|
|
|
|
+ "num_sequences": max(
|
|
|
|
+ math.ceil(
|
|
|
|
+ max(len(item["text"]) - (self.config.max_seq_length - 1), 0)
|
|
|
|
+ / self.config.generation_length
|
|
|
|
+ )
|
|
|
|
+ + 1,
|
|
|
|
+ 1,
|
|
|
|
+ ),
|
|
|
|
+ "meta": item["meta"],
|
|
|
|
+ }
|
|
|
|
+ )
|
|
|
|
+ num_sequences.append(self.data[-1]["num_sequences"])
|
|
|
|
+ self.weights = list(accumulate(num_sequences))
|
|
|
|
+ self.left_weights = [0] + self.weights[:-1]
|