123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- from dataclasses import dataclass
- from typing import Optional, Tuple
- import torch
- from fairseq2.data.text import TextTokenizer
- from fairseq2.generation import (
- Seq2SeqGenerator,
- SequenceGeneratorOptions,
- SequenceGeneratorOutput,
- SequenceToTextGenerator,
- SequenceToTextOutput,
- )
- from seamless_communication.models.unity.model import UnitYModel, UnitYX2TModel
- from seamless_communication.models.unity.unit_tokenizer import (
- UnitTokenDecoder,
- UnitTokenizer,
- )
- from fairseq2.nn.utils.module import infer_device
- from torch import Tensor
- class UnitYGenerator:
- """Generates text translations and speech units from a UnitY model."""
- model: UnitYModel
- s2t_generator: SequenceToTextGenerator
- t2t_generator: Optional[SequenceToTextGenerator]
- unit_decoder: Optional[UnitTokenDecoder]
- unit_generator: Optional[Seq2SeqGenerator]
- def __init__(
- self,
- model: UnitYModel,
- text_tokenizer: TextTokenizer,
- target_lang: str,
- unit_tokenizer: Optional[UnitTokenizer] = None,
- text_opts: Optional[SequenceGeneratorOptions] = None,
- unit_opts: Optional[SequenceGeneratorOptions] = None,
- ) -> None:
- """
- :param model:
- The UnitY model to use for generation.
- :param text_tokenizer:
- The text tokenizer to use.
- :param unit_tokenizer:
- The unit tokenizer to use.
- :param target_lang:
- The target language.
- :param text_generator_opts:
- The options to pass to the underlying text :class:`Seq2SeqGenerator`.
- :param unit_generator_opts:
- The options to pass to the underlying unit :class:`Seq2SeqGenerator`.
- """
- if model.t2u_model is None:
- raise ValueError(
- "`model` does not have a T2U sub-model. "
- "For text generation only, "
- "use `SequenceToTextGenerator` instead."
- )
- model.eval()
- self.model = model
- s2t_model = UnitYX2TModel(
- encoder_frontend=model.speech_encoder_frontend,
- encoder=model.speech_encoder,
- decoder_frontend=model.text_decoder_frontend,
- decoder=model.text_decoder,
- final_proj=model.final_proj,
- pad_idx=model.pad_idx,
- )
- self.s2t_generator = SequenceToTextGenerator(
- s2t_model, text_tokenizer, target_lang, text_opts
- )
- if model.text_encoder is None:
- self.t2t_generator = None
- else:
- assert model.text_encoder_frontend is not None
- assert model.text_encoder is not None
- t2t_model = UnitYX2TModel(
- encoder_frontend=model.text_encoder_frontend,
- encoder=model.text_encoder,
- decoder_frontend=model.text_decoder_frontend,
- decoder=model.text_decoder,
- final_proj=model.final_proj,
- pad_idx=model.pad_idx,
- )
- self.t2t_generator = SequenceToTextGenerator(
- t2t_model, text_tokenizer, target_lang, text_opts
- )
- self.unit_generator = None
- self.unit_decoder = None
- # Set up unit generator.
- if unit_tokenizer is not None:
- self.unit_decoder = unit_tokenizer.create_decoder()
- unit_encoder = unit_tokenizer.create_encoder(
- lang=target_lang, device=infer_device(model.t2u_model)
- )
- if unit_opts is None:
- # Speech sequences are typically much longer than text sequences.
- unit_opts = SequenceGeneratorOptions(
- soft_max_seq_len=(1, 50), hard_max_seq_len=5000
- )
- self.unit_generator = Seq2SeqGenerator(
- model.t2u_model,
- unit_tokenizer.vocab_info,
- unit_encoder.prefix_indices,
- unit_opts,
- )
- @torch.inference_mode()
- def __call__(
- self,
- source_seqs: Tensor,
- source_seq_lens: Optional[Tensor],
- input_modality: str = "speech",
- output_modality: str = "speech",
- ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
- """
- :param source_seqs:
- The source sequences to use for generation. *Shape:* :math:`(N,S,*)`,
- where :math:`N` is the batch size, :math:`S` is the sequence length,
- and :math:`*` is any number of sequence-specific dimensions
- including none.
- :param source_seq_lens:
- An array where each element represents the length of the sequence at
- the same index in ``source_seqs``. *Shape:* :math:`(N)`, where
- :math:`N` is the batch size.
- :param input_modality:
- The type of modality to encode.
- :param output_modality:
- The type of modality to decode.
- :returns:
- - The output of the text generator.
- - The output of the unit generator.
- """
- if input_modality == "speech":
- text_output = self.s2t_generator.generate_ex(source_seqs, source_seq_lens)
- elif input_modality == "text" and self.t2t_generator is not None:
- text_output = self.t2t_generator.generate_ex(source_seqs, source_seq_lens)
- elif input_modality == "text" and self.t2t_generator is None:
- raise ValueError(
- f"Please set use_text_encoder to True in your model config to encode text."
- )
- else:
- raise ValueError(f"Unsupported input_modality: {input_modality}")
- # We skip T2U when we only need to output text.
- if output_modality == "text":
- return text_output, None
- text_seqs, text_seq_lens = text_output.generator_output.collate()
- # Use the output of the text generator to compute the decoder output.
- decoder_output, decoder_padding_mask = self.model.decode(
- text_seqs,
- text_seq_lens,
- text_output.encoder_output,
- text_output.encoder_padding_mask,
- )
- assert self.model.t2u_model is not None
- t2u_encoder_output, t2u_encoder_padding_mask = self.model.t2u_model.encode(
- decoder_output, decoder_padding_mask
- )
- assert self.unit_generator is not None
- assert self.unit_decoder is not None
- unit_gen_output = self.unit_generator(
- t2u_encoder_output,
- t2u_encoder_padding_mask,
- source_seq_len=source_seqs.size(1),
- )
- unit_seqs, _ = unit_gen_output.collate()
- # Convert to speech units.
- units = self.unit_decoder(unit_seqs)
- unit_output = SequenceToUnitOutput(
- units, unit_gen_output, t2u_encoder_output, t2u_encoder_padding_mask
- )
- return text_output, unit_output
- @dataclass
- class SequenceToUnitOutput:
- units: Tensor
- """The generated units."""
- generator_output: SequenceGeneratorOutput
- """The output of the underlying :class:`Seq2SeqGenerator`."""
- t2u_encoder_output: Tensor
- """The encoder output of the underlying UnitY T2U model used to generate the
- units. *Shape:* :math:`(N,S_{enc},M)`, where :math:`N` is the batch size,
- :math:`S_{enc}` is the encoder output sequence length, and :math:`M` is the
- dimensionality of the model."""
- t2u_encoder_padding_mask: Optional[Tensor]
- """The float padding mask of :attr:`encoder_output`. *Shape:*
- :math:`(N,S_{enc})`, where :math:`N` is the batch size and :math:`S_{enc}`
- is the encoder output sequence length."""
|