configs.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import yaml
  7. from dataclasses import dataclass
  8. from typing import Dict, Any, Union, get_origin, get_args, List, Literal, Optional
  9. @dataclass
  10. class Config:
  11. def serialize(self):
  12. asdict = {}
  13. for key in self.__dataclass_fields__.keys():
  14. value = getattr(self, key)
  15. if isinstance(value, Config):
  16. asdict[key] = value.serialize()
  17. else:
  18. asdict[key] = value
  19. return asdict
  20. @classmethod
  21. def _is_config(cls, type_like: Any) -> bool:
  22. """Checks if type_like class is a subclass of Config"""
  23. try:
  24. if issubclass(type_like, Config):
  25. return True
  26. except TypeError:
  27. pass
  28. return False
  29. @classmethod
  30. def _is_optional_config(cls, type_like: Any) -> bool:
  31. """Checks if type_like == Optional[subclass of Config]"""
  32. if not get_origin(type_like) == Union:
  33. return False
  34. args = [arg for arg in get_args(type_like) if arg is not type(None)]
  35. return len(args) == 1 and cls._is_config(args[0])
  36. @classmethod
  37. def deserialize(cls, asdict: Dict[str, Any]):
  38. kwargs = {}
  39. for key, field_desc in cls.__dataclass_fields__.items():
  40. non_null = asdict.get(key) is not None
  41. # Optional[Config]
  42. if cls._is_optional_config(field_desc.type):
  43. if non_null:
  44. type_arg = [
  45. arg
  46. for arg in get_args(field_desc.type)
  47. if arg is not type(None)
  48. ][0]
  49. kwargs[key] = type_arg.deserialize(asdict[key])
  50. else:
  51. kwargs[key] = None
  52. # TODO: add containers with Config
  53. elif get_origin(field_desc.type) in [Union, List, Dict, Literal]:
  54. kwargs[key] = asdict.get(key)
  55. elif cls._is_config(field_desc.type):
  56. if non_null:
  57. kwargs[key] = field_desc.type.deserialize(asdict[key])
  58. else:
  59. kwargs[key] = field_desc.type.default # type: ignore
  60. else:
  61. kwargs[key] = asdict.get(key)
  62. return cls(**kwargs)
  63. @classmethod
  64. def from_string(cls, serialized_config: str):
  65. return cls.deserialize(yaml.load(serialized_config, Loader=yaml.FullLoader))
  66. @classmethod
  67. def from_file(cls, config_path: str):
  68. return cls.deserialize(yaml.load(config_path, Loader=yaml.FullLoader))
  69. @dataclass
  70. class TextTokenizationConfig(Config):
  71. from_model: Optional[str] = "seamlessM4T_large"
  72. """If set, using a tokenizer from the model cards."""
  73. spm_path: Optional[str] = None
  74. """Path to a custom spm model. Not used if `from_model` is set."""
  75. langtoks: Optional[List[str]] = None
  76. """List of language tokens that should be added. Not used if `from_model` is set."""
  77. @dataclass
  78. class UnitTokenizationConfig(Config):
  79. from_model: Optional[str] = "seamlessM4T_large"
  80. """If set, using tokenizer from a model card."""
  81. num_units: Optional[int] = None
  82. """Alternatively, build custom tokenizer, set number of units"""
  83. langtoks: Optional[List[str]] = None
  84. """List of language tokens that should be added. Not used if `from_model` is set."""
  85. @dataclass
  86. class AudioProcessingConfig(Config):
  87. audio_root_dir: str = "/"
  88. """The root directory of the zipped audio files."""
  89. fbanks_standardize_audio: bool = True
  90. fbanks_num_mel_bins: int = 80
  91. fbanks_waveform_scale: int = 2**15
  92. @dataclass
  93. class DataLoadingConfig(Config):
  94. manifest_list_path: Optional[str] = None
  95. """Path to a file with the list of tsv manifests"""
  96. manifest_list: Optional[str] = None
  97. """Comma separated list of tsv manifests. Can be combined with `manifest_list_path`"""
  98. manifest_path_prefix: Optional[str] = None
  99. """Path prefix to manifest files (root directory)"""
  100. audio: AudioProcessingConfig = AudioProcessingConfig()
  101. """ Audio processing params """
  102. text_tokenization: TextTokenizationConfig = TextTokenizationConfig()
  103. """ Text tokenization params """
  104. unit_tokenization: UnitTokenizationConfig = UnitTokenizationConfig()
  105. """ Units tokenization params """
  106. unit_tokenizer_name: Optional[str] = "seamlessM4T_large"
  107. prepend_tgt_lang_tag: bool = True
  108. """ Prepend output text sequence with target lang token"""
  109. fbank_feats_pad_idx: int = 0
  110. """The pad index to use in fbanks batching."""
  111. max_tgt_text_tokens_per_batch: Optional[int] = 1000
  112. """ Defines flexible batch construction """
  113. max_batch_size: Optional[int] = None
  114. """ In flexible batch construction sets max allowed size"""
  115. fixed_batch_size: Optional[int] = None
  116. """ If set, uses fixed batch size """
  117. max_seconds_per_input_audio: int = 15
  118. """Accept only samples with less than max_seconds_per_input_audio ( waveform.shape[0] * SR )"""
  119. max_tgt_text_tokens_per_sample: int = 300
  120. """Accept only samples with less than max_sequence_length units"""
  121. max_units_per_sample: int = 1500
  122. """Accept only samples with less than max_sequence_length units"""
  123. num_threads: int = 5
  124. """The number of parallel threads during data reading and processing."""
  125. shuffle_window: Optional[int] = 1000
  126. """The size of sliding shuffle window."""
  127. prefech_batches: Optional[int] = 10
  128. """How many batches to prefetch in the background."""
  129. @dataclass
  130. class CustomModelParams(Config):
  131. model_embed_dim: int = 1024
  132. w2v2_encoder_layers: int = 24
  133. w2v2_encoder_layers_use_conformer: bool = True
  134. w2v2_encoder_layers_layernorm_features: bool = False
  135. w2v2_pos_encoder_type: Literal["conv", "relative", "rotary"] = "relative"
  136. w2v2_pos_encoder_depth: int = 0
  137. w2v2_pos_conv_kernel_size: int = 0
  138. w2v2_num_pos_conv_groups: int = 0
  139. nllb_encoder_layers: int = 24
  140. nllb_decoder_layers: int = 24
  141. t2u_encoder_layers: int = 6
  142. t2u_decoder_layers: int = 6
  143. nllb_vocabulary_size: int = 256102 # num_tokens + langs + spec symbols
  144. unit_vocabulary_size: int = 10082
  145. @dataclass
  146. class ModelConfig(Config):
  147. from_model: Optional[str] = None
  148. """If set, initialize a model defined in model cards. Also loads model weights."""
  149. from_model_config: Optional[str] = None
  150. """If set, initialize a model defined in model cards. Doesn't load weights."""
  151. custom_params: Optional[CustomModelParams] = None
  152. """If set, intitalize a new model with custom parameters"""
  153. pretrained_w2v2_path: Optional[str] = None
  154. """If set, use pre-trained w2v block"""
  155. pretrained_s2t_decoder_path: Optional[str] = None
  156. """If set, use pre-trained s2t decoder (NLLB)"""
  157. pretrained_t2u_path: Optional[str] = None
  158. """If set, use pre-trained t2u weights"""
  159. @dataclass
  160. class TrainingParams(Config):
  161. max_epochs: int = 100
  162. """ Maximum number of trainign epochs"""
  163. label_smoothing: float = 0.2
  164. """ Label smoothing coefficient for nll_loss """
  165. warmup_steps: int = 1000
  166. """ Number of steps with linearly increasing LR"""
  167. log_steps: int = 200
  168. """ Log inner loss after each `log_steps` training steps"""
  169. eval_steps: int = 1000
  170. """ Get eval loss after each `eval_steps` training steps """
  171. patience: int = 10
  172. """ Terminate if eval loss did not improve
  173. over the last `patience * eval_steps` training steps"""
  174. learning_rate: float = 1e-4
  175. """ Optimizer learining rate """
  176. start_learning_rate: float = 1e-7
  177. """ Start learining rate """
  178. float_dtype: Literal["fp16", "bf16", "fp32"] = "bf16"
  179. """ Dtype used for float numbers, defines training precision """
  180. @dataclass
  181. class WorkflowParams(Config):
  182. training: TrainingParams
  183. model: ModelConfig
  184. train_data: DataLoadingConfig
  185. eval_data: DataLoadingConfig