tasks.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. import torch
  2. import time
  3. import numpy as np
  4. import torch.distributed as dist
  5. from tqdm import tqdm
  6. from typing import Dict, Callable, Type, Tuple, List, Any
  7. from abc import ABC, abstractmethod
  8. from glob import glob
  9. from os.path import join, relpath
  10. from collections import defaultdict
  11. from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy
  12. from SwissArmyTransformer.tokenization.icetk_glm_130B.ice_tokenizer import _IceTokenizer
  13. from generation import BeamSearchStrategy
  14. from .configs import BaseConfig, GenerationTaskConfig, MultiChoiceTaskConfig
  15. from .model import ModelForEvaluation
  16. from .dataset import EvaluationDataset, GenerationTaskDataset, MultiChoiceTaskDataset
  17. from .utils import build_data_loader, gather_result, print_rank_0
  18. from .metrics import DEFAULT_METRICS
  19. class BaseTask(ABC):
  20. model: ModelForEvaluation
  21. tokenizer: _IceTokenizer
  22. config: BaseConfig
  23. file_groups: Dict[str, List[str]]
  24. @classmethod
  25. def config_class(cls) -> Type[BaseConfig]:
  26. return BaseConfig
  27. @property
  28. def metrics(self) -> Dict[str, Callable]:
  29. return {metric: DEFAULT_METRICS[metric] for metric in self.config.metrics}
  30. def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: BaseConfig):
  31. self.model = model
  32. self.tokenizer = tokenizer
  33. self.config = config
  34. self.config.metrics = list(self.metrics.keys())
  35. self.file_groups = self.get_file_groups()
  36. self.verbose = dist.get_rank() == 0
  37. self.save_prediction = config.save_prediction
  38. def save_prediction_to_file(self, file, prediction, data):
  39. pass
  40. def get_file_groups(self):
  41. pattern_group = {}
  42. if isinstance(self.config.file_pattern, str):
  43. pattern_group["all"] = self.config.file_pattern
  44. else:
  45. pattern_group = self.config.file_pattern
  46. return {
  47. name: [
  48. relpath(path, start=self.config.path)
  49. for path in sorted(glob(join(self.config.path, pattern), recursive=True))
  50. ]
  51. for name, pattern in pattern_group.items()
  52. }
  53. def evaluate(self):
  54. dist.barrier()
  55. start = time.time()
  56. print_rank_0("\n")
  57. print_rank_0(f"{self.config}")
  58. print_rank_0(f"Evaluating task {self.config.name}:")
  59. result_dict_all = {}
  60. for group_name, filelist in self.file_groups.items():
  61. print_rank_0(f" Evaluating group {group_name}:")
  62. result_dict_group = {}
  63. for file in filelist:
  64. dataset = self.build_dataset(file, group_name)
  65. dataloader = build_data_loader(
  66. dataset,
  67. micro_batch_size=self.config.micro_batch_size,
  68. num_workers=1,
  69. drop_last=False,
  70. collate_fn=dataset.collate_fn if dataset.has_collate_fn else None,
  71. )
  72. prediction = []
  73. tqdm_wrapper = tqdm if torch.distributed.get_rank() == 0 else lambda x:x
  74. with torch.no_grad():
  75. for idx, batch in tqdm_wrapper(enumerate(dataloader)):
  76. p_batch = self.predict_single_batch(batch)
  77. prediction.append(p_batch)
  78. prediction = gather_result(prediction, len(dataset), self.config.micro_batch_size)
  79. result_dict = {key: metric(prediction, dataset.data) for key, metric in self.metrics.items()}
  80. result_dict_group[file] = (result_dict, len(dataset))
  81. if torch.distributed.get_rank() == 0 and self.save_prediction:
  82. self.save_prediction_to_file(file, prediction, dataset.data)
  83. if self.verbose:
  84. self.report_single_metrics(file, result_dict)
  85. result_dict_all[group_name] = result_dict_group
  86. print_rank_0(f"Evaluation results of task {self.config.name}:")
  87. if self.verbose:
  88. for group_name, result_dict_group in result_dict_all.items():
  89. self.report_group_metrics(group_name, result_dict_group)
  90. self.report_overall_metrics(
  91. {k: v for result_dict_group in result_dict_all.values() for k, v in result_dict_group.items()},
  92. )
  93. print_rank_0(f"Finish task {self.config.name} in {time.time() - start:.1f}s.")
  94. def report_single_metrics(self, file: str, result_dict: Dict[str, float]):
  95. output_str = f" Finish {file}"
  96. for key, value in result_dict.items():
  97. output_str += f", {key} = {value:.3f}"
  98. print_rank_0(output_str)
  99. @staticmethod
  100. def calc_group_metrics(result_dict_group: Dict[str, Tuple[Dict[str, float], int]]):
  101. metrics_dict = defaultdict(lambda: [])
  102. weight = []
  103. for file, (result_dict, length) in result_dict_group.items():
  104. for key, value in result_dict.items():
  105. metrics_dict[key].append(value)
  106. weight.append(length)
  107. return {
  108. name: {
  109. "max": np.max(value),
  110. "median": np.median(value),
  111. "average": np.average(value, weights=weight),
  112. }
  113. for name, value in metrics_dict.items()
  114. }
  115. def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1):
  116. stats_dict = self.calc_group_metrics(result_dict_group)
  117. if len(stats_dict) == 1:
  118. name, stats = next(iter(stats_dict.items()))
  119. print_rank_0(
  120. " " * level + f"Group {group_name} {name}: max = {stats['max']:.3f}, "
  121. f"median = {stats['median']:.3f}, average = {stats['average']:.3f}"
  122. )
  123. else:
  124. print_rank_0(" " * level + f" Group {group_name}: ")
  125. for name, stats in stats_dict.items():
  126. print(
  127. " " * (level + 1) + f"Metric {name}: max = {stats['max']:.3f}, "
  128. f"median = {stats['median']:.3f}, average = {stats['average']:.3f}"
  129. )
  130. def report_overall_metrics(self, result_dict_all: Dict[str, Tuple[Dict[str, float], int]]):
  131. pass
  132. @abstractmethod
  133. def predict_single_batch(self, batch) -> List[Any]:
  134. pass
  135. @abstractmethod
  136. def build_dataset(self, relative_path: str, split: str) -> EvaluationDataset:
  137. pass
  138. class GenerationTask(BaseTask, ABC):
  139. config: GenerationTaskConfig
  140. @classmethod
  141. def config_class(cls):
  142. return GenerationTaskConfig
  143. def build_dataset(self, relative_path, split):
  144. return GenerationTaskDataset(join(self.config.path, relative_path), self.config)
  145. def __init__(self, model: ModelForEvaluation, tokenizer: _IceTokenizer, config: GenerationTaskConfig):
  146. super(GenerationTask, self).__init__(model, tokenizer, config)
  147. end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
  148. if self.config.end_tokens:
  149. for token in self.config.end_tokens:
  150. end_tokens.append(self.tokenizer.tokenize(token)[-1])
  151. print_rank_0(f"End tokens {end_tokens}")
  152. if self.config.sampling_strategy == "BaseStrategy":
  153. self.strategy = BaseStrategy(temperature=1.0, top_k=1, end_tokens=end_tokens)
  154. elif self.config.sampling_strategy == "BeamSearchStrategy":
  155. self.strategy = BeamSearchStrategy(
  156. self.config.num_beams,
  157. length_penalty=self.config.length_penalty,
  158. consider_end=True,
  159. end_tokens=end_tokens,
  160. no_repeat_ngram_size=self.config.no_repeat_ngram_size,
  161. min_gen_length=self.config.min_gen_length,
  162. deterministic=False, # For evaluation, we need a determined generation strategy
  163. )
  164. else:
  165. raise ValueError(f"unknown strategy {self.config.sampling_strategy}")
  166. def predict_single_batch(self, batch) -> List[List[int]]:
  167. # micro batch size = 1 for generation task,
  168. # but we still need to return a list of predictions for consistency
  169. output = self.model.generate_text(batch, self.strategy, return_all_beams=False)
  170. return [output]
  171. class MultiChoiceTask(BaseTask, ABC):
  172. config: MultiChoiceTaskConfig
  173. @classmethod
  174. def config_class(cls):
  175. return MultiChoiceTaskConfig
  176. def build_dataset(self, relative_path, split):
  177. return MultiChoiceTaskDataset(join(self.config.path, relative_path), self.config)
  178. def predict_single_batch(self, batch) -> List[int]:
  179. log_probs = self.model.cond_log_prob(batch)
  180. return [np.argmax(log_probs_single).item() for log_probs_single in log_probs]