|
@@ -3,30 +3,31 @@
|
|
# This source code is licensed under the license found in the
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
|
|
+from dataclasses import dataclass
|
|
|
|
+from enum import Enum, auto
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
-from typing import Callable, Optional, Tuple, Union
|
|
|
|
|
|
+from torch import Tensor
|
|
|
|
+from typing import Callable, List, Optional, Tuple, Union, cast
|
|
|
|
|
|
import torch
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
|
|
+
|
|
from fairseq2.assets.card import AssetCard
|
|
from fairseq2.assets.card import AssetCard
|
|
from fairseq2.data import Collater, SequenceData
|
|
from fairseq2.data import Collater, SequenceData
|
|
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
|
|
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
|
|
-from fairseq2.data.text.text_tokenizer import TextTokenizer
|
|
|
|
|
|
+from fairseq2.data.text import TextTokenizer
|
|
from fairseq2.data.typing import StringLike
|
|
from fairseq2.data.typing import StringLike
|
|
from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
|
|
from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
|
|
from fairseq2.memory import MemoryBlock
|
|
from fairseq2.memory import MemoryBlock
|
|
from fairseq2.nn.padding import get_seqs_and_padding_mask
|
|
from fairseq2.nn.padding import get_seqs_and_padding_mask
|
|
from fairseq2.typing import DataType, Device
|
|
from fairseq2.typing import DataType, Device
|
|
-from torch import Tensor
|
|
|
|
-from enum import Enum, auto
|
|
|
|
-from seamless_communication.models.inference.ngram_repeat_block_processor import (
|
|
|
|
- NGramRepeatBlockProcessor,
|
|
|
|
-)
|
|
|
|
|
|
+
|
|
|
|
|
|
from seamless_communication.models.unity import (
|
|
from seamless_communication.models.unity import (
|
|
UnitTokenizer,
|
|
UnitTokenizer,
|
|
UnitYGenerator,
|
|
UnitYGenerator,
|
|
UnitYModel,
|
|
UnitYModel,
|
|
|
|
+ UnitYNART2UModel,
|
|
UnitYT2UModel,
|
|
UnitYT2UModel,
|
|
load_unity_model,
|
|
load_unity_model,
|
|
load_unity_text_tokenizer,
|
|
load_unity_text_tokenizer,
|
|
@@ -49,12 +50,25 @@ class Modality(Enum):
|
|
TEXT = "text"
|
|
TEXT = "text"
|
|
|
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
|
+class BatchedSpeechOutput:
|
|
|
|
+ units: List[List[int]]
|
|
|
|
+ """The batched list of generated units."""
|
|
|
|
+
|
|
|
|
+ audio_wavs: List[Tensor]
|
|
|
|
+ """The batched list of audio waveforms."""
|
|
|
|
+
|
|
|
|
+ sample_rate: int = 16000
|
|
|
|
+ """Sample rate of the audio waveforms."""
|
|
|
|
+
|
|
|
|
+
|
|
class Translator(nn.Module):
|
|
class Translator(nn.Module):
|
|
def __init__(
|
|
def __init__(
|
|
self,
|
|
self,
|
|
model_name_or_card: Union[str, AssetCard],
|
|
model_name_or_card: Union[str, AssetCard],
|
|
vocoder_name_or_card: Union[str, AssetCard],
|
|
vocoder_name_or_card: Union[str, AssetCard],
|
|
device: Device,
|
|
device: Device,
|
|
|
|
+ text_tokenizer: Optional[TextTokenizer] = None,
|
|
dtype: DataType = torch.float16,
|
|
dtype: DataType = torch.float16,
|
|
):
|
|
):
|
|
super().__init__()
|
|
super().__init__()
|
|
@@ -66,7 +80,12 @@ class Translator(nn.Module):
|
|
)
|
|
)
|
|
assert isinstance(self.model, UnitYModel)
|
|
assert isinstance(self.model, UnitYModel)
|
|
|
|
|
|
- self.text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
|
|
|
|
|
+ if text_tokenizer is None:
|
|
|
|
+ self.text_tokenizer: TextTokenizer = load_unity_text_tokenizer(
|
|
|
|
+ model_name_or_card
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ self.text_tokenizer = text_tokenizer
|
|
|
|
|
|
self.unit_tokenizer: Optional[UnitTokenizer] = None
|
|
self.unit_tokenizer: Optional[UnitTokenizer] = None
|
|
if self.model.t2u_model is not None:
|
|
if self.model.t2u_model is not None:
|
|
@@ -112,40 +131,23 @@ class Translator(nn.Module):
|
|
input_modality: Modality,
|
|
input_modality: Modality,
|
|
output_modality: Modality,
|
|
output_modality: Modality,
|
|
tgt_lang: str,
|
|
tgt_lang: str,
|
|
- ngram_filtering: bool = False,
|
|
|
|
- text_max_len_a: int = 1,
|
|
|
|
- text_max_len_b: int = 200,
|
|
|
|
- unit_max_len_a: Optional[int] = None,
|
|
|
|
- unit_max_len_b: Optional[int] = None,
|
|
|
|
|
|
+ text_generation_opts: SequenceGeneratorOptions,
|
|
|
|
+ unit_generation_opts: Optional[SequenceGeneratorOptions],
|
|
|
|
+ unit_generation_ngram_filtering: bool = False,
|
|
) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
|
|
) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
|
|
- if unit_max_len_a is None:
|
|
|
|
- # need to adjust this for T2ST since src_len is smaller for text.
|
|
|
|
- if input_modality == Modality.TEXT:
|
|
|
|
- unit_max_len_a = 25
|
|
|
|
- else:
|
|
|
|
- unit_max_len_a = 1
|
|
|
|
-
|
|
|
|
- text_opts = SequenceGeneratorOptions(
|
|
|
|
- beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
|
|
|
|
- )
|
|
|
|
- unit_opts = SequenceGeneratorOptions(
|
|
|
|
- beam_size=5, soft_max_seq_len=(unit_max_len_a, unit_max_len_b or 50)
|
|
|
|
- )
|
|
|
|
|
|
+ # We disregard unit generations opts for the NAR T2U decoder.
|
|
|
|
+ if output_modality != Modality.SPEECH or isinstance(
|
|
|
|
+ model.t2u_model, UnitYNART2UModel
|
|
|
|
+ ):
|
|
|
|
+ unit_generation_opts = None
|
|
|
|
|
|
- if ngram_filtering:
|
|
|
|
- text_opts.logits_processor = NGramRepeatBlockProcessor(
|
|
|
|
- no_repeat_ngram_size=4
|
|
|
|
- )
|
|
|
|
- unit_opts.logits_processor = NGramRepeatBlockProcessor(
|
|
|
|
- no_repeat_ngram_size=4
|
|
|
|
- )
|
|
|
|
generator = UnitYGenerator(
|
|
generator = UnitYGenerator(
|
|
model,
|
|
model,
|
|
text_tokenizer,
|
|
text_tokenizer,
|
|
tgt_lang,
|
|
tgt_lang,
|
|
unit_tokenizer if output_modality == Modality.SPEECH else None,
|
|
unit_tokenizer if output_modality == Modality.SPEECH else None,
|
|
- text_opts=text_opts,
|
|
|
|
- unit_opts=unit_opts,
|
|
|
|
|
|
+ text_opts=text_generation_opts,
|
|
|
|
+ unit_opts=unit_generation_opts,
|
|
)
|
|
)
|
|
seqs, padding_mask = get_seqs_and_padding_mask(src)
|
|
seqs, padding_mask = get_seqs_and_padding_mask(src)
|
|
return generator(
|
|
return generator(
|
|
@@ -153,10 +155,16 @@ class Translator(nn.Module):
|
|
padding_mask,
|
|
padding_mask,
|
|
input_modality.value,
|
|
input_modality.value,
|
|
output_modality.value,
|
|
output_modality.value,
|
|
- ngram_filtering=ngram_filtering,
|
|
|
|
|
|
+ ngram_filtering=unit_generation_ngram_filtering,
|
|
)
|
|
)
|
|
|
|
|
|
- def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
|
|
|
|
|
|
+ @staticmethod
|
|
|
|
+ def get_modalities_from_task_str(task_str: str) -> Tuple[Modality, Modality]:
|
|
|
|
+ try:
|
|
|
|
+ task = Task[task_str.upper()]
|
|
|
|
+ except KeyError:
|
|
|
|
+ raise ValueError(f"Unsupported task: {task_str}")
|
|
|
|
+
|
|
if task == Task.S2ST:
|
|
if task == Task.S2ST:
|
|
return Modality.SPEECH, Modality.SPEECH
|
|
return Modality.SPEECH, Modality.SPEECH
|
|
# ASR is treated as S2TT with src_lang == tgt_lang
|
|
# ASR is treated as S2TT with src_lang == tgt_lang
|
|
@@ -170,18 +178,20 @@ class Translator(nn.Module):
|
|
@torch.inference_mode()
|
|
@torch.inference_mode()
|
|
def predict(
|
|
def predict(
|
|
self,
|
|
self,
|
|
- input: Union[str, Tensor],
|
|
|
|
|
|
+ input: Union[str, Tensor, dict],
|
|
task_str: str,
|
|
task_str: str,
|
|
tgt_lang: str,
|
|
tgt_lang: str,
|
|
src_lang: Optional[str] = None,
|
|
src_lang: Optional[str] = None,
|
|
|
|
+ text_generation_opts: SequenceGeneratorOptions = SequenceGeneratorOptions(
|
|
|
|
+ beam_size=5, soft_max_seq_len=(1, 200)
|
|
|
|
+ ),
|
|
|
|
+ unit_generation_opts: Optional[
|
|
|
|
+ SequenceGeneratorOptions
|
|
|
|
+ ] = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(25, 50)),
|
|
spkr: Optional[int] = -1,
|
|
spkr: Optional[int] = -1,
|
|
- ngram_filtering: bool = False,
|
|
|
|
sample_rate: int = 16000,
|
|
sample_rate: int = 16000,
|
|
- text_max_len_a: int = 1,
|
|
|
|
- text_max_len_b: int = 200,
|
|
|
|
- unit_max_len_a: Optional[int] = None,
|
|
|
|
- unit_max_len_b: Optional[int] = None,
|
|
|
|
- ) -> Tuple[StringLike, Optional[Tensor], Optional[int]]:
|
|
|
|
|
|
+ unit_generation_ngram_filtering: bool = False,
|
|
|
|
+ ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
|
|
"""
|
|
"""
|
|
The main method used to perform inference on all tasks.
|
|
The main method used to perform inference on all tasks.
|
|
|
|
|
|
@@ -194,22 +204,27 @@ class Translator(nn.Module):
|
|
Target language to decode into.
|
|
Target language to decode into.
|
|
:param src_lang:
|
|
:param src_lang:
|
|
Source language of input, only required for T2ST, T2TT tasks.
|
|
Source language of input, only required for T2ST, T2TT tasks.
|
|
|
|
+ :param text_generation_opts:
|
|
|
|
+ Text generation hyperparameters for incremental decoding.
|
|
|
|
+ :param unit_generation_opts:
|
|
|
|
+ Unit generation hyperparameters for incremental decoding.
|
|
:param spkr:
|
|
:param spkr:
|
|
Speaker id for vocoder.
|
|
Speaker id for vocoder.
|
|
|
|
+ :param unit_generation_ngram_filtering:
|
|
|
|
+ If True, removes consecutive repeated ngrams
|
|
|
|
+ from the decoded unit output.
|
|
|
|
|
|
:returns:
|
|
:returns:
|
|
- - Translated text.
|
|
|
|
- - Generated output audio waveform corresponding to the translated text.
|
|
|
|
- - Sample rate of output audio waveform.
|
|
|
|
|
|
+ - Batched list of Translated text.
|
|
|
|
+ - Translated BatchedSpeechOutput.
|
|
"""
|
|
"""
|
|
- try:
|
|
|
|
- task = Task[task_str.upper()]
|
|
|
|
- except KeyError:
|
|
|
|
- raise ValueError(f"Unsupported task: {task_str}")
|
|
|
|
-
|
|
|
|
- input_modality, output_modality = self.get_modalities_from_task(task)
|
|
|
|
|
|
+ input_modality, output_modality = self.get_modalities_from_task_str(task_str)
|
|
|
|
|
|
- if input_modality == Modality.SPEECH:
|
|
|
|
|
|
+ if isinstance(input, dict):
|
|
|
|
+ assert "seqs" in input
|
|
|
|
+ assert "seq_lens" in input
|
|
|
|
+ src = cast(SequenceData, input)
|
|
|
|
+ elif input_modality == Modality.SPEECH:
|
|
audio = input
|
|
audio = input
|
|
if isinstance(audio, str):
|
|
if isinstance(audio, str):
|
|
with Path(audio).open("rb") as fb:
|
|
with Path(audio).open("rb") as fb:
|
|
@@ -235,34 +250,51 @@ class Translator(nn.Module):
|
|
src = self.collate(self.token_encoder(text))
|
|
src = self.collate(self.token_encoder(text))
|
|
|
|
|
|
assert isinstance(self.model, UnitYModel)
|
|
assert isinstance(self.model, UnitYModel)
|
|
- result = self.get_prediction(
|
|
|
|
|
|
+ text_output, unit_output = self.get_prediction(
|
|
self.model,
|
|
self.model,
|
|
self.text_tokenizer,
|
|
self.text_tokenizer,
|
|
self.unit_tokenizer,
|
|
self.unit_tokenizer,
|
|
src,
|
|
src,
|
|
input_modality,
|
|
input_modality,
|
|
output_modality,
|
|
output_modality,
|
|
- tgt_lang=tgt_lang,
|
|
|
|
- ngram_filtering=ngram_filtering,
|
|
|
|
- text_max_len_a=text_max_len_a,
|
|
|
|
- text_max_len_b=text_max_len_b,
|
|
|
|
- unit_max_len_a=unit_max_len_a,
|
|
|
|
- unit_max_len_b=unit_max_len_b,
|
|
|
|
|
|
+ tgt_lang,
|
|
|
|
+ text_generation_opts,
|
|
|
|
+ unit_generation_opts,
|
|
|
|
+ unit_generation_ngram_filtering=unit_generation_ngram_filtering,
|
|
)
|
|
)
|
|
|
|
|
|
- text_out = result[0]
|
|
|
|
- unit_out = result[1]
|
|
|
|
if output_modality == Modality.TEXT:
|
|
if output_modality == Modality.TEXT:
|
|
- return text_out.sentences[0], None, None
|
|
|
|
|
|
+ return text_output.sentences, None
|
|
else:
|
|
else:
|
|
|
|
+ assert unit_output is not None
|
|
|
|
+
|
|
if isinstance(self.model.t2u_model, UnitYT2UModel):
|
|
if isinstance(self.model.t2u_model, UnitYT2UModel):
|
|
# Remove the lang token for AR UnitY since the vocoder doesn't need it
|
|
# Remove the lang token for AR UnitY since the vocoder doesn't need it
|
|
# in the unit sequence. tgt_lang is fed as an argument to the vocoder.
|
|
# in the unit sequence. tgt_lang is fed as an argument to the vocoder.
|
|
- units = unit_out.units[:, 1:]
|
|
|
|
|
|
+ units = unit_output.units[:, 1:]
|
|
else:
|
|
else:
|
|
- units = unit_out.units
|
|
|
|
|
|
+ units = unit_output.units
|
|
|
|
|
|
- # TODO: batch_size set to 1 for now, implement batching.
|
|
|
|
- units = units[0].cpu().numpy().tolist()
|
|
|
|
- wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
|
|
|
|
- return text_out.sentences[0], wav_out, sample_rate
|
|
|
|
|
|
+ audio_wavs = []
|
|
|
|
+ speech_units = []
|
|
|
|
+ for i in range(len(unit_output.units)):
|
|
|
|
+ u = units[i].cpu().numpy().tolist()
|
|
|
|
+ index_of_first_one = next(
|
|
|
|
+ (index for index, value in enumerate(u) if value == 1), len(u)
|
|
|
|
+ )
|
|
|
|
+ u = u[:index_of_first_one]
|
|
|
|
+ speech_units.append(u)
|
|
|
|
+ # TODO: Implement batched inference for vocoder.
|
|
|
|
+ translated_audio_wav = self.vocoder(
|
|
|
|
+ u, tgt_lang, spkr, dur_prediction=True
|
|
|
|
+ )
|
|
|
|
+ audio_wavs.append(translated_audio_wav)
|
|
|
|
+
|
|
|
|
+ return (
|
|
|
|
+ text_output.sentences,
|
|
|
|
+ BatchedSpeechOutput(
|
|
|
|
+ units=speech_units,
|
|
|
|
+ audio_wavs=audio_wavs,
|
|
|
|
+ sample_rate=sample_rate,
|
|
|
|
+ ),
|
|
|
|
+ )
|