configs.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from dataclasses import dataclass
  7. from typing import Dict, Any, Union, get_origin, get_args, List, Literal, Optional
  8. @dataclass
  9. class Config:
  10. def serialize(self):
  11. asdict = {}
  12. for key in self.__dataclass_fields__.keys():
  13. value = getattr(self, key)
  14. if isinstance(value, Config):
  15. asdict[key] = value.serialize()
  16. else:
  17. asdict[key] = value
  18. return asdict
  19. @classmethod
  20. def _is_config(cls, type_like: Any) -> bool:
  21. """ checks if type_like class is a subclass of Config"""
  22. try:
  23. if issubclass(type_like, Config):
  24. return True
  25. except TypeError:
  26. pass
  27. return False
  28. @classmethod
  29. def _is_optional_config(cls, type_like: Any) -> bool:
  30. """ checks if type_like == Optional[subclass of Config] """
  31. if not get_origin(type_like) == Union:
  32. return False
  33. args = [arg for arg in get_args(type_like) if arg is not type(None)]
  34. return len(args) == 1 and cls._is_config(args[0])
  35. @classmethod
  36. def deserialize(cls, asdict: Dict[str, Any]):
  37. kwargs = {}
  38. for key, field_desc in cls.__dataclass_fields__.items():
  39. non_null = asdict.get(key) is not None
  40. # Optional[Config]
  41. if cls._is_optional_config(field_desc.type):
  42. if non_null:
  43. type_arg = [arg for arg in get_args(field_desc.type) if arg is not type(None)][0]
  44. kwargs[key] = type_arg.deserialize(asdict[key])
  45. else:
  46. kwargs[key] = None
  47. # TODO: add containers with Config
  48. elif get_origin(field_desc.type) in [Union, List, Dict, Literal]:
  49. kwargs[key] = asdict.get(key)
  50. elif cls._is_config(field_desc.type):
  51. if non_null:
  52. kwargs[key] = field_desc.type.deserialize(asdict[key])
  53. else:
  54. kwargs[key] = field_desc.type.default # type: ignore
  55. else:
  56. kwargs[key] = asdict.get(key)
  57. return cls(**kwargs)
  58. @dataclass
  59. class TextTokenizationConfig(Config):
  60. from_model: Optional[str] = "seamlessM4T_large"
  61. """If set, using a tokenizer from the model cards."""
  62. spm_path: Optional[str] = None
  63. """Path to a custom spm model. Not used if `from_model` is set."""
  64. langtoks: Optional[List[str]] = None
  65. """List of language tokens that should be added. Not used if `from_model` is set."""
  66. @dataclass
  67. class UnitTokenizationConfig(Config):
  68. from_model: Optional[str] = "seamlessM4T_large"
  69. """If set, using tokenizer from a model card."""
  70. num_units: Optional[int] = None
  71. """Alternatively, build custom tokenizer, set number of units"""
  72. langtoks: Optional[List[str]] = None
  73. """List of language tokens that should be added. Not used if `from_model` is set."""
  74. @dataclass
  75. class AudioProcessingConfig(Config):
  76. audio_root_dir: str = "/"
  77. """The root directory of the zipped audio files."""
  78. fbanks_standardize_audio: bool = True
  79. fbanks_num_mel_bins: int = 80
  80. fbanks_waveform_scale: int = 2**15
  81. @dataclass
  82. class DataLoadingConfig(Config):
  83. manifest_list_path: Optional[str] = None
  84. """Path to a file with the list of tsv manifests"""
  85. manifest_list: Optional[str] = None
  86. """Comma separated list of tsv manifests. Can be combined with `manifest_list_path`"""
  87. manifest_path_prefix: Optional[str] = None
  88. """Path prefix to manifest files (root directory)"""
  89. audio: AudioProcessingConfig = AudioProcessingConfig()
  90. """ Audio processing params """
  91. text_tokenization: TextTokenizationConfig = TextTokenizationConfig()
  92. """ Text tokenization params """
  93. unit_tokenization: UnitTokenizationConfig = UnitTokenizationConfig()
  94. """ Units tokenization params """
  95. unit_tokenizer_name: Optional[str] = "seamlessM4T_large"
  96. prepend_tgt_lang_tag: bool = True
  97. """ Prepend output text sequence with target lang token"""
  98. fbank_feats_pad_idx: int = 0
  99. """The pad index to use in fbanks batching."""
  100. max_tgt_text_tokens_per_batch: Optional[int] = 1000
  101. """ Defines flexible batch construction """
  102. fixed_batch_size: Optional[int] = None
  103. """ If set, uses fixed batch size """
  104. max_seconds_per_input_audio: int = 15
  105. """Accept only samples with less than max_seconds_per_input_audio ( waveform.shape[0] * SR )"""
  106. max_tgt_text_tokens_per_sample: int = 300
  107. """Accept only samples with less than max_sequence_length units"""
  108. max_units_per_sample: int = 1500
  109. """Accept only samples with less than max_sequence_length units"""
  110. num_threads: int = 5
  111. """The number of parallel threads during data reading and processing."""
  112. shuffle_window: Optional[int] = 1000
  113. """The size of sliding shuffle window."""
  114. prefech_batches: Optional[int] = 10
  115. """How many batches to prefetch in the background."""
  116. @dataclass
  117. class CustomModelParams(Config):
  118. model_embed_dim: int = 1024
  119. w2v2_encoder_layers: int = 24
  120. w2v2_encoder_layers_use_conformer: bool = True
  121. w2v2_encoder_layers_layernorm_features: bool = False
  122. w2v2_pos_encoder_type: Literal["conv", "relative", "rotary"] = "relative"
  123. w2v2_pos_encoder_depth: int = 0
  124. w2v2_pos_conv_kernel_size: int = 0
  125. w2v2_num_pos_conv_groups: int = 0
  126. nllb_encoder_layers: int = 24
  127. nllb_decoder_layers: int = 24
  128. t2u_encoder_layers: int = 6
  129. t2u_decoder_layers: int = 6
  130. nllb_vocabulary_size: int = 256102 # num_tokens + langs + spec symbols
  131. unit_vocabulary_size: int = 10082
  132. @dataclass
  133. class ModelConfig(Config):
  134. from_model: Optional[str] = None
  135. """If set, initialize a model defined in model cards. Also loads model weights."""
  136. from_model_config: Optional[str] = None
  137. """If set, initialize a model defined in model cards. Doesn't load weights."""
  138. custom_params: Optional[CustomModelParams] = None
  139. """If set, intitalize a new model with custom parameters"""
  140. pretrained_w2v2_path: Optional[str] = None
  141. """If set, use pre-trained w2v block"""
  142. pretrained_s2t_decoder_path: Optional[str] = None
  143. """If set, use pre-trained s2t decoder (NLLB)"""
  144. pretrained_t2u_path: Optional[str] = None
  145. """If set, use pre-trained t2u weights"""
  146. @dataclass
  147. class TrainingParams(Config):
  148. max_epochs: int = 100
  149. """ Maximum number of trainign epochs"""
  150. label_smoothing: float = 0.2
  151. """ Label smoothing coefficient for nll_loss """
  152. warmup_steps: int = 1000
  153. """ Number of steps with linearly increasing LR"""
  154. log_steps: int = 200
  155. """ Log inner loss after each `log_steps` training steps"""
  156. eval_steps: int = 1000
  157. """ Get eval loss after each `eval_steps` training steps """
  158. patience: int = 10
  159. """ Terminate if eval loss did not improve
  160. over the last `patience * eval_steps` training steps"""
  161. learning_rate: float = 1e-4
  162. """ Optimizer learining rate """
  163. start_learning_rate: float = 1e-7
  164. """ Start learining rate """
  165. float_dtype: Literal["fp16", "bf16", "fp32"] = "bf16"
  166. """ Dtype used for float numbers, defines training precision """
  167. @dataclass
  168. class WorkflowParams(Config):
  169. training: TrainingParams
  170. model: ModelConfig
  171. train_data: DataLoadingConfig
  172. eval_data: DataLoadingConfig