|
@@ -3,22 +3,22 @@
|
|
|
# This source code is licensed under the license found in the
|
|
|
# MIT_LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
-import torch
|
|
|
-import torchaudio
|
|
|
-
|
|
|
-from torch.nn import Module
|
|
|
+from typing import cast
|
|
|
+from copy import deepcopy
|
|
|
+from pathlib import Path
|
|
|
from typing import List, Optional, Tuple, Union
|
|
|
|
|
|
+import torch
|
|
|
+import torchaudio
|
|
|
from fairseq2.assets.card import AssetCard
|
|
|
from fairseq2.data import SequenceData, StringLike
|
|
|
from fairseq2.data.audio import WaveformToFbankConverter
|
|
|
from fairseq2.typing import DataType, Device
|
|
|
+from torch.nn import Module
|
|
|
|
|
|
-from seamless_communication.inference import BatchedSpeechOutput, Translator
|
|
|
from seamless_communication.inference.generator import SequenceGeneratorOptions
|
|
|
-from seamless_communication.inference.pretssel_generator import (
|
|
|
- PretsselGenerator,
|
|
|
-)
|
|
|
+from seamless_communication.inference.pretssel_generator import PretsselGenerator
|
|
|
+from seamless_communication.inference.translator import BatchedSpeechOutput, Translator
|
|
|
from seamless_communication.models.unity import (
|
|
|
load_gcmvn_stats,
|
|
|
load_unity_unit_tokenizer,
|
|
@@ -38,7 +38,7 @@ class ExpressiveTranslator(Module):
|
|
|
super().__init__()
|
|
|
|
|
|
unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
|
|
|
-
|
|
|
+
|
|
|
self.translator = Translator(
|
|
|
model_name_or_card,
|
|
|
vocoder_name_or_card=None,
|
|
@@ -65,13 +65,13 @@ class ExpressiveTranslator(Module):
|
|
|
_gcmvn_mean, _gcmvn_std = load_gcmvn_stats(vocoder_name_or_card)
|
|
|
self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
|
|
|
self.gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
|
|
|
-
|
|
|
+
|
|
|
@staticmethod
|
|
|
def remove_prosody_tokens_from_text(text_output: List[str]) -> List[str]:
|
|
|
modified_text_output = []
|
|
|
for text in text_output:
|
|
|
# filter out prosody tokens, there is only emphasis '*', and pause '='
|
|
|
- text = text.replace("*", "").replace("=", "")
|
|
|
+ text = str(text).replace("*", "").replace("=", "")
|
|
|
text = " ".join(text.split())
|
|
|
modified_text_output.append(text)
|
|
|
return modified_text_output
|
|
@@ -79,7 +79,7 @@ class ExpressiveTranslator(Module):
|
|
|
@torch.inference_mode()
|
|
|
def predict(
|
|
|
self,
|
|
|
- audio_path: str,
|
|
|
+ input: Union[Path, SequenceData],
|
|
|
tgt_lang: str,
|
|
|
text_generation_opts: Optional[SequenceGeneratorOptions] = None,
|
|
|
unit_generation_opts: Optional[SequenceGeneratorOptions] = None,
|
|
@@ -89,8 +89,8 @@ class ExpressiveTranslator(Module):
|
|
|
"""
|
|
|
The main method used to perform inference on all tasks.
|
|
|
|
|
|
- :param audio_path:
|
|
|
- Path to audio waveform.
|
|
|
+ :param input:
|
|
|
+ Either path to audio or audio Tensor.
|
|
|
:param tgt_lang:
|
|
|
Target language to decode into.
|
|
|
:param text_generation_opts:
|
|
@@ -105,32 +105,48 @@ class ExpressiveTranslator(Module):
|
|
|
- Batched list of Translated text.
|
|
|
- Translated BatchedSpeechOutput.
|
|
|
"""
|
|
|
- # TODO: Replace with fairseq2.data once re-sampling is implemented.
|
|
|
- wav, sample_rate = torchaudio.load(audio_path)
|
|
|
- wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000)
|
|
|
- wav = wav.transpose(0, 1)
|
|
|
-
|
|
|
- data = self.fbank_extractor(
|
|
|
- {
|
|
|
- "waveform": wav,
|
|
|
- "sample_rate": AUDIO_SAMPLE_RATE,
|
|
|
- }
|
|
|
- )
|
|
|
- fbank = data["fbank"]
|
|
|
- gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
|
|
|
- std, mean = torch.std_mean(fbank, dim=0)
|
|
|
- fbank = fbank.subtract(mean).divide(std)
|
|
|
-
|
|
|
- src = SequenceData(
|
|
|
- seqs=fbank.unsqueeze(0),
|
|
|
- seq_lens=torch.LongTensor([fbank.shape[0]]),
|
|
|
- is_ragged=False,
|
|
|
- )
|
|
|
- src_gcmvn = SequenceData(
|
|
|
- seqs=gcmvn_fbank.unsqueeze(0),
|
|
|
- seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
|
|
|
- is_ragged=False,
|
|
|
- )
|
|
|
+ if isinstance(input, dict):
|
|
|
+ src = cast(SequenceData, input)
|
|
|
+ src_gcmvn = deepcopy(src)
|
|
|
+
|
|
|
+ fbank = src["seqs"]
|
|
|
+ gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
|
|
|
+ std, mean = torch.std_mean(fbank, dim=[0, 1]) # B x T x mel
|
|
|
+ ucmvn_fbank = fbank.subtract(mean).divide(std)
|
|
|
+
|
|
|
+ src["seqs"] = ucmvn_fbank
|
|
|
+ src_gcmvn["seqs"] = gcmvn_fbank
|
|
|
+
|
|
|
+ elif isinstance(input, Path):
|
|
|
+ # TODO: Replace with fairseq2.data once re-sampling is implemented.
|
|
|
+ wav, sample_rate = torchaudio.load(path)
|
|
|
+ wav = torchaudio.functional.resample(
|
|
|
+ wav, orig_freq=sample_rate, new_freq=AUDIO_SAMPLE_RATE,
|
|
|
+ )
|
|
|
+ wav = wav.transpose(0, 1)
|
|
|
+
|
|
|
+ data = self.fbank_extractor(
|
|
|
+ {
|
|
|
+ "waveform": wav,
|
|
|
+ "sample_rate": AUDIO_SAMPLE_RATE,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ fbank = data["fbank"]
|
|
|
+ gcmvn_fbank = fbank.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
|
|
|
+ std, mean = torch.std_mean(fbank, dim=0)
|
|
|
+ fbank = fbank.subtract(mean).divide(std)
|
|
|
+
|
|
|
+ src = SequenceData(
|
|
|
+ seqs=fbank.unsqueeze(0),
|
|
|
+ seq_lens=torch.LongTensor([fbank.shape[0]]),
|
|
|
+ is_ragged=False,
|
|
|
+ )
|
|
|
+ src_gcmvn = SequenceData(
|
|
|
+ seqs=gcmvn_fbank.unsqueeze(0),
|
|
|
+ seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
|
|
|
+ is_ragged=False,
|
|
|
+ )
|
|
|
|
|
|
text_output, unit_output = self.translator.predict(
|
|
|
src,
|