|
@@ -42,13 +42,27 @@ def numpy_dtype(ggml_type: ctypes.c_int) -> np.dtype:
|
|
raise NotImplementedError(f"Can't convert GGML_TYPE({ggml_type}) to a numpy.dtype")
|
|
raise NotImplementedError(f"Can't convert GGML_TYPE({ggml_type}) to a numpy.dtype")
|
|
|
|
|
|
|
|
|
|
|
|
+@functools.lru_cache()
|
|
def from_numpy_dtype(dtype: np.dtype) -> ctypes.c_int:
|
|
def from_numpy_dtype(dtype: np.dtype) -> ctypes.c_int:
|
|
|
|
+
|
|
|
|
+ def _ggml_type(name: bytes, value: int) -> ctypes.c_int:
|
|
|
|
+ t = ctypes.c_int(value)
|
|
|
|
+ type_name = ggml_type_name(t)
|
|
|
|
+ if name != type_name:
|
|
|
|
+ raise RuntimeError(
|
|
|
|
+ f"Type '{name}' doesn't have value {value}. ggml.h was probably updated but not ggml.py"
|
|
|
|
+ )
|
|
|
|
+ return t
|
|
|
|
+
|
|
if dtype == np.float32:
|
|
if dtype == np.float32:
|
|
- return ctypes.c_int(0)
|
|
|
|
- elif dtype == np.int32:
|
|
|
|
- return ctypes.c_int(18)
|
|
|
|
|
|
+ return _ggml_type(b"f32", 0)
|
|
elif dtype == np.float16:
|
|
elif dtype == np.float16:
|
|
- return ctypes.c_int(1)
|
|
|
|
|
|
+ return _ggml_type(b"f16", 1)
|
|
|
|
+ elif dtype == np.dtype("bool"):
|
|
|
|
+ return _ggml_type(b"i8", 16)
|
|
|
|
+ elif dtype == np.int32:
|
|
|
|
+ return _ggml_type(b"i32", 18)
|
|
|
|
+
|
|
raise NotImplementedError(f"Can't convert {dtype} to a GGML_TYPE")
|
|
raise NotImplementedError(f"Can't convert {dtype} to a GGML_TYPE")
|
|
|
|
|
|
|
|
|