_generate_pyx.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757
  1. """
  2. Code generator script to make the Cython BLAS and LAPACK wrappers
  3. from the files "cython_blas_signatures.txt" and
  4. "cython_lapack_signatures.txt" which contain the signatures for
  5. all the BLAS/LAPACK routines that should be included in the wrappers.
  6. """
  7. from collections import defaultdict
  8. from operator import itemgetter
  9. import os
  10. BASE_DIR = os.path.abspath(os.path.dirname(__file__))
  11. fortran_types = {'int': 'integer',
  12. 'c': 'complex',
  13. 'd': 'double precision',
  14. 's': 'real',
  15. 'z': 'complex*16',
  16. 'char': 'character',
  17. 'bint': 'logical'}
  18. c_types = {'int': 'int',
  19. 'c': 'npy_complex64',
  20. 'd': 'double',
  21. 's': 'float',
  22. 'z': 'npy_complex128',
  23. 'char': 'char',
  24. 'bint': 'int',
  25. 'cselect1': '_cselect1',
  26. 'cselect2': '_cselect2',
  27. 'dselect2': '_dselect2',
  28. 'dselect3': '_dselect3',
  29. 'sselect2': '_sselect2',
  30. 'sselect3': '_sselect3',
  31. 'zselect1': '_zselect1',
  32. 'zselect2': '_zselect2'}
  33. def arg_names_and_types(args):
  34. return zip(*[arg.split(' *') for arg in args.split(', ')])
  35. pyx_func_template = """
  36. cdef extern from "{header_name}":
  37. void _fortran_{name} "F_FUNC({name}wrp, {upname}WRP)"({ret_type} *out, {fort_args}) nogil
  38. cdef {ret_type} {name}({args}) nogil:
  39. cdef {ret_type} out
  40. _fortran_{name}(&out, {argnames})
  41. return out
  42. """
  43. npy_types = {'c': 'npy_complex64', 'z': 'npy_complex128',
  44. 'cselect1': '_cselect1', 'cselect2': '_cselect2',
  45. 'dselect2': '_dselect2', 'dselect3': '_dselect3',
  46. 'sselect2': '_sselect2', 'sselect3': '_sselect3',
  47. 'zselect1': '_zselect1', 'zselect2': '_zselect2'}
  48. def arg_casts(arg):
  49. if arg in ['npy_complex64', 'npy_complex128', '_cselect1', '_cselect2',
  50. '_dselect2', '_dselect3', '_sselect2', '_sselect3',
  51. '_zselect1', '_zselect2']:
  52. return '<{0}*>'.format(arg)
  53. return ''
  54. def pyx_decl_func(name, ret_type, args, header_name):
  55. argtypes, argnames = arg_names_and_types(args)
  56. # Fix the case where one of the arguments has the same name as the
  57. # abbreviation for the argument type.
  58. # Otherwise the variable passed as an argument is considered overwrites
  59. # the previous typedef and Cython compilation fails.
  60. if ret_type in argnames:
  61. argnames = [n if n != ret_type else ret_type + '_' for n in argnames]
  62. argnames = [n if n not in ['lambda', 'in'] else n + '_'
  63. for n in argnames]
  64. args = ', '.join([' *'.join([n, t])
  65. for n, t in zip(argtypes, argnames)])
  66. argtypes = [npy_types.get(t, t) for t in argtypes]
  67. fort_args = ', '.join([' *'.join([n, t])
  68. for n, t in zip(argtypes, argnames)])
  69. argnames = [arg_casts(t) + n for n, t in zip(argnames, argtypes)]
  70. argnames = ', '.join(argnames)
  71. c_ret_type = c_types[ret_type]
  72. args = args.replace('lambda', 'lambda_')
  73. return pyx_func_template.format(name=name, upname=name.upper(), args=args,
  74. fort_args=fort_args, ret_type=ret_type,
  75. c_ret_type=c_ret_type, argnames=argnames,
  76. header_name=header_name)
  77. pyx_sub_template = """cdef extern from "{header_name}":
  78. void _fortran_{name} "F_FUNC({name},{upname})"({fort_args}) nogil
  79. cdef void {name}({args}) nogil:
  80. _fortran_{name}({argnames})
  81. """
  82. def pyx_decl_sub(name, args, header_name):
  83. argtypes, argnames = arg_names_and_types(args)
  84. argtypes = [npy_types.get(t, t) for t in argtypes]
  85. argnames = [n if n not in ['lambda', 'in'] else n + '_' for n in argnames]
  86. fort_args = ', '.join([' *'.join([n, t])
  87. for n, t in zip(argtypes, argnames)])
  88. argnames = [arg_casts(t) + n for n, t in zip(argnames, argtypes)]
  89. argnames = ', '.join(argnames)
  90. args = args.replace('*lambda,', '*lambda_,').replace('*in,', '*in_,')
  91. return pyx_sub_template.format(name=name, upname=name.upper(),
  92. args=args, fort_args=fort_args,
  93. argnames=argnames, header_name=header_name)
  94. blas_pyx_preamble = '''# cython: boundscheck = False
  95. # cython: wraparound = False
  96. # cython: cdivision = True
  97. """
  98. BLAS Functions for Cython
  99. =========================
  100. Usable from Cython via::
  101. cimport scipy.linalg.cython_blas
  102. These wrappers do not check for alignment of arrays.
  103. Alignment should be checked before these wrappers are used.
  104. Raw function pointers (Fortran-style pointer arguments):
  105. - {}
  106. """
  107. # Within scipy, these wrappers can be used via relative or absolute cimport.
  108. # Examples:
  109. # from ..linalg cimport cython_blas
  110. # from scipy.linalg cimport cython_blas
  111. # cimport scipy.linalg.cython_blas as cython_blas
  112. # cimport ..linalg.cython_blas as cython_blas
  113. # Within scipy, if BLAS functions are needed in C/C++/Fortran,
  114. # these wrappers should not be used.
  115. # The original libraries should be linked directly.
  116. from __future__ import absolute_import
  117. cdef extern from "fortran_defs.h":
  118. pass
  119. from numpy cimport npy_complex64, npy_complex128
  120. '''
  121. def make_blas_pyx_preamble(all_sigs):
  122. names = [sig[0] for sig in all_sigs]
  123. return blas_pyx_preamble.format("\n- ".join(names))
  124. lapack_pyx_preamble = '''"""
  125. LAPACK functions for Cython
  126. ===========================
  127. Usable from Cython via::
  128. cimport scipy.linalg.cython_lapack
  129. This module provides Cython-level wrappers for all primary routines included
  130. in LAPACK 3.4.0 except for ``zcgesv`` since its interface is not consistent
  131. from LAPACK 3.4.0 to 3.6.0. It also provides some of the
  132. fixed-api auxiliary routines.
  133. These wrappers do not check for alignment of arrays.
  134. Alignment should be checked before these wrappers are used.
  135. Raw function pointers (Fortran-style pointer arguments):
  136. - {}
  137. """
  138. # Within scipy, these wrappers can be used via relative or absolute cimport.
  139. # Examples:
  140. # from ..linalg cimport cython_lapack
  141. # from scipy.linalg cimport cython_lapack
  142. # cimport scipy.linalg.cython_lapack as cython_lapack
  143. # cimport ..linalg.cython_lapack as cython_lapack
  144. # Within scipy, if LAPACK functions are needed in C/C++/Fortran,
  145. # these wrappers should not be used.
  146. # The original libraries should be linked directly.
  147. from __future__ import absolute_import
  148. cdef extern from "fortran_defs.h":
  149. pass
  150. from numpy cimport npy_complex64, npy_complex128
  151. cdef extern from "_lapack_subroutines.h":
  152. # Function pointer type declarations for
  153. # gees and gges families of functions.
  154. ctypedef bint _cselect1(npy_complex64*)
  155. ctypedef bint _cselect2(npy_complex64*, npy_complex64*)
  156. ctypedef bint _dselect2(d*, d*)
  157. ctypedef bint _dselect3(d*, d*, d*)
  158. ctypedef bint _sselect2(s*, s*)
  159. ctypedef bint _sselect3(s*, s*, s*)
  160. ctypedef bint _zselect1(npy_complex128*)
  161. ctypedef bint _zselect2(npy_complex128*, npy_complex128*)
  162. '''
  163. def make_lapack_pyx_preamble(all_sigs):
  164. names = [sig[0] for sig in all_sigs]
  165. return lapack_pyx_preamble.format("\n- ".join(names))
  166. blas_py_wrappers = """
  167. # Python-accessible wrappers for testing:
  168. cdef inline bint _is_contiguous(double[:,:] a, int axis) nogil:
  169. return (a.strides[axis] == sizeof(a[0,0]) or a.shape[axis] == 1)
  170. cpdef float complex _test_cdotc(float complex[:] cx, float complex[:] cy) nogil:
  171. cdef:
  172. int n = cx.shape[0]
  173. int incx = cx.strides[0] // sizeof(cx[0])
  174. int incy = cy.strides[0] // sizeof(cy[0])
  175. return cdotc(&n, &cx[0], &incx, &cy[0], &incy)
  176. cpdef float complex _test_cdotu(float complex[:] cx, float complex[:] cy) nogil:
  177. cdef:
  178. int n = cx.shape[0]
  179. int incx = cx.strides[0] // sizeof(cx[0])
  180. int incy = cy.strides[0] // sizeof(cy[0])
  181. return cdotu(&n, &cx[0], &incx, &cy[0], &incy)
  182. cpdef double _test_dasum(double[:] dx) nogil:
  183. cdef:
  184. int n = dx.shape[0]
  185. int incx = dx.strides[0] // sizeof(dx[0])
  186. return dasum(&n, &dx[0], &incx)
  187. cpdef double _test_ddot(double[:] dx, double[:] dy) nogil:
  188. cdef:
  189. int n = dx.shape[0]
  190. int incx = dx.strides[0] // sizeof(dx[0])
  191. int incy = dy.strides[0] // sizeof(dy[0])
  192. return ddot(&n, &dx[0], &incx, &dy[0], &incy)
  193. cpdef int _test_dgemm(double alpha, double[:,:] a, double[:,:] b, double beta,
  194. double[:,:] c) nogil except -1:
  195. cdef:
  196. char *transa
  197. char *transb
  198. int m, n, k, lda, ldb, ldc
  199. double *a0=&a[0,0]
  200. double *b0=&b[0,0]
  201. double *c0=&c[0,0]
  202. # In the case that c is C contiguous, swap a and b and
  203. # swap whether or not each of them is transposed.
  204. # This can be done because a.dot(b) = b.T.dot(a.T).T.
  205. if _is_contiguous(c, 1):
  206. if _is_contiguous(a, 1):
  207. transb = 'n'
  208. ldb = (&a[1,0]) - a0 if a.shape[0] > 1 else 1
  209. elif _is_contiguous(a, 0):
  210. transb = 't'
  211. ldb = (&a[0,1]) - a0 if a.shape[1] > 1 else 1
  212. else:
  213. with gil:
  214. raise ValueError("Input 'a' is neither C nor Fortran contiguous.")
  215. if _is_contiguous(b, 1):
  216. transa = 'n'
  217. lda = (&b[1,0]) - b0 if b.shape[0] > 1 else 1
  218. elif _is_contiguous(b, 0):
  219. transa = 't'
  220. lda = (&b[0,1]) - b0 if b.shape[1] > 1 else 1
  221. else:
  222. with gil:
  223. raise ValueError("Input 'b' is neither C nor Fortran contiguous.")
  224. k = b.shape[0]
  225. if k != a.shape[1]:
  226. with gil:
  227. raise ValueError("Shape mismatch in input arrays.")
  228. m = b.shape[1]
  229. n = a.shape[0]
  230. if n != c.shape[0] or m != c.shape[1]:
  231. with gil:
  232. raise ValueError("Output array does not have the correct shape.")
  233. ldc = (&c[1,0]) - c0 if c.shape[0] > 1 else 1
  234. dgemm(transa, transb, &m, &n, &k, &alpha, b0, &lda, a0,
  235. &ldb, &beta, c0, &ldc)
  236. elif _is_contiguous(c, 0):
  237. if _is_contiguous(a, 1):
  238. transa = 't'
  239. lda = (&a[1,0]) - a0 if a.shape[0] > 1 else 1
  240. elif _is_contiguous(a, 0):
  241. transa = 'n'
  242. lda = (&a[0,1]) - a0 if a.shape[1] > 1 else 1
  243. else:
  244. with gil:
  245. raise ValueError("Input 'a' is neither C nor Fortran contiguous.")
  246. if _is_contiguous(b, 1):
  247. transb = 't'
  248. ldb = (&b[1,0]) - b0 if b.shape[0] > 1 else 1
  249. elif _is_contiguous(b, 0):
  250. transb = 'n'
  251. ldb = (&b[0,1]) - b0 if b.shape[1] > 1 else 1
  252. else:
  253. with gil:
  254. raise ValueError("Input 'b' is neither C nor Fortran contiguous.")
  255. m = a.shape[0]
  256. k = a.shape[1]
  257. if k != b.shape[0]:
  258. with gil:
  259. raise ValueError("Shape mismatch in input arrays.")
  260. n = b.shape[1]
  261. if m != c.shape[0] or n != c.shape[1]:
  262. with gil:
  263. raise ValueError("Output array does not have the correct shape.")
  264. ldc = (&c[0,1]) - c0 if c.shape[1] > 1 else 1
  265. dgemm(transa, transb, &m, &n, &k, &alpha, a0, &lda, b0,
  266. &ldb, &beta, c0, &ldc)
  267. else:
  268. with gil:
  269. raise ValueError("Input 'c' is neither C nor Fortran contiguous.")
  270. return 0
  271. cpdef double _test_dnrm2(double[:] x) nogil:
  272. cdef:
  273. int n = x.shape[0]
  274. int incx = x.strides[0] // sizeof(x[0])
  275. return dnrm2(&n, &x[0], &incx)
  276. cpdef double _test_dzasum(double complex[:] zx) nogil:
  277. cdef:
  278. int n = zx.shape[0]
  279. int incx = zx.strides[0] // sizeof(zx[0])
  280. return dzasum(&n, &zx[0], &incx)
  281. cpdef double _test_dznrm2(double complex[:] x) nogil:
  282. cdef:
  283. int n = x.shape[0]
  284. int incx = x.strides[0] // sizeof(x[0])
  285. return dznrm2(&n, &x[0], &incx)
  286. cpdef int _test_icamax(float complex[:] cx) nogil:
  287. cdef:
  288. int n = cx.shape[0]
  289. int incx = cx.strides[0] // sizeof(cx[0])
  290. return icamax(&n, &cx[0], &incx)
  291. cpdef int _test_idamax(double[:] dx) nogil:
  292. cdef:
  293. int n = dx.shape[0]
  294. int incx = dx.strides[0] // sizeof(dx[0])
  295. return idamax(&n, &dx[0], &incx)
  296. cpdef int _test_isamax(float[:] sx) nogil:
  297. cdef:
  298. int n = sx.shape[0]
  299. int incx = sx.strides[0] // sizeof(sx[0])
  300. return isamax(&n, &sx[0], &incx)
  301. cpdef int _test_izamax(double complex[:] zx) nogil:
  302. cdef:
  303. int n = zx.shape[0]
  304. int incx = zx.strides[0] // sizeof(zx[0])
  305. return izamax(&n, &zx[0], &incx)
  306. cpdef float _test_sasum(float[:] sx) nogil:
  307. cdef:
  308. int n = sx.shape[0]
  309. int incx = sx.shape[0] // sizeof(sx[0])
  310. return sasum(&n, &sx[0], &incx)
  311. cpdef float _test_scasum(float complex[:] cx) nogil:
  312. cdef:
  313. int n = cx.shape[0]
  314. int incx = cx.strides[0] // sizeof(cx[0])
  315. return scasum(&n, &cx[0], &incx)
  316. cpdef float _test_scnrm2(float complex[:] x) nogil:
  317. cdef:
  318. int n = x.shape[0]
  319. int incx = x.strides[0] // sizeof(x[0])
  320. return scnrm2(&n, &x[0], &incx)
  321. cpdef float _test_sdot(float[:] sx, float[:] sy) nogil:
  322. cdef:
  323. int n = sx.shape[0]
  324. int incx = sx.strides[0] // sizeof(sx[0])
  325. int incy = sy.strides[0] // sizeof(sy[0])
  326. return sdot(&n, &sx[0], &incx, &sy[0], &incy)
  327. cpdef float _test_snrm2(float[:] x) nogil:
  328. cdef:
  329. int n = x.shape[0]
  330. int incx = x.shape[0] // sizeof(x[0])
  331. return snrm2(&n, &x[0], &incx)
  332. cpdef double complex _test_zdotc(double complex[:] zx, double complex[:] zy) nogil:
  333. cdef:
  334. int n = zx.shape[0]
  335. int incx = zx.strides[0] // sizeof(zx[0])
  336. int incy = zy.strides[0] // sizeof(zy[0])
  337. return zdotc(&n, &zx[0], &incx, &zy[0], &incy)
  338. cpdef double complex _test_zdotu(double complex[:] zx, double complex[:] zy) nogil:
  339. cdef:
  340. int n = zx.shape[0]
  341. int incx = zx.strides[0] // sizeof(zx[0])
  342. int incy = zy.strides[0] // sizeof(zy[0])
  343. return zdotu(&n, &zx[0], &incx, &zy[0], &incy)
  344. """
  345. def generate_blas_pyx(func_sigs, sub_sigs, all_sigs, header_name):
  346. funcs = "\n".join(pyx_decl_func(*(s+(header_name,))) for s in func_sigs)
  347. subs = "\n" + "\n".join(pyx_decl_sub(*(s[::2]+(header_name,)))
  348. for s in sub_sigs)
  349. return make_blas_pyx_preamble(all_sigs) + funcs + subs + blas_py_wrappers
  350. lapack_py_wrappers = """
  351. # Python accessible wrappers for testing:
  352. def _test_dlamch(cmach):
  353. # This conversion is necessary to handle Python 3 strings.
  354. cmach_bytes = bytes(cmach)
  355. # Now that it is a bytes representation, a non-temporary variable
  356. # must be passed as a part of the function call.
  357. cdef char* cmach_char = cmach_bytes
  358. return dlamch(cmach_char)
  359. def _test_slamch(cmach):
  360. # This conversion is necessary to handle Python 3 strings.
  361. cmach_bytes = bytes(cmach)
  362. # Now that it is a bytes representation, a non-temporary variable
  363. # must be passed as a part of the function call.
  364. cdef char* cmach_char = cmach_bytes
  365. return slamch(cmach_char)
  366. """
  367. def generate_lapack_pyx(func_sigs, sub_sigs, all_sigs, header_name):
  368. funcs = "\n".join(pyx_decl_func(*(s+(header_name,))) for s in func_sigs)
  369. subs = "\n" + "\n".join(pyx_decl_sub(*(s[::2]+(header_name,)))
  370. for s in sub_sigs)
  371. preamble = make_lapack_pyx_preamble(all_sigs)
  372. return preamble + funcs + subs + lapack_py_wrappers
  373. pxd_template = """ctypedef {ret_type} {name}_t({args}) nogil
  374. cdef {name}_t *{name}_f
  375. """
  376. pxd_template = """cdef {ret_type} {name}({args}) nogil
  377. """
  378. def pxd_decl(name, ret_type, args):
  379. args = args.replace('lambda', 'lambda_').replace('*in,', '*in_,')
  380. return pxd_template.format(name=name, ret_type=ret_type, args=args)
  381. blas_pxd_preamble = """# Within scipy, these wrappers can be used via relative or absolute cimport.
  382. # Examples:
  383. # from ..linalg cimport cython_blas
  384. # from scipy.linalg cimport cython_blas
  385. # cimport scipy.linalg.cython_blas as cython_blas
  386. # cimport ..linalg.cython_blas as cython_blas
  387. # Within scipy, if BLAS functions are needed in C/C++/Fortran,
  388. # these wrappers should not be used.
  389. # The original libraries should be linked directly.
  390. ctypedef float s
  391. ctypedef double d
  392. ctypedef float complex c
  393. ctypedef double complex z
  394. """
  395. def generate_blas_pxd(all_sigs):
  396. body = '\n'.join(pxd_decl(*sig) for sig in all_sigs)
  397. return blas_pxd_preamble + body
  398. lapack_pxd_preamble = """# Within scipy, these wrappers can be used via relative or absolute cimport.
  399. # Examples:
  400. # from ..linalg cimport cython_lapack
  401. # from scipy.linalg cimport cython_lapack
  402. # cimport scipy.linalg.cython_lapack as cython_lapack
  403. # cimport ..linalg.cython_lapack as cython_lapack
  404. # Within scipy, if LAPACK functions are needed in C/C++/Fortran,
  405. # these wrappers should not be used.
  406. # The original libraries should be linked directly.
  407. ctypedef float s
  408. ctypedef double d
  409. ctypedef float complex c
  410. ctypedef double complex z
  411. # Function pointer type declarations for
  412. # gees and gges families of functions.
  413. ctypedef bint cselect1(c*)
  414. ctypedef bint cselect2(c*, c*)
  415. ctypedef bint dselect2(d*, d*)
  416. ctypedef bint dselect3(d*, d*, d*)
  417. ctypedef bint sselect2(s*, s*)
  418. ctypedef bint sselect3(s*, s*, s*)
  419. ctypedef bint zselect1(z*)
  420. ctypedef bint zselect2(z*, z*)
  421. """
  422. def generate_lapack_pxd(all_sigs):
  423. return lapack_pxd_preamble + '\n'.join(pxd_decl(*sig) for sig in all_sigs)
  424. fortran_template = """ subroutine {name}wrp(
  425. + ret,
  426. + {argnames}
  427. + )
  428. external {wrapper}
  429. {ret_type} {wrapper}
  430. {ret_type} ret
  431. {argdecls}
  432. ret = {wrapper}(
  433. + {argnames}
  434. + )
  435. end
  436. """
  437. dims = {'work': '(*)', 'ab': '(ldab,*)', 'a': '(lda,*)', 'dl': '(*)',
  438. 'd': '(*)', 'du': '(*)', 'ap': '(*)', 'e': '(*)', 'lld': '(*)'}
  439. xy_specialized_dims = {'x': '', 'y': ''}
  440. a_specialized_dims = {'a': '(*)'}
  441. special_cases = defaultdict(dict,
  442. ladiv = xy_specialized_dims,
  443. lanhf = a_specialized_dims,
  444. lansf = a_specialized_dims,
  445. lapy2 = xy_specialized_dims,
  446. lapy3 = xy_specialized_dims)
  447. def process_fortran_name(name, funcname):
  448. if 'inc' in name:
  449. return name
  450. special = special_cases[funcname[1:]]
  451. if 'x' in name or 'y' in name:
  452. suffix = special.get(name, '(n)')
  453. else:
  454. suffix = special.get(name, '')
  455. return name + suffix
  456. def called_name(name):
  457. included = ['cdotc', 'cdotu', 'zdotc', 'zdotu', 'cladiv', 'zladiv']
  458. if name in included:
  459. return "w" + name
  460. return name
  461. def fort_subroutine_wrapper(name, ret_type, args):
  462. wrapper = called_name(name)
  463. types, names = arg_names_and_types(args)
  464. argnames = ',\n + '.join(names)
  465. names = [process_fortran_name(n, name) for n in names]
  466. argdecls = '\n '.join('{0} {1}'.format(fortran_types[t], n)
  467. for n, t in zip(names, types))
  468. return fortran_template.format(name=name, wrapper=wrapper,
  469. argnames=argnames, argdecls=argdecls,
  470. ret_type=fortran_types[ret_type])
  471. def generate_fortran(func_sigs):
  472. return "\n".join(fort_subroutine_wrapper(*sig) for sig in func_sigs)
  473. def make_c_args(args):
  474. types, names = arg_names_and_types(args)
  475. types = [c_types[arg] for arg in types]
  476. return ', '.join('{0} *{1}'.format(t, n) for t, n in zip(types, names))
  477. c_func_template = ("void F_FUNC({name}wrp, {upname}WRP)"
  478. "({return_type} *ret, {args});\n")
  479. def c_func_decl(name, return_type, args):
  480. args = make_c_args(args)
  481. return_type = c_types[return_type]
  482. return c_func_template.format(name=name, upname=name.upper(),
  483. return_type=return_type, args=args)
  484. c_sub_template = "void F_FUNC({name},{upname})({args});\n"
  485. def c_sub_decl(name, return_type, args):
  486. args = make_c_args(args)
  487. return c_sub_template.format(name=name, upname=name.upper(), args=args)
  488. c_preamble = """#ifndef SCIPY_LINALG_{lib}_FORTRAN_WRAPPERS_H
  489. #define SCIPY_LINALG_{lib}_FORTRAN_WRAPPERS_H
  490. #include "fortran_defs.h"
  491. #include "numpy/arrayobject.h"
  492. """
  493. lapack_decls = """
  494. typedef int (*_cselect1)(npy_complex64*);
  495. typedef int (*_cselect2)(npy_complex64*, npy_complex64*);
  496. typedef int (*_dselect2)(double*, double*);
  497. typedef int (*_dselect3)(double*, double*, double*);
  498. typedef int (*_sselect2)(float*, float*);
  499. typedef int (*_sselect3)(float*, float*, float*);
  500. typedef int (*_zselect1)(npy_complex128*);
  501. typedef int (*_zselect2)(npy_complex128*, npy_complex128*);
  502. """
  503. cpp_guard = """
  504. #ifdef __cplusplus
  505. extern "C" {
  506. #endif
  507. """
  508. c_end = """
  509. #ifdef __cplusplus
  510. }
  511. #endif
  512. #endif
  513. """
  514. def generate_c_header(func_sigs, sub_sigs, all_sigs, lib_name):
  515. funcs = "".join(c_func_decl(*sig) for sig in func_sigs)
  516. subs = "\n" + "".join(c_sub_decl(*sig) for sig in sub_sigs)
  517. if lib_name == 'LAPACK':
  518. preamble = (c_preamble.format(lib=lib_name) + lapack_decls)
  519. else:
  520. preamble = c_preamble.format(lib=lib_name)
  521. return "".join([preamble, cpp_guard, funcs, subs, c_end])
  522. def split_signature(sig):
  523. name_and_type, args = sig[:-1].split('(')
  524. ret_type, name = name_and_type.split(' ')
  525. return name, ret_type, args
  526. def filter_lines(lines):
  527. lines = [line for line in map(str.strip, lines)
  528. if line and not line.startswith('#')]
  529. func_sigs = [split_signature(line) for line in lines
  530. if line.split(' ')[0] != 'void']
  531. sub_sigs = [split_signature(line) for line in lines
  532. if line.split(' ')[0] == 'void']
  533. all_sigs = list(sorted(func_sigs + sub_sigs, key=itemgetter(0)))
  534. return func_sigs, sub_sigs, all_sigs
  535. def all_newer(src_files, dst_files):
  536. from distutils.dep_util import newer
  537. return all(os.path.exists(dst) and newer(dst, src)
  538. for dst in dst_files for src in src_files)
  539. def make_all(blas_signature_file="cython_blas_signatures.txt",
  540. lapack_signature_file="cython_lapack_signatures.txt",
  541. blas_name="cython_blas",
  542. lapack_name="cython_lapack",
  543. blas_fortran_name="_blas_subroutine_wrappers.f",
  544. lapack_fortran_name="_lapack_subroutine_wrappers.f",
  545. blas_header_name="_blas_subroutines.h",
  546. lapack_header_name="_lapack_subroutines.h"):
  547. src_files = (os.path.abspath(__file__),
  548. blas_signature_file,
  549. lapack_signature_file)
  550. dst_files = (blas_name + '.pyx',
  551. blas_name + '.pxd',
  552. blas_fortran_name,
  553. blas_header_name,
  554. lapack_name + '.pyx',
  555. lapack_name + '.pxd',
  556. lapack_fortran_name,
  557. lapack_header_name)
  558. os.chdir(BASE_DIR)
  559. if all_newer(src_files, dst_files):
  560. print("scipy/linalg/_generate_pyx.py: all files up-to-date")
  561. return
  562. comments = ["This file was generated by _generate_pyx.py.\n",
  563. "Do not edit this file directly.\n"]
  564. ccomment = ''.join(['/* ' + line.rstrip() + ' */\n'
  565. for line in comments]) + '\n'
  566. pyxcomment = ''.join(['# ' + line for line in comments]) + '\n'
  567. fcomment = ''.join(['c ' + line for line in comments]) + '\n'
  568. with open(blas_signature_file, 'r') as f:
  569. blas_sigs = f.readlines()
  570. blas_sigs = filter_lines(blas_sigs)
  571. blas_pyx = generate_blas_pyx(*(blas_sigs + (blas_header_name,)))
  572. with open(blas_name + '.pyx', 'w') as f:
  573. f.write(pyxcomment)
  574. f.write(blas_pyx)
  575. blas_pxd = generate_blas_pxd(blas_sigs[2])
  576. with open(blas_name + '.pxd', 'w') as f:
  577. f.write(pyxcomment)
  578. f.write(blas_pxd)
  579. blas_fortran = generate_fortran(blas_sigs[0])
  580. with open(blas_fortran_name, 'w') as f:
  581. f.write(fcomment)
  582. f.write(blas_fortran)
  583. blas_c_header = generate_c_header(*(blas_sigs + ('BLAS',)))
  584. with open(blas_header_name, 'w') as f:
  585. f.write(ccomment)
  586. f.write(blas_c_header)
  587. with open(lapack_signature_file, 'r') as f:
  588. lapack_sigs = f.readlines()
  589. lapack_sigs = filter_lines(lapack_sigs)
  590. lapack_pyx = generate_lapack_pyx(*(lapack_sigs + (lapack_header_name,)))
  591. with open(lapack_name + '.pyx', 'w') as f:
  592. f.write(pyxcomment)
  593. f.write(lapack_pyx)
  594. lapack_pxd = generate_lapack_pxd(lapack_sigs[2])
  595. with open(lapack_name + '.pxd', 'w') as f:
  596. f.write(pyxcomment)
  597. f.write(lapack_pxd)
  598. lapack_fortran = generate_fortran(lapack_sigs[0])
  599. with open(lapack_fortran_name, 'w') as f:
  600. f.write(fcomment)
  601. f.write(lapack_fortran)
  602. lapack_c_header = generate_c_header(*(lapack_sigs + ('LAPACK',)))
  603. with open(lapack_header_name, 'w') as f:
  604. f.write(ccomment)
  605. f.write(lapack_c_header)
  606. if __name__ == '__main__':
  607. make_all()