123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- from __future__ import annotations
- from dataclass_wizard import YAMLWizard
- from dataclasses import dataclass, field
- from enum import Enum
- from typing import Optional, List, Dict
- class TaskType(Enum):
- MULTICHOICE = "mul"
- GENERATION = "gen"
- LANGUAGE_MODEL = "lm"
- OTHER = "other"
- @dataclass
- class BaseConfig(YAMLWizard):
- name: str # Task name
- type: TaskType # Task type
- path: str # task data path relative to DATA_PATH
- module: Optional[str] = None # Custom task module file, optional
- metrics: List[str] = field(default_factory=list) # Evaluation metrics
- use_task_mask: bool = False # Whether to use [gMASK] for evaluation
- use_multitask_encoding: bool = False # Not supported now
- unidirectional: bool = False # Whether to use unidirectional attention
- max_seq_length: int = 2048 # Max sequence length
- no_tokenized: bool = False
- file_pattern: str | Dict[str, str] = "**/*.json*" # Organize data file in groups
- micro_batch_size: int = 1 # 'gen' task only support mbs = 1 for now
- def __post_init__(self):
- assert self.use_task_mask or not self.unidirectional, "[MASK] doesn't support unidirectional attention"
- @dataclass
- class MultiChoiceTaskConfig(BaseConfig):
- module = "evaluation.MultiChoiceTask"
- metrics: List[str] = field(default_factory=lambda: ["Accuracy"])
- @dataclass
- class GenerationTaskConfig(BaseConfig):
- module = "evaluation.GenerationTask"
- metrics: List[str] = field(default_factory=lambda: [])
- sampling_strategy: str = "BaseStrategy"
- num_beams: int = 4
- length_penalty: float = 1.0
- no_repeat_ngram_size: int = 3
- min_gen_length: int = 0
- max_gen_length: int = 128
- @dataclass
- class LanguageModelTaskConfig(BaseConfig):
- module = "evaluation.LanguageModelTask"
- metrics: List[str] = field(default_factory=lambda: ["PPL"])
- generation_length: int = 256 # Generated length in each window
|