|
@@ -7,7 +7,7 @@ import logging
|
|
|
from dataclasses import dataclass
|
|
|
from enum import Enum, auto
|
|
|
from pathlib import Path
|
|
|
-from typing import Callable, List, Optional, Tuple, Union, cast
|
|
|
+from typing import List, Optional, Tuple, Union, cast
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
@@ -39,6 +39,7 @@ from seamless_communication.models.unity import (
|
|
|
)
|
|
|
from seamless_communication.models.vocoder import load_vocoder_model
|
|
|
from seamless_communication.toxicity import (
|
|
|
+ BadWordChecker,
|
|
|
load_bad_word_checker,
|
|
|
)
|
|
|
from seamless_communication.toxicity.mintox import mintox_pipeline
|
|
@@ -110,9 +111,9 @@ class Translator(nn.Module):
|
|
|
# Load the model.
|
|
|
if device == torch.device("cpu"):
|
|
|
dtype = torch.float32
|
|
|
- self.model = self.load_model_for_inference(
|
|
|
- load_unity_model, model_name_or_card, device, dtype
|
|
|
- )
|
|
|
+
|
|
|
+ self.model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
|
|
+ self.model.eval()
|
|
|
assert isinstance(self.model, UnitYModel)
|
|
|
|
|
|
if text_tokenizer is None:
|
|
@@ -126,10 +127,9 @@ class Translator(nn.Module):
|
|
|
if self.model.t2u_model is not None:
|
|
|
self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
|
|
|
|
|
|
+ self.bad_word_checker: Optional[BadWordChecker] = None
|
|
|
if apply_mintox:
|
|
|
self.bad_word_checker = load_bad_word_checker("mintox")
|
|
|
- else:
|
|
|
- self.bad_word_checker = None
|
|
|
|
|
|
self.apply_mintox = apply_mintox
|
|
|
|
|
@@ -150,20 +150,10 @@ class Translator(nn.Module):
|
|
|
if vocoder_name_or_card is not None and (
|
|
|
output_modality is None or output_modality == Modality.SPEECH
|
|
|
):
|
|
|
- self.vocoder = self.load_model_for_inference(
|
|
|
- load_vocoder_model, vocoder_name_or_card, device, torch.float32
|
|
|
+ self.vocoder = load_vocoder_model(
|
|
|
+ vocoder_name_or_card, device=device, dtype=torch.float32
|
|
|
)
|
|
|
-
|
|
|
- @staticmethod
|
|
|
- def load_model_for_inference(
|
|
|
- load_model_fn: Callable[..., nn.Module],
|
|
|
- model_name_or_card: Union[str, AssetCard],
|
|
|
- device: Device,
|
|
|
- dtype: DataType,
|
|
|
- ) -> nn.Module:
|
|
|
- model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
|
|
|
- model.eval()
|
|
|
- return model
|
|
|
+ self.vocoder.eval()
|
|
|
|
|
|
@classmethod
|
|
|
def get_prediction(
|
|
@@ -272,7 +262,9 @@ class Translator(nn.Module):
|
|
|
input_modality, output_modality = self.get_modalities_from_task_str(task_str)
|
|
|
|
|
|
if self.apply_mintox and src_lang is None:
|
|
|
- raise ValueError("`src_lang` must be specified when `apply_mintox` is `True`.")
|
|
|
+ raise ValueError(
|
|
|
+ "`src_lang` must be specified when `apply_mintox` is `True`."
|
|
|
+ )
|
|
|
|
|
|
if isinstance(input, dict):
|
|
|
src = cast(SequenceData, input)
|