|
@@ -4,7 +4,7 @@
|
|
# LICENSE file in the root directory of this source tree.
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
+from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
@@ -15,7 +15,7 @@ from fairseq2.data.text.text_tokenizer import TextTokenizer
|
|
from fairseq2.data.typing import StringLike
|
|
from fairseq2.data.typing import StringLike
|
|
from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
|
|
from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
|
|
from fairseq2.memory import MemoryBlock
|
|
from fairseq2.memory import MemoryBlock
|
|
-from fairseq2.typing import Device
|
|
|
|
|
|
+from fairseq2.typing import DataType, Device
|
|
from torch import Tensor
|
|
from torch import Tensor
|
|
from enum import Enum, auto
|
|
from enum import Enum, auto
|
|
from seamless_communication.models.inference.ngram_repeat_block_processor import (
|
|
from seamless_communication.models.inference.ngram_repeat_block_processor import (
|
|
@@ -56,11 +56,11 @@ class Translator(nn.Module):
|
|
sample_rate: int = 16000,
|
|
sample_rate: int = 16000,
|
|
):
|
|
):
|
|
super().__init__()
|
|
super().__init__()
|
|
|
|
+ dtype = torch.float16 if "cuda" in device.type else torch.float32
|
|
# Load the model.
|
|
# Load the model.
|
|
- self.model: UnitYModel = load_unity_model(
|
|
|
|
- model_name_or_card, device=device, dtype=torch.float16
|
|
|
|
|
|
+ self.model: UnitYModel = self.load_model_for_inference(
|
|
|
|
+ load_unity_model, model_name_or_card, device, dtype
|
|
)
|
|
)
|
|
- self.model.eval()
|
|
|
|
self.text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
|
self.text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
|
self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
|
|
self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
|
|
self.device = device
|
|
self.device = device
|
|
@@ -71,15 +71,27 @@ class Translator(nn.Module):
|
|
channel_last=True,
|
|
channel_last=True,
|
|
standardize=True,
|
|
standardize=True,
|
|
device=device,
|
|
device=device,
|
|
- dtype=torch.float16,
|
|
|
|
|
|
+ dtype=dtype,
|
|
)
|
|
)
|
|
self.collate = Collater(
|
|
self.collate = Collater(
|
|
pad_idx=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
|
|
pad_idx=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
|
|
)
|
|
)
|
|
# Load the vocoder.
|
|
# Load the vocoder.
|
|
- self.vocoder: Vocoder = load_vocoder_model(vocoder_name_or_card, device=device)
|
|
|
|
- self.vocoder.eval()
|
|
|
|
- self.sr = sample_rate
|
|
|
|
|
|
+ self.vocoder: Vocoder = self.load_model_for_inference(
|
|
|
|
+ load_vocoder_model, vocoder_name_or_card, device, torch.float32
|
|
|
|
+ )
|
|
|
|
+ self.sample_rate = sample_rate
|
|
|
|
+
|
|
|
|
+ @staticmethod
|
|
|
|
+ def load_model_for_inference(
|
|
|
|
+ load_model_fn: Callable[..., nn.Module],
|
|
|
|
+ model_name_or_card: Union[str, AssetCard],
|
|
|
|
+ device: Device,
|
|
|
|
+ dtype: DataType,
|
|
|
|
+ ) -> nn.Module:
|
|
|
|
+ model = load_model_fn(model_name_or_card, device=device, dtype=dtype)
|
|
|
|
+ model.eval()
|
|
|
|
+ return model
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def get_prediction(
|
|
def get_prediction(
|
|
@@ -136,17 +148,7 @@ class Translator(nn.Module):
|
|
else:
|
|
else:
|
|
return Modality.TEXT, Modality.SPEECH
|
|
return Modality.TEXT, Modality.SPEECH
|
|
|
|
|
|
- @torch.no_grad()
|
|
|
|
- def synthesize_speech(
|
|
|
|
- self,
|
|
|
|
- code: List[int],
|
|
|
|
- lang: str,
|
|
|
|
- speaker: Optional[int] = None,
|
|
|
|
- dur_prediction: Optional[bool] = True,
|
|
|
|
- ) -> Tuple[List[Tensor], int]:
|
|
|
|
- return self.vocoder(code, lang, speaker, dur_prediction), self.sr
|
|
|
|
-
|
|
|
|
- @torch.no_grad()
|
|
|
|
|
|
+ @torch.inference_mode()
|
|
def predict(
|
|
def predict(
|
|
self,
|
|
self,
|
|
input: Union[str, torch.Tensor],
|
|
input: Union[str, torch.Tensor],
|
|
@@ -173,8 +175,8 @@ class Translator(nn.Module):
|
|
|
|
|
|
:returns:
|
|
:returns:
|
|
- Translated text.
|
|
- Translated text.
|
|
- - Audio waveform.
|
|
|
|
- - Sampling rate of audio waveform.
|
|
|
|
|
|
+ - Generated output audio waveform corresponding to the translated text.
|
|
|
|
+ - Sample rate of output audio waveform.
|
|
"""
|
|
"""
|
|
try:
|
|
try:
|
|
task = Task[task_str.upper()]
|
|
task = Task[task_str.upper()]
|
|
@@ -192,7 +194,7 @@ class Translator(nn.Module):
|
|
else:
|
|
else:
|
|
decoded_audio = {
|
|
decoded_audio = {
|
|
"waveform": audio,
|
|
"waveform": audio,
|
|
- "sample_rate": self.sr,
|
|
|
|
|
|
+ "sample_rate": self.sample_rate,
|
|
"format": -1,
|
|
"format": -1,
|
|
}
|
|
}
|
|
src = self.collate(self.convert_to_fbank(decoded_audio))["fbank"]
|
|
src = self.collate(self.convert_to_fbank(decoded_audio))["fbank"]
|
|
@@ -223,5 +225,5 @@ class Translator(nn.Module):
|
|
return text_out.sentences[0], None, None
|
|
return text_out.sentences[0], None, None
|
|
else:
|
|
else:
|
|
units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
|
|
units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
|
|
- wav_out, sr_out = self.synthesize_speech(units, tgt_lang, spkr)
|
|
|
|
- return text_out.sentences[0], wav_out, sr_out
|
|
|
|
|
|
+ wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
|
|
|
|
+ return text_out.sentences[0], wav_out, self.sample_rate
|