2
0

evaluate.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import time
  2. import importlib
  3. from os.path import join, isdir, isfile, relpath
  4. from glob import glob
  5. from evaluation import BaseConfig, ModelForEvaluation, DEFAULT_CLASS, print_rank_0
  6. from initialize import initialize, initialize_model_and_tokenizer
  7. def add_evaluation_specific_args(parser):
  8. """Arguments for evaluation"""
  9. group = parser.add_argument_group("evaluation", "Evaluation configurations")
  10. # Task
  11. group.add_argument("--task", nargs="+", default=[], help="All task config to evaluation")
  12. group.add_argument("--data-path", type=str, required=True, help="Data dir path for all tasks")
  13. return parser
  14. def find_all_tasks(all_task_config_path):
  15. tasks = []
  16. for task in all_task_config_path:
  17. if isdir(task):
  18. tasks += [relpath(path, ".") for path in glob(join(task, "**/*.yaml"), recursive=True)]
  19. elif isfile(task):
  20. tasks.append(task)
  21. return tasks
  22. def evaluate_all_tasks(data_path, model, tokenizer, all_task_config_path, task_classes):
  23. for config_path, task_class in zip(all_task_config_path, task_classes):
  24. config = task_class.config_class().from_yaml_file(config_path)
  25. config.path = join(data_path, config.path)
  26. task = task_class(model, tokenizer, config)
  27. task.evaluate()
  28. def main():
  29. args = initialize(extra_args_provider=add_evaluation_specific_args)
  30. args.task = find_all_tasks(args.task)
  31. task_classes = []
  32. print_rank_0("> Loading task configs")
  33. for task_config_path in args.task:
  34. config = BaseConfig.from_yaml_file(task_config_path)
  35. if config.module:
  36. path = ".".join(config.module.split(".")[:-1])
  37. module = importlib.import_module(path)
  38. class_name = config.module.split(".")[-1]
  39. task_class = getattr(module, class_name)
  40. task_classes.append(task_class)
  41. else:
  42. task_classes.append(DEFAULT_CLASS[config.type])
  43. print_rank_0(f" Task {config.name} loaded from config {task_config_path}")
  44. print_rank_0(f"> Successfully load {len(task_classes)} task{'s' if len(task_classes) > 1 else ''}")
  45. model, tokenizer = initialize_model_and_tokenizer(args)
  46. model = ModelForEvaluation(model, args.position_encoding_2d)
  47. start = time.time()
  48. evaluate_all_tasks(args.data_path, model, tokenizer, args.task, task_classes)
  49. print_rank_0(f"Finish {len(task_classes)} task{'s' if len(task_classes) > 1 else ''} in {time.time() - start:.1f}s")
  50. if __name__ == "__main__":
  51. main()