| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128 | 
							- """
 
-   This generates .pyi stubs for the cffi Python bindings generated by regenerate.py
 
- """
 
- import sys, re, itertools
 
- sys.path.extend(['.', '..']) # for pycparser
 
- from pycparser import c_ast, parse_file, CParser
 
- import pycparser.plyparser
 
- from pycparser.c_ast import PtrDecl, TypeDecl, FuncDecl, EllipsisParam, IdentifierType, Struct, Enum, Typedef
 
- from typing import Tuple
 
- __c_type_to_python_type = {
 
-     'void': 'None', '_Bool': 'bool',
 
-     'char': 'int', 'short': 'int', 'int': 'int', 'long': 'int',
 
-     'ptrdiff_t': 'int', 'size_t': 'int',
 
-     'int8_t': 'int', 'uint8_t': 'int',
 
-     'int16_t': 'int', 'uint16_t': 'int',
 
-     'int32_t': 'int', 'uint32_t': 'int',
 
-     'int64_t': 'int', 'uint64_t': 'int',
 
-     'float': 'float', 'double': 'float',
 
-     'ggml_fp16_t': 'np.float16',
 
- }
 
- def format_type(t: TypeDecl):
 
-     if isinstance(t, PtrDecl) or isinstance(t, Struct):
 
-         return 'ffi.CData'
 
-     if isinstance(t, Enum):
 
-         return 'int'
 
-     if isinstance(t, TypeDecl):
 
-         return format_type(t.type)
 
-     if isinstance(t, IdentifierType):
 
-         assert len(t.names) == 1, f'Expected a single name, got {t.names}'
 
-         return __c_type_to_python_type.get(t.names[0]) or 'ffi.CData'
 
-     return t.name
 
- class PythonStubFuncDeclVisitor(c_ast.NodeVisitor):
 
-     def __init__(self):
 
-         self.sigs = {}
 
-         self.sources = {}
 
-     def get_source_snippet_lines(self, coord: pycparser.plyparser.Coord) -> Tuple[list[str], list[str]]:
 
-         if coord.file not in self.sources:
 
-             with open(coord.file, 'rt') as f:
 
-                 self.sources[coord.file] = f.readlines()
 
-         source_lines = self.sources[coord.file]
 
-         ncomment_lines = len(list(itertools.takewhile(lambda i: re.search(r'^\s*(//|/\*)', source_lines[i]), range(coord.line - 2, -1, -1))))
 
-         comment_lines = [l.strip() for l in source_lines[coord.line - 1 - ncomment_lines:coord.line - 1]]
 
-         decl_lines = []
 
-         for line in source_lines[coord.line - 1:]:
 
-             decl_lines.append(line.rstrip())
 
-             if (';' in line) or ('{' in line): break
 
-         return (comment_lines, decl_lines)
 
-     def visit_Enum(self, node: Enum):
 
-         if node.values is not None:
 
-           for e in node.values.enumerators:
 
-               self.sigs[e.name] = f'  @property\n  def {e.name}(self) -> int: ...'
 
-     def visit_Typedef(self, node: Typedef):
 
-         pass
 
-     def visit_FuncDecl(self, node: FuncDecl):
 
-         ret_type = node.type
 
-         is_ptr = False
 
-         while isinstance(ret_type, PtrDecl):
 
-             ret_type = ret_type.type
 
-             is_ptr = True
 
-         fun_name = ret_type.declname
 
-         if fun_name.startswith('__'):
 
-             return
 
-         args = []
 
-         argnames = []
 
-         def gen_name(stem):
 
-             i = 1
 
-             while True:
 
-                 new_name = stem if i == 1 else f'{stem}{i}'
 
-                 if new_name not in argnames: return new_name
 
-                 i += 1
 
-         for a in node.args.params:
 
-             if isinstance(a, EllipsisParam):
 
-                 arg_name = gen_name('args')
 
-                 argnames.append(arg_name)
 
-                 args.append('*' + gen_name('args'))
 
-             elif format_type(a.type) == 'None':
 
-                 continue
 
-             else:
 
-                 arg_name = a.name or gen_name('arg')
 
-                 argnames.append(arg_name)
 
-                 args.append(f'{arg_name}: {format_type(a.type)}')
 
-         ret = format_type(ret_type if not is_ptr else node.type)
 
-         comment_lines, decl_lines = self.get_source_snippet_lines(node.coord)
 
-         lines = [f'  def {fun_name}({", ".join(args)}) -> {ret}:']
 
-         if len(comment_lines) == 0 and len(decl_lines) == 1:
 
-             lines += [f'    """{decl_lines[0]}"""']
 
-         else:
 
-             lines += ['    """']
 
-             lines += [f'    {c.lstrip("/* ")}' for c in comment_lines]
 
-             if len(comment_lines) > 0:
 
-                 lines += ['']
 
-             lines += [f'    {d}' for d in decl_lines]
 
-             lines += ['    """']
 
-         lines += ['    ...']
 
-         self.sigs[fun_name] = '\n'.join(lines)
 
- def generate_stubs(header: str):
 
-     """
 
-       Generates a .pyi Python stub file for the GGML API using C header files.
 
-     """
 
-     v = PythonStubFuncDeclVisitor()
 
-     v.visit(CParser().parse(header, "<input>"))
 
-     keys = list(v.sigs.keys())
 
-     keys.sort()
 
-     return '\n'.join([
 
-         '# auto-generated file',
 
-         'import ggml.ffi as ffi',
 
-         'import numpy as np',
 
-         'class lib:',
 
-         *[v.sigs[k] for k in keys]
 
-     ])
 
 
  |