|
@@ -10,27 +10,14 @@ import torch
|
|
import torchaudio
|
|
import torchaudio
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
|
|
|
-from fairseq2.data import SequenceData
|
|
|
|
-from fairseq2.data.audio import WaveformToFbankConverter
|
|
|
|
-
|
|
|
|
-from seamless_communication.cli.expressivity.predict.pretssel_generator import (
|
|
|
|
- PretsselGenerator,
|
|
|
|
-)
|
|
|
|
from seamless_communication.cli.m4t.predict import (
|
|
from seamless_communication.cli.m4t.predict import (
|
|
add_inference_arguments,
|
|
add_inference_arguments,
|
|
set_generation_opts,
|
|
set_generation_opts,
|
|
)
|
|
)
|
|
-from seamless_communication.inference import Translator
|
|
|
|
-from seamless_communication.models.unity import (
|
|
|
|
- load_gcmvn_stats,
|
|
|
|
- load_unity_unit_tokenizer,
|
|
|
|
-)
|
|
|
|
|
|
+from seamless_communication.inference import ExpressiveTranslator
|
|
from seamless_communication.store import add_gated_assets
|
|
from seamless_communication.store import add_gated_assets
|
|
|
|
|
|
|
|
|
|
-AUDIO_SAMPLE_RATE = 16000
|
|
|
|
-
|
|
|
|
-
|
|
|
|
logging.basicConfig(
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
|
|
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
|
|
@@ -39,13 +26,6 @@ logging.basicConfig(
|
|
logger = logging.getLogger(__name__)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
-def remove_prosody_tokens_from_text(text: str) -> str:
|
|
|
|
- # filter out prosody tokens, there is only emphasis '*', and pause '='
|
|
|
|
- text = text.replace("*", "").replace("=", "")
|
|
|
|
- text = " ".join(text.split())
|
|
|
|
- return text
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def main() -> None:
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference.")
|
|
parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference.")
|
|
parser.add_argument("input", type=str, help="Audio WAV file path.")
|
|
parser.add_argument("input", type=str, help="Audio WAV file path.")
|
|
@@ -82,59 +62,11 @@ def main() -> None:
|
|
|
|
|
|
logger.info(f"Running inference on {device=} with {dtype=}.")
|
|
logger.info(f"Running inference on {device=} with {dtype=}.")
|
|
|
|
|
|
- unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
|
|
|
|
-
|
|
|
|
- translator = Translator(
|
|
|
|
|
|
+ expressive_translator = ExpressiveTranslator(
|
|
args.model_name,
|
|
args.model_name,
|
|
- vocoder_name_or_card=None,
|
|
|
|
- device=device,
|
|
|
|
- dtype=dtype,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- pretssel_generator = PretsselGenerator(
|
|
|
|
args.vocoder_name,
|
|
args.vocoder_name,
|
|
- vocab_info=unit_tokenizer.vocab_info,
|
|
|
|
- device=device,
|
|
|
|
- dtype=dtype,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- fbank_extractor = WaveformToFbankConverter(
|
|
|
|
- num_mel_bins=80,
|
|
|
|
- waveform_scale=2**15,
|
|
|
|
- channel_last=True,
|
|
|
|
- standardize=False,
|
|
|
|
- device=device,
|
|
|
|
- dtype=dtype,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
|
|
|
|
- gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
|
|
|
|
- gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
|
|
|
|
-
|
|
|
|
- wav, sample_rate = torchaudio.load(args.input)
|
|
|
|
- wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000)
|
|
|
|
- wav = wav.transpose(0, 1)
|
|
|
|
-
|
|
|
|
- data = fbank_extractor(
|
|
|
|
- {
|
|
|
|
- "waveform": wav,
|
|
|
|
- "sample_rate": 16000,
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
- fbank = data["fbank"]
|
|
|
|
- gcmvn_fbank = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
|
|
|
|
- std, mean = torch.std_mean(fbank, dim=0)
|
|
|
|
- fbank = fbank.subtract(mean).divide(std)
|
|
|
|
-
|
|
|
|
- src = SequenceData(
|
|
|
|
- seqs=fbank.unsqueeze(0),
|
|
|
|
- seq_lens=torch.LongTensor([fbank.shape[0]]),
|
|
|
|
- is_ragged=False,
|
|
|
|
- )
|
|
|
|
- src_gcmvn = SequenceData(
|
|
|
|
- seqs=gcmvn_fbank.unsqueeze(0),
|
|
|
|
- seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
|
|
|
|
- is_ragged=False,
|
|
|
|
|
|
+ device,
|
|
|
|
+ dtype
|
|
)
|
|
)
|
|
|
|
|
|
text_generation_opts, unit_generation_opts = set_generation_opts(args)
|
|
text_generation_opts, unit_generation_opts = set_generation_opts(args)
|
|
@@ -145,22 +77,13 @@ def main() -> None:
|
|
f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
|
|
f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
|
|
)
|
|
)
|
|
|
|
|
|
- text_output, unit_output = translator.predict(
|
|
|
|
- src,
|
|
|
|
- "s2st",
|
|
|
|
|
|
+ speech_output, text_output = expressive_translator.predict(
|
|
|
|
+ args.input,
|
|
args.tgt_lang,
|
|
args.tgt_lang,
|
|
- text_generation_opts=text_generation_opts,
|
|
|
|
- unit_generation_opts=unit_generation_opts,
|
|
|
|
- unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
|
|
|
|
- duration_factor=args.duration_factor,
|
|
|
|
- prosody_encoder_input=src_gcmvn,
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- assert unit_output is not None
|
|
|
|
- speech_output = pretssel_generator.predict(
|
|
|
|
- unit_output.units,
|
|
|
|
- tgt_lang=args.tgt_lang,
|
|
|
|
- prosody_encoder_input=src_gcmvn,
|
|
|
|
|
|
+ text_generation_opts,
|
|
|
|
+ unit_generation_opts,
|
|
|
|
+ args.unit_generation_ngram_filtering,
|
|
|
|
+ args.duration_factor,
|
|
)
|
|
)
|
|
|
|
|
|
logger.info(f"Saving expressive translated audio in {args.tgt_lang}")
|
|
logger.info(f"Saving expressive translated audio in {args.tgt_lang}")
|
|
@@ -170,9 +93,7 @@ def main() -> None:
|
|
sample_rate=speech_output.sample_rate,
|
|
sample_rate=speech_output.sample_rate,
|
|
)
|
|
)
|
|
|
|
|
|
- text_out = remove_prosody_tokens_from_text(str(text_output[0]))
|
|
|
|
-
|
|
|
|
- logger.info(f"Translated text in {args.tgt_lang}: {text_out}")
|
|
|
|
|
|
+ logger.info(f"Translated text in {args.tgt_lang}: {text_output[0]}")
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|