Browse Source

Introduce Prosody encoder (#87)

Can Balioglu 1 year ago
parent
commit
05419775be

+ 1 - 1
pyproject.toml

@@ -12,7 +12,7 @@ per-file-ignores = [
 profile = "black"
 profile = "black"
 
 
 [tool.mypy]
 [tool.mypy]
-disable_error_code = "type-abstract"
+disable_error_code = "type-abstract,typeddict-unknown-key"
 disallow_untyped_calls = false
 disallow_untyped_calls = false
 disallow_untyped_decorators = false
 disallow_untyped_decorators = false
 ignore_missing_imports = true
 ignore_missing_imports = true

+ 51 - 0
src/seamless_communication/cards/seamless_expressivity.yaml

@@ -0,0 +1,51 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: seamless_expressivity
+base: unity_nllb-100
+model_arch: expressivity_v2
+char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
+checkpoint: "file://checkpoint/hygong/Expressivity/multilingual_models/m2m.clean.ecapa_tdnn2.dim512.all.all.lr5e-05.mk4k.config_t2_fbank_nosa_gcmvn_10k.rdrop0.ls0.2.uf3.wu5k.fp16.mem_fp16.seed1.dr0.1.ld0.2.mp0.3.cmp0.25.ma.ak8.as8.al1.ald0.0.dld0.0.ca.D24L.t2uE4L.t2uD4L.usesfilm.inj_dec.ngpu64/checkpoint_best_export.pt"
+num_units: 10000
+unit_langs:
+  - arb
+  - ben
+  - cat
+  - ces
+  - cmn
+  - cym
+  - dan
+  - deu
+  - eng
+  - est
+  - fin
+  - fra
+  - hin
+  - ind
+  - ita
+  - jpn
+  - kan
+  - kor
+  - mlt
+  - nld
+  - pes
+  - pol
+  - por
+  - ron
+  - rus
+  - slk
+  - spa
+  - swe
+  - swh
+  - tam
+  - tel
+  - tgl
+  - tha
+  - tur
+  - ukr
+  - urd
+  - uzn
+  - vie

+ 1 - 2
src/seamless_communication/cli/eval_utils/compute_metrics.py

@@ -10,17 +10,16 @@ from typing import Tuple, Union
 
 
 import pandas as pd
 import pandas as pd
 import whisper
 import whisper
-
 from fairseq2.typing import Device
 from fairseq2.typing import Device
 from jiwer import cer, wer
 from jiwer import cer, wer
 from sacrebleu.metrics.base import Score, Signature
 from sacrebleu.metrics.base import Score, Signature
 from sacrebleu.metrics.bleu import BLEU
 from sacrebleu.metrics.bleu import BLEU
 from sacrebleu.metrics.chrf import CHRF
 from sacrebleu.metrics.chrf import CHRF
-from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 from tqdm import tqdm
 from tqdm import tqdm
 from whisper import Whisper
 from whisper import Whisper
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 
 
+from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 
 
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,

+ 0 - 0
src/seamless_communication/cli/expressivity/__init__.py


+ 0 - 0
src/seamless_communication/cli/expressivity/evaluate/__init__.py


+ 423 - 0
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -0,0 +1,423 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import contextlib
+import logging
+import subprocess
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torchaudio
+from fairseq2.data import Collater, CString, DataPipeline, FileMapper
+from fairseq2.data.audio import (
+    AudioDecoder,
+    WaveformToFbankConverter,
+    WaveformToFbankOutput,
+)
+from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
+from fairseq2.data.typing import PathLike, StringLike
+from fairseq2.generation import SequenceGeneratorOptions
+from fairseq2.typing import DataType, Device
+from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
+from torch import Tensor
+from tqdm import tqdm
+
+from seamless_communication.cli.m4t.predict import (
+    add_inference_arguments,
+    set_generation_opts,
+)
+from seamless_communication.inference import BatchedSpeechOutput, Modality, Translator
+from seamless_communication.models.unity import load_unity_text_tokenizer
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class EvalContext:
+    task: str
+    """String representing the task. Valid choices are
+    "S2ST", "S2TT", "T2ST", "T2TT", "ASR"."""
+
+    output_modality: Modality
+    """The output modality of the task."""
+
+    model_name: str
+    """The name of the S2T UnitY model."""
+
+    data_file: Path
+    """The pathname of the test TSV data file."""
+
+    audio_root_dir: Optional[Path]
+    """The pathname of the directory under which
+    audio files are stored."""
+
+    target_lang: str
+    """The target translation language."""
+
+    source_lang: Optional[str]
+    """The source language."""
+
+    batch_size: int
+    """The batch size for model input."""
+
+    device: Device
+    """The device on which to run inference."""
+
+    dtype: DataType
+    """The data type with which to run inference."""
+
+    output_path: Path
+    """The pathname of the output directory to save
+    the evaluation results."""
+
+    ref_field: str
+    """The reference target text field to compute
+    the BLEU score against."""
+
+    text_generation_opts: SequenceGeneratorOptions
+    """Text generation hyperparameters."""
+
+    unit_generation_opts: Optional[SequenceGeneratorOptions]
+    """Unit generation hyperparameters, not applicable
+    for the NAR T2U decoder."""
+
+    unit_generation_ngram_filtering: bool
+    """If True, removes consecutive repeating ngrams
+    from the decoded unit output."""
+
+    gcmvn_stats: Optional[PathLike] = None
+    """the stats for gcmvn, used by Prosody Encoder"""
+
+
+def count_lines(filename: Path) -> int:
+    result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
+    return int(result.stdout.decode().split()[0])
+
+
+def build_data_pipeline(
+    ctx: EvalContext,
+    text_tokenizer: TextTokenizer,
+) -> DataPipeline:
+    with open(ctx.data_file, "r") as f:
+        header = f.readline().strip("\n").split("\t")
+
+    # TODO: This will be soon auto-tuned. Right now hand-tuned for devfair.
+    n_parallel = 4
+
+    split_tsv = StrSplitter(names=header)
+
+    if ctx.gcmvn_stats is not None:
+        if isinstance(ctx.gcmvn_stats, CString):
+            ctx.gcmvn_stats = str(ctx.gcmvn_stats)
+        gcmvn_stats: Dict[str, np.ndarray] = np.load(ctx.gcmvn_stats)  # type: ignore[type-arg]
+        gcmvn_mean = torch.tensor(
+            gcmvn_stats["mean"], device=ctx.device, dtype=ctx.dtype
+        )
+        gcmvn_std = torch.tensor(gcmvn_stats["std"], device=ctx.device, dtype=ctx.dtype)
+
+    pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
+
+    assert ctx.audio_root_dir is not None
+
+    map_file = FileMapper(root_dir=ctx.audio_root_dir, cached_fd_count=10)
+
+    pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
+
+    decode_audio = AudioDecoder(dtype=torch.float32, device=ctx.device)
+
+    convert_to_fbank = WaveformToFbankConverter(
+        num_mel_bins=80,
+        waveform_scale=2**15,
+        channel_last=True,
+        standardize=False,
+        device=ctx.device,
+        dtype=ctx.dtype,
+    )
+
+    def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
+        fbank = data["fbank"]
+        std, mean = torch.std_mean(fbank, dim=0)
+        data["fbank"] = fbank.subtract(mean).divide(std)
+        if ctx.gcmvn_stats is not None:
+            data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
+        return data
+
+    pipeline_builder.map(
+        [decode_audio, convert_to_fbank, normalize_fbank],
+        selector="audio.data",
+        num_parallel_calls=n_parallel,
+    )
+
+    pipeline_builder.bucket(bucket_size=ctx.batch_size)
+
+    collate = Collater(pad_value=0, pad_to_multiple=1)
+
+    pipeline_builder.map(collate, num_parallel_calls=n_parallel)
+
+    pipeline_builder.prefetch(4)
+
+    return pipeline_builder.and_return()
+
+
+def adjust_output_for_corrupted_inputs(
+    valid_sequences: Tensor,
+    text_output: List[StringLike],
+    speech_output: Optional[BatchedSpeechOutput],
+) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
+    adjusted_text_output: List[StringLike] = []
+    adjusted_speech_output: Optional[BatchedSpeechOutput] = None
+
+    if speech_output is not None:
+        assert (
+            len(text_output)
+            == len(speech_output.units)
+            == len(speech_output.audio_wavs)
+        )
+        adjusted_speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
+
+    batch_counter = 0
+    for is_valid in valid_sequences:
+        if is_valid:
+            adjusted_text_output.append(text_output[batch_counter])
+            if speech_output is not None:
+                assert adjusted_speech_output is not None
+                adjusted_speech_output.units.append(speech_output.units[batch_counter])
+                adjusted_speech_output.audio_wavs.append(
+                    speech_output.audio_wavs[batch_counter]
+                )
+            batch_counter += 1
+        else:
+            # For the corrupted inputs, we save the following dummy outputs:
+            # empty string for text, empty list for units, 1 second of silence for audio.
+            adjusted_text_output.append("")
+            if adjusted_speech_output is not None:
+                sample_rate = adjusted_speech_output.sample_rate
+                adjusted_speech_output.units.append([])
+                adjusted_speech_output.audio_wavs.append(
+                    torch.zeros(sample_rate).unsqueeze(0).unsqueeze(0)
+                )
+    return (
+        adjusted_text_output,
+        adjusted_speech_output,
+    )
+
+
+def run_eval(
+    translator: Translator, text_tokenizer: TextTokenizer, ctx: EvalContext
+) -> None:
+    pipeline = build_data_pipeline(ctx, text_tokenizer)
+
+    total_steps = count_lines(ctx.data_file) - 1
+    progress_bar = tqdm(total=total_steps)
+
+    output_path = ctx.output_path / ctx.data_file.stem
+    output_path.mkdir(parents=True, exist_ok=True)
+
+    if ctx.output_modality == Modality.SPEECH:
+        waveforms_dir = output_path / f"waveform_{ctx.data_file.stem}"
+        waveforms_dir.mkdir(parents=True, exist_ok=True)
+
+    hyps = []
+    refs = []
+
+    with contextlib.ExitStack() as stack:
+        hyp_file = stack.enter_context(
+            open(output_path / f"text_output-{ctx.data_file.stem}.txt", "w")
+        )
+        if ctx.output_modality == Modality.SPEECH:
+            unit_file = stack.enter_context(
+                open(output_path / f"unit_output-{ctx.data_file.stem}.txt", "w")
+            )
+
+        sample_id = 0
+        for example in pipeline:
+            valid_sequences: Optional[Tensor] = None
+            src = example["audio"]["data"]["fbank"]
+            # Skip corrupted audio tensors.
+            valid_sequences = ~torch.any(
+                torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
+            )
+            if not valid_sequences.all():
+                logger.warning(
+                    f"Sample IDs {sample_id} to {sample_id + ctx.batch_size} has some corrupted input."
+                )
+                src["seqs"] = src["seqs"][valid_sequences]
+                src["seq_lens"] = src["seq_lens"][valid_sequences]
+
+            # Skip performing inference when the input is entirely corrupted.
+            if src["seqs"].numel() > 0:
+                (
+                    text_output,
+                    speech_output,
+                ) = translator.predict(
+                    src,
+                    ctx.task,
+                    ctx.target_lang,
+                    src_lang=ctx.source_lang,
+                    text_generation_opts=ctx.text_generation_opts,
+                    unit_generation_opts=ctx.unit_generation_opts,
+                    unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
+                    gcmvn_fbank=example["audio"]["data"].get("gcmvn_fbank", None),
+                )
+            else:
+                text_output = []
+                if ctx.output_modality == Modality.SPEECH:
+                    speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
+                else:
+                    speech_output = None
+
+            if valid_sequences is not None and not valid_sequences.all():
+                (
+                    text_output,
+                    speech_output,
+                ) = adjust_output_for_corrupted_inputs(
+                    valid_sequences,
+                    text_output,
+                    speech_output,
+                )
+
+            hyps += [str(s) for s in text_output]
+            refs += [str(s) for s in example[ctx.ref_field]]
+
+            for i in range(len(text_output)):
+                t = text_output[i]
+                hyp_file.write(f"{t}\n")
+
+                if ctx.output_modality == Modality.SPEECH:
+                    assert speech_output is not None
+                    u = speech_output.units[i]
+                    str_units = [str(i) for i in u]
+                    unit_file.write(" ".join(str_units) + "\n")
+                    torchaudio.save(
+                        waveforms_dir / f"{sample_id}_pred.wav",
+                        speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
+                        sample_rate=speech_output.sample_rate,
+                    )
+
+                sample_id += 1
+                progress_bar.update(1)
+
+    progress_bar.close()
+    logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
+
+    assert len(hyps) == len(refs)
+    if len(hyps) > 0:
+        if ctx.target_lang in ("cmn", "jpn", "lao", "mya", "tha"):
+            tokenizer = "char"
+        else:
+            tokenizer = "13a"
+
+        bleu = BLEU(tokenize=tokenizer)
+        score = bleu.corpus_score(hyps, [refs])
+        bleu_filename = output_path / f"{ctx.data_file.stem}_text_output_bleu.json"
+        with open(bleu_filename, "w") as f:
+            f.write(score.format(signature=str(bleu.get_signature()), is_json=True))
+        logger.info(score.format(signature=bleu.get_signature()))
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(
+        description="Expressivity evaluation for tasks supported by Translator."
+    )
+    parser.add_argument("data_file", type=str, help="Data file (.tsv) to be evaluated.")
+
+    parser = add_inference_arguments(parser)
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        help="Inference batch size.",
+        default=4,
+    )
+    parser.add_argument(
+        "--audio_root_dir",
+        type=str,
+        help="Root directory for the audio filenames in the data file.",
+        default="",
+    )
+    parser.add_argument(
+        "--ref_field",
+        type=str,
+        help="Reference target text field to compute the BLEU score against.",
+        default="tgt_text",
+    )
+    parser.add_argument(
+        "--gcmvn_stats",
+        type=str,
+        help="The path to gcmvn fbank stats, if provided, the DataPipeline'd have another copy of gcmvn fbank features (for P2V enc)",
+        default=None,
+    )
+    args = parser.parse_args()
+
+    input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
+
+    if input_modality == Modality.SPEECH and not Path(args.audio_root_dir).exists():
+        raise ValueError(
+            f"Invalid audio_root_dir: {args.audio_root_dir} for speech input."
+        )
+
+    if torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        dtype = torch.float32
+    else:
+        device = torch.device("cpu")
+        dtype = torch.float32
+
+    text_tokenizer = load_unity_text_tokenizer(args.model_name)
+
+    # TODO: Avoid loading the T2U model, vocoder when the output
+    # modality is text.
+    translator = Translator(
+        args.model_name,
+        args.vocoder_name,
+        device,
+        text_tokenizer=text_tokenizer,
+        dtype=dtype,
+    )
+
+    text_generation_opts, unit_generation_opts = set_generation_opts(args)
+
+    logger.info(f"{text_generation_opts=}")
+    logger.info(f"{unit_generation_opts=}")
+    logger.info(
+        f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
+    )
+
+    # fmt: off
+    ctx = EvalContext(
+        task=args.task,
+        output_modality=output_modality,
+        model_name=args.model_name,
+        data_file=Path(args.data_file),
+        audio_root_dir=Path(args.audio_root_dir),
+        target_lang=args.tgt_lang,
+        source_lang=args.src_lang,
+        batch_size=args.batch_size,
+        device=device,
+        dtype=dtype,
+        ref_field=args.ref_field,
+        text_generation_opts=text_generation_opts,
+        unit_generation_opts=unit_generation_opts,
+        unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
+        output_path=Path(args.output_path),
+        gcmvn_stats=args.gcmvn_stats,
+    )
+    # fmt: on
+    logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
+
+    run_eval(translator, text_tokenizer, ctx)
+
+
+if __name__ == "__main__":
+    main()

+ 8 - 2
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -267,7 +267,10 @@ def run_eval(
 
 
             # Skip performing inference when the input is entirely corrupted.
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
             if src["seqs"].numel() > 0:
-                (text_output, speech_output,) = translator.predict(
+                (
+                    text_output,
+                    speech_output,
+                ) = translator.predict(
                     src,
                     src,
                     ctx.task,
                     ctx.task,
                     ctx.target_lang,
                     ctx.target_lang,
@@ -284,7 +287,10 @@ def run_eval(
                     speech_output = None
                     speech_output = None
 
 
             if valid_sequences is not None and not valid_sequences.all():
             if valid_sequences is not None and not valid_sequences.all():
-                (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
+                (
+                    text_output,
+                    speech_output,
+                ) = adjust_output_for_corrupted_inputs(
                     valid_sequences,
                     valid_sequences,
                     text_output,
                     text_output,
                     speech_output,
                     speech_output,

+ 10 - 2
src/seamless_communication/inference/generator.py

@@ -153,6 +153,7 @@ class UnitYGenerator:
         input_modality: str = "speech",
         input_modality: str = "speech",
         output_modality: str = "speech",
         output_modality: str = "speech",
         ngram_filtering: bool = False,
         ngram_filtering: bool = False,
+        gcmvn_seqs: Optional[Tensor] = None,
     ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
     ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
         """
         """
         :param source_seqs:
         :param source_seqs:
@@ -215,6 +216,12 @@ class UnitYGenerator:
         assert self.unit_decoder is not None
         assert self.unit_decoder is not None
 
 
         unit_gen_output = None
         unit_gen_output = None
+        prosody_encoder_out = None
+        if self.model.prosody_encoder_model is not None:
+            prosody_encoder_out = self.model.prosody_encoder_model(
+                gcmvn_seqs, source_padding_mask
+            ).unsqueeze(1)
+
         if isinstance(self.model.t2u_model, UnitYT2UModel):
         if isinstance(self.model.t2u_model, UnitYT2UModel):
             assert self.unit_generator is not None
             assert self.unit_generator is not None
             t2u_encoder_output, t2u_encoder_padding_mask = self.model.t2u_model.encode(
             t2u_encoder_output, t2u_encoder_padding_mask = self.model.t2u_model.encode(
@@ -231,6 +238,7 @@ class UnitYGenerator:
                 text_decoder_output=decoder_output,
                 text_decoder_output=decoder_output,
                 text_decoder_padding_mask=decoder_padding_mask,
                 text_decoder_padding_mask=decoder_padding_mask,
                 text_seqs=text_seqs,
                 text_seqs=text_seqs,
+                film_cond_emb=prosody_encoder_out,
             )
             )
             # (B, S_unit, V_unit)
             # (B, S_unit, V_unit)
             unit_seqs = unit_decoder_output.logits.argmax(dim=2)
             unit_seqs = unit_decoder_output.logits.argmax(dim=2)
@@ -243,8 +251,8 @@ class UnitYGenerator:
         units = self.unit_decoder(unit_seqs)
         units = self.unit_decoder(unit_seqs)
 
 
         if ngram_filtering:
         if ngram_filtering:
-            units = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
-            units = torch.tensor(units)
+            arr = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
+            units = torch.tensor(arr)
 
 
         unit_output = SequenceToUnitOutput(units, unit_gen_output)
         unit_output = SequenceToUnitOutput(units, unit_gen_output)
 
 

+ 11 - 4
src/seamless_communication/inference/translator.py

@@ -7,7 +7,7 @@ import logging
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum, auto
 from enum import Enum, auto
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Dict, Callable, List, Optional, Tuple, Union, cast
+from typing import Callable, List, Optional, Tuple, Union, cast
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
@@ -144,6 +144,7 @@ class Translator(nn.Module):
         text_generation_opts: SequenceGeneratorOptions,
         text_generation_opts: SequenceGeneratorOptions,
         unit_generation_opts: Optional[SequenceGeneratorOptions],
         unit_generation_opts: Optional[SequenceGeneratorOptions],
         unit_generation_ngram_filtering: bool = False,
         unit_generation_ngram_filtering: bool = False,
+        gcmvn_fbank: Optional[SequenceData] = None,
     ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
     ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
         # We disregard unit generations opts for the NAR T2U decoder.
         # We disregard unit generations opts for the NAR T2U decoder.
         if output_modality != Modality.SPEECH or isinstance(
         if output_modality != Modality.SPEECH or isinstance(
@@ -160,12 +161,18 @@ class Translator(nn.Module):
             unit_opts=unit_generation_opts,
             unit_opts=unit_generation_opts,
         )
         )
         seqs, padding_mask = get_seqs_and_padding_mask(src)
         seqs, padding_mask = get_seqs_and_padding_mask(src)
+        if gcmvn_fbank is not None:
+            gcmvn_seqs = gcmvn_fbank["seqs"]
+        else:
+            gcmvn_seqs = None
+
         return generator(
         return generator(
             seqs,
             seqs,
             padding_mask,
             padding_mask,
             input_modality.value,
             input_modality.value,
             output_modality.value,
             output_modality.value,
             ngram_filtering=unit_generation_ngram_filtering,
             ngram_filtering=unit_generation_ngram_filtering,
+            gcmvn_seqs=gcmvn_seqs,
         )
         )
 
 
     @staticmethod
     @staticmethod
@@ -188,7 +195,7 @@ class Translator(nn.Module):
     @torch.inference_mode()
     @torch.inference_mode()
     def predict(
     def predict(
         self,
         self,
-        input: Union[str, Tensor, Dict[str, Any]],
+        input: Union[str, Tensor, SequenceData],
         task_str: str,
         task_str: str,
         tgt_lang: str,
         tgt_lang: str,
         src_lang: Optional[str] = None,
         src_lang: Optional[str] = None,
@@ -201,6 +208,7 @@ class Translator(nn.Module):
         spkr: Optional[int] = -1,
         spkr: Optional[int] = -1,
         sample_rate: int = 16000,
         sample_rate: int = 16000,
         unit_generation_ngram_filtering: bool = False,
         unit_generation_ngram_filtering: bool = False,
+        gcmvn_fbank: Optional[SequenceData] = None,
     ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
     ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
         """
         """
         The main method used to perform inference on all tasks.
         The main method used to perform inference on all tasks.
@@ -231,8 +239,6 @@ class Translator(nn.Module):
         input_modality, output_modality = self.get_modalities_from_task_str(task_str)
         input_modality, output_modality = self.get_modalities_from_task_str(task_str)
 
 
         if isinstance(input, dict):
         if isinstance(input, dict):
-            assert "seqs" in input
-            assert "seq_lens" in input
             src = cast(SequenceData, input)
             src = cast(SequenceData, input)
         elif input_modality == Modality.SPEECH:
         elif input_modality == Modality.SPEECH:
             audio = input
             audio = input
@@ -282,6 +288,7 @@ class Translator(nn.Module):
             text_generation_opts,
             text_generation_opts,
             unit_generation_opts,
             unit_generation_opts,
             unit_generation_ngram_filtering=unit_generation_ngram_filtering,
             unit_generation_ngram_filtering=unit_generation_ngram_filtering,
+            gcmvn_fbank=gcmvn_fbank,
         )
         )
 
 
         if output_modality == Modality.TEXT:
         if output_modality == Modality.TEXT:

+ 16 - 0
src/seamless_communication/models/pretssel/__init__.py

@@ -0,0 +1,16 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN as ECAPA_TDNN
+from seamless_communication.models.pretssel.ecapa_tdnn_builder import (
+    EcapaTDNNBuilder as EcapaTDNNBuilder,
+)
+from seamless_communication.models.pretssel.ecapa_tdnn_builder import (
+    EcapaTDNNConfig as EcapaTDNNConfig,
+)
+from seamless_communication.models.pretssel.ecapa_tdnn_builder import (
+    ecapa_tdnn_archs as ecapa_tdnn_archs,
+)

+ 477 - 0
src/seamless_communication/models/pretssel/ecapa_tdnn.py

@@ -0,0 +1,477 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from fairseq2.nn.padding import PaddingMask, to_padding_mask
+from torch import Tensor
+from torch.nn import Conv1d, LayerNorm, Module, ModuleList, ReLU, Sigmoid, Tanh, init
+
+
+class ECAPA_TDNN(Module):
+    """
+    Represents the ECAPA-TDNN model described in paper:
+    :cite:t`https://doi.org/10.48550/arxiv.2005.07143`.
+
+    Arguments
+    ---------
+    :param channels:
+        Output channels for TDNN/SERes2Net layer.
+    :param kernel_sizes:
+        List of kernel sizes for each layer.
+    :param dilations:
+        List of dilations for kernels in each layer.
+    :param groups:
+        List of groups for kernels in each layer.
+    """
+
+    def __init__(
+        self,
+        channels: List[int],
+        kernel_sizes: List[int],
+        dilations: List[int],
+        attention_channels: int,
+        res2net_scale: int,
+        se_channels: int,
+        global_context: bool,
+        groups: List[int],
+        embed_dim: int,
+        input_dim: int,
+    ):
+        super().__init__()
+        assert len(channels) == len(kernel_sizes) == len(dilations)
+        self.channels = channels
+        self.embed_dim = embed_dim
+        self.blocks = ModuleList()
+
+        self.blocks.append(
+            TDNNBlock(
+                input_dim,
+                channels[0],
+                kernel_sizes[0],
+                dilations[0],
+                groups[0],
+            )
+        )
+
+        # SE-Res2Net layers
+        for i in range(1, len(channels) - 1):
+            self.blocks.append(
+                SERes2NetBlock(
+                    channels[i - 1],
+                    channels[i],
+                    res2net_scale=res2net_scale,
+                    se_channels=se_channels,
+                    kernel_size=kernel_sizes[i],
+                    dilation=dilations[i],
+                    groups=groups[i],
+                )
+            )
+
+        # Multi-layer feature aggregation
+        self.mfa = TDNNBlock(
+            channels[-1],
+            channels[-1],
+            kernel_sizes[-1],
+            dilations[-1],
+            groups=groups[-1],
+        )
+
+        # Attentive Statistical Pooling
+        self.asp = AttentiveStatisticsPooling(
+            channels[-1],
+            attention_channels=attention_channels,
+            global_context=global_context,
+        )
+        self.asp_norm = LayerNorm(channels[-1] * 2, eps=1e-12)
+
+        # Final linear transformation
+        self.fc = Conv1d(
+            in_channels=channels[-1] * 2,
+            out_channels=embed_dim,
+            kernel_size=1,
+        )
+
+        self.reset_parameters()
+
+    def reset_parameters(self) -> None:
+        """Reset the parameters and buffers of the module."""
+
+        def encoder_init(m: Module) -> None:
+            if isinstance(m, Conv1d):
+                init.xavier_uniform_(m.weight, init.calculate_gain("relu"))
+
+        self.apply(encoder_init)
+
+    def forward(
+        self,
+        x: Tensor,
+        padding_mask: Optional[PaddingMask] = None,
+    ) -> Tensor:
+        """Returns the embedding vector.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of shape (batch, time, channel).
+        """
+        # Minimize transpose for efficiency
+        x = x.transpose(1, 2)
+
+        xl = []
+        for layer in self.blocks:
+            x = layer(x, padding_mask=padding_mask)
+            xl.append(x)
+
+        # Multi-layer feature aggregation
+        x = torch.cat(xl[1:], dim=1)
+        x = self.mfa(x)
+
+        # Attentive Statistical Pooling
+        x = self.asp(x, padding_mask=padding_mask)
+        x = self.asp_norm(x.transpose(1, 2)).transpose(1, 2)
+
+        # Final linear transformation
+        x = self.fc(x)
+
+        x = x.transpose(1, 2).squeeze(1)  # B x C
+        return F.normalize(x, dim=-1)
+
+
+class TDNNBlock(Module):
+    """An implementation of TDNN.
+
+    Arguments
+    ----------
+    :param in_channels : int
+        Number of input channels.
+    :param out_channels : int
+        The number of output channels.
+    :param kernel_size : int
+        The kernel size of the TDNN blocks.
+    :param dilation : int
+        The dilation of the TDNN block.
+    :param groups: int
+        The groups size of the TDNN blocks.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
+    >>> out_tensor = layer(inp_tensor).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        dilation: int,
+        groups: int = 1,
+    ):
+        super().__init__()
+        self.conv = Conv1d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            dilation=dilation,
+            padding=dilation * (kernel_size - 1) // 2,
+            groups=groups,
+        )
+        self.activation = ReLU()
+        self.norm = LayerNorm(out_channels, eps=1e-12)
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        x = self.activation(self.conv(x))
+
+        return self.norm(x.transpose(1, 2)).transpose(1, 2)  # type: ignore[no-any-return]
+
+
+class Res2NetBlock(Module):
+    """An implementation of Res2NetBlock w/ dilation.
+
+    Arguments
+    ---------
+    :param in_channels : int
+        The number of channels expected in the input.
+    :param out_channels : int
+        The number of output channels.
+    :param scale : int
+        The scale of the Res2Net block.
+    :param kernel_size: int
+        The kernel size of the Res2Net block.
+    :param dilation : int
+        The dilation of the Res2Net block.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
+    >>> out_tensor = layer(inp_tensor).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        scale: int = 8,
+        kernel_size: int = 3,
+        dilation: int = 1,
+    ):
+        super().__init__()
+        assert in_channels % scale == 0
+        assert out_channels % scale == 0
+
+        in_channel = in_channels // scale
+        hidden_channel = out_channels // scale
+        self.blocks = ModuleList(
+            [
+                TDNNBlock(
+                    in_channel,
+                    hidden_channel,
+                    kernel_size=kernel_size,
+                    dilation=dilation,
+                )
+                for i in range(scale - 1)
+            ]
+        )
+        self.scale = scale
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        y = []
+        for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
+            if i == 0:
+                y_i = x_i
+            elif i == 1:
+                y_i = self.blocks[i - 1](x_i)
+            else:
+                y_i = self.blocks[i - 1](x_i + y_i)
+            y.append(y_i)
+
+        y_tensor = torch.cat(y, dim=1)
+        return y_tensor
+
+
+class SEBlock(Module):
+    """An implementation of squeeze-and-excitation block.
+
+    Arguments
+    ---------
+    in_channels : int
+        The number of input channels.
+    se_channels : int
+        The number of output channels after squeeze.
+    out_channels : int
+        The number of output channels.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        se_channels: int,
+        out_channels: int,
+    ):
+        super().__init__()
+
+        self.conv1 = Conv1d(
+            in_channels=in_channels, out_channels=se_channels, kernel_size=1
+        )
+        self.relu = ReLU(inplace=True)
+        self.conv2 = Conv1d(
+            in_channels=se_channels, out_channels=out_channels, kernel_size=1
+        )
+        self.sigmoid = Sigmoid()
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        if padding_mask is not None:
+            mask = padding_mask.materialize().unsqueeze(1)
+            s = (x * mask).sum(dim=2, keepdim=True) / padding_mask.seq_lens[
+                :, None, None
+            ]
+        else:
+            s = x.mean(dim=2, keepdim=True)
+
+        s = self.relu(self.conv1(s))
+        s = self.sigmoid(self.conv2(s))
+
+        return s * x
+
+
+class AttentiveStatisticsPooling(Module):
+    """This class implements an attentive statistic pooling layer for each channel.
+    It returns the concatenated mean and std of the input tensor.
+
+    Arguments
+    ---------
+    channels: int
+        The number of input channels.
+    attention_channels: int
+        The number of attention channels.
+    """
+
+    def __init__(
+        self, channels: int, attention_channels: int = 128, global_context: bool = True
+    ):
+        super().__init__()
+
+        self.eps = 1e-12
+        self.global_context = global_context
+        if global_context:
+            self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
+        else:
+            self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
+
+        self.tanh = Tanh()
+        self.conv = Conv1d(
+            in_channels=attention_channels, out_channels=channels, kernel_size=1
+        )
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Calculates mean and std for a batch (input tensor).
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of shape [N, C, L].
+        """
+        L = x.shape[-1]
+
+        def _compute_statistics(
+            x: Tensor, m: Tensor, dim: int = 2, eps: float = self.eps
+        ) -> Tuple[Tensor, Tensor]:
+            mean = (m * x).sum(dim)
+            std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
+            return mean, std
+
+        # if lengths is None:
+        #     lengths = [x.shape[0]]
+
+        # Make binary mask of shape [N, 1, L]
+        # mask = to_padding_mask(lengths, max(lengths))
+        if padding_mask is not None:
+            mask = padding_mask.materialize()
+        else:
+            mask = to_padding_mask(torch.IntTensor([L]), L).repeat(x.shape[0], 1).to(x)
+        mask = mask.unsqueeze(1)
+
+        # Expand the temporal context of the pooling layer by allowing the
+        # self-attention to look at global properties of the utterance.
+        if self.global_context:
+            # torch.std is unstable for backward computation
+            # https://github.com/pytorch/pytorch/issues/4320
+            total = mask.sum(dim=2, keepdim=True).to(x)
+            mean, std = _compute_statistics(x, mask / total)
+            mean = mean.unsqueeze(2).repeat(1, 1, L)
+            std = std.unsqueeze(2).repeat(1, 1, L)
+            attn = torch.cat([x, mean, std], dim=1)
+        else:
+            attn = x
+
+        # Apply layers
+        attn = self.conv(self.tanh(self.tdnn(attn)))
+
+        # Filter out zero-paddings
+        attn = attn.masked_fill(mask == 0, float("-inf"))
+
+        attn = F.softmax(attn, dim=2)
+        mean, std = _compute_statistics(x, attn)
+        # Append mean and std of the batch
+        pooled_stats = torch.cat((mean, std), dim=1)
+        pooled_stats = pooled_stats.unsqueeze(2)
+
+        return pooled_stats
+
+
+class SERes2NetBlock(Module):
+    """An implementation of building block in ECAPA-TDNN, i.e.,
+    TDNN-Res2Net-TDNN-SEBlock.
+
+    Arguments
+    ----------
+    out_channels: int
+        The number of output channels.
+    res2net_scale: int
+        The scale of the Res2Net block.
+    kernel_size: int
+        The kernel size of the TDNN blocks.
+    dilation: int
+        The dilation of the Res2Net block.
+    groups: int
+    Number of blocked connections from input channels to output channels.
+
+    Example
+    -------
+    >>> x = torch.rand(8, 120, 64).transpose(1, 2)
+    >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
+    >>> out = conv(x).transpose(1, 2)
+    >>> out.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        res2net_scale: int = 8,
+        se_channels: int = 128,
+        kernel_size: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+    ):
+        super().__init__()
+        self.out_channels = out_channels
+        self.tdnn1 = TDNNBlock(
+            in_channels,
+            out_channels,
+            kernel_size=1,
+            dilation=1,
+            groups=groups,
+        )
+        self.res2net_block = Res2NetBlock(
+            out_channels,
+            out_channels,
+            res2net_scale,
+            kernel_size,
+            dilation,
+        )
+        self.tdnn2 = TDNNBlock(
+            out_channels,
+            out_channels,
+            kernel_size=1,
+            dilation=1,
+            groups=groups,
+        )
+        self.se_block = SEBlock(out_channels, se_channels, out_channels)
+
+        self.shortcut = None
+        if in_channels != out_channels:
+            self.shortcut = Conv1d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+            )
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        residual = x
+        if self.shortcut:
+            residual = self.shortcut(x)
+
+        x = self.tdnn1(x)
+        x = self.res2net_block(x)
+        x = self.tdnn2(x)
+        x = self.se_block(x, padding_mask=padding_mask)
+
+        return x + residual

+ 112 - 0
src/seamless_communication/models/pretssel/ecapa_tdnn_builder.py

@@ -0,0 +1,112 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import List, Optional
+
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.typing import DataType, Device
+
+from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
+
+
+@dataclass
+class EcapaTDNNConfig:
+    channels: List[int]
+    kernel_sizes: List[int]
+    dilations: List[int]
+    attention_channels: int
+    res2net_scale: int
+    se_channels: int
+    global_context: bool
+    groups: List[int]
+    embed_dim: int
+    input_dim: int
+
+
+ecapa_tdnn_archs = ArchitectureRegistry[EcapaTDNNConfig]("ecapa_tdnn")
+
+ecapa_tdnn_arch = ecapa_tdnn_archs.marker
+
+
+@ecapa_tdnn_arch("base")
+def _base_ecapa_tdnn() -> EcapaTDNNConfig:
+    return EcapaTDNNConfig(
+        channels=[512, 512, 512, 512, 1536],
+        kernel_sizes=[5, 3, 3, 3, 1],
+        dilations=[1, 2, 3, 4, 1],
+        attention_channels=128,
+        res2net_scale=8,
+        se_channels=128,
+        global_context=True,
+        groups=[1, 1, 1, 1, 1],
+        embed_dim=512,
+        input_dim=80,
+    )
+
+
+class EcapaTDNNBuilder:
+    """
+    Builder module for ECAPA_TDNN model
+    """
+
+    config: EcapaTDNNConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
+
+    def __init__(
+        self,
+        config: EcapaTDNNConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param devicev:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+
+        self.device, self.dtype = device, dtype
+
+    def build_model(self) -> ECAPA_TDNN:
+        """Build a model."""
+        model = ECAPA_TDNN(
+            self.config.channels,
+            self.config.kernel_sizes,
+            self.config.dilations,
+            self.config.attention_channels,
+            self.config.res2net_scale,
+            self.config.se_channels,
+            self.config.global_context,
+            self.config.groups,
+            self.config.embed_dim,
+            self.config.input_dim,
+        )
+        model.to(device=self.device, dtype=self.dtype)
+        return model
+
+
+def create_ecapa_tdnn_model(
+    config: EcapaTDNNConfig,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> ECAPA_TDNN:
+    """Create a ECAPA_TDNN model.
+
+    :param config:
+        The configuration to use.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+
+    return EcapaTDNNBuilder(config, device=device, dtype=dtype).build_model()

+ 1 - 0
src/seamless_communication/models/unity/__init__.py

@@ -20,6 +20,7 @@ from seamless_communication.models.unity.char_tokenizer import (
 from seamless_communication.models.unity.char_tokenizer import (
 from seamless_communication.models.unity.char_tokenizer import (
     load_unity_char_tokenizer as load_unity_char_tokenizer,
     load_unity_char_tokenizer as load_unity_char_tokenizer,
 )
 )
+from seamless_communication.models.unity.film import FiLM
 from seamless_communication.models.unity.length_regulator import (
 from seamless_communication.models.unity.length_regulator import (
     HardUpsampling as HardUpsampling,
     HardUpsampling as HardUpsampling,
 )
 )

+ 96 - 4
src/seamless_communication/models/unity/builder.py

@@ -14,15 +14,23 @@ from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.models.wav2vec2 import Wav2Vec2EncoderBuilder, Wav2Vec2EncoderConfig
 from fairseq2.models.wav2vec2 import Wav2Vec2EncoderBuilder, Wav2Vec2EncoderConfig
 from fairseq2.nn.projection import TiedProjection
 from fairseq2.nn.projection import TiedProjection
 from fairseq2.nn.transformer import (
 from fairseq2.nn.transformer import (
+    FeedForwardNetwork,
     MultiheadAttention,
     MultiheadAttention,
     StandardFeedForwardNetwork,
     StandardFeedForwardNetwork,
     StandardMultiheadAttention,
     StandardMultiheadAttention,
     TransformerEncoder,
     TransformerEncoder,
     TransformerEncoderLayer,
     TransformerEncoderLayer,
+    TransformerNormOrder,
     create_default_sdpa,
     create_default_sdpa,
 )
 )
-from fairseq2.typing import DataType, Device
+from fairseq2.typing import DataType, Device, override
+from torch.nn import GELU, ReLU
 
 
+from seamless_communication.models.pretssel import (
+    EcapaTDNNBuilder,
+    EcapaTDNNConfig,
+    ecapa_tdnn_archs,
+)
 from seamless_communication.models.unity.adaptor_block import (
 from seamless_communication.models.unity.adaptor_block import (
     UnitYConformerAdaptorLayer,
     UnitYConformerAdaptorLayer,
     UnitYEncoderAdaptor,
     UnitYEncoderAdaptor,
@@ -59,12 +67,19 @@ class UnitYConfig:
     t2u_config: Optional[UnitYT2UConfig]
     t2u_config: Optional[UnitYT2UConfig]
     """The configuration of the UnitY T2U sub-model."""
     """The configuration of the UnitY T2U sub-model."""
 
 
+    prosody_encoder_config: Optional[EcapaTDNNConfig]
+    """The configuration of the expressive prosody encoder."""
+
     use_text_encoder: bool
     use_text_encoder: bool
     """If ``True``, uses an aligned MT encoder for the MT task."""
     """If ``True``, uses an aligned MT encoder for the MT task."""
 
 
     use_conformer_adaptor: bool
     use_conformer_adaptor: bool
     """If ``True``, uses a Conformer-based adaptor block."""
     """If ``True``, uses a Conformer-based adaptor block."""
 
 
+    use_gelu: bool
+    """If ``True``, uses GELU activation function in feed-forward networks of
+    adaptor blocks and decoder layers."""
+
     num_adaptor_layers: int
     num_adaptor_layers: int
     """The number of Transformer encoder layers in the adaptor block."""
     """The number of Transformer encoder layers in the adaptor block."""
 
 
@@ -103,8 +118,10 @@ def _base() -> UnitYConfig:
         w2v2_encoder_config=w2vbert_config.w2v2_config.encoder_config,
         w2v2_encoder_config=w2vbert_config.w2v2_config.encoder_config,
         mt_model_config=mt_model_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         t2u_config=t2u_config,
+        prosody_encoder_config=None,
         use_text_encoder=True,
         use_text_encoder=True,
         use_conformer_adaptor=False,
         use_conformer_adaptor=False,
+        use_gelu=False,
         num_adaptor_layers=1,
         num_adaptor_layers=1,
         adaptor_kernel_size=8,
         adaptor_kernel_size=8,
         adaptor_stride=8,
         adaptor_stride=8,
@@ -128,8 +145,10 @@ def _medium() -> UnitYConfig:
         w2v2_encoder_config=w2vbert_config.w2v2_config.encoder_config,
         w2v2_encoder_config=w2vbert_config.w2v2_config.encoder_config,
         mt_model_config=mt_model_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         t2u_config=t2u_config,
+        prosody_encoder_config=None,
         use_text_encoder=True,
         use_text_encoder=True,
         use_conformer_adaptor=False,
         use_conformer_adaptor=False,
+        use_gelu=False,
         num_adaptor_layers=1,
         num_adaptor_layers=1,
         adaptor_kernel_size=8,
         adaptor_kernel_size=8,
         adaptor_stride=8,
         adaptor_stride=8,
@@ -155,8 +174,43 @@ def _base_v2() -> UnitYConfig:
         w2v2_encoder_config=w2v2_chunk_encoder_config,
         w2v2_encoder_config=w2v2_chunk_encoder_config,
         mt_model_config=mt_model_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
         t2u_config=t2u_config,
+        prosody_encoder_config=None,
         use_text_encoder=True,
         use_text_encoder=True,
         use_conformer_adaptor=False,
         use_conformer_adaptor=False,
+        use_gelu=False,
+        num_adaptor_layers=1,
+        adaptor_kernel_size=8,
+        adaptor_stride=8,
+        adaptor_layer_norm=True,
+        adaptor_dropout_p=0.1,
+    )
+
+
+@unity_arch("expressivity_v2")
+def _expressivity_v2() -> UnitYConfig:
+    w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
+
+    mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
+
+    mt_model_config.vocab_info.size = 256102  # NLLB-100
+
+    mt_model_config.vocab_info.pad_idx = 1
+
+    mt_model_config.max_seq_len = 4000
+
+    t2u_config = unity_t2u_archs.get_config("expressivity_nar")
+
+    prosody_encoder_config = ecapa_tdnn_archs.get_config("base")
+
+    return UnitYConfig(
+        model_dim=1024,
+        w2v2_encoder_config=w2v2_chunk_encoder_config,
+        mt_model_config=mt_model_config,
+        t2u_config=t2u_config,
+        prosody_encoder_config=prosody_encoder_config,
+        use_text_encoder=False,
+        use_conformer_adaptor=False,
+        use_gelu=True,
         num_adaptor_layers=1,
         num_adaptor_layers=1,
         adaptor_kernel_size=8,
         adaptor_kernel_size=8,
         adaptor_stride=8,
         adaptor_stride=8,
@@ -176,6 +230,7 @@ class UnitYBuilder:
     w2v2_encoder_builder: Wav2Vec2EncoderBuilder
     w2v2_encoder_builder: Wav2Vec2EncoderBuilder
     mt_model_builder: NllbBuilder
     mt_model_builder: NllbBuilder
     t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None]
     t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None]
+    prosody_encoder_builder: Optional[EcapaTDNNBuilder]
     device: Optional[Device]
     device: Optional[Device]
     dtype: Optional[DataType]
     dtype: Optional[DataType]
 
 
@@ -185,6 +240,7 @@ class UnitYBuilder:
         w2v2_encoder_builder: Wav2Vec2EncoderBuilder,
         w2v2_encoder_builder: Wav2Vec2EncoderBuilder,
         mt_model_builder: NllbBuilder,
         mt_model_builder: NllbBuilder,
         t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None],
         t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None],
+        prosody_encoder_builder: Optional[EcapaTDNNBuilder],
         *,
         *,
         device: Optional[Device] = None,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
         dtype: Optional[DataType] = None,
@@ -223,6 +279,7 @@ class UnitYBuilder:
         self.w2v2_encoder_builder = w2v2_encoder_builder
         self.w2v2_encoder_builder = w2v2_encoder_builder
         self.mt_model_builder = mt_model_builder
         self.mt_model_builder = mt_model_builder
         self.t2u_builder = t2u_builder
         self.t2u_builder = t2u_builder
+        self.prosody_encoder_builder = prosody_encoder_builder
 
 
         self.device, self.dtype = device, dtype
         self.device, self.dtype = device, dtype
 
 
@@ -251,6 +308,11 @@ class UnitYBuilder:
         else:
         else:
             t2u_model = self.t2u_builder.build_model()
             t2u_model = self.t2u_builder.build_model()
 
 
+        if self.prosody_encoder_builder is None:
+            prosody_encoder_model = None
+        else:
+            prosody_encoder_model = self.prosody_encoder_builder.build_model()
+
         return UnitYModel(
         return UnitYModel(
             speech_encoder_frontend,
             speech_encoder_frontend,
             speech_encoder,
             speech_encoder,
@@ -261,6 +323,7 @@ class UnitYBuilder:
             final_proj,
             final_proj,
             t2u_model,
             t2u_model,
             self.config.mt_model_config.vocab_info,
             self.config.mt_model_config.vocab_info,
+            prosody_encoder_model,
         )
         )
 
 
     def build_speech_encoder(self) -> TransformerEncoder:
     def build_speech_encoder(self) -> TransformerEncoder:
@@ -292,11 +355,10 @@ class UnitYBuilder:
             self.w2v2_encoder_builder.config.num_encoder_attn_heads
             self.w2v2_encoder_builder.config.num_encoder_attn_heads
         )
         )
 
 
-        # Unlike wav2vec2, we use ReLU (i.e. standard FFN activation function)
-        # instead of GELU.
         ffn = StandardFeedForwardNetwork(
         ffn = StandardFeedForwardNetwork(
             self.config.model_dim,
             self.config.model_dim,
             self.w2v2_encoder_builder.config.ffn_inner_dim,
             self.w2v2_encoder_builder.config.ffn_inner_dim,
+            inner_activation=GELU() if self.config.use_gelu else ReLU(),
             bias=True,
             bias=True,
             device=self.device,
             device=self.device,
             dtype=self.dtype,
             dtype=self.dtype,
@@ -365,6 +427,20 @@ class UnitYBuilder:
         )
         )
 
 
 
 
+class NllbWithGELUBuilder(NllbBuilder):
+    @override
+    def build_ffn(self) -> FeedForwardNetwork:
+        return StandardFeedForwardNetwork(
+            self.config.model_dim,
+            self.config.ffn_inner_dim,
+            bias=True,
+            inner_activation=GELU(),
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+
 def create_unity_model(
 def create_unity_model(
     config: UnitYConfig,
     config: UnitYConfig,
     device: Optional[Device] = None,
     device: Optional[Device] = None,
@@ -397,12 +473,28 @@ def create_unity_model(
     else:
     else:
         t2u_builder = UnitYNART2UBuilder(config.t2u_config, device=device, dtype=dtype)
         t2u_builder = UnitYNART2UBuilder(config.t2u_config, device=device, dtype=dtype)
 
 
-    mt_model_builder = NllbBuilder(config.mt_model_config, device=device, dtype=dtype)
+    if config.prosody_encoder_config is None:
+        prosody_encoder_builder = None
+    else:
+        prosody_encoder_builder = EcapaTDNNBuilder(
+            config.prosody_encoder_config, device=device, dtype=dtype
+        )
+
+    if config.use_gelu:
+        mt_model_builder: NllbBuilder = NllbWithGELUBuilder(
+            config.mt_model_config, device=device, dtype=dtype
+        )
+    else:
+        mt_model_builder = NllbBuilder(
+            config.mt_model_config, device=device, dtype=dtype
+        )
+
     unity_builder = UnitYBuilder(
     unity_builder = UnitYBuilder(
         config,
         config,
         w2v2_encoder_builder,
         w2v2_encoder_builder,
         mt_model_builder,
         mt_model_builder,
         t2u_builder,
         t2u_builder,
+        prosody_encoder_builder,
         device=device,
         device=device,
         dtype=dtype,
         dtype=dtype,
     )
     )

+ 68 - 0
src/seamless_communication/models/unity/film.py

@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from typing import Optional
+
+import torch
+from fairseq2.nn.projection import Linear
+from fairseq2.typing import DataType, Device
+from torch import Tensor
+from torch.nn import Module, Parameter
+
+
+class FiLM(Module):
+    """
+    A Feature-wise Linear Modulation Layer from
+    'FiLM: Visual Reasoning with a General Conditioning Layer'
+    """
+
+    proj: Linear
+    s_gamma: Parameter
+    s_beta: Parameter
+
+    def __init__(
+        self,
+        cond_dim: int,
+        embed_dim: int,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.proj = Linear(
+            cond_dim, 2 * embed_dim, bias=True, device=device, dtype=dtype
+        )
+
+        self.s_gamma = Parameter(
+            torch.ones(
+                1,
+                device=device,
+                dtype=dtype,
+            ),
+            requires_grad=True,
+        )
+
+        self.s_beta = Parameter(
+            torch.ones(
+                1,
+                device=device,
+                dtype=dtype,
+            ),
+            requires_grad=True,
+        )
+
+    def forward(self, x: Tensor, cond_embs: Tensor) -> Tensor:
+        """
+        x -- [B, T, H]
+        cond_emb -- [B, 1, C]
+        """
+        # get trainable gamma, beta
+        gammas, betas = self.proj(cond_embs).chunk(2, dim=-1)  # B x 1 x H
+
+        # apply film
+        gammas = self.s_gamma * gammas.expand_as(x)
+        betas = self.s_beta * betas.expand_as(x)
+
+        return (gammas + 1.0) * x + betas  # type: ignore[no-any-return]

+ 26 - 2
src/seamless_communication/models/unity/length_regulator.py

@@ -14,6 +14,8 @@ from fairseq2.typing import DataType, Device
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Conv1d, Dropout, Module, ReLU, Sequential
 from torch.nn import Conv1d, Dropout, Module, ReLU, Sequential
 
 
+from seamless_communication.models.unity.film import FiLM
+
 
 
 class HardUpsampling(Module):
 class HardUpsampling(Module):
     """Upsamples sequences in a deterministic way as governed by durations."""
     """Upsamples sequences in a deterministic way as governed by durations."""
@@ -46,6 +48,7 @@ class VariancePredictor(Module):
     conv2: Sequential
     conv2: Sequential
     ln2: LayerNorm
     ln2: LayerNorm
     proj: Linear
     proj: Linear
+    film: Optional[FiLM]
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -54,6 +57,8 @@ class VariancePredictor(Module):
         var_pred_kernel_size: int,
         var_pred_kernel_size: int,
         var_pred_dropout: float,
         var_pred_dropout: float,
         bias: bool = True,
         bias: bool = True,
+        use_film: bool = False,
+        film_cond_dim: int = 512,
         device: Optional[Device] = None,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
         dtype: Optional[DataType] = None,
     ):
     ):
@@ -99,7 +104,19 @@ class VariancePredictor(Module):
             var_pred_hidden_dim, 1, bias=True, device=device, dtype=dtype
             var_pred_hidden_dim, 1, bias=True, device=device, dtype=dtype
         )
         )
 
 
-    def forward(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor:
+        if use_film:
+            self.film = FiLM(
+                film_cond_dim, var_pred_hidden_dim, device=device, dtype=dtype
+            )
+        else:
+            self.register_module("film", None)
+
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+        film_cond_emb: Optional[Tensor] = None,
+    ) -> Tensor:
         # Ensure that we do not leak padded positions in the convolution layer.
         # Ensure that we do not leak padded positions in the convolution layer.
         seqs = apply_padding_mask(seqs, padding_mask)
         seqs = apply_padding_mask(seqs, padding_mask)
 
 
@@ -131,6 +148,12 @@ class VariancePredictor(Module):
 
 
         seqs = self.dropout_module(seqs)
         seqs = self.dropout_module(seqs)
 
 
+        seqs = apply_padding_mask(seqs, padding_mask)
+
+        if self.film is not None and film_cond_emb is not None:
+            seqs = self.film(seqs, film_cond_emb)
+            seqs = apply_padding_mask(seqs, padding_mask)
+
         # (N, S, H) -> (N, S, 1) -> (N, S)
         # (N, S, H) -> (N, S, 1) -> (N, S)
         seqs = self.proj(seqs).squeeze(dim=2)
         seqs = self.proj(seqs).squeeze(dim=2)
 
 
@@ -174,8 +197,9 @@ class VarianceAdaptor(Module):
         padding_mask: Optional[PaddingMask],
         padding_mask: Optional[PaddingMask],
         duration_factor: float = 1.0,
         duration_factor: float = 1.0,
         min_duration: int = 0,
         min_duration: int = 0,
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, PaddingMask]:
     ) -> Tuple[Tensor, PaddingMask]:
-        log_durations = self.duration_predictor(seqs, padding_mask)
+        log_durations = self.duration_predictor(seqs, padding_mask, film_cond_emb)
 
 
         durations = torch.clamp(
         durations = torch.clamp(
             torch.round((torch.exp(log_durations) - 1) * duration_factor).long(),
             torch.round((torch.exp(log_durations) - 1) * duration_factor).long(),

+ 79 - 35
src/seamless_communication/models/unity/loader.py

@@ -47,10 +47,16 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
 
 
         keys_to_delete = []
         keys_to_delete = []
 
 
+        # ExpressiveUnitY model (from multi_arch codebase)
+        if config.prosody_encoder_config is not None:
+            encoder_key = "s2t_model.encoder"
+            decoder_key = "s2t_model.decoder"
+            t2u_decoder_key = "t2s_model.decoder"
         # X2T/S2T + T2U model.
         # X2T/S2T + T2U model.
-        if config.t2u_config is not None:
+        elif config.t2u_config is not None:
             encoder_key = "encoder"
             encoder_key = "encoder"
             decoder_key = "target_letter_decoder"
             decoder_key = "target_letter_decoder"
+            t2u_decoder_key = "decoder"
         # X2T model.
         # X2T model.
         elif config.use_text_encoder:
         elif config.use_text_encoder:
             encoder_key = "speech_encoder"
             encoder_key = "speech_encoder"
@@ -70,12 +76,18 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
         keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
         keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
 
 
-        keys_to_delete.append("decoder.char_upsampler.embed_positions._float_tensor")
-        keys_to_delete.append("decoder.char_upsampler.embed_tokens_char.weight")
+        keys_to_delete.append(
+            f"{t2u_decoder_key}.char_upsampler.embed_positions._float_tensor"
+        )
+        keys_to_delete.append(
+            f"{t2u_decoder_key}.char_upsampler.embed_tokens_char.weight"
+        )
 
 
         # Delete AlignmentEncoder keys for inference.
         # Delete AlignmentEncoder keys for inference.
         alignment_encoder_keys = [
         alignment_encoder_keys = [
-            key for key in state_dict if key.startswith("decoder.alignment_encoder.")
+            key
+            for key in state_dict
+            if key.startswith(f"{t2u_decoder_key}.alignment_encoder.")
         ]
         ]
         keys_to_delete.extend(alignment_encoder_keys)
         keys_to_delete.extend(alignment_encoder_keys)
 
 
@@ -87,6 +99,17 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             ]
             ]
         )
         )
 
 
+        if config.prosody_encoder_config is not None:
+            keys_to_delete.extend(
+                [
+                    f"{t2u_decoder_key}.embed_positions._float_tensor",
+                    "t2s_model.global_proj_dec.weight",
+                    "t2s_model.global_proj_dec.bias",
+                    "t2s_model.decoder_target_letter_nllb_spm_decoder.encoder.proj.weight",
+                    "t2s_model.decoder_target_letter_nllb_spm_decoder.encoder.proj.bias",
+                ]
+            )
+
         for key in keys_to_delete:
         for key in keys_to_delete:
             if key in state_dict:
             if key in state_dict:
                 del state_dict[key]
                 del state_dict[key]
@@ -157,10 +180,19 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
 
 
     @staticmethod
     @staticmethod
     def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
     def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
+        # ExpressiveUnitY model (from multi_arch codebase)
+        if config.prosody_encoder_config is not None:
+            encoder_key = "s2t_model.encoder"
+            decoder_key = "s2t_model.decoder"
+            t2u_encoder_key = "t2s_model.encoder"
+            t2u_decoder_key = "t2s_model.decoder"
+            ecapa_tdnn_key = "global_prosody"
         # X2T/S2T + T2U model.
         # X2T/S2T + T2U model.
-        if config.t2u_config is not None:
+        elif config.t2u_config is not None:
             encoder_key = "encoder"
             encoder_key = "encoder"
             decoder_key = "target_letter_decoder"
             decoder_key = "target_letter_decoder"
+            t2u_encoder_key = "synthesizer_encoder"
+            t2u_decoder_key = "decoder"
         # X2T model.
         # X2T model.
         elif config.use_text_encoder:
         elif config.use_text_encoder:
             encoder_key = "speech_encoder"
             encoder_key = "speech_encoder"
@@ -231,8 +263,8 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         # fairseq was accidentally run with a pre-LN encoder, and ended up with
         # fairseq was accidentally run with a pre-LN encoder, and ended up with
         # a redundant `LayerNorm` right after the Conformer blocks. We mitigate
         # a redundant `LayerNorm` right after the Conformer blocks. We mitigate
         # that issue here by moving that `LayerNorm` to the adaptor block.
         # that issue here by moving that `LayerNorm` to the adaptor block.
+        # fmt: off
         if config.w2v2_encoder_config.use_conformer:
         if config.w2v2_encoder_config.use_conformer:
-            # fmt: off
             key_map.update(
             key_map.update(
                 {
                 {
                     fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
                     fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
@@ -244,7 +276,7 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
                     rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
                     rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
                 }
                 }
             )
             )
-            # fmt: on
+        # fmt: on
 
 
         if config.use_conformer_adaptor:
         if config.use_conformer_adaptor:
             key_map.update(
             key_map.update(
@@ -303,44 +335,56 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
                 # fmt: on
                 # fmt: on
             }
             }
         )
         )
+        # ExpressiveUnitY model (from multi_arch codebase)
+        if config.prosody_encoder_config is not None:
+            key_map.update(
+                {
+                    # fmt: off
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.film\.":              r"t2u_model.decoder.layers.\1.film.",
+                    fr"^{ecapa_tdnn_key}\.":                                       r"prosody_encoder_model.",
+                    r"^t2s_model\.global_proj_enc\.":                             r"t2u_model.prosody_proj.",
+                    # fmt: on
+                }
+            )
+
         # X2T/S2T + T2U model.
         # X2T/S2T + T2U model.
         if config.t2u_config is not None:
         if config.t2u_config is not None:
             key_map.update(
             key_map.update(
                 {
                 {
                     # fmt: off
                     # fmt: off
                     # T2U Encoder
                     # T2U Encoder
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.encoder.layers.\1.self_attn.",
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.encoder.layers.\1.self_attn_layer_norm.",
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.encoder.layers.\1.ffn.inner_proj.",
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.encoder.layers.\1.ffn.output_proj.",
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
-                    r"^synthesizer_encoder\.layer_norm\.":                                r"t2u_model.encoder.layer_norm.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.encoder.layers.\1.self_attn.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.encoder.layers.\1.self_attn_layer_norm.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.encoder.layers.\1.ffn.inner_proj.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.encoder.layers.\1.ffn.output_proj.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
+                    fr"^{t2u_encoder_key}\.layer_norm\.":                                r"t2u_model.encoder.layer_norm.",
 
 
                     # T2U Decoder frontend
                     # T2U Decoder frontend
-                    r"^decoder\.embed_tokens_text\.":                           r"t2u_model.decoder_frontend.embed_char.",
-                    r"^decoder\.embed_tokens_unit\.":                           r"t2u_model.decoder_frontend.embed.",
-                    r"^decoder\.embed_tokens\.":                                r"t2u_model.decoder_frontend.embed.",
-                    r"^decoder\.var_adaptor\.duration_predictor\.":             r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
-                    r"^decoder\.dec_pos_emb_alpha":                             r"t2u_model.decoder_frontend.pos_emb_alpha",
-                    r"^decoder\.char_upsampler\.pos_emb_alpha":                 r"t2u_model.decoder_frontend.pos_emb_alpha_char",
+                    fr"^{t2u_decoder_key}\.embed_tokens_text\.":                           r"t2u_model.decoder_frontend.embed_char.",
+                    fr"^{t2u_decoder_key}\.embed_tokens_unit\.":                           r"t2u_model.decoder_frontend.embed.",
+                    fr"^{t2u_decoder_key}\.embed_tokens\.":                                r"t2u_model.decoder_frontend.embed.",
+                    fr"^{t2u_decoder_key}\.var_adaptor\.duration_predictor\.":             r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
+                    fr"^{t2u_decoder_key}\.dec_pos_emb_alpha":                             r"t2u_model.decoder_frontend.pos_emb_alpha",
+                    fr"^{t2u_decoder_key}\.char_upsampler\.pos_emb_alpha":                 r"t2u_model.decoder_frontend.pos_emb_alpha_char",
 
 
                     # T2U Decoder
                     # T2U Decoder
-                    r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.decoder.layers.\1.self_attn.",
-                    r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
-                    r"^decoder\.layers\.([0-9]+)\.layer_norm\.":              r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
-                    r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
-                    r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
-                    r"^decoder\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.decoder.layers.\1.ffn.output_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
-                    r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"t2u_model.decoder.layers.\1.conv1d.conv1.",
-                    r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"t2u_model.decoder.layers.\1.conv1d.conv2.",
-                    r"^decoder\.layers\.([0-9]+)\.ffn\.layer_norm\.":         r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
-                    r"^decoder\.layer_norm\.":                                r"t2u_model.decoder.layer_norm.",
-                    r"^decoder\.output_projection\.":                         r"t2u_model.final_proj.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.decoder.layers.\1.self_attn.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.layer_norm\.":              r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn\.":            r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.decoder.layers.\1.ffn.output_proj.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"t2u_model.decoder.layers.\1.conv1d.conv1.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"t2u_model.decoder.layers.\1.conv1d.conv2.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.layer_norm\.":         r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
+                    fr"^{t2u_decoder_key}\.layer_norm\.":                                r"t2u_model.decoder.layer_norm.",
+                    fr"^{t2u_decoder_key}\.output_projection\.":                         r"t2u_model.final_proj.",
                     # fmt: on
                     # fmt: on
                 }
                 }
             )
             )

+ 19 - 2
src/seamless_communication/models/unity/model.py

@@ -19,6 +19,7 @@ from overrides import final as finaloverride
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Module
 from torch.nn import Module
 
 
+from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
 from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
 from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 
 
@@ -42,6 +43,7 @@ class UnitYModel(EncoderDecoderModel):
     text_decoder: TransformerDecoder
     text_decoder: TransformerDecoder
     final_proj: Projection
     final_proj: Projection
     t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None]
     t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None]
+    prosody_encoder_model: Optional[ECAPA_TDNN]
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -54,6 +56,7 @@ class UnitYModel(EncoderDecoderModel):
         final_proj: Projection,
         final_proj: Projection,
         t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
         t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
         target_vocab_info: VocabularyInfo,
         target_vocab_info: VocabularyInfo,
+        prosody_encoder_model: Optional[ECAPA_TDNN] = None,
         input_modality: str = "speech",
         input_modality: str = "speech",
     ) -> None:
     ) -> None:
         model_dim = speech_encoder.model_dim
         model_dim = speech_encoder.model_dim
@@ -93,6 +96,10 @@ class UnitYModel(EncoderDecoderModel):
             self.register_module("t2u_model", None)
             self.register_module("t2u_model", None)
 
 
         self.target_vocab_info = target_vocab_info
         self.target_vocab_info = target_vocab_info
+        if prosody_encoder_model is not None:
+            self.prosody_encoder_model = prosody_encoder_model
+        else:
+            self.register_module("prosody_encoder_model", None)
 
 
     @finaloverride
     @finaloverride
     def encode(
     def encode(
@@ -304,6 +311,7 @@ class UnitYNART2UModel(Module):
     decoder: NARTransformerDecoder
     decoder: NARTransformerDecoder
     final_proj: Projection
     final_proj: Projection
     target_vocab_info: VocabularyInfo
     target_vocab_info: VocabularyInfo
+    prosody_proj: Optional[Projection]
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -312,6 +320,7 @@ class UnitYNART2UModel(Module):
         decoder: NARTransformerDecoder,
         decoder: NARTransformerDecoder,
         final_proj: Projection,
         final_proj: Projection,
         target_vocab_info: VocabularyInfo,
         target_vocab_info: VocabularyInfo,
+        prosody_proj: Optional[Projection] = None,
     ) -> None:
     ) -> None:
         super().__init__()
         super().__init__()
 
 
@@ -339,20 +348,27 @@ class UnitYNART2UModel(Module):
 
 
         self.target_vocab_info = target_vocab_info
         self.target_vocab_info = target_vocab_info
 
 
+        self.prosody_proj = prosody_proj
+
     def forward(
     def forward(
         self,
         self,
         text_decoder_output: Tensor,
         text_decoder_output: Tensor,
         text_decoder_padding_mask: Optional[PaddingMask],
         text_decoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
         text_seqs: Optional[Tensor],
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[SequenceModelOutput, Optional[PaddingMask]]:
     ) -> Tuple[SequenceModelOutput, Optional[PaddingMask]]:
         encoder_output, encoder_padding_mask = self.encode(
         encoder_output, encoder_padding_mask = self.encode(
             text_decoder_output, text_decoder_padding_mask
             text_decoder_output, text_decoder_padding_mask
         )
         )
 
 
+        if self.prosody_proj is not None and film_cond_emb is not None:
+            encoder_output = encoder_output + self.prosody_proj(film_cond_emb)
+
         decoder_output, decoder_padding_mask = self.decode(
         decoder_output, decoder_padding_mask = self.decode(
             encoder_output,
             encoder_output,
             encoder_padding_mask,
             encoder_padding_mask,
             text_seqs,
             text_seqs,
+            film_cond_emb,
         )
         )
 
 
         return self.project(decoder_output), decoder_padding_mask
         return self.project(decoder_output), decoder_padding_mask
@@ -372,14 +388,15 @@ class UnitYNART2UModel(Module):
         encoder_output: Tensor,
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
         encoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
         text_seqs: Optional[Tensor],
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         # encoder_output: (N, S, M)
         # encoder_output: (N, S, M)
         # text_seqs: (N, S)
         # text_seqs: (N, S)
         seqs, padding_mask = self.decoder_frontend(
         seqs, padding_mask = self.decoder_frontend(
-            encoder_output, encoder_padding_mask, text_seqs
+            encoder_output, encoder_padding_mask, text_seqs, film_cond_emb
         )
         )
 
 
-        return self.decoder(seqs, padding_mask)  # type: ignore[no-any-return]
+        return self.decoder(seqs, padding_mask, film_cond_emb=film_cond_emb)  # type: ignore[no-any-return]
 
 
     def project(self, decoder_output: Tensor) -> SequenceModelOutput:
     def project(self, decoder_output: Tensor) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
         logits = self.final_proj(decoder_output)

+ 2 - 1
src/seamless_communication/models/unity/nar_decoder.py

@@ -66,9 +66,10 @@ class NARTransformerDecoder(Module):
         self,
         self,
         seqs: Tensor,
         seqs: Tensor,
         padding_mask: Optional[PaddingMask],
         padding_mask: Optional[PaddingMask],
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         for layer in self.layers.drop_iter():
         for layer in self.layers.drop_iter():
-            seqs, padding_mask = layer(seqs, padding_mask)
+            seqs, padding_mask = layer(seqs, padding_mask, film_cond_emb=film_cond_emb)
 
 
         if self.layer_norm is not None:
         if self.layer_norm is not None:
             seqs = self.layer_norm(seqs)
             seqs = self.layer_norm(seqs)

+ 2 - 0
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -302,6 +302,7 @@ class NARDecoderFrontend(Module):
         encoder_output: Tensor,
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
         encoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
         text_seqs: Optional[Tensor],
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         assert text_seqs is not None
         assert text_seqs is not None
 
 
@@ -323,6 +324,7 @@ class NARDecoderFrontend(Module):
             seqs,
             seqs,
             encoder_padding_mask,
             encoder_padding_mask,
             min_duration=1,
             min_duration=1,
+            film_cond_emb=film_cond_emb,
         )
         )
 
 
         seqs = self.forward_unit_pos_embedding(seqs, padding_mask)
         seqs = self.forward_unit_pos_embedding(seqs, padding_mask)

+ 19 - 0
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -13,6 +13,8 @@ from fairseq2.typing import DataType, Device, finaloverride
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Conv1d, Dropout, Module, ReLU
 from torch.nn import Conv1d, Dropout, Module, ReLU
 
 
+from seamless_communication.models.unity.film import FiLM
+
 
 
 @final
 @final
 class Conv1dBlock(Module):
 class Conv1dBlock(Module):
@@ -111,6 +113,7 @@ class NARTransformerDecoderLayer(Module):
     conv1d: Conv1dBlock
     conv1d: Conv1dBlock
     conv1d_dropout: Optional[Dropout]
     conv1d_dropout: Optional[Dropout]
     conv1d_layer_norm: LayerNorm
     conv1d_layer_norm: LayerNorm
+    film: Optional[FiLM]
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -118,6 +121,8 @@ class NARTransformerDecoderLayer(Module):
         conv1d: Conv1dBlock,
         conv1d: Conv1dBlock,
         dropout_p: float = 0.1,
         dropout_p: float = 0.1,
         conv1d_dropout_p: float = 0.1,
         conv1d_dropout_p: float = 0.1,
+        use_film: bool = False,
+        film_cond_dim: int = 512,
         device: Optional[Device] = None,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
         dtype: Optional[DataType] = None,
     ) -> None:
     ) -> None:
@@ -130,6 +135,10 @@ class NARTransformerDecoderLayer(Module):
             The dropout probability on the outputs of the self attention layer.
             The dropout probability on the outputs of the self attention layer.
         :param conv1d_dropout_p:
         :param conv1d_dropout_p:
             The dropout probability on the outputs of the conv1d block.
             The dropout probability on the outputs of the conv1d block.
+        :param use_film:
+            Whether to condition on a fixed-size vector through FiLM.
+        :param film_cond_dim:
+            The dim of fixed-size vector conditioned on during model forward.
         """
         """
         super().__init__()
         super().__init__()
 
 
@@ -159,16 +168,26 @@ class NARTransformerDecoderLayer(Module):
             self.model_dim, device=device, dtype=dtype
             self.model_dim, device=device, dtype=dtype
         )
         )
 
 
+        if use_film:
+            self.film = FiLM(film_cond_dim, self.model_dim, device=device, dtype=dtype)
+        else:
+            self.register_module("film", None)
+
     @finaloverride
     @finaloverride
     def forward(
     def forward(
         self,
         self,
         seqs: Tensor,
         seqs: Tensor,
         padding_mask: Optional[PaddingMask],
         padding_mask: Optional[PaddingMask],
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         seqs = self._forward_self_attn(seqs, padding_mask)
         seqs = self._forward_self_attn(seqs, padding_mask)
 
 
         seqs = self._forward_conv1d(seqs, padding_mask)
         seqs = self._forward_conv1d(seqs, padding_mask)
 
 
+        if self.film is not None and film_cond_emb is not None:
+            seqs = self.film(seqs, film_cond_emb)
+            seqs = apply_padding_mask(seqs, padding_mask)
+
         return seqs, padding_mask
         return seqs, padding_mask
 
 
     def _forward_self_attn(
     def _forward_self_attn(

+ 106 - 7
src/seamless_communication/models/unity/t2u_builder.py

@@ -17,7 +17,7 @@ from fairseq2.models.transformer import (
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
-from fairseq2.nn.projection import TiedProjection
+from fairseq2.nn.projection import Linear, Projection, TiedProjection
 from fairseq2.nn.transformer import (
 from fairseq2.nn.transformer import (
     FeedForwardNetwork,
     FeedForwardNetwork,
     MultiheadAttention,
     MultiheadAttention,
@@ -35,6 +35,7 @@ from fairseq2.nn.transformer import (
     create_default_sdpa,
     create_default_sdpa,
 )
 )
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
+from torch.nn import GELU, ReLU
 
 
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 from seamless_communication.models.unity.length_regulator import (
 from seamless_communication.models.unity.length_regulator import (
@@ -55,6 +56,8 @@ class VariancePredictorConfig:
     var_pred_hidden_dim: int
     var_pred_hidden_dim: int
     var_pred_kernel_size: int
     var_pred_kernel_size: int
     var_pred_dropout: float
     var_pred_dropout: float
+    use_film: bool
+    film_cond_dim: int
 
 
 
 
 @dataclass
 @dataclass
@@ -73,6 +76,8 @@ class NARDecoderConfig:
     conv1d_kernel_size: int
     conv1d_kernel_size: int
     conv1d_inner_dim: int
     conv1d_inner_dim: int
     conv1d_dropout_p: float
     conv1d_dropout_p: float
+    use_film: bool
+    film_cond_dim: int
 
 
 
 
 @dataclass
 @dataclass
@@ -113,9 +118,17 @@ class UnitYT2UConfig:
     dropout_p: float
     dropout_p: float
     """The dropout probability in Transformer layers."""
     """The dropout probability in Transformer layers."""
 
 
-    def update_unit_vocabulary(self, info: VocabularyInfo) -> None:
-        """Update unit vocabulary configuration from ``info``."""
-        self.unit_vocabulary_size, self.unit_pad_idx = info.size, info.pad_idx
+    use_gelu: bool
+    """If ``True``, uses GELU activation function in feed-forward networks."""
+
+    char_pad_idx: int
+    """The index of the pad symbol in the char vocabulary."""
+
+    use_prosody_proj: bool
+    """If ``True``, uses a prosody projection layer."""
+
+    prosody_encoder_dim: int
+    """The dimensionality of prosody encoder (e.g. ECAPA_TDNN) output"""
 
 
 
 
 unity_t2u_archs = ArchitectureRegistry[UnitYT2UConfig]("unity_t2u")
 unity_t2u_archs = ArchitectureRegistry[UnitYT2UConfig]("unity_t2u")
@@ -140,6 +153,10 @@ def _base_t2u() -> UnitYT2UConfig:
         num_decoder_attn_heads=16,
         num_decoder_attn_heads=16,
         ffn_inner_dim=1024 * 8,
         ffn_inner_dim=1024 * 8,
         dropout_p=0.1,
         dropout_p=0.1,
+        use_gelu=False,
+        char_pad_idx=0,
+        use_prosody_proj=False,
+        prosody_encoder_dim=0,
     )
     )
 
 
 
 
@@ -159,6 +176,10 @@ def _medium_t2u() -> UnitYT2UConfig:
         num_decoder_attn_heads=16,
         num_decoder_attn_heads=16,
         ffn_inner_dim=1024 * 8,
         ffn_inner_dim=1024 * 8,
         dropout_p=0.1,
         dropout_p=0.1,
+        use_gelu=False,
+        char_pad_idx=0,
+        use_prosody_proj=False,
+        prosody_encoder_dim=0,
     )
     )
 
 
 
 
@@ -168,6 +189,8 @@ def _base_nar() -> UnitYT2UConfig:
         var_pred_hidden_dim=256,
         var_pred_hidden_dim=256,
         var_pred_kernel_size=3,
         var_pred_kernel_size=3,
         var_pred_dropout=0.5,
         var_pred_dropout=0.5,
+        use_film=False,
+        film_cond_dim=0,
     )
     )
 
 
     nar_decoder_frontend_config = NARDecoderFrontendConfig(
     nar_decoder_frontend_config = NARDecoderFrontendConfig(
@@ -184,6 +207,8 @@ def _base_nar() -> UnitYT2UConfig:
         conv1d_kernel_size=7,
         conv1d_kernel_size=7,
         conv1d_inner_dim=1024,
         conv1d_inner_dim=1024,
         conv1d_dropout_p=0.1,
         conv1d_dropout_p=0.1,
+        use_film=False,
+        film_cond_dim=0,
     )
     )
 
 
     return UnitYT2UConfig(
     return UnitYT2UConfig(
@@ -200,6 +225,59 @@ def _base_nar() -> UnitYT2UConfig:
         num_decoder_attn_heads=16,
         num_decoder_attn_heads=16,
         ffn_inner_dim=1024 * 8,
         ffn_inner_dim=1024 * 8,
         dropout_p=0.0,
         dropout_p=0.0,
+        use_gelu=False,
+        char_pad_idx=0,
+        use_prosody_proj=False,
+        prosody_encoder_dim=0,
+    )
+
+
+@unity_t2u_arch("expressivity_nar")
+def _expressivity_nar() -> UnitYT2UConfig:
+    duration_predictor_config = VariancePredictorConfig(
+        var_pred_hidden_dim=256,
+        var_pred_kernel_size=3,
+        var_pred_dropout=0.5,
+        use_film=True,
+        film_cond_dim=512,
+    )
+
+    nar_decoder_frontend_config = NARDecoderFrontendConfig(
+        subword_to_unit_upsampling_type="hard",
+        duration_predictor_config=duration_predictor_config,
+        pitch_predictor_config=None,
+        energy_predictor_config=None,
+    )
+
+    nar_decoder_config = NARDecoderConfig(
+        model_name_or_card="seamless_expressivity",
+        char_vocabulary_size=10904,
+        char_max_seq_len=4000,
+        conv1d_kernel_size=7,
+        conv1d_inner_dim=1024,
+        conv1d_dropout_p=0.1,
+        use_film=True,
+        film_cond_dim=512,
+    )
+
+    return UnitYT2UConfig(
+        model_dim=1024,
+        unit_max_seq_len=4000,
+        target_vocab_info=VocabularyInfo(
+            size=10005, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
+        num_encoder_layers=4,
+        num_decoder_layers=4,
+        nar_decoder_frontend_config=nar_decoder_frontend_config,
+        nar_decoder_config=nar_decoder_config,
+        num_encoder_attn_heads=16,
+        num_decoder_attn_heads=16,
+        ffn_inner_dim=1024 * 8,
+        dropout_p=0.0,
+        use_gelu=True,
+        char_pad_idx=1,
+        use_prosody_proj=True,
+        prosody_encoder_dim=512,
     )
     )
 
 
 
 
@@ -417,12 +495,15 @@ class UnitYNART2UBuilder:
 
 
         decoder_frontend = self.build_decoder_frontend(embed_unit)
         decoder_frontend = self.build_decoder_frontend(embed_unit)
 
 
+        prosody_proj = self.build_prosody_proj()
+
         return UnitYNART2UModel(
         return UnitYNART2UModel(
             encoder,
             encoder,
             decoder_frontend,
             decoder_frontend,
             decoder,
             decoder,
             final_proj,
             final_proj,
             self.config.target_vocab_info,
             self.config.target_vocab_info,
+            prosody_proj=prosody_proj,
         )
         )
 
 
     def build_unit_embedding(self) -> StandardEmbedding:
     def build_unit_embedding(self) -> StandardEmbedding:
@@ -482,6 +563,8 @@ class UnitYNART2UBuilder:
             duration_predictor_config.var_pred_hidden_dim,
             duration_predictor_config.var_pred_hidden_dim,
             duration_predictor_config.var_pred_kernel_size,
             duration_predictor_config.var_pred_kernel_size,
             duration_predictor_config.var_pred_dropout,
             duration_predictor_config.var_pred_dropout,
+            use_film=duration_predictor_config.use_film,
+            film_cond_dim=duration_predictor_config.film_cond_dim,
             device=self.device,
             device=self.device,
             dtype=self.dtype,
             dtype=self.dtype,
         )
         )
@@ -518,19 +601,18 @@ class UnitYNART2UBuilder:
         nllb_tokenizer = NllbTokenizerLoader(asset_store, download_manager)(
         nllb_tokenizer = NllbTokenizerLoader(asset_store, download_manager)(
             self.config.nar_decoder_config.model_name_or_card
             self.config.nar_decoder_config.model_name_or_card
         )
         )
-        text_pad_idx = nllb_tokenizer.vocab_info.pad_idx
 
 
         char_pos_encoder = SinusoidalPositionEncoder(
         char_pos_encoder = SinusoidalPositionEncoder(
             self.config.model_dim,
             self.config.model_dim,
             self.config.nar_decoder_config.char_max_seq_len,
             self.config.nar_decoder_config.char_max_seq_len,
-            _legacy_pad_idx=text_pad_idx,
+            _legacy_pad_idx=self.config.char_pad_idx,
             device=self.device,
             device=self.device,
         )
         )
 
 
         embed_char = StandardEmbedding(
         embed_char = StandardEmbedding(
             num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
             num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
             embedding_dim=self.config.model_dim,
             embedding_dim=self.config.model_dim,
-            pad_idx=text_pad_idx,
+            pad_idx=self.config.char_pad_idx,
             init_fn=init_scaled_embedding,
             init_fn=init_scaled_embedding,
             device=self.device,
             device=self.device,
             dtype=self.dtype,
             dtype=self.dtype,
@@ -584,6 +666,8 @@ class UnitYNART2UBuilder:
             conv1d,
             conv1d,
             dropout_p=self.config.dropout_p,
             dropout_p=self.config.dropout_p,
             conv1d_dropout_p=self.config.nar_decoder_config.conv1d_dropout_p,
             conv1d_dropout_p=self.config.nar_decoder_config.conv1d_dropout_p,
+            use_film=self.config.nar_decoder_config.use_film,
+            film_cond_dim=self.config.nar_decoder_config.film_cond_dim,
             device=self.device,
             device=self.device,
             dtype=self.dtype,
             dtype=self.dtype,
         )
         )
@@ -608,11 +692,26 @@ class UnitYNART2UBuilder:
             self.config.model_dim,
             self.config.model_dim,
             self.config.ffn_inner_dim,
             self.config.ffn_inner_dim,
             bias=True,
             bias=True,
+            inner_activation=GELU() if self.config.use_gelu else ReLU(),
             norm_order=TransformerNormOrder.PRE,
             norm_order=TransformerNormOrder.PRE,
             device=self.device,
             device=self.device,
             dtype=self.dtype,
             dtype=self.dtype,
         )
         )
 
 
+    def build_prosody_proj(self) -> Optional[Projection]:
+        """Build a prosody projection layer if needed"""
+
+        if self.config.use_prosody_proj:
+            return Linear(
+                self.config.prosody_encoder_dim,
+                self.config.model_dim,
+                bias=True,
+                dtype=self.dtype,
+                device=self.device,
+            )
+        else:
+            return None
+
 
 
 def create_unity_t2u_model(
 def create_unity_t2u_model(
     config: UnitYT2UConfig,
     config: UnitYT2UConfig,