ctypes_utils.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import ctypes
  2. import dataclasses
  3. import functools
  4. import inspect
  5. import types
  6. from typing import Any, Callable, Generic, Optional, Type, TypeVar
  7. T = TypeVar("T")
  8. class Ptr(Generic[T], ctypes._Pointer): # type: ignore
  9. contents: T
  10. def __new__(cls, x: T) -> "Ptr[T]":
  11. return ctypes.pointer(x) # type: ignore
  12. NULLPTR: Ptr[Any] = None # type: ignore[assignment]
  13. def c_struct(cls: Type[T]) -> Type[T]:
  14. struct = types.new_class(cls.__name__, bases=(ctypes.Structure,))
  15. struct.__module__ = cls.__module__
  16. struct._fields_ = [ # type: ignore
  17. (k, _py_type_to_ctype(v)) for k, v in cls.__annotations__.items()
  18. ]
  19. def nice_init(self: T, *args: Any, **kwargs: Any) -> None:
  20. dc = cls(*args, **kwargs)
  21. for k, _ in self._fields_: # type: ignore
  22. setattr(self, k, getattr(dc, k))
  23. setattr(struct, "__init__", nice_init)
  24. return struct
  25. @functools.lru_cache(256)
  26. def _py_type_to_ctype(t: type) -> type:
  27. if isinstance(t, str):
  28. raise ValueError(
  29. f"Type parsing of '{t}' isn't supported, you need to provide a real type annotation."
  30. )
  31. if t is None:
  32. return None
  33. if isinstance(t, type):
  34. if t.__module__ == "ctypes":
  35. return t
  36. if issubclass(t, ctypes.Structure):
  37. return t
  38. if issubclass(t, ctypes._Pointer):
  39. return t
  40. if t is int:
  41. return ctypes.c_int
  42. if t is float:
  43. return ctypes.c_float
  44. if t is bool:
  45. return ctypes.c_bool
  46. if t is bytes:
  47. return ctypes.c_char_p
  48. if t is str:
  49. raise ValueError("str type is't supported by ctypes ?")
  50. if getattr(t, "__origin__", None) is Ptr:
  51. pointee = _py_type_to_ctype(t.__args__[0]) # type: ignore
  52. return ctypes.POINTER(pointee)
  53. return ctypes.c_void_p
  54. F = TypeVar("F", bound=Callable[..., Any])
  55. def _c_fn(module: Any, fn: F) -> F:
  56. if callable(module):
  57. c_fn = module
  58. else:
  59. c_fn = getattr(module, fn.__name__)
  60. annotations = fn.__annotations__
  61. if "return" not in annotations:
  62. raise ValueError(
  63. "@c_fn decorator requires type annotations on the decorated function."
  64. )
  65. c_fn.argtypes = [
  66. _py_type_to_ctype(t) for k, t in fn.__annotations__.items() if k != "return"
  67. ]
  68. c_fn.restype = _py_type_to_ctype(fn.__annotations__["return"])
  69. @functools.wraps(fn)
  70. def actual_fn(*args, **kwargs): # type: ignore
  71. raw_res = c_fn(*args, **kwargs)
  72. return raw_res
  73. return actual_fn # type: ignore
  74. def c_fn(module: Any) -> Callable[[F], F]:
  75. return functools.partial(_c_fn, module)