Bladeren bron

Hook (#66)

* Update layer_output_hook
Can Balioglu 1 jaar geleden
bovenliggende
commit
bc32323fa2

+ 5 - 2
src/seamless_communication/models/unit_extraction/wav2vec2_layer_output.py

@@ -112,9 +112,10 @@ class Wav2Vec2LayerOutputModel(nn.Module):
             The batch of sequences to process.
         """
         seqs, padding_mask = self.encoder_frontend(batch.seqs, batch.padding_mask)
+
         w2v2_layer_output = None
 
-        def layer_output_hook(
+        def hook(
             layer_idx: int,
             layer_output: Tensor,
             layer_padding_mask: Optional[PaddingMask],
@@ -130,7 +131,9 @@ class Wav2Vec2LayerOutputModel(nn.Module):
 
             return True
 
-        _, _ = self.encoder(seqs, padding_mask, layer_output_hook=layer_output_hook)
+        with self.encoder.register_layer_output_hook(hook):
+            _, _ = self.encoder(seqs, padding_mask)
 
         assert w2v2_layer_output is not None
+
         return w2v2_layer_output

+ 1 - 5
src/seamless_communication/models/unity/adaptor_block.py

@@ -100,12 +100,8 @@ class UnitYEncoderAdaptor(TransformerEncoder):
         self,
         seqs: Tensor,
         padding_mask: Optional[PaddingMask],
-        *,
-        layer_output_hook: Optional[EncoderLayerOutputHook] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
-        seqs, padding_mask = self.inner(
-            seqs, padding_mask, layer_output_hook=layer_output_hook
-        )
+        seqs, padding_mask = self.inner(seqs, padding_mask)
 
         if self.inner_layer_norm is not None:
             seqs = self.inner_layer_norm(seqs)

+ 7 - 9
src/seamless_communication/models/wav2vec2_chunk/encoder.py

@@ -81,14 +81,12 @@ class ChunkTransformerEncoder(TransformerEncoder):
 
     @finaloverride
     def forward(
-        self,
-        seqs: Tensor,
-        padding_mask: Optional[PaddingMask],
-        *,
-        layer_output_hook: Optional[EncoderLayerOutputHook] = None,
+        self, seqs: Tensor, padding_mask: Optional[PaddingMask]
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
-        if layer_output_hook is not None and self.layers.drop_p > 0.0:
-            raise ValueError("`layer_hook` must be `None` when LayerDrop is enabled.")
+        if self._layer_output_hooks and self.layers.drop_p > 0.0:
+            raise ValueError(
+                "The layer output hooks cannot be run when LayerDrop is enabled."
+            )
 
         if self.preliminary_dropout is not None:
             seqs = self.preliminary_dropout(seqs)
@@ -100,8 +98,8 @@ class ChunkTransformerEncoder(TransformerEncoder):
         for layer_idx, layer in enumerate(self.layers.drop_iter()):
             seqs, padding_mask = layer(seqs, padding_mask, self_attn_mask)
 
-            if layer_output_hook is not None:
-                if not layer_output_hook(layer_idx, seqs, padding_mask, num_layers):
+            for hook in self._layer_output_hooks.values():
+                if not hook(layer_idx, seqs, padding_mask, num_layers):
                     break
 
         return seqs, padding_mask