|
@@ -8,6 +8,7 @@ import fairseq2.nn
|
|
|
import fairseq2.nn.transformer
|
|
|
import logging
|
|
|
import sys
|
|
|
+import functools
|
|
|
from typing import Tuple
|
|
|
from pathlib import Path
|
|
|
from ctypes_utils import Ptr
|
|
@@ -246,7 +247,7 @@ def test_ggml_slice(ctx: Ctx) -> None:
|
|
|
assert np.allclose(a[2:5, :], s1)
|
|
|
|
|
|
|
|
|
-@pytest.mark.xfail(reason="not implemented")
|
|
|
+@pytest.mark.xfail(reason="to_numpy not implemented")
|
|
|
def test_ggml_transpose_and_slice(ctx: Ctx) -> None:
|
|
|
ga = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 10, 5)
|
|
|
a = ggml.to_numpy(ga)
|
|
@@ -281,12 +282,33 @@ def test_numpy_mul_mat(ctx: Ctx) -> None:
|
|
|
# gy = ggml.ggml_add(ctx, ggml.ggml_mul_mat(ctx, gw, gx), gb) # (dim_out, seq_len)
|
|
|
gy = ggml.ggml_mul_mat(ctx, gw, gx) # (dim_out, seq_len)
|
|
|
|
|
|
- gf = ggml.ggml_build_forward(gy)
|
|
|
- ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
|
|
|
+ ggml.build_and_compute(ctx, gy)
|
|
|
|
|
|
- y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
|
|
|
+ y = ggml.to_numpy(gy)
|
|
|
assert np.allclose(y_exp, y)
|
|
|
|
|
|
+@pytest.mark.parametrize("ndim", [2, 3, 4])
|
|
|
+def test_flatten(ctx: Ctx, ndim: int) -> None:
|
|
|
+ shape = [11, 7, 5, 3][:ndim] # Prime numbers to avoid surprises
|
|
|
+ numel = functools.reduce(lambda a, b: a * b, shape, 1)
|
|
|
+ x = torch.arange(numel, dtype=torch.float32).reshape(shape)
|
|
|
+ for torch_dim in range(ndim - 1):
|
|
|
+ ggml_dim = ndim - 1 - torch_dim
|
|
|
+ n = x.shape[torch_dim + 1]
|
|
|
+
|
|
|
+ gx = ggml.from_numpy(ctx, x)
|
|
|
+ gx1 = ggml.ggml_flatten_1d(ctx, gx, ggml_dim - 1)
|
|
|
+ gy = ggml.ggml_unflatten_1d(ctx, gx1, ggml_dim - 1, n)
|
|
|
+
|
|
|
+ x1 = x.flatten(torch_dim, torch_dim + 1)
|
|
|
+ y = x1.unflatten(torch_dim, (-1, n))
|
|
|
+ assert y.shape == x.shape
|
|
|
+ assert np.allclose(y.numpy(), x.numpy())
|
|
|
+ assert x1.shape == ggml.shape(gx1)
|
|
|
+ assert np.allclose(x1.numpy(), ggml.to_numpy(gx1))
|
|
|
+ assert y.shape == ggml.shape(gy)
|
|
|
+ assert np.allclose(y.numpy(), ggml.to_numpy(gy))
|
|
|
+
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def test_torch_spda_vs_ggml_flash_attn(ctx: Ctx) -> None:
|