123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- from dataclasses import dataclass
- from typing import Dict, Any, Union, get_origin, get_args, List, Literal, Optional
- @dataclass
- class Config:
- def serialize(self):
- asdict = {}
- for key in self.__dataclass_fields__.keys():
- value = getattr(self, key)
- if isinstance(value, Config):
- asdict[key] = value.serialize()
- else:
- asdict[key] = value
- return asdict
- @classmethod
- def _is_config(cls, type_like: Any) -> bool:
- """ checks if type_like class is a subclass of Config"""
- try:
- if issubclass(type_like, Config):
- return True
- except TypeError:
- pass
- return False
- @classmethod
- def _is_optional_config(cls, type_like: Any) -> bool:
- """ checks if type_like == Optional[subclass of Config] """
- if not get_origin(type_like) == Union:
- return False
- args = [arg for arg in get_args(type_like) if arg is not type(None)]
- return len(args) == 1 and cls._is_config(args[0])
- @classmethod
- def deserialize(cls, asdict: Dict[str, Any]):
- kwargs = {}
- for key, field_desc in cls.__dataclass_fields__.items():
- non_null = asdict.get(key) is not None
- # Optional[Config]
- if cls._is_optional_config(field_desc.type):
- if non_null:
- type_arg = [arg for arg in get_args(field_desc.type) if arg is not type(None)][0]
- kwargs[key] = type_arg.deserialize(asdict[key])
- else:
- kwargs[key] = None
- # TODO: add containers with Config
- elif get_origin(field_desc.type) in [Union, List, Dict, Literal]:
- kwargs[key] = asdict.get(key)
- elif cls._is_config(field_desc.type):
- if non_null:
- kwargs[key] = field_desc.type.deserialize(asdict[key])
- else:
- kwargs[key] = field_desc.type.default # type: ignore
- else:
- kwargs[key] = asdict.get(key)
- return cls(**kwargs)
- @dataclass
- class TextTokenizationConfig(Config):
- from_model: Optional[str] = "seamlessM4T_large"
- """If set, using a tokenizer from the model cards."""
- spm_path: Optional[str] = None
- """Path to a custom spm model. Not used if `from_model` is set."""
- langtoks: Optional[List[str]] = None
- """List of language tokens that should be added. Not used if `from_model` is set."""
- @dataclass
- class UnitTokenizationConfig(Config):
- from_model: Optional[str] = "seamlessM4T_large"
- """If set, using tokenizer from a model card."""
- num_units: Optional[int] = None
- """Alternatively, build custom tokenizer, set number of units"""
- langtoks: Optional[List[str]] = None
- """List of language tokens that should be added. Not used if `from_model` is set."""
- @dataclass
- class AudioProcessingConfig(Config):
- audio_root_dir: str = "/"
- """The root directory of the zipped audio files."""
- fbanks_standardize_audio: bool = True
- fbanks_num_mel_bins: int = 80
- fbanks_waveform_scale: int = 2**15
- @dataclass
- class DataLoadingConfig(Config):
- manifest_list_path: Optional[str] = None
- """Path to a file with the list of tsv manifests"""
- manifest_list: Optional[str] = None
- """Comma separated list of tsv manifests. Can be combined with `manifest_list_path`"""
- manifest_path_prefix: Optional[str] = None
- """Path prefix to manifest files (root directory)"""
- audio: AudioProcessingConfig = AudioProcessingConfig()
- """ Audio processing params """
- text_tokenization: TextTokenizationConfig = TextTokenizationConfig()
- """ Text tokenization params """
- unit_tokenization: UnitTokenizationConfig = UnitTokenizationConfig()
- """ Units tokenization params """
- unit_tokenizer_name: Optional[str] = "seamlessM4T_large"
- prepend_tgt_lang_tag: bool = True
- """ Prepend output text sequence with target lang token"""
- fbank_feats_pad_idx: int = 0
- """The pad index to use in fbanks batching."""
- max_tgt_text_tokens_per_batch: Optional[int] = 1000
- """ Defines flexible batch construction """
- fixed_batch_size: Optional[int] = None
- """ If set, uses fixed batch size """
- max_seconds_per_input_audio: int = 15
- """Accept only samples with less than max_seconds_per_input_audio ( waveform.shape[0] * SR )"""
- max_tgt_text_tokens_per_sample: int = 300
- """Accept only samples with less than max_sequence_length units"""
- max_units_per_sample: int = 1500
- """Accept only samples with less than max_sequence_length units"""
- num_threads: int = 5
- """The number of parallel threads during data reading and processing."""
- shuffle_window: Optional[int] = 1000
- """The size of sliding shuffle window."""
- prefech_batches: Optional[int] = 10
- """How many batches to prefetch in the background."""
- @dataclass
- class CustomModelParams(Config):
- model_embed_dim: int = 1024
- w2v2_encoder_layers: int = 24
- w2v2_encoder_layers_use_conformer: bool = True
- w2v2_encoder_layers_layernorm_features: bool = False
- w2v2_pos_encoder_type: Literal["conv", "relative", "rotary"] = "relative"
- w2v2_pos_encoder_depth: int = 0
- w2v2_pos_conv_kernel_size: int = 0
- w2v2_num_pos_conv_groups: int = 0
- nllb_encoder_layers: int = 24
- nllb_decoder_layers: int = 24
- t2u_encoder_layers: int = 6
- t2u_decoder_layers: int = 6
- nllb_vocabulary_size: int = 256102 # num_tokens + langs + spec symbols
- unit_vocabulary_size: int = 10082
- @dataclass
- class ModelConfig(Config):
- from_model: Optional[str] = None
- """If set, initialize a model defined in model cards. Also loads model weights."""
- from_model_config: Optional[str] = None
- """If set, initialize a model defined in model cards. Doesn't load weights."""
- custom_params: Optional[CustomModelParams] = None
- """If set, intitalize a new model with custom parameters"""
- pretrained_w2v2_path: Optional[str] = None
- """If set, use pre-trained w2v block"""
- pretrained_s2t_decoder_path: Optional[str] = None
- """If set, use pre-trained s2t decoder (NLLB)"""
- pretrained_t2u_path: Optional[str] = None
- """If set, use pre-trained t2u weights"""
- @dataclass
- class TrainingParams(Config):
- max_epochs: int = 100
- """ Maximum number of trainign epochs"""
- label_smoothing: float = 0.2
- """ Label smoothing coefficient for nll_loss """
- warmup_steps: int = 1000
- """ Number of steps with linearly increasing LR"""
- log_steps: int = 200
- """ Log inner loss after each `log_steps` training steps"""
- eval_steps: int = 1000
- """ Get eval loss after each `eval_steps` training steps """
- patience: int = 10
- """ Terminate if eval loss did not improve
- over the last `patience * eval_steps` training steps"""
- learning_rate: float = 1e-4
- """ Optimizer learining rate """
- start_learning_rate: float = 1e-7
- """ Start learining rate """
- float_dtype: Literal["fp16", "bf16", "fp32"] = "bf16"
- """ Dtype used for float numbers, defines training precision """
- @dataclass
- class WorkflowParams(Config):
- training: TrainingParams
- model: ModelConfig
- train_data: DataLoadingConfig
- eval_data: DataLoadingConfig
|