configs.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. from __future__ import annotations
  2. from dataclass_wizard import YAMLWizard
  3. from dataclasses import dataclass, field
  4. from enum import Enum
  5. from typing import Optional, List, Dict
  6. class TaskType(Enum):
  7. MULTICHOICE = "mul"
  8. GENERATION = "gen"
  9. LANGUAGE_MODEL = "lm"
  10. OTHER = "other"
  11. @dataclass
  12. class BaseConfig(YAMLWizard):
  13. name: str # Task name
  14. type: TaskType # Task type
  15. path: str # task data path relative to DATA_PATH
  16. module: Optional[str] = None # Custom task module file, optional
  17. metrics: List[str] = field(default_factory=list) # Evaluation metrics
  18. use_task_mask: bool = False # Whether to use [gMASK] for evaluation
  19. use_multitask_encoding: bool = False # Not supported now
  20. unidirectional: bool = False # Whether to use unidirectional attention
  21. max_seq_length: int = 2048 # Max sequence length
  22. file_pattern: str | Dict[str, str] = "**/*.json*" # Organize data file in groups
  23. micro_batch_size: int = 1 # 'gen' task only support mbs = 1 for now
  24. def __post_init__(self):
  25. assert self.use_task_mask or not self.unidirectional, "[MASK] doesn't support unidirectional attention"
  26. @dataclass
  27. class MultiChoiceTaskConfig(BaseConfig):
  28. module = "evaluation.MultiChoiceTask"
  29. metrics: List[str] = field(default_factory=lambda: ["Accuracy"])
  30. @dataclass
  31. class GenerationTaskConfig(BaseConfig):
  32. module = "evaluation.GenerationTask"
  33. metrics: List[str] = field(default_factory=lambda: ["EM", "F1"])
  34. sampling_strategy: str = "BaseStrategy"
  35. num_beams: int = 4
  36. length_penalty: float = 1.0
  37. no_repeat_ngram_size: int = 3
  38. min_gen_length: int = 0
  39. max_gen_length: int = 128
  40. def __post_init__(self):
  41. assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"
  42. @dataclass
  43. class LanguageModelTaskConfig(BaseConfig):
  44. module = "evaluation.LanguageModelTask"
  45. metrics: List[str] = field(default_factory=lambda: ["PPL"])
  46. generation_length: int = 256 # Generated length in each window