tasks.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. import os
  2. import time
  3. import torch
  4. import numpy as np
  5. import torch.distributed as dist
  6. from typing import Dict, Callable, Type, Tuple, List, Any
  7. from abc import ABC, abstractmethod
  8. from glob import glob
  9. from os.path import join, relpath
  10. from collections import defaultdict
  11. from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
  12. from generation import BaseStrategy, BeamSearchStrategy
  13. from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
  14. from .model import ModelForEvaluation
  15. from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
  16. from .utils import build_data_loader, gather_result, print_rank_0
  17. from .metrics import DEFAULT_METRICS
  18. class BaseTask(ABC):
  19. model: ModelForEvaluation
  20. tokenizer: _IceTokenizer
  21. config: BaseConfig
  22. file_groups: Dict[str, List[str]]
  23. @classmethod
  24. def config_class(cls) -> Type[BaseConfig]:
  25. return BaseConfig
  26. @property
  27. def metrics(self) -> Dict[str, Callable]:
  28. return {metric: DEFAULT_METRICS[metric] for metric in self.config.metrics}
  29. def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: BaseConfig):
  30. self.model = model
  31. self.tokenizer = tokenizer
  32. self.config = config
  33. self.config.metrics = list(self.metrics.keys())
  34. self.file_groups = self.get_file_groups()
  35. self.verbose = dist.get_rank() == 0
  36. self.save_prediction = config.save_prediction
  37. def save_prediction_to_file(self, file, prediction, data):
  38. pass
  39. def get_file_groups(self):
  40. pattern_group = {}
  41. if isinstance(self.config.file_pattern, str):
  42. pattern_group["all"] = self.config.file_pattern
  43. else:
  44. pattern_group = self.config.file_pattern
  45. return {
  46. name: [
  47. relpath(path, start=self.config.path)
  48. for path in sorted(glob(join(self.config.path, pattern), recursive=True))
  49. ]
  50. for name, pattern in pattern_group.items()
  51. }
  52. def evaluate(self):
  53. dist.barrier()
  54. start = time.time()
  55. print_rank_0("\n")
  56. print_rank_0(f"{self.config}")
  57. print_rank_0(f"Evaluating task {self.config.name}:")
  58. result_dict_all = {}
  59. for group_name, filelist in self.file_groups.items():
  60. print_rank_0(f" Evaluating group {group_name}:")
  61. result_dict_group = {}
  62. for file in filelist:
  63. dataset = self.build_dataset(file)
  64. dataloader = build_data_loader(
  65. dataset,
  66. micro_batch_size=self.config.micro_batch_size,
  67. num_workers=1,
  68. drop_last=False,
  69. collate_fn=dataset.collate_fn if dataset.has_collate_fn else None,
  70. )
  71. prediction = []
  72. with torch.no_grad():
  73. for _, batch in enumerate(dataloader):
  74. prediction.append(self.predict_single_batch(batch))
  75. prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
  76. result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
  77. result_dict_group[file] = (result_dict, len(dataset))
  78. if torch.distributed.get_rank() == 0 and self.save_prediction:
  79. self.save_prediction_to_file(file, prediction, dataset.data)
  80. if self.verbose:
  81. self.report_single_metrics(file, result_dict)
  82. result_dict_all[group_name] = result_dict_group
  83. print_rank_0(f"Evaluation results of task {self.config.name}:")
  84. if self.verbose:
  85. for group_name, result_dict_group in result_dict_all.items():
  86. self.report_group_metrics(group_name, result_dict_group)
  87. self.report_overall_metrics(
  88. {k: v for result_dict_group in result_dict_all.values() for k, v in result_dict_group.items()},
  89. )
  90. print_rank_0(f"Finish task {self.config.name} in {time.time() - start:.1f}s.")
  91. def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
  92. output_str = f" Finish {file}"
  93. for key, value in result_dict.items():
  94. output_str += f", {key} = {value:.3f}"
  95. print_rank_0(output_str)
  96. @staticmethod
  97. def calc_group_metrics(result_dict_group: Dict[str, Tuple[Dict[str, float], int]]):
  98. metrics_dict = defaultdict(lambda: [])
  99. weight = []
  100. for file, (result_dict, length) in result_dict_group.items():
  101. for key, value in result_dict.items():
  102. metrics_dict[key].append(value)
  103. weight.append(length)
  104. return {
  105. name: {
  106. "max": np.max(value),
  107. "median": np.median(value),
  108. "average": np.average(value, weights=weight),
  109. }
  110. for name, value in metrics_dict.items()
  111. }
  112. def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1):
  113. stats_dict = self.calc_group_metrics(result_dict_group)
  114. if len(stats_dict) == 1:
  115. name, stats = next(iter(stats_dict.items()))
  116. print_rank_0(
  117. " " * level + f"Group {group_name} {name}: max = {stats['max']:.3f}, "
  118. f"median = {stats['median']:.3f}, average = {stats['average']:.3f}"
  119. )
  120. else:
  121. print_rank_0(" " * level + f" Group {group_name}: ")
  122. for name, stats in stats_dict.items():
  123. print(
  124. " " * (level + 1) + f"Metric {name}: max = {stats['max']:.3f}, "
  125. f"median = {stats['median']:.3f}, average = {stats['average']:.3f}"
  126. )
  127. def report_overall_metrics(self, result_dict_all: Dict[str, Tuple[Dict[str, float], int]]):
  128. pass
  129. @abstractmethod
  130. def predict_single_batch(self, batch) -> List[Any]:
  131. pass
  132. @abstractmethod
  133. def build_dataset(self, relative_path: str) -> EvaluationDataset:
  134. pass
  135. class GenerationTask(BaseTask, ABC):
  136. config: GenerationTaskConfig
  137. @classmethod
  138. def config_class(cls):
  139. return GenerationTaskConfig
  140. def build_dataset(self, relative_path):
  141. return GenerationTaskDataset(join(self.config.path, relative_path), self.model, self.config)
  142. def save_prediction_to_file(self, file, prediction, data):
  143. filename = os.path.join("outputs", self.config.name, f"{file}.predict")
  144. os.makedirs(os.path.dirname(filename), exist_ok=True)
  145. with open(filename, "w") as file:
  146. for item in prediction:
  147. file.write(self.tokenizer.detokenize(item) + "\n")
  148. def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
  149. super(GenerationTask, self).__init__(model, tokenizer, config)
  150. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  151. if self.config.end_tokens:
  152. for token in self.config.end_tokens:
  153. end_tokens.append(self.tokenizer.tokenize(token)[-1])
  154. print_rank_0(f"End tokens {end_tokens}")
  155. if self.config.sampling_strategy == "BaseStrategy":
  156. self.strategy = BaseStrategy(
  157. batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1, end_tokens=end_tokens
  158. )
  159. elif self.config.sampling_strategy == "BeamSearchStrategy":
  160. self.strategy = BeamSearchStrategy(
  161. self.config.micro_batch_size,
  162. self.config.num_beams,
  163. length_penalty=self.config.length_penalty,
  164. consider_end=True,
  165. end_tokens=end_tokens,
  166. no_repeat_ngram_size=self.config.no_repeat_ngram_size,
  167. min_gen_length=self.config.min_gen_length,
  168. deterministic=True, # For evaluation, we need a determined generation strategy
  169. )
  170. else:
  171. raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
  172. def predict_single_batch(self, batch) -> List[List[int]]:
  173. output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
  174. return output
  175. class MultiChoiceTask(BaseTask, ABC):
  176. config: MultiChoiceTaskConfig
  177. @classmethod
  178. def config_class(cls):
  179. return MultiChoiceTaskConfig
  180. def build_dataset(self, relative_path):
  181. return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.model, self.config)
  182. def predict_single_batch(self, batch) -> List[int]:
  183. log_probs = self.model.cond_log_prob(batch)
  184. return [np.argmax(log_probs_single).item() for log_probs_single in log_probs]
  185. class LanguageModelTask(BaseTask, ABC):
  186. config: LanguageModelTaskConfig
  187. @classmethod
  188. def config_class(cls):
  189. return LanguageModelTaskConfig
  190. def build_dataset(self, relative_path):
  191. return LanguageModelTaskDataset(join(self.config.path, relative_path), self.model, self.config)
  192. def predict_single_batch(self, batch) -> List[float]:
  193. return self.model.calculate_loss(batch)