ctypes_utils.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import inspect
  2. import ctypes
  3. import types
  4. import functools
  5. from typing import TypeVar
  6. from typing import Generic
  7. T = TypeVar("T")
  8. class Ptr(Generic[T]):
  9. contents: T
  10. def __new__(cls):
  11. return ctypes.pointer()
  12. def c_struct(cls):
  13. struct = types.new_class(cls.__name__, bases=(ctypes.Structure,))
  14. struct.__module__ = "ctypes"
  15. struct._fields_ = [
  16. (k, _py_type_to_ctype(v)) for k, v in cls.__annotations__.items()
  17. ]
  18. return struct
  19. @functools.lru_cache(256)
  20. def _py_type_to_ctype(t: type):
  21. if isinstance(t, str):
  22. raise ValueError(
  23. f"Type parsing of '{t}' isn't supported, you need to provide a real type annotation."
  24. )
  25. if t.__module__ == "ctypes":
  26. return t
  27. if isinstance(t, type) and issubclass(t, ctypes.Structure):
  28. return t
  29. if t is int:
  30. return ctypes.c_int
  31. if t is float:
  32. return ctypes.c_float
  33. if t is bool:
  34. return ctypes.c_bool
  35. if t is str:
  36. return ctypes.c_char_p
  37. if t is None:
  38. return None
  39. if getattr(t, "__origin__", None) is Ptr:
  40. pointee = _py_type_to_ctype(t.__args__[0])
  41. return ctypes.POINTER(pointee)
  42. return ctypes.c_void_p
  43. def _c_fn(module, fn):
  44. c_fn = getattr(module, fn.__name__)
  45. annotations = fn.__annotations__
  46. if "return" not in annotations:
  47. raise ValueError("@c_fn decorator requires type annotations on the decorated function.")
  48. c_fn.argtypes = [
  49. _py_type_to_ctype(t) for k, t in fn.__annotations__.items() if k != "return"
  50. ]
  51. c_fn.restype = _py_type_to_ctype(fn.__annotations__["return"])
  52. @functools.wraps(fn)
  53. def actual_fn(*args, **kwargs):
  54. return c_fn(*args, **kwargs)
  55. return actual_fn
  56. def c_fn(module):
  57. return functools.partial(_c_fn, module)