浏览代码

Add integrated test for ProsodyUnitY/seamless_expressivity model (#99)

* Add integrated test for ProsodyUnitY/seamless_expressivity model

* minor

* uploaded the wav file to fbaipublicfiles

* tmp

* move fixture to common.py

* add end-to-end Expressivity test

* tmp

* fix to fp32 in integrated test

* move download to fixture

* formatter

* mypy

* address PR comment
Yilin Yang 1 年之前
父节点
当前提交
0cc7bc610a

+ 28 - 1
tests/common.py

@@ -8,7 +8,9 @@ from contextlib import contextmanager
 from typing import Any, Generator, List, Optional, Union
 
 import torch
-from fairseq2.typing import Device
+from fairseq2.data import Collater
+from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
+from fairseq2.typing import DataType, Device
 from torch import Tensor
 
 # The default device that tests should use. Note that pytest can change it based
@@ -64,3 +66,28 @@ def tmp_rng_seed(device: Device, seed: int = 0) -> Generator[None, None, None]:
         torch.manual_seed(seed)
 
         yield
+
+
+def get_default_dtype() -> DataType:
+    if device == Device("cpu"):
+        dtype = torch.float32
+    else:
+        dtype = torch.float16
+    return dtype
+
+
+def convert_to_collated_fbank(audio_dict: WaveformToFbankInput, dtype: DataType) -> Any:
+    convert_to_fbank = WaveformToFbankConverter(
+        num_mel_bins=80,
+        waveform_scale=2**15,
+        channel_last=True,
+        standardize=True,
+        device=device,
+        dtype=dtype,
+    )
+
+    collater = Collater(pad_value=1)
+
+    feat = collater(convert_to_fbank(audio_dict))["fbank"]
+
+    return feat

+ 20 - 0
tests/conftest.py

@@ -4,10 +4,15 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+import tempfile
 from argparse import ArgumentTypeError
 from typing import cast
+from urllib.request import urlretrieve
 
 import pytest
+import torch
+from fairseq2.data.audio import AudioDecoder, AudioDecoderOutput
+from fairseq2.memory import MemoryBlock
 from fairseq2.typing import Device
 
 import tests.common
@@ -31,3 +36,18 @@ def pytest_addoption(parser: pytest.Parser) -> None:
 
 def pytest_sessionstart(session: pytest.Session) -> None:
     tests.common.device = cast(Device, session.config.getoption("device"))
+
+
+@pytest.fixture(scope="module")
+def example_rate16k_audio() -> AudioDecoderOutput:
+    url = "https://dl.fbaipublicfiles.com/seamlessM4T/LJ037-0171_sr16k.wav"
+
+    audio_decoder = AudioDecoder(dtype=torch.float32, device=tests.common.device)
+
+    with tempfile.NamedTemporaryFile() as f:
+        urlretrieve(url, f.name)
+        with open(f.name, "rb") as fb:
+            block = MemoryBlock(fb.read())
+        decoded_audio = audio_decoder(block)
+
+    return decoded_audio

+ 4 - 13
tests/integration/inference/test_translator.py

@@ -10,7 +10,7 @@ import torch
 from fairseq2.typing import Device
 
 from seamless_communication.inference import Translator
-from tests.common import device
+from tests.common import device, get_default_dtype
 
 # fmt: off
 ENG_SENTENCE:     Final = "On Monday, scientists from the Stanford University School of Medicine announced the invention of a new diagnostic tool that can sort cells by type: a tiny printable chip that can be manufactured using standard inkjet printers for possibly about one U.S. cent each."
@@ -24,10 +24,7 @@ def test_seamless_m4t_large_t2tt() -> None:
     src_lang = "eng"
     tgt_lang = "deu"
 
-    if device == Device("cpu"):
-        dtype = torch.float32
-    else:
-        dtype = torch.float16
+    dtype = get_default_dtype()
 
     translator = Translator(model_name, "vocoder_36langs", device, dtype=dtype)
     text_output, _ = translator.predict(
@@ -44,10 +41,7 @@ def test_seamless_m4t_v2_large_t2tt() -> None:
     src_lang = "eng"
     tgt_lang = "deu"
 
-    if device == Device("cpu"):
-        dtype = torch.float32
-    else:
-        dtype = torch.float16
+    dtype = get_default_dtype()
 
     translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
     text_output, _ = translator.predict(
@@ -67,10 +61,7 @@ def test_seamless_m4t_v2_large_multiple_tasks() -> None:
     ref_spanish_text = "Hola, espero que todos estéis haciendo bien."
     ref_spanish_asr_text = "Hola, espero que todos estéis haciendo bien."
 
-    if device == Device("cpu"):
-        dtype = torch.float32
-    else:
-        dtype = torch.float16
+    dtype = get_default_dtype()
 
     translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
 

+ 84 - 0
tests/integration/models/test_expressivity.py

@@ -0,0 +1,84 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Final
+
+import torch
+from fairseq2.data.audio import AudioDecoderOutput
+from torch import tensor
+
+from seamless_communication.inference import Translator
+from seamless_communication.inference.pretssel_generator import PretsselGenerator
+from seamless_communication.models.unit_extractor import UnitExtractor
+from seamless_communication.models.unity import load_gcmvn_stats
+from tests.common import (
+    assert_equal,
+    convert_to_collated_fbank,
+    device,
+    get_default_dtype,
+)
+
+# fmt: off
+REF_UNITS: Final = [8976, 6589, 6589, 5736, 7542, 6515, 1240, 8335, 2381, 1076, 1076, 3380, 4085, 8207, 7957, 4446, 2641, 2544, 5552, 5529, 6319, 2779, 2890, 2890, 3229, 3303, 9751, 1979, 664, 1859, 1302, 528, 1303, 9543, 5770, 3532, 1286, 1286, 1727, 9287, 5248, 5586, 594, 3385, 2613, 1717, 7529, 7634, 931, 1602, 4512, 850, 2748, 5056, 1086, 2320, 2320, 9320, 3223, 5592, 1122, 419, 24, 4126, 5200, 2712, 9549, 8676, 8676, 3443, 7598, 7598, 2200, 2745, 1215, 118, 3840, 2703, 1616, 8788, 1240, 3349, 4890, 2756, 166, 9574, 9773, 5887, 2516, 9332, 6092, 3377, 4334, 3127, 3127, 3127, 944, 3089, 5947, 6572, 6572, 7561, 4358, 4358, 4358, 8124, 5549, 9275, 82, 8830, 8830, 5949, 22, 6729, 6878, 3817, 1871, 6092, 1441, 3127, 3928, 8254, 7984, 1116, 2796, 1806, 3710, 797, 9269, 576, 576, 2020, 137, 6624, 3815, 8690, 3634, 6036, 3530, 8719, 3458, 138, 8745, 5233, 2235, 8580, 8580, 6831, 2709, 7136, 9693, 3437, 3437, 3238, 4368, 2321, 2321, 391, 391, 4976, 8622, 6722, 3864, 9113, 9113, 7222, 7222, 7937, 999, 1286, 1286, 7789, 9396, 9603, 6690, 5233, 2235, 618, 8830, 6954, 3668, 4302, 596, 1934, 2886, 2704, 9097, 4161, 458, 4147, 9245, 9245, 3127, 3127, 944, 9676, 9676, 3468, 270, 270, 4608, 5549, 4182, 102, 8568, 1286, 1286, 5087, 817, 4153, 207, 207, 3763, 6415, 5188, 6010, 554, 753, 9953, 5104, 3828, 1879, 995, 9683, 6932, 3644, 2683, 9335, 183, 5525, 7023, 9568, 6222, 6315, 676, 3443, 6971, 2084, 999, 1286, 1286, 9620, 9620, 1048, 5577, 9328, 4963, 1364, 8328, 4573, 4573, 7917, 7917, 560, 2020, 4923, 137, 9542, 5832, 9775, 4780, 9400, 2745, 2745, 8984, 628, 8834, 6932, 3817, 8312, 5393, 458, 4147, 9191, 2225, 2759, 8980, 2351, 193, 1476, 9347, 3063, 2076, 3641, 1614, 9832, 3554, 8197, 5589, 5589, 7306, 184, 1708, 2954, 2954, 3485, 3485, 7665, 8909, 5405, 3590, 3590, 3446, 6442, 6442, 2802, 5549, 3791]
+REF_WAVE_EXTRACTED_UNITS: Final = [8976, 2066, 3800, 2357, 2357, 8080, 9479, 2181, 311, 7241, 5301, 9666, 9925, 940, 9479, 9479, 9479, 3151, 9666, 9925, 2937, 9479, 9479, 3043, 9666, 9189, 9189, 4821, 2937, 2357, 9479, 9479, 9666, 9666, 9666, 9666, 9666, 9666, 9479, 1369, 247, 5025, 5574, 940, 2937, 9479, 9479, 9666, 9666, 9666, 9666, 9666, 5025, 9666, 9666, 9666, 9666, 9666, 9925, 9666, 9479, 9479, 9666, 9666, 9479, 9666, 9479, 9666, 1589, 9666, 9362, 940, 2937, 2937, 9479, 9479, 8063, 9666, 9925, 2937, 9479, 9479, 9666, 9666, 9666, 2130, 4978, 1589, 5574, 5574, 9925, 2937, 9479, 515, 2379, 9666, 9666, 9666, 1589, 4978, 9532, 225, 225, 225, 1251, 225, 3978, 3800, 6343, 1840, 8080, 9666, 9479, 5514, 9666, 6606, 940, 2937, 9479, 9479, 9479, 9666, 9666, 9666, 9666, 9479, 9479, 9666, 9666, 9666, 940, 8080, 9479, 9479, 9479, 9666, 9666, 9666, 9479, 9479, 515, 247, 5025, 5574, 940, 9536, 9479, 9479, 9666, 9666, 9666, 9666, 9666, 1369, 9666, 1653, 4978, 530, 1589, 5574, 940, 940, 9479, 9479, 9479, 9666, 9666, 9925, 9666, 2937, 9479, 8770, 515, 9666, 2130, 5574, 5574, 940, 2937, 9479, 9479, 9479, 8770, 1369, 9580, 1589, 1589, 5574, 940, 2937, 9479, 5634, 9666, 9479, 9202, 1351, 8193, 4660, 4660, 4660, 1463, 1251, 2130, 5574, 1840, 2937, 9479, 9479, 515, 2066, 1653, 7962, 530, 1589, 9666, 940, 940, 9479, 9479, 9666, 9666, 9479, 515, 515, 2720, 8819, 530, 9666, 8063, 940, 2937, 9666, 9666, 9666, 9479, 9666, 9666, 2379, 9925, 2937, 9479, 9479, 1351, 8193, 1589, 9666, 1589, 5574, 940, 2937, 9479, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9666, 9479, 1369, 7810, 9666, 1589, 8063, 5574, 940, 940, 2937, 9479, 9479, 9666, 9479, 311, 1369, 1589, 1589, 5574, 940, 2937, 9479, 7848, 3511, 1589, 1795, 5574, 940, 940, 5786, 2003, 8857, 8193, 8193, 1653, 979, 8471, 8471, 1275, 1885, 225, 225, 4199]
+# fmt: on
+
+
+def test_seamless_expressivity(example_rate16k_audio: AudioDecoderOutput) -> None:
+    # float16 is seeing non-deterministic behavior
+    dtype = torch.float32
+
+    audio_dict = example_rate16k_audio
+
+    feat = convert_to_collated_fbank(audio_dict, dtype=dtype)
+
+    unity_model_name = "seamless_expressivity"
+    vocoder_model_name = "vocoder_mel"
+    pretssel_model_name = "pretssel_v1"
+    target_lang = "fra"
+
+    translator = Translator(unity_model_name, None, device, dtype=dtype)
+
+    _, speech_output = translator.predict(
+        feat,
+        "s2st",
+        target_lang,
+        prosody_encoder_input=feat,
+    )
+
+    assert speech_output is not None
+
+    units = tensor(speech_output.units[0], device=device, dtype=torch.int64)
+
+    # same target units
+    assert_equal(units, tensor(REF_UNITS).to(units))
+
+    pretssel_generator = PretsselGenerator(
+        unity_model_name,
+        vocoder_model_name,
+        pretssel_model_name,
+        device=device,
+        dtype=dtype,
+    )
+
+    # same target mel_spectrogram
+    speech_output = pretssel_generator.predict(
+        speech_output.units,
+        tgt_lang=target_lang,
+        prosody_encoder_input=feat,
+    )
+
+    # UnitExtrator only operates in fp32
+    waveform = speech_output.audio_wavs[0][0].float()
+
+    unit_extractor = UnitExtractor(
+        "xlsr2_1b_v2",
+        "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
+        device=device,
+    )
+    units = unit_extractor.predict(waveform, 34)
+
+    assert_equal(units, tensor(REF_WAVE_EXTRACTED_UNITS).to(units))

+ 19 - 26
tests/integration/models/test_pretssel_vocoder.py

@@ -4,46 +4,39 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-import tempfile
-from urllib.request import urlretrieve
+from typing import cast
 
 import torch
-import torchaudio
+from fairseq2.data.audio import AudioDecoderOutput, WaveformToFbankInput
 
 from seamless_communication.models.vocoder.loader import load_mel_vocoder_model
-from tests.common import assert_close, device
+from tests.common import (
+    assert_close,
+    convert_to_collated_fbank,
+    device,
+    get_default_dtype,
+)
 
 
-def test_pretssel_vocoder() -> None:
-    n_mel_bins = 80
+def test_pretssel_vocoder(example_rate16k_audio: AudioDecoderOutput) -> None:
     sample_rate = 16_000
 
-    vocoder = load_mel_vocoder_model(
-        "vocoder_mel", device=device, dtype=torch.float32
-    )
+    dtype = get_default_dtype()
 
-    url = "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
+    audio_dict = example_rate16k_audio
 
-    with tempfile.NamedTemporaryFile() as f:
-        urlretrieve(url, f.name)
-        _wav, _sr = torchaudio.load(f.name)
+    feat = convert_to_collated_fbank(audio_dict, dtype=dtype)["seqs"][0]
 
-    wav = torchaudio.sox_effects.apply_effects_tensor(
-        _wav, _sr, [["rate", f"{sample_rate}"], ["channels", "1"]]
-    )[0].to(device=device)
-    feat = torchaudio.compliance.kaldi.fbank(
-        wav * (2**15), num_mel_bins=n_mel_bins, sample_frequency=sample_rate
-    )
+    vocoder = load_mel_vocoder_model("vocoder_mel", device=device, dtype=torch.float32)
+    vocoder.eval()
 
     with torch.no_grad():
-        wav_hat = vocoder(feat).t()
+        wav_hat = vocoder(feat).view(1, -1)
 
-    feat_hat = torchaudio.compliance.kaldi.fbank(
-        wav_hat * (2**15), num_mel_bins=n_mel_bins, sample_frequency=sample_rate
-    )
+    audio_hat = {"sample_rate": sample_rate, "waveform": wav_hat}
 
-    assert_close(feat_hat, feat[: feat_hat.shape[0], :], atol=0.0, rtol=5.0)
+    audio_hat_dict = cast(WaveformToFbankInput, audio_hat)
 
+    feat_hat = convert_to_collated_fbank(audio_hat_dict, dtype=dtype)["seqs"][0]
 
-if __name__ == "__main__":
-    test_pretssel_vocoder()
+    assert_close(feat_hat, feat[: feat_hat.shape[0], :], atol=0.0, rtol=5.0)

+ 2 - 5
tests/integration/models/test_unit_extractor.py

@@ -12,7 +12,7 @@ from torch import tensor
 
 from seamless_communication.inference import Translator
 from seamless_communication.models.unit_extractor import UnitExtractor
-from tests.common import assert_equal, device
+from tests.common import assert_equal, device, get_default_dtype
 
 # fmt: off
 REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 8119, 8119, 4335, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 8063, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 2839, 6756, 5665, 4191, 5205, 5205, 9568, 9568, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
@@ -23,10 +23,7 @@ def test_unit_extractor() -> None:
     model_name = "seamlessM4T_v2_large"
     english_text = "Hello! I hope you're all doing well."
 
-    if device == Device("cpu"):
-        dtype = torch.float32
-    else:
-        dtype = torch.float16
+    dtype = get_default_dtype()
 
     translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
     unit_extractor = UnitExtractor(