tasks.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. import torch
  2. import time
  3. import numpy as np
  4. import torch.distributed as dist
  5. from typing import Dict, Callable, Type, Tuple, List, Any
  6. from abc import ABC, abstractmethod
  7. from glob import glob
  8. from os.path import join, relpath
  9. from collections import defaultdict
  10. from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
  11. from generation import BaseStrategy, BeamSearchStrategy
  12. from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
  13. from .model import ModelForEvaluation
  14. from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
  15. from .utils import build_data_loader, gather_result, print_rank_0
  16. from .metrics import DEFAULT_METRICS
  17. class BaseTask(ABC):
  18. model: ModelForEvaluation
  19. tokenizer: _IceTokenizer
  20. config: BaseConfig
  21. file_groups: Dict[str, List[str]]
  22. @classmethod
  23. def config_class(cls) -> Type[BaseConfig]:
  24. return BaseConfig
  25. @property
  26. def metrics(self) -> Dict[str, Callable]:
  27. return {metric: DEFAULT_METRICS[metric] for metric in self.config.metrics}
  28. def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: BaseConfig):
  29. self.model = model
  30. self.tokenizer = tokenizer
  31. self.config = config
  32. self.config.metrics = list(self.metrics.keys())
  33. self.file_groups = self.get_file_groups()
  34. self.verbose = dist.get_rank() == 0
  35. def get_file_groups(self):
  36. pattern_group = {}
  37. if isinstance(self.config.file_pattern, str):
  38. pattern_group["all"] = self.config.file_pattern
  39. else:
  40. pattern_group = self.config.file_pattern
  41. return {
  42. name: [
  43. relpath(path, start=self.config.path)
  44. for path in sorted(glob(join(self.config.path, pattern), recursive=True))
  45. ]
  46. for name, pattern in pattern_group.items()
  47. }
  48. def evaluate(self):
  49. dist.barrier()
  50. start = time.time()
  51. print_rank_0("\n")
  52. print_rank_0(f"{self.config}")
  53. print_rank_0(f"Evaluating task {self.config.name}:")
  54. result_dict_all = {}
  55. for group_name, filelist in self.file_groups.items():
  56. print_rank_0(f" Evaluating group {group_name}:")
  57. result_dict_group = {}
  58. for file in filelist:
  59. dataset = self.build_dataset(file)
  60. dataloader = build_data_loader(
  61. dataset,
  62. micro_batch_size=self.config.micro_batch_size,
  63. num_workers=1,
  64. drop_last=False,
  65. collate_fn=dataset.collate_fn if dataset.has_collate_fn else None,
  66. )
  67. prediction = []
  68. with torch.no_grad():
  69. for _, batch in enumerate(dataloader):
  70. prediction.append(self.predict_single_batch(batch))
  71. prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
  72. result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
  73. result_dict_group[file] = (result_dict, len(dataset))
  74. if self.verbose:
  75. self.report_single_metrics(file, result_dict)
  76. result_dict_all[group_name] = result_dict_group
  77. print_rank_0(f"Evaluation results of task {self.config.name}:")
  78. if self.verbose:
  79. for group_name, result_dict_group in result_dict_all.items():
  80. self.report_group_metrics(group_name, result_dict_group)
  81. self.report_overall_metrics(
  82. {k: v for result_dict_group in result_dict_all.values() for k, v in result_dict_group.items()},
  83. )
  84. print_rank_0(f"Finish task {self.config.name} in {time.time() - start:.1f}s.")
  85. def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
  86. output_str = f" Finish {file}"
  87. for key, value in result_dict.items():
  88. output_str += f", {key} = {value:.3f}"
  89. print_rank_0(output_str)
  90. @staticmethod
  91. def calc_group_metrics(result_dict_group: Dict[str, Tuple[Dict[str, float], int]]):
  92. metrics_dict = defaultdict(lambda: [])
  93. weight = []
  94. for file, (result_dict, length) in result_dict_group.items():
  95. for key, value in result_dict.items():
  96. metrics_dict[key].append(value)
  97. weight.append(length)
  98. return {
  99. name: {
  100. "max": np.max(value),
  101. "median": np.median(value),
  102. "average": np.average(value, weights=weight),
  103. }
  104. for name, value in metrics_dict.items()
  105. }
  106. def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1):
  107. stats_dict = self.calc_group_metrics(result_dict_group)
  108. if len(stats_dict) == 1:
  109. name, stats = next(iter(stats_dict.items()))
  110. print_rank_0(
  111. " " * level + f"Group {group_name} {name}: max = {stats['max']:.3f}, "
  112. f"median = {stats['median']:.3f}, average = {stats['average']:.3f}"
  113. )
  114. else:
  115. print_rank_0(" " * level + f" Group {group_name}: ")
  116. for name, stats in stats_dict.items():
  117. print(
  118. " " * (level + 1) + f"Metric {name}: max = {stats['max']:.3f}, "
  119. f"median = {stats['median']:.3f}, average = {stats['average']:.3f}"
  120. )
  121. def report_overall_metrics(self, result_dict_all: Dict[str, Tuple[Dict[str, float], int]]):
  122. pass
  123. @abstractmethod
  124. def predict_single_batch(self, batch) -> List[Any]:
  125. pass
  126. @abstractmethod
  127. def build_dataset(self, relative_path: str) -> EvaluationDataset:
  128. pass
  129. class GenerationTask(BaseTask, ABC):
  130. config: GenerationTaskConfig
  131. @classmethod
  132. def config_class(cls):
  133. return GenerationTaskConfig
  134. def build_dataset(self, relative_path):
  135. return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
  136. def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
  137. super(GenerationTask, self).__init__(model, tokenizer, config)
  138. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  139. if self.config.sampling_strategy == "BaseStrategy":
  140. self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
  141. end_tokens=end_tokens)
  142. elif self.config.sampling_strategy == "BeamSearchStrategy":
  143. self.strategy = BeamSearchStrategy(
  144. self.config.micro_batch_size,
  145. self.config.num_beams,
  146. length_penalty=self.config.length_penalty,
  147. consider_end=True,
  148. end_tokens=end_tokens,
  149. no_repeat_ngram_size=self.config.no_repeat_ngram_size,
  150. min_gen_length=self.config.min_gen_length,
  151. deterministic=True, # For evaluation, we need a determined generation strategy
  152. )
  153. else:
  154. raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
  155. def predict_single_batch(self, batch) -> List[List[int]]:
  156. # micro batch size = 1 for generation task,
  157. # but we still need to return a list of predictions for consistency
  158. output = self.model.generate_text(batch, self.strategy, return_all_beams=False,
  159. max_gen_length=self.config.max_gen_length)
  160. return output
  161. class MultiChoiceTask(BaseTask, ABC):
  162. config: MultiChoiceTaskConfig
  163. @classmethod
  164. def config_class(cls):
  165. return MultiChoiceTaskConfig
  166. def build_dataset(self, relative_path):
  167. return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
  168. def predict_single_batch(self, batch) -> List[int]:
  169. log_probs = self.model.cond_log_prob(batch)
  170. return [np.argmax(log_probs_single).item() for log_probs_single in log_probs]
  171. class LanguageModelTask(BaseTask, ABC):
  172. config: LanguageModelTaskConfig
  173. @classmethod
  174. def config_class(cls):
  175. return LanguageModelTaskConfig
  176. def build_dataset(self, relative_path):
  177. return LanguageModelTaskDataset(join(self.config.path, relative_path), self.config)
  178. def predict_single_batch(self, batch) -> List[float]:
  179. return self.model.calculate_loss(batch)