|
@@ -7,6 +7,7 @@
|
|
|
from dataclasses import dataclass
|
|
|
from typing import Optional, Tuple, Union, final
|
|
|
|
|
|
+from fairseq2.data import VocabularyInfo
|
|
|
from fairseq2.models.encoder_decoder import EncoderDecoderModel, Seq2SeqDecoder
|
|
|
from fairseq2.models.sequence import SequenceModelOutput
|
|
|
from fairseq2.models.transformer.frontend import TransformerFrontend
|
|
@@ -41,7 +42,6 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
text_decoder: TransformerDecoder
|
|
|
final_proj: Projection
|
|
|
t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None]
|
|
|
- pad_idx: Optional[int]
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
@@ -53,12 +53,12 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
text_decoder: TransformerDecoder,
|
|
|
final_proj: Projection,
|
|
|
t2u_model: Union["UnitYT2UModel", "UnitYNART2UModel", None],
|
|
|
- pad_idx: Optional[int],
|
|
|
+ target_vocab_info: VocabularyInfo,
|
|
|
input_modality: str = "speech",
|
|
|
) -> None:
|
|
|
model_dim = speech_encoder.model_dim
|
|
|
|
|
|
- super().__init__(model_dim)
|
|
|
+ super().__init__(model_dim, target_vocab_info)
|
|
|
|
|
|
self.input_modality = input_modality
|
|
|
|
|
@@ -92,7 +92,7 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
else:
|
|
|
self.register_module("t2u_model", None)
|
|
|
|
|
|
- self.pad_idx = pad_idx
|
|
|
+ self.target_vocab_info = target_vocab_info
|
|
|
|
|
|
@finaloverride
|
|
|
def encode(
|
|
@@ -136,6 +136,7 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
padding_mask: Optional[PaddingMask],
|
|
|
encoder_output: Tensor,
|
|
|
encoder_padding_mask: Optional[PaddingMask],
|
|
|
+ *,
|
|
|
state_bag: Optional[IncrementalStateBag] = None,
|
|
|
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
seqs, padding_mask = self.text_decoder_frontend(
|
|
@@ -156,7 +157,7 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
) -> SequenceModelOutput:
|
|
|
logits = self.final_proj(decoder_output)
|
|
|
|
|
|
- return SequenceModelOutput(logits, self.pad_idx)
|
|
|
+ return SequenceModelOutput(logits, self.target_vocab_info)
|
|
|
|
|
|
|
|
|
@final
|
|
@@ -167,7 +168,6 @@ class UnitYX2TModel(EncoderDecoderModel):
|
|
|
decoder_frontend: TransformerFrontend
|
|
|
decoder: TransformerDecoder
|
|
|
final_proj: Projection
|
|
|
- pad_idx: Optional[int]
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
@@ -176,17 +176,18 @@ class UnitYX2TModel(EncoderDecoderModel):
|
|
|
decoder_frontend: TransformerFrontend,
|
|
|
decoder: TransformerDecoder,
|
|
|
final_proj: Projection,
|
|
|
- pad_idx: Optional[int],
|
|
|
+ target_vocab_info: VocabularyInfo,
|
|
|
) -> None:
|
|
|
model_dim = encoder.model_dim
|
|
|
- super().__init__(model_dim)
|
|
|
+
|
|
|
+ super().__init__(model_dim, target_vocab_info)
|
|
|
|
|
|
self.encoder_frontend = encoder_frontend
|
|
|
self.encoder = encoder
|
|
|
self.decoder_frontend = decoder_frontend
|
|
|
self.decoder = decoder
|
|
|
self.final_proj = final_proj
|
|
|
- self.pad_idx = pad_idx
|
|
|
+ self.target_vocab_info = target_vocab_info
|
|
|
|
|
|
@finaloverride
|
|
|
def encode(
|
|
@@ -202,6 +203,7 @@ class UnitYX2TModel(EncoderDecoderModel):
|
|
|
padding_mask: Optional[PaddingMask],
|
|
|
encoder_output: Tensor,
|
|
|
encoder_padding_mask: Optional[PaddingMask],
|
|
|
+ *,
|
|
|
state_bag: Optional[IncrementalStateBag] = None,
|
|
|
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
seqs, padding_mask = self.decoder_frontend(
|
|
@@ -222,7 +224,7 @@ class UnitYX2TModel(EncoderDecoderModel):
|
|
|
) -> SequenceModelOutput:
|
|
|
logits = self.final_proj(decoder_output)
|
|
|
|
|
|
- return SequenceModelOutput(logits, self.pad_idx)
|
|
|
+ return SequenceModelOutput(logits, self.target_vocab_info)
|
|
|
|
|
|
|
|
|
@final
|
|
@@ -235,7 +237,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
decoder_frontend: TransformerFrontend
|
|
|
decoder: TransformerDecoder
|
|
|
final_proj: Projection
|
|
|
- pad_idx: Optional[int]
|
|
|
+ target_vocab_info: VocabularyInfo
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
@@ -243,7 +245,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
decoder_frontend: TransformerFrontend,
|
|
|
decoder: TransformerDecoder,
|
|
|
final_proj: Projection,
|
|
|
- pad_idx: Optional[int],
|
|
|
+ target_vocab_info: VocabularyInfo,
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
|
|
@@ -269,7 +271,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
|
|
|
self.final_proj = final_proj
|
|
|
|
|
|
- self.pad_idx = pad_idx
|
|
|
+ self.target_vocab_info = target_vocab_info
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
@@ -307,6 +309,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
padding_mask: Optional[PaddingMask],
|
|
|
encoder_output: Tensor,
|
|
|
encoder_padding_mask: Optional[PaddingMask],
|
|
|
+ *,
|
|
|
state_bag: Optional[IncrementalStateBag] = None,
|
|
|
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
seqs, padding_mask = self.decoder_frontend(
|
|
@@ -326,7 +329,7 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
) -> SequenceModelOutput:
|
|
|
logits = self.final_proj(decoder_output)
|
|
|
|
|
|
- return SequenceModelOutput(logits, self.pad_idx)
|
|
|
+ return SequenceModelOutput(logits, self.target_vocab_info)
|
|
|
|
|
|
|
|
|
@final
|
|
@@ -338,7 +341,7 @@ class UnitYNART2UModel(Module):
|
|
|
decoder_frontend: NARDecoderFrontend
|
|
|
decoder: NARTransformerDecoder
|
|
|
final_proj: Projection
|
|
|
- pad_idx: Optional[int]
|
|
|
+ target_vocab_info: VocabularyInfo
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
@@ -346,7 +349,7 @@ class UnitYNART2UModel(Module):
|
|
|
decoder_frontend: NARDecoderFrontend,
|
|
|
decoder: NARTransformerDecoder,
|
|
|
final_proj: Projection,
|
|
|
- pad_idx: Optional[int],
|
|
|
+ target_vocab_info: VocabularyInfo,
|
|
|
) -> None:
|
|
|
super().__init__()
|
|
|
|
|
@@ -372,7 +375,7 @@ class UnitYNART2UModel(Module):
|
|
|
|
|
|
self.final_proj = final_proj
|
|
|
|
|
|
- self.pad_idx = pad_idx
|
|
|
+ self.target_vocab_info = target_vocab_info
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
@@ -419,7 +422,7 @@ class UnitYNART2UModel(Module):
|
|
|
def project(self, decoder_output: Tensor) -> SequenceModelOutput:
|
|
|
logits = self.final_proj(decoder_output)
|
|
|
|
|
|
- return SequenceModelOutput(logits, self.pad_idx)
|
|
|
+ return SequenceModelOutput(logits, self.target_vocab_info)
|
|
|
|
|
|
|
|
|
@dataclass
|