Browse Source

Padding mask bug fix, it is boolean and not float, UnitY generator. (#54)

Kaushik Ram Sadagopan 1 year ago
parent
commit
d8521ec060
1 changed files with 8 additions and 7 deletions
  1. 8 7
      src/seamless_communication/models/unity/generator.py

+ 8 - 7
src/seamless_communication/models/unity/generator.py

@@ -8,6 +8,8 @@ from dataclasses import dataclass
 from typing import Optional, Tuple, List
 
 import torch
+
+from torch import Tensor
 from fairseq2.data.text import TextTokenizer
 from fairseq2.generation import (
     Seq2SeqGenerator,
@@ -16,6 +18,9 @@ from fairseq2.generation import (
     SequenceToTextGenerator,
     SequenceToTextOutput,
 )
+from fairseq2.nn.padding import PaddingMask, apply_padding_mask
+from fairseq2.nn.utils.module import infer_device
+
 from seamless_communication.models.unity.model import (
     UnitYModel,
     UnitYX2TModel,
@@ -25,9 +30,6 @@ from seamless_communication.models.unity.unit_tokenizer import (
     UnitTokenDecoder,
     UnitTokenizer,
 )
-from fairseq2.nn.utils.module import infer_device
-from fairseq2.nn.padding import PaddingMask
-from torch import Tensor
 
 
 def remove_consecutive_repeated_ngrams(
@@ -233,10 +235,9 @@ class UnitYGenerator:
             # (B, S_unit, V_unit)
             unit_seqs = unit_decoder_output.logits.argmax(dim=2)
             # Apply the padding mask to the generated units.
-            if decoder_padding_mask is not None:
-                m = decoder_padding_mask.materialize()
-
-                unit_seqs[m == -torch.inf] = unit_decoder_output.pad_idx
+            unit_seqs = apply_padding_mask(
+                unit_seqs, decoder_padding_mask, fill_value=unit_decoder_output.pad_idx
+            )
 
         # Convert to speech units.
         units = self.unit_decoder(unit_seqs)