generate_sparsetools.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. """
  2. python generate_sparsetools.py
  3. Generate manual wrappers for C++ sparsetools code.
  4. Type codes used:
  5. 'i': integer scalar
  6. 'I': integer array
  7. 'T': data array
  8. 'B': boolean array
  9. 'V': std::vector<integer>*
  10. 'W': std::vector<data>*
  11. '*': indicates that the next argument is an output argument
  12. 'v': void
  13. 'l': 64-bit integer scalar
  14. See sparsetools.cxx for more details.
  15. """
  16. import optparse
  17. import os
  18. from distutils.dep_util import newer
  19. #
  20. # List of all routines and their argument types.
  21. #
  22. # The first code indicates the return value, the rest the arguments.
  23. #
  24. # bsr.h
  25. BSR_ROUTINES = """
  26. bsr_diagonal v iiiiiIIT*T
  27. bsr_tocsr v iiiiIIT*I*I*T
  28. bsr_scale_rows v iiiiII*TT
  29. bsr_scale_columns v iiiiII*TT
  30. bsr_sort_indices v iiii*I*I*T
  31. bsr_transpose v iiiiIIT*I*I*T
  32. bsr_matmat_pass2 v iiiiiIITIIT*I*I*T
  33. bsr_matvec v iiiiIITT*T
  34. bsr_matvecs v iiiiiIITT*T
  35. bsr_elmul_bsr v iiiiIITIIT*I*I*T
  36. bsr_eldiv_bsr v iiiiIITIIT*I*I*T
  37. bsr_plus_bsr v iiiiIITIIT*I*I*T
  38. bsr_minus_bsr v iiiiIITIIT*I*I*T
  39. bsr_maximum_bsr v iiiiIITIIT*I*I*T
  40. bsr_minimum_bsr v iiiiIITIIT*I*I*T
  41. bsr_ne_bsr v iiiiIITIIT*I*I*B
  42. bsr_lt_bsr v iiiiIITIIT*I*I*B
  43. bsr_gt_bsr v iiiiIITIIT*I*I*B
  44. bsr_le_bsr v iiiiIITIIT*I*I*B
  45. bsr_ge_bsr v iiiiIITIIT*I*I*B
  46. """
  47. # csc.h
  48. CSC_ROUTINES = """
  49. csc_diagonal v iiiIIT*T
  50. csc_tocsr v iiIIT*I*I*T
  51. csc_matmat_pass1 v iiIIII*I
  52. csc_matmat_pass2 v iiIITIIT*I*I*T
  53. csc_matvec v iiIITT*T
  54. csc_matvecs v iiiIITT*T
  55. csc_elmul_csc v iiIITIIT*I*I*T
  56. csc_eldiv_csc v iiIITIIT*I*I*T
  57. csc_plus_csc v iiIITIIT*I*I*T
  58. csc_minus_csc v iiIITIIT*I*I*T
  59. csc_maximum_csc v iiIITIIT*I*I*T
  60. csc_minimum_csc v iiIITIIT*I*I*T
  61. csc_ne_csc v iiIITIIT*I*I*B
  62. csc_lt_csc v iiIITIIT*I*I*B
  63. csc_gt_csc v iiIITIIT*I*I*B
  64. csc_le_csc v iiIITIIT*I*I*B
  65. csc_ge_csc v iiIITIIT*I*I*B
  66. """
  67. # csr.h
  68. CSR_ROUTINES = """
  69. csr_matmat_pass1 v iiIIII*I
  70. csr_matmat_pass2 v iiIITIIT*I*I*T
  71. csr_diagonal v iiiIIT*T
  72. csr_tocsc v iiIIT*I*I*T
  73. csr_tobsr v iiiiIIT*I*I*T
  74. csr_todense v iiIIT*T
  75. csr_matvec v iiIITT*T
  76. csr_matvecs v iiiIITT*T
  77. csr_elmul_csr v iiIITIIT*I*I*T
  78. csr_eldiv_csr v iiIITIIT*I*I*T
  79. csr_plus_csr v iiIITIIT*I*I*T
  80. csr_minus_csr v iiIITIIT*I*I*T
  81. csr_maximum_csr v iiIITIIT*I*I*T
  82. csr_minimum_csr v iiIITIIT*I*I*T
  83. csr_ne_csr v iiIITIIT*I*I*B
  84. csr_lt_csr v iiIITIIT*I*I*B
  85. csr_gt_csr v iiIITIIT*I*I*B
  86. csr_le_csr v iiIITIIT*I*I*B
  87. csr_ge_csr v iiIITIIT*I*I*B
  88. csr_scale_rows v iiII*TT
  89. csr_scale_columns v iiII*TT
  90. csr_sort_indices v iI*I*T
  91. csr_eliminate_zeros v ii*I*I*T
  92. csr_sum_duplicates v ii*I*I*T
  93. get_csr_submatrix v iiIITiiii*V*V*W
  94. csr_sample_values v iiIITiII*T
  95. csr_count_blocks i iiiiII
  96. csr_sample_offsets i iiIIiII*I
  97. expandptr v iI*I
  98. test_throw_error i
  99. csr_has_sorted_indices i iII
  100. csr_has_canonical_format i iII
  101. """
  102. # coo.h, dia.h, csgraph.h
  103. OTHER_ROUTINES = """
  104. coo_tocsr v iiiIIT*I*I*T
  105. coo_todense v iilIIT*Ti
  106. coo_matvec v lIITT*T
  107. dia_matvec v iiiiITT*T
  108. cs_graph_components i iII*I
  109. """
  110. # List of compilation units
  111. COMPILATION_UNITS = [
  112. ('bsr', BSR_ROUTINES),
  113. ('csr', CSR_ROUTINES),
  114. ('csc', CSC_ROUTINES),
  115. ('other', OTHER_ROUTINES),
  116. ]
  117. #
  118. # List of the supported index typenums and the corresponding C++ types
  119. #
  120. I_TYPES = [
  121. ('NPY_INT32', 'npy_int32'),
  122. ('NPY_INT64', 'npy_int64'),
  123. ]
  124. #
  125. # List of the supported data typenums and the corresponding C++ types
  126. #
  127. T_TYPES = [
  128. ('NPY_BOOL', 'npy_bool_wrapper'),
  129. ('NPY_BYTE', 'npy_byte'),
  130. ('NPY_UBYTE', 'npy_ubyte'),
  131. ('NPY_SHORT', 'npy_short'),
  132. ('NPY_USHORT', 'npy_ushort'),
  133. ('NPY_INT', 'npy_int'),
  134. ('NPY_UINT', 'npy_uint'),
  135. ('NPY_LONG', 'npy_long'),
  136. ('NPY_ULONG', 'npy_ulong'),
  137. ('NPY_LONGLONG', 'npy_longlong'),
  138. ('NPY_ULONGLONG', 'npy_ulonglong'),
  139. ('NPY_FLOAT', 'npy_float'),
  140. ('NPY_DOUBLE', 'npy_double'),
  141. ('NPY_LONGDOUBLE', 'npy_longdouble'),
  142. ('NPY_CFLOAT', 'npy_cfloat_wrapper'),
  143. ('NPY_CDOUBLE', 'npy_cdouble_wrapper'),
  144. ('NPY_CLONGDOUBLE', 'npy_clongdouble_wrapper'),
  145. ]
  146. #
  147. # Code templates
  148. #
  149. THUNK_TEMPLATE = """
  150. static PY_LONG_LONG %(name)s_thunk(int I_typenum, int T_typenum, void **a)
  151. {
  152. %(thunk_content)s
  153. }
  154. """
  155. METHOD_TEMPLATE = """
  156. NPY_VISIBILITY_HIDDEN PyObject *
  157. %(name)s_method(PyObject *self, PyObject *args)
  158. {
  159. return call_thunk('%(ret_spec)s', "%(arg_spec)s", %(name)s_thunk, args);
  160. }
  161. """
  162. GET_THUNK_CASE_TEMPLATE = """
  163. static int get_thunk_case(int I_typenum, int T_typenum)
  164. {
  165. %(content)s;
  166. return -1;
  167. }
  168. """
  169. #
  170. # Code generation
  171. #
  172. def get_thunk_type_set():
  173. """
  174. Get a list containing cartesian product of data types, plus a getter routine.
  175. Returns
  176. -------
  177. i_types : list [(j, I_typenum, None, I_type, None), ...]
  178. Pairing of index type numbers and the corresponding C++ types,
  179. and an unique index `j`. This is for routines that are parameterized
  180. only by I but not by T.
  181. it_types : list [(j, I_typenum, T_typenum, I_type, T_type), ...]
  182. Same as `i_types`, but for routines parameterized both by T and I.
  183. getter_code : str
  184. C++ code for a function that takes I_typenum, T_typenum and returns
  185. the unique index corresponding to the lists, or -1 if no match was
  186. found.
  187. """
  188. it_types = []
  189. i_types = []
  190. j = 0
  191. getter_code = " if (0) {}"
  192. for I_typenum, I_type in I_TYPES:
  193. piece = """
  194. else if (I_typenum == %(I_typenum)s) {
  195. if (T_typenum == -1) { return %(j)s; }"""
  196. getter_code += piece % dict(I_typenum=I_typenum, j=j)
  197. i_types.append((j, I_typenum, None, I_type, None))
  198. j += 1
  199. for T_typenum, T_type in T_TYPES:
  200. piece = """
  201. else if (T_typenum == %(T_typenum)s) { return %(j)s; }"""
  202. getter_code += piece % dict(T_typenum=T_typenum, j=j)
  203. it_types.append((j, I_typenum, T_typenum, I_type, T_type))
  204. j += 1
  205. getter_code += """
  206. }"""
  207. return i_types, it_types, GET_THUNK_CASE_TEMPLATE % dict(content=getter_code)
  208. def parse_routine(name, args, types):
  209. """
  210. Generate thunk and method code for a given routine.
  211. Parameters
  212. ----------
  213. name : str
  214. Name of the C++ routine
  215. args : str
  216. Argument list specification (in format explained above)
  217. types : list
  218. List of types to instantiate, as returned `get_thunk_type_set`
  219. """
  220. ret_spec = args[0]
  221. arg_spec = args[1:]
  222. def get_arglist(I_type, T_type):
  223. """
  224. Generate argument list for calling the C++ function
  225. """
  226. args = []
  227. next_is_writeable = False
  228. j = 0
  229. for t in arg_spec:
  230. const = '' if next_is_writeable else 'const '
  231. next_is_writeable = False
  232. if t == '*':
  233. next_is_writeable = True
  234. continue
  235. elif t == 'i':
  236. args.append("*(%s*)a[%d]" % (const + I_type, j))
  237. elif t == 'I':
  238. args.append("(%s*)a[%d]" % (const + I_type, j))
  239. elif t == 'T':
  240. args.append("(%s*)a[%d]" % (const + T_type, j))
  241. elif t == 'B':
  242. args.append("(npy_bool_wrapper*)a[%d]" % (j,))
  243. elif t == 'V':
  244. if const:
  245. raise ValueError("'V' argument must be an output arg")
  246. args.append("(std::vector<%s>*)a[%d]" % (I_type, j,))
  247. elif t == 'W':
  248. if const:
  249. raise ValueError("'W' argument must be an output arg")
  250. args.append("(std::vector<%s>*)a[%d]" % (T_type, j,))
  251. elif t == 'l':
  252. args.append("*(%snpy_int64*)a[%d]" % (const, j))
  253. else:
  254. raise ValueError("Invalid spec character %r" % (t,))
  255. j += 1
  256. return ", ".join(args)
  257. # Generate thunk code: a giant switch statement with different
  258. # type combinations inside.
  259. thunk_content = """int j = get_thunk_case(I_typenum, T_typenum);
  260. switch (j) {"""
  261. for j, I_typenum, T_typenum, I_type, T_type in types:
  262. arglist = get_arglist(I_type, T_type)
  263. if T_type is None:
  264. dispatch = "%s" % (I_type,)
  265. else:
  266. dispatch = "%s,%s" % (I_type, T_type)
  267. if 'B' in arg_spec:
  268. dispatch += ",npy_bool_wrapper"
  269. piece = """
  270. case %(j)s:"""
  271. if ret_spec == 'v':
  272. piece += """
  273. (void)%(name)s<%(dispatch)s>(%(arglist)s);
  274. return 0;"""
  275. else:
  276. piece += """
  277. return %(name)s<%(dispatch)s>(%(arglist)s);"""
  278. thunk_content += piece % dict(j=j, I_type=I_type, T_type=T_type,
  279. I_typenum=I_typenum, T_typenum=T_typenum,
  280. arglist=arglist, name=name,
  281. dispatch=dispatch)
  282. thunk_content += """
  283. default:
  284. throw std::runtime_error("internal error: invalid argument typenums");
  285. }"""
  286. thunk_code = THUNK_TEMPLATE % dict(name=name,
  287. thunk_content=thunk_content)
  288. # Generate method code
  289. method_code = METHOD_TEMPLATE % dict(name=name,
  290. ret_spec=ret_spec,
  291. arg_spec=arg_spec)
  292. return thunk_code, method_code
  293. def main():
  294. p = optparse.OptionParser(usage=(__doc__ or '').strip())
  295. p.add_option("--no-force", action="store_false",
  296. dest="force", default=True)
  297. options, args = p.parse_args()
  298. names = []
  299. i_types, it_types, getter_code = get_thunk_type_set()
  300. # Generate *_impl.h for each compilation unit
  301. for unit_name, routines in COMPILATION_UNITS:
  302. thunks = []
  303. methods = []
  304. # Generate thunks and methods for all routines
  305. for line in routines.splitlines():
  306. line = line.strip()
  307. if not line or line.startswith('#'):
  308. continue
  309. try:
  310. name, args = line.split(None, 1)
  311. except ValueError:
  312. raise ValueError("Malformed line: %r" % (line,))
  313. args = "".join(args.split())
  314. if 't' in args or 'T' in args:
  315. thunk, method = parse_routine(name, args, it_types)
  316. else:
  317. thunk, method = parse_routine(name, args, i_types)
  318. if name in names:
  319. raise ValueError("Duplicate routine %r" % (name,))
  320. names.append(name)
  321. thunks.append(thunk)
  322. methods.append(method)
  323. # Produce output
  324. dst = os.path.join(os.path.dirname(__file__),
  325. 'sparsetools',
  326. unit_name + '_impl.h')
  327. if newer(__file__, dst) or options.force:
  328. print("[generate_sparsetools] generating %r" % (dst,))
  329. with open(dst, 'w') as f:
  330. write_autogen_blurb(f)
  331. f.write(getter_code)
  332. for thunk in thunks:
  333. f.write(thunk)
  334. for method in methods:
  335. f.write(method)
  336. else:
  337. print("[generate_sparsetools] %r already up-to-date" % (dst,))
  338. # Generate code for method struct
  339. method_defs = ""
  340. for name in names:
  341. method_defs += "NPY_VISIBILITY_HIDDEN PyObject *%s_method(PyObject *, PyObject *);\n" % (name,)
  342. method_struct = """\nstatic struct PyMethodDef sparsetools_methods[] = {"""
  343. for name in names:
  344. method_struct += """
  345. {"%(name)s", (PyCFunction)%(name)s_method, METH_VARARGS, NULL},""" % dict(name=name)
  346. method_struct += """
  347. {NULL, NULL, 0, NULL}
  348. };"""
  349. # Produce sparsetools_impl.h
  350. dst = os.path.join(os.path.dirname(__file__),
  351. 'sparsetools',
  352. 'sparsetools_impl.h')
  353. if newer(__file__, dst) or options.force:
  354. print("[generate_sparsetools] generating %r" % (dst,))
  355. with open(dst, 'w') as f:
  356. write_autogen_blurb(f)
  357. f.write(method_defs)
  358. f.write(method_struct)
  359. else:
  360. print("[generate_sparsetools] %r already up-to-date" % (dst,))
  361. def write_autogen_blurb(stream):
  362. stream.write("""\
  363. /* This file is autogenerated by generate_sparsetools.py
  364. * Do not edit manually or check into VCS.
  365. */
  366. """)
  367. if __name__ == "__main__":
  368. main()