tasks.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import torch
  2. import time
  3. import numpy as np
  4. import torch.distributed as dist
  5. from tqdm import tqdm
  6. from typing import Dict, Callable, Type, Tuple, List, Any
  7. from abc import ABC, abstractmethod
  8. from glob import glob
  9. from os.path import join, relpath
  10. from collections import defaultdict
  11. from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
  12. from generation import BaseStrategy, BeamSearchStrategy
  13. from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig, LanguageModelTaskConfig
  14. from .model import ModelForEvaluation
  15. from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset, LanguageModelTaskDataset
  16. from .utils import build_data_loader, gather_result, print_rank_0
  17. from .metrics import DEFAULT_METRICS
  18. class BaseTask(ABC):
  19. model: ModelForEvaluation
  20. tokenizer: _IceTokenizer
  21. config: BaseConfig
  22. file_groups: Dict[str, List[str]]
  23. @classmethod
  24. def config_class(cls) -> Type[BaseConfig]:
  25. return BaseConfig
  26. @property
  27. def metrics(self) -> Dict[str, Callable]:
  28. return {metric: DEFAULT_METRICS[metric] for metric in self.config.metrics}
  29. def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: BaseConfig):
  30. self.model = model
  31. self.tokenizer = tokenizer
  32. self.config = config
  33. self.config.metrics = list(self.metrics.keys())
  34. self.file_groups = self.get_file_groups()
  35. self.verbose = dist.get_rank() == 0
  36. self.save_prediction = config.save_prediction
  37. def save_prediction_to_file(self, file, prediction, data):
  38. pass
  39. def get_file_groups(self):
  40. pattern_group = {}
  41. if isinstance(self.config.file_pattern, str):
  42. pattern_group["all"] = self.config.file_pattern
  43. else:
  44. pattern_group = self.config.file_pattern
  45. return {
  46. name: [
  47. relpath(path, start=self.config.path)
  48. for path in sorted(glob(join(self.config.path, pattern), recursive=True))
  49. ]
  50. for name, pattern in pattern_group.items()
  51. }
  52. def evaluate(self):
  53. dist.barrier()
  54. start = time.time()
  55. print_rank_0("\n")
  56. print_rank_0(f"{self.config}")
  57. print_rank_0(f"Evaluating task {self.config.name}:")
  58. result_dict_all = {}
  59. for group_name, filelist in self.file_groups.items():
  60. print_rank_0(f" Evaluating group {group_name}:")
  61. result_dict_group = {}
  62. for file in filelist:
  63. dataset = self.build_dataset(file, group_name)
  64. dataloader = build_data_loader(
  65. dataset,
  66. micro_batch_size=self.config.micro_batch_size,
  67. num_workers=1,
  68. drop_last=False,
  69. collate_fn=dataset.collate_fn if dataset.has_collate_fn else None,
  70. )
  71. prediction = []
  72. tqdm_wrapper = tqdm if torch.distributed.get_rank() == 0 else lambda x:x
  73. with torch.no_grad():
  74. for idx, batch in tqdm_wrapper(enumerate(dataloader)):
  75. p_batch = self.predict_single_batch(batch)
  76. prediction.append(p_batch)
  77. prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
  78. result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
  79. result_dict_group[file] = (result_dict, len(dataset))
  80. if torch.distributed.get_rank() == 0 and self.save_prediction:
  81. self.save_prediction_to_file(file, prediction, dataset.data)
  82. if self.verbose:
  83. self.report_single_metrics(file, result_dict)
  84. result_dict_all[group_name] = result_dict_group
  85. print_rank_0(f"Evaluation results of task {self.config.name}:")
  86. if self.verbose:
  87. for group_name, result_dict_group in result_dict_all.items():
  88. self.report_group_metrics(group_name, result_dict_group)
  89. self.report_overall_metrics(
  90. {k: v for result_dict_group in result_dict_all.values() for k, v in result_dict_group.items()},
  91. )
  92. print_rank_0(f"Finish task {self.config.name} in {time.time() - start:.1f}s.")
  93. def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
  94. output_str = f" Finish {file}"
  95. for key, value in result_dict.items():
  96. output_str += f", {key} = {value:.3f}"
  97. print_rank_0(output_str)
  98. @staticmethod
  99. def calc_group_metrics(result_dict_group: Dict[str, Tuple[Dict[str, float], int]]):
  100. metrics_dict = defaultdict(lambda: [])
  101. weight = []
  102. for file, (result_dict, length) in result_dict_group.items():
  103. for key, value in result_dict.items():
  104. metrics_dict[key].append(value)
  105. weight.append(length)
  106. return {
  107. name: {
  108. "max": np.max(value),
  109. "median": np.median(value),
  110. "average": np.average(value, weights=weight),
  111. }
  112. for name, value in metrics_dict.items()
  113. }
  114. def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1):
  115. stats_dict = self.calc_group_metrics(result_dict_group)
  116. if len(stats_dict) == 1:
  117. name, stats = next(iter(stats_dict.items()))
  118. print_rank_0(
  119. " " * level + f"Group {group_name} {name}: max = {stats['max']:.3f}, "
  120. f"median = {stats['median']:.3f}, average = {stats['average']:.3f}"
  121. )
  122. else:
  123. print_rank_0(" " * level + f" Group {group_name}: ")
  124. for name, stats in stats_dict.items():
  125. print(
  126. " " * (level + 1) + f"Metric {name}: max = {stats['max']:.3f}, "
  127. f"median = {stats['median']:.3f}, average = {stats['average']:.3f}"
  128. )
  129. def report_overall_metrics(self, result_dict_all: Dict[str, Tuple[Dict[str, float], int]]):
  130. pass
  131. @abstractmethod
  132. def predict_single_batch(self, batch) -> List[Any]:
  133. pass
  134. @abstractmethod
  135. def build_dataset(self, relative_path: str, split: str) -> EvaluationDataset:
  136. pass
  137. class GenerationTask(BaseTask, ABC):
  138. config: GenerationTaskConfig
  139. @classmethod
  140. def config_class(cls):
  141. return GenerationTaskConfig
  142. def build_dataset(self, relative_path, split):
  143. return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
  144. def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
  145. super(GenerationTask, self).__init__(model, tokenizer, config)
  146. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  147. if self.config.end_tokens:
  148. for token in self.config.end_tokens:
  149. end_tokens.append(self.tokenizer.tokenize(token)[-1])
  150. print_rank_0(f"End tokens {end_tokens}")
  151. if self.config.sampling_strategy == "BaseStrategy":
  152. self.strategy = BaseStrategy(batch_size=self.config.micro_batch_size, temperature=1.0, top_k=1,
  153. end_tokens=end_tokens)
  154. elif self.config.sampling_strategy == "BeamSearchStrategy":
  155. self.strategy = BeamSearchStrategy(
  156. self.config.micro_batch_size,
  157. self.config.num_beams,
  158. length_penalty=self.config.length_penalty,
  159. consider_end=True,
  160. end_tokens=end_tokens,
  161. no_repeat_ngram_size=self.config.no_repeat_ngram_size,
  162. min_gen_length=self.config.min_gen_length,
  163. deterministic=False, # For evaluation, we need a determined generation strategy
  164. )
  165. else:
  166. raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
  167. def predict_single_batch(self, batch) -> List[List[int]]:
  168. output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
  169. return output
  170. class MultiChoiceTask(BaseTask, ABC):
  171. config: MultiChoiceTaskConfig
  172. @classmethod
  173. def config_class(cls):
  174. return MultiChoiceTaskConfig
  175. def build_dataset(self, relative_path, split):
  176. return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
  177. def predict_single_batch(self, batch) -> List[int]:
  178. log_probs = self.model.cond_log_prob(batch)
  179. return [np.argmax(log_probs_single).item() for log_probs_single in log_probs]
  180. class LanguageModelTask(BaseTask, ABC):
  181. config: LanguageModelTaskConfig
  182. @classmethod
  183. def config_class(cls):
  184. return LanguageModelTaskConfig
  185. def build_dataset(self, relative_path):
  186. return LanguageModelTaskDataset(join(self.config.path, relative_path), self.config)
  187. def predict_single_batch(self, batch) -> List[float]:
  188. return self.model.calculate_loss(batch)