Browse Source

Address comments.

Kaushik Ram Sadagopan 2 years ago
parent
commit
85c1b894cd

+ 2 - 1
src/seamless_communication/models/inference/translator.py

@@ -244,7 +244,8 @@ class Translator(nn.Module):
             return text_out.sentences[0], None, None
             return text_out.sentences[0], None, None
         else:
         else:
             if isinstance(self.model.t2u_model, UnitYT2UModel):
             if isinstance(self.model.t2u_model, UnitYT2UModel):
-                # Remove the lang token for AR UnitY.
+                # Remove the lang token for AR UnitY since the vocoder doesn't need it
+                # in the unit sequence. tgt_lang is fed as an argument to the vocoder.
                 units = unit_out.units[:, 1:]
                 units = unit_out.units[:, 1:]
             else:
             else:
                 units = unit_out.units
                 units = unit_out.units

+ 2 - 2
src/seamless_communication/models/unity/builder.py

@@ -383,11 +383,11 @@ def create_unity_model(
     else:
     else:
         t2u_builder = UnitYT2UBuilder(config.t2u_config, device=device, dtype=dtype)
         t2u_builder = UnitYT2UBuilder(config.t2u_config, device=device, dtype=dtype)
 
 
-    nllb_builder = NllbBuilder(config.mt_model_config, device=device, dtype=dtype)
+    mt_model_builder = NllbBuilder(config.mt_model_config, device=device, dtype=dtype)
     unity_builder = UnitYBuilder(
     unity_builder = UnitYBuilder(
         config,
         config,
         w2v2_encoder_builder,
         w2v2_encoder_builder,
-        nllb_builder,
+        mt_model_builder,
         t2u_builder,
         t2u_builder,
         device=device,
         device=device,
         dtype=dtype,
         dtype=dtype,

+ 6 - 1
src/seamless_communication/models/unity/unit_tokenizer.py

@@ -69,7 +69,12 @@ class UnitTokenizer:
 
 
     def index_to_lang(self, idx: int) -> str:
     def index_to_lang(self, idx: int) -> str:
         """Return the language of the specified language symbol index."""
         """Return the language of the specified language symbol index."""
-        relative_idx = idx - self.num_units - 5
+        relative_idx = (
+            idx
+            - self.num_units
+            - (self.lang_symbol_repititions - 1) * len(self.langs)
+            - 5
+        )
 
 
         if relative_idx < 0 or relative_idx >= len(self.langs):
         if relative_idx < 0 or relative_idx >= len(self.langs):
             raise ValueError(
             raise ValueError(