|
@@ -11,6 +11,8 @@ from typing import Callable, List, Optional, Tuple, Union, cast
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
+
|
|
|
+from fairseq2.assets import asset_store
|
|
|
from fairseq2.assets.card import AssetCard
|
|
|
from fairseq2.data import Collater, SequenceData
|
|
|
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
|
|
@@ -34,6 +36,7 @@ from seamless_communication.models.unity import (
|
|
|
load_unity_model,
|
|
|
load_unity_text_tokenizer,
|
|
|
load_unity_unit_tokenizer,
|
|
|
+ unity_archs,
|
|
|
)
|
|
|
from seamless_communication.models.vocoder import load_vocoder_model
|
|
|
|
|
@@ -78,9 +81,28 @@ class Translator(nn.Module):
|
|
|
device: Device,
|
|
|
text_tokenizer: Optional[TextTokenizer] = None,
|
|
|
dtype: DataType = torch.float16,
|
|
|
+ input_modality: Optional[Modality] = None,
|
|
|
output_modality: Optional[Modality] = None,
|
|
|
):
|
|
|
super().__init__()
|
|
|
+
|
|
|
+ if isinstance(model_name_or_card, str):
|
|
|
+ model_name_or_card = asset_store.retrieve_card(model_name_or_card)
|
|
|
+
|
|
|
+ assert isinstance(model_name_or_card, AssetCard)
|
|
|
+
|
|
|
+ if input_modality or output_modality:
|
|
|
+ unity_config = unity_archs.get_config(
|
|
|
+ model_name_or_card.field("model_arch").as_(str)
|
|
|
+ )
|
|
|
+ # Skip loading the text encoder.
|
|
|
+ if input_modality == Modality.SPEECH:
|
|
|
+ unity_config.use_text_encoder = False
|
|
|
+ # Skip loading the T2U model.
|
|
|
+ if output_modality == Modality.TEXT:
|
|
|
+ unity_config.t2u_config = None
|
|
|
+ model_name_or_card.field("model_config").set(unity_config)
|
|
|
+
|
|
|
# Load the model.
|
|
|
if device == torch.device("cpu"):
|
|
|
dtype = torch.float32
|