Эх сурвалжийг харах

Update Shaw attention init (#101)

Can Balioglu 1 жил өмнө
parent
commit
a618cd43f0

+ 5 - 2
src/seamless_communication/models/wav2vec2_chunk/builder.py

@@ -14,7 +14,7 @@ from fairseq2.models.wav2vec2.builder import (
     Wav2Vec2EncoderBuilder,
     Wav2Vec2EncoderConfig,
 )
-from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA
+from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA, create_default_sdpa
 from fairseq2.typing import DataType, Device
 
 from seamless_communication.models.wav2vec2_chunk.encoder import ChunkTransformerEncoder
@@ -133,14 +133,17 @@ class Wav2Vec2ChunkEncoderBuilder(Wav2Vec2EncoderBuilder):
                     "`shaw_rel_pos_sdpa_config` must be specified when `pos_encoder_type` is 'shaw_relative'."
                 )
 
+            sdpa = create_default_sdpa(attn_dropout_p=self.config.attn_dropout_p)
+
             sdpa_config = self.config.shaw_rel_pos_sdpa_config
+
             return ShawRelativePositionSDPA(
                 self.config.model_dim,
                 self.config.num_encoder_attn_heads,
                 sdpa_config.max_left_rel_pos,
                 max_right_rel_pos=sdpa_config.max_right_rel_pos,
                 use_rel_pos_values=sdpa_config.use_rel_pos_values,
-                attn_dropout_p=self.config.attn_dropout_p,
+                inner_sdpa=sdpa,
                 device=self.device,
                 dtype=self.dtype,
             )