|
@@ -14,7 +14,6 @@ from fairseq2.nn.incremental_state import IncrementalStateBag
|
|
from fairseq2.nn.padding import PaddingMask
|
|
from fairseq2.nn.padding import PaddingMask
|
|
from fairseq2.nn.projection import Projection
|
|
from fairseq2.nn.projection import Projection
|
|
from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder
|
|
from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder
|
|
-from fairseq2.nn.utils.module import check_model_dim
|
|
|
|
from overrides import final as finaloverride
|
|
from overrides import final as finaloverride
|
|
from torch import Tensor
|
|
from torch import Tensor
|
|
from torch.nn import Module
|
|
from torch.nn import Module
|
|
@@ -94,8 +93,6 @@ class UnitYModel(EncoderDecoderModel):
|
|
|
|
|
|
self.pad_idx = pad_idx
|
|
self.pad_idx = pad_idx
|
|
|
|
|
|
- check_model_dim(self)
|
|
|
|
-
|
|
|
|
@finaloverride
|
|
@finaloverride
|
|
def encode(
|
|
def encode(
|
|
self, seqs: Tensor, padding_mask: Optional[PaddingMask]
|
|
self, seqs: Tensor, padding_mask: Optional[PaddingMask]
|
|
@@ -189,7 +186,6 @@ class UnitYX2TModel(EncoderDecoderModel):
|
|
self.decoder = decoder
|
|
self.decoder = decoder
|
|
self.final_proj = final_proj
|
|
self.final_proj = final_proj
|
|
self.pad_idx = pad_idx
|
|
self.pad_idx = pad_idx
|
|
- check_model_dim(self)
|
|
|
|
|
|
|
|
@finaloverride
|
|
@finaloverride
|
|
def encode(
|
|
def encode(
|