|
@@ -18,6 +18,9 @@ from fairseq2.memory import MemoryBlock
|
|
from fairseq2.typing import Device
|
|
from fairseq2.typing import Device
|
|
from torch import Tensor
|
|
from torch import Tensor
|
|
from enum import Enum, auto
|
|
from enum import Enum, auto
|
|
|
|
+from seamless_communication.models.inference.ngram_repeat_block_logits_processor import (
|
|
|
|
+ NGramRepeatBlockLogitsProcessor,
|
|
|
|
+)
|
|
|
|
|
|
from seamless_communication.models.unity import (
|
|
from seamless_communication.models.unity import (
|
|
UnitTokenizer,
|
|
UnitTokenizer,
|
|
@@ -30,6 +33,13 @@ from seamless_communication.models.unity import (
|
|
from seamless_communication.models.unity.generator import SequenceToUnitOutput
|
|
from seamless_communication.models.unity.generator import SequenceToUnitOutput
|
|
from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
|
|
from seamless_communication.models.vocoder import load_vocoder_model, Vocoder
|
|
|
|
|
|
|
|
+import urllib.request
|
|
|
|
+from urllib.request import urlopen
|
|
|
|
+import ssl
|
|
|
|
+import json
|
|
|
|
+
|
|
|
|
+ssl._create_default_https_context = ssl._create_unverified_context
|
|
|
|
+
|
|
|
|
|
|
class Task(Enum):
|
|
class Task(Enum):
|
|
S2ST = auto()
|
|
S2ST = auto()
|
|
@@ -88,25 +98,38 @@ class Translator(nn.Module):
|
|
input_modality: Modality,
|
|
input_modality: Modality,
|
|
output_modality: Modality,
|
|
output_modality: Modality,
|
|
tgt_lang: str,
|
|
tgt_lang: str,
|
|
|
|
+ ngram_filtering: bool = False,
|
|
) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
|
|
) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
|
|
if input_modality == Modality.TEXT:
|
|
if input_modality == Modality.TEXT:
|
|
# need to adjust this since src_len is smaller for text.
|
|
# need to adjust this since src_len is smaller for text.
|
|
max_len_a = 25
|
|
max_len_a = 25
|
|
else:
|
|
else:
|
|
max_len_a = 1
|
|
max_len_a = 1
|
|
-
|
|
|
|
|
|
+ text_opts = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200))
|
|
|
|
+ unit_opts = SequenceGeneratorOptions(
|
|
|
|
+ beam_size=5, soft_max_seq_len=(max_len_a, 50)
|
|
|
|
+ )
|
|
|
|
+ if ngram_filtering:
|
|
|
|
+ text_opts.logits_processor = NGramRepeatBlockLogitsProcessor(
|
|
|
|
+ no_repeat_ngram_size=10
|
|
|
|
+ )
|
|
|
|
+ unit_opts.logits_processor = NGramRepeatBlockLogitsProcessor(
|
|
|
|
+ no_repeat_ngram_size=10
|
|
|
|
+ )
|
|
generator = UnitYGenerator(
|
|
generator = UnitYGenerator(
|
|
model,
|
|
model,
|
|
text_tokenizer,
|
|
text_tokenizer,
|
|
tgt_lang,
|
|
tgt_lang,
|
|
unit_tokenizer if output_modality == Modality.SPEECH else None,
|
|
unit_tokenizer if output_modality == Modality.SPEECH else None,
|
|
- text_opts=SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(1, 200)),
|
|
|
|
- unit_opts=SequenceGeneratorOptions(
|
|
|
|
- beam_size=5, soft_max_seq_len=(max_len_a, 50)
|
|
|
|
- ),
|
|
|
|
|
|
+ text_opts=text_opts,
|
|
|
|
+ unit_opts=unit_opts,
|
|
)
|
|
)
|
|
return generator(
|
|
return generator(
|
|
- src["seqs"], src["seq_lens"], input_modality.value, output_modality.value
|
|
|
|
|
|
+ src["seqs"],
|
|
|
|
+ src["seq_lens"],
|
|
|
|
+ input_modality.value,
|
|
|
|
+ output_modality.value,
|
|
|
|
+ ngram_filtering=ngram_filtering,
|
|
)
|
|
)
|
|
|
|
|
|
def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
|
|
def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
|
|
@@ -138,6 +161,7 @@ class Translator(nn.Module):
|
|
tgt_lang: str,
|
|
tgt_lang: str,
|
|
src_lang: Optional[str] = None,
|
|
src_lang: Optional[str] = None,
|
|
spkr: Optional[int] = -1,
|
|
spkr: Optional[int] = -1,
|
|
|
|
+ ngram_filtering: bool = False,
|
|
) -> Tuple[StringLike, Optional[List[Tensor]], Optional[int]]:
|
|
) -> Tuple[StringLike, Optional[List[Tensor]], Optional[int]]:
|
|
"""
|
|
"""
|
|
The main method used to perform inference on all tasks.
|
|
The main method used to perform inference on all tasks.
|
|
@@ -197,6 +221,7 @@ class Translator(nn.Module):
|
|
input_modality,
|
|
input_modality,
|
|
output_modality,
|
|
output_modality,
|
|
tgt_lang=tgt_lang,
|
|
tgt_lang=tgt_lang,
|
|
|
|
+ ngram_filtering=ngram_filtering,
|
|
)
|
|
)
|
|
|
|
|
|
text_out = result[0]
|
|
text_out = result[0]
|
|
@@ -207,3 +232,19 @@ class Translator(nn.Module):
|
|
units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
|
|
units = unit_out.units[:, 1:][0].cpu().numpy().tolist()
|
|
wav_out, sr_out = self.synthesize_speech(units, tgt_lang, spkr)
|
|
wav_out, sr_out = self.synthesize_speech(units, tgt_lang, spkr)
|
|
return text_out.sentences[0], wav_out, sr_out
|
|
return text_out.sentences[0], wav_out, sr_out
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
+ import torchaudio
|
|
|
|
+
|
|
|
|
+ # audio = "/data/home/dnn/LJ003-0001.wav"
|
|
|
|
+ audio = "/data/home/dnn/oss_sc/seamless_communication/spanish_repeat.wav"
|
|
|
|
+ translator = Translator(
|
|
|
|
+ "seamlessM4T_large", "vocoder_36langs", torch.device("cuda:0")
|
|
|
|
+ )
|
|
|
|
+ text_out, wav, sr = translator.predict(audio, "s2st", "deu", ngram_filtering=True) # type: ignore
|
|
|
|
+ torchaudio.save(
|
|
|
|
+ "/data/home/dnn/deu_testing.wav",
|
|
|
|
+ wav[0].cpu(),
|
|
|
|
+ sample_rate=sr,
|
|
|
|
+ )
|