Bladeren bron

allow flash attn

Guillaume Wenzek 1 jaar geleden
bovenliggende
commit
e7c3b7a4ba
2 gewijzigde bestanden met toevoegingen van 35 en 17 verwijderingen
  1. 11 2
      ggml/examples/unity/fairseq2.cpp
  2. 24 15
      ggml/test_unity_cpp.py

+ 11 - 2
ggml/examples/unity/fairseq2.cpp

@@ -98,15 +98,15 @@ ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads)
     return x;
 }
 
+# define UNITY_FLASH_ATTN
 
-// TODO: borken
 extern "C" ggml_tensor* MultiheadAttention_forward(
     fairseq2_model& model,
     const std::string &prefix,
     ggml_tensor* queries,  // (slen, d_in)
     ggml_tensor* keys,  // (klen, d_in)
     ggml_tensor* values,  // (klen, d_out)
-    ggml_tensor* mask // (klen, slen)  TODO: do we need to pass mask here ?
+    ggml_tensor* mask // (klen, slen)
 ) {
     int slen = queries->ne[1];
     int slenk = keys->ne[1];
@@ -126,6 +126,14 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     v = ggml_cont(ctx, v);
     ggml_set_name(v, "v");
 
+#ifdef UNITY_FLASH_ATTN
+    // For flash_attn, we assume either no masks, or triangular masks.
+    ggml_tensor* attn = ggml_flash_attn(ctx, q, k, v, /*masked*/mask != nullptr);  // (H, S, H_dim)
+    ggml_set_name(attn, "attn");
+    attn = ggml_permute(ctx, attn, 0, 2, 1, 3); // (S, H, H_dim)
+    attn = ggml_cont(ctx, attn);
+    attn = ggml_reshape_2d(ctx, attn, num_heads * head_dim, slen); // (S, H * H_dim)
+#else
     // (H, Sk, H_dim) x (H, S, H_dim) -> (H, S, Sk)
     ggml_tensor* qk = ggml_mul_mat(ctx, k, q);
     ggml_set_name(qk, "qk");
@@ -146,6 +154,7 @@ extern "C" ggml_tensor* MultiheadAttention_forward(
     attn = ggml_transpose(ctx, attn); // (S, H * H_dim)
     // // I'm not sure why this one is needed ...
     attn = ggml_cont(ctx, attn);
+#endif  // UNITY_FLASH_ATTN
     // out -> (S, d_out)
     ggml_tensor* out = Linear_forward(model, prefix + ".output_proj", attn);
     ggml_set_name(out, "out");

+ 24 - 15
ggml/test_unity_cpp.py

@@ -337,6 +337,7 @@ def test_torch_spda_vs_ggml_flash_attn(ctx: Ctx) -> None:
     gq = ggml.from_numpy(ctx, q.numpy())
     gk = ggml.from_numpy(ctx, k.numpy())
     # ggml flash attention expect a different order of axis for v:
+    # (H, slen, H_dim) -> (H, H_dim, slen)
     gv = ggml.from_numpy(ctx, v.transpose(1, 2).contiguous().numpy())
     assert ggml.shape(gv) == (num_heads, d_in, slen)
     gy = ggml.ggml_flash_attn(ctx, gq, gk, gv, True)
@@ -410,23 +411,25 @@ def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
     # TODO: implement spda
     # self_attn.spda = lambda *qkv, **kwargs: qkv[0]
 
+    # Note: we use different lengths for queries and keys,
+    # this tests the implementation in decoding context too.
+    # Note2: ggml_flash_attn requires that we have more keys than queries
+    gxq = ggml.from_numpy(ctx, x[0, :11, :])
     gx = ggml.from_numpy(ctx, x[0])
-    gxk = ggml.from_numpy(ctx, x[0, :11, :])
-    gxv = ggml.from_numpy(ctx, x[0, :11, :])
     ggml.ggml_set_name(gx, b"x")
     gy = ggml.forward(
         "MultiheadAttention",
         g_model,
         "text_encoder.layers.0.self_attn",
+        gxq,
         gx,
-        gxk,
-        gxv,
-        None,
+        gx,
+        None,  # attention mask
     )
     gf = ggml.ggml_build_forward(gy)
     ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
 
-    q_exp = self_attn._project_q(x, None, None).squeeze(0).numpy()
+    q_exp = self_attn._project_q(x[:, :11, :], None, None).squeeze(0).numpy()
 
     y = ggml.to_numpy(gy)
     nodes = {}
@@ -440,7 +443,7 @@ def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
     attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
     self_attn.register_attn_weight_hook(attn_weights_hook)
 
-    y_exp = self_attn(x, None, x[:, :11, :], x[:, :11, :]).numpy()
+    y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
     y_exp = y_exp.squeeze(0)  # remove batch dimension
 
     q = nodes[b"q"]
@@ -448,17 +451,23 @@ def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
     assert np.allclose(q_exp, q, atol=1e-5)
 
     attn_exp, attn_weights_exp = map(lambda t: t.squeeze(0).numpy(), attn_weights_hook._storage[0])
-    attn_weights = nodes[b"attn_weights"]
-    assert attn_weights_exp.shape == attn_weights.shape
-    # GGML is very agressively reducing small softmax weights to 0.
-    # Not sure to what this is due.
-    assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
 
-    attn_exp = attn_exp.transpose(0, 2, 1)
+    # with flash_attn we don't have attn_weights
+    flash_attn = b"attn_weights" not in nodes
+
+    if not flash_attn:
+        attn_weights = nodes[b"attn_weights"]
+        assert attn_weights_exp.shape == attn_weights.shape
+        # GGML is very agressively reducing small softmax weights to 0.
+        # Not sure to what this is due.
+        assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
+        attn_exp = attn_exp.transpose(0, 2, 1)
+
     attn = nodes[b"attn"]
     assert attn_exp.shape == attn.shape
     # Because of rounding errors in softmax, it's even worse here.
-    assert np.allclose(attn_exp, attn, atol=1e-2)
+    # flash attention have a better numerical precision though.
+    assert np.allclose(attn_exp, attn, atol=1e-4 if flash_attn else 1e-2)
 
     assert y.shape == y_exp.shape
-    assert np.allclose(y_exp, y, atol=1e-2)
+    assert np.allclose(y_exp, y, atol=1e-4 if flash_attn else 1e-2)