Guillaume Wenzek пре 1 година
родитељ
комит
18c919b0b8
1 измењених фајлова са 20 додато и 13 уклоњено
  1. 20 13
      ggml/test_unity_cpp.py

+ 20 - 13
ggml/test_unity_cpp.py

@@ -161,7 +161,8 @@ def _name(tensor: ggml.ggml_tensor_p) -> bytes:
         return b"???"
 
 
-def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p) -> None:
+@pytest.mark.parametrize("flash_attn", [False, True])
+def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: bool) -> None:
     x = torch.empty((2, 21, 1024))
     torch.random.manual_seed(0)
     torch.nn.init.uniform_(x, -1, 1)
@@ -172,22 +173,29 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p) -> None:
     # Note: we use different lengths for queries and keys,
     # this tests the implementation in decoding context too.
     # Note2: ggml_flash_attn requires that we have more keys than queries
-    gxq = ggml.from_numpy(ctx, x[:, :11, :].contiguous())
-    gx = ggml.from_numpy(ctx, x)
-    ggml.ggml_set_name(gx, b"x")
+    if flash_attn:
+        xq = x[:, :11, :]
+        xk = x
+    else:
+        xq = x
+        xk = x[:, :13, :]
+
+    gxq = ggml.from_numpy(ctx, xq.contiguous())
+    gxk = ggml.from_numpy(ctx, xk)
+    ggml.ggml_set_name(gxk, b"xk")
     gy = ggml.forward(
         "MultiheadAttention",
         g_model,
         "text_encoder.layers.0.self_attn",
         gxq,
-        gx,
-        gx,
+        gxk,
+        gxk,
         None,  # TODO: tests with causal attention masks
     )
     gf = ggml.ggml_build_forward(gy)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
 
-    q_exp = self_attn.q_proj(x[:, :11, :]).numpy()
+    q_exp = self_attn.q_proj(xq).numpy()
 
     y = ggml.to_numpy(gy)
     nodes = {}
@@ -201,14 +209,14 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p) -> None:
     attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
     self_attn.register_attn_weight_hook(attn_weights_hook)
 
-    y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
+    y_exp = self_attn(xq, None, xk, xk).numpy()
 
     q = nodes[b"q"]
     assert q.shape == q_exp.shape
     assert np.allclose(q_exp, q, atol=1e-5)
 
     # with flash_attn we don't have attn_weights
-    if not UNITY_FLASH_ATTN:
+    if not flash_attn:
         attn_weights = nodes[b"attn_weights"]
         [attn_weights_exp] = attn_weights_hook._storage
         attn_weights_exp = attn_weights_exp.numpy()
@@ -217,13 +225,12 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p) -> None:
         # so the error isn't that small
         assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
         # But the sums should be close to 1
-        assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, 11)))
+        assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, 1)))
         # And the maximum index should match the original ones.
-        assert np.allclose(
-            np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
+        assert np.allclose(np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
         )
     assert y.shape == y_exp.shape
-    assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
+    assert np.allclose(y_exp, y, atol=1e-4 if flash_attn else 1e-2)
 
 
 def test_StandardTransformerEncoderLayer_forward(