|
@@ -434,7 +434,7 @@ class UnitYT2UBuilder:
|
|
|
)
|
|
|
else:
|
|
|
encoder_decoder_attn = self.build_attention(
|
|
|
- self.config.num_decoder_attn_heads, encoder_decoder=True
|
|
|
+ self.config.num_decoder_attn_heads
|
|
|
)
|
|
|
|
|
|
ffn = self.build_ffn()
|
|
@@ -449,16 +449,13 @@ class UnitYT2UBuilder:
|
|
|
dtype=self.dtype,
|
|
|
)
|
|
|
|
|
|
- def build_attention(
|
|
|
- self, num_heads: int, encoder_decoder: bool = False
|
|
|
- ) -> MultiheadAttention:
|
|
|
+ def build_attention(self, num_heads: int) -> MultiheadAttention:
|
|
|
"""Build a Transformer multi-head attention layer."""
|
|
|
sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
|
|
|
|
|
|
return StandardMultiheadAttention(
|
|
|
self.config.model_dim,
|
|
|
num_heads,
|
|
|
- static_kv=encoder_decoder,
|
|
|
sdpa=sdpa,
|
|
|
device=self.device,
|
|
|
dtype=self.dtype,
|