浏览代码

Rename 'commercial' to 'v2' and condition duration prediction on T2U model type for vocoder. (#72)

Kaushik Ram Sadagopan 1 年之前
父节点
当前提交
c60f92c86f

+ 1 - 1
scripts/m4t/predict/predict.py

@@ -55,7 +55,7 @@ def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.Argumen
         "--vocoder_name",
         type=str,
         help="Vocoder model name",
-        default="vocoder_commercial",
+        default="vocoder_v2",
     )
     # Text generation args.
     parser.add_argument(

+ 1 - 1
src/seamless_communication/assets/cards/vocoder_commercial.yaml → src/seamless_communication/assets/cards/vocoder_v2.yaml

@@ -4,7 +4,7 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-name: vocoder_commercial
+name: vocoder_v2
 model_type: vocoder_code_hifigan
 model_arch: base
 checkpoint: "file://large_experiments/seamless/ust/krs/M4T_Vocoder/lang_36_commercial/km_10000/seed_1/g_00600000"

+ 5 - 1
src/seamless_communication/models/inference/translator.py

@@ -291,8 +291,12 @@ class Translator(nn.Module):
                 # Remove the lang token for AR UnitY since the vocoder doesn't need it
                 # in the unit sequence. tgt_lang is fed as an argument to the vocoder.
                 units = unit_output.units[:, 1:]
+                duration_prediction = True
             else:
                 units = unit_output.units
+                # Vocoder duration predictions not required since the NAR
+                # T2U model already predicts duration in the units.
+                duration_prediction = False
 
             audio_wavs = []
             speech_units = []
@@ -305,7 +309,7 @@ class Translator(nn.Module):
                 speech_units.append(u)
                 # TODO: Implement batched inference for vocoder.
                 translated_audio_wav = self.vocoder(
-                    u, tgt_lang, spkr, dur_prediction=True
+                    u, tgt_lang, spkr, dur_prediction=duration_prediction
                 )
                 audio_wavs.append(translated_audio_wav)
 

+ 2 - 2
tests/integration/models/test_translator.py

@@ -48,7 +48,7 @@ def test_seamless_m4t_v2_large_t2tt() -> None:
     else:
         dtype = torch.float16
 
-    translator = Translator(model_name, "vocoder_commercial", device, dtype=dtype)
+    translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
     text_output, _ = translator.predict(
         ENG_SENTENCE,
         "t2tt",
@@ -71,7 +71,7 @@ def test_seamless_m4t_v2_large_multiple_tasks() -> None:
     else:
         dtype = torch.float16
 
-    translator = Translator(model_name, "vocoder_commercial", device, dtype=dtype)
+    translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
 
     # Generate english speech for the english text.
     _, english_speech_output = translator.predict(

+ 2 - 2
tests/integration/models/test_unit_extraction.py

@@ -15,7 +15,7 @@ from tests.common import assert_equal, device
 
 
 # fmt: off
-REF_ENG_UNITS: Final = [8976, 8299,    0,    0, 9692, 5395,  785,  785, 7805, 6193, 2922, 4806, 3362, 3560, 9007, 8119, 8119,  205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 9631, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 3752, 6756,  963, 5665, 4191, 5205, 5205, 9568, 5092, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469,  296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846,  661, 2225,  944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389,  334, 6334]
+REF_ENG_UNITS: Final = [8976, 8299, 0, 0, 9692, 5395, 785, 785, 7805, 6193, 2922, 4806, 3362, 3560, 8119, 8119, 4335, 205, 5424, 5424, 5064, 7421, 6547, 9952, 3728, 8544, 3321, 1093, 1443, 7962, 3978, 8063, 5168, 5491, 9133, 9275, 5912, 8729, 5097, 5495, 1650, 5048, 2839, 6756, 5665, 4191, 5205, 5205, 9568, 9568, 5932, 1190, 9339, 5839, 5839, 6244, 5320, 3454, 5216, 721, 6994, 6513, 7754, 3469, 296, 1849, 3254, 3254, 5042, 5042, 3961, 2079, 1907, 1846, 661, 2225, 944, 9295, 4712, 1785, 6060, 8701, 7646, 1355, 2876, 8199, 5901, 8199, 3861, 5153, 6420, 2897, 1389, 334, 6334]
 # fmt: on
 
 
@@ -28,7 +28,7 @@ def test_unit_extraction() -> None:
     else:
         dtype = torch.float16
 
-    translator = Translator(model_name, "vocoder_commercial", device, dtype=dtype)
+    translator = Translator(model_name, "vocoder_v2", device, dtype=dtype)
     unit_extractor = UnitExtractor(
         "xlsr2_1b_v2",
         "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",