Selaa lähdekoodia

Merge pull request #41 from fairinternal/static_kv

Use static_kv for encoder-decoder attention
Kaushik Ram Sadagopan 2 vuotta sitten
vanhempi
commit
6bac442c00
1 muutettua tiedostoa jossa 5 lisäystä ja 2 poistoa
  1. 5 2
      src/seamless_communication/models/unity/t2u_builder.py

+ 5 - 2
src/seamless_communication/models/unity/t2u_builder.py

@@ -433,7 +433,7 @@ class UnitYT2UBuilder:
             )
         else:
             encoder_decoder_attn = self.build_attention(
-                self.config.num_decoder_attn_heads
+                self.config.num_decoder_attn_heads, encoder_decoder=True
             )
 
             ffn = self.build_ffn()
@@ -448,13 +448,16 @@ class UnitYT2UBuilder:
                 dtype=self.dtype,
             )
 
-    def build_attention(self, num_heads: int) -> MultiheadAttention:
+    def build_attention(
+        self, num_heads: int, encoder_decoder: bool = False
+    ) -> MultiheadAttention:
         """Build a Transformer multi-head attention layer."""
         sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
 
         return StandardMultiheadAttention(
             self.config.model_dim,
             num_heads,
+            static_kv=encoder_decoder,
             sdpa=sdpa,
             device=self.device,
             dtype=self.dtype,