Pārlūkot izejas kodu

UnitY2 aligner for release (#112)

* initial aligner draft code

* more for aligner

* aligner working

* aligner changes

* test added

* formatting

* Fix fairseq2 API changes

* adding fixture

* remove unnecessary

* fixes

* more changes

* device configurable

* dtype fp16 works now

* add sampling rate check in unit extractor module

* formatting

* naming consistency

* Update src/seamless_communication/models/aligner/model.py

Co-authored-by: David Dale <dale.david@mail.ru>

* formatting

* using fairseq2 pad fn

* fixing the test

* mypy + formatting

* support using aligner without unit extractor

* mypy numpy

* addressing comments

* return string tokens as well

---------

Co-authored-by: Can Balioglu <cbalioglu@users.noreply.github.com>
Co-authored-by: David Dale <dale.david@mail.ru>
Ilia Kulikov 1 gadu atpakaļ
vecāks
revīzija
78d6dac3a9

+ 51 - 0
src/seamless_communication/cards/nar_t2u_aligner.yaml

@@ -0,0 +1,51 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: nar_t2u_aligner
+char_tokenizer: "file:///private/home/hirofumii/large_experiments/datasets/m4t/t2u_v2/spm_char_lang38_tc.model"
+model_type: unity2_aligner
+model_arch: nar_t2u_aligner
+checkpoint: "file:///checkpoint/kulikov/nar_t2u_m4tv2_aligner.pt"
+num_units: 10000
+unit_langs:
+  - arb
+  - ben
+  - cat
+  - ces
+  - cmn
+  - cym
+  - dan
+  - deu
+  - eng
+  - est
+  - fin
+  - fra
+  - hin
+  - ind
+  - ita
+  - jpn
+  - kan
+  - kor
+  - mlt
+  - nld
+  - pes
+  - pol
+  - por
+  - ron
+  - rus
+  - slk
+  - spa
+  - swe
+  - swh
+  - tam
+  - tel
+  - tgl
+  - tha
+  - tur
+  - ukr
+  - urd
+  - uzn
+  - vie

+ 9 - 0
src/seamless_communication/models/aligner/__init__.py

@@ -0,0 +1,9 @@
+from seamless_communication.models.aligner.model import (
+    UnitY2AlignmentEncoder as UnitY2AlignmentEncoder,
+)
+from seamless_communication.models.aligner.model import (
+    UnitY2AlignmentFrontend as UnitY2AlignmentFrontend,
+)
+from seamless_communication.models.aligner.model import (
+    UnitY2AlignmentModel as UnitY2AlignmentModel,
+)

+ 175 - 0
src/seamless_communication/models/aligner/alignment_extractor.py

@@ -0,0 +1,175 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+from typing import Any, List, Tuple, Union
+
+import numpy
+import torch
+import torch.nn as nn
+import torchaudio
+from fairseq2.data import CString
+from fairseq2.typing import DataType, Device
+from fairseq2.data.typing import StringLike
+from torch import Tensor
+
+from seamless_communication.models.aligner.loader import load_unity2_alignment_model
+from seamless_communication.models.unit_extractor import UnitExtractor
+
+try:
+    import matplotlib.pyplot as plt
+
+    matplotlib_available = True
+except ImportError:
+    matplotlib_available = False
+
+
+class AlignmentExtractor(nn.Module):
+    def __init__(
+        self,
+        aligner_model_name_or_card: str,
+        unit_extractor_model_name_or_card: Union[Any, str] = None,
+        unit_extractor_output_layer: Union[Any, int] = None,
+        unit_extractor_kmeans_model_uri: Union[Any, str] = None,
+        device: Device = Device("cpu"),
+        dtype: DataType = torch.float32,
+    ):
+        super().__init__()
+        self.device = device
+        self.dtype = dtype
+
+        if self.dtype == torch.float16 and self.device == Device("cpu"):
+            raise RuntimeError("FP16 only works on GPU, set args accordingly")
+
+        self.alignment_model = load_unity2_alignment_model(
+            aligner_model_name_or_card, device=self.device, dtype=self.dtype
+        )
+        self.alignment_model.eval()
+
+        self.unit_extractor = None
+        self.unit_extractor_output_layer = 0
+
+        if unit_extractor_model_name_or_card is not None:
+            self.unit_extractor = UnitExtractor(
+                unit_extractor_model_name_or_card,
+                unit_extractor_kmeans_model_uri,
+                device=device,
+                dtype=dtype,
+            )
+            self.unit_extractor_output_layer = unit_extractor_output_layer
+
+    def load_audio(
+        self, audio_path: str, sampling_rate: int = 16_000
+    ) -> Tuple[Tensor, int]:
+        assert os.path.exists(audio_path)
+        audio, rate = torchaudio.load(audio_path)
+        if rate != sampling_rate:
+            audio = torchaudio.functional.resample(audio, rate, sampling_rate)
+            rate = sampling_rate
+        return audio, rate
+
+    def prepare_audio(self, audio: Union[str, Tensor]) -> Tensor:
+        # TODO: switch to fairseq2 data pipeline once it supports resampling
+        if isinstance(audio, str):
+            audio, _ = self.load_audio(audio, sampling_rate=16_000)
+        if audio.ndim > 1:
+            # averaging over channels
+            assert audio.size(0) < audio.size(
+                1
+            ), "Expected [Channel,Time] shape, but Channel > Time"
+            audio = audio.mean(0)
+        assert (
+            audio.ndim == 1
+        ), f"After channel averaging audio shape expected to be [Time] i.e. mono audio"
+        audio = audio.to(self.device, self.dtype)
+
+        return audio
+
+    def extract_units(self, audio: Tensor) -> Tensor:
+        assert isinstance(
+            self.unit_extractor, UnitExtractor
+        ), "Unit extractor is required to get units from audio tensor"
+        units = self.unit_extractor.predict(audio, self.unit_extractor_output_layer)
+        return units
+
+    @torch.inference_mode()
+    def extract_alignment(
+        self,
+        audio: Union[str, Tensor],
+        text: str,
+        plot: bool = False,
+        add_trailing_silence: bool = False,
+    ) -> Tuple[Tensor, Tensor, List[StringLike]]:
+        if isinstance(audio, Tensor) and not torch.is_floating_point(audio):
+            # we got units as audio arg
+            units = audio
+            units = units.to(self.device)
+            audio_tensor = None
+        else:
+            audio_tensor = self.prepare_audio(audio)
+            units = self.extract_units(audio_tensor)
+
+        tokenized_unit_ids = self.alignment_model.alignment_frontend.tokenize_unit(
+            units
+        ).unsqueeze(0)
+        tokenized_text_ids = (
+            self.alignment_model.alignment_frontend.tokenize_text(
+                text, add_trailing_silence=add_trailing_silence
+            )
+            .to(self.device)
+            .unsqueeze(0)
+        )
+        tokenized_text_tokens = (
+            self.alignment_model.alignment_frontend.tokenize_text_to_tokens(
+                text, add_trailing_silence=add_trailing_silence
+            )
+        )
+        alignment_lprobs, alignment_durations = self.alignment_model(
+            tokenized_text_ids, tokenized_unit_ids
+        )
+
+        if plot and (audio_tensor is not None):
+            self.plot_alignment(
+                audio_tensor.cpu(), tokenized_text_tokens, alignment_durations.cpu()
+            )
+
+        return alignment_durations, tokenized_text_ids, tokenized_text_tokens
+
+    def detokenize_text(self, tokenized_text_ids: Tensor) -> StringLike:
+        return self.alignment_model.alignment_frontend.decode_text(tokenized_text_ids)
+
+    def plot_alignment(
+        self, audio: Tensor, text_tokens: List[StringLike], durations: Tensor
+    ) -> None:
+        if not matplotlib_available:
+            raise RuntimeError(
+                "Please `pip install matplotlib` in order to use plot alignment."
+            )
+        fig, ax = plt.subplots(figsize=(22, 3.5))
+        ax.plot(audio, color="gray", linewidth=0.3)
+        durations_cumul = numpy.concatenate([numpy.array([0]), numpy.cumsum(durations)])
+        alignment_ticks = durations_cumul * 320  # 320 is hardcoded for 20ms rate here
+
+        ax.vlines(
+            alignment_ticks,
+            ymax=1,
+            ymin=-1,
+            color="indigo",
+            linestyles="dashed",
+            lw=0.5,
+        )
+
+        middle_tick_positions = (
+            durations_cumul[:-1] + (durations_cumul[1:] - durations_cumul[:-1]) / 2
+        )
+        ax.set_xticks(middle_tick_positions * 320)
+        ax.set_xticklabels(text_tokens, fontsize=13)
+        ax.set_xlim(0, len(audio))
+
+        ax.set_ylim(audio.min(), audio.max())
+        ax.set_yticks([])
+        plt.tight_layout()
+        plt.show()

+ 186 - 0
src/seamless_communication/models/aligner/builder.py

@@ -0,0 +1,186 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import Optional, Union
+
+import torch
+from fairseq2.assets.card import AssetCard
+from fairseq2.data.vocabulary_info import VocabularyInfo
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
+from fairseq2.typing import DataType, Device
+
+from seamless_communication.models.aligner.model import (
+    UnitY2AlignmentEncoder,
+    UnitY2AlignmentFrontend,
+    UnitY2AlignmentModel,
+)
+from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
+from seamless_communication.models.unity.loader import load_unity_unit_tokenizer
+
+
+@dataclass
+class AlignmentEncoderConfig:
+    model_dim: int
+
+    feat_dim: int
+
+    num_text_layers: int
+
+    num_feat_layers: int
+
+    dropout: float
+
+    temperature: float
+
+    reduction_factor: int
+
+
+@dataclass
+class UnitY2AlignmentFrontendConfig:
+    unit_vocab_info: VocabularyInfo
+
+    text_vocab_size: int
+
+
+@dataclass
+class UnitY2AlignmentConfig:
+    model_name_or_card: Union[str, AssetCard]
+
+    alignment_encoder_config: AlignmentEncoderConfig
+
+    alignment_frontend_config: UnitY2AlignmentFrontendConfig
+
+
+aligner_archs = ArchitectureRegistry[UnitY2AlignmentConfig]("unity2_aligner")
+
+aligner_arch = aligner_archs.decorator
+
+
+@aligner_arch("nar_t2u_aligner")
+def _aligner_nar_t2u() -> UnitY2AlignmentConfig:
+    encoder_config = AlignmentEncoderConfig(
+        model_dim=1024,
+        feat_dim=1024,
+        num_text_layers=2,
+        num_feat_layers=3,
+        dropout=0.1,
+        temperature=1.0,
+        reduction_factor=1,
+    )
+
+    frontend_config = UnitY2AlignmentFrontendConfig(
+        unit_vocab_info=VocabularyInfo(
+            size=10082, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
+        text_vocab_size=10943,
+    )
+
+    return UnitY2AlignmentConfig(
+        model_name_or_card="nar_t2u_aligner",
+        alignment_encoder_config=encoder_config,
+        alignment_frontend_config=frontend_config,
+    )
+
+
+class UnitY2AlignmentBuilder:
+    config: UnitY2AlignmentConfig
+    device: Optional[Device]
+    dtype: DataType
+
+    def __init__(
+        self,
+        config: UnitY2AlignmentConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: DataType = torch.float32,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param device:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+
+        self.device, self.dtype = device, dtype
+
+    def build_model(self) -> UnitY2AlignmentModel:
+        alignment_frontend = self.build_alignment_frontend()
+
+        alignment_encoder = self.build_alignment_encoder()
+
+        return UnitY2AlignmentModel(alignment_frontend, alignment_encoder)
+
+    def build_alignment_frontend(self) -> UnitY2AlignmentFrontend:
+        text_tokenizer = load_unity_char_tokenizer(self.config.model_name_or_card)
+
+        unit_tokenizer = load_unity_unit_tokenizer(self.config.model_name_or_card)
+
+        embed_text = StandardEmbedding(
+            num_embeddings=self.config.alignment_frontend_config.text_vocab_size,
+            embedding_dim=self.config.alignment_encoder_config.model_dim,
+            pad_idx=self.config.alignment_frontend_config.unit_vocab_info.pad_idx,
+            init_fn=init_scaled_embedding,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        embed_unit = StandardEmbedding(
+            num_embeddings=self.config.alignment_frontend_config.unit_vocab_info.size,
+            embedding_dim=self.config.alignment_encoder_config.model_dim,
+            pad_idx=self.config.alignment_frontend_config.unit_vocab_info.pad_idx,
+            init_fn=init_scaled_embedding,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        return UnitY2AlignmentFrontend(
+            embed_text, embed_unit, text_tokenizer, unit_tokenizer
+        )
+
+    def build_alignment_encoder(self, training: bool = False) -> UnitY2AlignmentEncoder:
+        cfg = self.config.alignment_encoder_config
+        alignment_encoder = UnitY2AlignmentEncoder(
+            embed_dim=cfg.model_dim,
+            feat_dim=cfg.feat_dim,
+            text_layers=cfg.num_text_layers,
+            feat_layers=cfg.num_feat_layers,
+            dropout=cfg.dropout,
+            temperature=cfg.temperature,
+            reduction_factor=cfg.reduction_factor,
+            dtype=self.dtype,
+        )
+        alignment_encoder.training = training
+
+        return alignment_encoder
+
+
+def create_unity2_alignment_model(
+    config: UnitY2AlignmentConfig,
+    device: Optional[Device] = None,
+    dtype: DataType = torch.float32,
+) -> UnitY2AlignmentModel:
+    """Create a UnitY model.
+
+    :param config:
+        The configuration to use.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+
+    unity2_aligner_builder = UnitY2AlignmentBuilder(
+        config,
+        device=device,
+        dtype=dtype,
+    )
+
+    return unity2_aligner_builder.build_model()

+ 82 - 0
src/seamless_communication/models/aligner/loader.py

@@ -0,0 +1,82 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, List, Mapping, final
+
+import torch
+from fairseq2.assets import asset_store, download_manager
+from fairseq2.models.utils import ConfigLoader, ModelLoader
+
+from seamless_communication.models.aligner.builder import (
+    UnitY2AlignmentConfig,
+    aligner_archs,
+    create_unity2_alignment_model,
+)
+from seamless_communication.models.aligner.model import UnitY2AlignmentModel
+from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
+
+
+def convert_unity2_aligner_checkpoint(
+    checkpoint: Mapping[str, Any], config: UnitY2AlignmentConfig
+) -> Mapping[str, Any]:
+    alignment_frontend_statedict = {}
+    text_emb_state_keymap = {"weight": "alignment_frontend.embed_text.weight"}
+    for k, v in checkpoint["text_emb_state"].items():
+        alignment_frontend_statedict[text_emb_state_keymap[k]] = v
+
+    unit_emb_state_keymap = {"weight": "alignment_frontend.embed_unit.weight"}
+    for k, v in checkpoint["unit_emb_state"].items():
+        alignment_frontend_statedict[unit_emb_state_keymap[k]] = v
+
+    alignment_encoder_state_dict = {}
+    for k, v in checkpoint["aligner_state"].items():
+        alignment_encoder_state_dict[f"alignment_encoder.{k}"] = v
+
+    model_state = {
+        **alignment_encoder_state_dict,
+        **alignment_frontend_statedict,
+    }
+
+    char_embeds = model_state["alignment_frontend.embed_text.weight"]
+
+    index_mapping = _get_char_index_mapping(config)
+    vocab_size = len(index_mapping)
+    char_embeds[torch.arange(vocab_size)] = char_embeds[index_mapping]
+
+    checkpoint["model"] = model_state
+
+    return checkpoint
+
+
+def _get_char_index_mapping(config: UnitY2AlignmentConfig) -> List[int]:
+    char_tokenizer = load_unity_char_tokenizer(config.model_name_or_card)
+    spm_order = [
+        char_tokenizer.model.index_to_token(i)
+        for i in range(char_tokenizer.model.vocabulary_size)
+    ][4:]
+    spm_to_dict_mapping = {
+        ch: idx
+        for (idx, ch) in zip(
+            range(4, char_tokenizer.model.vocabulary_size),
+            sorted(spm_order),
+        )
+    }
+    model_to_dict_mapping = [0, 1, 2, 3] + [spm_to_dict_mapping[ch] for ch in spm_order]
+    return model_to_dict_mapping
+
+
+load_unity2_alignment_config = ConfigLoader[UnitY2AlignmentConfig](
+    asset_store, aligner_archs
+)
+
+load_unity2_alignment_model = ModelLoader[UnitY2AlignmentModel, UnitY2AlignmentConfig](
+    asset_store,
+    download_manager,
+    load_unity2_alignment_config,
+    create_unity2_alignment_model,
+    convert_unity2_aligner_checkpoint,
+    restrict_checkpoints=False,
+)

+ 305 - 0
src/seamless_communication/models/aligner/model.py

@@ -0,0 +1,305 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, List, Tuple, Union
+
+import numpy as np
+import numpy.typing as npt
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from fairseq2.data import CString
+from fairseq2.nn.embedding import StandardEmbedding
+from fairseq2.nn.padding import to_padding_mask
+from fairseq2.typing import DataType
+from numba import jit
+from torch import Tensor
+from torch.nn import Module
+
+from seamless_communication.models.unity.char_tokenizer import CharTokenizer
+from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
+
+
+class UnitY2AlignmentFrontend(Module):
+    def __init__(
+        self,
+        embed_text: StandardEmbedding,
+        embed_unit: StandardEmbedding,
+        text_tokenizer: CharTokenizer,
+        unit_tokenizer: UnitTokenizer,
+    ):
+        super().__init__()
+        self.embed_text = embed_text
+        self.embed_unit = embed_unit
+        self.text_tokenizer = text_tokenizer
+        self.unit_tokenizer = unit_tokenizer
+        unit_tokenizer.is_nar_decoder = True
+
+        self.encode_text = self.text_tokenizer.create_raw_encoder()
+        # text decoder can be used to map aligned characters to words
+        self.decode_text = self.text_tokenizer.create_decoder()
+        self.encode_unit = self.unit_tokenizer.create_encoder(lang="eng")
+
+    def tokenize_text(
+        self, text: str, return_tokens: bool = False, add_trailing_silence: bool = False
+    ) -> Tensor:
+        tokenized = self.encode_text(text)
+        if add_trailing_silence:
+            tokenized = torch.cat([tokenized, tokenized[0:1]])
+
+        return tokenized
+
+    def tokenize_text_to_tokens(
+        self, text: str, add_trailing_silence: bool = False
+    ) -> List[Union[CString, str]]:
+        tokenized = self.encode_text.encode_as_tokens(text)
+        if add_trailing_silence:
+            tokenized = tokenized + [tokenized[0]]
+
+        return tokenized
+
+    def tokenize_unit(self, units: Union[str, Tensor]) -> Tensor:
+        if isinstance(units, str):
+            units = torch.tensor([int(u) for u in units.split(" ")])
+        return self.encode_unit(units)
+
+    def forward(self, text: Tensor, unit: Tensor) -> Tuple[Any, Any]:
+        embs_unit = self.embed_unit(unit)
+        embs_text = self.embed_text(text)
+        return embs_text, embs_unit
+
+
+class Permute12(nn.Module):
+    def forward(self, x: Tensor) -> Tensor:
+        return x.transpose(1, 2)
+
+
+class UnitY2AlignmentEncoder(Module):
+    """
+    UnitY2 Aligner component
+    """
+
+    def __init__(
+        self,
+        embed_dim: int,
+        feat_dim: int,
+        text_layers: int,
+        feat_layers: int,
+        dropout: float,
+        temperature: float,
+        reduction_factor: int,
+        dtype: DataType,
+    ):
+        super().__init__()
+        self.temperature = temperature
+        self.reduction_factor = reduction_factor  # for unit
+
+        layers: List[Module] = [Permute12()]
+        for i in range(text_layers):
+            if i < text_layers - 1:
+                layers.append(
+                    nn.Conv1d(
+                        embed_dim, embed_dim, kernel_size=3, padding=1, dtype=dtype
+                    )
+                )
+                layers.append(nn.ReLU())
+                layers.append(nn.Dropout(p=dropout))
+            else:
+                layers.append(
+                    nn.Conv1d(
+                        embed_dim, embed_dim, kernel_size=1, padding=0, dtype=dtype
+                    )
+                )
+                layers.append(nn.Dropout(p=dropout))
+                layers.append(Permute12())
+        self.t_conv = nn.Sequential(*layers)
+
+        layers = [Permute12()]
+        input_dim = feat_dim
+        for i in range(feat_layers):
+            if i < feat_layers - 1:
+                layers.append(
+                    nn.Conv1d(
+                        input_dim, embed_dim, kernel_size=3, padding=1, dtype=dtype
+                    )
+                )
+                layers.append(nn.ReLU())
+                layers.append(nn.Dropout(p=dropout))
+            else:
+                layers.append(
+                    nn.Conv1d(
+                        input_dim,
+                        embed_dim,
+                        kernel_size=1,
+                        padding=0,
+                        stride=reduction_factor,
+                        dtype=dtype,
+                    )
+                )
+                layers.append(nn.Dropout(p=dropout))
+                layers.append(Permute12())
+            input_dim = embed_dim
+        self.f_conv = nn.Sequential(*layers)
+
+    def forward(
+        self,
+        text_emb: Tensor,
+        feat_emb: Tensor,
+        text_lengths: Tensor,
+        feat_lengths: Tensor,
+    ) -> Tuple[Tensor, Tensor]:
+        """Compute alignment between sequence of text and feature embeddings
+
+        Args:
+            text_emb (Tensor): Batched text embedding (B, T_text, C).
+            feat_emb (Tensor): Batched acoustic feature (B, T_feat, feat_dim).
+            text_lengths (Tensor): Source text length (B,).
+            feat_lengths (Tensor): Target feature length (B,).
+
+        Returns:
+            Tensor: Log probability of attention matrix (B, T_feat, T_text)
+            Tensor: Unit durations of every text token (B, T_text)
+
+        """
+        _feat_lengths = feat_lengths.clone()
+        if self.reduction_factor > 1:
+            feat_lengths = torch.ceil(feat_lengths / self.reduction_factor).long()
+
+        text_emb = self.t_conv(text_emb)
+        feat_emb = self.f_conv(feat_emb)
+
+        dist = feat_emb.unsqueeze(2) - text_emb.unsqueeze(1)
+        dist = torch.norm(dist, p=2, dim=3)
+        score = -self.temperature * dist
+
+        padding_mask = ~(to_padding_mask(text_lengths, max(text_lengths)))
+        padding_mask = padding_mask.unsqueeze(-2)
+        score = score.masked_fill(padding_mask, -np.inf)
+
+        attn_lprob = F.log_softmax(score, dim=-1)
+
+        attn_hard_dur = viterbi_decode(attn_lprob, text_lengths, feat_lengths)
+
+        if self.reduction_factor > 1:
+            attn_hard_dur = self.postprocess_alignment(
+                attn_hard_dur, text_lengths, _feat_lengths
+            )
+
+        return attn_lprob, attn_hard_dur
+
+    def postprocess_alignment(
+        self, attn_hard_dur: Tensor, text_lengths: Tensor, feat_lengths: Tensor
+    ) -> Tensor:
+        attn_hard_dur = attn_hard_dur * self.reduction_factor
+        B, T = attn_hard_dur.size()  # B x T_text
+        dur_cumsum = torch.cumsum(attn_hard_dur, dim=1)
+        for b in range(B):
+            for t in range(text_lengths[b]):
+                # truncate the right frames
+                if dur_cumsum[b, t] >= feat_lengths[b]:
+                    if t == 0:
+                        attn_hard_dur[b, t] = feat_lengths[b]
+                    else:
+                        attn_hard_dur[b, t] = feat_lengths[b] - dur_cumsum[b, t - 1]
+                    if t < text_lengths[b] - 1:
+                        attn_hard_dur[b, t + 1 :] = 0
+                    break
+        return attn_hard_dur
+
+
+def _monotonic_alignment_search(
+    attn_lprob: npt.NDArray[np.float64],
+) -> npt.NDArray[np.float64]:
+    # https://arxiv.org/abs/2005.11129
+    T_feat = attn_lprob.shape[0]
+    T_text = attn_lprob.shape[1]
+    Q = np.full((T_text, T_feat), fill_value=-np.inf)
+
+    log_prob = attn_lprob.transpose(1, 0)  # -> (T_text, T_feat)
+    # 1.  Q <- init first row for all j
+    for j in range(T_feat):
+        Q[0, j] = log_prob[0, : j + 1].sum()
+
+    # 2.
+    for j in range(1, T_feat):
+        for i in range(1, min(j + 1, T_text)):
+            Q[i, j] = max(Q[i - 1, j - 1], Q[i, j - 1]) + log_prob[i, j]
+
+    # 3.
+    A = np.full((T_feat,), fill_value=T_text - 1)
+    for j in range(T_feat - 2, -1, -1):  # T_feat-2, ..., 0
+        # 'i' in {A[j+1]-1, A[j+1]}
+        i_a = A[j + 1] - 1
+        i_b = A[j + 1]
+        if i_b == 0:
+            argmax_i = 0
+        elif Q[i_a, j] >= Q[i_b, j]:
+            argmax_i = i_a
+        else:
+            argmax_i = i_b
+        A[j] = argmax_i
+    return A
+
+
+def viterbi_decode(
+    attn_lprob: Tensor, text_lengths: Tensor, feat_lengths: Tensor
+) -> Tensor:
+    """Extract duration from an attention probability matrix
+
+    Args:
+        attn_lprob (Tensor): Batched log probability of attention
+            matrix (B, T_feat, T_text).
+        text_lengths (Tensor): Text length tensor (B,).
+        feat_lengths (Tensor): Feature length tensor (B,).
+
+    Returns:
+        Tensor: Batched token duration extracted from `attn_lprob` (B, T_text).
+        Tensor: Binarization loss tensor ().
+
+    """
+    B = attn_lprob.size(0)
+    T_text = attn_lprob.size(2)
+    device = attn_lprob.device
+
+    durations = torch.zeros((B, T_text), device=device, dtype=torch.long)
+    for b in range(B):
+        assert feat_lengths[b] > 0
+        assert text_lengths[b] > 0
+        cur_log_p_attn = attn_lprob[b, : feat_lengths[b], : text_lengths[b]]
+        viterbi = _monotonic_alignment_search(
+            cur_log_p_attn.float().detach().cpu().numpy()
+        )
+        _durations = np.bincount(viterbi)
+        durations[b, : len(_durations)] = torch.from_numpy(_durations).to(device)
+
+    return durations
+
+
+class UnitY2AlignmentModel(Module):
+    alignment_encoder: UnitY2AlignmentEncoder
+    alignment_frontend: UnitY2AlignmentFrontend
+
+    def __init__(
+        self,
+        alignment_frontend: UnitY2AlignmentFrontend,
+        alignment_encoder: UnitY2AlignmentEncoder,
+    ):
+        super().__init__()
+        self.alignment_frontend = alignment_frontend
+        self.alignment_encoder = alignment_encoder
+
+    def forward(self, input_text: Tensor, input_unit: Tensor) -> Tuple[Tensor, Tensor]:
+        assert input_text.ndim == 2
+        assert input_unit.ndim == 2
+        embs_text, embs_unit = self.alignment_frontend(input_text, input_unit)
+        attn_lprob, attn_hard_dur = self.alignment_encoder(
+            embs_text,
+            embs_unit,
+            torch.tensor([embs_text.size(1)]).to(embs_text).int(),
+            torch.tensor([embs_unit.size(1)]).to(embs_unit).int(),
+        )
+
+        return attn_lprob, attn_hard_dur

+ 4 - 0
src/seamless_communication/models/unit_extractor/unit_extractor.py

@@ -70,6 +70,10 @@ class UnitExtractor(nn.Module):
             with Path(audio).open("rb") as fb:
                 block = MemoryBlock(fb.read())
             decoded_audio = self.decode_audio(block)
+            assert (
+                sample_rate == decoded_audio["sample_rate"]
+            ), f"Input audio must have {sample_rate} sampling rate"
+
         else:
             assert audio.dim() <= 2, "The audio tensor can't be more than 2 dimensions."
             if audio.dim() == 1:

+ 54 - 0
tests/integration/models/test_unity2_aligner.py

@@ -0,0 +1,54 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Final
+
+import torch
+from fairseq2.typing import Device
+from torch import tensor
+
+from tests.common import assert_equal, device
+from seamless_communication.models.aligner.alignment_extractor import AlignmentExtractor
+from fairseq2.data.audio import (
+    AudioDecoder,
+    AudioDecoderOutput
+)
+from fairseq2.memory import MemoryBlock
+from urllib.request import urlretrieve
+import tempfile
+from tests.common import assert_equal, device
+
+REF_TEXT = "the examination and testimony of the experts enabled the commision to conclude that five shots may have been fired"
+
+REF_DURATIONS: Final = [[ 1,  1,  2,  1,  1,  5,  5,  6,  4,  3,  2,  3,  4,  4,  2,  2,  2,  1,
+           1,  1,  3,  3,  3,  4,  3,  3,  4,  3,  4,  3,  2,  2,  1,  1,  1,  1,
+           2,  4,  6,  5,  4,  3,  4,  5,  5, 16,  6,  3,  5,  5,  3,  3,  1,  2,
+           1,  1,  1,  2,  3,  2,  3,  1,  3,  3,  3,  2,  2,  4,  2,  2,  2,  3,
+           2,  4,  5,  4,  5,  8,  3, 17,  2,  2,  3,  2,  5,  4,  6,  3,  1,  1,
+           4,  4,  3,  5,  3,  3,  2,  2,  2,  2,  2,  2,  2,  1,  2,  2,  1,  1,
+           2,  6,  4,  5,  9,  5,  1, 12]]
+
+def test_aligner(example_rate16k_audio: AudioDecoderOutput) -> None:
+
+    aligner_name = "nar_t2u_aligner"
+    unit_extractor_name = "xlsr2_1b_v2"
+    unit_extractor_output_layer_n = 35
+    unit_extractor_kmeans_uri = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
+
+    extractor = AlignmentExtractor(
+        aligner_name,
+        unit_extractor_name,
+        unit_extractor_output_layer_n,
+        unit_extractor_kmeans_uri,
+        device=device
+    )
+
+    audio = example_rate16k_audio["waveform"].mean(1)  # averaging mono to get [Time] shape required by aligner
+
+    alignment_durations, _, _ = extractor.extract_alignment(audio, REF_TEXT, plot=False, add_trailing_silence=True)
+
+    assert_equal(alignment_durations, tensor(REF_DURATIONS, device=device, dtype=torch.int64))
+