|
@@ -8,19 +8,19 @@ from dataclasses import dataclass
|
|
|
from typing import List, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
-from fairseq2.data import SequenceData
|
|
|
+from fairseq2.data import SequenceData, StringLike
|
|
|
from fairseq2.data.text import TextTokenizer
|
|
|
from fairseq2.generation import (
|
|
|
+ BeamSearchSeq2SeqGenerator,
|
|
|
Seq2SeqGenerator,
|
|
|
- SequenceGeneratorOptions,
|
|
|
- SequenceGeneratorOutput,
|
|
|
- SequenceToTextGenerator,
|
|
|
- SequenceToTextOutput,
|
|
|
+ SequenceToTextConverter,
|
|
|
+ StepProcessor,
|
|
|
)
|
|
|
from fairseq2.nn.padding import (
|
|
|
PaddingMask,
|
|
|
apply_padding_mask,
|
|
|
get_seqs_and_padding_mask,
|
|
|
+ pad_seqs,
|
|
|
)
|
|
|
from fairseq2.nn.utils.module import infer_device
|
|
|
from torch import Tensor
|
|
@@ -56,13 +56,34 @@ def remove_consecutive_repeated_ngrams(
|
|
|
return [token for idx, token in enumerate(sequence) if idx not in drop_idx]
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
+class SequenceGeneratorOptions:
|
|
|
+ """Holds the options to pass to a sequence generator."""
|
|
|
+
|
|
|
+ beam_size: int = 5
|
|
|
+ """The beam size."""
|
|
|
+
|
|
|
+ soft_max_seq_len: Tuple[int, int] = (1, 200)
|
|
|
+ """The terms ``a`` and ``b`` of ``ax + b`` where ``x`` is the source
|
|
|
+ sequence length. The generated sequences (including prefix sequence) will
|
|
|
+ have the maximum length of ``min(hard_max_seq_len, ax + b)``. See also
|
|
|
+ ``hard_max_seq_len``."""
|
|
|
+
|
|
|
+ hard_max_seq_len: int = 1024
|
|
|
+ """The hard limit on maximum length of generated sequences."""
|
|
|
+
|
|
|
+ step_processor: Optional[StepProcessor] = None
|
|
|
+ """The processor called at each generation step."""
|
|
|
+
|
|
|
+
|
|
|
class UnitYGenerator:
|
|
|
"""Generates text translations and speech units from a UnitY model."""
|
|
|
|
|
|
model: UnitYModel
|
|
|
- s2t_generator: SequenceToTextGenerator
|
|
|
- t2t_generator: Optional[SequenceToTextGenerator]
|
|
|
+ s2t_converter: SequenceToTextConverter
|
|
|
+ t2t_converter: Optional[SequenceToTextConverter]
|
|
|
unit_decoder: Optional[UnitTokenDecoder]
|
|
|
+ unit_prefix_indices: Optional[Tensor]
|
|
|
unit_generator: Optional[Seq2SeqGenerator]
|
|
|
|
|
|
def __init__(
|
|
@@ -92,6 +113,9 @@ class UnitYGenerator:
|
|
|
|
|
|
self.model = model
|
|
|
|
|
|
+ if text_opts is None:
|
|
|
+ text_opts = SequenceGeneratorOptions()
|
|
|
+
|
|
|
if model.text_decoder is None:
|
|
|
raise ValueError(
|
|
|
"`UnitYGenerator` requires a text decoder, but the current UnitY model does not have one."
|
|
@@ -107,8 +131,21 @@ class UnitYGenerator:
|
|
|
final_proj=model.final_proj,
|
|
|
target_vocab_info=model.target_vocab_info,
|
|
|
)
|
|
|
- self.s2t_generator = SequenceToTextGenerator(
|
|
|
- s2t_model, text_tokenizer, target_lang, text_opts
|
|
|
+
|
|
|
+ step_processors = []
|
|
|
+ if text_opts.step_processor is not None:
|
|
|
+ step_processors.append(text_opts.step_processor)
|
|
|
+
|
|
|
+ generator = BeamSearchSeq2SeqGenerator(
|
|
|
+ s2t_model,
|
|
|
+ beam_size=text_opts.beam_size,
|
|
|
+ max_gen_len=text_opts.soft_max_seq_len,
|
|
|
+ max_seq_len=text_opts.hard_max_seq_len,
|
|
|
+ echo_prompt=True,
|
|
|
+ step_processors=step_processors,
|
|
|
+ )
|
|
|
+ self.s2t_converter = SequenceToTextConverter(
|
|
|
+ generator, text_tokenizer, "translation", target_lang
|
|
|
)
|
|
|
|
|
|
if model.text_encoder is None:
|
|
@@ -124,8 +161,16 @@ class UnitYGenerator:
|
|
|
final_proj=model.final_proj,
|
|
|
target_vocab_info=model.target_vocab_info,
|
|
|
)
|
|
|
- self.t2t_generator = SequenceToTextGenerator(
|
|
|
- t2t_model, text_tokenizer, target_lang, text_opts
|
|
|
+ generator = BeamSearchSeq2SeqGenerator(
|
|
|
+ t2t_model,
|
|
|
+ beam_size=text_opts.beam_size,
|
|
|
+ max_gen_len=text_opts.soft_max_seq_len,
|
|
|
+ max_seq_len=text_opts.hard_max_seq_len,
|
|
|
+ echo_prompt=True,
|
|
|
+ step_processors=step_processors,
|
|
|
+ )
|
|
|
+ self.t2t_converter = SequenceToTextConverter(
|
|
|
+ generator, text_tokenizer, "translation", target_lang
|
|
|
)
|
|
|
|
|
|
self.unit_generator = None
|
|
@@ -143,18 +188,26 @@ class UnitYGenerator:
|
|
|
lang=target_lang, device=infer_device(model.t2u_model)
|
|
|
)
|
|
|
|
|
|
+ self.unit_prefix_indices = unit_encoder.prefix_indices
|
|
|
+
|
|
|
if isinstance(self.model.t2u_model, UnitYT2UModel):
|
|
|
if unit_opts is None:
|
|
|
# Speech sequences are typically much longer than text sequences.
|
|
|
unit_opts = SequenceGeneratorOptions(
|
|
|
- soft_max_seq_len=(1, 50), hard_max_seq_len=5000
|
|
|
+ soft_max_seq_len=(25, 50), hard_max_seq_len=5000
|
|
|
)
|
|
|
|
|
|
- self.unit_generator = Seq2SeqGenerator(
|
|
|
+ step_processors = []
|
|
|
+ if unit_opts.step_processor is not None:
|
|
|
+ step_processors.append(unit_opts.step_processor)
|
|
|
+
|
|
|
+ self.unit_generator = BeamSearchSeq2SeqGenerator(
|
|
|
self.model.t2u_model,
|
|
|
- unit_tokenizer.vocab_info,
|
|
|
- unit_encoder.prefix_indices,
|
|
|
- unit_opts,
|
|
|
+ beam_size=unit_opts.beam_size,
|
|
|
+ max_gen_len=unit_opts.soft_max_seq_len,
|
|
|
+ max_seq_len=unit_opts.hard_max_seq_len,
|
|
|
+ echo_prompt=True,
|
|
|
+ step_processors=step_processors,
|
|
|
)
|
|
|
|
|
|
@torch.inference_mode()
|
|
@@ -167,7 +220,7 @@ class UnitYGenerator:
|
|
|
ngram_filtering: bool = False,
|
|
|
duration_factor: float = 1.0,
|
|
|
prosody_encoder_input: Optional[SequenceData] = None,
|
|
|
- ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
|
|
|
+ ) -> Tuple[List[StringLike], Optional[Tensor]]:
|
|
|
"""
|
|
|
:param source_seqs:
|
|
|
The source sequences to use for generation. *Shape:* :math:`(N,S,*)`,
|
|
@@ -191,25 +244,31 @@ class UnitYGenerator:
|
|
|
"""
|
|
|
|
|
|
if input_modality == "speech":
|
|
|
- text_output = self.s2t_generator.generate_ex(
|
|
|
+ texts, text_gen_output = self.s2t_converter.batch_convert(
|
|
|
source_seqs, source_padding_mask
|
|
|
)
|
|
|
- elif input_modality == "text" and self.t2t_generator is not None:
|
|
|
- text_output = self.t2t_generator.generate_ex(
|
|
|
+ elif input_modality == "text":
|
|
|
+ if self.t2t_converter is None:
|
|
|
+ raise ValueError(
|
|
|
+ "Please set `use_text_encoder` to `True` in your model config to encode text."
|
|
|
+ )
|
|
|
+ texts, text_gen_output = self.t2t_converter.batch_convert(
|
|
|
source_seqs, source_padding_mask
|
|
|
)
|
|
|
- elif input_modality == "text" and self.t2t_generator is None:
|
|
|
- raise ValueError(
|
|
|
- "Please set use_text_encoder to True in your model config to encode text."
|
|
|
- )
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported input_modality: {input_modality}")
|
|
|
|
|
|
# We skip T2U when we only need to output text.
|
|
|
if output_modality == "text":
|
|
|
- return text_output, None
|
|
|
+ return texts, None
|
|
|
+
|
|
|
+ assert self.model.target_vocab_info.pad_idx is not None
|
|
|
|
|
|
- text_seqs, text_padding_mask = text_output.generator_output.collate()
|
|
|
+ text_seq_list = [h[0].seq for h in text_gen_output.hypotheses]
|
|
|
+
|
|
|
+ text_seqs, text_padding_mask = pad_seqs(
|
|
|
+ text_seq_list, self.model.target_vocab_info.pad_idx
|
|
|
+ )
|
|
|
|
|
|
# Manually trim the final EOS token to be consistent with fairseq.
|
|
|
text_seqs = text_seqs[:, :-1]
|
|
@@ -221,8 +280,8 @@ class UnitYGenerator:
|
|
|
decoder_output, decoder_padding_mask = self.model.decode(
|
|
|
text_seqs,
|
|
|
text_padding_mask,
|
|
|
- text_output.encoder_output,
|
|
|
- text_output.encoder_padding_mask,
|
|
|
+ text_gen_output.encoder_output,
|
|
|
+ text_gen_output.encoder_padding_mask,
|
|
|
)
|
|
|
|
|
|
assert self.model.t2u_model is not None
|
|
@@ -242,15 +301,25 @@ class UnitYGenerator:
|
|
|
|
|
|
if isinstance(self.model.t2u_model, UnitYT2UModel):
|
|
|
assert self.unit_generator is not None
|
|
|
- t2u_encoder_output, t2u_encoder_padding_mask = self.model.t2u_model.encode(
|
|
|
- decoder_output, decoder_padding_mask
|
|
|
- )
|
|
|
+ assert self.unit_prefix_indices is not None
|
|
|
+
|
|
|
+ # (S_pre) -> (N, S_pre)
|
|
|
+ prefix_seqs = self.unit_prefix_indices.expand(decoder_output.size(0), -1)
|
|
|
+
|
|
|
unit_gen_output = self.unit_generator(
|
|
|
- t2u_encoder_output,
|
|
|
- t2u_encoder_padding_mask,
|
|
|
- source_seq_len=source_seqs.size(1),
|
|
|
+ source_seqs=decoder_output,
|
|
|
+ source_padding_mask=decoder_padding_mask,
|
|
|
+ prompt_seqs=prefix_seqs,
|
|
|
+ prompt_padding_mask=None,
|
|
|
+ )
|
|
|
+
|
|
|
+ assert self.model.t2u_model.target_vocab_info.pad_idx is not None
|
|
|
+
|
|
|
+ unit_seq_list = [h[0].seq for h in unit_gen_output.hypotheses]
|
|
|
+
|
|
|
+ unit_seqs, _ = pad_seqs(
|
|
|
+ unit_seq_list, self.model.t2u_model.target_vocab_info.pad_idx
|
|
|
)
|
|
|
- unit_seqs, _ = unit_gen_output.collate()
|
|
|
else:
|
|
|
t2u_model_output, decoder_padding_mask, _ = self.model.t2u_model(
|
|
|
text_decoder_output=decoder_output,
|
|
@@ -273,15 +342,4 @@ class UnitYGenerator:
|
|
|
arr = remove_consecutive_repeated_ngrams(units.cpu().numpy().tolist())
|
|
|
units = torch.tensor(arr)
|
|
|
|
|
|
- unit_output = SequenceToUnitOutput(units, unit_gen_output)
|
|
|
-
|
|
|
- return text_output, unit_output
|
|
|
-
|
|
|
-
|
|
|
-@dataclass
|
|
|
-class SequenceToUnitOutput:
|
|
|
- units: Tensor
|
|
|
- """The generated units."""
|
|
|
-
|
|
|
- generator_output: Optional[SequenceGeneratorOutput]
|
|
|
- """The output of the underlying :class:`Seq2SeqGenerator`."""
|
|
|
+ return texts, units
|