Sfoglia il codice sorgente

M4T training scripts and recipes

mavlyutov 1 anno fa
parent
commit
efe88afa2e

+ 2 - 1
dev_requirements.txt

@@ -1,4 +1,5 @@
 pytest
 black
 flake8
-isort
+isort
+mypy

+ 0 - 0
scripts/m4t/train/__init__.py


+ 247 - 0
scripts/m4t/train/configs.py

@@ -0,0 +1,247 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from dataclasses import dataclass
+from typing import Dict, Any, Union, get_origin, get_args, List, Literal, Optional
+
+
+@dataclass
+class Config:
+    def serialize(self):
+        asdict = {}
+        for key in self.__dataclass_fields__.keys():
+            value = getattr(self, key)
+            if isinstance(value, Config):
+                asdict[key] = value.serialize()
+            else:
+                asdict[key] = value
+        return asdict
+
+    @classmethod
+    def _is_config(cls, type_like: Any) -> bool:
+        """ checks if type_like class is a subclass of Config"""
+        try:
+            if issubclass(type_like, Config):
+                return True
+        except TypeError:
+            pass
+        return False
+
+    @classmethod
+    def _is_optional_config(cls, type_like: Any) -> bool:
+        """ checks if type_like == Optional[subclass of Config] """
+        if not get_origin(type_like) == Union:
+            return False
+        args = [arg for arg in get_args(type_like) if arg is not type(None)]
+        return len(args) == 1 and cls._is_config(args[0])
+
+    @classmethod
+    def deserialize(cls, asdict: Dict[str, Any]):
+        kwargs = {}
+        for key, field_desc in cls.__dataclass_fields__.items():
+            non_null = asdict.get(key) is not None
+            # Optional[Config]
+            if cls._is_optional_config(field_desc.type):
+                if non_null:
+                    type_arg = [arg for arg in get_args(field_desc.type) if arg is not type(None)][0]
+                    kwargs[key] = type_arg.deserialize(asdict[key])
+                else:
+                    kwargs[key] = None
+            # TODO: add containers with Config
+            elif get_origin(field_desc.type) in [Union, List, Dict, Literal]:
+                kwargs[key] = asdict.get(key)
+            elif cls._is_config(field_desc.type):
+                if non_null:
+                    kwargs[key] = field_desc.type.deserialize(asdict[key])
+                else:
+                    kwargs[key] = field_desc.type.default  # type: ignore
+            else:
+                kwargs[key] = asdict.get(key)
+        return cls(**kwargs)
+
+
+@dataclass
+class TextTokenizationConfig(Config):
+    from_model: Optional[str] = "seamlessM4T_large"
+    """If set, using a tokenizer from the model cards."""
+
+    spm_path: Optional[str] = None
+    """Path to a custom spm model. Not used if `from_model` is set."""
+
+    langtoks: Optional[List[str]] = None
+    """List of language tokens that should be added. Not used if `from_model` is set."""
+
+
+@dataclass
+class UnitTokenizationConfig(Config):
+    from_model: Optional[str] = "seamlessM4T_large"
+    """If set, using tokenizer from a model card."""
+
+    num_units: Optional[int] = None
+    """Alternatively, build custom tokenizer, set number of units"""
+
+    langtoks: Optional[List[str]] = None
+    """List of language tokens that should be added. Not used if `from_model` is set."""
+
+
+@dataclass
+class AudioProcessingConfig(Config):
+    audio_root_dir: str = "/"
+    """The root directory of the zipped audio files."""
+
+    fbanks_standardize_audio: bool = True
+
+    fbanks_num_mel_bins: int = 80
+
+    fbanks_waveform_scale: int = 2**15
+
+
+@dataclass
+class DataLoadingConfig(Config):
+    manifest_list_path: Optional[str] = None
+    """Path to a file with the list of tsv manifests"""
+
+    manifest_list: Optional[str] = None
+    """Comma separated list of tsv manifests. Can be combined with `manifest_list_path`"""
+
+    manifest_path_prefix: Optional[str] = None
+    """Path prefix to manifest files (root directory)"""
+
+    audio: AudioProcessingConfig = AudioProcessingConfig()
+    """ Audio processing params """
+
+    text_tokenization: TextTokenizationConfig = TextTokenizationConfig()
+    """ Text tokenization params """
+
+    unit_tokenization: UnitTokenizationConfig = UnitTokenizationConfig()
+    """ Units tokenization params """
+
+    unit_tokenizer_name: Optional[str] = "seamlessM4T_large"
+
+    prepend_tgt_lang_tag: bool = True
+    """ Prepend output text sequence with target lang token"""
+
+    fbank_feats_pad_idx: int = 0
+    """The pad index to use in fbanks batching."""
+
+    max_tgt_text_tokens_per_batch: Optional[int] = 1000
+    """ Defines flexible batch construction """
+
+    fixed_batch_size: Optional[int] = None
+    """ If set, uses fixed batch size """
+
+    max_seconds_per_input_audio: int = 15
+    """Accept only samples with less than max_seconds_per_input_audio ( waveform.shape[0] * SR )"""
+
+    max_tgt_text_tokens_per_sample: int = 300
+    """Accept only samples with less than max_sequence_length units"""
+
+    max_units_per_sample: int = 1500
+    """Accept only samples with less than max_sequence_length units"""
+
+    num_threads: int = 5
+    """The number of parallel threads during data reading and processing."""
+
+    shuffle_window: Optional[int] = 1000
+    """The size of sliding shuffle window."""
+
+    prefech_batches: Optional[int] = 10
+    """How many batches to prefetch in the background."""
+
+
+@dataclass
+class CustomModelParams(Config):
+    model_embed_dim: int = 1024
+
+    w2v2_encoder_layers: int = 24
+
+    w2v2_encoder_layers_use_conformer: bool = True
+
+    w2v2_encoder_layers_layernorm_features: bool = False
+
+    w2v2_pos_encoder_type: Literal["conv", "relative", "rotary"] = "relative"
+
+    w2v2_pos_encoder_depth: int = 0
+
+    w2v2_pos_conv_kernel_size: int = 0
+
+    w2v2_num_pos_conv_groups: int = 0
+
+    nllb_encoder_layers: int = 24
+
+    nllb_decoder_layers: int = 24
+
+    t2u_encoder_layers: int = 6
+
+    t2u_decoder_layers: int = 6
+
+    nllb_vocabulary_size: int = 256102  # num_tokens + langs + spec symbols
+
+    unit_vocabulary_size: int = 10082
+
+
+@dataclass
+class ModelConfig(Config):
+    from_model: Optional[str] = None
+    """If set, initialize a model defined in model cards. Also loads model weights."""
+
+    from_model_config: Optional[str] = None
+    """If set, initialize a model defined in model cards. Doesn't load weights."""
+
+    custom_params: Optional[CustomModelParams] = None
+    """If set, intitalize a new model with custom parameters"""
+
+    pretrained_w2v2_path: Optional[str] = None
+    """If set, use pre-trained w2v block"""
+
+    pretrained_s2t_decoder_path: Optional[str] = None
+    """If set, use pre-trained s2t decoder (NLLB)"""
+
+    pretrained_t2u_path: Optional[str] = None
+    """If set, use pre-trained t2u weights"""
+
+
+@dataclass
+class TrainingParams(Config):
+    max_epochs: int = 100
+    """ Maximum number of trainign epochs"""
+
+    label_smoothing: float = 0.2
+    """ Label smoothing coefficient for nll_loss """
+
+    warmup_steps: int = 1000
+    """ Number of steps with linearly increasing LR"""
+
+    log_steps: int = 200
+    """ Log inner loss after each `log_steps` training steps"""
+
+    eval_steps: int = 1000
+    """ Get eval loss after each `eval_steps` training steps """
+
+    patience: int = 10
+    """ Terminate if eval loss did not improve
+    over the last `patience * eval_steps` training steps"""
+
+    learning_rate: float = 1e-4
+    """ Optimizer learining rate """
+
+    start_learning_rate: float = 1e-7
+    """ Start learining rate """
+
+    float_dtype: Literal["fp16", "bf16", "fp32"] = "bf16"
+    """ Dtype used for float numbers, defines training precision """
+
+
+@dataclass
+class WorkflowParams(Config):
+    training: TrainingParams
+
+    model: ModelConfig
+
+    train_data: DataLoadingConfig
+
+    eval_data: DataLoadingConfig

+ 520 - 0
scripts/m4t/train/dataloader.py

@@ -0,0 +1,520 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+import os
+from typing import Any, Dict, Iterator, List, NamedTuple, Optional, Tuple, Union
+import ctypes
+
+import torch
+from m4t_scripts.train.configs import AudioProcessingConfig, DataLoadingConfig
+from torch import Tensor
+
+from fairseq2.data import (
+    CollateOptionsOverride,
+    Collater,
+    DataPipeline,
+    DataPipelineBuilder,
+    FileMapper,
+)
+from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
+from fairseq2.data.text import SentencePieceEncoder, StrSplitter, read_text
+from fairseq2.models.nllb.tokenizer import NllbTokenizer
+from seamless_communication.models.tokenizer import SPMTokenizer
+from seamless_communication.models.unity import (
+    UnitTokenizer,
+    load_unity_text_tokenizer,
+    load_unity_unit_tokenizer,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class SeqsBatch(NamedTuple):
+    src_tokens: Optional[Tensor]
+    src_lengths: Optional[Tensor]
+    target_tokens: Tensor
+    prev_output_tokens: Tensor
+    target_lengths: Tensor
+    prefix_tokens: Optional[Tensor]
+
+
+class MultimodalSeqsBatch(NamedTuple):
+    speech_to_text: SeqsBatch
+    text_to_units: SeqsBatch
+
+
+class UnityDataLoader:
+    CPU_DEVICE = torch.device("cpu")
+    MANIFEST_EXT = ".tsv"
+    MANIFEST_COLUMN_SEP = "\t"
+    AUDIO_COLUMN_NAME = "audio"
+    TARGET_TEXT_COLUMN = "raw_tgt_text"
+    TARGET_UNITS_COLUMN = "tgt_text"
+    TARGET_LANG_COLUMN = "tgt_lang"
+    ROOT_COLUMN = "_"
+    BATCH_WIDTH_STEP = 8
+
+    def __init__(
+        self,
+        config: DataLoadingConfig,
+        rank: int = 0,
+        world_size: int = 1,
+        target_device: torch.device = CPU_DEVICE,
+        float_dtype: torch.dtype = torch.float16,  # training/inference precision
+    ):
+        self.config = config
+        self.rank = rank
+        self.world_size = world_size
+        self.target_device = target_device
+        self.float_dtype = float_dtype
+        self._set_mkl_num_threads()
+        self.manifest_paths = list(self._iterate_manifest_paths())
+        self.text_tokenizer = self._init_text_tokenizer()
+        self.unit_tokenizer = self._init_unit_tokenizer()
+        self.spm_encoder = SentencePieceEncoder(model=self.text_tokenizer.model, suffix_tokens=["</s>"])
+        self.text_prefix_tokens = self._build_text_tgt_prefixes()
+        self.unit_prefix_tokens = self._build_unit_tgt_prefixes()
+        if self.config.fixed_batch_size is None:
+            self.tgt_text_batch_shapes = self._calculate_tgt_text_batch_shapes()
+        else:
+            self.tgt_text_batch_shapes = []
+
+        self.pipeline = self._build_pipeline()
+
+    @classmethod
+    def _set_mkl_num_threads(cls):
+        """ Setting mkl num threads to 1, so that we don't get thread explosion."""
+        mkl_rt = ctypes.CDLL('libmkl_rt.so')
+        mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(1)))
+
+    def _calculate_tgt_text_batch_shapes(self) -> List[Tuple[int, int]]:
+        max_seq_len = self.config.max_tgt_text_tokens_per_sample
+        max_tokens_per_batch = self.config.max_tgt_text_tokens_per_batch
+        assert max_tokens_per_batch is not None, "max_tokens_per_batch is not set"
+        step = self.BATCH_WIDTH_STEP
+        bucket_sizes = []
+        for seq_len in range(step, max(step, max_seq_len) + 1, step):
+            bsz = max(1, max_tokens_per_batch // seq_len)
+            bucket_sizes.append((bsz, seq_len))
+        return bucket_sizes
+
+    def _build_text_tgt_prefixes(self) -> Dict[str, List[int]]:
+        return {
+            lang_tok: self.text_tokenizer.create_encoder(
+                lang=lang_tok, mode="target"
+            ).prefix_indices.tolist()  # type:ignore
+            for lang_tok in self.text_tokenizer.langs
+        }
+
+    def _build_unit_tgt_prefixes(self) -> Dict[str, List[int]]:
+        assert self.unit_tokenizer.vocab_info.eos_idx is not None
+        return {
+            lang_tok: [
+                self.unit_tokenizer.vocab_info.eos_idx,
+                self.unit_tokenizer.lang_to_index(lang_tok),
+            ]
+            for lang_tok in self.unit_tokenizer.langs
+        }  # type: ignore
+
+    def _init_text_tokenizer(self) -> Union[NllbTokenizer, SPMTokenizer]:
+        if self.config.text_tokenization.from_model is not None:
+            return load_unity_text_tokenizer(self.config.text_tokenization.from_model)
+        else:
+            assert self.config.text_tokenization.langtoks is not None
+            assert self.config.text_tokenization.spm_path is not None
+            return SPMTokenizer(
+                pathname=self.config.text_tokenization.spm_path, langs=self.config.text_tokenization.langtoks
+            )
+
+    def _init_unit_tokenizer(self) -> UnitTokenizer:
+        if self.config.unit_tokenization.from_model is not None:
+            return load_unity_unit_tokenizer(self.config.unit_tokenization.from_model)
+        else:
+            raise NotImplementedError("TBD")
+
+    def _load_manifest_list_from_file(self) -> Iterator[str]:
+        if self.config.manifest_list_path is not None:
+            for line in open(self.config.manifest_list_path).readlines():
+                line = line.split("#")[0].strip()  # allow comments
+                if line:
+                    yield line
+
+    def _load_raw_manifest_list(self) -> List[str]:
+        raw_list = []
+        if self.config.manifest_list is not None:
+            raw_list += self.config.manifest_list.strip().split(",")
+        raw_list += list(self._load_manifest_list_from_file())
+        return raw_list
+
+    def _infer_manifest_full_path(self, manifest_name: str) -> str:
+        full_path = manifest_name.strip()
+        if self.config.manifest_path_prefix is not None:
+            full_path = os.path.join(self.config.manifest_path_prefix.strip(), full_path)
+        if not full_path.endswith(self.MANIFEST_EXT) and not os.path.exists(full_path):
+            full_path += self.MANIFEST_EXT
+        if not os.path.exists(full_path):
+            raise FileNotFoundError(f"File not found {full_path}")
+        return full_path
+
+    def _iterate_manifest_paths(self, skip_missing_files: bool = True) -> Iterator[str]:
+        """Yields full paths to manifests described in the data config.
+        Check that each file exist.
+        Expects *.tsv files"""
+        raw_list = self._load_raw_manifest_list()
+        for manifest_name in raw_list:
+            try:
+                full_path = self._infer_manifest_full_path(manifest_name=manifest_name)
+            except FileNotFoundError:
+                if skip_missing_files:
+                    logger.warning(f"Skipping manifest {manifest_name}, file not found")
+                    continue
+                raise
+            yield full_path
+
+    def _read_column_names(self, manifest_path: str) -> List[str]:
+        """Gets the order of columns in the manifest file.
+        Also checks that expected columns are present."""
+        with open(manifest_path, "r") as in_fp:
+            column_names = in_fp.readline().strip().split("\t")
+        for column in [
+            self.AUDIO_COLUMN_NAME,
+            self.TARGET_TEXT_COLUMN,
+            self.TARGET_UNITS_COLUMN,
+            self.TARGET_LANG_COLUMN,
+        ]:
+            if column not in column_names:
+                raise ValueError(f"Column `{column}` is not present in `{manifest_path}` ")
+        return column_names
+
+    def _builder_from_manifest(self, manifest_path: str) -> DataPipelineBuilder:
+        """Creates a data pipeline builder for the specified manifest_path file."""
+        logger.debug(f"Initialiazing samples loader from {manifest_path}")
+
+        # Memory map file and read it in text mode (skip empty lines if any).
+        # Skip header.
+        tsv_lines = (
+            read_text(
+                pathname=manifest_path,
+                encoding="UTF-8",
+                rtrim=True,
+                skip_empty=True,
+                memory_map=True,
+            )
+            .skip(1)
+            .and_return()
+        )
+
+        # Assing column names:
+        # line content: `_`
+        # source manifest path: `manifest_path`
+        # line number: `lineno`
+        line_numbers = DataPipeline.count().and_return()
+        filename_const = DataPipeline.constant(manifest_path).and_return()
+        pipeline = DataPipeline.zip(
+            [tsv_lines, filename_const, line_numbers],
+            names=[self.ROOT_COLUMN, "manifest_path", "lineno"],
+            zip_to_shortest=True,
+        )
+
+        # Read every `world_size`th line starting from `rank`th item in the file.
+        pipeline.shard(self.rank, self.world_size)
+
+        if self.config.shuffle_window is not None:
+            pipeline.shuffle(self.config.shuffle_window)
+
+        # Split each text line into its fields.
+        fields = self._read_column_names(manifest_path)
+        logger.debug(f"Column names: {fields}")
+        txt_splitter = StrSplitter(sep=self.MANIFEST_COLUMN_SEP, names=fields, indices=[], exclude=True)
+        pipeline.map(
+            txt_splitter,
+            selector=self.ROOT_COLUMN,
+            num_parallel_calls=self.config.num_threads,
+        )
+        # And, create the pipeline for the TSV file.
+        return pipeline
+
+    def _get_manifest_funnel(self) -> DataPipelineBuilder:
+        """Creates a joined pipeline from all manifests.
+        Picks samples from per-manifest pipelines in a round-robin order"""
+        # TODO: add the ability to upsample/downsample manifests
+        logger.info(f"Aggregating data from {len(self.manifest_paths)} manifests")
+        builders = [self._builder_from_manifest(manifest_path=path) for path in self.manifest_paths]
+        pipelines = [builder.and_return() for builder in builders]
+        return DataPipeline.round_robin(pipelines=pipelines)
+
+    def _attach_audio(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
+        """Attaches audio waveforms and fbanks from linked autio files"""
+        audio_selector = f"{self.ROOT_COLUMN}.{self.AUDIO_COLUMN_NAME}"
+        audio_data_selector = f"{audio_selector}.data"
+
+        # Memory map each `audio_file`
+        map_file = FileMapper(self.config.audio.audio_root_dir, cached_fd_count=100)
+        builder.map(
+            map_file,
+            selector=audio_selector,
+            num_parallel_calls=self.config.num_threads,
+        )
+
+        # Decode each mmap'ed audio file using libsndfile.
+        decode_audio = AudioDecoder(dtype=torch.float32)
+        builder.map(
+            decode_audio,
+            selector=audio_data_selector,
+            num_parallel_calls=self.config.num_threads,
+        )
+
+        # And, convert from waveform to log-mel filterbank
+        convert_to_fbank = WaveformToFbankConverter(
+            num_mel_bins=self.config.audio.fbanks_num_mel_bins,
+            waveform_scale=self.config.audio.fbanks_waveform_scale,
+            channel_last=True,  # audio channel is the last dimension in the waveform
+            standardize=self.config.audio.fbanks_standardize_audio,
+            keep_waveform=False,
+            device=self.target_device,
+            dtype=self.float_dtype,
+        )
+        builder.map(
+            convert_to_fbank,
+            selector=audio_data_selector,
+            num_parallel_calls=self.config.num_threads,
+        )
+        return builder
+
+    def _attach_target_tokens(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
+        # Convert `raw_tgt_text` to (full) target tokenized sequences:
+        #                   <eos> <lang_tok> <tokens .. > <eos>
+        # Lang tokens change between rows, so can't use static encoder
+        builder.map(
+            [self.spm_encoder],
+            selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
+            num_parallel_calls=self.config.num_threads,
+        )
+
+        # Convert the `tgt_text` field into a unit tensor + EOS
+        # TODO: We should use unit tokenizer.
+        # Motivation for the current implementation:
+        # 1) lang_tok can change between rows.
+        #       If we want to attach lang_token_id here, we need a way to join values from two columns
+        # 2) StrToTensorConverter doesn't allow suffix tokens. Adding it later is less covenient.
+        # 3) Not a computational blocker
+        convert_to_units = lambda units_str: (  # noqa: E731
+            torch.LongTensor(
+                [int(unit_id) + 4 for unit_id in units_str.rstrip().bytes().decode("utf-8").split()]
+                + [self.unit_tokenizer.vocab_info.eos_idx]
+            )
+        )
+        builder.map(
+            [convert_to_units],
+            selector=f"{self.ROOT_COLUMN}.{self.TARGET_UNITS_COLUMN}",
+            num_parallel_calls=self.config.num_threads,
+        )
+
+        # prefixes for tokenized texts and speech units (<eos> <lang_tok>)
+        prefix_builder = lambda lang_tok: torch.LongTensor(  # noqa: E731
+            [
+                self.text_prefix_tokens[lang_tok.bytes().decode("utf8")],
+                self.unit_prefix_tokens[lang_tok.bytes().decode("utf8")],
+            ]
+        )
+        builder.map(
+            [prefix_builder],
+            selector=f"{self.ROOT_COLUMN}.{self.TARGET_LANG_COLUMN}",
+            num_parallel_calls=self.config.num_threads,
+        )
+        return builder
+
+    def _get_input_audio_seconds(self, sample: Any) -> float:
+        audio_data = sample[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]
+        input_audio_sample_rate = audio_data["sample_rate"]
+        num_fbanks = max(audio_data["fbank"].shape)  # not guessing the dim order
+        # TODO: clarify where '* 2' comes from
+        waveform_length = num_fbanks * self.config.audio.fbanks_num_mel_bins * 2
+        input_audio_seconds = waveform_length / input_audio_sample_rate
+        return input_audio_seconds
+
+    def _is_long_sample(self, sample: Any) -> bool:
+        # input audio length
+        if self._get_input_audio_seconds(sample) > self.config.max_seconds_per_input_audio:
+            return True
+
+        # target text tokens
+        num_tgt_text_tokens = sample[self.ROOT_COLUMN][self.TARGET_TEXT_COLUMN].shape[-1]
+        if num_tgt_text_tokens > self.config.max_tgt_text_tokens_per_sample:
+            return True
+
+        # target units
+        num_tgt_units = sample[self.ROOT_COLUMN][self.TARGET_UNITS_COLUMN].shape[-1]  # target units
+        if num_tgt_units > self.config.max_units_per_sample:
+            return True
+        return False
+
+    def _filter_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
+        # Drop long samples
+        builder.filter(lambda sample: not self._is_long_sample(sample))
+        return builder
+
+    def _batch_samples(self, builder: DataPipelineBuilder) -> DataPipelineBuilder:
+        if self.config.fixed_batch_size is not None:
+            builder.bucket(bucket_size=self.config.fixed_batch_size)
+        elif self.tgt_text_batch_shapes is not None:
+            builder.bucket_by_length(
+                self.tgt_text_batch_shapes,
+                selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
+            )
+        else:
+            raise ValueError("Unclear batching strategy")
+        # Collate bucketed elements into a batch.
+        collater = Collater(
+            pad_to_multiple=1,
+            overrides=[
+                CollateOptionsOverride(
+                    selector=f"{self.ROOT_COLUMN}.{self.AUDIO_COLUMN_NAME}.data.fbank",
+                    pad_idx=self.config.fbank_feats_pad_idx,
+                ),
+                CollateOptionsOverride(
+                    selector=f"{self.ROOT_COLUMN}.{self.TARGET_TEXT_COLUMN}",
+                    pad_idx=self.text_tokenizer.vocab_info.pad_idx,
+                ),
+                CollateOptionsOverride(
+                    selector=f"{self.ROOT_COLUMN}.{self.TARGET_UNITS_COLUMN}",
+                    pad_idx=self.unit_tokenizer.vocab_info.pad_idx,
+                ),
+            ],
+        )
+        builder.map(collater, num_parallel_calls=self.config.num_threads)
+        if self.config.prefech_batches is not None:
+            builder.prefetch(self.config.prefech_batches)
+        return builder
+
+    def _build_pipeline(self) -> DataPipeline:
+        data = self._get_manifest_funnel()
+        data = self._attach_audio(data)
+        data = self._attach_target_tokens(data)
+        data = self._filter_samples(data)
+        batches = self._batch_samples(data)
+        return batches.and_return()
+
+    def _gen_prev_toks_target_toks_target_lens(
+        self, seqs: Any, prefix_tokens: torch.Tensor, pad_idx: int, eos_idx: int
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+        # <eos> <lang_tok> ... <eos> <pad>*
+        tokens = torch.cat((prefix_tokens, seqs["seqs"]), 1)
+        target_lengths = seqs["seq_lens"] + 1  # + <leng_tok>
+
+        prev_output_tokens = torch.clone(tokens)
+        # replace last <eos> with <pad> and remove last column
+        mask = prev_output_tokens == eos_idx
+        mask[:, 0] = 0
+        prev_output_tokens[mask] = pad_idx
+        prev_output_tokens = prev_output_tokens[:, :-1]
+
+        target_tokens = tokens[:, 1:]
+        assert torch.equal(torch.count_nonzero(prev_output_tokens != pad_idx, dim=1), target_lengths)
+        assert torch.equal(torch.count_nonzero(target_tokens != pad_idx, dim=1), target_lengths)
+        return prev_output_tokens, target_tokens, target_lengths
+
+    def _get_text_to_units_batch(self, raw_batch: Any) -> SeqsBatch:
+        root = raw_batch[self.ROOT_COLUMN]
+        seqs = root[self.TARGET_UNITS_COLUMN]
+        prefix_tokens = root[self.TARGET_LANG_COLUMN][:, 1, :]
+        pad_idx = self.unit_tokenizer.vocab_info.pad_idx
+        eos_idx = self.unit_tokenizer.vocab_info.eos_idx
+        assert pad_idx is not None
+        assert eos_idx is not None
+
+        (
+            prev_output_tokens,
+            target_tokens,
+            target_lengths,
+        ) = self._gen_prev_toks_target_toks_target_lens(
+            seqs=seqs,
+            prefix_tokens=prefix_tokens,
+            pad_idx=pad_idx,
+            eos_idx=eos_idx,
+        )
+
+        return SeqsBatch(
+            src_tokens=None,
+            src_lengths=None,
+            target_tokens=target_tokens.to(self.target_device),
+            prev_output_tokens=prev_output_tokens.to(self.target_device),
+            target_lengths=target_lengths.to(self.target_device),
+            prefix_tokens=prefix_tokens.to(self.target_device),
+        )
+
+    def _get_speech_src_tokens_and_lengths(self, raw_batch: Any) -> Tuple[torch.Tensor, torch.Tensor]:
+        fbanks = raw_batch[self.ROOT_COLUMN][self.AUDIO_COLUMN_NAME]["data"]["fbank"]
+        return fbanks["seqs"].to(self.float_dtype), fbanks["seq_lens"]
+
+    def _get_speech_to_text_batch(self, raw_batch: Any) -> SeqsBatch:
+        root = raw_batch[self.ROOT_COLUMN]
+        seqs = root[self.TARGET_TEXT_COLUMN]
+        prefix_tokens = root[self.TARGET_LANG_COLUMN][:, 0, :]
+        pad_idx = self.text_tokenizer.vocab_info.pad_idx
+        assert pad_idx is not None
+        eos_idx = self.text_tokenizer.vocab_info.eos_idx
+        assert eos_idx is not None
+
+        (
+            prev_output_tokens,
+            target_tokens,
+            target_lengths,
+        ) = self._gen_prev_toks_target_toks_target_lens(
+            seqs=seqs,
+            prefix_tokens=prefix_tokens,
+            pad_idx=pad_idx,
+            eos_idx=eos_idx,
+        )
+        src_tokens, src_lengths = self._get_speech_src_tokens_and_lengths(raw_batch=raw_batch)
+
+        return SeqsBatch(
+            src_tokens=src_tokens.to(self.target_device),
+            src_lengths=src_lengths.to(self.target_device),
+            target_tokens=target_tokens.to(self.target_device),
+            prev_output_tokens=prev_output_tokens.to(self.target_device),
+            target_lengths=target_lengths.to(self.target_device),
+            prefix_tokens=prefix_tokens.to(self.target_device),
+        )
+
+    def _convert_to_mulitmodal_seqs_batch(self, raw_batch: Any) -> MultimodalSeqsBatch:
+        return MultimodalSeqsBatch(
+            speech_to_text=self._get_speech_to_text_batch(raw_batch=raw_batch),
+            text_to_units=self._get_text_to_units_batch(raw_batch=raw_batch),
+        )
+
+    def iterate_batches(self) -> Iterator[MultimodalSeqsBatch]:
+        for raw_batch in self.pipeline:
+            yield self._convert_to_mulitmodal_seqs_batch(raw_batch)
+
+    def reset(self) -> None:
+        self.pipeline.reset()
+
+
+if __name__ == "__main__":
+    logging.basicConfig(
+        level=logging.INFO,
+        format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
+    )
+    config = DataLoadingConfig(
+        audio=AudioProcessingConfig(
+            audio_root_dir="/fsx-ust/data/audio_zips/",
+        ),
+        manifest_path_prefix="/fsx-ust/spopuri/datasets/S2ST/V1/M4T_V1_phase2/primary",
+        manifest_list_path="/data/home/mavlyutov/train_manifests.txt",
+        shuffle_window=1000,
+        num_threads=5,
+    )
+    loader = UnityDataLoader(config=config, target_device=torch.device("cpu"))
+    for idx, batch in enumerate(loader.iterate_batches()):
+        if idx % 10 == 0:
+            assert batch.speech_to_text.src_tokens is not None
+            print(batch.speech_to_text.src_tokens.shape)
+            logger.info(f".. pulled {idx} batches")
+            if idx > 1000:
+                break

+ 76 - 0
scripts/m4t/train/dist_utils.py

@@ -0,0 +1,76 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+import os
+from datetime import timedelta
+from typing import List
+
+import torch
+import torch.distributed as dist
+import torch.multiprocessing
+
+logger = logging.getLogger(__name__)
+
+
+def is_dist_initialized() -> bool:
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_rank() -> int:
+    if not is_dist_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def get_local_rank() -> int:
+    if not is_dist_initialized():
+        return 0
+    return int(os.environ["LOCAL_RANK"])
+
+
+def get_world_size() -> int:
+    if not is_dist_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def is_main_process() -> bool:
+    return get_rank() == 0
+
+
+def init_distributed(loggers: List[logging.Logger]) -> None:
+    """Initializes the distributed backend"""
+    torch.multiprocessing.set_start_method("spawn")
+    if "RANK" not in os.environ:
+        logger.error(
+            "Cannot init disributed context, as environment varaibles are not set."
+        )
+        return
+    rank = int(os.environ["RANK"])
+    world_size = int(os.environ["WORLD_SIZE"])
+    local_rank = int(os.environ["LOCAL_RANK"])
+    logger.info(
+        f"Rank={rank} local rank={local_rank}, world_size={world_size}, is_master={rank == 0}"
+    )
+    dist.init_process_group(
+        backend="nccl",
+        init_method="env://",
+        world_size=world_size,
+        rank=rank,
+        timeout=timedelta(seconds=180),
+    )
+    logger.info(f"Setting cuda:{local_rank} as main device")
+    if not is_main_process():
+        for to_mute in loggers:
+            to_mute.setLevel(logging.ERROR)
+    torch.cuda.set_device(local_rank)
+    dist.barrier()

+ 79 - 0
scripts/m4t/train/install_devfair.sh

@@ -0,0 +1,79 @@
+
+#  The script is installing seamless_communication (internal) + fairseq2 on AWS cluster.
+
+set -e
+set -x
+
+echo "Installing Conda"
+export TGT=`echo ~/seacom`
+rm -rf $TGT
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -qO /tmp/conda.sh
+bash /tmp/conda.sh -bp $TGT
+export CONDA=$TGT/bin/conda
+export CONDA_ACTIVATE=$TGT/bin/activate
+export ENV_N=sc_fr2
+echo "Next step will take ~15 minutes. Get some coffee" 
+module add cuda/11.8
+$CONDA create -y -n ${ENV_N} python=3.10 pytorch=2.0.1 pytorch-cuda=11.8 torchvision torchaudio \
+             compilers libsndfile==1.0.31 gcc==11.4.0 \
+    --strict-channel-priority --override-channels \
+    -c pytorch \
+    -c nvidia \
+    -c conda-forge
+
+echo "Setting LD_LIBRARY_PATH"
+. $CONDA_ACTIVATE activate ${ENV_N}
+if [ -z "$CONDA_PREFIX" ]; then 
+  echo "CONDA_PREFIX env var is not set!" 
+  exit 1
+else 
+   path=$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+   echo  "export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH\n"  >> ${path}
+fi
+. $CONDA_ACTIVATE activate ${ENV_N}  # update env vars
+
+#  Installing fairseq2.
+echo "Installing fairseq2"
+if [[ "${I_DONT_PLAN_TO_HACK_FAIRSEQ2:-No}" == "Yes" ]] ; then
+pip install fairseq2 \
+  --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/nightly/pt2.0.1/cu118
+else
+#  NOTICE: to compile CUDA kernels, you need NVCC. On AWS cluster an easy way would be to get a GPU container:
+#  srun -N 1 --gres=gpu:1 --cpus-per-task=20 --partition seamless --time 2400 --pty /bin/bash -l
+cd $TGT
+git clone --recurse-submodules  git@github.com:facebookresearch/fairseq2.git
+pip install -r fairseq2/fairseq2n/python/requirements-build.txt
+cd fairseq2
+pip install -e .  # it will install public fairseq2n, we rewrite it below
+cd fairseq2n
+args="-GNinja\
+  -DCMAKE_BUILD_TYPE=Release \
+  -DCMAKE_CUDA_ARCHITECTURES=80-real;80-virtual\
+  -DFAIRSEQ2N_INSTALL_STANDALONE=ON\
+  -DFAIRSEQ2N_PERFORM_LTO=ON\
+  -DFAIRSEQ2N_TREAT_WARNINGS_AS_ERRORS=OFF\
+  -DFAIRSEQ2N_USE_CUDA=ON\
+  -DFAIRSEQ2N_BUILD_PYTHON_BINDINGS=ON\
+  -DFAIRSEQ2N_PYTHON_DEVEL=OFF"
+cmake ${args} -B build
+cmake --build build
+cd python && pip install .
+fi
+# Quick test
+python -c "from fairseq2n.bindings.data.string import CString as CString"
+
+# Has to go before fairseq2 to make sure that it will not reinstall fairseq2n
+echo "Installing seamless_communication"
+cd $TGT
+git clone git@github.com:fairinternal/seamless_communication.git
+cd seamless_communication
+pip install -e .   # editable mode for hacking
+
+echo "One more time re-install fairseq2n (most propably overriden by seamless_communication)"
+cd $TGT/fairseq2/fairseq2n/python
+pip install .
+
+
+echo "Finished."
+echo "To activate the environment run: . $CONDA_ACTIVATE activate ${ENV_N}"
+echo "Location of seamless_communication checkout: $TGT/seamless_communication"

+ 90 - 0
scripts/m4t/train/install_fairaws.sh

@@ -0,0 +1,90 @@
+
+#  The script is installing seamless_communication (internal) + fairseq2 on AWS cluster.
+
+set -e
+set -x
+
+echo "Installing Conda"
+export TGT=`echo ~/seacom_aws_dev`
+rm -rf $TGT
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -qO /tmp/conda.sh
+bash /tmp/conda.sh -bp $TGT
+export CONDA=$TGT/bin/conda
+export CONDA_ACTIVATE=$TGT/bin/activate
+export ENV_N=sc_fr2_dev
+echo "Next step will take ~15 minutes. Get some coffee" 
+$CONDA create -y -n ${ENV_N} python=3.10 pytorch=2.0.1 pytorch-cuda=11.8 torchvision torchaudio \
+             compilers libsndfile==1.0.31 gcc==11.4.0 \
+    --strict-channel-priority --override-channels \
+    -c https://aws-ml-conda.s3.us-west-2.amazonaws.com \
+    -c pytorch \
+    -c nvidia \
+    -c conda-forge
+
+echo "Setting LD_LIBRARY_PATH"
+. $CONDA_ACTIVATE activate ${ENV_N}
+if [ -z "$CONDA_PREFIX" ]; then 
+  echo "CONDA_PREFIX env var is not set!" 
+  exit 1
+else 
+   path=$CONDA_PREFIX/etc/conda/activate.d/env_vars.sh
+   echo  "export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH\n"  >> ${path}
+fi
+. $CONDA_ACTIVATE activate ${ENV_N}  # update env vars
+
+
+#  NOTICE: to compile CUDA kernels, you need NVCC. On AWS cluster an easy way would be to get a GPU container:
+#  srun -N 1 --gres=gpu:1 --cpus-per-task=20 --partition seamless --time 2400 --pty /bin/bash -l
+
+#  Installing fairseq2.
+echo "Installing fairseq2"
+set -e
+rm -rf fairseq2  # wipe existing clones
+if [[ "${I_DONT_PLAN_TO_HACK_FAIRSEQ2:-No}" == "Yes" ]] ; then
+pip install fairseq2 \
+  --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/whl/nightly/pt2.0.1/cu118
+else
+nvidia-smi || echo "to compile CUDA kernels, you need NVCC.\n \
+   On AWS cluster an easy way would be to get a GPU container.\n \
+   Run smth like 'srun -N 1 --gres=gpu:1 --cpus-per-task=20 --partition seamless --time 2400 --pty /bin/bash -l' \n \
+   and continue from "Installing fairseq2" line. \
+   Terminating for now."
+nvidia-smi || exit 1
+cd $TGT
+. $CONDA_ACTIVATE activate ${ENV_N}
+git clone --recurse-submodules  git@github.com:facebookresearch/fairseq2.git
+pip install -r fairseq2/fairseq2n/python/requirements-build.txt
+cd fairseq2
+pip install -e .  # it will install public fairseq2n, we rewrite it below
+cd fairseq2n
+args="-GNinja\
+  -DCMAKE_BUILD_TYPE=Release \
+  -DCMAKE_CUDA_ARCHITECTURES=80-real;80-virtual\
+  -DFAIRSEQ2N_INSTALL_STANDALONE=ON\
+  -DFAIRSEQ2N_PERFORM_LTO=ON\
+  -DFAIRSEQ2N_TREAT_WARNINGS_AS_ERRORS=OFF\
+  -DFAIRSEQ2N_USE_CUDA=ON\
+  -DFAIRSEQ2N_BUILD_PYTHON_BINDINGS=ON\
+  -DFAIRSEQ2N_PYTHON_DEVEL=OFF"
+cmake ${args} -B build
+cmake --build build
+cd python && pip install .
+fi
+# Quick test
+python -c "from fairseq2n.bindings.data.string import CString as CString"
+
+echo "Installing seamless_communication"
+cd $TGT
+git clone git@github.com:fairinternal/seamless_communication.git
+cd seamless_communication
+pip install -e .   # editable mode for hacking
+
+
+echo "One more time re-install fairseq2n (most propably overriden by seamless_communication)"
+cd $TGT/fairseq2/fairseq2n/python
+pip install .
+
+
+echo "Finished."
+echo "To activate the environment run: . $CONDA_ACTIVATE activate ${ENV_N}"
+echo "Location of seamless_communication checkout: $TGT/seamless_communication"

+ 258 - 0
scripts/m4t/train/model.py

@@ -0,0 +1,258 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+import os
+from typing import Dict, Any
+
+import torch
+from m4t_scripts.train.configs import CustomModelParams, ModelConfig
+
+from seamless_communication.models.unity import (
+    UnitYConfig,
+    UnitYModel,
+    load_unity_model,
+    create_unity_model,
+)
+from seamless_communication.models.unity.loader import load_unity_config
+from seamless_communication.models.unity import UnitYT2UConfig
+from fairseq2.nn.transformer import TransformerNormOrder
+from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
+from fairseq2.models.nllb.builder import NllbConfig
+from fairseq2.models.utils.checkpoint_loader import convert_model_state_dict
+from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
+from seamless_communication.models.unity.loader import UnitYLoader
+
+from fairseq2.models.nllb.loader import NllbLoader
+
+logger = logging.getLogger(__name__)
+
+
+CPU_DEVICE = torch.device("cpu")
+
+
+class ModelBuilder:
+    def __init__(
+        self,
+        config: ModelConfig,
+        dtype: torch.dtype = torch.float16,
+        device: torch.device = CPU_DEVICE,
+    ):
+        self.config = config
+        self.dtype = dtype
+        self.device = device
+
+    @classmethod
+    def _sel_and_upd_prefix(cls, kv: Dict[str, Any], prefix: str, new_prefix: str = "") -> Dict[str, Any]:
+        # fmt: off
+        return {new_prefix + k[len(prefix):]: v for k, v in kv.items() if k.startswith(prefix)}
+        # fmt: on
+
+    @classmethod
+    def _load_pretrained_w2v2_encoder(cls, model: UnitYModel, checkpoint_path: str) -> None:
+        """Load w2v2 encoder model trained in fairseq1"""
+        logger.info(f"Loading w2v2 weights from {checkpoint_path}")
+        state_dict = torch.load(checkpoint_path)["model"]
+        key_map = Wav2Vec2Loader._fairseq_key_map()
+        key_map.update(
+            {
+                r"^encoder.layers\.([0-9]+)\.conv_module.batch_norm.": r"encoder.layers.\1.conv.batch_norm.",
+                r"^encoder.layers\.([0-9]+)\.conv_module.depthwise_conv.": r"encoder.layers.\1.conv.depthwise_conv.",
+                r"^encoder.layers\.([0-9]+)\.conv_module.pointwise_conv([0-9]+)\.": (
+                    r"encoder.layers.\1.conv.pointwise_conv\2."
+                ),
+                r"^encoder.layers\.([0-9]+)\.conv_module.layer_norm.": r"encoder.layers.\1.conv_layer_norm.",
+                r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.layer_norm.": r"encoder.layers.\1.ffn\2_layer_norm.",
+                r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.w_1.": r"encoder.layers.\1.ffn\2.inner_proj.",
+                r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.w_2.": r"encoder.layers.\1.ffn\2.output_proj.",
+                r"^encoder.layers\.([0-9]+)\.self_attn.linear_k\.": r"encoder.layers.\1.self_attn.k_proj.",
+                r"^encoder.layers\.([0-9]+)\.self_attn.linear_q\.": r"encoder.layers.\1.self_attn.q_proj.",
+                r"^encoder.layers\.([0-9]+)\.self_attn.linear_v\.": r"encoder.layers.\1.self_attn.v_proj.",
+                r"^encoder.layers\.([0-9]+)\.self_attn.linear_out\.": r"encoder.layers.\1.self_attn.output_proj.",
+                r"^encoder.layers\.([0-9]+)\.self_attn.linear_pos.weight": (
+                    r"encoder.layers.\1.self_attn.sdpa.r_proj.weight"
+                ),
+                r"^encoder.layers\.([0-9]+)\.self_attn.pos_bias_u": r"encoder.layers.\1.self_attn.sdpa.u_bias",
+                r"^encoder.layers\.([0-9]+)\.self_attn.pos_bias_v": r"encoder.layers.\1.self_attn.sdpa.v_bias",
+                # overrides existing rule
+                r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.layer_norm.",
+            }
+        )
+        state_dict = convert_model_state_dict(state_dict=state_dict, key_map=key_map)
+        # w2v2_encoder in fairseq2 have encoder layer_norm set to None
+        for rm_key in ["encoder.layer_norm.bias", "encoder.layer_norm.weight"]:
+            del state_dict[rm_key]
+        enc_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="encoder.")
+        model.speech_encoder.inner.load_state_dict(enc_state_dict, strict=True)  # type: ignore
+        logger.info(f"Loaded w2v2 encoder from {checkpoint_path}")
+
+        enc_fronted_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="encoder_frontend.")  # noqa
+        # TODO: reconcile discrepancies between fr1 and fr2 model designs
+        #  fr1-based w2v2 checkpoints with conv positional encoders use relpos self attention
+        #   this is not compatible with the fr2 model design
+        # model.speech_encoder_frontend.load_state_dict(enc_fronted_state_dict)
+        # logger.info(f"Loaded w2v2 encoder frontend from {checkpoint_path}")
+
+    @classmethod
+    def _load_pretrained_s2t_decoder(cls, model: UnitYModel, checkpoint_path: str) -> None:
+        """Load NLLB decoder trained in fairseq1"""
+        logger.info(f"Loading s2t decoder weights from {checkpoint_path}")
+        try:
+            state_dict = torch.load(checkpoint_path)["model"]
+        except ModuleNotFoundError:
+            logger.info("If seeing `No module named 'omegaconf'`, run `pip install omegaconf`")
+            raise
+        decoder_prefix = "decoder."
+        shared_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="shared_decoder.", new_prefix=decoder_prefix)
+        shared_state_dict = convert_model_state_dict(
+            state_dict=shared_state_dict, key_map=NllbLoader._fairseq_key_map()
+        )
+        for rm_key in ["decoder.embed_positions._float_tensor", "decoder.version"]:
+            del shared_state_dict[rm_key]
+        decoder_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix=decoder_prefix, new_prefix="")
+        frontend_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix="decoder_frontend.", new_prefix="")
+        proj_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix="final_proj.", new_prefix="")
+        model.text_decoder_frontend.load_state_dict(frontend_state, strict=True)
+        logger.info(f"Loaded s2t decoder frontend weights from {checkpoint_path}")
+        model.text_decoder.load_state_dict(decoder_state, strict=True)
+        logger.info(f"Loaded s2t decoder weights from {checkpoint_path}")
+        model.final_proj.load_state_dict(proj_state, strict=True)
+        logger.info(f"Loaded s2t decoder final_proj weights from {checkpoint_path}")
+
+    @classmethod
+    def _load_pretrained_t2u(cls, model: UnitYModel, model_config: UnitYConfig, checkpoint_path: str) -> None:
+        logger.info(f"Loading t2u weights from {checkpoint_path}")
+        t2u_model = model.t2u_model
+        assert t2u_model is not None
+        try:
+            state_dict = torch.load(checkpoint_path)["model"]
+        except ModuleNotFoundError:
+            logger.info("If seeing `No module named 'omegaconf'`, run `pip install omegaconf`")
+            raise
+        state_dict = {k.replace("encoder.", "synthesizer_encoder."): v for k, v in state_dict.items()}
+        state_dict = convert_model_state_dict(
+            state_dict=state_dict, key_map=UnitYLoader._fairseq_key_map(config=model_config)
+        )
+        t2u_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="t2u_model.", new_prefix="")
+        t2u_model.load_state_dict(t2u_state_dict)
+        logger.info(f"Loaded t2u weights from {checkpoint_path}")
+
+    def build_model(
+        self,
+    ) -> UnitYModel:
+        config = self.config
+        logger.info("Initializing model")
+        if config.from_model is not None:
+            logger.info(f"Loading model and weights from `{config.from_model}`")
+            return load_unity_model(config.from_model, device=self.device, dtype=self.dtype)
+
+        if config.from_model_config is not None:
+            logger.info(f"Loading Unity config from `{config.from_model_config}`")
+            model_config = load_unity_config(config.from_model_config)
+        elif config.custom_params is not None:
+            logger.info("Creating custom Unity config")
+            model_config = self._build_custom_model_config()
+        else:
+            raise ValueError("One of params from_model, from_model_config or custom_params has to be set")
+        logger.info("Building model")
+        model = create_unity_model(config=model_config, dtype=self.dtype, device=self.device)
+
+        if self.config.pretrained_w2v2_path is not None:
+            self._load_pretrained_w2v2_encoder(model, self.config.pretrained_w2v2_path)
+
+        if self.config.pretrained_s2t_decoder_path is not None:
+            self._load_pretrained_s2t_decoder(model, self.config.pretrained_s2t_decoder_path)
+
+        if self.config.pretrained_t2u_path is not None:
+            self._load_pretrained_t2u(model, model_config, self.config.pretrained_t2u_path)
+
+        return model
+
+    def _build_custom_model_config(self) -> UnitYConfig:
+        config = self.config.custom_params
+        assert config is not None
+        return UnitYConfig(
+            model_dim=config.model_embed_dim,
+            w2v2_encoder_config=Wav2Vec2EncoderConfig(
+                model_dim=config.model_embed_dim,
+                max_seq_len=4096,
+                feature_dim=160,
+                use_fbank=True,
+                first_pass_dropout_p=0.0,
+                layer_norm_features=config.w2v2_encoder_layers_layernorm_features,
+                feature_extractor_layer_descs=[],
+                feature_extractor_bias=False,
+                feature_extractor_layer_norm_convs=False,
+                feature_grad_scale=0,
+                num_fbank_channels=80,
+                fbank_stride=2,
+                sample_fbank_every_k=1,
+                pos_encoder_type=config.w2v2_pos_encoder_type,
+                pos_encoder_depth=config.w2v2_pos_encoder_depth,
+                pos_conv_kernel_size=config.w2v2_pos_conv_kernel_size,
+                num_pos_conv_groups=config.w2v2_num_pos_conv_groups,
+                use_conformer=config.w2v2_encoder_layers_use_conformer,
+                num_encoder_layers=config.w2v2_encoder_layers,
+                num_encoder_attn_heads=16,
+                ffn_inner_dim=config.model_embed_dim * 4,
+                dropout_p=0.0,
+                attn_dropout_p=0.0,
+                layer_drop_p=0.0,
+                norm_order=TransformerNormOrder.POST,
+                depthwise_conv_kernel_size=31,
+            ),
+            mt_model_config=NllbConfig(
+                model_dim=config.model_embed_dim,
+                max_seq_len=1024,
+                vocabulary_size=config.nllb_vocabulary_size,  # num_tokens + langs + spec symbols
+                pad_idx=0,
+                num_encoder_layers=config.nllb_encoder_layers,
+                num_decoder_layers=config.nllb_decoder_layers,
+                num_encoder_attn_heads=16,
+                num_decoder_attn_heads=16,
+                ffn_inner_dim=config.model_embed_dim * 8,
+                dropout_p=0.1,
+            ),
+            t2u_config=UnitYT2UConfig(
+                model_dim=config.model_embed_dim,
+                unit_max_seq_len=2048,
+                unit_vocabulary_size=config.unit_vocabulary_size,
+                unit_pad_idx=1,
+                num_encoder_layers=config.t2u_encoder_layers,
+                num_decoder_layers=config.t2u_decoder_layers,
+                nar_decoder_frontend_config=None,
+                nar_decoder_config=None,
+                num_encoder_attn_heads=16,
+                num_decoder_attn_heads=16,
+                ffn_inner_dim=config.model_embed_dim * 8,
+                dropout_p=0.1,
+            ),
+            use_text_encoder=True,
+            use_conformer_adaptor=False,
+            num_adaptor_layers=1,
+            adaptor_kernel_size=8,
+            adaptor_stride=8,
+            adaptor_layer_norm=True,
+            adaptor_dropout_p=0.1,
+        )
+
+
+if __name__ == "__main__":
+    logging.basicConfig(
+        level=logging.INFO,
+        format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
+    )
+    config = ModelConfig(
+        custom_params=CustomModelParams(
+            nllb_vocabulary_size=256103,
+        ),
+        pretrained_w2v2_path="/fsx-ust/spopuri/datasets/PT_CKPT/w2v2/w2vbert2rpq_600m_al5.pt",
+        pretrained_s2t_decoder_path="/fsx-ust/spopuri/datasets/PT_CKPT/S2T/S2T_M4T_V1_V1_cleaned.pt",
+        pretrained_t2u_path="/fsx-ust/spopuri/datasets/PT_CKPT/T2U/V5_10K_p2_14_80K.pt",
+    )
+    builder = ModelBuilder(config=config)
+    model = ModelBuilder(config=config).build_model()

+ 97 - 0
scripts/m4t/train/recipes/asr_small.yaml

@@ -0,0 +1,97 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_asr_only_aggregated_adapted
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 256102
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 0
+    w2v2_pos_conv_kernel_size: 0
+    w2v2_pos_encoder_depth: 0
+    w2v2_pos_encoder_type: relative
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000 
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  200 
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 97 - 0
scripts/m4t/train/recipes/asr_small_wh_transc.yaml

@@ -0,0 +1,97 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_asr_only_aggregated_adapted
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 1000
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 5
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    model_embed_dim: 768
+    nllb_decoder_layers: 3
+    nllb_encoder_layers: 1
+    nllb_vocabulary_size: 256102
+    t2u_decoder_layers: 1
+    t2u_encoder_layers: 1
+    unit_vocabulary_size: 10082
+    w2v2_encoder_layers: 6
+    w2v2_encoder_layers_layernorm_features: false
+    w2v2_encoder_layers_use_conformer: true
+    w2v2_num_pos_conv_groups: 0
+    w2v2_pos_conv_kernel_size: 0
+    w2v2_pos_encoder_depth: 0
+    w2v2_pos_encoder_type: relative
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: null
+  pretrained_t2u_path: null
+  pretrained_w2v2_path: null
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: train_asr_only_aggregated_5_dial_filtered_adapted_wh_transc
+  manifest_list_path: null
+  manifest_path_prefix: /data/home/mavlyutov/s2t_ondevice/
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 4 
+  prefech_batches: null 
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000 
+  text_tokenization:
+    from_model: null
+    langtoks:
+    - eng
+    - rus
+    - hin
+    - por
+    - spa
+    spm_path: /data/home/mavlyutov/s2t_ondevice/vocab20k/5_5_20k.model
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 1000 
+  float_dtype: fp32
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps:  50 
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 74 - 0
scripts/m4t/train/recipes/large_M4T_v1.yaml

@@ -0,0 +1,74 @@
+eval_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: dev_fleurs_arb-eng,dev_fleurs_ben-eng,dev_fleurs_hin-eng,dev_fleurs_ind-eng,dev_fleurs_ita-eng,dev_fleurs_jpn-eng,dev_fleurs_por-eng,dev_fleurs_rus-eng,dev_fleurs_swh-eng,dev_fleurs_tha-eng,dev_fleurs_tur-eng,dev_fleurs_urd-eng,dev_fleurs_vie-eng,dev_fleurs_spa-eng,dev_fleurs_eng-arb,dev_fleurs_eng-ben,dev_fleurs_eng-hin,dev_fleurs_eng-ind,dev_fleurs_eng-ita,dev_fleurs_eng-jpn,dev_fleurs_eng-por,dev_fleurs_eng-rus,dev_fleurs_eng-swh,dev_fleurs_eng-tha,dev_fleurs_eng-tur,dev_fleurs_eng-urd,dev_fleurs_eng-vie,dev_fleurs_eng-spa
+  manifest_list_path: null
+  manifest_path_prefix: /fsx-ust/spopuri/datasets/S2ST/V1/M4T_V1_phase2/primary/
+  max_seconds_per_input_audio: 150
+  fixed_batch_size: 40
+  max_tgt_text_tokens_per_batch: null
+  max_tgt_text_tokens_per_sample: 3000
+  max_units_per_sample: 1500
+  num_threads: 10 
+  prefech_batches: 10
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: seamlessM4T_large
+    spm_path: null
+    langtoks: null
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+model:
+  custom_params:
+    nllb_vocabulary_size: 256103
+  from_model: null
+  from_model_config: null
+  pretrained_s2t_decoder_path: /fsx-ust/spopuri/datasets/PT_CKPT/S2T/S2T_M4T_V1_V1_cleaned.pt
+  pretrained_t2u_path: /fsx-ust/spopuri/datasets/PT_CKPT/T2U/V5_10K_p2_14_80K.pt 
+  pretrained_w2v2_path: /fsx-ust/spopuri/datasets/PT_CKPT/w2v2/w2vbert2rpq_600m_al5.pt
+train_data:
+  audio:
+    audio_root_dir: /fsx-ust/data/audio_zips/
+    fbanks_num_mel_bins: 80
+    fbanks_standardize_audio: true
+    fbanks_waveform_scale: 32768
+  fbank_feats_pad_idx: 0
+  manifest_list: null 
+  manifest_list_path: /data/home/mavlyutov/train_configs/m4t_v1_train_manifests.txt
+  manifest_path_prefix: /fsx-ust/spopuri/datasets/S2ST/V1/M4T_V1_phase2/primary 
+  max_seconds_per_input_audio: 15
+  fixed_batch_size: null 
+  max_tgt_text_tokens_per_batch: 600
+  max_tgt_text_tokens_per_sample: 300
+  max_units_per_sample: 1500
+  num_threads: 10 
+  prefech_batches: 10
+  prepend_tgt_lang_tag: true
+  shuffle_window: 1000
+  text_tokenization:
+    from_model: seamlessM4T_large
+    spm_path: null
+    langtoks: null
+  unit_tokenization:
+    from_model: seamlessM4T_large
+    langtoks: null
+    num_units: null
+  unit_tokenizer_name: seamlessM4T_large
+training:
+  eval_steps: 5000 
+  float_dtype: fp16
+  label_smoothing: 0.2
+  learning_rate: 0.0001
+  log_steps: 200 
+  max_epochs: 100
+  patience: 10
+  start_learning_rate: 1.0e-07
+  warmup_steps: 1000

+ 94 - 0
scripts/m4t/train/recipes/m4t_v1_train_manifests.txt

@@ -0,0 +1,94 @@
+train_mc_eng-arb
+train_mc_eng-ita
+train_mc_eng-por
+train_mc_eng-rus
+train_mc_eng-spa
+train_mc_eng-tur
+train_mc_eng-vie
+train_cv11_eng-arb
+train_cv11_eng-ben
+train_cv11_eng-hin
+train_cv11_eng-ind
+train_cv11_eng-ita
+train_cv11_eng-jpn
+train_cv11_eng-por
+train_cv11_eng-rus
+train_cv11_eng-spa
+train_cv11_eng-swh
+train_cv11_eng-tha
+train_cv11_eng-tur
+train_cv11_eng-urd
+train_cv11_eng-vie
+train_epst_eng-ita
+train_epst_eng-por
+train_epst_eng-spa
+train_licds2s_eng-vie
+train_cv12_arb-eng
+train_masc_arb-eng
+train_mtedx_arb-eng
+train_shaip_arb-eng
+train_slr108_arb-eng
+train_css10_spa-eng
+train_cv12_spa-eng
+train_epst_spa-eng
+train_mls_spa-eng
+train_mtedx_spa-eng
+train_slr108_spa-eng
+train_vpsr_spa-eng
+train_vpst_spa-eng
+train_cv12_hin-eng
+train_slr118_hin-eng
+train_speechocean_hin-eng
+train_cv12_ind-eng
+train_mdata-c_ind-eng
+train_mdata-s_ind-eng
+train_shaip_ind-eng
+train_speechocean_ind-eng
+train_tt221213_ind-eng
+train_bbl_tur-eng
+train_cv12_tur-eng
+train_mdata-s_tur-eng
+train_slr108_tur-eng
+train_speechocean_tur-eng
+train_tt221213_tur-eng
+train_bbl_swh-eng
+train_cv12_swh-eng
+train_shaip_swh-eng
+train_css10_rus-eng
+train_cv12_rus-eng
+train_mtedx_rus-eng
+train_ruls_rus-eng
+train_bbl_ben-eng
+train_bbl_vie-eng
+train_css10_jpn-eng
+train_epst_ita-eng
+train_epst_por-eng
+train_fosd_vie-eng
+train_kokoro_jpn-eng
+train_mdata-s_jpn-eng
+train_mdata-s_tha-eng
+train_mls_ita-eng
+train_mls_por-eng
+train_mtedx_ita-eng
+train_mtedx_por-eng
+train_reazonspeech-m_jpn-eng
+train_shaip_ben-eng
+train_shaip_jpn-eng
+train_shaip_tha-eng
+train_shaip_vie-eng
+train_slr53_ben-eng
+train_speechocean_urd-eng
+train_tt221213_jpn-eng
+train_tt221213_tha-eng
+train_vivos_vie-eng
+train_vpsr_ita-eng
+train_vpst_ita-eng
+train_cv12_ben-eng
+train_cv12_ita-eng
+train_cv12_jpn-eng
+train_cv12_por-eng
+train_cv12_tha-eng
+train_cv12_urd-eng
+train_cv12_vie-eng
+train_speechocean_urd_2-eng
+train_licds2s_vie-eng

+ 118 - 0
scripts/m4t/train/run_training.py

@@ -0,0 +1,118 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import os
+import platform
+import shutil
+import time
+from pathlib import Path
+from typing import List
+
+import torch
+import yaml
+from m4t_scripts.train import dataloader as _dataloader
+from m4t_scripts.train import dist_utils
+from m4t_scripts.train import model as _model
+from m4t_scripts.train import trainer as _trainer
+from m4t_scripts.train.configs import WorkflowParams
+
+logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
+logging.basicConfig(
+    level=logging.INFO,
+    format=logging_format,
+)
+
+logger = logging.getLogger("train")
+
+
+def init_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(description="Run M4T training")
+    parser.add_argument(
+        "--wd",
+        type=Path,
+        required=True,
+        help="Work directory, where logs, checkpoints and core dumps will be stored",
+    )
+    parser.add_argument(
+        "--params",
+        type=Path,
+        required=True,
+        help="Config with training parameters",
+    )
+    return parser
+
+
+def run_training(parameters: WorkflowParams, work_dir: str, checkpoint_dir: str) -> None:
+    logger.info(f"Workflow params: {parameters}")
+    rank, world_size = dist_utils.get_rank(), dist_utils.get_world_size()
+    logger.info(f"Rank: {rank}, world_size: {world_size}")
+    assert torch.cuda.device_count() > 0, "GPU is not available"
+    device = torch.device("cuda")
+    float_dtype = _trainer.UnitYTrainer._get_float_dtype(parameters.training.float_dtype)
+    logger.info(f"Device: {device}, float dtype: {float_dtype}")
+    model = _model.ModelBuilder(config=parameters.model, dtype=float_dtype, device=device).build_model()
+    logger.info(f"Model: {model}")
+    train_data = _dataloader.UnityDataLoader(
+        config=parameters.train_data, rank=rank, world_size=world_size, target_device=device, float_dtype=float_dtype
+    )
+    eval_data = _dataloader.UnityDataLoader(
+        config=parameters.eval_data, rank=rank, world_size=world_size, target_device=device, float_dtype=float_dtype
+    )
+    trainer = _trainer.UnitYTrainer(
+        model=model,
+        params=parameters.training,
+        train_data_loader=train_data,
+        eval_data_loader=eval_data,
+        chck_save_dir=checkpoint_dir,
+        device=device,
+    )
+    trainer.run()
+
+
+def get_loggers() -> List[logging.Logger]:
+    return [logger, _trainer.logger, _dataloader.logger, _model.logger, dist_utils.logger]
+
+
+def set_file_output_for_loggers(log_filename: str) -> None:
+    handler = logging.FileHandler(filename=log_filename, mode="a", delay=False)
+    formatter = logging.Formatter(logging_format)
+    handler.setFormatter(formatter)
+    for logger in get_loggers():
+        logger.handlers.append(handler)
+
+
+def main() -> None:
+    args = init_parser().parse_args()
+    dist_utils.init_distributed(get_loggers())
+    is_master = dist_utils.is_main_process()
+    with open(args.params, "r") as fp_in:
+        parameters = WorkflowParams.deserialize(yaml.load(fp_in, Loader=yaml.FullLoader))
+    ts = str(int(time.time()))
+    work_dir = args.wd
+    checkpoint_dir = os.path.join(work_dir, "checkpoints")
+    if not os.path.exists(checkpoint_dir) and is_master:
+        logger.info(f"Creating checkpoint dir: {checkpoint_dir}")
+        # checkpoint_dir is not going to be used before syncs downstream,
+        #   so don't expect racing condition, and don't run barrier
+        os.makedirs(checkpoint_dir)
+    config_path = os.path.join(work_dir, f"{ts}_config.yaml")
+    # copy to work dir to keep a snapshot of workflow config
+    if is_master:
+        shutil.copy(args.params, config_path)
+    log_path = os.path.join(work_dir, "train_log.txt")
+    logger.info(f"Set logging to {log_path}")
+    set_file_output_for_loggers(log_path)
+    try:
+        run_training(parameters=parameters, work_dir=work_dir, checkpoint_dir=checkpoint_dir)
+    except Exception:
+        # make sure that the stack tracke will be logged to log files
+        logger.exception("Training failed")
+
+
+if __name__ == "__main__":
+    main()

+ 166 - 0
scripts/m4t/train/run_with_slurm.py

@@ -0,0 +1,166 @@
+import argparse
+import logging
+import os
+import platform
+import shutil
+import subprocess
+import time
+from pathlib import Path
+
+
+logging_format = f"%(asctime)s - {platform.node()} - %(process)s - %(levelname)s - %(name)s: %(message)s"
+logging.basicConfig(
+    level=logging.INFO,
+    format=logging_format,
+)
+
+logger = logging.getLogger("train")
+
+
+def init_parser() -> argparse.ArgumentParser:
+    parser = argparse.ArgumentParser(description="Run M4T training")
+    parser.add_argument(
+        "-w",
+        type=Path,
+        required=True,
+        help="Work directory, where logs, checkpoints and core dumps will be stored",
+    )
+    parser.add_argument(
+        "-p",
+        type=Path,
+        required=True,
+        help="Training workflow config",
+    )
+    parser.add_argument(
+        "-n",
+        type=int,
+        required=False,
+        default=1,
+        help="Number of training nodes",
+    )
+    parser.add_argument(
+        "-c",
+        type=str,
+        required=False,
+        default="seamless",
+        help="Cluster partitions to use",
+    )
+    parser.add_argument(
+        "-j",
+        type=str,
+        required=False,
+        default="train",
+        help="Slurm job name",
+    )
+    return parser
+
+
+def prepare_sbatch_config(
+    job_name: str,
+    params_file: str,
+    num_nodes: int,
+    partitions: str,
+    work_dir: str,
+    cluster_logs_dir: str,
+    run_script: str,
+) -> str:
+    return f"""#!/bin/bash
+## job name
+#SBATCH --job-name={job_name}
+
+## filename for job standard output (stdout)
+## %j is the job id, %u is the user id
+#SBATCH --output={cluster_logs_dir}/%j.out
+
+## filename for job standard error output (stderr)
+#SBATCH --error={cluster_logs_dir}/%j.err
+
+## partition name
+#SBATCH --partition={partitions}
+
+## number of nodes
+#SBATCH --nodes={num_nodes}
+
+## number of nodes
+#SBATCH --gpus-per-node=8
+
+## number of cpus per task
+#SBATCH --cpus-per-task=96
+
+#SBATCH --gres=gpu:8
+
+## number of tasks per node
+#SBATCH --ntasks-per-node=1
+
+## amount of mem
+#SBATCH --mem 50G
+
+## amount of time in minutes
+#SBATCH --time 2400
+
+set -x
+export WANDB_DISABLED=true
+export HDF5_USE_FILE_LOCKING='FALSE'
+export PARENT=`/bin/hostname -s`
+export MPORT=24198
+export CHILDREN=`scontrol show hostnames $SLURM_JOB_NODELIST | grep -v $PARENT`
+export HOSTLIST="$PARENT $CHILDREN"
+echo $HOSTLIST
+export WORLD_SIZE=$SLURM_NTASKS
+srun --label bash -c 'which python && torchrun \\
+ --nproc_per_node=8 \\
+ --nnodes=$SLURM_JOB_NUM_NODES \\
+ --node_rank="$SLURM_PROCID" \\
+ --master_addr="$PARENT" \\
+ --master_port="$MPORT" \\
+ --log-dir={cluster_logs_dir} \\
+{run_script} --params {params_file}  --wd {work_dir}'
+"""
+
+
+def main() -> None:
+    args = init_parser().parse_args()
+    params_file = args.p
+    num_nodes = args.n
+    partitions = args.c
+    work_dir = args.w
+    job_name = args.j
+
+    assert job_name is not None
+    assert len(job_name.split()) == 1, "spaces in job name not allowed"
+    assert partitions and len(partitions.split()) == 1, "spaces in partitions not allowed"
+    assert os.path.exists(params_file), "config file is missing"
+    training_script_path = os.path.join(os.path.dirname(__file__), "run_training.py")
+    assert os.path.exists(training_script_path), f"Can't find training script {training_script_path}"
+    assert num_nodes > 0
+    if not os.path.exists(work_dir):
+        logger.info(f"Creating workdir {work_dir}")
+        os.makedirs(work_dir)
+    cluster_logs_dir = os.path.join(work_dir, "cluster_logs")
+    if os.path.exists(cluster_logs_dir):
+        logger.info(f"Clearing cluster logs dir {cluster_logs_dir}")
+        shutil.rmtree(cluster_logs_dir)
+    os.makedirs(cluster_logs_dir)
+    config_text = prepare_sbatch_config(
+        job_name=job_name,
+        params_file=params_file,
+        num_nodes=num_nodes,
+        partitions=partitions,
+        work_dir=work_dir,
+        cluster_logs_dir=cluster_logs_dir,
+        run_script=training_script_path,
+    )
+    logger.info(f"SBATCH config to launch: \n{config_text}")
+    fname = f"{int(time.time())}_sbatch.sh"
+    config_path = os.path.join(work_dir, fname)
+    with open(config_path, "w") as fp_out:
+        fp_out.write(config_text)
+        logger.info(f"Saved to {config_path}")
+    command = f"sbatch {config_path}"
+    logger.info(f"Executing command: '{command}'")
+    subprocess.Popen(command, shell=True).communicate()
+    logger.info(f"Train log: {os.path.join(work_dir, 'train_log.txt')}")
+
+
+if __name__ == "__main__":
+    main()

+ 394 - 0
scripts/m4t/train/trainer.py

@@ -0,0 +1,394 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import logging
+from typing import Any, Optional, Tuple, Dict, List
+
+import os
+import time
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from fairseq2.models.sequence import SequenceModelOutput
+from fairseq2.optim.lr_scheduler import MyleLR
+from m4t_scripts.train import dataloader, dist_utils
+from torch.optim import Adam
+
+from seamless_communication.models.unity import UnitYModel, UnitYT2UModel
+from m4t_scripts.train.configs import TrainingParams
+
+logger = logging.getLogger(__name__)
+
+
+class UnitYTrainWrapper(nn.Module):
+    """Convenience wrapper that does a forward pass
+    and returns S2T and T2U logits"""
+
+    def __init__(self, model: UnitYModel):
+        super().__init__()
+        self.model: UnitYModel = model
+        if isinstance(self.model.t2u_model, UnitYT2UModel):
+            self.t2u: UnitYT2UModel = self.model.t2u_model
+        else:
+            raise NotImplementedError("Expand UnitYTrainWrapper supports only instances of UnitYT2UModel as t2u")
+
+    def forward(self, batch: dataloader.MultimodalSeqsBatch) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Forward pass, computes S2T and T2U losses"""
+        assert self.model.t2u_model is not None
+        assert batch.speech_to_text.src_tokens is not None
+        # s2t
+        speech_encoder_out, speech_encoder_padding_mask = self.model.encode_speech(
+            seqs=batch.speech_to_text.src_tokens,
+            seq_lens=batch.speech_to_text.src_lengths,
+        )
+        assert batch.speech_to_text.prev_output_tokens is not None
+        text_decoder_out, text_decoder_padding_mask = self.model.decode(
+            seqs=batch.speech_to_text.prev_output_tokens,
+            seq_lens=batch.speech_to_text.target_lengths,
+            encoder_output=speech_encoder_out,
+            encoder_padding_mask=speech_encoder_padding_mask,
+        )
+        text_logits = self.model.final_proj(text_decoder_out)
+        # t2u
+        (
+            unit_encoder_out,
+            unit_encoder_padding_mask,
+        ) = self.t2u.encode(
+            text_decoder_output=text_decoder_out,
+            text_decoder_padding_mask=text_decoder_padding_mask,
+        )
+        unit_decoder_out, _ = self.t2u.decode(
+            seqs=batch.text_to_units.prev_output_tokens,
+            seq_lens=batch.text_to_units.target_lengths,
+            encoder_output=unit_encoder_out,
+            encoder_padding_mask=unit_encoder_padding_mask,
+        )
+        unit_logits = self.model.t2u_model.final_proj(unit_decoder_out)
+        return (text_logits, unit_logits)
+
+
+class CalcLoss:
+    """Calculates per-token negative log likelihood loss for S2T and T2U"""
+
+    def __init__(
+        self,
+        label_smoothing: float,
+        s2t_pad_idx: Optional[int],
+        t2u_pad_idx: Optional[int],
+        s2t_skip_langtok_loss: bool = False,
+    ):
+        self.label_smoothing = label_smoothing
+        self.s2t_pad_idx = s2t_pad_idx
+        self.t2u_pad_idx = t2u_pad_idx
+        self.s2t_ignore_prefix_size = 1 if s2t_skip_langtok_loss else 0
+        self.t2u_ignore_prefix_size = 1
+
+    def __call__(
+        self,
+        batch: dataloader.MultimodalSeqsBatch,
+        text_logits: torch.Tensor,
+        unit_logits: torch.Tensor,
+    ) -> torch.Tensor:
+        assert batch.speech_to_text.target_lengths is not None
+        s2t_numel = torch.sum(batch.speech_to_text.target_lengths).to(text_logits.device)
+        s2t_loss = SequenceModelOutput(logits=text_logits, pad_idx=self.s2t_pad_idx).compute_loss(
+            targets=batch.speech_to_text.target_tokens.to(text_logits.device),
+            ignore_prefix_size=self.s2t_ignore_prefix_size,
+            label_smoothing=self.label_smoothing,
+        )
+        assert batch.text_to_units.target_lengths is not None
+        s2u_numel = torch.sum(batch.text_to_units.target_lengths).to(unit_logits.device)
+        s2u_loss = SequenceModelOutput(logits=unit_logits, pad_idx=self.t2u_pad_idx).compute_loss(
+            targets=batch.text_to_units.target_tokens.to(unit_logits.device),
+            ignore_prefix_size=1,
+            label_smoothing=self.label_smoothing,
+        )
+        return s2t_loss / s2t_numel + s2u_loss / s2u_numel
+
+
+class LossCollector:
+    """Aggregrates loss history across nodes"""
+
+    def __init__(self, device: Optional[torch.device] = None, reduce_op: str = "avg"):
+        self.n_samples: float = 0
+        self.val_sum: float = 0.0
+        self.reduce_op = reduce_op
+        self.device = device
+        self.is_distributed = dist_utils.is_dist_initialized()
+
+    def reset(self) -> None:
+        self.n_samples = 0
+        self.val_sum = 0.0
+
+    def update(self, n_samples: int, batch_loss: float) -> None:
+        self.n_samples += n_samples
+        self.val_sum += batch_loss
+
+    def reduce(self) -> float:
+        n_samples, val_sum = self._collect()
+        if self.reduce_op == "avg":
+            return val_sum / (n_samples + 1)
+        if self.reduce_op == "sum":
+            return val_sum
+        raise ValueError()
+
+    def _collect(self) -> Tuple[float, float]:
+        if not self.is_distributed:
+            return self.n_samples, self.val_sum
+        local_val = torch.tensor([[self.n_samples, self.val_sum]], device=self.device)
+        all_vals = [torch.zeros((1, 2), device=self.device) for _ in range(dist_utils.get_world_size())]
+        dist.all_gather(all_vals, local_val)
+        losses = torch.concat(all_vals, dim=0)
+        reduced = torch.sum(losses, dim=0).reshape(2).cpu()
+        return reduced[0].item(), reduced[1].item()
+
+
+class UnitYTrainer:
+    CHECKPOINT_BEST = "checkpoint_best.pt"
+
+    def __init__(
+        self,
+        model: UnitYModel,
+        params: TrainingParams,
+        train_data_loader: dataloader.UnityDataLoader,
+        eval_data_loader: Optional[dataloader.UnityDataLoader],
+        chck_save_dir: str,
+        device: torch.device,
+    ):
+        self.params = params
+        self.device = device
+        self.float_dtype = self._get_float_dtype(self.params.float_dtype)
+        self.train_data_loader = train_data_loader
+        self.eval_data_loader = eval_data_loader
+        self.chck_save_dir = chck_save_dir
+
+        assert model.t2u_model is not None
+        self.calc_loss = CalcLoss(
+            label_smoothing=self.params.label_smoothing,
+            s2t_pad_idx=model.pad_idx,
+            t2u_pad_idx=model.t2u_model.pad_idx,
+        )
+        self._try_load_checkpoint(model=model)
+        self.model = self._wrap_model_for_trainining(model=model)
+
+        # TODO: make tweakable
+        self.optimizer = Adam(
+            params=self.model.parameters(),
+            lr=self.params.learning_rate,
+            betas=(0.9, 0.98),
+            eps=1e-08,
+            maximize=False,
+            weight_decay=0.0,
+            fused=True,
+        )
+
+        self.grad_scaler = torch.cuda.amp.GradScaler() if self.float_dtype == torch.float16 else None  # type: ignore
+
+        # TODO: allow scheduler selection
+        self.lr_scheduler = MyleLR(
+            optimizer=self.optimizer,
+            num_warmup_steps=self.params.warmup_steps,
+            start_lr=self.params.start_learning_rate,
+        )
+
+        self.train_loss_hist = LossCollector(device=self.device)
+        self.epoch_idx: int = 0
+        self.update_idx: int = 0
+        self.patience_left: int = self.params.patience
+        self.last_eval_loss: Optional[float] = None
+        self.best_eval_loss: Optional[float] = None
+        self.is_best_state: bool = False
+        self.batch_sizes: List[int] = []
+        self.gpu_usage: List[float] = []
+
+    def _try_load_checkpoint(self, model: torch.nn.Module):
+        chck_path = self.get_best_checkpoint_path()
+        if os.path.exists(chck_path):
+            logger.info(f"Loading state dict from {chck_path}")
+            state_dict = torch.load(chck_path)
+            model.load_state_dict(state_dict)
+
+    @classmethod
+    def _get_float_dtype(cls, float_dtype: str) -> torch.dtype:
+        if float_dtype == "fp16":
+            return torch.float16
+        elif float_dtype == "fp32":
+            return torch.float32
+        elif float_dtype == "bf16":
+            return torch.bfloat16
+        else:
+            raise ValueError(f"Unkown dtype literal: {float_dtype}")
+
+    def _reset_stats(self) -> None:
+        self.train_loss_hist.reset()
+        self.epoch_idx = 0
+        self.update_idx = 0
+        self.patience_left = self.params.patience
+        self.last_eval_loss = None
+        self.best_eval_loss = None
+        self.is_best_state = False
+        self._reset_log_stats()
+
+    def _reset_log_stats(self) -> None:
+        self.batch_sizes.clear()
+        self.gpu_usage.clear()
+        self.ts = time.time()
+        self.last_update_idx = self.update_idx
+
+    def _record_gpu_usage(self) -> None:
+        gb = (torch.cuda.memory_reserved(self.device) >> 20) / 1024.0
+        self.gpu_usage.append(gb)
+
+    def _get_avg_bsz(self) -> float:
+        """Avg training batch size"""
+        return sum(self.batch_sizes) / len(self.batch_sizes) if self.batch_sizes else 0.0
+
+    def _get_ups(self) -> float:
+        """Updates per second"""
+        ts_delta = time.time() - self.ts
+        return (self.update_idx - self.last_update_idx) / ts_delta
+
+    def _get_avg_gpu_usage(self) -> float:
+        return sum(self.gpu_usage) / len(self.gpu_usage) if self.gpu_usage else 0.0
+
+    def _wrap_model_for_trainining(self, model: UnitYModel) -> nn.Module:
+        wrapped_model = UnitYTrainWrapper(model=model)
+        if not dist_utils.is_dist_initialized():
+            return wrapped_model
+        return nn.parallel.DistributedDataParallel(
+            wrapped_model,
+            device_ids=[dist_utils.get_local_rank()],
+            find_unused_parameters=True,
+        )
+
+    def _update_eval_stats(self, eval_loss: float) -> None:
+        self.last_eval_loss = eval_loss
+        self.is_best_state = self.best_eval_loss is None or eval_loss < self.best_eval_loss
+        self.best_eval_loss = eval_loss if self.is_best_state else self.best_eval_loss
+        self.patience_left = self.params.patience if self.is_best_state else self.patience_left - 1
+        logger.info(
+            f"Eval after {self.update_idx} updates: "
+            f"loss={eval_loss:.4f} "
+            f"best_loss={self.best_eval_loss:.4f} "
+            f"patience_steps_left={self.patience_left}"
+        )
+
+    def _eval_model(self) -> None:
+        """Calc avg loss on eval dataset and update evaluation stats"""
+        if self.eval_data_loader is None:
+            return
+        logger.info("Run evaluation")
+        loss_hist = LossCollector(device=self.device)
+        self.model.eval()
+        with torch.no_grad():
+            self.eval_data_loader.reset()
+            for batch in self.eval_data_loader.iterate_batches():
+                assert batch.speech_to_text.src_tokens is not None
+                loss = self.calc_loss(batch, *self.model(batch))
+                if loss.isnan():
+                    logger.warning("Eval loss value is NaN, setting to inf")
+                    loss_val = float("Inf")
+                else:
+                    loss_val = loss.item()
+                del batch  # force memory release
+                loss_hist.update(1, loss_val)
+        eval_loss = loss_hist.reduce()
+        self._update_eval_stats(eval_loss)
+
+    def _train_step_log(self):
+        """Log train stats"""
+        if (self.update_idx + 1) % self.params.log_steps == 0:
+            avg_loss = self.train_loss_hist.reduce()
+            self.train_loss_hist.reset()
+            logger.info(
+                f"Epoch {str(self.epoch_idx + 1).zfill(3)} / "
+                f"update {str(self.update_idx + 1).zfill(5)}: "
+                f"train loss={avg_loss:.4f} "
+                f"last lr={self.lr_scheduler.get_last_lr()[0]:.2E} "
+                f"bsz_avg={self._get_avg_bsz():.1f} "
+                f"ups={self._get_ups():.2f} "
+                f"gpu_avg={self._get_avg_gpu_usage():.2f}Gb"
+            )
+            self._reset_log_stats()
+
+    def _train_step(self, batch: dataloader.MultimodalSeqsBatch) -> None:
+        """Run one train step"""
+        self.model.train()
+        self.optimizer.zero_grad()
+        tokens, units = self.model(batch)
+        loss = self.calc_loss(batch, tokens, units)
+        # peak of gpu usage
+        self._record_gpu_usage()
+
+        if self.grad_scaler is not None:
+            self.grad_scaler.scale(loss).backward()  # type: ignore
+            self.grad_scaler.step(self.optimizer)
+            self.grad_scaler.update()
+        else:
+            loss.backward()
+            self.optimizer.step()
+
+        self.lr_scheduler.step()
+        assert batch.speech_to_text.src_tokens is not None
+        self.train_loss_hist.update(1, loss.item())
+        self.batch_sizes.append(batch.speech_to_text.src_tokens.shape[0])
+        self._train_step_log()
+
+    def _get_state(self) -> Dict[str, Any]:
+        model_state_dict = self.model.state_dict()
+        model_state_dict = {key.replace("module.model.", ""): value for key, value in model_state_dict.items()}
+        return model_state_dict
+
+    def _get_chck_path(self) -> str:
+        ts = str(int(time.time()))
+        epoch = str(self.epoch_idx).zfill(3)
+        update = str(self.update_idx).zfill(6)
+        eval_loss = f"{self.last_eval_loss:.4f}"
+        name = f"{ts}_{epoch}_{update}_{eval_loss}.pt"
+        return os.path.join(self.chck_save_dir, name)
+
+    def _get_best_checkpoint_link_path(self) -> str:
+        return os.path.join(self.chck_save_dir, self.CHECKPOINT_BEST)
+
+    def get_best_checkpoint_path(self) -> str:
+        return os.path.realpath(self._get_best_checkpoint_link_path())
+
+    def _save_model(self):
+        if dist_utils.is_main_process():
+            state_dict = self._get_state()
+            save_path = self._get_chck_path()
+            logger.info(f"Saving checkpoint to {save_path}")
+            torch.save(state_dict, save_path)
+            if self.is_best_state:
+                best_link_path = self._get_best_checkpoint_link_path()
+                if os.path.exists(best_link_path):
+                    os.unlink(best_link_path)
+                os.symlink(save_path, best_link_path)
+                logger.info(f"Updating pointer to the best checkpoint {best_link_path} -> {save_path}")
+        if dist_utils.is_dist_initialized():
+            dist.barrier()
+
+    def run(self):
+        logger.info("Start training")
+        self._reset_stats()
+        self._eval_model()
+        while self.epoch_idx < self.params.max_epochs and self.patience_left:
+            for train_batch in self.train_data_loader.iterate_batches():
+                self._train_step(batch=train_batch)
+                if self.update_idx and self.update_idx % self.params.eval_steps == 0:
+                    self._eval_model()
+                    if self.is_best_state:
+                        self._save_model()
+                    elif not self.patience_left:
+                        no_improve_steps = self.params.eval_steps * self.params.patience
+                        logger.info(
+                            f"Early termination, as eval loss did not improve over last {no_improve_steps} updates"
+                        )
+                        break
+                self.update_idx += 1
+            self.train_data_loader.reset()
+            self.epoch_idx += 1

+ 105 - 0
src/seamless_communication/models/tokenizer.py

@@ -0,0 +1,105 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Optional, Sequence, Set, final
+
+from fairseq2.data.text import (
+    SentencePieceDecoder,
+    SentencePieceEncoder,
+    SentencePieceModel,
+    TextTokenDecoder,
+    TextTokenEncoder,
+    TextTokenizer,
+    vocabulary_from_sentencepiece,
+)
+from fairseq2.data.typing import PathLike
+from fairseq2.typing import Device, finaloverride
+
+
+@final
+class SPMTokenizer(TextTokenizer):
+    """Represents standard SPM-based tokenizer used in MT tasks"""
+
+    model: SentencePieceModel
+    langs: Set[str]
+    prepend_target_langtok_to_target: bool
+
+    def __init__(self, pathname: PathLike, langs: Sequence[str], prepend_target_langtok_to_target: bool = True) -> None:
+        """
+        :param pathname:
+            The pathname of the SentencePiece model file.
+        :param langs:
+            The list of supported languages.
+        :param default_lang:
+            The fall-back language if no language is specified.
+        """
+        self.langs = set(langs)
+        self.prepend_target_langtok_to_target = prepend_target_langtok_to_target
+
+        # Each language is represented by a `__lang__` control symbol.
+        control_symbols = [self._lang_tok_to_internal(lang) for lang in sorted(langs)]
+        self.model = SentencePieceModel(pathname, control_symbols)
+        vocab_info = vocabulary_from_sentencepiece(self.model)
+        super().__init__(vocab_info)
+
+    @classmethod
+    def _lang_tok_to_internal(cls, lang: str) -> str:
+        return f"__{lang}__"
+
+    @finaloverride
+    def create_encoder(
+        self,
+        *,
+        task: Optional[str] = None,
+        lang: Optional[str] = None,
+        mode: Optional[str] = None,
+        device: Optional[Device] = None,
+        pin_memory: bool = False,
+    ) -> TextTokenEncoder:
+        """Create a token encoder.
+
+        :param task:
+            Must be 'translation'. If ``None``, defaults to 'translation'.
+        :param lang:
+            A language from :attr:`langs`. If ``None``, defaults to
+            :attr:`default_lang`.
+        :param mode:
+            Must be 'source' or 'target'.
+        :param device:
+            The device on which to construct tensors.
+        :param pin_memory:
+            If ``True``, uses pinned memory while constructing tensors.
+        """
+        if task is not None and task != "translation":
+            raise ValueError(f"`task` must be 'translation', but is '{task}' instead.")
+
+        assert lang is not None
+
+        if lang not in self.langs:
+            raise ValueError(f"`lang` must be a supported language, but is '{lang}' instead.")
+
+        if mode is None or mode == "source":
+            prefix_tokens = []
+            suffix_tokens = ["</s>"]
+        elif mode == "target":
+            prefix_tokens = (
+                ["</s>"] + [self._lang_tok_to_internal(lang)] if self.prepend_target_langtok_to_target else []
+            )
+            suffix_tokens = ["</s>"]
+        else:
+            raise ValueError(f"`mode` must be 'source' or 'target', but is '{mode}' instead.")
+
+        return SentencePieceEncoder(
+            self.model,
+            prefix_tokens=prefix_tokens,
+            suffix_tokens=suffix_tokens,
+            device=device,
+            pin_memory=pin_memory,
+        )
+
+    @finaloverride
+    def create_decoder(self) -> TextTokenDecoder:
+        return SentencePieceDecoder(self.model)