ctypes_utils.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import inspect
  2. import ctypes
  3. import types
  4. import functools
  5. import dataclasses
  6. from typing import Callable
  7. from typing import Any
  8. from typing import Optional
  9. from typing import Type
  10. from typing import TypeVar
  11. from typing import Generic
  12. T = TypeVar("T")
  13. class Ptr(Generic[T], ctypes._Pointer): # type: ignore
  14. contents: T
  15. def __new__(cls, x: T) -> "Ptr[T]":
  16. return ctypes.pointer(x) # type: ignore
  17. NULLPTR: Ptr[Any] = None # type: ignore[assignment]
  18. def c_struct(cls: Type[T]) -> Type[T]:
  19. struct = types.new_class(cls.__name__, bases=(ctypes.Structure,))
  20. struct.__module__ = cls.__module__
  21. struct._fields_ = [ # type: ignore
  22. (k, _py_type_to_ctype(v)) for k, v in cls.__annotations__.items()
  23. ]
  24. def nice_init(self: T, *args: Any, **kwargs: Any) -> None:
  25. dc = cls(*args, **kwargs)
  26. for k, _ in self._fields_: # type: ignore
  27. setattr(self, k, getattr(dc, k))
  28. setattr(struct, "__init__", nice_init)
  29. return struct
  30. @functools.lru_cache(256)
  31. def _py_type_to_ctype(t: type) -> type:
  32. if isinstance(t, str):
  33. raise ValueError(
  34. f"Type parsing of '{t}' isn't supported, you need to provide a real type annotation."
  35. )
  36. if t is None:
  37. return None
  38. if isinstance(t, type):
  39. if t.__module__ == "ctypes":
  40. return t
  41. if issubclass(t, ctypes.Structure):
  42. return t
  43. if issubclass(t, ctypes._Pointer):
  44. return t
  45. if t is int:
  46. return ctypes.c_int
  47. if t is float:
  48. return ctypes.c_float
  49. if t is bool:
  50. return ctypes.c_bool
  51. if t is str:
  52. return ctypes.c_char_p
  53. if getattr(t, "__origin__", None) is Ptr:
  54. pointee = _py_type_to_ctype(t.__args__[0]) # type: ignore
  55. return ctypes.POINTER(pointee)
  56. return ctypes.c_void_p
  57. F = TypeVar("F", bound=Callable[..., Any])
  58. def _c_fn(module: Any, fn: F) -> F:
  59. c_fn = getattr(module, fn.__name__)
  60. annotations = fn.__annotations__
  61. if "return" not in annotations:
  62. raise ValueError("@c_fn decorator requires type annotations on the decorated function.")
  63. c_fn.argtypes = [
  64. _py_type_to_ctype(t) for k, t in fn.__annotations__.items() if k != "return"
  65. ]
  66. c_fn.restype = _py_type_to_ctype(fn.__annotations__["return"])
  67. @functools.wraps(fn)
  68. def actual_fn(*args, **kwargs): # type: ignore
  69. raw_res = c_fn(*args, **kwargs)
  70. return raw_res
  71. return actual_fn # type: ignore
  72. def c_fn(module: Any) -> Callable[[F], F]:
  73. return functools.partial(_c_fn, module)