Forráskód Böngészése

Clean up aligner assets, refactor aligner test to run on GPU with fp16 as well. (#148)

Kaushik Ram Sadagopan 1 éve
szülő
commit
26bc428198

+ 2 - 2
src/seamless_communication/cards/nar_t2u_aligner.yaml

@@ -5,10 +5,10 @@
 # LICENSE file in the root directory of this source tree.
 
 name: nar_t2u_aligner
-char_tokenizer: "file:///private/home/hirofumii/large_experiments/datasets/m4t/t2u_v2/spm_char_lang38_tc.model"
+char_tokenizer: "file:///checkpoint/krs/unity2/spm_char_lang38_tc.model"
 model_type: unity2_aligner
 model_arch: nar_t2u_aligner
-checkpoint: "file:///checkpoint/kulikov/nar_t2u_m4tv2_aligner.pt"
+checkpoint: "file:///large_experiments/seamless/ust/krs/fairseq2_checkpoints/unity2_aligner.pt"
 num_units: 10000
 unit_langs:
   - arb

+ 6 - 0
src/seamless_communication/models/aligner/__init__.py

@@ -1,3 +1,9 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
 from seamless_communication.models.aligner.model import (
     UnitY2AlignmentEncoder as UnitY2AlignmentEncoder,
 )

+ 2 - 3
src/seamless_communication/models/aligner/alignment_extractor.py

@@ -11,7 +11,6 @@ import numpy
 import torch
 import torch.nn as nn
 import torchaudio
-from fairseq2.data import CString
 from fairseq2.typing import DataType, Device
 from fairseq2.data.typing import StringLike
 from torch import Tensor
@@ -127,7 +126,7 @@ class AlignmentExtractor(nn.Module):
                 text, add_trailing_silence=add_trailing_silence
             )
         )
-        alignment_lprobs, alignment_durations = self.alignment_model(
+        _, alignment_durations = self.alignment_model(
             tokenized_text_ids, tokenized_unit_ids
         )
 
@@ -148,7 +147,7 @@ class AlignmentExtractor(nn.Module):
             raise RuntimeError(
                 "Please `pip install matplotlib` in order to use plot alignment."
             )
-        fig, ax = plt.subplots(figsize=(22, 3.5))
+        _, ax = plt.subplots(figsize=(22, 3.5))
         ax.plot(audio, color="gray", linewidth=0.3)
         durations_cumul = numpy.concatenate([numpy.array([0]), numpy.cumsum(durations)])
         alignment_ticks = durations_cumul * 320  # 320 is hardcoded for 20ms rate here

+ 7 - 1
src/seamless_communication/models/aligner/loader.py

@@ -4,7 +4,7 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Any, List, Mapping, final
+from typing import Any, List, Mapping
 
 import torch
 from fairseq2.assets import asset_store, download_manager
@@ -22,6 +22,12 @@ from seamless_communication.models.unity.char_tokenizer import load_unity_char_t
 def convert_unity2_aligner_checkpoint(
     checkpoint: Mapping[str, Any], config: UnitY2AlignmentConfig
 ) -> Mapping[str, Any]:
+    if (
+        "model" in checkpoint
+        and "alignment_encoder.t_conv.1.weight" in checkpoint["model"]
+    ):
+        return checkpoint
+
     alignment_frontend_statedict = {}
     text_emb_state_keymap = {"weight": "alignment_frontend.embed_text.weight"}
     for k, v in checkpoint["text_emb_state"].items():

+ 0 - 1
src/seamless_communication/models/aligner/model.py

@@ -15,7 +15,6 @@ from fairseq2.data import CString
 from fairseq2.nn.embedding import StandardEmbedding
 from fairseq2.nn.padding import to_padding_mask
 from fairseq2.typing import DataType
-from numba import jit
 from torch import Tensor
 from torch.nn import Module
 

+ 34 - 18
tests/integration/models/test_unity2_aligner.py

@@ -7,48 +7,64 @@
 from typing import Final
 
 import torch
-from fairseq2.typing import Device
 from torch import tensor
 
-from tests.common import assert_equal, device
+from fairseq2.data.audio import AudioDecoderOutput
 from seamless_communication.models.aligner.alignment_extractor import AlignmentExtractor
-from fairseq2.data.audio import (
-    AudioDecoder,
-    AudioDecoderOutput
-)
-from fairseq2.memory import MemoryBlock
-from urllib.request import urlretrieve
-import tempfile
-from tests.common import assert_equal, device
+from tests.common import assert_equal, device, get_default_dtype
+
 
 REF_TEXT = "the examination and testimony of the experts enabled the commision to conclude that five shots may have been fired"
 
-REF_DURATIONS: Final = [[ 1,  1,  2,  1,  1,  5,  5,  6,  4,  3,  2,  3,  4,  4,  2,  2,  2,  1,
+# fmt: off
+REF_DURATIONS_FP16: Final = [[ 1,  1,  2,  1,  1,  5,  5,  6,  4,  3,  2,  3,  4,  4,  2,  2,  2,  1,
+          1,  1,  3,  3,  3,  4,  3,  3,  3,  4,  4,  3,  2,  2,  1,  1,  1,  1,
+          2,  4,  6,  5,  4,  3,  4,  5,  5, 16,  6,  3,  5,  5,  3,  3,  1,  2,
+          1,  1,  1,  2,  3,  2,  3,  1,  3,  3,  3,  2,  2,  4,  2,  2,  2,  3,
+          2,  4,  5,  4,  5,  8,  3, 17,  2,  2,  3,  2,  5,  4,  6,  3,  1,  1,
+          4,  4,  3,  5,  3,  3,  2,  2,  2,  2,  2,  2,  2,  1,  2,  2,  1,  1,
+          2,  6,  4,  5,  9,  5,  1, 12]]
+# fmt: on
+
+# fmt: off
+REF_DURATIONS_FP32: Final = [[ 1,  1,  2,  1,  1,  5,  5,  6,  4,  3,  2,  3,  4,  4,  2,  2,  2,  1,
            1,  1,  3,  3,  3,  4,  3,  3,  4,  3,  4,  3,  2,  2,  1,  1,  1,  1,
            2,  4,  6,  5,  4,  3,  4,  5,  5, 16,  6,  3,  5,  5,  3,  3,  1,  2,
            1,  1,  1,  2,  3,  2,  3,  1,  3,  3,  3,  2,  2,  4,  2,  2,  2,  3,
            2,  4,  5,  4,  5,  8,  3, 17,  2,  2,  3,  2,  5,  4,  6,  3,  1,  1,
            4,  4,  3,  5,  3,  3,  2,  2,  2,  2,  2,  2,  2,  1,  2,  2,  1,  1,
            2,  6,  4,  5,  9,  5,  1, 12]]
+# fmt: on
 
-def test_aligner(example_rate16k_audio: AudioDecoderOutput) -> None:
 
+def test_aligner(example_rate16k_audio: AudioDecoderOutput) -> None:
     aligner_name = "nar_t2u_aligner"
     unit_extractor_name = "xlsr2_1b_v2"
     unit_extractor_output_layer_n = 35
     unit_extractor_kmeans_uri = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
+    dtype = get_default_dtype()
+    if dtype == torch.float32:
+        ref_tensor = REF_DURATIONS_FP32
+    else:
+        ref_tensor = REF_DURATIONS_FP16
+
+    audio = example_rate16k_audio["waveform"].mean(
+        1
+    )  # averaging mono to get [Time] shape required by aligner
 
     extractor = AlignmentExtractor(
         aligner_name,
         unit_extractor_name,
         unit_extractor_output_layer_n,
         unit_extractor_kmeans_uri,
-        device=device
+        device=device,
+        dtype=dtype,
     )
 
-    audio = example_rate16k_audio["waveform"].mean(1)  # averaging mono to get [Time] shape required by aligner
-
-    alignment_durations, _, _ = extractor.extract_alignment(audio, REF_TEXT, plot=False, add_trailing_silence=True)
-
-    assert_equal(alignment_durations, tensor(REF_DURATIONS, device=device, dtype=torch.int64))
+    alignment_durations, _, _ = extractor.extract_alignment(
+        audio, REF_TEXT, plot=False, add_trailing_silence=True
+    )
 
+    assert_equal(
+        alignment_durations, tensor(ref_tensor, device=device, dtype=torch.int64)
+    )