|
@@ -14,7 +14,7 @@ from fairseq2.models.wav2vec2.builder import (
|
|
Wav2Vec2EncoderBuilder,
|
|
Wav2Vec2EncoderBuilder,
|
|
Wav2Vec2EncoderConfig,
|
|
Wav2Vec2EncoderConfig,
|
|
)
|
|
)
|
|
-from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA
|
|
|
|
|
|
+from fairseq2.nn.transformer import SDPA, ShawRelativePositionSDPA, create_default_sdpa
|
|
from fairseq2.typing import DataType, Device
|
|
from fairseq2.typing import DataType, Device
|
|
|
|
|
|
from seamless_communication.models.wav2vec2_chunk.encoder import ChunkTransformerEncoder
|
|
from seamless_communication.models.wav2vec2_chunk.encoder import ChunkTransformerEncoder
|
|
@@ -133,14 +133,17 @@ class Wav2Vec2ChunkEncoderBuilder(Wav2Vec2EncoderBuilder):
|
|
"`shaw_rel_pos_sdpa_config` must be specified when `pos_encoder_type` is 'shaw_relative'."
|
|
"`shaw_rel_pos_sdpa_config` must be specified when `pos_encoder_type` is 'shaw_relative'."
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ sdpa = create_default_sdpa(attn_dropout_p=self.config.attn_dropout_p)
|
|
|
|
+
|
|
sdpa_config = self.config.shaw_rel_pos_sdpa_config
|
|
sdpa_config = self.config.shaw_rel_pos_sdpa_config
|
|
|
|
+
|
|
return ShawRelativePositionSDPA(
|
|
return ShawRelativePositionSDPA(
|
|
self.config.model_dim,
|
|
self.config.model_dim,
|
|
self.config.num_encoder_attn_heads,
|
|
self.config.num_encoder_attn_heads,
|
|
sdpa_config.max_left_rel_pos,
|
|
sdpa_config.max_left_rel_pos,
|
|
max_right_rel_pos=sdpa_config.max_right_rel_pos,
|
|
max_right_rel_pos=sdpa_config.max_right_rel_pos,
|
|
use_rel_pos_values=sdpa_config.use_rel_pos_values,
|
|
use_rel_pos_values=sdpa_config.use_rel_pos_values,
|
|
- attn_dropout_p=self.config.attn_dropout_p,
|
|
|
|
|
|
+ inner_sdpa=sdpa,
|
|
device=self.device,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
dtype=self.dtype,
|
|
)
|
|
)
|