2
0

configs.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. from __future__ import annotations
  2. from dataclass_wizard import YAMLWizard
  3. from dataclasses import dataclass, field
  4. from enum import Enum
  5. from typing import Optional, List, Dict
  6. class TaskType(Enum):
  7. MULTICHOICE = "mul"
  8. GENERATION = "gen"
  9. OTHER = "other"
  10. @dataclass
  11. class BaseConfig(YAMLWizard):
  12. name: str # Task name
  13. type: TaskType # Task type
  14. path: str # task data path relative to DATA_PATH
  15. module: Optional[str] = None # Custom task module file, optional
  16. metrics: List[str] = field(default_factory=list) # Evaluation metrics
  17. use_task_mask: bool = False # Whether to use [gMASK] for evaluation
  18. use_multitask_encoding: bool = False # Not supported now
  19. unidirectional: bool = False # Whether to use unidirectional attention
  20. max_seq_length: int = 2048 # Max sequence length
  21. file_pattern: str | Dict[str, str] = "**/*.json*" # Organize data file in groups
  22. micro_batch_size: int = 1 # 'gen' task only support mbs = 1 for now
  23. def __post_init__(self):
  24. assert self.use_task_mask or not self.unidirectional, "[MASK] doesn't support unidirectional attention"
  25. @dataclass
  26. class MultiChoiceTaskConfig(BaseConfig):
  27. module = "evaluation.MultiChoiceTask"
  28. metrics: List[str] = field(default_factory=lambda: ["Accuracy"])
  29. @dataclass
  30. class GenerationTaskConfig(BaseConfig):
  31. module = "evaluation.GenerationTask"
  32. metrics: List[str] = field(default_factory=lambda: ["EM", "F1"])
  33. sampling_strategy: str = "BaseStrategy"
  34. num_beams: int = 4
  35. length_penalty: float = 1.0
  36. no_repeat_ngram_size: int = 3
  37. min_gen_length: int = 0
  38. max_gen_length: int = 128
  39. def __post_init__(self):
  40. assert self.micro_batch_size == 1, "Only support micro batch size = 1 for generation task"