ctypes_utils.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import inspect
  2. import ctypes
  3. import types
  4. import functools
  5. from typing import TypeVar
  6. from typing import Generic
  7. T = TypeVar("T")
  8. class Ptr(Generic[T]):
  9. contents: T
  10. def __new__(cls):
  11. breakpoint()
  12. return ctypes.pointer()
  13. def c_struct(cls):
  14. struct = types.new_class(cls.__name__, bases=(ctypes.Structure,))
  15. struct.__module__ = cls.__module__
  16. struct._fields_ = [
  17. (k, _py_type_to_ctype(v)) for k, v in cls.__annotations__.items()
  18. ]
  19. return struct
  20. @functools.lru_cache(256)
  21. def _py_type_to_ctype(t: type):
  22. if isinstance(t, str):
  23. raise ValueError(
  24. f"Type parsing of '{t}' isn't supported, you need to provide a real type annotation."
  25. )
  26. if t.__module__ == "ctypes":
  27. return t
  28. if isinstance(t, type):
  29. if issubclass(t, ctypes.Structure):
  30. return t
  31. if issubclass(t, ctypes._Pointer):
  32. return t
  33. if t is int:
  34. return ctypes.c_int
  35. if t is float:
  36. return ctypes.c_float
  37. if t is bool:
  38. return ctypes.c_bool
  39. if t is str:
  40. return ctypes.c_char_p
  41. if t is None:
  42. return None
  43. if getattr(t, "__origin__", None) is Ptr:
  44. pointee = _py_type_to_ctype(t.__args__[0])
  45. return ctypes.POINTER(pointee)
  46. return ctypes.c_void_p
  47. def _c_fn(module, fn):
  48. c_fn = getattr(module, fn.__name__)
  49. annotations = fn.__annotations__
  50. if "return" not in annotations:
  51. raise ValueError("@c_fn decorator requires type annotations on the decorated function.")
  52. c_fn.argtypes = [
  53. _py_type_to_ctype(t) for k, t in fn.__annotations__.items() if k != "return"
  54. ]
  55. c_fn.restype = _py_type_to_ctype(fn.__annotations__["return"])
  56. @functools.wraps(fn)
  57. def actual_fn(*args, **kwargs):
  58. raw_res = c_fn(*args, **kwargs)
  59. return raw_res
  60. return actual_fn
  61. def c_fn(module):
  62. return functools.partial(_c_fn, module)