فهرست منبع

Fix BC breaking changes to current state of fairseq2

Can Balioglu 1 سال پیش
والد
کامیت
246ec1b13b

+ 1 - 1
requirements.txt

@@ -3,4 +3,4 @@ datasets
 torchaudio
 soundfile
 librosa
-fairseq2==0.1.*
+fairseq2==0.2.*

+ 6 - 1
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py

@@ -118,12 +118,17 @@ class Wav2Vec2LayerOutputModel(nn.Module):
             layer_output: Tensor,
             layer_padding_mask: Optional[Tensor],
             num_layers: int,
-        ) -> None:
+        ) -> bool:
             nonlocal w2v2_layer_output
 
             if layer_idx == out_layer_idx:
                 w2v2_layer_output = layer_output
 
+                # We don't need to execute the remaining layers.
+                return False
+
+            return True
+
         _, _ = self.encoder(seqs, padding_mask, layer_output_hook)
 
         assert w2v2_layer_output is not None

+ 5 - 5
src/seamless_communication/models/unity/adaptor_block.py

@@ -70,7 +70,7 @@ class UnitYEncoderAdaptor(TransformerEncoder):
         self.inner = inner
 
         if inner_layer_norm:
-            self.inner_layer_norm = layer_norm_fn(model_dim, device, dtype)
+            self.inner_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
         else:
             self.register_module("inner_layer_norm", None)
 
@@ -90,7 +90,7 @@ class UnitYEncoderAdaptor(TransformerEncoder):
 
         self.adaptor_layers = layer_list
 
-        self.layer_norm = layer_norm_fn(model_dim, device, dtype)
+        self.layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
 
         check_model_dim(self)
 
@@ -185,7 +185,7 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
         self.kernel_size = kernel_size
         self.stride = stride
 
-        self.residual_layer_norm = layer_norm_fn(model_dim, device, dtype)
+        self.residual_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
 
         self.residual_conv = Conv1d(
             model_dim,
@@ -199,7 +199,7 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
 
         self.residual_activation = GLU(dim=1)
 
-        self.self_attn_layer_norm = layer_norm_fn(model_dim, device, dtype)
+        self.self_attn_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
 
         self.self_attn_conv = Conv1d(
             model_dim,
@@ -220,7 +220,7 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
         else:
             self.register_module("self_attn_dropout", None)
 
-        self.ffn_layer_norm = layer_norm_fn(model_dim, device, dtype)
+        self.ffn_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
 
         self.ffn = ffn
 

+ 13 - 6
src/seamless_communication/models/unity/builder.py

@@ -218,7 +218,7 @@ class UnitYBuilder:
             text_encoder_frontend = None
             text_encoder = None
 
-        final_proj = TiedProjection(text_embed.weight)
+        final_proj = TiedProjection(text_embed.weight, bias=None)
 
         if self.t2u_builder is None:
             t2u_model = None
@@ -271,6 +271,7 @@ class UnitYBuilder:
         ffn = StandardFeedForwardNetwork(
             self.config.model_dim,
             self.w2v2_encoder_builder.config.ffn_inner_dim,
+            bias=True,
             device=self.device,
             dtype=self.dtype,
         )
@@ -353,18 +354,23 @@ def create_unity_model(
         The data type of module parameters and buffers.
     """
     w2v2_encoder_builder = Wav2Vec2EncoderBuilder(
-        config.w2v2_encoder_config, device, dtype
+        config.w2v2_encoder_config, device=device, dtype=dtype
     )
 
-    nllb_builder = NllbBuilder(config.nllb_config, device, dtype)
+    nllb_builder = NllbBuilder(config.nllb_config, device=device, dtype=dtype)
 
     if config.t2u_config is None:
         t2u_builder = None
     else:
-        t2u_builder = UnitYT2UBuilder(config.t2u_config, device, dtype)
+        t2u_builder = UnitYT2UBuilder(config.t2u_config, device=device, dtype=dtype)
 
     unity_builder = UnitYBuilder(
-        config, w2v2_encoder_builder, nllb_builder, t2u_builder, device, dtype
+        config,
+        w2v2_encoder_builder,
+        nllb_builder,
+        t2u_builder,
+        device=device,
+        dtype=dtype,
     )
 
     return unity_builder.build_model()
@@ -486,7 +492,7 @@ class UnitYT2UBuilder:
         decoder_frontend = self.build_decoder_frontend(embed)
         decoder = self.build_decoder()
 
-        final_proj = TiedProjection(embed.weight)
+        final_proj = TiedProjection(embed.weight, bias=None)
 
         return UnitYT2UModel(
             encoder,
@@ -603,6 +609,7 @@ class UnitYT2UBuilder:
         return StandardFeedForwardNetwork(
             self.config.model_dim,
             self.config.ffn_inner_dim,
+            bias=True,
             norm_order=TransformerNormOrder.PRE,
             device=self.device,
             dtype=self.dtype,

+ 5 - 1
src/seamless_communication/models/unity/loader.py

@@ -231,7 +231,11 @@ class UnitYLoader(ModelLoader[UnitYModel, UnitYConfig]):
 
 
 load_unity_model = UnitYLoader(
-    asset_store, download_manager, create_unity_model, unity_archs
+    asset_store,
+    download_manager,
+    create_unity_model,
+    unity_archs,
+    restrict_checkpoints=False,
 )
 
 

+ 20 - 6
src/seamless_communication/models/unity/model.py

@@ -138,10 +138,16 @@ class UnitYModel(EncoderDecoderModel):
         encoder_padding_mask: Optional[Tensor],
         state_bag: Optional[IncrementalStateBag] = None,
     ) -> Tuple[Tensor, Optional[Tensor]]:
-        seqs, padding_mask = self.text_decoder_frontend(seqs, seq_lens, state_bag)
+        seqs, padding_mask = self.text_decoder_frontend(
+            seqs, seq_lens, state_bag=state_bag
+        )
 
         return self.text_decoder(  # type: ignore[no-any-return]
-            seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
+            seqs,
+            padding_mask,
+            encoder_output,
+            encoder_padding_mask,
+            state_bag=state_bag,
         )
 
     @finaloverride
@@ -199,10 +205,14 @@ class UnitYX2TModel(EncoderDecoderModel):
         encoder_padding_mask: Optional[Tensor],
         state_bag: Optional[IncrementalStateBag] = None,
     ) -> Tuple[Tensor, Optional[Tensor]]:
-        seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag)
+        seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag=state_bag)
 
         return self.decoder(  # type: ignore[no-any-return]
-            seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
+            seqs,
+            padding_mask,
+            encoder_output,
+            encoder_padding_mask,
+            state_bag=state_bag,
         )
 
     @finaloverride
@@ -292,10 +302,14 @@ class UnitYT2UModel(Module, Seq2SeqDecoder):
         encoder_padding_mask: Optional[Tensor],
         state_bag: Optional[IncrementalStateBag] = None,
     ) -> Tuple[Tensor, Optional[Tensor]]:
-        seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag)
+        seqs, padding_mask = self.decoder_frontend(seqs, seq_lens, state_bag=state_bag)
 
         return self.decoder(  # type: ignore[no-any-return]
-            seqs, padding_mask, encoder_output, encoder_padding_mask, state_bag
+            seqs,
+            padding_mask,
+            encoder_output,
+            encoder_padding_mask,
+            state_bag=state_bag,
         )
 
     def project(