瀏覽代碼

* Training recipees for M4T-nano/-micro\n* Adjustments to fairseq2/sc updates\n* Fixing MyPy warnings (#59)

Ruslan Mavlyutov 1 年之前
父節點
當前提交
ca1ebf90ea
共有 29 個文件被更改,包括 2817 次插入77 次删除
  1. 247 0
      src/seamless_communication/cli/m4t/train/cleaners.py
  2. 10 1
      src/seamless_communication/cli/m4t/train/configs.py
  3. 7 6
      src/seamless_communication/cli/m4t/train/dataloader.py
  4. 120 29
      src/seamless_communication/cli/m4t/train/model.py
  5. 101 0
      src/seamless_communication/cli/m4t/train/recipes/asr_1024_1_3_wh_transc_120ch.yaml
  6. 99 0
      src/seamless_communication/cli/m4t/train/recipes/asr_768_8_4_wh_transc.yaml
  7. 98 0
      src/seamless_communication/cli/m4t/train/recipes/asr_micro_5x5m.yaml
  8. 98 0
      src/seamless_communication/cli/m4t/train/recipes/asr_micro_5x5m_tune_on_noise.yaml
  9. 99 0
      src/seamless_communication/cli/m4t/train/recipes/asr_micro_optimized_5x5m.yaml
  10. 98 0
      src/seamless_communication/cli/m4t/train/recipes/asr_mini_5x5m.yaml
  11. 98 0
      src/seamless_communication/cli/m4t/train/recipes/asr_nano_5x5m.yaml
  12. 98 0
      src/seamless_communication/cli/m4t/train/recipes/asr_nano_5x5m_tune_on_noise.yaml
  13. 99 0
      src/seamless_communication/cli/m4t/train/recipes/asr_nano_optimized_5x5m.yaml
  14. 98 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small_5x5m.yaml
  15. 11 9
      src/seamless_communication/cli/m4t/train/recipes/asr_small_eng_10m_wh.yaml
  16. 99 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small_old_and_10m_wh.yaml
  17. 99 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small_vit_transc_eng.yaml
  18. 101 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_mms_120ch.yaml
  19. 101 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_mms_120ch_fp16.yaml
  20. 101 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_mms_120ch_nost.yaml
  21. 11 9
      src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_transc.yaml
  22. 99 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_transc_mms.yaml
  23. 101 0
      src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_vary_audio.yaml
  24. 101 0
      src/seamless_communication/cli/m4t/train/recipes/asr_wide_wh_120ch.yaml
  25. 1 1
      src/seamless_communication/cli/m4t/train/recipes/large_M4T_v1.yaml
  26. 99 0
      src/seamless_communication/cli/m4t/train/recipes/mt_small_orig_dataset.yaml
  27. 99 0
      src/seamless_communication/cli/m4t/train/recipes/mt_small_orig_dataset_and_eng_10m.yaml
  28. 340 0
      src/seamless_communication/cli/m4t/train/run_eval.py
  29. 84 22
      src/seamless_communication/cli/m4t/train/trainer.py

+ 247 - 0
src/seamless_communication/cli/m4t/train/cleaners.py

@@ -0,0 +1,247 @@
+import logging
+import re
+import string
+from typing import Callable, List
+
+from unidecode import unidecode
+
+logger = logging.getLogger(__name__)
+
+
+# Regular expression matching whitespace:
+_whitespace_re = re.compile(r"\s+")
+
+# List of (regular expression, replacement) pairs for abbreviations:
+_abbreviations = {
+    "en": [
+        (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+        for x in [
+            ("mrs", "misess"),
+            ("mr", "mister"),
+            ("dr", "doctor"),
+            ("st", "saint"),
+            ("co", "company"),
+            ("jr", "junior"),
+            ("maj", "major"),
+            ("gen", "general"),
+            ("drs", "doctors"),
+            ("rev", "reverend"),
+            ("lt", "lieutenant"),
+            ("hon", "honorable"),
+            ("sgt", "sergeant"),
+            ("capt", "captain"),
+            ("esq", "esquire"),
+            ("ltd", "limited"),
+            ("col", "colonel"),
+            ("ft", "fort"),
+        ]
+    ],
+}
+
+
+def expand_abbreviations(text, lang="en"):
+    if lang not in _abbreviations:
+        return text
+
+    for regex, replacement in _abbreviations[lang]:
+        text = re.sub(regex, replacement, text)
+    return text
+
+
+def expand_numbers(text, lang="en"):
+    # return normalize_numbers(text, lang)
+    return text
+
+
+def lowercase(text):
+    return text.lower()
+
+
+def collapse_whitespace(text):
+    return re.sub(_whitespace_re, " ", text)
+
+
+def convert_to_ascii(text):
+    return unidecode(text)
+
+
+def basic_cleaners(text):
+    """Basic pipeline that lowercases and collapses whitespace without transliteration."""
+    text = lowercase(text)
+    text = collapse_whitespace(text)
+    return text
+
+
+def transliteration_cleaners(text):
+    """Pipeline for non-English text that transliterates to ASCII."""
+    text = convert_to_ascii(text)
+    text = lowercase(text)
+    text = collapse_whitespace(text)
+    return text
+
+
+PUNCTUATIONS_EXCLUDE_APOSTROPHE = (
+    string.punctuation.replace("'", "") + "¡¨«°³º»¿‘“”…♪♫ˆᵉ™,ʾ˚"
+)
+PUNCTUATIONS_TO_SPACE = "-/–·—•"
+
+
+def remove_punctuations(text, punctuations=string.punctuation):
+    text = text.translate(
+        str.maketrans(PUNCTUATIONS_TO_SPACE, " " * len(PUNCTUATIONS_TO_SPACE))
+    )
+    return text.translate(str.maketrans("", "", punctuations))
+
+
+def remove_parentheses(text: str) -> str:
+    # remove all substring within () or []
+    out = ""
+    num_p = 0
+    start_i = 0
+    for i, c in enumerate(text):
+        if c == "(" or c == "[" or c == "(":
+            if num_p == 0 and i > start_i:
+                out += text[start_i:i]
+            num_p += 1
+        elif c == ")" or c == "]" or c == ")":
+            num_p -= 1
+            if num_p == 0:
+                start_i = i + 1
+
+    if len(text) > start_i:
+        out += text[start_i:]
+
+    return out.strip()
+
+
+REMAP_CHARS = {
+    "`": "'",
+    "’ ": " ",
+    "’": "'",
+}
+
+
+def remap_chars(text, remap_chars=REMAP_CHARS):
+    for k, v in remap_chars.items():
+        text = text.replace(k, v)
+    return text
+
+
+def expand_capitals(text):
+    words = text.split()
+    for i, w in enumerate(words):
+        if w.isupper():
+            words[i] = " ".join(w)
+
+    return " ".join(words)
+
+
+def english_cleaners(text, punctuations=string.punctuation):
+    """Pipeline for English text, including number and abbreviation expansion."""
+    text = convert_to_ascii(text)
+    text = remap_chars(text)
+    text = lowercase(text)
+    text = expand_numbers(text)
+    text = expand_abbreviations(text)
+    text = remove_parentheses(text)
+    text = remove_punctuations(text, punctuations=punctuations)
+    text = collapse_whitespace(text)
+    text = text.strip()
+    return text
+
+
+def english_cleaners_keep_apostrophe(text):
+    return english_cleaners(text, punctuations=PUNCTUATIONS_EXCLUDE_APOSTROPHE)
+
+
+def fisher_text_cleaners(text):
+    # remove the convert_to_ascii cleaner to keep Spanish characters
+    text = lowercase(text)
+    text = expand_numbers(text)
+    text = expand_abbreviations(text)
+    text = remove_punctuations(text, punctuations=PUNCTUATIONS_EXCLUDE_APOSTROPHE)
+    text = collapse_whitespace(text)
+    return text
+
+
+def text_cleaners(text, lang="en"):
+    if lang == "hok":
+        # no op for Hokkien TaiLo
+        return text
+
+    text = remap_chars(text)
+    text = expand_capitals(text)
+    text = lowercase(text)
+    text = remove_parentheses(text)
+
+    if lang == "zh":
+        raise NotImplementedError()
+    if lang in ["en", "fr", "es", "nl", "de", "bn"]:
+        try:
+            text = expand_numbers(text, lang)
+        except Exception:
+            logger.exception("Failed to expand numbers")
+            raise
+    text = expand_abbreviations(text, lang)
+    if lang == "zh":
+        raise NotImplementedError()
+    else:
+        text = remove_punctuations(text, punctuations=PUNCTUATIONS_EXCLUDE_APOSTROPHE)
+        text = collapse_whitespace(text)
+    if lang == "ar":
+        raise NotImplementedError()
+    text = text.strip()
+    return text
+
+
+def apply_text_functions(text_funcs: List[Callable], text: str) -> str:
+    for func in text_funcs:
+        text = func(text)
+    return text
+
+
+def merge_tailo_init_final(text):
+    sps = text.strip().split()
+    results = []
+    last_syllable = ""
+    for sp in sps:
+        if sp == "NULLINIT":
+            continue
+        last_syllable += sp
+        if sp[-1].isnumeric():
+            results.append(last_syllable)
+            last_syllable = ""
+    if last_syllable != "":
+        results.append(last_syllable)
+    return " ".join(results)
+
+
+def _numeric_feature_by_regex(regex, s):
+    match = re.search(regex, s)
+    if match is None:
+        return -50
+    return int(match.group(1))
+
+
+def normalize_text_references(refs, lang):
+    norm_refs = []
+    for text in refs:
+        text = basic_cleaners(text)
+        text = remove_punctuations(text, PUNCTUATIONS_EXCLUDE_APOSTROPHE)
+        if lang == "ja":
+            raise NotImplementedError()
+        norm_refs.append(text)
+    return norm_refs
+
+
+def normalize_text_whisper(refs, lang):
+    from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer  # type: ignore
+
+    if lang in ["en", "eng"]:
+        normalizer = EnglishTextNormalizer()
+    else:
+        normalizer = BasicTextNormalizer()
+    norm_refs = []
+    for text in refs:
+        norm_refs.append(normalizer(text))
+    return norm_refs

+ 10 - 1
src/seamless_communication/cli/m4t/train/configs.py

@@ -74,7 +74,8 @@ class Config:
 
     @classmethod
     def from_file(cls, config_path: str):
-        return cls.deserialize(yaml.load(config_path, Loader=yaml.FullLoader))
+        with open(config_path, "r") as fp_in:
+            return cls.deserialize(yaml.load(fp_in, Loader=yaml.FullLoader))
 
 
 @dataclass
@@ -173,6 +174,12 @@ class DataLoadingConfig(Config):
 class CustomModelParams(Config):
     model_embed_dim: int = 1024
 
+    num_fbank_channels: int = 80
+
+    fbank_stride: int = 2
+
+    w2v2_ffn_inner_dim: Optional[int] = None
+
     w2v2_encoder_layers: int = 24
 
     w2v2_encoder_layers_use_conformer: bool = True
@@ -187,6 +194,8 @@ class CustomModelParams(Config):
 
     w2v2_num_pos_conv_groups: int = 0
 
+    nllb_ffn_inner_dim: Optional[int] = None
+
     nllb_encoder_layers: int = 24
 
     nllb_decoder_layers: int = 24

+ 7 - 6
src/seamless_communication/cli/m4t/train/dataloader.py

@@ -331,7 +331,7 @@ class UnityDataLoader:
             torch.LongTensor(
                 [
                     int(unit_id) + 4
-                    for unit_id in units_str.rstrip().bytes().decode("utf-8").split()
+                    for unit_id in str(units_str).split()
                 ]
                 + [self.unit_tokenizer.vocab_info.eos_idx]
             )
@@ -345,8 +345,8 @@ class UnityDataLoader:
         # prefixes for tokenized texts and speech units (<eos> <lang_tok>)
         prefix_builder = lambda lang_tok: torch.LongTensor(  # noqa: E731
             [
-                self.text_prefix_tokens[lang_tok.bytes().decode("utf8")],
-                self.unit_prefix_tokens[lang_tok.bytes().decode("utf8")],
+                self.text_prefix_tokens[str(lang_tok)],
+                self.unit_prefix_tokens[str(lang_tok)],
             ]
         )
         builder.map(
@@ -392,9 +392,10 @@ class UnityDataLoader:
         """Tells if NaNs present in fbank"""
         fbank = sample[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]["fbank"]
         has_nans: bool = torch.any(torch.isnan(fbank)).item()  # type: ignore
-        if has_nans:
-            logger.warning("Sample fbank contains NaNs. Skipping")
-        return has_nans
+        empty = fbank.shape[0] == 0
+        if has_nans or empty:
+            logger.warning("Sample fbank contains NaNs or Empty. Skipping")
+        return has_nans or empty
 
     def _filter_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
         # Drop:

+ 120 - 29
src/seamless_communication/cli/m4t/train/model.py

@@ -7,17 +7,15 @@
 
 import logging
 import os
-from typing import Any, Dict
+from typing import Any, Dict, Optional
 
 import torch
+
 from fairseq2.data import VocabularyInfo
 from fairseq2.models.nllb.builder import NllbConfig
-from fairseq2.models.nllb.loader import NllbLoader
-from fairseq2.models.utils.checkpoint_loader import convert_model_state_dict
+from fairseq2.models.utils.checkpoint import convert_model_state_dict
 from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
-from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
 from fairseq2.nn.transformer import TransformerNormOrder
-
 from seamless_communication.cli.m4t.train.configs import CustomModelParams, ModelConfig
 from seamless_communication.models.unity import (
     UnitYConfig,
@@ -26,7 +24,10 @@ from seamless_communication.models.unity import (
     create_unity_model,
     load_unity_model,
 )
-from seamless_communication.models.unity.loader import UnitYLoader, load_unity_config
+from seamless_communication.models.unity.loader import (
+    _fairseq_key_map as unity_fairseq_key_map,
+)
+from seamless_communication.models.unity.loader import load_unity_config
 
 logger = logging.getLogger(__name__)
 
@@ -60,7 +61,28 @@ class ModelBuilder:
         """Load w2v2 encoder model trained in fairseq1"""
         logger.info(f"Loading w2v2 weights from {checkpoint_path}")
         state_dict = torch.load(checkpoint_path)["model"]
-        key_map = Wav2Vec2Loader._fairseq_key_map()
+        key_map = {
+            # fmt: off
+            r"^encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"encoder.layers.\1.self_attn.output_proj.",
+            r"^encoder\.layers\.([0-9]+)\.fc1\.":                 r"encoder.layers.\1.ffn.inner_proj.",
+            r"^encoder\.layers\.([0-9]+)\.fc2\.":                 r"encoder.layers.\1.ffn.output_proj.",
+            r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.":    r"encoder.layers.\1.ffn_layer_norm.",
+            r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":    r"decoder.layers.\1.ffn_layer_norm.",
+            r"^encoder\.embed_tokens\.":                          r"encoder_frontend.embed.",
+            r"^encoder\.pos_conv\.0\.":                           r"encoder_frontend.pos_encoder.conv.",
+            r"^feature_extractor\.conv_layers\.([0-9]+)\.0\.":    r"encoder_frontend.feature_extractor.layers.\1.conv.",
+            r"^feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.": \
+            r"encoder_frontend.feature_extractor.layers.\1.layer_norm.",
+            r"^feature_extractor\.conv_layers\.0\.2\.": \
+            r"encoder_frontend.feature_extractor.layers.0.group_norm.",
+            r"^layer_norm\.":                                     r"encoder_frontend.post_extract_layer_norm.",
+            r"^post_extract_proj\.":                              r"encoder_frontend.model_dim_proj.",
+            r"^mask_emb":                                         r"masker.temporal_mask_embed",
+            r"^quantizer\.vars":                                  r"quantizer.entries",
+            r"^quantizer\.weight_proj\.":                         r"quantizer.entry_proj.",
+            r"^project_q\.":                                      r"final_target_proj.",
+            # fmt: on
+        }
         key_map.update(
             {
                 r"^encoder.layers\.([0-9]+)\.conv_module.batch_norm.": r"encoder.layers.\1.conv.batch_norm.",
@@ -119,8 +141,28 @@ class ModelBuilder:
         shared_state_dict = cls._sel_and_upd_prefix(
             kv=state_dict, prefix="shared_decoder.", new_prefix=decoder_prefix
         )
+        nllb_fairseq_key_map = {
+            # fmt: off
+            r"^encoder\.embed_tokens\.":                              r"encoder_frontend.embed.",
+            r"^decoder\.embed_tokens\.":                              r"decoder_frontend.embed.",
+            r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"decoder.layers.\1.self_attn.output_proj.",
+            r"^encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"encoder.layers.\1.self_attn.output_proj.",
+            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": \
+            r"decoder.layers.\1.encoder_decoder_attn.output_proj.",
+            r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"decoder.layers.\1.encoder_decoder_attn.",
+            r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": \
+            r"decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+            r"^encoder\.layers\.([0-9]+)\.fc1\.":                     r"encoder.layers.\1.ffn.inner_proj.",
+            r"^decoder\.layers\.([0-9]+)\.fc1\.":                     r"decoder.layers.\1.ffn.inner_proj.",
+            r"^encoder\.layers\.([0-9]+)\.fc2\.":                     r"encoder.layers.\1.ffn.output_proj.",
+            r"^decoder\.layers\.([0-9]+)\.fc2\.":                     r"decoder.layers.\1.ffn.output_proj.",
+            r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"encoder.layers.\1.ffn_layer_norm.",
+            r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"decoder.layers.\1.ffn_layer_norm.",
+            r"^decoder\.output_projection\.":                         r"final_proj.",
+            # fmt: on
+        }
         shared_state_dict = convert_model_state_dict(
-            state_dict=shared_state_dict, key_map=NllbLoader._fairseq_key_map()
+            state_dict=shared_state_dict, key_map=nllb_fairseq_key_map
         )
         for rm_key in ["decoder.embed_positions._float_tensor", "decoder.version"]:
             del shared_state_dict[rm_key]
@@ -133,11 +175,14 @@ class ModelBuilder:
         proj_state = cls._sel_and_upd_prefix(
             kv=shared_state_dict, prefix="final_proj.", new_prefix=""
         )
-        model.text_decoder_frontend.load_state_dict(frontend_state, strict=True)
+        if model.text_decoder_frontend is not None:
+            model.text_decoder_frontend.load_state_dict(frontend_state, strict=True)
         logger.info(f"Loaded s2t decoder frontend weights from {checkpoint_path}")
-        model.text_decoder.load_state_dict(decoder_state, strict=True)
+        if model.text_decoder is not None:
+            model.text_decoder.load_state_dict(decoder_state, strict=True)
         logger.info(f"Loaded s2t decoder weights from {checkpoint_path}")
-        model.final_proj.load_state_dict(proj_state, strict=True)
+        if model.final_proj is not None:
+            model.final_proj.load_state_dict(proj_state, strict=True)
         logger.info(f"Loaded s2t decoder final_proj weights from {checkpoint_path}")
 
     @classmethod
@@ -160,7 +205,7 @@ class ModelBuilder:
         }
         state_dict = convert_model_state_dict(
             state_dict=state_dict,
-            key_map=UnitYLoader._fairseq_key_map(config=model_config),
+            key_map=unity_fairseq_key_map(config=model_config),
         )
         t2u_state_dict = cls._sel_and_upd_prefix(
             kv=state_dict, prefix="t2u_model.", new_prefix=""
@@ -170,7 +215,16 @@ class ModelBuilder:
 
     def build_model(
         self,
+        skip_loading_weights: bool = False,
     ) -> UnitYModel:
+        """
+        Args:
+            skip_loading_weights (bool, optional):
+                Ignores pretrained_w2v2_path, pretrained_s2t_decoder_path, pretrained_t2u_path.
+                Defaults to False.
+        Returns:
+            UnitYModel: initialized UnitY model
+        """
         config = self.config
         logger.info("Initializing model")
         if config.from_model is not None:
@@ -193,25 +247,38 @@ class ModelBuilder:
         model = create_unity_model(
             config=model_config, dtype=self.dtype, device=self.device
         )
+        if not skip_loading_weights:
+            if self.config.pretrained_w2v2_path is not None:
+                self._load_pretrained_w2v2_encoder(
+                    model, self.config.pretrained_w2v2_path
+                )
 
-        if self.config.pretrained_w2v2_path is not None:
-            self._load_pretrained_w2v2_encoder(model, self.config.pretrained_w2v2_path)
+            if self.config.pretrained_s2t_decoder_path is not None:
+                self._load_pretrained_s2t_decoder(
+                    model, self.config.pretrained_s2t_decoder_path
+                )
 
-        if self.config.pretrained_s2t_decoder_path is not None:
-            self._load_pretrained_s2t_decoder(
-                model, self.config.pretrained_s2t_decoder_path
-            )
+            if self.config.pretrained_t2u_path is not None:
+                self._load_pretrained_t2u(
+                    model, model_config, self.config.pretrained_t2u_path
+                )
 
-        if self.config.pretrained_t2u_path is not None:
-            self._load_pretrained_t2u(
-                model, model_config, self.config.pretrained_t2u_path
+        def _num_s2t_params(model: UnitYModel) -> int:
+            return (
+                self._get_num_model_params(model.speech_encoder_frontend)
+                + self._get_num_model_params(model.speech_encoder)
+                + self._get_num_model_params(model.text_decoder_frontend)
+                + self._get_num_model_params(model.text_decoder)
             )
 
         logger.info(f"Number of model params: {self._get_num_model_params(model)}")
+        logger.info(f"Number of S2T params: {_num_s2t_params(model)}")
         return model
 
     @classmethod
-    def _get_num_model_params(cls, model: torch.nn.Module) -> int:
+    def _get_num_model_params(cls, model: Optional[torch.nn.Module]) -> int:
+        if model is None:
+            return 0
         pp = 0
         for p in list(model.parameters()):
             nn = 1
@@ -221,14 +288,32 @@ class ModelBuilder:
         return pp
 
     def _build_custom_model_config(self) -> UnitYConfig:
-        config = self.config.custom_params
+        assert self.config.custom_params is not None
+        config: CustomModelParams = self.config.custom_params
+        num_fbank_channels = (
+            config.num_fbank_channels if config.num_fbank_channels is not None else 80
+        )
+        fbank_stride = config.fbank_stride if config.fbank_stride is not None else 2
+        nllb_ffn_inner_dim = (
+            config.nllb_ffn_inner_dim
+            if config.nllb_ffn_inner_dim is not None
+            else config.model_embed_dim * 8
+        )
+        w2v2_ffn_inner_dim = (
+            config.w2v2_ffn_inner_dim
+            if config.w2v2_ffn_inner_dim is not None
+            else config.model_embed_dim * 4
+        )
         assert config is not None
         return UnitYConfig(
+            use_gelu=False,
+            use_text_decoder=True,
+            prosody_encoder_config=None,
             model_dim=config.model_embed_dim,
             w2v2_encoder_config=Wav2Vec2EncoderConfig(
                 model_dim=config.model_embed_dim,
                 max_seq_len=4096,
-                feature_dim=160,
+                feature_dim=num_fbank_channels * fbank_stride,
                 use_fbank=True,
                 first_pass_dropout_p=0.0,
                 layer_norm_features=config.w2v2_encoder_layers_layernorm_features,
@@ -236,8 +321,8 @@ class ModelBuilder:
                 feature_extractor_bias=False,
                 feature_extractor_layer_norm_convs=False,
                 feature_grad_scale=0,
-                num_fbank_channels=80,
-                fbank_stride=2,
+                num_fbank_channels=num_fbank_channels,
+                fbank_stride=fbank_stride,
                 sample_fbank_every_k=1,
                 pos_encoder_type=config.w2v2_pos_encoder_type,
                 pos_encoder_depth=config.w2v2_pos_encoder_depth,
@@ -246,7 +331,7 @@ class ModelBuilder:
                 use_conformer=config.w2v2_encoder_layers_use_conformer,
                 num_encoder_layers=config.w2v2_encoder_layers,
                 num_encoder_attn_heads=16,
-                ffn_inner_dim=config.model_embed_dim * 4,
+                ffn_inner_dim=w2v2_ffn_inner_dim,
                 dropout_p=0.0,
                 attn_dropout_p=0.0,
                 layer_drop_p=0.0,
@@ -267,10 +352,14 @@ class ModelBuilder:
                 num_decoder_layers=config.nllb_decoder_layers,
                 num_encoder_attn_heads=16,
                 num_decoder_attn_heads=16,
-                ffn_inner_dim=config.model_embed_dim * 8,
+                ffn_inner_dim=nllb_ffn_inner_dim,
                 dropout_p=0.1,
             ),
             t2u_config=UnitYT2UConfig(
+                use_gelu=False,
+                char_pad_idx=0,
+                use_prosody_proj=False,
+                prosody_encoder_dim=0,
                 model_dim=config.model_embed_dim,
                 unit_max_seq_len=2048,
                 target_vocab_info=VocabularyInfo(
@@ -312,5 +401,7 @@ if __name__ == "__main__":
         pretrained_s2t_decoder_path="/fsx-ust/spopuri/datasets/PT_CKPT/S2T/S2T_M4T_V1_V1_cleaned.pt",
         pretrained_t2u_path="/fsx-ust/spopuri/datasets/PT_CKPT/T2U/V5_10K_p2_14_80K.pt",
     )
+    config = ModelConfig(from_model_config="seamlessM4T_medium")
     builder = ModelBuilder(config=config)
-    model = ModelBuilder(config=config).build_model()
+    model = builder.build_model()
+    print(model)

+ 101 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_1024_1_3_wh_transc_120ch.yaml

@@ -0,0 +1,101 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 120
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 1024
+    num_fbank_channels: 120
+    fbank_stride: 2
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 1
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 120
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 80
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  50
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 99 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_768_8_4_wh_transc.yaml

@@ -0,0 +1,99 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    nllb_decoder_layers: 4
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 8
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  50
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 98 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_micro_5x5m.yaml

@@ -0,0 +1,98 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 512
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: eng2,por2,rus2,spa2,hin_m4t
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/asr_5x5m/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 6
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 1000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0005
+  log_steps:  100
+  max_epochs: 20
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 98 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_micro_5x5m_tune_on_noise.yaml

@@ -0,0 +1,98 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 512
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: eng2,por2,rus2,spa2,hin_m4t,esc
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/asr_5x5m/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 6
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 200
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.00001
+  log_steps:  100
+  max_epochs: 20
+  patience: 6
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 99 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_micro_optimized_5x5m.yaml

@@ -0,0 +1,99 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 512
+    nllb_decoder_layers: 2
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    nllb_ffn_inner_dim: 2048
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 4
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: eng2,por2,rus2,spa2,hin_m4t
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/asr_5x5m/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 6
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 1000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0005
+  log_steps:  100
+  max_epochs: 20
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 98 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_mini_5x5m.yaml

@@ -0,0 +1,98 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: eng2,por2,rus2,spa2,hin_m4t
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/asr_5x5m/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 6
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 1000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0005
+  log_steps:  100
+  max_epochs: 20
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 98 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_nano_5x5m.yaml

@@ -0,0 +1,98 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 256
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: eng2,por2,rus2,spa2,hin_m4t
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/asr_5x5m/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 6
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 1000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0005
+  log_steps:  100
+  max_epochs: 20
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 98 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_nano_5x5m_tune_on_noise.yaml

@@ -0,0 +1,98 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 256
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: eng2,por2,rus2,spa2,hin_m4t,esc
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/asr_5x5m/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 6
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 200
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.00001
+  log_steps:  100
+  max_epochs: 20
+  patience: 6
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 99 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_nano_optimized_5x5m.yaml

@@ -0,0 +1,99 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 384
+    nllb_decoder_layers: 2
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    nllb_ffn_inner_dim: 1536
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: eng2,por2,rus2,spa2,hin_m4t
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/asr_5x5m/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 6
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 1000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0005
+  log_steps:  100
+  max_epochs: 20
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 98 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_small_5x5m.yaml

@@ -0,0 +1,98 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: eng2,por2,rus2,spa2,hin_m4t
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/asr_5x5m/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 6
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 1000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0005
+  log_steps:  100
+  max_epochs: 20
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 11 - 9
src/seamless_communication/cli/m4t/train/recipes/asr_small.yaml → src/seamless_communication/cli/m4t/train/recipes/asr_small_eng_10m_wh.yaml

@@ -5,11 +5,12 @@ eval_data:
     fbanks_standardize_audio: true
     fbanks_waveform_scale: 32768
   fbank_feats_pad_idx: 0
-  manifest_list: dev_asr_only_aggregated_adapted
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng
   manifest_list_path: null
   manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
   max_seconds_per_input_audio: 15
-  fixed_batch_size: 40
+  fixed_batch_size: 30
   max_tgt_text_tokens_per_batch: 1000
   max_tgt_text_tokens_per_sample: 300
   max_units_per_sample: 1500
@@ -43,10 +44,11 @@ model:
     w2v2_encoder_layers: 6
     w2v2_encoder_layers_layernorm_features: false
     w2v2_encoder_layers_use_conformer: true
-    w2v2_num_pos_conv_groups: 0
-    w2v2_pos_conv_kernel_size: 0
-    w2v2_pos_encoder_depth: 0
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
     w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
   from_model: null
   from_model_config: null
   pretrained_s2t_decoder_path: null
@@ -59,7 +61,7 @@ train_data:
     fbanks_standardize_audio: true
     fbanks_waveform_scale: 32768
   fbank_feats_pad_idx: 0
-  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted
+  manifest_list: eng_10m_wh_transc
   manifest_list_path: null
   manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
   max_seconds_per_input_audio: 15
@@ -86,12 +88,12 @@ train_data:
     num_units: null
   unit_tokenizer_name: seamlessM4T_large
 training:
-  eval_steps: 5000 
+  eval_steps: 5000
   float_dtype: fp32
   label_smoothing: 0.2
   learning_rate: 0.0001
-  log_steps:  200 
+  log_steps:  50
   max_epochs: 100
   patience: 10
   start_learning_rate: 1.0e-07
-  warmup_steps: 1000
+  warmup_steps: 1000

+ 99 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_small_old_and_10m_wh.yaml

@@ -0,0 +1,99 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 512
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc,eng_10m_wh_transc
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 80
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  50
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 99 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_small_vit_transc_eng.yaml

@@ -0,0 +1,99 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 1024
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 3
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  50
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 101 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_mms_120ch.yaml

@@ -0,0 +1,101 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 120
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    num_fbank_channels: 120
+    fbank_stride: 2
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 120
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc_norm,/data/home/mavlyutov/mmstts/dataset/premixed/all_mms
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  100
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 101 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_mms_120ch_fp16.yaml

@@ -0,0 +1,101 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 120
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    num_fbank_channels: 120
+    fbank_stride: 2
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 120
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc_norm,/data/home/mavlyutov/mmstts/dataset/premixed/all_mms
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 1000
+  float_dtype: fp16
+  label_smoothing: 0.2
+  learning_rate: 0.00005
+  log_steps:  5
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 101 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_mms_120ch_nost.yaml

@@ -0,0 +1,101 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: false
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    num_fbank_channels: 80
+    fbank_stride: 2
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: false
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc_norm,/data/home/mavlyutov/mmstts/dataset/premixed/all_mms
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  100
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 11 - 9
src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_transc.yaml

@@ -5,7 +5,8 @@ eval_data:
     fbanks_standardize_audio: true
     fbanks_waveform_scale: 32768
   fbank_feats_pad_idx: 0
-  manifest_list: dev_asr_only_aggregated_adapted
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng_norm
   manifest_list_path: null
   manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
   max_seconds_per_input_audio: 15
@@ -43,10 +44,11 @@ model:
     w2v2_encoder_layers: 6
     w2v2_encoder_layers_layernorm_features: false
     w2v2_encoder_layers_use_conformer: true
-    w2v2_num_pos_conv_groups: 0
-    w2v2_pos_conv_kernel_size: 0
-    w2v2_pos_encoder_depth: 0
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
     w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
   from_model: null
   from_model_config: null
   pretrained_s2t_decoder_path: null
@@ -59,11 +61,11 @@ train_data:
     fbanks_standardize_audio: true
     fbanks_waveform_scale: 32768
   fbank_feats_pad_idx: 0
-  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc_norm
   manifest_list_path: null
   manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
   max_seconds_per_input_audio: 15
-  fixed_batch_size: 30
+  fixed_batch_size: 40
   max_tgt_text_tokens_per_batch: 600
   max_tgt_text_tokens_per_sample: 300
   max_units_per_sample: 1500
@@ -87,11 +89,11 @@ train_data:
   unit_tokenizer_name: seamlessM4T_large
 training:
   eval_steps: 5000
-  float_dtype: bf16
+  float_dtype: fp32
   label_smoothing: 0.2
   learning_rate: 0.0001
-  log_steps:  50 
+  log_steps:  50
   max_epochs: 100
   patience: 10
   start_learning_rate: 1.0e-07
-  warmup_steps: 1000
+  warmup_steps: 1000

+ 99 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_transc_mms.yaml

@@ -0,0 +1,99 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc_norm,/data/home/mavlyutov/mmstts/dataset/premixed/all_mms
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  50
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 101 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_small_wh_vary_audio.yaml

@@ -0,0 +1,101 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    num_fbank_channels: 80
+    fbank_stride: 2
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  50
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 101 - 0
src/seamless_communication/cli/m4t/train/recipes/asr_wide_wh_120ch.yaml

@@ -0,0 +1,101 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 120
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 1024
+    num_fbank_channels: 120
+    fbank_stride: 2
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 3
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 120
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc_norm
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 60
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  100
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 1 - 1
src/seamless_communication/cli/m4t/train/recipes/large_M4T_v1.yaml

@@ -81,7 +81,7 @@ training:
   float_dtype: bf16
   label_smoothing: 0.2
   learning_rate: 0.0001
-  log_steps: 200
+  log_steps: 1
   max_epochs: 100
   patience: 10
   start_learning_rate: 1.0e-07

+ 99 - 0
src/seamless_communication/cli/m4t/train/recipes/mt_small_orig_dataset.yaml

@@ -0,0 +1,99 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    nllb_decoder_layers: 4
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 8
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_hindi_filt_balanced_aggregated_adapted
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0005
+  log_steps:  50
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 99 - 0
src/seamless_communication/cli/m4t/train/recipes/mt_small_orig_dataset_and_eng_10m.yaml

@@ -0,0 +1,99 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  # manifest_list: dev_asr_only_aggregated_adapted_norm
+  manifest_list: dev_vpsr_eng
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 30
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    nllb_decoder_layers: 4
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 20010
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 8
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 16
+    w2v2_pos_conv_kernel_size: 128
+    w2v2_pos_encoder_depth: 1
+    w2v2_pos_encoder_type: relative
+    # w2v2_pos_encoder_type: conv
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_hindi_filt_balanced_aggregated_adapted,eng_10m_wh_transc_mt_cols
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  50
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 340 - 0
src/seamless_communication/cli/m4t/train/run_eval.py

@@ -0,0 +1,340 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+import platform
+from io import BytesIO
+from pathlib import Path
+from typing import Any, Dict, Iterator, List, NamedTuple, Tuple, Union
+from zipfile import ZipFile, ZipInfo
+
+import sacrebleu
+import torch
+import torchaudio  # type: ignore
+from jiwer import wer  # type: ignore
+
+import seamless_communication.cli.m4t.train.cleaners as cleaners
+from fairseq2.data.audio import WaveformToFbankConverter
+from fairseq2.generation import NGramRepeatBlockProcessor, SequenceGeneratorOptions
+from fairseq2.models.nllb.tokenizer import NllbTokenizer
+from seamless_communication.cli.m4t.train import model as _model
+from seamless_communication.cli.m4t.train import trainer as _trainer
+from seamless_communication.cli.m4t.train.configs import (
+    DataLoadingConfig,
+    WorkflowParams,
+)
+from seamless_communication.inference.generator import UnitYGenerator
+from seamless_communication.models.tokenizer import SPMTokenizer
+from seamless_communication.models.unity import (
+    UnitTokenizer,
+    UnitYModel,
+    load_unity_text_tokenizer,
+    load_unity_unit_tokenizer,
+)
+
+logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
+logging.basicConfig(
+    level=logging.INFO,
+    format=logging_format,
+)
+
+logger = logging.getLogger("eval")
+
+
+class TestRecord(NamedTuple):
+    wav: torch.Tensor
+    tgt_lang: str
+    tgt_text: str
+
+
+RAW_TGT_TEXT_COL_NAME = "tgt_text"
+TGT_LANG_COL_NAME = "tgt_lang"
+AUDIO_COL_NAME = "audio"
+SAMPLE_RATE = 16000
+OPEN_ZIP_ARCHIVES: Dict[str, Tuple[ZipFile, List[ZipInfo]]] = {}
+
+
+def get_bleu(
+    translations: List[str], ref_translations: List[str], dialect: str = "en"
+) -> float:
+    if dialect.lower().startswith("ja") or dialect.lower().startswith("zh"):
+        tokenizer = "char"
+    else:
+        tokenizer = "13a"
+    print(f"Num samples {len(translations)} {len(ref_translations)}")
+    for idx in range(5):
+        logger.info(f"Example transl: {translations[idx]}")
+        logger.info(f"Example refere: {ref_translations[idx]}")
+        logger.info("---")
+    score = sacrebleu.corpus_bleu(translations, [ref_translations], tokenize=tokenizer)
+    return score.score
+
+
+def _remove_stuttering(text):
+    filt = []
+    for word in text.split():
+        if len(filt) > 1 and filt[-1] == word and filt[-2] == word:
+            continue
+        filt.append(word)
+    return " ".join(filt)
+
+
+def _normalize_text_for_wer(text, lang="en"):
+    text = cleaners.basic_cleaners(text)
+    text = cleaners.remove_punctuations(text, cleaners.PUNCTUATIONS_EXCLUDE_APOSTROPHE)
+    text = _remove_stuttering(text)
+    if lang == "ja":
+        text = cleaners.normalize_ja_text(text)
+    return text
+
+
+def get_wer(
+    translations: List[str], ref_translations: List[str], dialect: str = "en"
+) -> float:
+    reference = [_normalize_text_for_wer(txt) for txt in ref_translations]
+    hypothesis = [_normalize_text_for_wer(txt) for txt in translations]
+    return (
+        wer(
+            reference=reference,
+            hypothesis=hypothesis,
+        )
+        * 100
+    )
+
+
+def _iter_manifest(manifest_path: Path) -> Iterator[Tuple[str, str, str]]:
+    tgt_lang_idx = None
+    tgt_text_idx = None
+    audio_idx = None
+    with open(manifest_path) as fp_in:
+        for line in fp_in:
+            chunks = line.strip().split("\t")
+            if tgt_lang_idx is None:  # header
+                tgt_lang_idx = chunks.index(TGT_LANG_COL_NAME)
+                tgt_text_idx = chunks.index(RAW_TGT_TEXT_COL_NAME)
+                audio_idx = chunks.index(AUDIO_COL_NAME)
+                continue
+            yield chunks[audio_idx], chunks[tgt_lang_idx], chunks[tgt_text_idx]
+
+
+def _extract_audio_blob(arch_name: str, offset: int) -> bytes:
+    archive, records = OPEN_ZIP_ARCHIVES[arch_name]
+    for info in records:
+        info_offset = info.header_offset + len(info.FileHeader())
+        if abs(info_offset - offset) < 100:  # expect some misalignment
+            local_path = archive.extract(info)
+            with open(local_path, "rb") as fp_in:
+                content_bytes = fp_in.read()
+            os.unlink(local_path)
+            return content_bytes
+    raise ValueError(f"Didn't find record with offset {offset} in {arch_name}")
+
+
+def _load_archive_data(audio_zips_root: str, name: str) -> None:
+    if name in OPEN_ZIP_ARCHIVES:
+        return
+    archive = ZipFile(
+        os.path.join(audio_zips_root, name),
+        mode="r",
+    )
+    OPEN_ZIP_ARCHIVES[name] = (archive, archive.infolist())
+    logging.info(f"Loaded archive {name}")
+
+
+def _load_audio_wav(audio_zips_root: str, audio_str: str) -> torch.Tensor:
+    archive_name, offset_str, _ = audio_str.split(":")
+    offset = int(offset_str)
+    _load_archive_data(audio_zips_root=audio_zips_root, name=archive_name)
+    blob = _extract_audio_blob(arch_name=archive_name, offset=offset)
+    wav, samplerate = torchaudio.load(BytesIO(blob))
+    assert samplerate == SAMPLE_RATE
+    return wav
+
+
+def load_manifest(manifest_path: Path, audio_zips_root: str) -> Iterator[TestRecord]:
+    for audio_str, tgt_lang, tgt_text in _iter_manifest(manifest_path=manifest_path):
+        audio = _load_audio_wav(audio_zips_root=audio_zips_root, audio_str=audio_str)
+        yield TestRecord(wav=audio, tgt_lang=tgt_lang, tgt_text=tgt_text)
+
+
+def _init_unit_tokenizer(data_config: DataLoadingConfig) -> UnitTokenizer:
+    if data_config.unit_tokenization.from_model is not None:
+        return load_unity_unit_tokenizer(data_config.unit_tokenization.from_model)
+    else:
+        raise NotImplementedError("TBD")
+
+
+def _init_text_tokenizer(
+    data_config: DataLoadingConfig,
+) -> Union[NllbTokenizer, SPMTokenizer]:
+    if data_config.text_tokenization.from_model is not None:
+        return load_unity_text_tokenizer(data_config.text_tokenization.from_model)
+    else:
+        assert data_config.text_tokenization.langtoks is not None
+        assert data_config.text_tokenization.spm_path is not None
+        return SPMTokenizer(
+            pathname=data_config.text_tokenization.spm_path,
+            langs=data_config.text_tokenization.langtoks,
+        )
+
+
+def translate(
+    model: UnitYModel,
+    text_tokenizer: Union[NllbTokenizer, SPMTokenizer],
+    unit_tokenizer: UnitTokenizer,
+    fbank_extractor: WaveformToFbankConverter,
+    dtype: torch.dtype,
+    device: torch.device,
+    test_record: TestRecord,
+    ngram_filtering: bool = True,
+    text_max_len_a: int = 1,
+    text_max_len_b: int = 200,
+    unit_max_len_a: int = 1,
+    unit_max_len_b: int = 50,
+) -> Tuple[str, Any]:
+    """Runs S2T translation. TBD: add S2S"""
+    text_opts = SequenceGeneratorOptions(
+        beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
+    )
+    unit_opts = SequenceGeneratorOptions(
+        beam_size=5, soft_max_seq_len=(unit_max_len_a, unit_max_len_b)
+    )
+    if ngram_filtering:
+        text_opts.step_processor = NGramRepeatBlockProcessor(ngram_size=4)
+        unit_opts.step_processor = NGramRepeatBlockProcessor(ngram_size=4)
+    generator = UnitYGenerator(
+        model,
+        text_tokenizer,
+        test_record.tgt_lang,
+        unit_tokenizer,
+        text_opts=text_opts,
+        unit_opts=unit_opts,
+    )
+    wav = test_record.wav
+    assert len(wav.shape) in (1, 2)
+    if len(wav.shape) == 1:
+        wav = wav.unsqueeze(-1)
+    elif wav.shape[0] <= 2:  # channel is first, should be second:
+        wav = wav.transpose(0, 1)
+    fbank = fbank_extractor(
+        {
+            "waveform": wav,
+            "sample_rate": SAMPLE_RATE,
+        }
+    )["fbank"]
+    s2t_result, t2u_result = generator(
+        fbank.unsqueeze(0),
+        None,
+        "speech",
+        "text",
+        ngram_filtering=ngram_filtering,
+    )
+    s2t_out = str(s2t_result.sentences[0])
+    return s2t_out, None
+
+
+def run_evaluation(
+    parameters: WorkflowParams, checkpoint_path: Path, manifest_path: Path
+):
+    device = torch.device("cuda") if torch.cuda.device_count() else torch.device("cpu")
+    float_dtype = _trainer.UnitYTrainer._get_float_dtype(
+        parameters.training.float_dtype
+    )
+    logger.info(f"Device: {device}, float dtype: {float_dtype}")
+    audio_zips_root = parameters.train_data.audio.audio_root_dir
+    logger.info(f"Audio zip root: {audio_zips_root}")
+    model = _model.ModelBuilder(
+        config=parameters.model, dtype=float_dtype, device=device
+    ).build_model(skip_loading_weights=True)
+    logger.info(f"Loading checkpoint from {checkpoint_path}")
+    state_dict = torch.load(checkpoint_path, map_location=device)
+    # temporary fix for previous bug with checkpoint saving:
+    state_dict = {
+        _trainer.UnitYTrainer._strip_state_key_prefixes(
+            key.replace("t2u.", "t2u_model.")
+        ): value
+        for key, value in state_dict.items()
+    }
+    model.load_state_dict(state_dict)
+    model.eval()
+    text_tokenizer = _init_text_tokenizer(data_config=parameters.train_data)
+    unit_tokenizer = _init_unit_tokenizer(data_config=parameters.train_data)
+    fbank_extractor = WaveformToFbankConverter(
+        num_mel_bins=parameters.train_data.audio.fbanks_num_mel_bins or 80,
+        waveform_scale=parameters.train_data.audio.fbanks_waveform_scale,
+        channel_last=True,
+        standardize=parameters.train_data.audio.fbanks_standardize_audio,
+        device=device,
+        dtype=float_dtype,
+    )
+
+    logger.info(f"Model: {model}")
+    records = load_manifest(
+        manifest_path=manifest_path, audio_zips_root=audio_zips_root
+    )
+
+    model_translations = []
+    reference_translations = []
+    for idx, record in enumerate(records):
+        reference_translations.append(record.tgt_text)
+        s2t, t2u = translate(
+            model=model,
+            text_tokenizer=text_tokenizer,
+            unit_tokenizer=unit_tokenizer,
+            fbank_extractor=fbank_extractor,
+            test_record=record,
+            device=device,
+            dtype=float_dtype,
+        )
+        model_translations.append(s2t)
+        logger.info(f"{idx} ref: {record.tgt_text}")
+        logger.info(f"{idx} s2t: {s2t}")
+        logger.info("--")
+    model_wer = get_wer(model_translations, ref_translations=reference_translations)
+    logger.info(f"FINAL WER: {model_wer}, manifest: {manifest_path}")
+
+
+def init_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(description="Run M4T training")
+    parser.add_argument(
+        "--manifest",
+        type=Path,
+        required=True,
+        help="Path to test manifest",
+    )
+    parser.add_argument(
+        "--checkpoint",
+        type=Path,
+        required=True,
+        help="Path to checkpoint",
+    )
+    parser.add_argument(
+        "--train_params",
+        type=Path,
+        required=True,
+        help="Training workflow config (*_config.yaml is available in work directory)",
+    )
+    return parser
+
+
+def main() -> None:
+    args = init_parser().parse_args()
+    manifest: Path = args.manifest
+    config_path: Path = args.train_params
+    checkpoint_path: Path = args.checkpoint
+    assert manifest.exists()
+    assert config_path.exists()
+    assert checkpoint_path.exists()
+    parameters = WorkflowParams.from_file(config_path.as_posix())
+    run_evaluation(
+        parameters=parameters, checkpoint_path=checkpoint_path, manifest_path=manifest
+    )
+
+
+if __name__ == "__main__":
+    main()

+ 84 - 22
src/seamless_communication/cli/m4t/train/trainer.py

@@ -8,7 +8,7 @@
 import logging
 import os
 import time
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import torch
 import torch.distributed as dist
@@ -18,6 +18,7 @@ from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.optim.lr_scheduler import MyleLR
 from torch.optim import Adam
+from fairseq2.optim import AdamW
 
 from seamless_communication.cli.m4t.train import dataloader, dist_utils
 from seamless_communication.cli.m4t.train.configs import TrainingParams
@@ -47,6 +48,7 @@ class UnitYTrainWrapper(nn.Module):
         assert self.model.t2u_model is not None
         assert batch.speech_to_text.src_tokens is not None
         # s2t
+        assert batch.speech_to_text.src_lengths is not None
         speech_padding_mask = PaddingMask(
             seq_lens=batch.speech_to_text.src_lengths,
             batch_seq_len=int(torch.max(batch.speech_to_text.src_lengths).item()),
@@ -56,6 +58,7 @@ class UnitYTrainWrapper(nn.Module):
             padding_mask=speech_padding_mask,
         )
         assert batch.speech_to_text.prev_output_tokens is not None
+        assert batch.speech_to_text.target_lengths is not None
         s2t_prev_out_tokens_padding_mask = PaddingMask(
             seq_lens=batch.speech_to_text.target_lengths,
             batch_seq_len=int(torch.max(batch.speech_to_text.target_lengths).item()),
@@ -66,14 +69,15 @@ class UnitYTrainWrapper(nn.Module):
             encoder_output=speech_encoder_out,
             encoder_padding_mask=speech_encoder_padding_mask,
         )
+        assert self.model.final_proj is not None
         text_logits = self.model.final_proj(text_decoder_out)
         # t2u
         (
             unit_encoder_out,
             unit_encoder_padding_mask,
         ) = self.t2u.encode(
-            text_decoder_output=text_decoder_out,
-            text_decoder_padding_mask=text_decoder_padding_mask,
+            seqs=text_decoder_out,
+            padding_mask=text_decoder_padding_mask,
         )
         t2u_prev_out_tokens_padding_mask = PaddingMask(
             seq_lens=batch.text_to_units.target_lengths,
@@ -122,6 +126,12 @@ class CalcLoss:
             ignore_prefix_size=self.s2t_ignore_prefix_size,
             label_smoothing=self.label_smoothing,
         )
+        loss = s2t_loss / s2t_numel
+        if torch.any(torch.isnan(loss)):
+            logger.error("LOSS IS NAN. EXITING")
+            logger.error(batch.speech_to_text)
+            raise ValueError("LOSS IS NAN")
+        return loss
         assert batch.text_to_units.target_lengths is not None
         s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
         s2u_loss = SequenceModelOutput(
@@ -200,18 +210,12 @@ class UnitYTrainer:
             t2u_vocab_info=model.t2u_model.target_vocab_info,
         )
         self._try_load_checkpoint(model=model)
+
         self.model = self._wrap_model_for_trainining(model=model)
 
         # TODO: make tweakable
-        self.optimizer = Adam(
-            params=self.model.parameters(),
-            lr=self.params.learning_rate,
-            betas=(0.9, 0.98),
-            eps=1e-08,
-            maximize=False,
-            weight_decay=0.0,
-            fused=True,
-        )
+        self._use_fairseq_adam_w = True
+        self.optimizer = self._build_optimizer()
 
         self.grad_scaler = torch.cuda.amp.GradScaler() if self.float_dtype == torch.float16 else None  # type: ignore
 
@@ -232,11 +236,40 @@ class UnitYTrainer:
         self.batch_sizes: List[int] = []
         self.gpu_usage: List[float] = []
 
-    def _try_load_checkpoint(self, model: torch.nn.Module):
+    def _build_optimizer(self) -> Union[Adam, AdamW]:
+        betas = (0.9, 0.98)
+        eps = 1e-08
+        maximize = False
+        weight_decay = 0.0
+        if self._use_fairseq_adam_w:
+            return AdamW(
+                params=self.model.parameters(),
+                lr=self.params.learning_rate,
+                betas=betas,
+                eps=eps,
+                maximize=maximize,
+                weight_decay=weight_decay,
+                impl="fused",
+            )
+        return Adam(
+            params=self.model.parameters(),
+            lr=self.params.learning_rate,
+            betas=betas,
+            eps=eps,
+            maximize=maximize,
+            weight_decay=weight_decay,
+            fused=True,
+        )
+
+    def _try_load_checkpoint(self, model: torch.nn.Module) -> None:
         chck_path = self.get_best_checkpoint_path()
         if os.path.exists(chck_path):
             logger.info(f"Loading state dict from {chck_path}")
             state_dict = torch.load(chck_path)
+            state_dict = {
+                self._strip_state_key_prefixes(key): value
+                for key, value in state_dict.items()
+            }
             model.load_state_dict(state_dict)
 
     @classmethod
@@ -332,7 +365,7 @@ class UnitYTrainer:
         eval_loss = loss_hist.reduce()
         self._update_eval_stats(eval_loss)
 
-    def _train_step_log(self):
+    def _train_step_log(self) -> None:
         """Log train stats"""
         if (self.update_idx + 1) % self.params.log_steps == 0:
             avg_loss = self.train_loss_hist.reduce()
@@ -348,6 +381,25 @@ class UnitYTrainer:
             )
             self._reset_log_stats()
 
+    def _get_grad_norms(self) -> None:
+        if self.update_idx > 5000:
+            return
+        path = os.path.join(self.chck_save_dir, "grads.txt")
+        if self.update_idx == 0 and os.path.exists(path):
+            os.unlink(path)
+        out_fp = open(path, "a")
+        for key, param in self.model.named_parameters():
+            try:
+                if param.grad is None:
+                    norm = None
+                else:
+                    norm = (torch.linalg.norm(param.grad.reshape(-1), ord=1) / torch.numel(param.grad)).item()
+            except Exception:
+                logger.exception(f"Failed for param {key}")
+                raise
+            out_fp.write(f"{self.update_idx}\t{key}\t{norm}\n")
+        out_fp.close()
+
     def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
         """Run one train step"""
         self.model.train()
@@ -363,6 +415,8 @@ class UnitYTrainer:
             self.grad_scaler.update()
         else:
             loss.backward()
+            if dist_utils.is_main_process():
+                self._get_grad_norms()
             self.optimizer.step()
 
         self.lr_scheduler.step()
@@ -376,12 +430,16 @@ class UnitYTrainer:
         """Explicitly release large memory consumers"""
         del batch
 
-    def _strip_state_key_prefixes(self, key: str) -> str:
+    @classmethod
+    def _strip_state_key_prefixes(cls, key: str) -> str:
         """Removes state_dict keys prefixes associated with model wrappers"""
         to_strip = ["module.", "model."]
         for prefix in to_strip:
             if key.startswith(prefix):
-                key = key[len(prefix) :]
+                key = key[len(prefix):]  # noqa
+        for prefix, rewrite in [("t2u.", "t2u_model.")]:
+            if key.startswith(prefix):
+                key = rewrite + key[len(prefix):]
         return key
 
     def _get_state(self) -> Dict[str, Any]:
@@ -392,12 +450,13 @@ class UnitYTrainer:
         }
         return model_state_dict
 
-    def _get_chck_path(self) -> str:
+    def _get_chck_path(self, suffix: Optional[str] = None) -> str:
         ts = str(int(time.time()))
         epoch = str(self.epoch_idx).zfill(3)
         update = str(self.update_idx).zfill(6)
         eval_loss = f"{self.last_eval_loss:.4f}"
-        name = f"{ts}_{epoch}_{update}_{eval_loss}.pt"
+        chck_suffix = "" if suffix is None else f"_{suffix}"
+        name = f"{ts}_{epoch}_{update}_{eval_loss}{chck_suffix}.pt"
         return os.path.join(self.chck_save_dir, name)
 
     def _get_best_checkpoint_link_path(self) -> str:
@@ -406,16 +465,18 @@ class UnitYTrainer:
     def get_best_checkpoint_path(self) -> str:
         return os.path.realpath(self._get_best_checkpoint_link_path())
 
-    def _save_model(self):
+    def _save_model(self, suffix: Optional[str] = None) -> None:
         if dist_utils.is_main_process():
             state_dict = self._get_state()
-            save_path = self._get_chck_path()
+            save_path = self._get_chck_path(suffix)
             logger.info(f"Saving checkpoint to {save_path}")
             torch.save(state_dict, save_path)
             if self.is_best_state:
                 best_link_path = self._get_best_checkpoint_link_path()
-                if os.path.exists(best_link_path):
+                try:
                     os.unlink(best_link_path)
+                except FileNotFoundError:
+                    pass
                 os.symlink(save_path, best_link_path)
                 logger.info(
                     f"Updating pointer to the best checkpoint {best_link_path} -> {save_path}"
@@ -423,7 +484,7 @@ class UnitYTrainer:
         if dist_utils.is_dist_initialized():
             dist.barrier()
 
-    def run(self):
+    def run(self) -> None:
         logger.info("Start training")
         self._reset_stats()
         self._eval_model()
@@ -443,3 +504,4 @@ class UnitYTrainer:
                 self.update_idx += 1
             self.train_data_loader.reset()
             self.epoch_idx += 1
+        self._save_model(suffix="last")