generator.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from dataclasses import dataclass
  7. from typing import Optional, Tuple
  8. import torch
  9. from fairseq2.data.text import TextTokenizer
  10. from fairseq2.generation import (
  11. Seq2SeqGenerator,
  12. SequenceGeneratorOptions,
  13. SequenceGeneratorOutput,
  14. SequenceToTextGenerator,
  15. SequenceToTextOutput,
  16. )
  17. from seamless_communication.models.unity.model import UnitYModel, UnitYX2TModel
  18. from seamless_communication.models.unity.unit_tokenizer import (
  19. UnitTokenDecoder,
  20. UnitTokenizer,
  21. )
  22. from fairseq2.nn.utils.module import infer_device
  23. from torch import Tensor
  24. class UnitYGenerator:
  25. """Generates text translations and speech units from a UnitY model."""
  26. model: UnitYModel
  27. s2t_generator: SequenceToTextGenerator
  28. t2t_generator: Optional[SequenceToTextGenerator]
  29. unit_decoder: Optional[UnitTokenDecoder]
  30. unit_generator: Optional[Seq2SeqGenerator]
  31. def __init__(
  32. self,
  33. model: UnitYModel,
  34. text_tokenizer: TextTokenizer,
  35. target_lang: str,
  36. unit_tokenizer: Optional[UnitTokenizer] = None,
  37. text_opts: Optional[SequenceGeneratorOptions] = None,
  38. unit_opts: Optional[SequenceGeneratorOptions] = None,
  39. ) -> None:
  40. """
  41. :param model:
  42. The UnitY model to use for generation.
  43. :param text_tokenizer:
  44. The text tokenizer to use.
  45. :param unit_tokenizer:
  46. The unit tokenizer to use.
  47. :param target_lang:
  48. The target language.
  49. :param text_generator_opts:
  50. The options to pass to the underlying text :class:`Seq2SeqGenerator`.
  51. :param unit_generator_opts:
  52. The options to pass to the underlying unit :class:`Seq2SeqGenerator`.
  53. """
  54. if model.t2u_model is None:
  55. raise ValueError(
  56. "`model` does not have a T2U sub-model. "
  57. "For text generation only, "
  58. "use `SequenceToTextGenerator` instead."
  59. )
  60. model.eval()
  61. self.model = model
  62. s2t_model = UnitYX2TModel(
  63. encoder_frontend=model.speech_encoder_frontend,
  64. encoder=model.speech_encoder,
  65. decoder_frontend=model.text_decoder_frontend,
  66. decoder=model.text_decoder,
  67. final_proj=model.final_proj,
  68. pad_idx=model.pad_idx,
  69. )
  70. self.s2t_generator = SequenceToTextGenerator(
  71. s2t_model, text_tokenizer, target_lang, text_opts
  72. )
  73. if model.text_encoder is None:
  74. self.t2t_generator = None
  75. else:
  76. assert model.text_encoder_frontend is not None
  77. assert model.text_encoder is not None
  78. t2t_model = UnitYX2TModel(
  79. encoder_frontend=model.text_encoder_frontend,
  80. encoder=model.text_encoder,
  81. decoder_frontend=model.text_decoder_frontend,
  82. decoder=model.text_decoder,
  83. final_proj=model.final_proj,
  84. pad_idx=model.pad_idx,
  85. )
  86. self.t2t_generator = SequenceToTextGenerator(
  87. t2t_model, text_tokenizer, target_lang, text_opts
  88. )
  89. self.unit_generator = None
  90. self.unit_decoder = None
  91. # Set up unit generator.
  92. if unit_tokenizer is not None:
  93. self.unit_decoder = unit_tokenizer.create_decoder()
  94. unit_encoder = unit_tokenizer.create_encoder(
  95. lang=target_lang, device=infer_device(model.t2u_model)
  96. )
  97. if unit_opts is None:
  98. # Speech sequences are typically much longer than text sequences.
  99. unit_opts = SequenceGeneratorOptions(
  100. soft_max_seq_len=(1, 50), hard_max_seq_len=5000
  101. )
  102. self.unit_generator = Seq2SeqGenerator(
  103. model.t2u_model,
  104. unit_tokenizer.vocab_info,
  105. unit_encoder.prefix_indices,
  106. unit_opts,
  107. )
  108. @torch.inference_mode()
  109. def __call__(
  110. self,
  111. source_seqs: Tensor,
  112. source_seq_lens: Optional[Tensor],
  113. input_modality: str = "speech",
  114. output_modality: str = "speech",
  115. ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
  116. """
  117. :param source_seqs:
  118. The source sequences to use for generation. *Shape:* :math:`(N,S,*)`,
  119. where :math:`N` is the batch size, :math:`S` is the sequence length,
  120. and :math:`*` is any number of sequence-specific dimensions
  121. including none.
  122. :param source_seq_lens:
  123. An array where each element represents the length of the sequence at
  124. the same index in ``source_seqs``. *Shape:* :math:`(N)`, where
  125. :math:`N` is the batch size.
  126. :param input_modality:
  127. The type of modality to encode.
  128. :param output_modality:
  129. The type of modality to decode.
  130. :returns:
  131. - The output of the text generator.
  132. - The output of the unit generator.
  133. """
  134. if input_modality == "speech":
  135. text_output = self.s2t_generator.generate_ex(source_seqs, source_seq_lens)
  136. elif input_modality == "text" and self.t2t_generator is not None:
  137. text_output = self.t2t_generator.generate_ex(source_seqs, source_seq_lens)
  138. elif input_modality == "text" and self.t2t_generator is None:
  139. raise ValueError(
  140. f"Please set use_text_encoder to True in your model config to encode text."
  141. )
  142. else:
  143. raise ValueError(f"Unsupported input_modality: {input_modality}")
  144. # We skip T2U when we only need to output text.
  145. if output_modality == "text":
  146. return text_output, None
  147. text_seqs, text_seq_lens = text_output.generator_output.collate()
  148. # Use the output of the text generator to compute the decoder output.
  149. decoder_output, decoder_padding_mask = self.model.decode(
  150. text_seqs,
  151. text_seq_lens,
  152. text_output.encoder_output,
  153. text_output.encoder_padding_mask,
  154. )
  155. assert self.model.t2u_model is not None
  156. t2u_encoder_output, t2u_encoder_padding_mask = self.model.t2u_model.encode(
  157. decoder_output, decoder_padding_mask
  158. )
  159. assert self.unit_generator is not None
  160. assert self.unit_decoder is not None
  161. unit_gen_output = self.unit_generator(
  162. t2u_encoder_output,
  163. t2u_encoder_padding_mask,
  164. source_seq_len=source_seqs.size(1),
  165. )
  166. unit_seqs, _ = unit_gen_output.collate()
  167. # Convert to speech units.
  168. units = self.unit_decoder(unit_seqs)
  169. unit_output = SequenceToUnitOutput(
  170. units, unit_gen_output, t2u_encoder_output, t2u_encoder_padding_mask
  171. )
  172. return text_output, unit_output
  173. @dataclass
  174. class SequenceToUnitOutput:
  175. units: Tensor
  176. """The generated units."""
  177. generator_output: SequenceGeneratorOutput
  178. """The output of the underlying :class:`Seq2SeqGenerator`."""
  179. t2u_encoder_output: Tensor
  180. """The encoder output of the underlying UnitY T2U model used to generate the
  181. units. *Shape:* :math:`(N,S_{enc},M)`, where :math:`N` is the batch size,
  182. :math:`S_{enc}` is the encoder output sequence length, and :math:`M` is the
  183. dimensionality of the model."""
  184. t2u_encoder_padding_mask: Optional[Tensor]
  185. """The float padding mask of :attr:`encoder_output`. *Shape:*
  186. :math:`(N,S_{enc})`, where :math:`N` is the batch size and :math:`S_{enc}`
  187. is the encoder output sequence length."""