|
@@ -337,6 +337,7 @@ def test_torch_spda_vs_ggml_flash_attn(ctx: Ctx) -> None:
|
|
|
gq = ggml.from_numpy(ctx, q.numpy())
|
|
|
gk = ggml.from_numpy(ctx, k.numpy())
|
|
|
# ggml flash attention expect a different order of axis for v:
|
|
|
+ # (H, slen, H_dim) -> (H, H_dim, slen)
|
|
|
gv = ggml.from_numpy(ctx, v.transpose(1, 2).contiguous().numpy())
|
|
|
assert ggml.shape(gv) == (num_heads, d_in, slen)
|
|
|
gy = ggml.ggml_flash_attn(ctx, gq, gk, gv, True)
|
|
@@ -410,23 +411,25 @@ def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
|
|
|
# TODO: implement spda
|
|
|
# self_attn.spda = lambda *qkv, **kwargs: qkv[0]
|
|
|
|
|
|
+ # Note: we use different lengths for queries and keys,
|
|
|
+ # this tests the implementation in decoding context too.
|
|
|
+ # Note2: ggml_flash_attn requires that we have more keys than queries
|
|
|
+ gxq = ggml.from_numpy(ctx, x[0, :11, :])
|
|
|
gx = ggml.from_numpy(ctx, x[0])
|
|
|
- gxk = ggml.from_numpy(ctx, x[0, :11, :])
|
|
|
- gxv = ggml.from_numpy(ctx, x[0, :11, :])
|
|
|
ggml.ggml_set_name(gx, b"x")
|
|
|
gy = ggml.forward(
|
|
|
"MultiheadAttention",
|
|
|
g_model,
|
|
|
"text_encoder.layers.0.self_attn",
|
|
|
+ gxq,
|
|
|
gx,
|
|
|
- gxk,
|
|
|
- gxv,
|
|
|
- None,
|
|
|
+ gx,
|
|
|
+ None, # attention mask
|
|
|
)
|
|
|
gf = ggml.ggml_build_forward(gy)
|
|
|
ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
|
|
|
- q_exp = self_attn._project_q(x, None, None).squeeze(0).numpy()
|
|
|
+ q_exp = self_attn._project_q(x[:, :11, :], None, None).squeeze(0).numpy()
|
|
|
|
|
|
y = ggml.to_numpy(gy)
|
|
|
nodes = {}
|
|
@@ -440,7 +443,7 @@ def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
|
|
|
attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
|
|
|
self_attn.register_attn_weight_hook(attn_weights_hook)
|
|
|
|
|
|
- y_exp = self_attn(x, None, x[:, :11, :], x[:, :11, :]).numpy()
|
|
|
+ y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
|
|
|
y_exp = y_exp.squeeze(0) # remove batch dimension
|
|
|
|
|
|
q = nodes[b"q"]
|
|
@@ -448,17 +451,23 @@ def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
|
|
|
assert np.allclose(q_exp, q, atol=1e-5)
|
|
|
|
|
|
attn_exp, attn_weights_exp = map(lambda t: t.squeeze(0).numpy(), attn_weights_hook._storage[0])
|
|
|
- attn_weights = nodes[b"attn_weights"]
|
|
|
- assert attn_weights_exp.shape == attn_weights.shape
|
|
|
- # GGML is very agressively reducing small softmax weights to 0.
|
|
|
- # Not sure to what this is due.
|
|
|
- assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
|
|
|
|
|
|
- attn_exp = attn_exp.transpose(0, 2, 1)
|
|
|
+ # with flash_attn we don't have attn_weights
|
|
|
+ flash_attn = b"attn_weights" not in nodes
|
|
|
+
|
|
|
+ if not flash_attn:
|
|
|
+ attn_weights = nodes[b"attn_weights"]
|
|
|
+ assert attn_weights_exp.shape == attn_weights.shape
|
|
|
+ # GGML is very agressively reducing small softmax weights to 0.
|
|
|
+ # Not sure to what this is due.
|
|
|
+ assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
|
|
|
+ attn_exp = attn_exp.transpose(0, 2, 1)
|
|
|
+
|
|
|
attn = nodes[b"attn"]
|
|
|
assert attn_exp.shape == attn.shape
|
|
|
# Because of rounding errors in softmax, it's even worse here.
|
|
|
- assert np.allclose(attn_exp, attn, atol=1e-2)
|
|
|
+ # flash attention have a better numerical precision though.
|
|
|
+ assert np.allclose(attn_exp, attn, atol=1e-4 if flash_attn else 1e-2)
|
|
|
|
|
|
assert y.shape == y_exp.shape
|
|
|
- assert np.allclose(y_exp, y, atol=1e-2)
|
|
|
+ assert np.allclose(y_exp, y, atol=1e-4 if flash_attn else 1e-2)
|