Prechádzať zdrojové kódy

Introduce Prosody encoder (#87)

Can Balioglu 1 rok pred
rodič
commit
05419775be

+ 1 - 1
pyproject.toml

@@ -12,7 +12,7 @@ per-file-ignores = [
 profile = "black"
 
 [tool.mypy]
-disable_error_code = "type-abstract"
+disable_error_code = "type-abstract,typeddict-unknown-key"
 disallow_untyped_calls = false
 disallow_untyped_decorators = false
 ignore_missing_imports = true

+ 51 - 0
src/seamless_communication/cards/seamless_expressivity.yaml

@@ -0,0 +1,51 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: seamless_expressivity
+base: unity_nllb-100
+model_arch: expressivity_v2
+char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
+checkpoint: "file://checkpoint/hygong/Expressivity/multilingual_models/m2m.clean.ecapa_tdnn2.dim512.all.all.lr5e-05.mk4k.config_t2_fbank_nosa_gcmvn_10k.rdrop0.ls0.2.uf3.wu5k.fp16.mem_fp16.seed1.dr0.1.ld0.2.mp0.3.cmp0.25.ma.ak8.as8.al1.ald0.0.dld0.0.ca.D24L.t2uE4L.t2uD4L.usesfilm.inj_dec.ngpu64/checkpoint_best_export.pt"
+num_units: 10000
+unit_langs:
+  - arb
+  - ben
+  - cat
+  - ces
+  - cmn
+  - cym
+  - dan
+  - deu
+  - eng
+  - est
+  - fin
+  - fra
+  - hin
+  - ind
+  - ita
+  - jpn
+  - kan
+  - kor
+  - mlt
+  - nld
+  - pes
+  - pol
+  - por
+  - ron
+  - rus
+  - slk
+  - spa
+  - swe
+  - swh
+  - tam
+  - tel
+  - tgl
+  - tha
+  - tur
+  - ukr
+  - urd
+  - uzn
+  - vie

+ 1 - 2
src/seamless_communication/cli/eval_utils/compute_metrics.py

@@ -10,17 +10,16 @@ from typing import Tuple, Union
 
 import pandas as pd
 import whisper
-
 from fairseq2.typing import Device
 from jiwer import cer, wer
 from sacrebleu.metrics.base import Score, Signature
 from sacrebleu.metrics.bleu import BLEU
 from sacrebleu.metrics.chrf import CHRF
-from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 from tqdm import tqdm
 from whisper import Whisper
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 
+from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 
 logging.basicConfig(
     level=logging.INFO,

+ 0 - 0
src/seamless_communication/cli/expressivity/__init__.py


+ 0 - 0
src/seamless_communication/cli/expressivity/evaluate/__init__.py


+ 423 - 0
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -0,0 +1,423 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import contextlib
+import logging
+import subprocess
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+import torchaudio
+from fairseq2.data import Collater, CString, DataPipeline, FileMapper
+from fairseq2.data.audio import (
+    AudioDecoder,
+    WaveformToFbankConverter,
+    WaveformToFbankOutput,
+)
+from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
+from fairseq2.data.typing import PathLike, StringLike
+from fairseq2.generation import SequenceGeneratorOptions
+from fairseq2.typing import DataType, Device
+from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
+from torch import Tensor
+from tqdm import tqdm
+
+from seamless_communication.cli.m4t.predict import (
+    add_inference_arguments,
+    set_generation_opts,
+)
+from seamless_communication.inference import BatchedSpeechOutput, Modality, Translator
+from seamless_communication.models.unity import load_unity_text_tokenizer
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class EvalContext:
+    task: str
+    """String representing the task. Valid choices are
+    "S2ST", "S2TT", "T2ST", "T2TT", "ASR"."""
+
+    output_modality: Modality
+    """The output modality of the task."""
+
+    model_name: str
+    """The name of the S2T UnitY model."""
+
+    data_file: Path
+    """The pathname of the test TSV data file."""
+
+    audio_root_dir: Optional[Path]
+    """The pathname of the directory under which
+    audio files are stored."""
+
+    target_lang: str
+    """The target translation language."""
+
+    source_lang: Optional[str]
+    """The source language."""
+
+    batch_size: int
+    """The batch size for model input."""
+
+    device: Device
+    """The device on which to run inference."""
+
+    dtype: DataType
+    """The data type with which to run inference."""
+
+    output_path: Path
+    """The pathname of the output directory to save
+    the evaluation results."""
+
+    ref_field: str
+    """The reference target text field to compute
+    the BLEU score against."""
+
+    text_generation_opts: SequenceGeneratorOptions
+    """Text generation hyperparameters."""
+
+    unit_generation_opts: Optional[SequenceGeneratorOptions]
+    """Unit generation hyperparameters, not applicable
+    for the NAR T2U decoder."""
+
+    unit_generation_ngram_filtering: bool
+    """If True, removes consecutive repeating ngrams
+    from the decoded unit output."""
+
+    gcmvn_stats: Optional[PathLike] = None
+    """the stats for gcmvn, used by Prosody Encoder"""
+
+
+def count_lines(filename: Path) -> int:
+    result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
+    return int(result.stdout.decode().split()[0])
+
+
+def build_data_pipeline(
+    ctx: EvalContext,
+    text_tokenizer: TextTokenizer,
+) -> DataPipeline:
+    with open(ctx.data_file, "r") as f:
+        header = f.readline().strip("\n").split("\t")
+
+    # TODO: This will be soon auto-tuned. Right now hand-tuned for devfair.
+    n_parallel = 4
+
+    split_tsv = StrSplitter(names=header)
+
+    if ctx.gcmvn_stats is not None:
+        if isinstance(ctx.gcmvn_stats, CString):
+            ctx.gcmvn_stats = str(ctx.gcmvn_stats)
+        gcmvn_stats: Dict[str, np.ndarray] = np.load(ctx.gcmvn_stats)  # type: ignore[type-arg]
+        gcmvn_mean = torch.tensor(
+            gcmvn_stats["mean"], device=ctx.device, dtype=ctx.dtype
+        )
+        gcmvn_std = torch.tensor(gcmvn_stats["std"], device=ctx.device, dtype=ctx.dtype)
+
+    pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
+
+    assert ctx.audio_root_dir is not None
+
+    map_file = FileMapper(root_dir=ctx.audio_root_dir, cached_fd_count=10)
+
+    pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
+
+    decode_audio = AudioDecoder(dtype=torch.float32, device=ctx.device)
+
+    convert_to_fbank = WaveformToFbankConverter(
+        num_mel_bins=80,
+        waveform_scale=2**15,
+        channel_last=True,
+        standardize=False,
+        device=ctx.device,
+        dtype=ctx.dtype,
+    )
+
+    def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
+        fbank = data["fbank"]
+        std, mean = torch.std_mean(fbank, dim=0)
+        data["fbank"] = fbank.subtract(mean).divide(std)
+        if ctx.gcmvn_stats is not None:
+            data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
+        return data
+
+    pipeline_builder.map(
+        [decode_audio, convert_to_fbank, normalize_fbank],
+        selector="audio.data",
+        num_parallel_calls=n_parallel,
+    )
+
+    pipeline_builder.bucket(bucket_size=ctx.batch_size)
+
+    collate = Collater(pad_value=0, pad_to_multiple=1)
+
+    pipeline_builder.map(collate, num_parallel_calls=n_parallel)
+
+    pipeline_builder.prefetch(4)
+
+    return pipeline_builder.and_return()
+
+
+def adjust_output_for_corrupted_inputs(
+    valid_sequences: Tensor,
+    text_output: List[StringLike],
+    speech_output: Optional[BatchedSpeechOutput],
+) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
+    adjusted_text_output: List[StringLike] = []
+    adjusted_speech_output: Optional[BatchedSpeechOutput] = None
+
+    if speech_output is not None:
+        assert (
+            len(text_output)
+            == len(speech_output.units)
+            == len(speech_output.audio_wavs)
+        )
+        adjusted_speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
+
+    batch_counter = 0
+    for is_valid in valid_sequences:
+        if is_valid:
+            adjusted_text_output.append(text_output[batch_counter])
+            if speech_output is not None:
+                assert adjusted_speech_output is not None
+                adjusted_speech_output.units.append(speech_output.units[batch_counter])
+                adjusted_speech_output.audio_wavs.append(
+                    speech_output.audio_wavs[batch_counter]
+                )
+            batch_counter += 1
+        else:
+            # For the corrupted inputs, we save the following dummy outputs:
+            # empty string for text, empty list for units, 1 second of silence for audio.
+            adjusted_text_output.append("")
+            if adjusted_speech_output is not None:
+                sample_rate = adjusted_speech_output.sample_rate
+                adjusted_speech_output.units.append([])
+                adjusted_speech_output.audio_wavs.append(
+                    torch.zeros(sample_rate).unsqueeze(0).unsqueeze(0)
+                )
+    return (
+        adjusted_text_output,
+        adjusted_speech_output,
+    )
+
+
+def run_eval(
+    translator: Translator, text_tokenizer: TextTokenizer, ctx: EvalContext
+) -> None:
+    pipeline = build_data_pipeline(ctx, text_tokenizer)
+
+    total_steps = count_lines(ctx.data_file) - 1
+    progress_bar = tqdm(total=total_steps)
+
+    output_path = ctx.output_path / ctx.data_file.stem
+    output_path.mkdir(parents=True, exist_ok=True)
+
+    if ctx.output_modality == Modality.SPEECH:
+        waveforms_dir = output_path / f"waveform_{ctx.data_file.stem}"
+        waveforms_dir.mkdir(parents=True, exist_ok=True)
+
+    hyps = []
+    refs = []
+
+    with contextlib.ExitStack() as stack:
+        hyp_file = stack.enter_context(
+            open(output_path / f"text_output-{ctx.data_file.stem}.txt", "w")
+        )
+        if ctx.output_modality == Modality.SPEECH:
+            unit_file = stack.enter_context(
+                open(output_path / f"unit_output-{ctx.data_file.stem}.txt", "w")
+            )
+
+        sample_id = 0
+        for example in pipeline:
+            valid_sequences: Optional[Tensor] = None
+            src = example["audio"]["data"]["fbank"]
+            # Skip corrupted audio tensors.
+            valid_sequences = ~torch.any(
+                torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
+            )
+            if not valid_sequences.all():
+                logger.warning(
+                    f"Sample IDs {sample_id} to {sample_id + ctx.batch_size} has some corrupted input."
+                )
+                src["seqs"] = src["seqs"][valid_sequences]
+                src["seq_lens"] = src["seq_lens"][valid_sequences]
+
+            # Skip performing inference when the input is entirely corrupted.
+            if src["seqs"].numel() > 0:
+                (
+                    text_output,
+                    speech_output,
+                ) = translator.predict(
+                    src,
+                    ctx.task,
+                    ctx.target_lang,
+                    src_lang=ctx.source_lang,
+                    text_generation_opts=ctx.text_generation_opts,
+                    unit_generation_opts=ctx.unit_generation_opts,
+                    unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
+                    gcmvn_fbank=example["audio"]["data"].get("gcmvn_fbank", None),
+                )
+            else:
+                text_output = []
+                if ctx.output_modality == Modality.SPEECH:
+                    speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
+                else:
+                    speech_output = None
+
+            if valid_sequences is not None and not valid_sequences.all():
+                (
+                    text_output,
+                    speech_output,
+                ) = adjust_output_for_corrupted_inputs(
+                    valid_sequences,
+                    text_output,
+                    speech_output,
+                )
+
+            hyps += [str(s) for s in text_output]
+            refs += [str(s) for s in example[ctx.ref_field]]
+
+            for i in range(len(text_output)):
+                t = text_output[i]
+                hyp_file.write(f"{t}\n")
+
+                if ctx.output_modality == Modality.SPEECH:
+                    assert speech_output is not None
+                    u = speech_output.units[i]
+                    str_units = [str(i) for i in u]
+                    unit_file.write(" ".join(str_units) + "\n")
+                    torchaudio.save(
+                        waveforms_dir / f"{sample_id}_pred.wav",
+                        speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
+                        sample_rate=speech_output.sample_rate,
+                    )
+
+                sample_id += 1
+                progress_bar.update(1)
+
+    progress_bar.close()
+    logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
+
+    assert len(hyps) == len(refs)
+    if len(hyps) > 0:
+        if ctx.target_lang in ("cmn", "jpn", "lao", "mya", "tha"):
+            tokenizer = "char"
+        else:
+            tokenizer = "13a"
+
+        bleu = BLEU(tokenize=tokenizer)
+        score = bleu.corpus_score(hyps, [refs])
+        bleu_filename = output_path / f"{ctx.data_file.stem}_text_output_bleu.json"
+        with open(bleu_filename, "w") as f:
+            f.write(score.format(signature=str(bleu.get_signature()), is_json=True))
+        logger.info(score.format(signature=bleu.get_signature()))
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(
+        description="Expressivity evaluation for tasks supported by Translator."
+    )
+    parser.add_argument("data_file", type=str, help="Data file (.tsv) to be evaluated.")
+
+    parser = add_inference_arguments(parser)
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        help="Inference batch size.",
+        default=4,
+    )
+    parser.add_argument(
+        "--audio_root_dir",
+        type=str,
+        help="Root directory for the audio filenames in the data file.",
+        default="",
+    )
+    parser.add_argument(
+        "--ref_field",
+        type=str,
+        help="Reference target text field to compute the BLEU score against.",
+        default="tgt_text",
+    )
+    parser.add_argument(
+        "--gcmvn_stats",
+        type=str,
+        help="The path to gcmvn fbank stats, if provided, the DataPipeline'd have another copy of gcmvn fbank features (for P2V enc)",
+        default=None,
+    )
+    args = parser.parse_args()
+
+    input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
+
+    if input_modality == Modality.SPEECH and not Path(args.audio_root_dir).exists():
+        raise ValueError(
+            f"Invalid audio_root_dir: {args.audio_root_dir} for speech input."
+        )
+
+    if torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        dtype = torch.float32
+    else:
+        device = torch.device("cpu")
+        dtype = torch.float32
+
+    text_tokenizer = load_unity_text_tokenizer(args.model_name)
+
+    # TODO: Avoid loading the T2U model, vocoder when the output
+    # modality is text.
+    translator = Translator(
+        args.model_name,
+        args.vocoder_name,
+        device,
+        text_tokenizer=text_tokenizer,
+        dtype=dtype,
+    )
+
+    text_generation_opts, unit_generation_opts = set_generation_opts(args)
+
+    logger.info(f"{text_generation_opts=}")
+    logger.info(f"{unit_generation_opts=}")
+    logger.info(
+        f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
+    )
+
+    # fmt: off
+    ctx = EvalContext(
+        task=args.task,
+        output_modality=output_modality,
+        model_name=args.model_name,
+        data_file=Path(args.data_file),
+        audio_root_dir=Path(args.audio_root_dir),
+        target_lang=args.tgt_lang,
+        source_lang=args.src_lang,
+        batch_size=args.batch_size,
+        device=device,
+        dtype=dtype,
+        ref_field=args.ref_field,
+        text_generation_opts=text_generation_opts,
+        unit_generation_opts=unit_generation_opts,
+        unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
+        output_path=Path(args.output_path),
+        gcmvn_stats=args.gcmvn_stats,
+    )
+    # fmt: on
+    logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
+
+    run_eval(translator, text_tokenizer, ctx)
+
+
+if __name__ == "__main__":
+    main()

+ 8 - 2
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -267,7 +267,10 @@ def run_eval(
 
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
-                (text_output, speech_output,) = translator.predict(
+                (
+                    text_output,
+                    speech_output,
+                ) = translator.predict(
                     src,
                     ctx.task,
                     ctx.target_lang,
@@ -284,7 +287,10 @@ def run_eval(
                     speech_output = None
 
             if valid_sequences is not None and not valid_sequences.all():
-                (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
+                (
+                    text_output,
+                    speech_output,
+                ) = adjust_output_for_corrupted_inputs(
                     valid_sequences,
                     text_output,
                     speech_output,

+ 10 - 2
src/seamless_communication/inference/generator.py

@@ -153,6 +153,7 @@ class UnitYGenerator:
         input_modality: str = "speech",
         output_modality: str = "speech",
         ngram_filtering: bool = False,
+        gcmvn_seqs: Optional[Tensor] = None,
     ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
         """
         :param source_seqs:
@@ -215,6 +216,12 @@ class UnitYGenerator:
         assert self.unit_decoder is not None
 
         unit_gen_output = None
+        prosody_encoder_out = None
+        if self.model.prosody_encoder_model is not None:
+            prosody_encoder_out = self.model.prosody_encoder_model(
+                gcmvn_seqs, source_padding_mask
+            ).unsqueeze(1)
+
         if isinstance(self.model.t2u_model, UnitYT2UModel):
             assert self.unit_generator is not None
             t2u_encoder_output, t2u_encoder_padding_mask = self.model.t2u_model.encode(
@@ -231,6 +238,7 @@ class UnitYGenerator:
                 text_decoder_output=decoder_output,
                 text_decoder_padding_mask=decoder_padding_mask,
                 text_seqs=text_seqs,
+                film_cond_emb=prosody_encoder_out,
             )
             # (B, S_unit, V_unit)
             unit_seqs = unit_decoder_output.logits.argmax(dim=2)
@@ -243,8 +251,8 @@ class UnitYGenerator:
         units = self.unit_decoder(unit_seqs)
 
         if ngram_filtering:
-            units = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
-            units = torch.tensor(units)
+            arr = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
+            units = torch.tensor(arr)
 
         unit_output = SequenceToUnitOutput(units, unit_gen_output)
 

+ 11 - 4
src/seamless_communication/inference/translator.py

@@ -7,7 +7,7 @@ import logging
 from dataclasses import dataclass
 from enum import Enum, auto
 from pathlib import Path
-from typing import Any, Dict, Callable, List, Optional, Tuple, Union, cast
+from typing import Callable, List, Optional, Tuple, Union, cast
 
 import torch
 import torch.nn as nn
@@ -144,6 +144,7 @@ class Translator(nn.Module):
         text_generation_opts: SequenceGeneratorOptions,
         unit_generation_opts: Optional[SequenceGeneratorOptions],
         unit_generation_ngram_filtering: bool = False,
+        gcmvn_fbank: Optional[SequenceData] = None,
     ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
         # We disregard unit generations opts for the NAR T2U decoder.
         if output_modality != Modality.SPEECH or isinstance(
@@ -160,12 +161,18 @@ class Translator(nn.Module):
             unit_opts=unit_generation_opts,
         )
         seqs, padding_mask = get_seqs_and_padding_mask(src)
+        if gcmvn_fbank is not None:
+            gcmvn_seqs = gcmvn_fbank["seqs"]
+        else:
+            gcmvn_seqs = None
+
         return generator(
             seqs,
             padding_mask,
             input_modality.value,
             output_modality.value,
             ngram_filtering=unit_generation_ngram_filtering,
+            gcmvn_seqs=gcmvn_seqs,
         )
 
     @staticmethod
@@ -188,7 +195,7 @@ class Translator(nn.Module):
     @torch.inference_mode()
     def predict(
         self,
-        input: Union[str, Tensor, Dict[str, Any]],
+        input: Union[str, Tensor, SequenceData],
         task_str: str,
         tgt_lang: str,
         src_lang: Optional[str] = None,
@@ -201,6 +208,7 @@ class Translator(nn.Module):
         spkr: Optional[int] = -1,
         sample_rate: int = 16000,
         unit_generation_ngram_filtering: bool = False,
+        gcmvn_fbank: Optional[SequenceData] = None,
     ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
         """
         The main method used to perform inference on all tasks.
@@ -231,8 +239,6 @@ class Translator(nn.Module):
         input_modality, output_modality = self.get_modalities_from_task_str(task_str)
 
         if isinstance(input, dict):
-            assert "seqs" in input
-            assert "seq_lens" in input
             src = cast(SequenceData, input)
         elif input_modality == Modality.SPEECH:
             audio = input
@@ -282,6 +288,7 @@ class Translator(nn.Module):
             text_generation_opts,
             unit_generation_opts,
             unit_generation_ngram_filtering=unit_generation_ngram_filtering,
+            gcmvn_fbank=gcmvn_fbank,
         )
 
         if output_modality == Modality.TEXT:

+ 16 - 0
src/seamless_communication/models/pretssel/__init__.py

@@ -0,0 +1,16 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN as ECAPA_TDNN
+from seamless_communication.models.pretssel.ecapa_tdnn_builder import (
+    EcapaTDNNBuilder as EcapaTDNNBuilder,
+)
+from seamless_communication.models.pretssel.ecapa_tdnn_builder import (
+    EcapaTDNNConfig as EcapaTDNNConfig,
+)
+from seamless_communication.models.pretssel.ecapa_tdnn_builder import (
+    ecapa_tdnn_archs as ecapa_tdnn_archs,
+)

+ 477 - 0
src/seamless_communication/models/pretssel/ecapa_tdnn.py

@@ -0,0 +1,477 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from fairseq2.nn.padding import PaddingMask, to_padding_mask
+from torch import Tensor
+from torch.nn import Conv1d, LayerNorm, Module, ModuleList, ReLU, Sigmoid, Tanh, init
+
+
+class ECAPA_TDNN(Module):
+    """
+    Represents the ECAPA-TDNN model described in paper:
+    :cite:t`https://doi.org/10.48550/arxiv.2005.07143`.
+
+    Arguments
+    ---------
+    :param channels:
+        Output channels for TDNN/SERes2Net layer.
+    :param kernel_sizes:
+        List of kernel sizes for each layer.
+    :param dilations:
+        List of dilations for kernels in each layer.
+    :param groups:
+        List of groups for kernels in each layer.
+    """
+
+    def __init__(
+        self,
+        channels: List[int],
+        kernel_sizes: List[int],
+        dilations: List[int],
+        attention_channels: int,
+        res2net_scale: int,
+        se_channels: int,
+        global_context: bool,
+        groups: List[int],
+        embed_dim: int,
+        input_dim: int,
+    ):
+        super().__init__()
+        assert len(channels) == len(kernel_sizes) == len(dilations)
+        self.channels = channels
+        self.embed_dim = embed_dim
+        self.blocks = ModuleList()
+
+        self.blocks.append(
+            TDNNBlock(
+                input_dim,
+                channels[0],
+                kernel_sizes[0],
+                dilations[0],
+                groups[0],
+            )
+        )
+
+        # SE-Res2Net layers
+        for i in range(1, len(channels) - 1):
+            self.blocks.append(
+                SERes2NetBlock(
+                    channels[i - 1],
+                    channels[i],
+                    res2net_scale=res2net_scale,
+                    se_channels=se_channels,
+                    kernel_size=kernel_sizes[i],
+                    dilation=dilations[i],
+                    groups=groups[i],
+                )
+            )
+
+        # Multi-layer feature aggregation
+        self.mfa = TDNNBlock(
+            channels[-1],
+            channels[-1],
+            kernel_sizes[-1],
+            dilations[-1],
+            groups=groups[-1],
+        )
+
+        # Attentive Statistical Pooling
+        self.asp = AttentiveStatisticsPooling(
+            channels[-1],
+            attention_channels=attention_channels,
+            global_context=global_context,
+        )
+        self.asp_norm = LayerNorm(channels[-1] * 2, eps=1e-12)
+
+        # Final linear transformation
+        self.fc = Conv1d(
+            in_channels=channels[-1] * 2,
+            out_channels=embed_dim,
+            kernel_size=1,
+        )
+
+        self.reset_parameters()
+
+    def reset_parameters(self) -> None:
+        """Reset the parameters and buffers of the module."""
+
+        def encoder_init(m: Module) -> None:
+            if isinstance(m, Conv1d):
+                init.xavier_uniform_(m.weight, init.calculate_gain("relu"))
+
+        self.apply(encoder_init)
+
+    def forward(
+        self,
+        x: Tensor,
+        padding_mask: Optional[PaddingMask] = None,
+    ) -> Tensor:
+        """Returns the embedding vector.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of shape (batch, time, channel).
+        """
+        # Minimize transpose for efficiency
+        x = x.transpose(1, 2)
+
+        xl = []
+        for layer in self.blocks:
+            x = layer(x, padding_mask=padding_mask)
+            xl.append(x)
+
+        # Multi-layer feature aggregation
+        x = torch.cat(xl[1:], dim=1)
+        x = self.mfa(x)
+
+        # Attentive Statistical Pooling
+        x = self.asp(x, padding_mask=padding_mask)
+        x = self.asp_norm(x.transpose(1, 2)).transpose(1, 2)
+
+        # Final linear transformation
+        x = self.fc(x)
+
+        x = x.transpose(1, 2).squeeze(1)  # B x C
+        return F.normalize(x, dim=-1)
+
+
+class TDNNBlock(Module):
+    """An implementation of TDNN.
+
+    Arguments
+    ----------
+    :param in_channels : int
+        Number of input channels.
+    :param out_channels : int
+        The number of output channels.
+    :param kernel_size : int
+        The kernel size of the TDNN blocks.
+    :param dilation : int
+        The dilation of the TDNN block.
+    :param groups: int
+        The groups size of the TDNN blocks.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
+    >>> out_tensor = layer(inp_tensor).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        dilation: int,
+        groups: int = 1,
+    ):
+        super().__init__()
+        self.conv = Conv1d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            dilation=dilation,
+            padding=dilation * (kernel_size - 1) // 2,
+            groups=groups,
+        )
+        self.activation = ReLU()
+        self.norm = LayerNorm(out_channels, eps=1e-12)
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        x = self.activation(self.conv(x))
+
+        return self.norm(x.transpose(1, 2)).transpose(1, 2)  # type: ignore[no-any-return]
+
+
+class Res2NetBlock(Module):
+    """An implementation of Res2NetBlock w/ dilation.
+
+    Arguments
+    ---------
+    :param in_channels : int
+        The number of channels expected in the input.
+    :param out_channels : int
+        The number of output channels.
+    :param scale : int
+        The scale of the Res2Net block.
+    :param kernel_size: int
+        The kernel size of the Res2Net block.
+    :param dilation : int
+        The dilation of the Res2Net block.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
+    >>> out_tensor = layer(inp_tensor).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        scale: int = 8,
+        kernel_size: int = 3,
+        dilation: int = 1,
+    ):
+        super().__init__()
+        assert in_channels % scale == 0
+        assert out_channels % scale == 0
+
+        in_channel = in_channels // scale
+        hidden_channel = out_channels // scale
+        self.blocks = ModuleList(
+            [
+                TDNNBlock(
+                    in_channel,
+                    hidden_channel,
+                    kernel_size=kernel_size,
+                    dilation=dilation,
+                )
+                for i in range(scale - 1)
+            ]
+        )
+        self.scale = scale
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        y = []
+        for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
+            if i == 0:
+                y_i = x_i
+            elif i == 1:
+                y_i = self.blocks[i - 1](x_i)
+            else:
+                y_i = self.blocks[i - 1](x_i + y_i)
+            y.append(y_i)
+
+        y_tensor = torch.cat(y, dim=1)
+        return y_tensor
+
+
+class SEBlock(Module):
+    """An implementation of squeeze-and-excitation block.
+
+    Arguments
+    ---------
+    in_channels : int
+        The number of input channels.
+    se_channels : int
+        The number of output channels after squeeze.
+    out_channels : int
+        The number of output channels.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        se_channels: int,
+        out_channels: int,
+    ):
+        super().__init__()
+
+        self.conv1 = Conv1d(
+            in_channels=in_channels, out_channels=se_channels, kernel_size=1
+        )
+        self.relu = ReLU(inplace=True)
+        self.conv2 = Conv1d(
+            in_channels=se_channels, out_channels=out_channels, kernel_size=1
+        )
+        self.sigmoid = Sigmoid()
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        if padding_mask is not None:
+            mask = padding_mask.materialize().unsqueeze(1)
+            s = (x * mask).sum(dim=2, keepdim=True) / padding_mask.seq_lens[
+                :, None, None
+            ]
+        else:
+            s = x.mean(dim=2, keepdim=True)
+
+        s = self.relu(self.conv1(s))
+        s = self.sigmoid(self.conv2(s))
+
+        return s * x
+
+
+class AttentiveStatisticsPooling(Module):
+    """This class implements an attentive statistic pooling layer for each channel.
+    It returns the concatenated mean and std of the input tensor.
+
+    Arguments
+    ---------
+    channels: int
+        The number of input channels.
+    attention_channels: int
+        The number of attention channels.
+    """
+
+    def __init__(
+        self, channels: int, attention_channels: int = 128, global_context: bool = True
+    ):
+        super().__init__()
+
+        self.eps = 1e-12
+        self.global_context = global_context
+        if global_context:
+            self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
+        else:
+            self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
+
+        self.tanh = Tanh()
+        self.conv = Conv1d(
+            in_channels=attention_channels, out_channels=channels, kernel_size=1
+        )
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Calculates mean and std for a batch (input tensor).
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of shape [N, C, L].
+        """
+        L = x.shape[-1]
+
+        def _compute_statistics(
+            x: Tensor, m: Tensor, dim: int = 2, eps: float = self.eps
+        ) -> Tuple[Tensor, Tensor]:
+            mean = (m * x).sum(dim)
+            std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
+            return mean, std
+
+        # if lengths is None:
+        #     lengths = [x.shape[0]]
+
+        # Make binary mask of shape [N, 1, L]
+        # mask = to_padding_mask(lengths, max(lengths))
+        if padding_mask is not None:
+            mask = padding_mask.materialize()
+        else:
+            mask = to_padding_mask(torch.IntTensor([L]), L).repeat(x.shape[0], 1).to(x)
+        mask = mask.unsqueeze(1)
+
+        # Expand the temporal context of the pooling layer by allowing the
+        # self-attention to look at global properties of the utterance.
+        if self.global_context:
+            # torch.std is unstable for backward computation
+            # https://github.com/pytorch/pytorch/issues/4320
+            total = mask.sum(dim=2, keepdim=True).to(x)
+            mean, std = _compute_statistics(x, mask / total)
+            mean = mean.unsqueeze(2).repeat(1, 1, L)
+            std = std.unsqueeze(2).repeat(1, 1, L)
+            attn = torch.cat([x, mean, std], dim=1)
+        else:
+            attn = x
+
+        # Apply layers
+        attn = self.conv(self.tanh(self.tdnn(attn)))
+
+        # Filter out zero-paddings
+        attn = attn.masked_fill(mask == 0, float("-inf"))
+
+        attn = F.softmax(attn, dim=2)
+        mean, std = _compute_statistics(x, attn)
+        # Append mean and std of the batch
+        pooled_stats = torch.cat((mean, std), dim=1)
+        pooled_stats = pooled_stats.unsqueeze(2)
+
+        return pooled_stats
+
+
+class SERes2NetBlock(Module):
+    """An implementation of building block in ECAPA-TDNN, i.e.,
+    TDNN-Res2Net-TDNN-SEBlock.
+
+    Arguments
+    ----------
+    out_channels: int
+        The number of output channels.
+    res2net_scale: int
+        The scale of the Res2Net block.
+    kernel_size: int
+        The kernel size of the TDNN blocks.
+    dilation: int
+        The dilation of the Res2Net block.
+    groups: int
+    Number of blocked connections from input channels to output channels.
+
+    Example
+    -------
+    >>> x = torch.rand(8, 120, 64).transpose(1, 2)
+    >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
+    >>> out = conv(x).transpose(1, 2)
+    >>> out.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        res2net_scale: int = 8,
+        se_channels: int = 128,
+        kernel_size: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+    ):
+        super().__init__()
+        self.out_channels = out_channels
+        self.tdnn1 = TDNNBlock(
+            in_channels,
+            out_channels,
+            kernel_size=1,
+            dilation=1,
+            groups=groups,
+        )
+        self.res2net_block = Res2NetBlock(
+            out_channels,
+            out_channels,
+            res2net_scale,
+            kernel_size,
+            dilation,
+        )
+        self.tdnn2 = TDNNBlock(
+            out_channels,
+            out_channels,
+            kernel_size=1,
+            dilation=1,
+            groups=groups,
+        )
+        self.se_block = SEBlock(out_channels, se_channels, out_channels)
+
+        self.shortcut = None
+        if in_channels != out_channels:
+            self.shortcut = Conv1d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+            )
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        residual = x
+        if self.shortcut:
+            residual = self.shortcut(x)
+
+        x = self.tdnn1(x)
+        x = self.res2net_block(x)
+        x = self.tdnn2(x)
+        x = self.se_block(x, padding_mask=padding_mask)
+
+        return x + residual

+ 112 - 0
src/seamless_communication/models/pretssel/ecapa_tdnn_builder.py

@@ -0,0 +1,112 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import List, Optional
+
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.typing import DataType, Device
+
+from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
+
+
+@dataclass
+class EcapaTDNNConfig:
+    channels: List[int]
+    kernel_sizes: List[int]
+    dilations: List[int]
+    attention_channels: int
+    res2net_scale: int
+    se_channels: int
+    global_context: bool
+    groups: List[int]
+    embed_dim: int
+    input_dim: int
+
+
+ecapa_tdnn_archs = ArchitectureRegistry[EcapaTDNNConfig]("ecapa_tdnn")
+
+ecapa_tdnn_arch = ecapa_tdnn_archs.marker
+
+
+@ecapa_tdnn_arch("base")
+def _base_ecapa_tdnn() -> EcapaTDNNConfig:
+    return EcapaTDNNConfig(
+        channels=[512, 512, 512, 512, 1536],
+        kernel_sizes=[5, 3, 3, 3, 1],
+        dilations=[1, 2, 3, 4, 1],
+        attention_channels=128,
+        res2net_scale=8,
+        se_channels=128,
+        global_context=True,
+        groups=[1, 1, 1, 1, 1],
+        embed_dim=512,
+        input_dim=80,
+    )
+
+
+class EcapaTDNNBuilder:
+    """
+    Builder module for ECAPA_TDNN model
+    """
+
+    config: EcapaTDNNConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
+
+    def __init__(
+        self,
+        config: EcapaTDNNConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param devicev:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+
+        self.device, self.dtype = device, dtype
+
+    def build_model(self) -> ECAPA_TDNN:
+        """Build a model."""
+        model = ECAPA_TDNN(
+            self.config.channels,
+            self.config.kernel_sizes,
+            self.config.dilations,
+            self.config.attention_channels,
+            self.config.res2net_scale,
+            self.config.se_channels,
+            self.config.global_context,
+            self.config.groups,
+            self.config.embed_dim,
+            self.config.input_dim,
+        )
+        model.to(device=self.device, dtype=self.dtype)
+        return model
+
+
+def create_ecapa_tdnn_model(
+    config: EcapaTDNNConfig,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> ECAPA_TDNN:
+    """Create a ECAPA_TDNN model.
+
+    :param config:
+        The configuration to use.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+
+    return EcapaTDNNBuilder(config, device=device, dtype=dtype).build_model()

+ 1 - 0
src/seamless_communication/models/unity/__init__.py

@@ -20,6 +20,7 @@ from seamless_communication.models.unity.char_tokenizer import (
 from seamless_communication.models.unity.char_tokenizer import (
     load_unity_char_tokenizer as load_unity_char_tokenizer,
 )
+from seamless_communication.models.unity.film import FiLM
 from seamless_communication.models.unity.length_regulator import (
     HardUpsampling as HardUpsampling,
 )

+ 96 - 4
src/seamless_communication/models/unity/builder.py

@@ -14,15 +14,23 @@ from fairseq2.models.w2vbert import w2vbert_archs
 from fairseq2.models.wav2vec2 import Wav2Vec2EncoderBuilder, Wav2Vec2EncoderConfig
 from fairseq2.nn.projection import TiedProjection
 from fairseq2.nn.transformer import (
+    FeedForwardNetwork,
     MultiheadAttention,
     StandardFeedForwardNetwork,
     StandardMultiheadAttention,
     TransformerEncoder,
     TransformerEncoderLayer,
+    TransformerNormOrder,
     create_default_sdpa,
 )
-from fairseq2.typing import DataType, Device
+from fairseq2.typing import DataType, Device, override
+from torch.nn import GELU, ReLU
 
+from seamless_communication.models.pretssel import (
+    EcapaTDNNBuilder,
+    EcapaTDNNConfig,
+    ecapa_tdnn_archs,
+)
 from seamless_communication.models.unity.adaptor_block import (
     UnitYConformerAdaptorLayer,
     UnitYEncoderAdaptor,
@@ -59,12 +67,19 @@ class UnitYConfig:
     t2u_config: Optional[UnitYT2UConfig]
     """The configuration of the UnitY T2U sub-model."""
 
+    prosody_encoder_config: Optional[EcapaTDNNConfig]
+    """The configuration of the expressive prosody encoder."""
+
     use_text_encoder: bool
     """If ``True``, uses an aligned MT encoder for the MT task."""
 
     use_conformer_adaptor: bool
     """If ``True``, uses a Conformer-based adaptor block."""
 
+    use_gelu: bool
+    """If ``True``, uses GELU activation function in feed-forward networks of
+    adaptor blocks and decoder layers."""
+
     num_adaptor_layers: int
     """The number of Transformer encoder layers in the adaptor block."""
 
@@ -103,8 +118,10 @@ def _base() -> UnitYConfig:
         w2v2_encoder_config=w2vbert_config.w2v2_config.encoder_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
+        prosody_encoder_config=None,
         use_text_encoder=True,
         use_conformer_adaptor=False,
+        use_gelu=False,
         num_adaptor_layers=1,
         adaptor_kernel_size=8,
         adaptor_stride=8,
@@ -128,8 +145,10 @@ def _medium() -> UnitYConfig:
         w2v2_encoder_config=w2vbert_config.w2v2_config.encoder_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
+        prosody_encoder_config=None,
         use_text_encoder=True,
         use_conformer_adaptor=False,
+        use_gelu=False,
         num_adaptor_layers=1,
         adaptor_kernel_size=8,
         adaptor_stride=8,
@@ -155,8 +174,43 @@ def _base_v2() -> UnitYConfig:
         w2v2_encoder_config=w2v2_chunk_encoder_config,
         mt_model_config=mt_model_config,
         t2u_config=t2u_config,
+        prosody_encoder_config=None,
         use_text_encoder=True,
         use_conformer_adaptor=False,
+        use_gelu=False,
+        num_adaptor_layers=1,
+        adaptor_kernel_size=8,
+        adaptor_stride=8,
+        adaptor_layer_norm=True,
+        adaptor_dropout_p=0.1,
+    )
+
+
+@unity_arch("expressivity_v2")
+def _expressivity_v2() -> UnitYConfig:
+    w2v2_chunk_encoder_config = wav2vec2_chunk_archs.get_config("600m")
+
+    mt_model_config: NllbConfig = nllb_archs.get_config("dense_1b")
+
+    mt_model_config.vocab_info.size = 256102  # NLLB-100
+
+    mt_model_config.vocab_info.pad_idx = 1
+
+    mt_model_config.max_seq_len = 4000
+
+    t2u_config = unity_t2u_archs.get_config("expressivity_nar")
+
+    prosody_encoder_config = ecapa_tdnn_archs.get_config("base")
+
+    return UnitYConfig(
+        model_dim=1024,
+        w2v2_encoder_config=w2v2_chunk_encoder_config,
+        mt_model_config=mt_model_config,
+        t2u_config=t2u_config,
+        prosody_encoder_config=prosody_encoder_config,
+        use_text_encoder=False,
+        use_conformer_adaptor=False,
+        use_gelu=True,
         num_adaptor_layers=1,
         adaptor_kernel_size=8,
         adaptor_stride=8,
@@ -176,6 +230,7 @@ class UnitYBuilder:
     w2v2_encoder_builder: Wav2Vec2EncoderBuilder
     mt_model_builder: NllbBuilder
     t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None]
+    prosody_encoder_builder: Optional[EcapaTDNNBuilder]
     device: Optional[Device]
     dtype: Optional[DataType]
 
@@ -185,6 +240,7 @@ class UnitYBuilder:
         w2v2_encoder_builder: Wav2Vec2EncoderBuilder,
         mt_model_builder: NllbBuilder,
         t2u_builder: Union[UnitYT2UBuilder, UnitYNART2UBuilder, None],
+        prosody_encoder_builder: Optional[EcapaTDNNBuilder],
         *,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
@@ -223,6 +279,7 @@ class UnitYBuilder:
         self.w2v2_encoder_builder = w2v2_encoder_builder
         self.mt_model_builder = mt_model_builder
         self.t2u_builder = t2u_builder
+        self.prosody_encoder_builder = prosody_encoder_builder
 
         self.device, self.dtype = device, dtype
 
@@ -251,6 +308,11 @@ class UnitYBuilder:
         else:
             t2u_model = self.t2u_builder.build_model()
 
+        if self.prosody_encoder_builder is None:
+            prosody_encoder_model = None
+        else:
+            prosody_encoder_model = self.prosody_encoder_builder.build_model()
+
         return UnitYModel(
             speech_encoder_frontend,
             speech_encoder,
@@ -261,6 +323,7 @@ class UnitYBuilder:
             final_proj,
             t2u_model,
             self.config.mt_model_config.vocab_info,
+            prosody_encoder_model,
         )
 
     def build_speech_encoder(self) -> TransformerEncoder:
@@ -292,11 +355,10 @@ class UnitYBuilder:
             self.w2v2_encoder_builder.config.num_encoder_attn_heads
         )
 
-        # Unlike wav2vec2, we use ReLU (i.e. standard FFN activation function)
-        # instead of GELU.
         ffn = StandardFeedForwardNetwork(
             self.config.model_dim,
             self.w2v2_encoder_builder.config.ffn_inner_dim,
+            inner_activation=GELU() if self.config.use_gelu else ReLU(),
             bias=True,
             device=self.device,
             dtype=self.dtype,
@@ -365,6 +427,20 @@ class UnitYBuilder:
         )
 
 
+class NllbWithGELUBuilder(NllbBuilder):
+    @override
+    def build_ffn(self) -> FeedForwardNetwork:
+        return StandardFeedForwardNetwork(
+            self.config.model_dim,
+            self.config.ffn_inner_dim,
+            bias=True,
+            inner_activation=GELU(),
+            norm_order=TransformerNormOrder.PRE,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+
 def create_unity_model(
     config: UnitYConfig,
     device: Optional[Device] = None,
@@ -397,12 +473,28 @@ def create_unity_model(
     else:
         t2u_builder = UnitYNART2UBuilder(config.t2u_config, device=device, dtype=dtype)
 
-    mt_model_builder = NllbBuilder(config.mt_model_config, device=device, dtype=dtype)
+    if config.prosody_encoder_config is None:
+        prosody_encoder_builder = None
+    else:
+        prosody_encoder_builder = EcapaTDNNBuilder(
+            config.prosody_encoder_config, device=device, dtype=dtype
+        )
+
+    if config.use_gelu:
+        mt_model_builder: NllbBuilder = NllbWithGELUBuilder(
+            config.mt_model_config, device=device, dtype=dtype
+        )
+    else:
+        mt_model_builder = NllbBuilder(
+            config.mt_model_config, device=device, dtype=dtype
+        )
+
     unity_builder = UnitYBuilder(
         config,
         w2v2_encoder_builder,
         mt_model_builder,
         t2u_builder,
+        prosody_encoder_builder,
         device=device,
         dtype=dtype,
     )

+ 68 - 0
src/seamless_communication/models/unity/film.py

@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from typing import Optional
+
+import torch
+from fairseq2.nn.projection import Linear
+from fairseq2.typing import DataType, Device
+from torch import Tensor
+from torch.nn import Module, Parameter
+
+
+class FiLM(Module):
+    """
+    A Feature-wise Linear Modulation Layer from
+    'FiLM: Visual Reasoning with a General Conditioning Layer'
+    """
+
+    proj: Linear
+    s_gamma: Parameter
+    s_beta: Parameter
+
+    def __init__(
+        self,
+        cond_dim: int,
+        embed_dim: int,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.proj = Linear(
+            cond_dim, 2 * embed_dim, bias=True, device=device, dtype=dtype
+        )
+
+        self.s_gamma = Parameter(
+            torch.ones(
+                1,
+                device=device,
+                dtype=dtype,
+            ),
+            requires_grad=True,
+        )
+
+        self.s_beta = Parameter(
+            torch.ones(
+                1,
+                device=device,
+                dtype=dtype,
+            ),
+            requires_grad=True,
+        )
+
+    def forward(self, x: Tensor, cond_embs: Tensor) -> Tensor:
+        """
+        x -- [B, T, H]
+        cond_emb -- [B, 1, C]
+        """
+        # get trainable gamma, beta
+        gammas, betas = self.proj(cond_embs).chunk(2, dim=-1)  # B x 1 x H
+
+        # apply film
+        gammas = self.s_gamma * gammas.expand_as(x)
+        betas = self.s_beta * betas.expand_as(x)
+
+        return (gammas + 1.0) * x + betas  # type: ignore[no-any-return]

+ 26 - 2
src/seamless_communication/models/unity/length_regulator.py

@@ -14,6 +14,8 @@ from fairseq2.typing import DataType, Device
 from torch import Tensor
 from torch.nn import Conv1d, Dropout, Module, ReLU, Sequential
 
+from seamless_communication.models.unity.film import FiLM
+
 
 class HardUpsampling(Module):
     """Upsamples sequences in a deterministic way as governed by durations."""
@@ -46,6 +48,7 @@ class VariancePredictor(Module):
     conv2: Sequential
     ln2: LayerNorm
     proj: Linear
+    film: Optional[FiLM]
 
     def __init__(
         self,
@@ -54,6 +57,8 @@ class VariancePredictor(Module):
         var_pred_kernel_size: int,
         var_pred_dropout: float,
         bias: bool = True,
+        use_film: bool = False,
+        film_cond_dim: int = 512,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
     ):
@@ -99,7 +104,19 @@ class VariancePredictor(Module):
             var_pred_hidden_dim, 1, bias=True, device=device, dtype=dtype
         )
 
-    def forward(self, seqs: Tensor, padding_mask: Optional[PaddingMask]) -> Tensor:
+        if use_film:
+            self.film = FiLM(
+                film_cond_dim, var_pred_hidden_dim, device=device, dtype=dtype
+            )
+        else:
+            self.register_module("film", None)
+
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+        film_cond_emb: Optional[Tensor] = None,
+    ) -> Tensor:
         # Ensure that we do not leak padded positions in the convolution layer.
         seqs = apply_padding_mask(seqs, padding_mask)
 
@@ -131,6 +148,12 @@ class VariancePredictor(Module):
 
         seqs = self.dropout_module(seqs)
 
+        seqs = apply_padding_mask(seqs, padding_mask)
+
+        if self.film is not None and film_cond_emb is not None:
+            seqs = self.film(seqs, film_cond_emb)
+            seqs = apply_padding_mask(seqs, padding_mask)
+
         # (N, S, H) -> (N, S, 1) -> (N, S)
         seqs = self.proj(seqs).squeeze(dim=2)
 
@@ -174,8 +197,9 @@ class VarianceAdaptor(Module):
         padding_mask: Optional[PaddingMask],
         duration_factor: float = 1.0,
         min_duration: int = 0,
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, PaddingMask]:
-        log_durations = self.duration_predictor(seqs, padding_mask)
+        log_durations = self.duration_predictor(seqs, padding_mask, film_cond_emb)
 
         durations = torch.clamp(
             torch.round((torch.exp(log_durations) - 1) * duration_factor).long(),

+ 79 - 35
src/seamless_communication/models/unity/loader.py

@@ -47,10 +47,16 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
 
         keys_to_delete = []
 
+        # ExpressiveUnitY model (from multi_arch codebase)
+        if config.prosody_encoder_config is not None:
+            encoder_key = "s2t_model.encoder"
+            decoder_key = "s2t_model.decoder"
+            t2u_decoder_key = "t2s_model.decoder"
         # X2T/S2T + T2U model.
-        if config.t2u_config is not None:
+        elif config.t2u_config is not None:
             encoder_key = "encoder"
             decoder_key = "target_letter_decoder"
+            t2u_decoder_key = "decoder"
         # X2T model.
         elif config.use_text_encoder:
             encoder_key = "speech_encoder"
@@ -70,12 +76,18 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         # Remnant of wav2vec2 pretraining, not needed for eval or fine-tuning.
         keys_to_delete.append(f"{encoder_key}.w2v_encoder.w2v_model.mask_emb")
 
-        keys_to_delete.append("decoder.char_upsampler.embed_positions._float_tensor")
-        keys_to_delete.append("decoder.char_upsampler.embed_tokens_char.weight")
+        keys_to_delete.append(
+            f"{t2u_decoder_key}.char_upsampler.embed_positions._float_tensor"
+        )
+        keys_to_delete.append(
+            f"{t2u_decoder_key}.char_upsampler.embed_tokens_char.weight"
+        )
 
         # Delete AlignmentEncoder keys for inference.
         alignment_encoder_keys = [
-            key for key in state_dict if key.startswith("decoder.alignment_encoder.")
+            key
+            for key in state_dict
+            if key.startswith(f"{t2u_decoder_key}.alignment_encoder.")
         ]
         keys_to_delete.extend(alignment_encoder_keys)
 
@@ -87,6 +99,17 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
             ]
         )
 
+        if config.prosody_encoder_config is not None:
+            keys_to_delete.extend(
+                [
+                    f"{t2u_decoder_key}.embed_positions._float_tensor",
+                    "t2s_model.global_proj_dec.weight",
+                    "t2s_model.global_proj_dec.bias",
+                    "t2s_model.decoder_target_letter_nllb_spm_decoder.encoder.proj.weight",
+                    "t2s_model.decoder_target_letter_nllb_spm_decoder.encoder.proj.bias",
+                ]
+            )
+
         for key in keys_to_delete:
             if key in state_dict:
                 del state_dict[key]
@@ -157,10 +180,19 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
 
     @staticmethod
     def _fairseq_key_map(config: UnitYConfig) -> Dict[str, str]:
+        # ExpressiveUnitY model (from multi_arch codebase)
+        if config.prosody_encoder_config is not None:
+            encoder_key = "s2t_model.encoder"
+            decoder_key = "s2t_model.decoder"
+            t2u_encoder_key = "t2s_model.encoder"
+            t2u_decoder_key = "t2s_model.decoder"
+            ecapa_tdnn_key = "global_prosody"
         # X2T/S2T + T2U model.
-        if config.t2u_config is not None:
+        elif config.t2u_config is not None:
             encoder_key = "encoder"
             decoder_key = "target_letter_decoder"
+            t2u_encoder_key = "synthesizer_encoder"
+            t2u_decoder_key = "decoder"
         # X2T model.
         elif config.use_text_encoder:
             encoder_key = "speech_encoder"
@@ -231,8 +263,8 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
         # fairseq was accidentally run with a pre-LN encoder, and ended up with
         # a redundant `LayerNorm` right after the Conformer blocks. We mitigate
         # that issue here by moving that `LayerNorm` to the adaptor block.
+        # fmt: off
         if config.w2v2_encoder_config.use_conformer:
-            # fmt: off
             key_map.update(
                 {
                     fr"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner_layer_norm."
@@ -244,7 +276,7 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
                     rf"^{encoder_key}\.w2v_encoder\.w2v_model\.encoder\.layer_norm\.": r"speech_encoder.inner.layer_norm."
                 }
             )
-            # fmt: on
+        # fmt: on
 
         if config.use_conformer_adaptor:
             key_map.update(
@@ -303,44 +335,56 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
                 # fmt: on
             }
         )
+        # ExpressiveUnitY model (from multi_arch codebase)
+        if config.prosody_encoder_config is not None:
+            key_map.update(
+                {
+                    # fmt: off
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.film\.":              r"t2u_model.decoder.layers.\1.film.",
+                    fr"^{ecapa_tdnn_key}\.":                                       r"prosody_encoder_model.",
+                    r"^t2s_model\.global_proj_enc\.":                             r"t2u_model.prosody_proj.",
+                    # fmt: on
+                }
+            )
+
         # X2T/S2T + T2U model.
         if config.t2u_config is not None:
             key_map.update(
                 {
                     # fmt: off
                     # T2U Encoder
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.encoder.layers.\1.self_attn.",
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.encoder.layers.\1.self_attn_layer_norm.",
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.encoder.layers.\1.ffn.inner_proj.",
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.encoder.layers.\1.ffn.output_proj.",
-                    r"^synthesizer_encoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
-                    r"^synthesizer_encoder\.layer_norm\.":                                r"t2u_model.encoder.layer_norm.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.encoder.layers.\1.self_attn.output_proj.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.encoder.layers.\1.self_attn.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.encoder.layers.\1.self_attn_layer_norm.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.encoder.layers.\1.ffn.inner_proj.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.encoder.layers.\1.ffn.output_proj.",
+                    fr"^{t2u_encoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.encoder.layers.\1.ffn_layer_norm.",
+                    fr"^{t2u_encoder_key}\.layer_norm\.":                                r"t2u_model.encoder.layer_norm.",
 
                     # T2U Decoder frontend
-                    r"^decoder\.embed_tokens_text\.":                           r"t2u_model.decoder_frontend.embed_char.",
-                    r"^decoder\.embed_tokens_unit\.":                           r"t2u_model.decoder_frontend.embed.",
-                    r"^decoder\.embed_tokens\.":                                r"t2u_model.decoder_frontend.embed.",
-                    r"^decoder\.var_adaptor\.duration_predictor\.":             r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
-                    r"^decoder\.dec_pos_emb_alpha":                             r"t2u_model.decoder_frontend.pos_emb_alpha",
-                    r"^decoder\.char_upsampler\.pos_emb_alpha":                 r"t2u_model.decoder_frontend.pos_emb_alpha_char",
+                    fr"^{t2u_decoder_key}\.embed_tokens_text\.":                           r"t2u_model.decoder_frontend.embed_char.",
+                    fr"^{t2u_decoder_key}\.embed_tokens_unit\.":                           r"t2u_model.decoder_frontend.embed.",
+                    fr"^{t2u_decoder_key}\.embed_tokens\.":                                r"t2u_model.decoder_frontend.embed.",
+                    fr"^{t2u_decoder_key}\.var_adaptor\.duration_predictor\.":             r"t2u_model.decoder_frontend.variance_adaptor.duration_predictor.",
+                    fr"^{t2u_decoder_key}\.dec_pos_emb_alpha":                             r"t2u_model.decoder_frontend.pos_emb_alpha",
+                    fr"^{t2u_decoder_key}\.char_upsampler\.pos_emb_alpha":                 r"t2u_model.decoder_frontend.pos_emb_alpha_char",
 
                     # T2U Decoder
-                    r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.decoder.layers.\1.self_attn.",
-                    r"^decoder\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
-                    r"^decoder\.layers\.([0-9]+)\.layer_norm\.":              r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
-                    r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.encoder_attn\.":            r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
-                    r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
-                    r"^decoder\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.decoder.layers.\1.ffn.output_proj.",
-                    r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
-                    r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"t2u_model.decoder.layers.\1.conv1d.conv1.",
-                    r"^decoder\.layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"t2u_model.decoder.layers.\1.conv1d.conv2.",
-                    r"^decoder\.layers\.([0-9]+)\.ffn\.layer_norm\.":         r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
-                    r"^decoder\.layer_norm\.":                                r"t2u_model.decoder.layer_norm.",
-                    r"^decoder\.output_projection\.":                         r"t2u_model.final_proj.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn\.out_proj\.":     r"t2u_model.decoder.layers.\1.self_attn.output_proj.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn\.":               r"t2u_model.decoder.layers.\1.self_attn.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.self_attn_layer_norm\.":    r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.layer_norm\.":              r"t2u_model.decoder.layers.\1.self_attn_layer_norm.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn\.out_proj\.":  r"t2u_model.decoder.layers.\1.encoder_decoder_attn.output_proj.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn\.":            r"t2u_model.decoder.layers.\1.encoder_decoder_attn.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": r"t2u_model.decoder.layers.\1.encoder_decoder_attn_layer_norm.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.fc1\.":                     r"t2u_model.decoder.layers.\1.ffn.inner_proj.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.fc2\.":                     r"t2u_model.decoder.layers.\1.ffn.output_proj.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.final_layer_norm\.":        r"t2u_model.decoder.layers.\1.ffn_layer_norm.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"t2u_model.decoder.layers.\1.conv1d.conv1.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"t2u_model.decoder.layers.\1.conv1d.conv2.",
+                    fr"^{t2u_decoder_key}\.layers\.([0-9]+)\.ffn\.layer_norm\.":         r"t2u_model.decoder.layers.\1.conv1d_layer_norm.",
+                    fr"^{t2u_decoder_key}\.layer_norm\.":                                r"t2u_model.decoder.layer_norm.",
+                    fr"^{t2u_decoder_key}\.output_projection\.":                         r"t2u_model.final_proj.",
                     # fmt: on
                 }
             )

+ 19 - 2
src/seamless_communication/models/unity/model.py

@@ -19,6 +19,7 @@ from overrides import final as finaloverride
 from torch import Tensor
 from torch.nn import Module
 
+from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
 from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 
@@ -42,6 +43,7 @@ class UnitYModel(EncoderDecoderModel):
     text_decoder: TransformerDecoder
     final_proj: Projection
     t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None]
+    prosody_encoder_model: Optional[ECAPA_TDNN]
 
     def __init__(
         self,
@@ -54,6 +56,7 @@ class UnitYModel(EncoderDecoderModel):
         final_proj: Projection,
         t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
         target_vocab_info: VocabularyInfo,
+        prosody_encoder_model: Optional[ECAPA_TDNN] = None,
         input_modality: str = "speech",
     ) -> None:
         model_dim = speech_encoder.model_dim
@@ -93,6 +96,10 @@ class UnitYModel(EncoderDecoderModel):
             self.register_module("t2u_model", None)
 
         self.target_vocab_info = target_vocab_info
+        if prosody_encoder_model is not None:
+            self.prosody_encoder_model = prosody_encoder_model
+        else:
+            self.register_module("prosody_encoder_model", None)
 
     @finaloverride
     def encode(
@@ -304,6 +311,7 @@ class UnitYNART2UModel(Module):
     decoder: NARTransformerDecoder
     final_proj: Projection
     target_vocab_info: VocabularyInfo
+    prosody_proj: Optional[Projection]
 
     def __init__(
         self,
@@ -312,6 +320,7 @@ class UnitYNART2UModel(Module):
         decoder: NARTransformerDecoder,
         final_proj: Projection,
         target_vocab_info: VocabularyInfo,
+        prosody_proj: Optional[Projection] = None,
     ) -> None:
         super().__init__()
 
@@ -339,20 +348,27 @@ class UnitYNART2UModel(Module):
 
         self.target_vocab_info = target_vocab_info
 
+        self.prosody_proj = prosody_proj
+
     def forward(
         self,
         text_decoder_output: Tensor,
         text_decoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[SequenceModelOutput, Optional[PaddingMask]]:
         encoder_output, encoder_padding_mask = self.encode(
             text_decoder_output, text_decoder_padding_mask
         )
 
+        if self.prosody_proj is not None and film_cond_emb is not None:
+            encoder_output = encoder_output + self.prosody_proj(film_cond_emb)
+
         decoder_output, decoder_padding_mask = self.decode(
             encoder_output,
             encoder_padding_mask,
             text_seqs,
+            film_cond_emb,
         )
 
         return self.project(decoder_output), decoder_padding_mask
@@ -372,14 +388,15 @@ class UnitYNART2UModel(Module):
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         # encoder_output: (N, S, M)
         # text_seqs: (N, S)
         seqs, padding_mask = self.decoder_frontend(
-            encoder_output, encoder_padding_mask, text_seqs
+            encoder_output, encoder_padding_mask, text_seqs, film_cond_emb
         )
 
-        return self.decoder(seqs, padding_mask)  # type: ignore[no-any-return]
+        return self.decoder(seqs, padding_mask, film_cond_emb=film_cond_emb)  # type: ignore[no-any-return]
 
     def project(self, decoder_output: Tensor) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)

+ 2 - 1
src/seamless_communication/models/unity/nar_decoder.py

@@ -66,9 +66,10 @@ class NARTransformerDecoder(Module):
         self,
         seqs: Tensor,
         padding_mask: Optional[PaddingMask],
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         for layer in self.layers.drop_iter():
-            seqs, padding_mask = layer(seqs, padding_mask)
+            seqs, padding_mask = layer(seqs, padding_mask, film_cond_emb=film_cond_emb)
 
         if self.layer_norm is not None:
             seqs = self.layer_norm(seqs)

+ 2 - 0
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -302,6 +302,7 @@ class NARDecoderFrontend(Module):
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         assert text_seqs is not None
 
@@ -323,6 +324,7 @@ class NARDecoderFrontend(Module):
             seqs,
             encoder_padding_mask,
             min_duration=1,
+            film_cond_emb=film_cond_emb,
         )
 
         seqs = self.forward_unit_pos_embedding(seqs, padding_mask)

+ 19 - 0
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -13,6 +13,8 @@ from fairseq2.typing import DataType, Device, finaloverride
 from torch import Tensor
 from torch.nn import Conv1d, Dropout, Module, ReLU
 
+from seamless_communication.models.unity.film import FiLM
+
 
 @final
 class Conv1dBlock(Module):
@@ -111,6 +113,7 @@ class NARTransformerDecoderLayer(Module):
     conv1d: Conv1dBlock
     conv1d_dropout: Optional[Dropout]
     conv1d_layer_norm: LayerNorm
+    film: Optional[FiLM]
 
     def __init__(
         self,
@@ -118,6 +121,8 @@ class NARTransformerDecoderLayer(Module):
         conv1d: Conv1dBlock,
         dropout_p: float = 0.1,
         conv1d_dropout_p: float = 0.1,
+        use_film: bool = False,
+        film_cond_dim: int = 512,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
     ) -> None:
@@ -130,6 +135,10 @@ class NARTransformerDecoderLayer(Module):
             The dropout probability on the outputs of the self attention layer.
         :param conv1d_dropout_p:
             The dropout probability on the outputs of the conv1d block.
+        :param use_film:
+            Whether to condition on a fixed-size vector through FiLM.
+        :param film_cond_dim:
+            The dim of fixed-size vector conditioned on during model forward.
         """
         super().__init__()
 
@@ -159,16 +168,26 @@ class NARTransformerDecoderLayer(Module):
             self.model_dim, device=device, dtype=dtype
         )
 
+        if use_film:
+            self.film = FiLM(film_cond_dim, self.model_dim, device=device, dtype=dtype)
+        else:
+            self.register_module("film", None)
+
     @finaloverride
     def forward(
         self,
         seqs: Tensor,
         padding_mask: Optional[PaddingMask],
+        film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         seqs = self._forward_self_attn(seqs, padding_mask)
 
         seqs = self._forward_conv1d(seqs, padding_mask)
 
+        if self.film is not None and film_cond_emb is not None:
+            seqs = self.film(seqs, film_cond_emb)
+            seqs = apply_padding_mask(seqs, padding_mask)
+
         return seqs, padding_mask
 
     def _forward_self_attn(

+ 106 - 7
src/seamless_communication/models/unity/t2u_builder.py

@@ -17,7 +17,7 @@ from fairseq2.models.transformer import (
 from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
 from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
-from fairseq2.nn.projection import TiedProjection
+from fairseq2.nn.projection import Linear, Projection, TiedProjection
 from fairseq2.nn.transformer import (
     FeedForwardNetwork,
     MultiheadAttention,
@@ -35,6 +35,7 @@ from fairseq2.nn.transformer import (
     create_default_sdpa,
 )
 from fairseq2.typing import DataType, Device
+from torch.nn import GELU, ReLU
 
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
 from seamless_communication.models.unity.length_regulator import (
@@ -55,6 +56,8 @@ class VariancePredictorConfig:
     var_pred_hidden_dim: int
     var_pred_kernel_size: int
     var_pred_dropout: float
+    use_film: bool
+    film_cond_dim: int
 
 
 @dataclass
@@ -73,6 +76,8 @@ class NARDecoderConfig:
     conv1d_kernel_size: int
     conv1d_inner_dim: int
     conv1d_dropout_p: float
+    use_film: bool
+    film_cond_dim: int
 
 
 @dataclass
@@ -113,9 +118,17 @@ class UnitYT2UConfig:
     dropout_p: float
     """The dropout probability in Transformer layers."""
 
-    def update_unit_vocabulary(self, info: VocabularyInfo) -> None:
-        """Update unit vocabulary configuration from ``info``."""
-        self.unit_vocabulary_size, self.unit_pad_idx = info.size, info.pad_idx
+    use_gelu: bool
+    """If ``True``, uses GELU activation function in feed-forward networks."""
+
+    char_pad_idx: int
+    """The index of the pad symbol in the char vocabulary."""
+
+    use_prosody_proj: bool
+    """If ``True``, uses a prosody projection layer."""
+
+    prosody_encoder_dim: int
+    """The dimensionality of prosody encoder (e.g. ECAPA_TDNN) output"""
 
 
 unity_t2u_archs = ArchitectureRegistry[UnitYT2UConfig]("unity_t2u")
@@ -140,6 +153,10 @@ def _base_t2u() -> UnitYT2UConfig:
         num_decoder_attn_heads=16,
         ffn_inner_dim=1024 * 8,
         dropout_p=0.1,
+        use_gelu=False,
+        char_pad_idx=0,
+        use_prosody_proj=False,
+        prosody_encoder_dim=0,
     )
 
 
@@ -159,6 +176,10 @@ def _medium_t2u() -> UnitYT2UConfig:
         num_decoder_attn_heads=16,
         ffn_inner_dim=1024 * 8,
         dropout_p=0.1,
+        use_gelu=False,
+        char_pad_idx=0,
+        use_prosody_proj=False,
+        prosody_encoder_dim=0,
     )
 
 
@@ -168,6 +189,8 @@ def _base_nar() -> UnitYT2UConfig:
         var_pred_hidden_dim=256,
         var_pred_kernel_size=3,
         var_pred_dropout=0.5,
+        use_film=False,
+        film_cond_dim=0,
     )
 
     nar_decoder_frontend_config = NARDecoderFrontendConfig(
@@ -184,6 +207,8 @@ def _base_nar() -> UnitYT2UConfig:
         conv1d_kernel_size=7,
         conv1d_inner_dim=1024,
         conv1d_dropout_p=0.1,
+        use_film=False,
+        film_cond_dim=0,
     )
 
     return UnitYT2UConfig(
@@ -200,6 +225,59 @@ def _base_nar() -> UnitYT2UConfig:
         num_decoder_attn_heads=16,
         ffn_inner_dim=1024 * 8,
         dropout_p=0.0,
+        use_gelu=False,
+        char_pad_idx=0,
+        use_prosody_proj=False,
+        prosody_encoder_dim=0,
+    )
+
+
+@unity_t2u_arch("expressivity_nar")
+def _expressivity_nar() -> UnitYT2UConfig:
+    duration_predictor_config = VariancePredictorConfig(
+        var_pred_hidden_dim=256,
+        var_pred_kernel_size=3,
+        var_pred_dropout=0.5,
+        use_film=True,
+        film_cond_dim=512,
+    )
+
+    nar_decoder_frontend_config = NARDecoderFrontendConfig(
+        subword_to_unit_upsampling_type="hard",
+        duration_predictor_config=duration_predictor_config,
+        pitch_predictor_config=None,
+        energy_predictor_config=None,
+    )
+
+    nar_decoder_config = NARDecoderConfig(
+        model_name_or_card="seamless_expressivity",
+        char_vocabulary_size=10904,
+        char_max_seq_len=4000,
+        conv1d_kernel_size=7,
+        conv1d_inner_dim=1024,
+        conv1d_dropout_p=0.1,
+        use_film=True,
+        film_cond_dim=512,
+    )
+
+    return UnitYT2UConfig(
+        model_dim=1024,
+        unit_max_seq_len=4000,
+        target_vocab_info=VocabularyInfo(
+            size=10005, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
+        num_encoder_layers=4,
+        num_decoder_layers=4,
+        nar_decoder_frontend_config=nar_decoder_frontend_config,
+        nar_decoder_config=nar_decoder_config,
+        num_encoder_attn_heads=16,
+        num_decoder_attn_heads=16,
+        ffn_inner_dim=1024 * 8,
+        dropout_p=0.0,
+        use_gelu=True,
+        char_pad_idx=1,
+        use_prosody_proj=True,
+        prosody_encoder_dim=512,
     )
 
 
@@ -417,12 +495,15 @@ class UnitYNART2UBuilder:
 
         decoder_frontend = self.build_decoder_frontend(embed_unit)
 
+        prosody_proj = self.build_prosody_proj()
+
         return UnitYNART2UModel(
             encoder,
             decoder_frontend,
             decoder,
             final_proj,
             self.config.target_vocab_info,
+            prosody_proj=prosody_proj,
         )
 
     def build_unit_embedding(self) -> StandardEmbedding:
@@ -482,6 +563,8 @@ class UnitYNART2UBuilder:
             duration_predictor_config.var_pred_hidden_dim,
             duration_predictor_config.var_pred_kernel_size,
             duration_predictor_config.var_pred_dropout,
+            use_film=duration_predictor_config.use_film,
+            film_cond_dim=duration_predictor_config.film_cond_dim,
             device=self.device,
             dtype=self.dtype,
         )
@@ -518,19 +601,18 @@ class UnitYNART2UBuilder:
         nllb_tokenizer = NllbTokenizerLoader(asset_store, download_manager)(
             self.config.nar_decoder_config.model_name_or_card
         )
-        text_pad_idx = nllb_tokenizer.vocab_info.pad_idx
 
         char_pos_encoder = SinusoidalPositionEncoder(
             self.config.model_dim,
             self.config.nar_decoder_config.char_max_seq_len,
-            _legacy_pad_idx=text_pad_idx,
+            _legacy_pad_idx=self.config.char_pad_idx,
             device=self.device,
         )
 
         embed_char = StandardEmbedding(
             num_embeddings=self.config.nar_decoder_config.char_vocabulary_size,
             embedding_dim=self.config.model_dim,
-            pad_idx=text_pad_idx,
+            pad_idx=self.config.char_pad_idx,
             init_fn=init_scaled_embedding,
             device=self.device,
             dtype=self.dtype,
@@ -584,6 +666,8 @@ class UnitYNART2UBuilder:
             conv1d,
             dropout_p=self.config.dropout_p,
             conv1d_dropout_p=self.config.nar_decoder_config.conv1d_dropout_p,
+            use_film=self.config.nar_decoder_config.use_film,
+            film_cond_dim=self.config.nar_decoder_config.film_cond_dim,
             device=self.device,
             dtype=self.dtype,
         )
@@ -608,11 +692,26 @@ class UnitYNART2UBuilder:
             self.config.model_dim,
             self.config.ffn_inner_dim,
             bias=True,
+            inner_activation=GELU() if self.config.use_gelu else ReLU(),
             norm_order=TransformerNormOrder.PRE,
             device=self.device,
             dtype=self.dtype,
         )
 
+    def build_prosody_proj(self) -> Optional[Projection]:
+        """Build a prosody projection layer if needed"""
+
+        if self.config.use_prosody_proj:
+            return Linear(
+                self.config.prosody_encoder_dim,
+                self.config.model_dim,
+                bias=True,
+                dtype=self.dtype,
+                device=self.device,
+            )
+        else:
+            return None
+
 
 def create_unity_t2u_model(
     config: UnitYT2UConfig,