Parcourir la source

Skip loading text_encoder for S2X tasks, skip loading T2U model for X2T tasks. (#95)

* Use the same asset card to load any type of unity model.

* Apply the change when the user specifies a model_card directly.
Kaushik Ram Sadagopan il y a 1 an
Parent
commit
4e93254fa5

+ 3 - 8
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -267,10 +267,7 @@ def run_eval(
 
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
-                (
-                    text_output,
-                    speech_output,
-                ) = translator.predict(
+                (text_output, speech_output,) = translator.predict(
                     src,
                     ctx.task,
                     ctx.target_lang,
@@ -287,10 +284,7 @@ def run_eval(
                     speech_output = None
 
             if valid_sequences is not None and not valid_sequences.all():
-                (
-                    text_output,
-                    speech_output,
-                ) = adjust_output_for_corrupted_inputs(
+                (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
                     valid_sequences,
                     text_output,
                     speech_output,
@@ -398,6 +392,7 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         device,
         text_tokenizer=text_tokenizer,
         dtype=dtype,
+        input_modality=input_modality,
         output_modality=output_modality,
     )
 

+ 22 - 0
src/seamless_communication/inference/translator.py

@@ -11,6 +11,8 @@ from typing import Callable, List, Optional, Tuple, Union, cast
 
 import torch
 import torch.nn as nn
+
+from fairseq2.assets import asset_store
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater, SequenceData
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
@@ -34,6 +36,7 @@ from seamless_communication.models.unity import (
     load_unity_model,
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
+    unity_archs,
 )
 from seamless_communication.models.vocoder import load_vocoder_model
 
@@ -78,9 +81,28 @@ class Translator(nn.Module):
         device: Device,
         text_tokenizer: Optional[TextTokenizer] = None,
         dtype: DataType = torch.float16,
+        input_modality: Optional[Modality] = None,
         output_modality: Optional[Modality] = None,
     ):
         super().__init__()
+
+        if isinstance(model_name_or_card, str):
+            model_name_or_card = asset_store.retrieve_card(model_name_or_card)
+
+        assert isinstance(model_name_or_card, AssetCard)
+
+        if input_modality or output_modality:
+            unity_config = unity_archs.get_config(
+                model_name_or_card.field("model_arch").as_(str)
+            )
+            # Skip loading the text encoder.
+            if input_modality == Modality.SPEECH:
+                unity_config.use_text_encoder = False
+            # Skip loading the T2U model.
+            if output_modality == Modality.TEXT:
+                unity_config.t2u_config = None
+            model_name_or_card.field("model_config").set(unity_config)
+
         # Load the model.
         if device == torch.device("cpu"):
             dtype = torch.float32