Browse Source

Add use_text_decoder parameter to UnitYConfig. (#88)

Kaushik Ram Sadagopan 1 year ago
parent
commit
a11376477b

+ 19 - 5
src/seamless_communication/models/unity/builder.py

@@ -73,6 +73,9 @@ class UnitYConfig:
     use_text_encoder: bool
     """If ``True``, uses an aligned MT encoder for the MT task."""
 
+    use_text_decoder: bool
+    """If ``False``, skips loading a text decoder, to be used with a Monotonic decoder."""
+
     use_conformer_adaptor: bool
     """If ``True``, uses a Conformer-based adaptor block."""
 
@@ -120,6 +123,7 @@ def _base() -> UnitYConfig:
         t2u_config=t2u_config,
         prosody_encoder_config=None,
         use_text_encoder=True,
+        use_text_decoder=True,
         use_conformer_adaptor=False,
         use_gelu=False,
         num_adaptor_layers=1,
@@ -147,6 +151,7 @@ def _medium() -> UnitYConfig:
         t2u_config=t2u_config,
         prosody_encoder_config=None,
         use_text_encoder=True,
+        use_text_decoder=True,
         use_conformer_adaptor=False,
         use_gelu=False,
         num_adaptor_layers=1,
@@ -176,6 +181,7 @@ def _base_v2() -> UnitYConfig:
         t2u_config=t2u_config,
         prosody_encoder_config=None,
         use_text_encoder=True,
+        use_text_decoder=True,
         use_conformer_adaptor=False,
         use_gelu=False,
         num_adaptor_layers=1,
@@ -209,6 +215,7 @@ def _expressivity_v2() -> UnitYConfig:
         t2u_config=t2u_config,
         prosody_encoder_config=prosody_encoder_config,
         use_text_encoder=False,
+        use_text_decoder=True,
         use_conformer_adaptor=False,
         use_gelu=True,
         num_adaptor_layers=1,
@@ -290,17 +297,24 @@ class UnitYBuilder:
         speech_encoder_frontend = self.w2v2_encoder_builder.build_frontend()
         speech_encoder = self.build_speech_encoder()
 
-        text_decoder_frontend = self.mt_model_builder.build_frontend(text_embed)
-        text_decoder = self.mt_model_builder.build_decoder()
-
         if self.config.use_text_encoder:
-            # We use shared embedding as in NLLB.
-            text_encoder_frontend = text_decoder_frontend
+            text_encoder_frontend = self.mt_model_builder.build_frontend(text_embed)
             text_encoder = self.mt_model_builder.build_encoder()
         else:
             text_encoder_frontend = None
             text_encoder = None
 
+        if self.config.use_text_decoder:
+            if text_encoder_frontend is not None:
+                # We use shared embedding as in NLLB.
+                text_decoder_frontend = text_encoder_frontend
+            else:
+                text_decoder_frontend = self.mt_model_builder.build_frontend(text_embed)
+            text_decoder = self.mt_model_builder.build_decoder()
+        else:
+            text_decoder_frontend = None
+            text_decoder = None
+
         final_proj = TiedProjection(text_embed.weight, bias=None)
 
         if self.t2u_builder is None:

+ 27 - 6
src/seamless_communication/models/unity/model.py

@@ -39,8 +39,8 @@ class UnitYModel(EncoderDecoderModel):
     speech_encoder: TransformerEncoder
     text_encoder_frontend: Optional[TransformerFrontend]
     text_encoder: Optional[TransformerEncoder]
-    text_decoder_frontend: TransformerFrontend
-    text_decoder: TransformerDecoder
+    text_decoder_frontend: Optional[TransformerFrontend]
+    text_decoder: Optional[TransformerDecoder]
     final_proj: Projection
     t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None]
     prosody_encoder_model: Optional[ECAPA_TDNN]
@@ -51,8 +51,8 @@ class UnitYModel(EncoderDecoderModel):
         speech_encoder: TransformerEncoder,
         text_encoder_frontend: Optional[TransformerFrontend],
         text_encoder: Optional[TransformerEncoder],
-        text_decoder_frontend: TransformerFrontend,
-        text_decoder: TransformerDecoder,
+        text_decoder_frontend: Optional[TransformerFrontend],
+        text_decoder: Optional[TransformerDecoder],
         final_proj: Projection,
         t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
         target_vocab_info: VocabularyInfo,
@@ -85,8 +85,22 @@ class UnitYModel(EncoderDecoderModel):
             self.register_module("text_encoder_frontend", None)
             self.register_module("text_encoder", None)
 
-        self.text_decoder_frontend = text_decoder_frontend
-        self.text_decoder = text_decoder
+        if text_decoder is not None:
+            if text_decoder_frontend is None:
+                raise ValueError(
+                    "Both `text_decoder` and `text_decoder_frontend` must be specified, but `text_decoder_frontend` is `None`."
+                )
+
+            self.text_decoder_frontend = text_decoder_frontend
+            self.text_decoder = text_decoder
+        else:
+            if text_decoder_frontend is not None:
+                raise ValueError(
+                    "Both `text_encoder` and `text_encoder_frontend` must be specified, but `text_decoder` is `None`."
+                )
+
+            self.register_module("text_decoder_frontend", None)
+            self.register_module("text_decoder", None)
 
         self.final_proj = final_proj
 
@@ -146,6 +160,13 @@ class UnitYModel(EncoderDecoderModel):
         *,
         state_bag: Optional[IncrementalStateBag] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
+        if self.text_decoder is None:
+            raise ValueError(
+                "`decode()` requires a text decoder, but the current UnitY model does not have one."
+            )
+
+        assert self.text_decoder_frontend is not None
+
         seqs, padding_mask = self.text_decoder_frontend(
             seqs, padding_mask, state_bag=state_bag
         )