ggml.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. """
  2. We are vendoring https://github.com/abetlen/ggml-python (MIT License)
  3. adding a few utilities to convert between ggml and numpy tensors for testing.
  4. """
  5. import numpy as np
  6. import ctypes
  7. import torch
  8. import functools
  9. from pathlib import Path
  10. from typing import Dict
  11. from typing import Callable
  12. from typing import Any
  13. from typing import Tuple
  14. from typing import Union
  15. from typing import Type
  16. from third_party_ggml import *
  17. from ctypes_utils import c_struct, c_fn, Ptr
  18. ### Helpers
  19. @functools.lru_cache(4)
  20. def numpy_dtype(ggml_type: ctypes.c_int) -> np.dtype:
  21. if ggml_type == 0:
  22. # GGML_TYPE_F32 = 0,
  23. return np.dtype(np.float32)
  24. if ggml_type == 1:
  25. # GGML_TYPE_F16 = 1,
  26. return np.dtype(np.float16)
  27. if ggml_type == 18:
  28. return np.dtype(np.int32)
  29. raise NotImplementedError(f"Can't convert GGML_TYPE({ggml_type}) to a numpy.dtype")
  30. def from_numpy_dtype(dtype: np.dtype) -> ctypes.c_int:
  31. if dtype == np.float32:
  32. return ctypes.c_int(0)
  33. elif dtype == np.int32:
  34. return ctypes.c_int(18)
  35. elif dtype == np.float16:
  36. return ctypes.c_int(1)
  37. raise NotImplementedError(f"Can't convert {dtype} to a GGML_TYPE")
  38. def shape(tensor: Union[ggml_tensor, ggml_tensor_p]) -> Tuple[int, ...]:
  39. if isinstance(tensor, ctypes._Pointer):
  40. tensor = tensor.contents
  41. ndims = tensor.n_dims
  42. return tuple([tensor.ne[i] for i in range(ndims)[::-1]])
  43. def nb(tensor: Union[ggml_tensor, ggml_tensor_p]) -> Tuple[int, ...]:
  44. if isinstance(tensor, ctypes._Pointer):
  45. tensor = tensor.contents
  46. return tuple([tensor.nb[i] for i in range(4)])
  47. def strides(tensor: Union[ggml_tensor, ggml_tensor_p]) -> Tuple[int, ...]:
  48. if isinstance(tensor, ctypes._Pointer):
  49. tensor = tensor.contents
  50. ndims = tensor.n_dims
  51. num_bytes = tuple([tensor.nb[i] for i in range(ndims)])
  52. strides = num_bytes[::-1]
  53. return strides
  54. def to_numpy(tensor_p: ggml_tensor_p) -> np.ndarray:
  55. if not ggml_is_contiguous(tensor_p):
  56. return _strided_to_numpy(tensor_p)
  57. tensor = tensor_p.contents
  58. res = _void_p_to_np_array(tensor.data, shape(tensor), numpy_dtype(tensor.type))
  59. if ggml_is_transposed(tensor_p):
  60. # Patch up strides to work with transposed ggml_tensor
  61. res.strides = strides(tensor) # type: ignore[assignment]
  62. return res
  63. def _strided_to_numpy(tensor_p: ggml_tensor_p) -> np.ndarray:
  64. if ggml_is_transposed(tensor_p):
  65. raise NotImplementedError(
  66. "to_numpy doesn't support tensors both transposed and strided."
  67. )
  68. tensor = tensor_p.contents
  69. n_dim = tensor.n_dims
  70. t_shape = shape(tensor)
  71. t_strides = strides(tensor)
  72. type_size = ggml_type_size(tensor.type)
  73. full_shape = []
  74. num_bytes = nb(tensor)
  75. # Determine the full backing slice of bytes to read.
  76. # TODO make this work for transposed array
  77. n = 1
  78. total_elements = 1
  79. for d in range(n_dim - 1):
  80. n = num_bytes[d + 1] // type_size // n
  81. full_shape.append(n)
  82. total_elements *= n
  83. # We don't need to guess for the first dimension, since this doesn't impact striding.
  84. full_shape.append(t_shape[0])
  85. total_elements *= t_shape[0]
  86. full_shape = full_shape[::-1]
  87. res = _void_p_to_np_array(tensor.data, tuple(full_shape), numpy_dtype(tensor.type))
  88. # Extract the correct slice
  89. res = res[*(slice(0, n) for n in t_shape)]
  90. # TODO: we could handle transposition here
  91. return res
  92. def _void_p_to_np_array(
  93. data: ctypes.c_void_p, shape: Tuple[int, ...], dtype: np.dtype
  94. ) -> np.ndarray:
  95. # Convert the ggml data pointer to a pointer of bytes
  96. # This is needed because Python ctypes doesn't have "float16", and `as_array` only works with ctypes
  97. int_width: type = getattr(ctypes, f"c_uint{8 * dtype.itemsize}")
  98. ptr = ctypes.cast(data, ctypes.POINTER(int_width))
  99. # Create a numpy array with the wrong dtype
  100. int_arr = np.ctypeslib.as_array(ptr, shape=shape)
  101. # Reinterpret it to the right dtype
  102. return np.frombuffer(int_arr, dtype=dtype).reshape(shape)
  103. GgmlNElem = ctypes.c_int64 * GGML_MAX_DIMS
  104. GgmlNBytes = ctypes.c_uint64 * GGML_MAX_DIMS
  105. def from_file(
  106. ctx: ggml_context_p, file: Path, shape: Tuple[int, ...], dtype: type = np.float32
  107. ) -> ggml_tensor_p:
  108. data = np.fromfile(str(file), dtype=dtype).reshape(shape) # type: ignore
  109. return from_numpy(ctx, data)
  110. def _shape_to_ne(shape: Tuple[int, ...]) -> Tuple[int, int, int, int]:
  111. # in GGML ne[0] indicates the contiguous dimension, ie the last one in numpy and torch
  112. ne = shape[::-1]
  113. if len(ne) >= GGML_MAX_DIMS:
  114. return # type: ignore
  115. # ne is always of the same length
  116. padding = (1,) * (GGML_MAX_DIMS - len(ne))
  117. return ne + padding # type: ignore
  118. def _compute_nbytes(
  119. ne: Tuple[int, int, int, int], type: ctypes.c_int
  120. ) -> Tuple[int, int, int, int]:
  121. nb0 = ggml_type_size(type)
  122. nb1 = nb0 * (ne[0] // ggml_blck_size(type))
  123. nb2 = nb1 * ne[1]
  124. nb3 = nb2 * ne[2]
  125. return (nb0, nb1, nb2, nb3)
  126. def from_numpy(
  127. ctx: ggml_context_p, array: Union[np.ndarray, "torch.Tensor"]
  128. ) -> ggml_tensor_p:
  129. if type(array).__name__ == "Tensor":
  130. array = array.numpy()
  131. # Create an empty tensor so we don't allocate memory for the data pointer
  132. gtype = from_numpy_dtype(array.dtype)
  133. tensor_p = ggml_new_tensor_1d(ctx, gtype, 0)
  134. # Fill out the correct dimensions and shape.
  135. tensor_p.contents.n_dims = array.ndim
  136. ne = _shape_to_ne(array.shape)
  137. tensor_p.contents.ne = GgmlNElem(*ne)
  138. tensor_p.contents.nb = GgmlNBytes(*_compute_nbytes(ne, gtype))
  139. # point the tensor data to the content of the numpy array.
  140. tensor_p.contents.data = array.ctypes.data_as(ctypes.c_void_p)
  141. # print(f"array: {array.shape} @0x{array.ctypes.data_as(ctypes.c_void_p)}")
  142. # print(f"tensor_p: {shape(tensor_p)} @0x{tensor_p.contents.data:x}")
  143. # prevent the underlying numpy array to be freed
  144. setattr(tensor_p, "__data", array)
  145. return tensor_p
  146. def ggml_can_mul_mat(t0: ggml_tensor_p, t1: ggml_tensor_p) -> bool:
  147. assert GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"
  148. return (
  149. (t0.contents.ne[0] == t1.contents.ne[0])
  150. and (t1.contents.ne[2] % t0.contents.ne[2] == 0)
  151. and (t1.contents.ne[3] % t0.contents.ne[3] == 0)
  152. )
  153. class NativeObj:
  154. AllocFn = Callable[[], ctypes.c_void_p]
  155. FreeFn = Callable[[ctypes.c_void_p], None]
  156. _cache: Dict[str, Tuple[AllocFn, FreeFn]] = {}
  157. @classmethod
  158. def _init_c_func(cls, kind: str) -> Tuple[AllocFn, FreeFn]:
  159. if kind in cls._cache:
  160. return cls._cache[kind]
  161. alloc_fn = getattr(lib, f"{kind}_alloc")
  162. alloc_fn.argtypes = []
  163. alloc_fn.restype = ctypes.c_void_p
  164. free_fn = getattr(lib, f"{kind}_free")
  165. free_fn.argtypes = [ctypes.c_void_p]
  166. free_fn.restype = None
  167. cls._cache[kind] = (alloc_fn, free_fn)
  168. return (alloc_fn, free_fn)
  169. def __init__(self, kind: str, ptr: ctypes.c_void_p = NULL):
  170. self.kind = kind
  171. alloc_fn, self._free_fn = self._init_c_func(kind)
  172. self.ptr = alloc_fn() if ptr is None else ptr
  173. # print(self)
  174. def free(self) -> None:
  175. if self.ptr is not None:
  176. self._free_fn(self.ptr)
  177. # print(f"freeing {self}")
  178. self.ptr = NULL
  179. def __enter__(self) -> ctypes.c_void_p:
  180. return self.ptr
  181. def __exit__(self, *args: Any) -> None:
  182. self.free()
  183. def __del__(self) -> None:
  184. self.free()
  185. def __repr__(self) -> str:
  186. return f"<{self.kind} native object at 0x{self.ptr:x}>"
  187. def MeasureArena() -> NativeObj:
  188. return NativeObj("ggml_allocr", ggml_allocr_new_measure(GGML_MEM_ALIGN))
  189. def FixedSizeArena(mem_size: int) -> NativeObj:
  190. memory = torch.zeros(mem_size, dtype=torch.uint8)
  191. allocr = ggml_allocr_new(
  192. ctypes.c_void_p(memory.data_ptr()), mem_size, GGML_MEM_ALIGN
  193. )
  194. arena = NativeObj("ggml_allocr", allocr)
  195. # Add a reference from the arena object to the underlying tensor, otherwise it will be freed to early.
  196. setattr(arena, "__memory", memory)
  197. return arena
  198. lib.fairseq2_model_set_inference_ctx.argtypes = [ctypes.c_void_p, ggml_context_p]
  199. def Fairseq2Model() -> NativeObj:
  200. return NativeObj("fairseq2_model")
  201. lib.std_string_alloc.argtypes = [ctypes.c_char_p]
  202. lib.std_string_alloc.restype = ctypes.c_void_p
  203. lib.std_string_free.argtypes = [ctypes.c_void_p]
  204. lib.std_string_free.restype = None
  205. NativeObj._cache["std_string"] = (lib.std_string_alloc, lib.std_string_free)
  206. @functools.lru_cache(1024)
  207. def CppStr(content: str) -> NativeObj:
  208. c_str = ctypes.create_string_buffer(content.encode("utf-8"))
  209. cpp_str = lib.std_string_alloc(c_str)
  210. return NativeObj("std_string", cpp_str)
  211. lib.load_unity_ggml_file.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
  212. lib.load_unity_ggml_file.restype = ctypes.c_int
  213. def load_unity_ggml_file(model_file: Path) -> NativeObj:
  214. model = Fairseq2Model()
  215. bytes_file = ctypes.create_string_buffer(str(model_file).encode("utf-8"))
  216. err = lib.load_unity_ggml_file(model.ptr, bytes_file)
  217. if err:
  218. raise Exception("Failed to load model")
  219. return model
  220. # lib.unity_audio_encoder_graph.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
  221. # lib.unity_audio_encoder_graph.restype = ctypes.POINTER(ggml_cgraph)
  222. # def unity_audio_encoder_graph(model: NativeObj, tensor: ggml_tensor_p) -> ggml_cgraph_p:
  223. # return lib.unity_audio_encoder_graph(model.ptr, tensor) # type: ignore
  224. # lib.unity_eval.argtypes = [
  225. # ctypes.c_void_p,
  226. # ctypes.c_void_p,
  227. # ctypes.POINTER(ggml_tensor),
  228. # ctypes.c_int,
  229. # ]
  230. # lib.unity_eval.restype = ctypes.POINTER(ggml_cgraph)
  231. # def unity_eval(
  232. # allocr: ctypes.c_void_p, model: NativeObj, tensor: ggml_tensor_p, n_threads: int
  233. # ) -> ggml_cgraph_p:
  234. # return lib.unity_eval(allocr, model.ptr, tensor, n_threads)
  235. _FORWARD_CACHE: Dict[str, Callable[..., ggml_tensor_p]] = {}
  236. def forward(
  237. layer_name: str, model: ctypes.c_void_p, prefix: str, *inputs: ggml_tensor_p
  238. ) -> ggml_tensor_p:
  239. fwd: Any = _FORWARD_CACHE.get(layer_name)
  240. if fwd is None:
  241. fwd = getattr(lib, layer_name + "_forward")
  242. num_inputs = len(inputs)
  243. fwd.argtypes = [ctypes.c_void_p, ctypes.c_void_p] + [
  244. ctypes.POINTER(ggml_tensor)
  245. ] * num_inputs
  246. fwd.restype = ctypes.POINTER(ggml_tensor)
  247. _FORWARD_CACHE[layer_name] = fwd
  248. with CppStr(prefix) as std_prefix:
  249. return fwd(model, std_prefix, *inputs) # ignore: type[no-any-return]
  250. @c_fn(lib)
  251. def causal_attention_mask(
  252. ctx: ggml_context_p, seqs: Ptr[ggml_tensor]
  253. ) -> Ptr[ggml_tensor]:
  254. ...
  255. @c_fn(lib)
  256. def ggml_slice(
  257. ctx: ggml_context_p,
  258. a: Ptr[ggml_tensor],
  259. axis: int,
  260. start: ctypes.c_int64,
  261. end: ctypes.c_int64,
  262. ) -> Ptr[ggml_tensor]:
  263. ...
  264. @c_struct
  265. class SequenceGeneratorOptions:
  266. beam_size: int
  267. min_seq_len: int
  268. soft_max_seq_len_a: int
  269. soft_max_seq_len_b: int
  270. hard_max_seq_len: int
  271. len_penalty: float
  272. unk_penalty: float
  273. normalize_scores: bool
  274. @c_struct
  275. class SequenceGeneratorJob:
  276. opts: SequenceGeneratorOptions
  277. prefix_seq: Ptr[ggml_tensor]
  278. pad_idx: int
  279. unk_idx: int
  280. bos_idx: int
  281. eos_idx: int
  282. @c_fn(lib)
  283. def generate_sequence(
  284. model: ctypes.c_void_p,
  285. job: Ptr[SequenceGeneratorJob],
  286. encoder_output: Ptr[ggml_tensor],
  287. encoder_padding_mask: Ptr[ggml_tensor],
  288. output_seq: Ptr[ggml_tensor],
  289. ) -> float:
  290. ...