stubs.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. """
  2. This generates .pyi stubs for the cffi Python bindings generated by regenerate.py
  3. """
  4. import sys, re, itertools
  5. sys.path.extend(['.', '..']) # for pycparser
  6. from pycparser import c_ast, parse_file, CParser
  7. import pycparser.plyparser
  8. from pycparser.c_ast import PtrDecl, TypeDecl, FuncDecl, EllipsisParam, IdentifierType, Struct, Enum, Typedef
  9. from typing import Tuple
  10. __c_type_to_python_type = {
  11. 'void': 'None', '_Bool': 'bool',
  12. 'char': 'int', 'short': 'int', 'int': 'int', 'long': 'int',
  13. 'ptrdiff_t': 'int', 'size_t': 'int',
  14. 'int8_t': 'int', 'uint8_t': 'int',
  15. 'int16_t': 'int', 'uint16_t': 'int',
  16. 'int32_t': 'int', 'uint32_t': 'int',
  17. 'int64_t': 'int', 'uint64_t': 'int',
  18. 'float': 'float', 'double': 'float',
  19. 'ggml_fp16_t': 'np.float16',
  20. }
  21. def format_type(t: TypeDecl):
  22. if isinstance(t, PtrDecl) or isinstance(t, Struct):
  23. return 'ffi.CData'
  24. if isinstance(t, Enum):
  25. return 'int'
  26. if isinstance(t, TypeDecl):
  27. return format_type(t.type)
  28. if isinstance(t, IdentifierType):
  29. assert len(t.names) == 1, f'Expected a single name, got {t.names}'
  30. return __c_type_to_python_type.get(t.names[0]) or 'ffi.CData'
  31. return t.name
  32. class PythonStubFuncDeclVisitor(c_ast.NodeVisitor):
  33. def __init__(self):
  34. self.sigs = {}
  35. self.sources = {}
  36. def get_source_snippet_lines(self, coord: pycparser.plyparser.Coord) -> Tuple[list[str], list[str]]:
  37. if coord.file not in self.sources:
  38. with open(coord.file, 'rt') as f:
  39. self.sources[coord.file] = f.readlines()
  40. source_lines = self.sources[coord.file]
  41. ncomment_lines = len(list(itertools.takewhile(lambda i: re.search(r'^\s*(//|/\*)', source_lines[i]), range(coord.line - 2, -1, -1))))
  42. comment_lines = [l.strip() for l in source_lines[coord.line - 1 - ncomment_lines:coord.line - 1]]
  43. decl_lines = []
  44. for line in source_lines[coord.line - 1:]:
  45. decl_lines.append(line.rstrip())
  46. if (';' in line) or ('{' in line): break
  47. return (comment_lines, decl_lines)
  48. def visit_Enum(self, node: Enum):
  49. if node.values is not None:
  50. for e in node.values.enumerators:
  51. self.sigs[e.name] = f' @property\n def {e.name}(self) -> int: ...'
  52. def visit_Typedef(self, node: Typedef):
  53. pass
  54. def visit_FuncDecl(self, node: FuncDecl):
  55. ret_type = node.type
  56. is_ptr = False
  57. while isinstance(ret_type, PtrDecl):
  58. ret_type = ret_type.type
  59. is_ptr = True
  60. fun_name = ret_type.declname
  61. if fun_name.startswith('__'):
  62. return
  63. args = []
  64. argnames = []
  65. def gen_name(stem):
  66. i = 1
  67. while True:
  68. new_name = stem if i == 1 else f'{stem}{i}'
  69. if new_name not in argnames: return new_name
  70. i += 1
  71. for a in node.args.params:
  72. if isinstance(a, EllipsisParam):
  73. arg_name = gen_name('args')
  74. argnames.append(arg_name)
  75. args.append('*' + gen_name('args'))
  76. elif format_type(a.type) == 'None':
  77. continue
  78. else:
  79. arg_name = a.name or gen_name('arg')
  80. argnames.append(arg_name)
  81. args.append(f'{arg_name}: {format_type(a.type)}')
  82. ret = format_type(ret_type if not is_ptr else node.type)
  83. comment_lines, decl_lines = self.get_source_snippet_lines(node.coord)
  84. lines = [f' def {fun_name}({", ".join(args)}) -> {ret}:']
  85. if len(comment_lines) == 0 and len(decl_lines) == 1:
  86. lines += [f' """{decl_lines[0]}"""']
  87. else:
  88. lines += [' """']
  89. lines += [f' {c.lstrip("/* ")}' for c in comment_lines]
  90. if len(comment_lines) > 0:
  91. lines += ['']
  92. lines += [f' {d}' for d in decl_lines]
  93. lines += [' """']
  94. lines += [' ...']
  95. self.sigs[fun_name] = '\n'.join(lines)
  96. def generate_stubs(header: str):
  97. """
  98. Generates a .pyi Python stub file for the GGML API using C header files.
  99. """
  100. v = PythonStubFuncDeclVisitor()
  101. v.visit(CParser().parse(header, "<input>"))
  102. keys = list(v.sigs.keys())
  103. keys.sort()
  104. return '\n'.join([
  105. '# auto-generated file',
  106. 'import ggml.ffi as ffi',
  107. 'import numpy as np',
  108. 'class lib:',
  109. *[v.sigs[k] for k in keys]
  110. ])