Explorar el Código

dtype("bool")

Guillaume Wenzek hace 1 año
padre
commit
c25c280302
Se han modificado 1 ficheros con 18 adiciones y 4 borrados
  1. 18 4
      ggml/ggml.py

+ 18 - 4
ggml/ggml.py

@@ -42,13 +42,27 @@ def numpy_dtype(ggml_type: ctypes.c_int) -> np.dtype:
     raise NotImplementedError(f"Can't convert GGML_TYPE({ggml_type}) to a numpy.dtype")
 
 
+@functools.lru_cache()
 def from_numpy_dtype(dtype: np.dtype) -> ctypes.c_int:
+
+    def _ggml_type(name: bytes, value: int) -> ctypes.c_int:
+        t = ctypes.c_int(value)
+        type_name = ggml_type_name(t)
+        if name != type_name:
+            raise RuntimeError(
+                f"Type '{name}' doesn't have value {value}. ggml.h was probably updated but not ggml.py"
+            )
+        return t
+
     if dtype == np.float32:
-        return ctypes.c_int(0)
-    elif dtype == np.int32:
-        return ctypes.c_int(18)
+        return _ggml_type(b"f32", 0)
     elif dtype == np.float16:
-        return ctypes.c_int(1)
+        return _ggml_type(b"f16", 1)
+    elif dtype == np.dtype("bool"):
+        return _ggml_type(b"i8", 16)
+    elif dtype == np.int32:
+        return _ggml_type(b"i32", 18)
+
     raise NotImplementedError(f"Can't convert {dtype} to a GGML_TYPE")