瀏覽代碼

test_flatten

Guillaume Wenzek 1 年之前
父節點
當前提交
b3e6d3c0c7
共有 1 個文件被更改,包括 26 次插入4 次删除
  1. 26 4
      ggml/test_ggml_integration.py

+ 26 - 4
ggml/test_ggml_integration.py

@@ -8,6 +8,7 @@ import fairseq2.nn
 import fairseq2.nn.transformer
 import fairseq2.nn.transformer
 import logging
 import logging
 import sys
 import sys
+import functools
 from typing import Tuple
 from typing import Tuple
 from pathlib import Path
 from pathlib import Path
 from ctypes_utils import Ptr
 from ctypes_utils import Ptr
@@ -246,7 +247,7 @@ def test_ggml_slice(ctx: Ctx) -> None:
     assert np.allclose(a[2:5, :], s1)
     assert np.allclose(a[2:5, :], s1)
 
 
 
 
-@pytest.mark.xfail(reason="not implemented")
+@pytest.mark.xfail(reason="to_numpy not implemented")
 def test_ggml_transpose_and_slice(ctx: Ctx) -> None:
 def test_ggml_transpose_and_slice(ctx: Ctx) -> None:
     ga = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 10, 5)
     ga = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 10, 5)
     a = ggml.to_numpy(ga)
     a = ggml.to_numpy(ga)
@@ -281,12 +282,33 @@ def test_numpy_mul_mat(ctx: Ctx) -> None:
     # gy = ggml.ggml_add(ctx, ggml.ggml_mul_mat(ctx, gw, gx), gb)  # (dim_out, seq_len)
     # gy = ggml.ggml_add(ctx, ggml.ggml_mul_mat(ctx, gw, gx), gb)  # (dim_out, seq_len)
     gy = ggml.ggml_mul_mat(ctx, gw, gx)  # (dim_out, seq_len)
     gy = ggml.ggml_mul_mat(ctx, gw, gx)  # (dim_out, seq_len)
 
 
-    gf = ggml.ggml_build_forward(gy)
-    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+    ggml.build_and_compute(ctx, gy)
 
 
-    y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
+    y = ggml.to_numpy(gy)
     assert np.allclose(y_exp, y)
     assert np.allclose(y_exp, y)
 
 
+@pytest.mark.parametrize("ndim", [2, 3, 4])
+def test_flatten(ctx: Ctx, ndim: int) -> None:
+    shape = [11, 7, 5, 3][:ndim]  # Prime numbers to avoid surprises
+    numel = functools.reduce(lambda a, b: a * b, shape, 1)
+    x = torch.arange(numel, dtype=torch.float32).reshape(shape)
+    for torch_dim in range(ndim - 1):
+        ggml_dim = ndim - 1 - torch_dim
+        n = x.shape[torch_dim + 1]
+
+        gx = ggml.from_numpy(ctx, x)
+        gx1 = ggml.ggml_flatten_1d(ctx, gx, ggml_dim - 1)
+        gy = ggml.ggml_unflatten_1d(ctx, gx1, ggml_dim - 1, n)
+
+        x1 = x.flatten(torch_dim, torch_dim + 1)
+        y = x1.unflatten(torch_dim, (-1, n))
+        assert y.shape == x.shape
+        assert np.allclose(y.numpy(), x.numpy())
+        assert x1.shape == ggml.shape(gx1)
+        assert np.allclose(x1.numpy(), ggml.to_numpy(gx1))
+        assert y.shape == ggml.shape(gy)
+        assert np.allclose(y.numpy(), ggml.to_numpy(gy))
+
 
 
 @torch.no_grad()
 @torch.no_grad()
 def test_torch_spda_vs_ggml_flash_attn(ctx: Ctx) -> None:
 def test_torch_spda_vs_ggml_flash_attn(ctx: Ctx) -> None: