浏览代码

Merge pull request #41 from fairinternal/static_kv

Use static_kv for encoder-decoder attention
Kaushik Ram Sadagopan 2 年之前
父节点
当前提交
6bac442c00
共有 1 个文件被更改,包括 5 次插入2 次删除
  1. 5 2
      src/seamless_communication/models/unity/t2u_builder.py

+ 5 - 2
src/seamless_communication/models/unity/t2u_builder.py

@@ -433,7 +433,7 @@ class UnitYT2UBuilder:
             )
             )
         else:
         else:
             encoder_decoder_attn = self.build_attention(
             encoder_decoder_attn = self.build_attention(
-                self.config.num_decoder_attn_heads
+                self.config.num_decoder_attn_heads, encoder_decoder=True
             )
             )
 
 
             ffn = self.build_ffn()
             ffn = self.build_ffn()
@@ -448,13 +448,16 @@ class UnitYT2UBuilder:
                 dtype=self.dtype,
                 dtype=self.dtype,
             )
             )
 
 
-    def build_attention(self, num_heads: int) -> MultiheadAttention:
+    def build_attention(
+        self, num_heads: int, encoder_decoder: bool = False
+    ) -> MultiheadAttention:
         """Build a Transformer multi-head attention layer."""
         """Build a Transformer multi-head attention layer."""
         sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
         sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
 
 
         return StandardMultiheadAttention(
         return StandardMultiheadAttention(
             self.config.model_dim,
             self.config.model_dim,
             num_heads,
             num_heads,
+            static_kv=encoder_decoder,
             sdpa=sdpa,
             sdpa=sdpa,
             device=self.device,
             device=self.device,
             dtype=self.dtype,
             dtype=self.dtype,