nar_decoder_frontend.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import math
  7. from typing import List, Optional, Tuple, final
  8. import torch
  9. from fairseq2.data import VocabularyInfo
  10. from fairseq2.models.nllb.tokenizer import NllbTokenizer
  11. from fairseq2.nn.embedding import Embedding
  12. from fairseq2.nn.normalization import LayerNorm
  13. from fairseq2.nn.padding import PaddingMask
  14. from fairseq2.nn.position_encoder import PositionEncoder
  15. from fairseq2.nn.transformer import create_standard_layer_norm
  16. from fairseq2.typing import DataType, Device, finaloverride
  17. from torch import Tensor
  18. from torch.nn import Dropout, Module, Parameter
  19. from seamless_communication.models.unity.char_tokenizer import CharTokenizer
  20. from seamless_communication.models.unity.length_regulator import (
  21. HardUpsampling,
  22. VarianceAdaptor,
  23. )
  24. SPACE = "▁"
  25. class TagManager:
  26. def __init__(self, vocab_info: VocabularyInfo):
  27. self.vocab_info = vocab_info
  28. def preprocess_text_seqs(self, text_seqs: Tensor) -> Tensor:
  29. # Remove EOS, lang tokens as per NLLB "target" tokenizer mode.
  30. text_seqs = text_seqs[:, 2:]
  31. assert self.vocab_info.pad_idx is not None
  32. text_seqs.masked_fill_(
  33. text_seqs == self.vocab_info.eos_idx, self.vocab_info.pad_idx
  34. )
  35. return text_seqs
  36. def postprocess_dur_or_len(self, dur_or_len: Tensor) -> Tensor:
  37. N = dur_or_len.shape[0]
  38. pad_zero = dur_or_len.new_zeros((N, 1))
  39. # Add pads for lang, EOS tokens as per NLLB "source" tokenizer mode.
  40. dur_or_len = torch.cat([pad_zero, dur_or_len, pad_zero], dim=1)
  41. return dur_or_len
  42. @final
  43. class NARDecoderFrontend(Module):
  44. """Represents a Non-autoregressive decoder front-end."""
  45. char_pos_encoder: PositionEncoder
  46. pos_emb_alpha_char: Parameter
  47. unit_pos_encoder: PositionEncoder
  48. pos_emb_alpha: Parameter
  49. scale: float
  50. char_length_regulator: HardUpsampling
  51. variance_adaptor: VarianceAdaptor
  52. layer_norm: Optional[LayerNorm]
  53. dropout: Optional[Dropout]
  54. def __init__(
  55. self,
  56. embed: Embedding,
  57. embed_char: Embedding,
  58. text_tokenizer: NllbTokenizer,
  59. char_tokenizer: CharTokenizer,
  60. unit_pos_encoder: PositionEncoder,
  61. char_pos_encoder: PositionEncoder,
  62. variance_adaptor: VarianceAdaptor,
  63. no_scale: bool = False,
  64. layer_norm: bool = False,
  65. dropout_p: float = 0.1,
  66. device: Optional[Device] = None,
  67. dtype: Optional[DataType] = None,
  68. ):
  69. self.model_dim = embed.embedding_dim
  70. super().__init__()
  71. self.embed = embed
  72. self.embed_char = embed_char
  73. self.text_tokenizer = text_tokenizer
  74. self.char_tokenizer = char_tokenizer
  75. self.tag_manager = TagManager(text_tokenizer.vocab_info)
  76. self.unk_idx = self.text_tokenizer.vocab_info.unk_idx
  77. self.pad_idx = self.text_tokenizer.vocab_info.pad_idx
  78. # TODO: Implement AlignmentEncoder for training.
  79. if unit_pos_encoder.encoding_dim != self.model_dim:
  80. raise ValueError(
  81. f"`encoding_dim` of `unit_pos_encoder` and `embedding_dim` of `embed` must be equal, but are {unit_pos_encoder.encoding_dim} and {self.model_dim} instead."
  82. )
  83. if char_pos_encoder.encoding_dim != self.model_dim:
  84. raise ValueError(
  85. f"`encoding_dim` of `char_pos_encoder` and `embedding_dim` of `embed` must be equal, but are {char_pos_encoder.encoding_dim} and {self.model_dim} instead."
  86. )
  87. self.unit_pos_encoder = unit_pos_encoder
  88. self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype))
  89. self.char_pos_encoder = char_pos_encoder
  90. self.pos_emb_alpha_char = Parameter(torch.ones(1, device=device, dtype=dtype))
  91. self.scale = 1.0 if no_scale else math.sqrt(self.model_dim)
  92. self.char_length_regulator = HardUpsampling()
  93. self.variance_adaptor = variance_adaptor
  94. if layer_norm:
  95. self.layer_norm = create_standard_layer_norm(
  96. self.model_dim, device=device, dtype=dtype
  97. )
  98. else:
  99. self.register_module("layer_norm", None)
  100. if dropout_p > 0.0:
  101. self.dropout = Dropout(dropout_p)
  102. else:
  103. self.register_module("dropout", None)
  104. def indices_to_subwords(self, text_seqs: Tensor) -> List[List[str]]:
  105. # TODO: To be replaced with fairseq2's indices_to_tokens SPM model method
  106. # once implemented.
  107. N, seq_len = text_seqs.shape
  108. subwords_batch = []
  109. for b in range(N):
  110. subwords = []
  111. for i in range(seq_len):
  112. subword = self.text_tokenizer.model.index_to_token(int(text_seqs[b, i]))
  113. subwords.append(str(subword))
  114. subwords_batch.append(subwords)
  115. return subwords_batch
  116. def text_to_char_seqs(self, text_seqs: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
  117. text_seqs = self.tag_manager.preprocess_text_seqs(text_seqs)
  118. subwords_batch = self.indices_to_subwords(text_seqs)
  119. char_lens = self.count_character_length_in_subword(text_seqs, subwords_batch)
  120. char_lens = self.tag_manager.postprocess_dur_or_len(char_lens)
  121. char_seqs, char_seq_lens = self.get_char_seqs(
  122. text_seqs, subwords_batch, char_lens
  123. )
  124. return char_seqs, char_seq_lens, char_lens
  125. def count_character_length_in_subword(
  126. self,
  127. text_seqs: Tensor,
  128. subwords_batch: List[List[str]],
  129. merge_space_with_prev_subword: bool = False,
  130. ) -> Tensor:
  131. N, _ = text_seqs.shape
  132. char_lens = text_seqs.new_zeros(text_seqs.size())
  133. assert self.pad_idx is not None
  134. subword_lens = text_seqs.ne(self.pad_idx).sum(1)
  135. for b in range(N):
  136. # We slice out the tensor till the padding index.
  137. subword_indices = text_seqs[b, : subword_lens[b]]
  138. subwords = subwords_batch[b][: subword_lens[b]]
  139. assert subword_indices.shape[0] == len(subwords)
  140. is_next_start_with_space = [
  141. len(subwords[i + 1]) > 1 and subwords[i + 1][0] == SPACE
  142. if i < len(subwords) - 1
  143. else False
  144. for i in range(len(subwords))
  145. ]
  146. is_punc = [
  147. len(subwords[i]) == 1
  148. and not subwords[i].isalpha()
  149. and not subwords[i].isnumeric()
  150. and subwords[i] != SPACE
  151. for i in range(len(subwords))
  152. ]
  153. for i, (subword_idx, subword) in enumerate(zip(subword_indices, subwords)):
  154. if subword_idx == self.pad_idx:
  155. break
  156. if subword_idx == self.unk_idx:
  157. # We set char_len to 1 for an unk token.
  158. char_len = 1
  159. if merge_space_with_prev_subword and is_next_start_with_space[i]:
  160. char_len += 1
  161. else:
  162. # By default, spaces are merged with the next subword.
  163. # char_len includes the space.
  164. char_len = len(subword)
  165. if merge_space_with_prev_subword:
  166. # Add the space for the next subword.
  167. if is_next_start_with_space[i]:
  168. char_len += 1
  169. # Subtract the space for the current subword.
  170. if i > 0 and is_next_start_with_space[i - 1]:
  171. char_len -= 1
  172. else:
  173. # Merge space with punctuation mark by default.
  174. if is_punc[i] and is_next_start_with_space[i]:
  175. char_len += 1
  176. # Subtract the space for the subword succeeding the punctuation mark.
  177. elif (
  178. i > 0 and is_punc[i - 1] and is_next_start_with_space[i - 1]
  179. ):
  180. char_len -= 1
  181. char_lens[b, i] = char_len
  182. return char_lens
  183. def get_char_seqs(
  184. self, text_seqs: Tensor, subwords_batch: List[List[str]], char_lens: Tensor
  185. ) -> Tuple[Tensor, Tensor]:
  186. N = text_seqs.shape[0]
  187. max_len = int(char_lens.sum(1).max().item())
  188. assert self.pad_idx is not None
  189. char_seqs = text_seqs.new_zeros((N, max_len)).fill_(self.pad_idx)
  190. char_seq_lens = char_seqs.new_zeros(N)
  191. assert self.pad_idx is not None
  192. subword_lens = text_seqs.ne(self.pad_idx).sum(1)
  193. for b in range(N):
  194. total = 0
  195. subword_indices = text_seqs[b, : subword_lens[b]]
  196. subwords = subwords_batch[b][: subword_lens[b]]
  197. for subword_idx, subword in zip(subword_indices, subwords):
  198. if subword_idx == self.unk_idx:
  199. char_ids = [self.unk_idx]
  200. else:
  201. # Get char token indices corresponding to the subwords.
  202. char_ids = [
  203. self.char_tokenizer.model.token_to_index(ch)
  204. for ch in list(subword)
  205. ]
  206. char_seq_len = len(char_ids)
  207. char_seqs[b, total : total + char_seq_len] = torch.tensor(char_ids).to(
  208. char_seqs
  209. )
  210. total += char_seq_len
  211. char_seq_lens[b] = total
  212. return char_seqs, char_seq_lens
  213. def character_level_upsampling(
  214. self,
  215. seqs: Tensor,
  216. padding_mask: Optional[PaddingMask],
  217. char_seqs: Tensor,
  218. char_lens: Tensor,
  219. ) -> Tensor:
  220. seqs, _ = self.char_length_regulator(seqs, char_lens)
  221. pos_embeds = self.pos_emb_alpha_char * (
  222. self.char_pos_encoder(seqs, padding_mask) - seqs
  223. )
  224. char_embeds = self.embed_char(char_seqs)
  225. if self.scale != 1.0:
  226. char_embeds *= self.scale
  227. pos_embeds += char_embeds
  228. seqs += pos_embeds
  229. return seqs
  230. def forward_unit_pos_embedding(
  231. self, seqs: Tensor, padding_mask: Optional[PaddingMask]
  232. ) -> Tensor:
  233. pos_embeds = self.pos_emb_alpha * (
  234. self.unit_pos_encoder(seqs, padding_mask) - seqs
  235. )
  236. seqs += pos_embeds
  237. if self.dropout is not None:
  238. seqs = self.dropout(seqs)
  239. return seqs
  240. @finaloverride
  241. def forward(
  242. self,
  243. encoder_output: Tensor,
  244. encoder_padding_mask: Optional[PaddingMask],
  245. text_seqs: Optional[Tensor],
  246. film_cond_emb: Optional[Tensor] = None,
  247. ) -> Tuple[Tensor, Optional[PaddingMask]]:
  248. assert text_seqs is not None
  249. # text_seqs: (N, S_text)
  250. char_seqs, char_seq_lens, char_lens = self.text_to_char_seqs(text_seqs)
  251. # char_seqs: (N, S_char)
  252. encoder_padding_mask = PaddingMask(
  253. char_seq_lens, batch_seq_len=char_seqs.size(1)
  254. )
  255. # (N, S_text, M) -> (N, S_char, M)
  256. seqs = self.character_level_upsampling(
  257. encoder_output, encoder_padding_mask, char_seqs, char_lens
  258. )
  259. # (N, S_char, M) -> (N, S_unit, M)
  260. seqs, padding_mask = self.variance_adaptor(
  261. seqs,
  262. encoder_padding_mask,
  263. min_duration=1,
  264. film_cond_emb=film_cond_emb,
  265. )
  266. seqs = self.forward_unit_pos_embedding(seqs, padding_mask)
  267. return seqs, padding_mask