|
@@ -33,13 +33,6 @@ from seamless_communication.models.unity import (
|
|
from seamless_communication.models.unity.generator import SequenceToUnitOutput
|
|
from seamless_communication.models.unity.generator import SequenceToUnitOutput
|
|
from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
|
|
from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
|
|
|
|
|
|
-import urllib.request
|
|
|
|
-from urllib.request import urlopen
|
|
|
|
-import ssl
|
|
|
|
-import json
|
|
|
|
-
|
|
|
|
-ssl._create_default_https_context = ssl._create_unverified_context
|
|
|
|
-
|
|
|
|
|
|
|
|
class Task(Enum):
|
|
class Task(Enum):
|
|
S2ST = auto()
|
|
S2ST = auto()
|
|
@@ -106,16 +99,10 @@ class Translator(nn.Module):
|
|
else:
|
|
else:
|
|
max_len_a = 1
|
|
max_len_a = 1
|
|
text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
|
|
text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
|
|
- unit_opts = SequenceGeneratorOptions(
|
|
|
|
- beam_size=5, soft_max_seq_len=(max_len_a, 50)
|
|
|
|
- )
|
|
|
|
|
|
+ unit_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(max_len_a, 50))
|
|
if ngram_filtering:
|
|
if ngram_filtering:
|
|
- text_opts.logits_processor = NGramRepeatBlockLogitsProcessor(
|
|
|
|
- no_repeat_ngram_size=10
|
|
|
|
- )
|
|
|
|
- unit_opts.logits_processor = NGramRepeatBlockLogitsProcessor(
|
|
|
|
- no_repeat_ngram_size=10
|
|
|
|
- )
|
|
|
|
|
|
+ text_opts.logits_processor = NGramRepeatBlockLogitsProcessor(no_repeat_ngram_size=10)
|
|
|
|
+ unit_opts.logits_processor = NGramRepeatBlockLogitsProcessor(no_repeat_ngram_size=10)
|
|
generator = UnitYGenerator(
|
|
generator = UnitYGenerator(
|
|
model,
|
|
model,
|
|
text_tokenizer,
|
|
text_tokenizer,
|
|
@@ -125,11 +112,7 @@ class Translator(nn.Module):
|
|
unit_opts=unit_opts,
|
|
unit_opts=unit_opts,
|
|
)
|
|
)
|
|
return generator(
|
|
return generator(
|
|
- src["seqs"],
|
|
|
|
- src["seq_lens"],
|
|
|
|
- input_modality.value,
|
|
|
|
- output_modality.value,
|
|
|
|
- ngram_filtering=ngram_filtering,
|
|
|
|
|
|
+ src["seqs"], src["seq_lens"], input_modality.value, output_modality.value, ngram_filtering=ngram_filtering
|
|
)
|
|
)
|
|
|
|
|
|
def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
|
|
def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
|
|
@@ -221,7 +204,7 @@ class Translator(nn.Module):
|
|
input_modality,
|
|
input_modality,
|
|
output_modality,
|
|
output_modality,
|
|
tgt_lang=tgt_lang,
|
|
tgt_lang=tgt_lang,
|
|
- ngram_filtering=ngram_filtering,
|
|
|
|
|
|
+ ngram_filtering=ngram_filtering
|
|
)
|
|
)
|
|
|
|
|
|
text_out = result[0]
|
|
text_out = result[0]
|
|
@@ -232,19 +215,3 @@ class Translator(nn.Module):
|
|
units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
|
|
units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
|
|
wav_out, sr_out = self.synthesize_speech(units, tgt_lang, spkr)
|
|
wav_out, sr_out = self.synthesize_speech(units, tgt_lang, spkr)
|
|
return text_out.sentences[0], wav_out, sr_out
|
|
return text_out.sentences[0], wav_out, sr_out
|
|
-
|
|
|
|
-
|
|
|
|
-if __name__ == "__main__":
|
|
|
|
- import torchaudio
|
|
|
|
-
|
|
|
|
- # audio = "/data/home/dnn/LJ003-0001.wav"
|
|
|
|
- audio = "/data/home/dnn/oss_sc/seamless_communication/spanish_repeat.wav"
|
|
|
|
- translator = Translator(
|
|
|
|
- "seamlessM4T_large", "vocoder_36langs", torch.device("cuda:0")
|
|
|
|
- )
|
|
|
|
- text_out, wav, sr = translator.predict(audio, "s2st", "deu", ngram_filtering=True) # type: ignore
|
|
|
|
- torchaudio.save(
|
|
|
|
- "/data/home/dnn/deu_testing.wav",
|
|
|
|
- wav[0].cpu(),
|
|
|
|
- sample_rate=sr,
|
|
|
|
- )
|
|
|