configs.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from __future__ import annotations
  2. from dataclass_wizard import YAMLWizard
  3. from dataclasses import dataclass, field
  4. from enum import Enum
  5. from typing import Optional, List, Dict
  6. class TaskType(Enum):
  7. MULTICHOICE = "mul"
  8. GENERATION = "gen"
  9. LANGUAGE_MODEL = "lm"
  10. OTHER = "other"
  11. @dataclass
  12. class BaseConfig(YAMLWizard):
  13. name: str # Task name
  14. type: TaskType # Task type
  15. path: str # task data path relative to DATA_PATH
  16. module: Optional[str] = None # Custom task module file, optional
  17. metrics: List[str] = field(default_factory=list) # Evaluation metrics
  18. use_task_mask: bool = False # Whether to use [gMASK] for evaluation
  19. use_multitask_encoding: bool = False # Not supported now
  20. unidirectional: bool = False # Whether to use unidirectional attention
  21. max_seq_length: int = 2048 # Max sequence length
  22. no_tokenized: bool = False
  23. file_pattern: str | Dict[str, str] = "**/*.json*" # Organize data file in groups
  24. micro_batch_size: int = 1 # 'gen' task only support mbs = 1 for now
  25. def __post_init__(self):
  26. assert self.use_task_mask or not self.unidirectional, "[MASK] doesn't support unidirectional attention"
  27. @dataclass
  28. class MultiChoiceTaskConfig(BaseConfig):
  29. module = "evaluation.MultiChoiceTask"
  30. metrics: List[str] = field(default_factory=lambda: ["Accuracy"])
  31. @dataclass
  32. class GenerationTaskConfig(BaseConfig):
  33. module = "evaluation.GenerationTask"
  34. metrics: List[str] = field(default_factory=lambda: [])
  35. sampling_strategy: str = "BaseStrategy"
  36. num_beams: int = 4
  37. length_penalty: float = 1.0
  38. no_repeat_ngram_size: int = 3
  39. min_gen_length: int = 0
  40. max_gen_length: int = 128
  41. @dataclass
  42. class LanguageModelTaskConfig(BaseConfig):
  43. module = "evaluation.LanguageModelTask"
  44. metrics: List[str] = field(default_factory=lambda: ["PPL"])
  45. generation_length: int = 256 # Generated length in each window