Can Balioglu 1 anno fa
parent
commit
1bb19a6aa1

+ 0 - 5
src/seamless_communication/models/unity/adaptor_block.py

@@ -22,7 +22,6 @@ from fairseq2.nn.transformer import (
     TransformerEncoderLayer,
     TransformerEncoderLayer,
     create_standard_layer_norm,
     create_standard_layer_norm,
 )
 )
-from fairseq2.nn.utils.module import check_model_dim
 from fairseq2.typing import DataType, Device
 from fairseq2.typing import DataType, Device
 from overrides import final as finaloverride
 from overrides import final as finaloverride
 from torch import Tensor
 from torch import Tensor
@@ -96,8 +95,6 @@ class UnitYEncoderAdaptor(TransformerEncoder):
 
 
         self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
         self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
 
 
-        check_model_dim(self)
-
     @finaloverride
     @finaloverride
     def forward(
     def forward(
         self,
         self,
@@ -241,8 +238,6 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
         else:
         else:
             self.register_module("ffn_dropout", None)
             self.register_module("ffn_dropout", None)
 
 
-        check_model_dim(self)
-
     @finaloverride
     @finaloverride
     def forward(
     def forward(
         self,
         self,

+ 0 - 4
src/seamless_communication/models/unity/model.py

@@ -14,7 +14,6 @@ from fairseq2.nn.incremental_state import IncrementalStateBag
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.projection import Projection
 from fairseq2.nn.projection import Projection
 from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder
 from fairseq2.nn.transformer import TransformerDecoder, TransformerEncoder
-from fairseq2.nn.utils.module import check_model_dim
 from overrides import final as finaloverride
 from overrides import final as finaloverride
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Module
 from torch.nn import Module
@@ -94,8 +93,6 @@ class UnitYModel(EncoderDecoderModel):
 
 
         self.pad_idx = pad_idx
         self.pad_idx = pad_idx
 
 
-        check_model_dim(self)
-
     @finaloverride
     @finaloverride
     def encode(
     def encode(
         self, seqs: Tensor, padding_mask: Optional[PaddingMask]
         self, seqs: Tensor, padding_mask: Optional[PaddingMask]
@@ -189,7 +186,6 @@ class UnitYX2TModel(EncoderDecoderModel):
         self.decoder = decoder
         self.decoder = decoder
         self.final_proj = final_proj
         self.final_proj = final_proj
         self.pad_idx = pad_idx
         self.pad_idx = pad_idx
-        check_model_dim(self)
 
 
     @finaloverride
     @finaloverride
     def encode(
     def encode(

+ 0 - 3
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -18,7 +18,6 @@ from fairseq2.nn.transformer import (
 from fairseq2.nn.incremental_state import IncrementalStateBag
 from fairseq2.nn.incremental_state import IncrementalStateBag
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.padding import PaddingMask, apply_padding_mask
 from fairseq2.nn.transformer import create_standard_layer_norm
 from fairseq2.nn.transformer import create_standard_layer_norm
-from fairseq2.nn.utils.module import check_model_dim
 from fairseq2.typing import DataType, Device, finaloverride
 from fairseq2.typing import DataType, Device, finaloverride
 
 
 
 
@@ -166,8 +165,6 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
             model_dim, device=device, dtype=dtype
             model_dim, device=device, dtype=dtype
         )
         )
 
 
-        check_model_dim(self)
-
     @finaloverride
     @finaloverride
     def forward(
     def forward(
         self,
         self,

+ 0 - 3
src/seamless_communication/models/wav2vec2_chunk/encoder.py

@@ -9,7 +9,6 @@ from typing import Iterable, Optional, Tuple, final
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Dropout
 from torch.nn import Dropout
 
 
-from fairseq2.nn.utils.module import check_model_dim
 from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.module_list import ModuleList
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.normalization import LayerNorm
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.padding import PaddingMask
@@ -80,8 +79,6 @@ class ChunkTransformerEncoder(TransformerEncoder):
 
 
         self.layers = layer_list
         self.layers = layer_list
 
 
-        check_model_dim(self)
-
     @finaloverride
     @finaloverride
     def forward(
     def forward(
         self,
         self,