Browse Source

use same convention than ggml.c

Guillaume Wenzek 1 year ago
parent
commit
9021fad301
3 changed files with 252 additions and 22 deletions
  1. 11 9
      ggml/examples/unity/unity.cpp
  2. 126 7
      ggml/ggml.py
  3. 115 6
      ggml/test_unity_cpp.py

+ 11 - 9
ggml/examples/unity/unity.cpp

@@ -111,19 +111,19 @@ struct unity_model {
     std::map<std::string, struct ggml_tensor *> tensors;
 };
 
-extern "C" unity_model* alloc_unity_model() {
+extern "C" unity_model* unity_model_alloc() {
     return new unity_model;
 }
 
-extern "C" void free_unity_model(unity_model* model) {
+extern "C" void unity_model_free(unity_model* model) {
     delete model;
 }
 
-extern "C" gpt_vocab* alloc_gpt_vocab() {
+extern "C" gpt_vocab* gpt_vocab_alloc() {
     return new gpt_vocab;
 }
 
-extern "C" void free_gpt_vocab(gpt_vocab* vocab) {
+extern "C" void gpt_vocab_free(gpt_vocab* vocab) {
     delete vocab;
 }
 
@@ -469,7 +469,7 @@ extern "C" bool unity_model_load(const char* fname, unity_model& model, gpt_voca
 }
 
 // build the computation graph
-struct ggml_cgraph * unity_graph(
+extern "C" struct ggml_cgraph * unity_graph(
         const unity_model & model,
         struct ggml_allocr * allocr) {
 
@@ -603,12 +603,12 @@ struct ggml_cgraph * unity_graph(
     return gf;
 }
 
-extern "C" bool unity_eval(
+extern "C" struct ggml_cgraph * unity_eval(
         const unity_model & model,
         struct ggml_allocr * allocr,
         const int n_threads) {
 
-    const auto & hparams = model.hparams;
+    // const auto & hparams = model.hparams;
 
     // reset the allocator to free all the memory allocated during the previous inference
     ggml_allocr_reset(allocr);
@@ -627,11 +627,13 @@ extern "C" bool unity_eval(
 
     // in this case, the output tensor is the last one in the graph
     struct ggml_tensor * inpL = gf->nodes[gf->n_nodes - 1];
+    printf("gf: %p, gf.nodes: %p, gf.n_nodes: %p", (void *)gf, (void *)gf->nodes, (void *)&(gf->n_nodes));
     for (int i = 0; i < 10; ++i) {
         printf("%8.4f ", ((float *)(inpL->data))[i]);
     }
+    printf("\n");
 
-    return true;
+    return gf;
 }
 
 int main(int argc, char ** argv) {
@@ -661,7 +663,7 @@ int main(int argc, char ** argv) {
 
     // load the model
     {
-        if (!unity_model_load(params.model.c_str(), &model, &vocab)) {
+        if (!unity_model_load(params.model.c_str(), model, vocab)) {
             fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str());
             return 1;
         }

+ 126 - 7
ggml/ggml.py

@@ -54,13 +54,21 @@ import sys
 import ctypes
 import pathlib
 import importlib.resources
+import numpy as np
+from typing import Union
+from typing import Type
+from typing import Callable
 from typing import Tuple
 from typing import Dict
+from typing import Self
 from typing import Any
 from pathlib import Path
 from typing import List, Optional, Sequence, Union
 from typing_extensions import TypeAlias
 
+NULL: ctypes.c_void_p = None  # ignore: type
+GGML_MEM_ALIGN = 16
+
 
 # Load the library
 def load_shared_library(base_path: Path, lib_base_name: str):
@@ -8045,36 +8053,116 @@ if GGML_USE_CLBLAST:
     ]
     lib.ggml_cl_transform_tensor.restype = None
 
+### Helpers
+
+
+def numpy_dtype(ggml_type: ctypes.c_int) -> Type:
+    if ggml_type == 0:
+        # GGML_TYPE_F32  = 0,
+        return np.float32
+
+    if ggml_type == 1:
+        # GGML_TYPE_F16  = 1,
+        return np.float16
+
+    raise NotImplementedError(f"Can't convert GGML_TYPE({ggml_type}) to a numpy.dtype")
+
+
+def from_numpy_dtype(dtype: np.dtype) -> ctypes.c_int:
+    if dtype == np.float32:
+        return ctypes.c_int(0)
+    elif dtype == np.float16:
+        return ctypes.c_int(1)
+    raise NotImplementedError(f"Can't convert {dtype} to a GGML_TYPE")
+
+
+def shape(tensor: Union[ggml_tensor, ggml_tensor_p]) -> Tuple[int, ...]:
+    if isinstance(tensor, ctypes._Pointer):
+        tensor = tensor.contents
+    ndims = tensor.n_dims
+    return tuple([tensor.ne[i] for i in range(ndims)])
+
+
+def strides(tensor: Union[ggml_tensor, ggml_tensor_p]) -> Tuple[int, ...]:
+    if isinstance(tensor, ctypes._Pointer):
+        tensor = tensor.contents
+    ndims = tensor.n_dims
+    return tuple([tensor.nb[i] for i in range(ndims)])
+
+
+def to_numpy(tensor: Union[ggml_tensor, ggml_tensor_p]) -> np.ndarray:
+    if isinstance(tensor, ctypes._Pointer):
+        tensor = tensor.contents
+
+    t_shape = shape(tensor)
+
+    # Convert the ggml data pointer to a pointer to ints with the same size (float16 -> uint16)
+    # This is needed because Python ctypes doesn't have "float16", and as_array only works with ctypes pointer
+    type_size = ggml_type_size(tensor.type)
+    int_width: Type[Any] = getattr(ctypes, f"c_uint{8 * type_size}")
+    ptr = ctypes.cast(tensor.data, ctypes.POINTER(int_width))
+    # Create a numpy array with the wrong dtype
+    int_arr = np.ctypeslib.as_array(ptr, shape=t_shape)
+    # Reinterpret it to the right dtype
+    res = np.frombuffer(int_arr, dtype=numpy_dtype(tensor.type)).reshape(t_shape)
+
+    # TODO: assert strides / check contiguous
+    # assert strides(tensor) == res.strides, "TODO: support strided tensor"
+    return res
+
+
+GgmlShape = ctypes.c_int64 * GGML_MAX_DIMS
+
+
+def from_numpy(ctx: ggml_context_p, array: np.ndarray) -> ggml_tensor_p:
+    tensor_p = ggml_new_tensor(
+        ctx, from_numpy_dtype(array.dtype), 1, GgmlShape(0, 0, 0, 0)
+    )
+    tensor_p.contents.n_dims = array.ndim
+    tensor_p.contents.data = array.ctypes.data_as(ctypes.c_void_p)
+    tensor_p.contents.ne = GgmlShape(*array.shape)
+    # print(f"array: {array.shape} @0x{array.ctypes.data_as(ctypes.c_void_p)}")
+    # print(f"tensor_p: {shape(tensor_p)} @0x{tensor_p.contents.data:x}")
+    return tensor_p
+
 
 class NativeObj:
-    _cache: Dict[str, Any] = {}
+    AllocFn = Callable[[], ctypes.c_void_p]
+    FreeFn = Callable[[ctypes.c_void_p], None]
+    _cache: Dict[str, Tuple[AllocFn, FreeFn]] = {}
 
     @classmethod
-    def _init_c_func(cls, kind: str) -> Any:
+    def _init_c_func(cls, kind: str) -> Tuple[AllocFn, FreeFn]:
         if kind in cls._cache:
             return cls._cache[kind]
 
-        alloc_fn = getattr(lib, f"alloc_{kind}")
+        alloc_fn = getattr(lib, f"{kind}_alloc")
         alloc_fn.argtypes = []
         alloc_fn.restype = ctypes.c_void_p
 
-        free_fn = getattr(lib, f"free_{kind}")
+        free_fn = getattr(lib, f"{kind}_free")
         free_fn.argtypes = [ctypes.c_void_p]
         free_fn.restype = None
 
         cls._cache[kind] = (alloc_fn, free_fn)
         return (alloc_fn, free_fn)
 
-    def __init__(self, kind: str):
+    def __init__(self, kind: str, ptr: ctypes.c_void_p = NULL):
         self.kind = kind
         alloc_fn, self._free_fn = self._init_c_func(kind)
-        self.ptr = alloc_fn()
+        self.ptr = alloc_fn() if ptr is None else ptr
         print(self)
 
     def free(self) -> None:
         if self.ptr is not None:
             self._free_fn(self.ptr)
-            self.ptr = None
+            self.ptr = NULL
+
+    def __enter__(self) -> Self:
+        return self
+
+    def __exit__(self, *args: Any) -> None:
+        self.free()
 
     def __del__(self) -> None:
         self.free()
@@ -8083,6 +8171,9 @@ class NativeObj:
         return f"<{self.kind} native object at 0x{self.ptr:x}>"
 
 
+### unity.cpp stuff
+
+
 def UnityModel() -> NativeObj:
     return NativeObj("unity_model")
 
@@ -8091,6 +8182,18 @@ def GptVocab() -> NativeObj:
     return NativeObj("gpt_vocab")
 
 
+def MeasureArena() -> NativeObj:
+    return NativeObj("ggml_allocr", ggml_allocr_new_measure(GGML_MEM_ALIGN))
+
+
+def FixedSizeArena(mem_size: int) -> NativeObj:
+    memory = np.zeros(mem_size, dtype=np.uint8)
+    allocr = ggml_allocr_new(
+        memory.ctypes.data_as(ctypes.POINTER(ctypes.c_byte)), mem_size, GGML_MEM_ALIGN
+    )
+    return NativeObj("ggml_allocr", allocr)
+
+
 lib.unity_model_load.argtypes = [ctypes.c_char_p, ctypes.c_void_p, ctypes.c_void_p]
 
 
@@ -8103,3 +8206,19 @@ def unity_model_load(model_file: Path) -> Tuple[NativeObj, NativeObj]:
         vocab.ptr,
     )
     return model, vocab
+
+
+lib.unity_graph.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+lib.unity_graph.restype = ctypes.POINTER(ggml_cgraph)
+
+
+def unity_graph(model: NativeObj, allocr: NativeObj) -> ggml_cgraph_p:
+    return lib.unity_graph(model.ptr, allocr.ptr)  # type: ignore
+
+
+lib.unity_eval.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
+lib.unity_eval.restype = ctypes.POINTER(ggml_cgraph)
+
+
+def unity_eval(model: NativeObj, allocr: NativeObj, n_threads: int) -> ggml_cgraph_p:
+    return lib.unity_eval(model.ptr, allocr.ptr, n_threads)

+ 115 - 6
ggml/test_unity_cpp.py

@@ -1,12 +1,26 @@
 import ggml
 import ctypes
+import torch
+import pytest
+import numpy as np
+from typing import Iterator
+from ggml import NativeObj
 
+Ctx = ggml.ggml_context_p
 
-def test_ggml_bindings_work() -> None:
-    # Allocate a new context with 16 MB of memory
-    params = ggml.ggml_init_params(mem_size=16 * 1024 * 1024, mem_buffer=None)
-    ctx = ggml.ggml_init(params=params)
+PARAMS_16MB = ggml.ggml_init_params(mem_size=16 * 1024 * 1024, mem_buffer=None)
 
+@pytest.fixture(name="ctx")
+def _ctx() -> Iterator[Ctx]:
+    """Allocate a new context with 16 MB of memory"""
+    try:
+        ctx = ggml.ggml_init(params=PARAMS_16MB)
+        yield ctx
+    finally:
+        ggml.ggml_free(ctx)
+
+
+def test_ggml_bindings_work(ctx: Ctx) -> None:
     # Instantiate tensors
     x = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
     a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 1)
@@ -30,8 +44,79 @@ def test_ggml_bindings_work() -> None:
     output = ggml.ggml_get_f32_1d(f, 0)
     assert output == 16.0
 
-    # Free the context
-    ggml.ggml_free(ctx)
+
+def test_shape_works(ctx: Ctx) -> None:
+    a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
+    assert ggml.shape(a) == (10,)
+
+    b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
+    assert ggml.shape(b) == (11, 21)
+
+    c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
+    assert ggml.shape(c) == (12, 22, 32)
+
+
+@pytest.mark.xfail(
+    reason="TODO: understand diff between ggml strides and numpy strides"
+)
+def test_strides_works(ctx: Ctx) -> None:
+    a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
+    assert ggml.strides(a) == np.ones((10,), dtype=np.float32).strides
+
+    b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
+    assert ggml.strides(b) == np.ones((11, 21), dtype=np.float32).strides
+
+    c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
+    assert ggml.strides(c) == np.ones((12, 22, 32), dtype=np.float32).strides
+
+
+def test_to_numpy_works_with_f32(ctx: Ctx) -> None:
+    a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F32, 10)
+    a = ggml.ggml_set_f32(a, 2.14)
+    assert np.allclose(ggml.to_numpy(a), np.ones((10,)) * 2.14)
+    b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F32, 11, 21)
+    assert np.allclose(ggml.to_numpy(b), np.zeros((11, 21)))
+    c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F32, 12, 22, 32)
+    assert np.allclose(ggml.to_numpy(c), np.zeros((12, 22, 32)))
+
+
+def test_from_numpy_works_with_f32(ctx: Ctx) -> None:
+    a = np.random.normal(size=(10,)).astype(dtype=np.float32)
+    ga = ggml.from_numpy(ctx, a)
+    assert np.allclose(a, ggml.to_numpy(ga))
+    a = np.random.normal(size=(11, 21)).astype(dtype=np.float32)
+    ga = ggml.from_numpy(ctx, a)
+    assert np.allclose(a, ggml.to_numpy(ga))
+    a = np.random.normal(size=(12, 22, 32)).astype(dtype=np.float32)
+    ga = ggml.from_numpy(ctx, a)
+    assert np.allclose(a, ggml.to_numpy(ga))
+
+
+def test_to_numpy_works_with_f16(ctx: Ctx) -> None:
+    # We explicitly fill the tensor otherwise they might have non-zero values in them.
+    a = ggml.ggml_new_tensor_1d(ctx, ggml.GGML_TYPE_F16, 10)
+    a = ggml.ggml_set_f32(a, 2.14)
+    assert np.allclose(ggml.to_numpy(a), np.ones((10,), dtype=np.float16) * 2.14)
+
+    b = ggml.ggml_new_tensor_2d(ctx, ggml.GGML_TYPE_F16, 11, 21)
+    b = ggml.ggml_set_f32(b, 4.18)
+    assert np.allclose(ggml.to_numpy(b), np.ones((11, 21), dtype=np.float16) * 4.18)
+
+    c = ggml.ggml_new_tensor_3d(ctx, ggml.GGML_TYPE_F16, 12, 22, 32)
+    c = ggml.ggml_set_f32(c, 3.16)
+    assert np.allclose(ggml.to_numpy(c), np.ones((12, 22, 32), dtype=np.float16) * 3.16)
+
+
+def test_from_numpy_works_with_f16(ctx: Ctx) -> None:
+    a = np.random.normal(size=(10,)).astype(dtype=np.float16)
+    ga = ggml.from_numpy(ctx, a)
+    assert np.allclose(a, ggml.to_numpy(ga))
+    a = np.random.normal(size=(11, 21)).astype(dtype=np.float16)
+    ga = ggml.from_numpy(ctx, a)
+    assert np.allclose(a, ggml.to_numpy(ga))
+    a = np.random.normal(size=(12, 22, 32)).astype(dtype=np.float16)
+    ga = ggml.from_numpy(ctx, a)
+    assert np.allclose(a, ggml.to_numpy(ga))
 
 
 def test_unity_model_load() -> None:
@@ -39,3 +124,27 @@ def test_unity_model_load() -> None:
         "examples/unity/models/unity-large/ggml-model.bin"
     )
     print(model, vocab)
+    with ggml.MeasureArena() as arena:
+        # compute graph
+        graph = ggml.unity_graph(model, arena)
+        # required memory
+        # TODO: why the extra padding ?
+        mem_size = ggml.ggml_allocr_alloc_graph(arena.ptr, graph) + ggml.GGML_MEM_ALIGN
+
+    compute_buffer = torch.zeros(mem_size, dtype=torch.uint8)
+    allocr = NativeObj(
+        "ggml_allocr",
+        ggml.ggml_allocr_new(compute_buffer.data_ptr(), mem_size, ggml.GGML_MEM_ALIGN),
+    )
+    print(
+        f"unity_graph: compute buffer size: {mem_size/1024/1024} MB  @0x{compute_buffer.data_ptr():x}"
+    )
+
+    eval_res_ptr = ggml.unity_eval(model, allocr, 1)
+    eval_res = eval_res_ptr.contents
+    inpL = ggml.to_numpy(eval_res.nodes[eval_res.n_nodes - 1])
+    expected_raw = (
+        "-0.1308,0.0346,-0.2656,0.2873,-0.0104,0.0574,0.4033,-0.1125,-0.0460,-0.0496"
+    )
+    expected = map(float, expected_raw.split(","))
+    assert np.allclose(inpL[0, :10], list(expected), atol=1e-4)