|
@@ -8,6 +8,8 @@ from dataclasses import dataclass
|
|
|
from typing import Optional, Tuple, List
|
|
|
|
|
|
import torch
|
|
|
+
|
|
|
+from torch import Tensor
|
|
|
from fairseq2.data.text import TextTokenizer
|
|
|
from fairseq2.generation import (
|
|
|
Seq2SeqGenerator,
|
|
@@ -16,6 +18,9 @@ from fairseq2.generation import (
|
|
|
SequenceToTextGenerator,
|
|
|
SequenceToTextOutput,
|
|
|
)
|
|
|
+from fairseq2.nn.padding import PaddingMask, apply_padding_mask
|
|
|
+from fairseq2.nn.utils.module import infer_device
|
|
|
+
|
|
|
from seamless_communication.models.unity.model import (
|
|
|
UnitYModel,
|
|
|
UnitYX2TModel,
|
|
@@ -25,9 +30,6 @@ from seamless_communication.models.unity.unit_tokenizer import (
|
|
|
UnitTokenDecoder,
|
|
|
UnitTokenizer,
|
|
|
)
|
|
|
-from fairseq2.nn.utils.module import infer_device
|
|
|
-from fairseq2.nn.padding import PaddingMask
|
|
|
-from torch import Tensor
|
|
|
|
|
|
|
|
|
def remove_consecutive_repeated_ngrams(
|
|
@@ -233,10 +235,9 @@ class UnitYGenerator:
|
|
|
# (B, S_unit, V_unit)
|
|
|
unit_seqs = unit_decoder_output.logits.argmax(dim=2)
|
|
|
# Apply the padding mask to the generated units.
|
|
|
- if decoder_padding_mask is not None:
|
|
|
- m = decoder_padding_mask.materialize()
|
|
|
-
|
|
|
- unit_seqs[m == -torch.inf] = unit_decoder_output.pad_idx
|
|
|
+ unit_seqs = apply_padding_mask(
|
|
|
+ unit_seqs, decoder_padding_mask, fill_value=unit_decoder_output.pad_idx
|
|
|
+ )
|
|
|
|
|
|
# Convert to speech units.
|
|
|
units = self.unit_decoder(unit_seqs)
|