2
0

pile.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. import math
  3. import json
  4. from typing import *
  5. from os.path import join
  6. from bisect import bisect_right
  7. from itertools import accumulate
  8. from collections import defaultdict
  9. from evaluation import LanguageModelTask, LanguageModelTaskDataset, print_rank_0
  10. def calculate_bpb_score(loss: List[float], data: List[Dict]):
  11. loss_per_category = defaultdict(lambda: 0.0)
  12. utf8_length_per_category = defaultdict(lambda: 0.0)
  13. weights = []
  14. for item in data:
  15. weights.append(item["num_sequences"])
  16. utf8_length_per_category[item["meta"]["pile_set_name"]] += item["utf8_length"]
  17. weights = list(accumulate(weights))
  18. for idx in range(len(loss)):
  19. document_idx = bisect_right(weights, idx)
  20. loss_per_category[data[document_idx]["meta"]["pile_set_name"]] += loss[idx]
  21. return {
  22. name: (loss_per_category[name] / utf8_length_per_category[name] / math.log(2)) for name in loss_per_category
  23. }
  24. class Pile(LanguageModelTask):
  25. @property
  26. def metrics(self) -> Dict[str, Callable]:
  27. return {"BPB": calculate_bpb_score}
  28. def build_dataset(self, relative_path):
  29. return PileDataset(join(self.config.path, relative_path), self.model, self.config)
  30. def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
  31. pass
  32. def report_group_metrics(
  33. self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, Dict[str, float]], int]], level=1
  34. ):
  35. output_str = f" Finish group {group_name}:\n"
  36. result = list(result_dict_group.values())[0][0]["BPB"]
  37. for key, value in result.items():
  38. output_str += f" {key} = {value:.3f}\n"
  39. print_rank_0(output_str)
  40. pass
  41. def report_overall_metrics(self, result_dict_all: Dict[str, Tuple[Dict[str, float], int]]):
  42. pass
  43. class PileDataset(LanguageModelTaskDataset):
  44. def __len__(self):
  45. return self.weights[-1]
  46. def process_single_file(self, path):
  47. num_sequences = []
  48. with open(os.path.join(path), "r", encoding="utf-8") as file:
  49. for line in file:
  50. item = json.loads(line)
  51. if len(item["text"]) == 0:
  52. continue
  53. self.data.append(
  54. {
  55. "raw_text": item["text"],
  56. "utf8_length": len(item["text_pretokenized"].encode("utf-8")),
  57. "num_sequences": max(
  58. math.ceil(
  59. max(len(item["text"]) - (self.config.max_seq_length - 1), 0)
  60. / self.config.generation_length
  61. )
  62. + 1,
  63. 1,
  64. ),
  65. "meta": item["meta"],
  66. }
  67. )
  68. num_sequences.append(self.data[-1]["num_sequences"])
  69. self.weights = list(accumulate(num_sequences))
  70. self.left_weights = [0] + self.weights[:-1]