|
@@ -161,7 +161,8 @@ def _name(tensor: ggml.ggml_tensor_p) -> bytes:
|
|
|
return b"???"
|
|
|
|
|
|
|
|
|
-def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
+@pytest.mark.parametrize("flash_attn", [False, True])
|
|
|
+def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p, flash_attn: bool) -> None:
|
|
|
x = torch.empty((2, 21, 1024))
|
|
|
torch.random.manual_seed(0)
|
|
|
torch.nn.init.uniform_(x, -1, 1)
|
|
@@ -172,22 +173,29 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
# Note: we use different lengths for queries and keys,
|
|
|
# this tests the implementation in decoding context too.
|
|
|
# Note2: ggml_flash_attn requires that we have more keys than queries
|
|
|
- gxq = ggml.from_numpy(ctx, x[:, :11, :].contiguous())
|
|
|
- gx = ggml.from_numpy(ctx, x)
|
|
|
- ggml.ggml_set_name(gx, b"x")
|
|
|
+ if flash_attn:
|
|
|
+ xq = x[:, :11, :]
|
|
|
+ xk = x
|
|
|
+ else:
|
|
|
+ xq = x
|
|
|
+ xk = x[:, :13, :]
|
|
|
+
|
|
|
+ gxq = ggml.from_numpy(ctx, xq.contiguous())
|
|
|
+ gxk = ggml.from_numpy(ctx, xk)
|
|
|
+ ggml.ggml_set_name(gxk, b"xk")
|
|
|
gy = ggml.forward(
|
|
|
"MultiheadAttention",
|
|
|
g_model,
|
|
|
"text_encoder.layers.0.self_attn",
|
|
|
gxq,
|
|
|
- gx,
|
|
|
- gx,
|
|
|
+ gxk,
|
|
|
+ gxk,
|
|
|
None, # TODO: tests with causal attention masks
|
|
|
)
|
|
|
gf = ggml.ggml_build_forward(gy)
|
|
|
ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
|
|
|
- q_exp = self_attn.q_proj(x[:, :11, :]).numpy()
|
|
|
+ q_exp = self_attn.q_proj(xq).numpy()
|
|
|
|
|
|
y = ggml.to_numpy(gy)
|
|
|
nodes = {}
|
|
@@ -201,14 +209,14 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
attn_weights_hook = fairseq2.nn.transformer.StoreAttentionWeights([])
|
|
|
self_attn.register_attn_weight_hook(attn_weights_hook)
|
|
|
|
|
|
- y_exp = self_attn(x[:, :11, :], None, x, x).numpy()
|
|
|
+ y_exp = self_attn(xq, None, xk, xk).numpy()
|
|
|
|
|
|
q = nodes[b"q"]
|
|
|
assert q.shape == q_exp.shape
|
|
|
assert np.allclose(q_exp, q, atol=1e-5)
|
|
|
|
|
|
# with flash_attn we don't have attn_weights
|
|
|
- if not UNITY_FLASH_ATTN:
|
|
|
+ if not flash_attn:
|
|
|
attn_weights = nodes[b"attn_weights"]
|
|
|
[attn_weights_exp] = attn_weights_hook._storage
|
|
|
attn_weights_exp = attn_weights_exp.numpy()
|
|
@@ -217,13 +225,12 @@ def test_MultiheadAttention_forward(ctx: Ctx, g_model: c_void_p) -> None:
|
|
|
# so the error isn't that small
|
|
|
assert np.allclose(attn_weights_exp, attn_weights, atol=1e-3)
|
|
|
# But the sums should be close to 1
|
|
|
- assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, 11)))
|
|
|
+ assert np.allclose(np.sum(attn_weights, axis=-1), np.ones((2 * 16, 1)))
|
|
|
# And the maximum index should match the original ones.
|
|
|
- assert np.allclose(
|
|
|
- np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
|
|
|
+ assert np.allclose(np.argmax(attn_weights_exp, axis=-1), np.argmax(attn_weights, axis=-1)
|
|
|
)
|
|
|
assert y.shape == y_exp.shape
|
|
|
- assert np.allclose(y_exp, y, atol=1e-4 if UNITY_FLASH_ATTN else 1e-2)
|
|
|
+ assert np.allclose(y_exp, y, atol=1e-4 if flash_attn else 1e-2)
|
|
|
|
|
|
|
|
|
def test_StandardTransformerEncoderLayer_forward(
|