Browse Source

Apply padding_mask in VariancePredictor, enforce TagManager parity with fairseq.

Kaushik Ram Sadagopan 2 years ago
parent
commit
65851ae4cc

+ 12 - 2
src/seamless_communication/models/unity/length_regulator.py

@@ -101,7 +101,10 @@ class VariancePredictor(Module):
             var_pred_hidden_dim, 1, bias=True, device=device, dtype=dtype
         )
 
-    def forward(self, seqs: Tensor) -> Tensor:
+    def forward(self, seqs: Tensor, padding_mask: Optional[Tensor]) -> Tensor:
+        # Ensure that we do not leak padded positions in the convolution layer.
+        seqs = apply_padding_mask(seqs, padding_mask)
+
         # (N, S, M) -> (N, M, S)
         seqs = seqs.transpose(1, 2)
 
@@ -115,6 +118,8 @@ class VariancePredictor(Module):
 
         seqs = self.dropout_module(seqs)
 
+        seqs = apply_padding_mask(seqs, padding_mask)
+
         # (N, S, H) -> (N, H, S)
         seqs = seqs.transpose(1, 2)
 
@@ -128,9 +133,13 @@ class VariancePredictor(Module):
 
         seqs = self.dropout_module(seqs)
 
+        seqs = apply_padding_mask(seqs, padding_mask)
+
         # (N, S, H) -> (N, S, 1) -> (N, S)
         seqs = self.proj(seqs).squeeze(dim=2)
 
+        seqs = apply_padding_mask(seqs, padding_mask)
+
         return seqs
 
 
@@ -172,13 +181,14 @@ class VarianceAdaptor(Module):
         duration_factor: float = 1.0,
         min_duration: int = 0,
     ) -> Tuple[Tensor, Tensor]:
-        log_durations = self.duration_predictor(seqs)
+        log_durations = self.duration_predictor(seqs, padding_mask)
 
         durations = torch.clamp(
             torch.round((torch.exp(log_durations) - 1) * duration_factor).long(),
             min=min_duration,
         )
 
+        # We need to apply the padding_mask again since we clamp by min_duration.
         durations = apply_padding_mask(durations, padding_mask)
 
         # TODO: Implement pitch, energy predictors.

+ 8 - 26
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -34,24 +34,12 @@ SPACE = "▁"
 
 
 class TagManager:
-    def __init__(self, text_tokenizer: TextTokenizer):
-        self.vocab_info: VocabularyInfo = text_tokenizer.vocab_info
-        token_encoder = text_tokenizer.create_encoder(mode="target")
-        self.prefix_len: int = (
-            token_encoder.prefix_indices.shape[0]
-            if token_encoder.prefix_indices is not None
-            else 0
-        )
-        self.suffix_len: int = (
-            token_encoder.suffix_indices.shape[0]
-            if token_encoder.suffix_indices is not None
-            else 0
-        )
+    def __init__(self, vocab_info: VocabularyInfo):
+        self.vocab_info = vocab_info
 
     def preprocess_text_seqs(self, text_seqs: Tensor) -> Tensor:
-        """Remove the prefix, suffix tokens."""
-        seq_len = text_seqs.shape[1]
-        text_seqs = text_seqs[:, self.prefix_len : seq_len - self.suffix_len]
+        # Remove EOS, lang tokens as per NLLB "target" tokenizer mode.
+        text_seqs = text_seqs[:, 2:]
         assert self.vocab_info.pad_idx is not None
         text_seqs.masked_fill_(
             text_seqs == self.vocab_info.eos_idx, self.vocab_info.pad_idx
@@ -59,16 +47,10 @@ class TagManager:
         return text_seqs
 
     def postprocess_dur_or_len(self, dur_or_len: Tensor) -> Tensor:
-        """Add back 0s in place of the prefix, suffix tokens."""
         N = dur_or_len.shape[0]
-
-        # prefix = dur_or_len.new_zeros((N, self.prefix_len))
-        # suffix = dur_or_len.new_zeros((N, self.suffix_len))
-
-        prefix = dur_or_len.new_zeros((N, 1))
-        suffix = dur_or_len.new_zeros((N, 1))
-
-        dur_or_len = torch.cat([prefix, dur_or_len, suffix], dim=1)
+        pad_zero = dur_or_len.new_zeros((N, 1))
+        # Add pads for lang, EOS tokens as per NLLB "source" tokenizer mode.
+        dur_or_len = torch.cat([pad_zero, dur_or_len, pad_zero], dim=1)
         return dur_or_len
 
 
@@ -109,7 +91,7 @@ class NARDecoderFrontend(Module):
         self.embed_char = embed_char
         self.text_tokenizer = text_tokenizer
         self.char_tokenizer = char_tokenizer
-        self.tag_manager = TagManager(text_tokenizer)
+        self.tag_manager = TagManager(text_tokenizer.vocab_info)
 
         self.unk_idx = self.text_tokenizer.vocab_info.unk_idx
         self.pad_idx = self.text_tokenizer.vocab_info.pad_idx

+ 1 - 1
src/seamless_communication/models/vocoder/codehifigan.py

@@ -78,7 +78,7 @@ class CodeGenerator(Generator):
 
         if self.dur_predictor and dur_prediction:
             assert x.size(0) == 1, "only support single sample"
-            log_dur_pred = self.dur_predictor(x.transpose(1, 2))
+            log_dur_pred = self.dur_predictor(x.transpose(1, 2), None)
             dur_out = torch.clamp(
                 torch.round((torch.exp(log_dur_pred) - 1)).long(), min=1
             )