소스 검색

Statickv (#52)

Can Balioglu 1 년 전
부모
커밋
d314c9c4da
1개의 변경된 파일2개의 추가작업 그리고 5개의 파일을 삭제
  1. 2 5
      src/seamless_communication/models/unity/t2u_builder.py

+ 2 - 5
src/seamless_communication/models/unity/t2u_builder.py

@@ -434,7 +434,7 @@ class UnitYT2UBuilder:
             )
             )
         else:
         else:
             encoder_decoder_attn = self.build_attention(
             encoder_decoder_attn = self.build_attention(
-                self.config.num_decoder_attn_heads, encoder_decoder=True
+                self.config.num_decoder_attn_heads
             )
             )
 
 
             ffn = self.build_ffn()
             ffn = self.build_ffn()
@@ -449,16 +449,13 @@ class UnitYT2UBuilder:
                 dtype=self.dtype,
                 dtype=self.dtype,
             )
             )
 
 
-    def build_attention(
-        self, num_heads: int, encoder_decoder: bool = False
-    ) -> MultiheadAttention:
+    def build_attention(self, num_heads: int) -> MultiheadAttention:
         """Build a Transformer multi-head attention layer."""
         """Build a Transformer multi-head attention layer."""
         sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
         sdpa = create_default_sdpa(attn_dropout_p=self.config.dropout_p)
 
 
         return StandardMultiheadAttention(
         return StandardMultiheadAttention(
             self.config.model_dim,
             self.config.model_dim,
             num_heads,
             num_heads,
-            static_kv=encoder_decoder,
             sdpa=sdpa,
             sdpa=sdpa,
             device=self.device,
             device=self.device,
             dtype=self.dtype,
             dtype=self.dtype,