Browse Source

Merge pull request #31 from fairinternal/fs2-fix

Fix the remaining BC breaking changes in adaptor layers
Kaushik Ram Sadagopan 2 years ago
parent
commit
634a284f8f

+ 11 - 5
src/seamless_communication/models/unity/adaptor_block.py

@@ -45,6 +45,7 @@ class UnitYEncoderAdaptor(TransformerEncoder):
         self,
         inner: TransformerEncoder,
         adaptor_layers: Iterable[TransformerEncoderLayer],
+        *,
         inner_layer_norm: bool = False,
         layer_norm_fn: Optional[LayerNormFactory] = None,
         device: Optional[Device] = None,
@@ -99,9 +100,12 @@ class UnitYEncoderAdaptor(TransformerEncoder):
         self,
         seqs: Tensor,
         padding_mask: Optional[Tensor],
+        *,
         layer_output_hook: Optional[EncoderLayerOutputHook] = None,
     ) -> Tuple[Tensor, Optional[Tensor]]:
-        seqs, padding_mask = self.inner(seqs, padding_mask, layer_output_hook)
+        seqs, padding_mask = self.inner(
+            seqs, padding_mask, layer_output_hook=layer_output_hook
+        )
 
         if self.inner_layer_norm is not None:
             seqs = self.inner_layer_norm(seqs)
@@ -153,8 +157,9 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
         self,
         self_attn: MultiheadAttention,
         ffn: FeedForwardNetwork,
-        kernel_size: int = 8,
-        stride: int = 8,
+        kernel_size: int,
+        stride: int,
+        *,
         dropout_p: float = 0.1,
         layer_norm_fn: Optional[LayerNormFactory] = None,
         device: Optional[Device] = None,
@@ -331,8 +336,9 @@ class UnitYConformerAdaptorLayer(TransformerEncoderLayer):
     def __init__(
         self,
         block: ConformerBlock,
-        kernel_size: int = 8,
-        stride: int = 8,
+        kernel_size: int,
+        stride: int,
+        *,
         layer_norm: bool = False,
         layer_norm_fn: Optional[LayerNormFactory] = None,
         device: Optional[Device] = None,

+ 3 - 3
src/seamless_communication/models/unity/builder.py

@@ -255,7 +255,7 @@ class UnitYBuilder:
         return UnitYEncoderAdaptor(
             w2v2_encoder,
             layers,
-            self.config.adaptor_layer_norm,
+            inner_layer_norm=self.config.adaptor_layer_norm,
             device=self.device,
             dtype=self.dtype,
         )
@@ -281,7 +281,7 @@ class UnitYBuilder:
             ffn,
             self.config.adaptor_kernel_size,
             self.config.adaptor_stride,
-            self.config.adaptor_dropout_p,
+            dropout_p=self.config.adaptor_dropout_p,
             device=self.device,
             dtype=self.dtype,
         )
@@ -321,7 +321,7 @@ class UnitYBuilder:
             block,
             self.config.adaptor_kernel_size,
             self.config.adaptor_stride,
-            layer_norm,
+            layer_norm=layer_norm,
             device=self.device,
             dtype=self.dtype,
         )