浏览代码

unit_tokenizer doesn't block t2tt to text-only languages.

Kaushik Ram Sadagopan 2 年之前
父节点
当前提交
4bebf663c3

+ 0 - 0
src/seamless_communication/models/audio_to_units/.gitkeep


+ 1 - 1
src/seamless_communication/models/inference/translator.py

@@ -98,8 +98,8 @@ class Translator(nn.Module):
         generator = UnitYGenerator(
             model,
             text_tokenizer,
-            unit_tokenizer,
             tgt_lang,
+            unit_tokenizer if output_modality == Modality.SPEECH else None,
             text_opts=SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200)),
             unit_opts=SequenceGeneratorOptions(
                 beam_size=5, soft_max_seq_len=(max_len_a, 50)

+ 24 - 18
src/seamless_communication/models/unity/generator.py

@@ -31,15 +31,15 @@ class UnitYGenerator:
     model: UnitYModel
     s2t_generator: SequenceToTextGenerator
     t2t_generator: Optional[SequenceToTextGenerator]
-    unit_decoder: UnitTokenDecoder
-    unit_generator: Seq2SeqGenerator
+    unit_decoder: Optional[UnitTokenDecoder]
+    unit_generator: Optional[Seq2SeqGenerator]
 
     def __init__(
         self,
         model: UnitYModel,
         text_tokenizer: TextTokenizer,
-        unit_tokenizer: UnitTokenizer,
         target_lang: str,
+        unit_tokenizer: Optional[UnitTokenizer] = None,
         text_opts: Optional[SequenceGeneratorOptions] = None,
         unit_opts: Optional[SequenceGeneratorOptions] = None,
     ) -> None:
@@ -97,25 +97,28 @@ class UnitYGenerator:
                 t2t_model, text_tokenizer, target_lang, text_opts
             )
 
+        self.unit_generator = None
+        self.unit_decoder = None
         # Set up unit generator.
-        self.unit_decoder = unit_tokenizer.create_decoder()
+        if unit_tokenizer is not None:
+            self.unit_decoder = unit_tokenizer.create_decoder()
 
-        unit_encoder = unit_tokenizer.create_encoder(
-            lang=target_lang, device=infer_device(model.t2u_model)
-        )
-
-        if unit_opts is None:
-            # Speech sequences are typically much longer than text sequences.
-            unit_opts = SequenceGeneratorOptions(
-                soft_max_seq_len=(1, 50), hard_max_seq_len=5000
+            unit_encoder = unit_tokenizer.create_encoder(
+                lang=target_lang, device=infer_device(model.t2u_model)
             )
 
-        self.unit_generator = Seq2SeqGenerator(
-            model.t2u_model,
-            unit_tokenizer.vocab_info,
-            unit_encoder.prefix_indices,
-            unit_opts,
-        )
+            if unit_opts is None:
+                # Speech sequences are typically much longer than text sequences.
+                unit_opts = SequenceGeneratorOptions(
+                    soft_max_seq_len=(1, 50), hard_max_seq_len=5000
+                )
+
+            self.unit_generator = Seq2SeqGenerator(
+                model.t2u_model,
+                unit_tokenizer.vocab_info,
+                unit_encoder.prefix_indices,
+                unit_opts,
+            )
 
     @torch.inference_mode()
     def __call__(
@@ -176,6 +179,9 @@ class UnitYGenerator:
             decoder_output, decoder_padding_mask
         )
 
+        assert self.unit_generator is not None
+        assert self.unit_decoder is not None
+
         unit_gen_output = self.unit_generator(
             t2u_encoder_output,
             t2u_encoder_padding_mask,