|
@@ -8,7 +8,7 @@ from dataclasses import dataclass
|
|
|
from typing import Optional, Tuple, Union, final
|
|
|
|
|
|
from fairseq2.data import VocabularyInfo
|
|
|
-from fairseq2.models.encoder_decoder import EncoderDecoderModel, Seq2SeqDecoder
|
|
|
+from fairseq2.models.encoder_decoder import EncoderDecoderModel
|
|
|
from fairseq2.models.sequence import SequenceModelOutput
|
|
|
from fairseq2.models.transformer.frontend import TransformerFrontend
|
|
|
from fairseq2.nn.incremental_state import IncrementalStateBag
|
|
@@ -228,16 +228,14 @@ class UnitYX2TModel(EncoderDecoderModel):
|
|
|
|
|
|
|
|
|
@final
|
|
|
-class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
+class UnitYT2UModel(EncoderDecoderModel):
|
|
|
"""Represents a UnitY T2U model as described in
|
|
|
:cite:t`https://doi.org/10.48550/arxiv.2212.08055`."""
|
|
|
|
|
|
- model_dim: int
|
|
|
encoder: Optional[TransformerEncoder]
|
|
|
decoder_frontend: TransformerFrontend
|
|
|
decoder: TransformerDecoder
|
|
|
final_proj: Projection
|
|
|
- target_vocab_info: VocabularyInfo
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
@@ -247,61 +245,25 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
|
|
|
final_proj: Projection,
|
|
|
target_vocab_info: VocabularyInfo,
|
|
|
) -> None:
|
|
|
- super().__init__()
|
|
|
-
|
|
|
- self.model_dim = decoder.model_dim
|
|
|
+ super().__init__(decoder.model_dim, target_vocab_info)
|
|
|
|
|
|
if encoder is not None:
|
|
|
- if encoder.model_dim != self.model_dim:
|
|
|
- raise ValueError(
|
|
|
- f"`model_dim` of `encoder` and `model_dim` of `decoder` must be equal, but are {encoder.model_dim} and {self.model_dim} instead."
|
|
|
- )
|
|
|
-
|
|
|
self.encoder = encoder
|
|
|
else:
|
|
|
self.register_module("encoder", None)
|
|
|
|
|
|
- if decoder_frontend.model_dim != self.model_dim:
|
|
|
- raise ValueError(
|
|
|
- f"`model_dim` of `decoder_frontend` and `model_dim` of `decoder` must be equal, but are {decoder_frontend.model_dim} and {self.model_dim} instead."
|
|
|
- )
|
|
|
-
|
|
|
self.decoder_frontend = decoder_frontend
|
|
|
self.decoder = decoder
|
|
|
|
|
|
self.final_proj = final_proj
|
|
|
|
|
|
- self.target_vocab_info = target_vocab_info
|
|
|
-
|
|
|
- def forward(
|
|
|
- self,
|
|
|
- text_decoder_output: Tensor,
|
|
|
- text_decoder_padding_mask: Optional[PaddingMask],
|
|
|
- target_seqs: Tensor,
|
|
|
- target_padding_mask: Optional[PaddingMask],
|
|
|
- ) -> SequenceModelOutput:
|
|
|
- encoder_output, encoder_padding_mask = self.encode(
|
|
|
- text_decoder_output, text_decoder_padding_mask
|
|
|
- )
|
|
|
-
|
|
|
- decoder_output, decoder_padding_mask = self.decode(
|
|
|
- target_seqs,
|
|
|
- target_padding_mask,
|
|
|
- encoder_output,
|
|
|
- encoder_padding_mask,
|
|
|
- )
|
|
|
-
|
|
|
- return self.project(decoder_output, decoder_padding_mask)
|
|
|
-
|
|
|
def encode(
|
|
|
- self,
|
|
|
- text_decoder_output: Tensor,
|
|
|
- text_decoder_padding_mask: Optional[PaddingMask],
|
|
|
+ self, seqs: Tensor, padding_mask: Optional[PaddingMask]
|
|
|
) -> Tuple[Tensor, Optional[PaddingMask]]:
|
|
|
if self.encoder is None:
|
|
|
- return text_decoder_output, text_decoder_padding_mask
|
|
|
+ return seqs, padding_mask
|
|
|
|
|
|
- return self.encoder(text_decoder_output, text_decoder_padding_mask) # type: ignore[no-any-return]
|
|
|
+ return self.encoder(seqs, padding_mask) # type: ignore[no-any-return]
|
|
|
|
|
|
def decode(
|
|
|
self,
|