瀏覽代碼

Derive UnitYT2UModel from EncoderDecoderModel (#77)

Can Balioglu 1 年之前
父節點
當前提交
74364ff4db
共有 1 個文件被更改,包括 6 次插入44 次删除
  1. 6 44
      src/seamless_communication/models/unity/model.py

+ 6 - 44
src/seamless_communication/models/unity/model.py

@@ -8,7 +8,7 @@ from dataclasses import dataclass
 from typing import Optional, Tuple, Union, final
 
 from fairseq2.data import VocabularyInfo
-from fairseq2.models.encoder_decoder import EncoderDecoderModel, Seq2SeqDecoder
+from fairseq2.models.encoder_decoder import EncoderDecoderModel
 from fairseq2.models.sequence import SequenceModelOutput
 from fairseq2.models.transformer.frontend import TransformerFrontend
 from fairseq2.nn.incremental_state import IncrementalStateBag
@@ -228,16 +228,14 @@ class UnitYX2TModel(EncoderDecoderModel):
 
 
 @final
-class UnitYT2UModel(Module, Seq2SeqDecoder):
+class UnitYT2UModel(EncoderDecoderModel):
     """Represents a UnitY T2U model as described in
     :cite:t`https://doi.org/10.48550/arxiv.2212.08055`."""
 
-    model_dim: int
     encoder: Optional[TransformerEncoder]
     decoder_frontend: TransformerFrontend
     decoder: TransformerDecoder
     final_proj: Projection
-    target_vocab_info: VocabularyInfo
 
     def __init__(
         self,
@@ -247,61 +245,25 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
         final_proj: Projection,
         target_vocab_info: VocabularyInfo,
     ) -> None:
-        super().__init__()
-
-        self.model_dim = decoder.model_dim
+        super().__init__(decoder.model_dim, target_vocab_info)
 
         if encoder is not None:
-            if encoder.model_dim != self.model_dim:
-                raise ValueError(
-                    f"`model_dim` of `encoder` and `model_dim` of `decoder` must be equal, but are {encoder.model_dim} and {self.model_dim} instead."
-                )
-
             self.encoder = encoder
         else:
             self.register_module("encoder", None)
 
-        if decoder_frontend.model_dim != self.model_dim:
-            raise ValueError(
-                f"`model_dim` of `decoder_frontend` and `model_dim` of `decoder` must be equal, but are {decoder_frontend.model_dim} and {self.model_dim} instead."
-            )
-
         self.decoder_frontend = decoder_frontend
         self.decoder = decoder
 
         self.final_proj = final_proj
 
-        self.target_vocab_info = target_vocab_info
-
-    def forward(
-        self,
-        text_decoder_output: Tensor,
-        text_decoder_padding_mask: Optional[PaddingMask],
-        target_seqs: Tensor,
-        target_padding_mask: Optional[PaddingMask],
-    ) -> SequenceModelOutput:
-        encoder_output, encoder_padding_mask = self.encode(
-            text_decoder_output, text_decoder_padding_mask
-        )
-
-        decoder_output, decoder_padding_mask = self.decode(
-            target_seqs,
-            target_padding_mask,
-            encoder_output,
-            encoder_padding_mask,
-        )
-
-        return self.project(decoder_output, decoder_padding_mask)
-
     def encode(
-        self,
-        text_decoder_output: Tensor,
-        text_decoder_padding_mask: Optional[PaddingMask],
+        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         if self.encoder is None:
-            return text_decoder_output, text_decoder_padding_mask
+            return seqs, padding_mask
 
-        return self.encoder(text_decoder_output, text_decoder_padding_mask)  # type: ignore[no-any-return]
+        return self.encoder(seqs, padding_mask)  # type: ignore[no-any-return]
 
     def decode(
         self,