浏览代码

Apply padding_mask to output unit sequences.

Kaushik Ram Sadagopan 2 年之前
父节点
当前提交
157a1628f0

+ 3 - 1
src/seamless_communication/models/unity/generator.py

@@ -219,7 +219,7 @@ class UnitYGenerator:
             )
             unit_seqs, _ = unit_gen_output.collate()
         else:
-            unit_decoder_output = self.model.t2u_model(
+            unit_decoder_output, decoder_padding_mask = self.model.t2u_model(
                 text_decoder_output=decoder_output,
                 text_decoder_padding_mask=decoder_padding_mask,
                 target_seqs=None,
@@ -228,6 +228,8 @@ class UnitYGenerator:
             )
             # (B, S_unit, V_unit)
             unit_seqs = unit_decoder_output.logits.argmax(dim=2)
+            # Apply the padding mask to the generated units.
+            unit_seqs[decoder_padding_mask == -torch.inf] = unit_decoder_output.pad_idx
 
         # Convert to speech units.
         units = self.unit_decoder(unit_seqs)

+ 3 - 2
src/seamless_communication/models/unity/length_regulator.py

@@ -14,6 +14,7 @@ from fairseq2.typing import DataType, Device
 from fairseq2.nn.transformer import create_default_layer_norm
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.projection import Linear
+from fairseq2.nn.utils.mask import apply_padding_mask
 
 
 class HardUpsampling(Module):
@@ -177,8 +178,8 @@ class VarianceAdaptor(Module):
             torch.round((torch.exp(log_durations) - 1) * duration_factor).long(),
             min=min_duration,
         )
-        if padding_mask is not None:
-            durations.masked_fill_(padding_mask, 0)
+
+        durations = apply_padding_mask(durations, padding_mask)
 
         # TODO: Implement pitch, energy predictors.
         # TODO: Implement GaussianUpsampling.

+ 3 - 5
src/seamless_communication/models/unity/model.py

@@ -379,7 +379,7 @@ class UnitYNART2UModel(Module):
         target_seqs: Optional[Tensor],
         target_seq_lens: Optional[Tensor],
         text_seqs: Optional[Tensor],
-    ) -> SequenceModelOutput:
+    ) -> Tuple[SequenceModelOutput, Optional[Tensor]]:
         encoder_output, encoder_padding_mask = self.encode(
             text_decoder_output, text_decoder_padding_mask
         )
@@ -392,7 +392,7 @@ class UnitYNART2UModel(Module):
             text_seqs,
         )
 
-        return self.project(decoder_output, decoder_padding_mask)
+        return self.project(decoder_output), decoder_padding_mask
 
     def encode(
         self,
@@ -420,9 +420,7 @@ class UnitYNART2UModel(Module):
 
         return self.decoder(seqs, padding_mask)  # type: ignore[no-any-return]
 
-    def project(
-        self, decoder_output: Tensor, decoder_padding_mask: Optional[Tensor]
-    ) -> SequenceModelOutput:
+    def project(self, decoder_output: Tensor) -> SequenceModelOutput:
         logits = self.final_proj(decoder_output)
 
         return SequenceModelOutput(logits, self.pad_idx)

+ 5 - 2
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -62,8 +62,11 @@ class TagManager:
         """Add back 0s in place of the prefix, suffix tokens."""
         N = dur_or_len.shape[0]
 
-        prefix = dur_or_len.new_zeros((N, self.prefix_len))
-        suffix = dur_or_len.new_zeros((N, self.suffix_len))
+        # prefix = dur_or_len.new_zeros((N, self.prefix_len))
+        # suffix = dur_or_len.new_zeros((N, self.suffix_len))
+
+        prefix = dur_or_len.new_zeros((N, 1))
+        suffix = dur_or_len.new_zeros((N, 1))
 
         dur_or_len = torch.cat([prefix, dur_or_len, suffix], dim=1)
         return dur_or_len