瀏覽代碼

Rename layer_norm_fn to layer_norm_factory (#46)

Can Balioglu 1 年之前
父節點
當前提交
ae29a800c8

+ 18 - 18
src/seamless_communication/models/unity/adaptor_block.py

@@ -47,7 +47,7 @@ class UnitYEncoderAdaptor(TransformerEncoder):
         adaptor_layers: Iterable[TransformerEncoderLayer],
         *,
         inner_layer_norm: bool = False,
-        layer_norm_fn: Optional[LayerNormFactory] = None,
+        layer_norm_factory: Optional[LayerNormFactory] = None,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
     ) -> None:
@@ -58,20 +58,20 @@ class UnitYEncoderAdaptor(TransformerEncoder):
             The adaptor layers to stack on top of ``inner``.
         :param inner_layer_norm:
             If ``True``, applies Layer Normalization to outputs of ``inner``.
-        :param layer_norm_fn:
+        :param layer_norm_factory:
             The factory to use to construct the Layer Normalization modules.
         """
         model_dim = inner.model_dim
 
         super().__init__(model_dim)
 
-        if layer_norm_fn is None:
-            layer_norm_fn = create_default_layer_norm
+        if layer_norm_factory is None:
+            layer_norm_factory = create_default_layer_norm
 
         self.inner = inner
 
         if inner_layer_norm:
-            self.inner_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
+            self.inner_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
         else:
             self.register_module("inner_layer_norm", None)
 
@@ -91,7 +91,7 @@ class UnitYEncoderAdaptor(TransformerEncoder):
 
         self.adaptor_layers = layer_list
 
-        self.layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
+        self.layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
 
         check_model_dim(self)
 
@@ -161,7 +161,7 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
         stride: int,
         *,
         dropout_p: float = 0.1,
-        layer_norm_fn: Optional[LayerNormFactory] = None,
+        layer_norm_factory: Optional[LayerNormFactory] = None,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
     ) -> None:
@@ -177,20 +177,20 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
         :param dropout_p:
             The dropout probability on outputs of the self attention layer and
             the feed-forward network.
-        :param layer_norm_fn:
+        :param layer_norm_factory:
             The factory to use to construct the Layer Normalization modules.
         """
         model_dim = self_attn.model_dim
 
         super().__init__(model_dim)
 
-        if layer_norm_fn is None:
-            layer_norm_fn = create_default_layer_norm
+        if layer_norm_factory is None:
+            layer_norm_factory = create_default_layer_norm
 
         self.kernel_size = kernel_size
         self.stride = stride
 
-        self.residual_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
+        self.residual_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
 
         self.residual_conv = Conv1d(
             model_dim,
@@ -204,7 +204,7 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
 
         self.residual_activation = GLU(dim=1)
 
-        self.self_attn_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
+        self.self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
 
         self.self_attn_conv = Conv1d(
             model_dim,
@@ -225,7 +225,7 @@ class UnitYTransformerAdaptorLayer(TransformerEncoderLayer):
         else:
             self.register_module("self_attn_dropout", None)
 
-        self.ffn_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
+        self.ffn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
 
         self.ffn = ffn
 
@@ -347,7 +347,7 @@ class UnitYConformerAdaptorLayer(TransformerEncoderLayer):
         stride: int,
         *,
         layer_norm: bool = False,
-        layer_norm_fn: Optional[LayerNormFactory] = None,
+        layer_norm_factory: Optional[LayerNormFactory] = None,
         device: Optional[Device] = None,
         dtype: Optional[DataType] = None,
     ) -> None:
@@ -360,19 +360,19 @@ class UnitYConformerAdaptorLayer(TransformerEncoderLayer):
             The stride for 1D pooling convolutions.
         :param layer_norm:
             If ``True``, applies Layer Normalization to inputs before pooling.
-        :param layer_norm_fn:
+        :param layer_norm_factory:
             The factory to use to construct the Layer Normalization modules.
         """
         super().__init__(block.model_dim)
 
-        if layer_norm_fn is None:
-            layer_norm_fn = create_default_layer_norm
+        if layer_norm_factory is None:
+            layer_norm_factory = create_default_layer_norm
 
         self.kernel_size = kernel_size
         self.stride = stride
 
         if layer_norm:
-            self.layer_norm = layer_norm_fn(self.model_dim, device=device, dtype=dtype)
+            self.layer_norm = layer_norm_factory(self.model_dim, device=device, dtype=dtype)
         else:
             self.register_module("layer_norm", None)
 

+ 3 - 3
src/seamless_communication/models/unity/length_regulator.py

@@ -75,9 +75,9 @@ class VariancePredictor(Module):
             ReLU(),
         )
 
-        layer_norm_fn = create_default_layer_norm
+        layer_norm_factory = create_default_layer_norm
 
-        self.ln1 = layer_norm_fn(var_pred_hidden_dim, device=device, dtype=dtype)
+        self.ln1 = layer_norm_factory(var_pred_hidden_dim, device=device, dtype=dtype)
 
         self.dropout_module = Dropout(p=var_pred_dropout)
 
@@ -95,7 +95,7 @@ class VariancePredictor(Module):
             ReLU(),
         )
 
-        self.ln2 = layer_norm_fn(var_pred_hidden_dim, device=device, dtype=dtype)
+        self.ln2 = layer_norm_factory(var_pred_hidden_dim, device=device, dtype=dtype)
 
         self.proj = Linear(
             var_pred_hidden_dim, 1, bias=True, device=device, dtype=dtype

+ 3 - 3
src/seamless_communication/models/unity/nar_decoder_layer.py

@@ -148,9 +148,9 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
         else:
             self.register_module("self_attn_dropout", None)
 
-        layer_norm_fn = create_default_layer_norm
+        layer_norm_factory = create_default_layer_norm
 
-        self.self_attn_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
+        self.self_attn_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
 
         self.conv1d = conv1d
 
@@ -159,7 +159,7 @@ class NARTransformerDecoderLayer(TransformerDecoderLayer):
         else:
             self.register_module("conv1d_dropout", None)
 
-        self.conv1d_layer_norm = layer_norm_fn(model_dim, device=device, dtype=dtype)
+        self.conv1d_layer_norm = layer_norm_factory(model_dim, device=device, dtype=dtype)
 
         check_model_dim(self)