Browse Source

Statickv (#52)

Can Balioglu 1 year ago
parent
commit
d314c9c4da
1 changed files with 2 additions and 5 deletions
  1. 2 5
      src/seamless_communication/models/unity/t2u_builder.py

+ 2 - 5
src/seamless_communication/models/unity/t2u_builder.py

@@ -434,7 +434,7 @@ class UnitYT2UBuilder:
             )
         else:
             encoder_decoder_attn = self.build_attention(
-                self.config.num_decoder_attn_heads, encoder_decoder=True
+                self.config.num_decoder_attn_heads
             )
 
             ffn = self.build_ffn()
@@ -449,16 +449,13 @@ class UnitYT2UBuilder:
                 dtype=self.dtype,
             )
 
-    def build_attention(
-        self, num_heads: int, encoder_decoder: bool = False
-    ) -> MultiheadAttention:
+    def build_attention(self, num_heads: int) -> MultiheadAttention:
         """Build a Transformer multi-head attention layer."""
         sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
 
         return StandardMultiheadAttention(
             self.config.model_dim,
             num_heads,
-            static_kv=encoder_decoder,
             sdpa=sdpa,
             device=self.device,
             dtype=self.dtype,