Эх сурвалжийг харах

Made unity_nar_multilingual compatible with Translator.

Kaushik Ram Sadagopan 2 жил өмнө
parent
commit
8fac405bcb

+ 1 - 1
src/seamless_communication/assets/cards/unity_nar_multilingual.yaml

@@ -4,7 +4,7 @@
 # This source code is licensed under the BSD-style license found in the
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
-name: nar_multilingual
+name: unity_nar_multilingual
 base: unity_nllb-100
 base: unity_nllb-100
 model_arch: nar_multilingual
 model_arch: nar_multilingual
 char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"
 char_tokenizer: "file://checkpoint/krs/unity2/spm_char_lang38_tc.model"

+ 9 - 1
src/seamless_communication/models/inference/translator.py

@@ -26,6 +26,7 @@ from seamless_communication.models.unity import (
     UnitTokenizer,
     UnitTokenizer,
     UnitYGenerator,
     UnitYGenerator,
     UnitYModel,
     UnitYModel,
+    UnitYT2UModel,
     load_unity_model,
     load_unity_model,
     load_unity_text_tokenizer,
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
     load_unity_unit_tokenizer,
@@ -242,6 +243,13 @@ class Translator(nn.Module):
         if output_modality == Modality.TEXT:
         if output_modality == Modality.TEXT:
             return text_out.sentences[0], None, None
             return text_out.sentences[0], None, None
         else:
         else:
-            units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
+            if isinstance(self.model.t2u_model, UnitYT2UModel):
+                # Remove the lang token for AR UnitY.
+                units = unit_out.units[:, 1:]
+            else:
+                units = unit_out.units
+
+            # TODO: batch_size set to 1 for now, implement batching.
+            units = units[0].cpu().numpy().tolist()
             wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
             wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
             return text_out.sentences[0], wav_out, sample_rate
             return text_out.sentences[0], wav_out, sample_rate