Guillaume Wenzek преди 2 години
родител
ревизия
06d4ed1475
променени са 4 файла, в които са добавени 210 реда и са изтрити 32 реда
  1. 48 1
      ggml/examples/unity/fairseq2.cpp
  2. 25 11
      ggml/ggml.py
  3. 1 3
      ggml/ggml_convert.py
  4. 136 17
      ggml/test_unity_cpp.py

+ 48 - 1
ggml/examples/unity/fairseq2.cpp

@@ -162,5 +162,52 @@ void MultiheadAttention_init(
     self.bias_v = ggml_new_tensor_3d(model.ctx, GGML_TYPE_F32, num_heads, 1, model_dim / num_heads);
 }
 
+ggml_tensor* reshape_num_head(ggml_context* ctx, ggml_tensor* x, int num_heads) {
+    int slen = x->ne[0];
+    // (S, M) -> (S, K_proj)
+    x = ggml_reshape_3d(ctx, x, slen, num_heads, x->ne[1] / num_heads);
+    // (S, K_proj) -> (H, S, K_h)
+    return ggml_transpose(ctx, x);
+}
+
+
+
+extern "C" ggml_tensor* // (d_in, seq_len)
+MultiheadAttention_forward(
+    fairseq2_model& model,
+    const std::string &prefix,
+    ggml_tensor* queries,  // (d_in, len_q)
+    ggml_tensor* keys,  // (d_in, len_k)
+    ggml_tensor* values,  // (d_out, len_k)
+    ggml_tensor* mask // (seq_len, len_q)
+) {
+    int num_heads = 16;
+    ggml_context* ctx = model.ctx;
+    ggml_tensor* q = Linear_forward(model, prefix + ".q_proj", queries);
+    q = reshape_num_head(ctx, q, num_heads);
+    ggml_tensor* k = Linear_forward(model, prefix + ".k_proj", keys);
+    k = reshape_num_head(ctx, k, num_heads);
+    ggml_tensor* v = Linear_forward(model, prefix + ".q_proj", queries);
+    v = reshape_num_head(ctx, v, num_heads);
+
+    ggml_tensor* attn = ggml_flash_attn(model.ctx, q, k, v, /*masked*/true);
+    attn = Linear_forward(model, prefix + ".output_proj", attn);
+    return attn;
+    // ggml_tensor* attn = SDPA_forward(q, k, v, nullptr);
+    // // (H, S, V_h) -> (S, H, V_h)
+    // attn = ggml_transpose(ctx, attn);
+    // // (S, H, V_h) -> (S, V_proj)
+    // attn = ggml_reshape_3d()
+}
 
-// void TransformerDecoderLayer_init(TransformerDecoderLayer& self);
+// extern "C" ggml_tensor* // (d_out, seq_len)
+// SDPA_forward(
+//     fairseq2_model& model,
+//     const std::string &prefix,
+//     ggml_tensor* queries,  // (d_in, len_q)
+//     ggml_tensor* keys,  // (d_in, len_k)
+//     ggml_tensor* values,  // (d_out, len_k)
+//     ggml_tensor* mask // (seq_len, len_q)
+// ) {
+//     return queries;
+// }

+ 25 - 11
ggml/ggml.py

@@ -44,7 +44,7 @@ def shape(tensor: Union[ggml_tensor, ggml_tensor_p]) -> Tuple[int, ...]:
     if isinstance(tensor, ctypes._Pointer):
         tensor = tensor.contents
     ndims = tensor.n_dims
-    return tuple([tensor.ne[i] for i in range(ndims)])
+    return tuple([tensor.ne[i] for i in range(ndims)[::-1]])
 
 
 def nb(tensor: Union[ggml_tensor, ggml_tensor_p]) -> Tuple[int, ...]:
@@ -70,7 +70,7 @@ def to_numpy(tensor: Union[ggml_tensor, ggml_tensor_p]) -> np.ndarray:
     t_shape = shape(tensor)
 
     # Convert the ggml data pointer to a pointer to ints with the same size (float16 -> uint16)
-    # This is needed because Python ctypes doesn't have "float16", and as_array only works with ctypes pointer
+    # This is needed because Python ctypes doesn't have "float16", and `as_array` only works with ctypes
     type_size = ggml_type_size(tensor.type)
     int_width: type = getattr(ctypes, f"c_uint{8 * type_size}")
     ptr = ctypes.cast(tensor.data, ctypes.POINTER(int_width))
@@ -84,7 +84,7 @@ def to_numpy(tensor: Union[ggml_tensor, ggml_tensor_p]) -> np.ndarray:
     return res
 
 
-GgmlShape = ctypes.c_int64 * GGML_MAX_DIMS
+GgmlNElem = ctypes.c_int64 * GGML_MAX_DIMS
 GgmlNBytes = ctypes.c_uint64 * GGML_MAX_DIMS
 
 
@@ -95,12 +95,15 @@ def from_file(
     return from_numpy(ctx, data)
 
 
-def _pad_shape(shape: Tuple[int, ...]) -> Tuple[int, int, int, int]:
-    if len(shape) >= 4:
-        return shape  # type: ignore
+def _shape_to_ne(shape: Tuple[int, ...]) -> Tuple[int, int, int, int]:
+    # in GGML ne[0] indicates the contiguous dimension, ie the last one in numpy and torch
+    ne = shape[::-1]
+    if len(ne) >= GGML_MAX_DIMS:
+        return   # type: ignore
 
-    padding = (1,) * (4 - len(shape))
-    return shape + padding  # type: ignore
+    # ne is always of the same length
+    padding = (1,) * (GGML_MAX_DIMS - len(ne))
+    return ne + padding  # type: ignore
 
 
 def _compute_nbytes(
@@ -123,9 +126,9 @@ def from_numpy(
     tensor_p = ggml_new_tensor_1d(ctx, gtype, 0)
     # Fill out the correct dimensions and shape.
     tensor_p.contents.n_dims = array.ndim
-    shape = _pad_shape(array.shape)
-    tensor_p.contents.ne = GgmlShape(*shape)
-    tensor_p.contents.nb = GgmlNBytes(*_compute_nbytes(shape, gtype))
+    ne = _shape_to_ne(array.shape)
+    tensor_p.contents.ne = GgmlNElem(*ne)
+    tensor_p.contents.nb = GgmlNBytes(*_compute_nbytes(ne, gtype))
     # point the tensor data to the content of the numpy array.
     tensor_p.contents.data = array.ctypes.data_as(ctypes.c_void_p)
     # print(f"array: {array.shape} @0x{array.ctypes.data_as(ctypes.c_void_p)}")
@@ -136,6 +139,16 @@ def from_numpy(
     return tensor_p
 
 
+def ggml_can_mul_mat(t0: ggml_tensor_p, t1: ggml_tensor_p) -> bool:
+    assert GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"
+
+    return (
+        (t0.contents.ne[0] == t1.contents.ne[0])
+        and (t1.contents.ne[2] % t0.contents.ne[2] == 0)
+        and (t1.contents.ne[3] % t0.contents.ne[3] == 0)
+    )
+
+
 class NativeObj:
     AllocFn = Callable[[], ctypes.c_void_p]
     FreeFn = Callable[[ctypes.c_void_p], None]
@@ -225,6 +238,7 @@ def CppStr(content: str) -> NativeObj:
 
 lib.unity_model_load.argtypes = [ctypes.c_char_p, ctypes.c_void_p, ctypes.c_void_p]
 
+
 def unity_model_load(model_file: Path) -> Tuple[NativeObj, NativeObj]:
     model = UnityModel()
     vocab = GptVocab()

+ 1 - 3
ggml/ggml_convert.py

@@ -50,9 +50,7 @@ def write_ggml_file(
         # Size of each tensor
         byte_size = sum(x.numel() * x.element_size() for x in state_dict.values())
         # + tensor overhead
-        byte_size += ggml.ggml_tensor_overhead() * len(state_dict)
-        # + some slack cause I'm bad at math
-        byte_size = int(byte_size * 1.2)
+        byte_size += ggml.ggml_tensor_overhead() * (len(state_dict) + 10)
         hparams["model_byte_size"] = byte_size
         logging.warning(f"Saving a ggml file with {len(state_dict)} tensors, for an estimated amount of {byte_size / (1024**3)} GGML Gb")
     # 6877961321223123048

+ 136 - 17
ggml/test_unity_cpp.py

@@ -4,6 +4,7 @@ import torch
 import pytest
 import numpy as np
 import torch
+import fairseq2.nn
 from typing import Any
 from pathlib import Path
 from typing import Iterator
@@ -51,16 +52,40 @@ def test_ggml_bindings_work(ctx: Ctx) -> None:
     output = ggml.ggml_get_f32_1d(f, 0)
     assert output == 16.0
 
+def test_ggml_matmul(ctx: Ctx) -> None:
+    # Instantiate tensors
+    a = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 4, 2)
+    x = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 4, 3)
+
+    # Use ggml operations to build a computational graph
+    y = ggml.ggml_mul_mat(ctx, a, x)
+    assert ggml.shape(y) == (3, 2)
+    gf = ggml.ggml_build_forward(y)
+
+    # Set the input values
+    ggml.ggml_set_f32(x, 0.0)
+    for i in range(4 * 3):
+        ggml.ggml_set_f32_1d(x, i, i)
+
+
+    ggml.ggml_set_f32(a, 0.0)
+    ggml.ggml_set_f32_1d(a, 1, 1.0)
+    ggml.ggml_set_f32_1d(a, 7, 1.0)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+    output = [[ggml.ggml_get_f32_1d(y, j * 2 + i) for j in range(3)] for i in range(2)]
+    assert output == [[1, 5, 9], [3, 7, 11]]
+
 
 def test_shape_works(ctx: Ctx) -> None:
+    """GGML shape order convention is the reverse from numpy"""
     a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
     assert ggml.shape(a) == (10,)
 
     b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
-    assert ggml.shape(b) == (11, 21)
+    assert ggml.shape(b) == (21, 11)
 
     c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
-    assert ggml.shape(c) == (12, 22, 32)
+    assert ggml.shape(c) == (32, 22, 12)
 
 
 def test_nb_works(ctx: Ctx) -> None:
@@ -88,16 +113,43 @@ def test_strides_works(ctx: Ctx) -> None:
 
 def test_to_numpy_works_with_f32(ctx: Ctx) -> None:
     a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
-    a = ggml.ggml_set_f32(a, 2.14)
-    assert np.allclose(ggml.to_numpy(a), np.ones((10,)) * 2.14)
-
+    na = ggml.to_numpy(a)
+    for i in range(10):
+        ggml.ggml_set_f32_1d(a, i, i)
+    assert na[5] == 5
+    assert np.allclose(na, np.array(range(10), dtype=np.float32))
+    ggml.ggml_set_f32_1d(a, 5, -1.5)
+    assert na[5] == -1.5
+
+    # Note: GGML order of dims is reversed wrt numpy shapes
     b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
-    b = ggml.ggml_set_f32(b, 2.14)
-    assert np.allclose(ggml.to_numpy(b), np.ones((11, 21)) * 2.14)
+    for i in range(11 * 21):
+        ggml.ggml_set_f32_1d(b, i, i)
+    nb = ggml.to_numpy(b)
+    # assert nb.shape == (21, 11)
+    assert nb[0, 5] == 5
+    assert nb[3, 5] == 11 * 3 + 5
+    assert np.allclose(nb, np.array(range(11 * 21), dtype=np.float32).reshape(ggml.shape(b)))
+    ggml.ggml_set_f32_1d(b, 11 * 3 + 5, -1.5)
+    assert nb[3, 5] == -1.5
+
+    sum_rows = ggml.ggml_sum_rows(ctx, b);
+    gf = ggml.ggml_build_forward(sum_rows)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+    np_sum_rows = np.sum(nb, axis=-1, keepdims=True)
+    assert np_sum_rows.shape == ggml.shape(sum_rows)
+    for i in range(11):
+        assert np_sum_rows[i] == ggml.ggml_get_f32_1d(sum_rows, i)
 
     c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
-    c = ggml.ggml_set_f32(c, 2.14)
-    assert np.allclose(ggml.to_numpy(c), np.ones((12, 22, 32)) * 2.14)
+    for i in range(12 * 22 * 32):
+        ggml.ggml_set_f32_1d(c, i, i)
+    nc = ggml.to_numpy(c)
+    assert ggml.shape(c) == (32, 22, 12)
+    assert nc[3, 5, 11] == 22 * 12 * 3 + 12 * 5 + 11
+    assert np.allclose(nc, np.array(range(12 * 22 * 32), dtype=np.float32).reshape(ggml.shape(c)))
+    ggml.ggml_set_f32_1d(c, 22 * 12 * 3 + 12 * 5 + 11, -1.5)
+    assert nc[3, 5, 11] == -1.5
 
 
 def test_from_numpy_works_with_f32(ctx: Ctx) -> None:
@@ -111,7 +163,7 @@ def test_from_numpy_works_with_f32(ctx: Ctx) -> None:
     ga = ggml.from_numpy(ctx, a)
     assert ggml.shape(ga) == (11, 21)
     assert ggml.nb(ga) == ggml.nb(
-        ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
+        ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, *a.shape[::-1])
     )
     assert np.allclose(a, ggml.to_numpy(ga))
 
@@ -119,7 +171,7 @@ def test_from_numpy_works_with_f32(ctx: Ctx) -> None:
     ga = ggml.from_numpy(ctx, a)
     assert ggml.shape(ga) == (12, 22, 32)
     assert ggml.nb(ga) == ggml.nb(
-        ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
+        ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, *a.shape[::-1])
     )
     assert np.allclose(a, ggml.to_numpy(ga))
 
@@ -127,16 +179,25 @@ def test_from_numpy_works_with_f32(ctx: Ctx) -> None:
 def test_to_numpy_works_with_f16(ctx: Ctx) -> None:
     # We explicitly fill the tensor otherwise they might have non-zero values in them.
     a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F16, 10)
-    a = ggml.ggml_set_f32(a, 2.14)
-    assert np.allclose(ggml.to_numpy(a), np.ones((10,), dtype=np.float16) * 2.14)
+    na = ggml.to_numpy(a)
+    ggml.ggml_set_f32(a, 2.14)
+    assert np.allclose(na, np.ones((10,), dtype=np.float16) * 2.14)
+    ggml.ggml_set_f32(a, 4.28)
+    assert np.allclose(na, np.ones((10,), dtype=np.float16) * 4.28)
 
     b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F16, 11, 21)
-    b = ggml.ggml_set_f32(b, 4.18)
-    assert np.allclose(ggml.to_numpy(b), np.ones((11, 21), dtype=np.float16) * 4.18)
+    nb = ggml.to_numpy(b)
+    ggml.ggml_set_f32(b, 4.18)
+    assert np.allclose(nb, np.ones((21, 11), dtype=np.float16) * 4.18)
+    ggml.ggml_set_f32(b, 5.12)
+    assert np.allclose(nb, np.ones((21, 11), dtype=np.float16) * 5.12)
 
     c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F16, 12, 22, 32)
-    c = ggml.ggml_set_f32(c, 3.16)
-    assert np.allclose(ggml.to_numpy(c), np.ones((12, 22, 32), dtype=np.float16) * 3.16)
+    nc = ggml.to_numpy(c)
+    ggml.ggml_set_f32(c, 3.16)
+    assert np.allclose(nc, np.ones((32, 22, 12), dtype=np.float16) * 3.16)
+    ggml.ggml_set_f32(c, 5.08)
+    assert np.allclose(nc, np.ones((32, 22, 12), dtype=np.float16) * 5.08)
 
 
 def test_from_numpy_works_with_f16(ctx: Ctx) -> None:
@@ -152,6 +213,7 @@ def test_from_numpy_works_with_f16(ctx: Ctx) -> None:
 
 
 def test_ning_model_load(ctx: Ctx) -> None:
+    pytest.skip("borken")
     model, vocab = ggml.unity_model_load(UNITY_MODELS / "unity-large/ggml-model.bin")
     print(model, vocab)
 
@@ -204,6 +266,34 @@ def test_hparams_code_is_up_to_date() -> None:
     assert hparams_struct in actual_code
 
 
+def test_forward_linear(ctx: Ctx) -> None:
+    slen, d_in, d_out = (5, 4, 2)
+    # torch.nn and fairseq2.nn assumes (seq_len, dim) to represent inputs,
+    x = np.zeros((slen, d_in), dtype=np.float32)  # (seq_len, dim_in)
+    # torch.nn.init.uniform_(x, -1, 1)
+    x[0, :] = [1, 1/3, 0, 0]
+
+    # linear = fairseq2.nn.Linear(d_in, d_out, bias=False)
+    weight = np.eye(d_out, d_in, dtype=np.float32)
+    weight[1, 1] = 1
+    # assert weight.shape == (d_out, d_in) # (dim_out, dim_in)
+    y_exp = (x @ weight.T)  # (seq_len, dim_out)
+
+    gx = ggml.from_numpy(ctx, x)  # (dim_in, seq_len)
+    gw = ggml.from_numpy(ctx, weight)  # (dim_in, dim_out)
+    # gb = ggml.from_numpy(ctx, linear.bias.numpy())  # (dim_out)
+    # GGML linear impl
+    assert ggml.ggml_can_mul_mat(gw, gx)
+    # gy = ggml.ggml_add(ctx, ggml.ggml_mul_mat(ctx, gw, gx), gb)  # (dim_out, seq_len)
+    gy = ggml.ggml_mul_mat(ctx, gw, gx)  # (dim_out, seq_len)
+
+    gf = ggml.ggml_build_forward(gy)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+
+    y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1])
+    assert np.allclose(y_exp, y)
+
+
 def test_forward_ffn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
     x = torch.empty((1024))
     torch.nn.init.uniform_(x, -1, 1)
@@ -236,3 +326,32 @@ def test_forward_layer_norm(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None
     y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1]).reshape(-1)
     abs_diff = np.max(np.abs(y - y_exp))
     assert np.allclose(y_exp, y)
+
+
+def test_forward_self_attn(ctx: Ctx, g_model: NativeObj, pt_model: Any) -> None:
+    x = torch.empty((1, 25, 1024))
+
+    torch.nn.init.uniform_(x, -1, 1)
+
+    self_attn = pt_model.text_encoder.layers[0].self_attn
+    # Replace spda by just returning queries
+    # TODO: implement spda
+    self_attn.spda = lambda *qkv, **kwargs: qkv[0]
+
+    y_exp = self_attn(x, None, x, x).numpy()
+    gx = ggml.from_numpy(ctx, x)
+    gy = ggml.forward(
+        "MultiheadAttention",
+        g_model,
+        "text_encoder.layers.0.self_attn",
+        gx,
+        gx,
+        gx,
+        None,
+    )
+    gf = ggml.ggml_build_forward(gy)
+    ggml.ggml_graph_compute_with_ctx(ctx, ctypes.pointer(gf), 1)
+
+    y = ggml.to_numpy(gf.nodes[gf.n_nodes - 1]).reshape(-1)
+    abs_diff = np.max(np.abs(y - y_exp))
+    assert np.allclose(y_exp, y)