_generate_pyx.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378
  1. """
  2. python _generate_pyx.py
  3. Generate Ufunc definition source files for scipy.special. Produces
  4. files '_ufuncs.c' and '_ufuncs_cxx.c' by first producing Cython.
  5. This will generate both calls to PyUFunc_FromFuncAndData and the
  6. required ufunc inner loops.
  7. The functions signatures are contained in 'functions.json', the syntax
  8. for a function signature is
  9. <function>: <name> ':' <input> '*' <output>
  10. '->' <retval> '*' <ignored_retval>
  11. <input>: <typecode>*
  12. <output>: <typecode>*
  13. <retval>: <typecode>?
  14. <ignored_retval>: <typecode>?
  15. <headers>: <header_name> [',' <header_name>]*
  16. The input parameter types are denoted by single character type
  17. codes, according to
  18. 'f': 'float'
  19. 'd': 'double'
  20. 'g': 'long double'
  21. 'F': 'float complex'
  22. 'D': 'double complex'
  23. 'G': 'long double complex'
  24. 'i': 'int'
  25. 'l': 'long'
  26. 'v': 'void'
  27. If multiple kernel functions are given for a single ufunc, the one
  28. which is used is determined by the standard ufunc mechanism. Kernel
  29. functions that are listed first are also matched first against the
  30. ufunc input types, so functions listed earlier take precedence.
  31. In addition, versions with casted variables, such as d->f,D->F and
  32. i->d are automatically generated.
  33. There should be either a single header that contains all of the kernel
  34. functions listed, or there should be one header for each kernel
  35. function. Cython pxd files are allowed in addition to .h files.
  36. Cython functions may use fused types, but the names in the list
  37. should be the specialized ones, such as 'somefunc[float]'.
  38. Function coming from C++ should have ``++`` appended to the name of
  39. the header.
  40. Floating-point exceptions inside these Ufuncs are converted to
  41. special function errors --- which are separately controlled by the
  42. user, and off by default, as they are usually not especially useful
  43. for the user.
  44. The C++ module
  45. --------------
  46. In addition to ``_ufuncs`` module, a second module ``_ufuncs_cxx`` is
  47. generated. This module only exports function pointers that are to be
  48. used when constructing some of the ufuncs in ``_ufuncs``. The function
  49. pointers are exported via Cython's standard mechanism.
  50. This mainly avoids build issues --- Python distutils has no way to
  51. figure out what to do if you want to link both C++ and Fortran code in
  52. the same shared library.
  53. """
  54. from __future__ import division, print_function, absolute_import
  55. #---------------------------------------------------------------------------------
  56. # Extra code
  57. #---------------------------------------------------------------------------------
  58. UFUNCS_EXTRA_CODE_COMMON = """\
  59. # This file is automatically generated by _generate_pyx.py.
  60. # Do not edit manually!
  61. from __future__ import absolute_import
  62. include "_ufuncs_extra_code_common.pxi"
  63. """
  64. UFUNCS_EXTRA_CODE = """\
  65. include "_ufuncs_extra_code.pxi"
  66. """
  67. UFUNCS_EXTRA_CODE_BOTTOM = """\
  68. #
  69. # Aliases
  70. #
  71. jn = jv
  72. """
  73. CYTHON_SPECIAL_PXD = """\
  74. # This file is automatically generated by _generate_pyx.py.
  75. # Do not edit manually!
  76. """
  77. CYTHON_SPECIAL_PYX = """\
  78. # This file is automatically generated by _generate_pyx.py.
  79. # Do not edit manually!
  80. \"\"\"
  81. .. highlight:: cython
  82. ================================
  83. Cython API for Special Functions
  84. ================================
  85. Scalar, typed versions of many of the functions in ``scipy.special``
  86. can be accessed directly from Cython; the complete list is given
  87. below. Functions are overloaded using Cython fused types so their
  88. names match their ufunc counterpart. The module follows the following
  89. conventions:
  90. - If a function's ufunc counterpart returns multiple values, then the
  91. function returns its outputs via pointers in the final arguments
  92. - If a function's ufunc counterpart returns a single value, then the
  93. function's output is returned directly.
  94. The module is usable from Cython via::
  95. cimport scipy.special.cython_special
  96. Error Handling
  97. ==============
  98. Functions can indicate an error by returning ``nan``; however they
  99. cannot emit warnings like their counterparts in ``scipy.special``.
  100. Available Functions
  101. ===================
  102. FUNCLIST
  103. \"\"\"
  104. from __future__ import absolute_import
  105. include "_cython_special.pxi"
  106. """
  107. #---------------------------------------------------------------------------------
  108. # Code generation
  109. #---------------------------------------------------------------------------------
  110. import os
  111. import optparse
  112. import re
  113. import textwrap
  114. import itertools
  115. import numpy
  116. import json
  117. BASE_DIR = os.path.abspath(os.path.dirname(__file__))
  118. add_newdocs = __import__('add_newdocs')
  119. CY_TYPES = {
  120. 'f': 'float',
  121. 'd': 'double',
  122. 'g': 'long double',
  123. 'F': 'float complex',
  124. 'D': 'double complex',
  125. 'G': 'long double complex',
  126. 'i': 'int',
  127. 'l': 'long',
  128. 'v': 'void',
  129. }
  130. C_TYPES = {
  131. 'f': 'npy_float',
  132. 'd': 'npy_double',
  133. 'g': 'npy_longdouble',
  134. 'F': 'npy_cfloat',
  135. 'D': 'npy_cdouble',
  136. 'G': 'npy_clongdouble',
  137. 'i': 'npy_int',
  138. 'l': 'npy_long',
  139. 'v': 'void',
  140. }
  141. TYPE_NAMES = {
  142. 'f': 'NPY_FLOAT',
  143. 'd': 'NPY_DOUBLE',
  144. 'g': 'NPY_LONGDOUBLE',
  145. 'F': 'NPY_CFLOAT',
  146. 'D': 'NPY_CDOUBLE',
  147. 'G': 'NPY_CLONGDOUBLE',
  148. 'i': 'NPY_INT',
  149. 'l': 'NPY_LONG',
  150. }
  151. CYTHON_SPECIAL_BENCHFUNCS = {
  152. 'airy': ['d*dddd', 'D*DDDD'],
  153. 'beta': ['dd'],
  154. 'erf': ['d', 'D'],
  155. 'exprel': ['d'],
  156. 'gamma': ['d', 'D'],
  157. 'jv': ['dd', 'dD'],
  158. 'loggamma': ['D'],
  159. 'logit': ['d'],
  160. 'psi': ['d', 'D'],
  161. }
  162. def underscore(arg):
  163. return arg.replace(" ", "_")
  164. def cast_order(c):
  165. return ['ilfdgFDG'.index(x) for x in c]
  166. # These downcasts will cause the function to return NaNs, unless the
  167. # values happen to coincide exactly.
  168. DANGEROUS_DOWNCAST = set([
  169. ('F', 'i'), ('F', 'l'), ('F', 'f'), ('F', 'd'), ('F', 'g'),
  170. ('D', 'i'), ('D', 'l'), ('D', 'f'), ('D', 'd'), ('D', 'g'),
  171. ('G', 'i'), ('G', 'l'), ('G', 'f'), ('G', 'd'), ('G', 'g'),
  172. ('f', 'i'), ('f', 'l'),
  173. ('d', 'i'), ('d', 'l'),
  174. ('g', 'i'), ('g', 'l'),
  175. ('l', 'i'),
  176. ])
  177. NAN_VALUE = {
  178. 'f': 'NPY_NAN',
  179. 'd': 'NPY_NAN',
  180. 'g': 'NPY_NAN',
  181. 'F': 'NPY_NAN',
  182. 'D': 'NPY_NAN',
  183. 'G': 'NPY_NAN',
  184. 'i': '0xbad0bad0',
  185. 'l': '0xbad0bad0',
  186. }
  187. def generate_loop(func_inputs, func_outputs, func_retval,
  188. ufunc_inputs, ufunc_outputs):
  189. """
  190. Generate a UFunc loop function that calls a function given as its
  191. data parameter with the specified input and output arguments and
  192. return value.
  193. This function can be passed to PyUFunc_FromFuncAndData.
  194. Parameters
  195. ----------
  196. func_inputs, func_outputs, func_retval : str
  197. Signature of the function to call, given as type codes of the
  198. input, output and return value arguments. These 1-character
  199. codes are given according to the CY_TYPES and TYPE_NAMES
  200. lists above.
  201. The corresponding C function signature to be called is:
  202. retval func(intype1 iv1, intype2 iv2, ..., outtype1 *ov1, ...);
  203. If len(ufunc_outputs) == len(func_outputs)+1, the return value
  204. is treated as the first output argument. Otherwise, the return
  205. value is ignored.
  206. ufunc_inputs, ufunc_outputs : str
  207. Ufunc input and output signature.
  208. This does not have to exactly match the function signature,
  209. as long as the type casts work out on the C level.
  210. Returns
  211. -------
  212. loop_name
  213. Name of the generated loop function.
  214. loop_body
  215. Generated C code for the loop.
  216. """
  217. if len(func_inputs) != len(ufunc_inputs):
  218. raise ValueError("Function and ufunc have different number of inputs")
  219. if len(func_outputs) != len(ufunc_outputs) and not (
  220. func_retval != "v" and len(func_outputs)+1 == len(ufunc_outputs)):
  221. raise ValueError("Function retval and ufunc outputs don't match")
  222. name = "loop_%s_%s_%s_As_%s_%s" % (
  223. func_retval, func_inputs, func_outputs, ufunc_inputs, ufunc_outputs
  224. )
  225. body = "cdef void %s(char **args, np.npy_intp *dims, np.npy_intp *steps, void *data) nogil:\n" % name
  226. body += " cdef np.npy_intp i, n = dims[0]\n"
  227. body += " cdef void *func = (<void**>data)[0]\n"
  228. body += " cdef char *func_name = <char*>(<void**>data)[1]\n"
  229. for j in range(len(ufunc_inputs)):
  230. body += " cdef char *ip%d = args[%d]\n" % (j, j)
  231. for j in range(len(ufunc_outputs)):
  232. body += " cdef char *op%d = args[%d]\n" % (j, j + len(ufunc_inputs))
  233. ftypes = []
  234. fvars = []
  235. outtypecodes = []
  236. for j in range(len(func_inputs)):
  237. ftypes.append(CY_TYPES[func_inputs[j]])
  238. fvars.append("<%s>(<%s*>ip%d)[0]" % (
  239. CY_TYPES[func_inputs[j]],
  240. CY_TYPES[ufunc_inputs[j]], j))
  241. if len(func_outputs)+1 == len(ufunc_outputs):
  242. func_joff = 1
  243. outtypecodes.append(func_retval)
  244. body += " cdef %s ov0\n" % (CY_TYPES[func_retval],)
  245. else:
  246. func_joff = 0
  247. for j, outtype in enumerate(func_outputs):
  248. body += " cdef %s ov%d\n" % (CY_TYPES[outtype], j+func_joff)
  249. ftypes.append("%s *" % CY_TYPES[outtype])
  250. fvars.append("&ov%d" % (j+func_joff))
  251. outtypecodes.append(outtype)
  252. body += " for i in range(n):\n"
  253. if len(func_outputs)+1 == len(ufunc_outputs):
  254. rv = "ov0 = "
  255. else:
  256. rv = ""
  257. funcall = " %s(<%s(*)(%s) nogil>func)(%s)\n" % (
  258. rv, CY_TYPES[func_retval], ", ".join(ftypes), ", ".join(fvars))
  259. # Cast-check inputs and call function
  260. input_checks = []
  261. for j in range(len(func_inputs)):
  262. if (ufunc_inputs[j], func_inputs[j]) in DANGEROUS_DOWNCAST:
  263. chk = "<%s>(<%s*>ip%d)[0] == (<%s*>ip%d)[0]" % (
  264. CY_TYPES[func_inputs[j]], CY_TYPES[ufunc_inputs[j]], j,
  265. CY_TYPES[ufunc_inputs[j]], j)
  266. input_checks.append(chk)
  267. if input_checks:
  268. body += " if %s:\n" % (" and ".join(input_checks))
  269. body += " " + funcall
  270. body += " else:\n"
  271. body += " sf_error.error(func_name, sf_error.DOMAIN, \"invalid input argument\")\n"
  272. for j, outtype in enumerate(outtypecodes):
  273. body += " ov%d = <%s>%s\n" % (
  274. j, CY_TYPES[outtype], NAN_VALUE[outtype])
  275. else:
  276. body += funcall
  277. # Assign and cast-check output values
  278. for j, (outtype, fouttype) in enumerate(zip(ufunc_outputs, outtypecodes)):
  279. if (fouttype, outtype) in DANGEROUS_DOWNCAST:
  280. body += " if ov%d == <%s>ov%d:\n" % (j, CY_TYPES[outtype], j)
  281. body += " (<%s *>op%d)[0] = <%s>ov%d\n" % (
  282. CY_TYPES[outtype], j, CY_TYPES[outtype], j)
  283. body += " else:\n"
  284. body += " sf_error.error(func_name, sf_error.DOMAIN, \"invalid output\")\n"
  285. body += " (<%s *>op%d)[0] = <%s>%s\n" % (
  286. CY_TYPES[outtype], j, CY_TYPES[outtype], NAN_VALUE[outtype])
  287. else:
  288. body += " (<%s *>op%d)[0] = <%s>ov%d\n" % (
  289. CY_TYPES[outtype], j, CY_TYPES[outtype], j)
  290. for j in range(len(ufunc_inputs)):
  291. body += " ip%d += steps[%d]\n" % (j, j)
  292. for j in range(len(ufunc_outputs)):
  293. body += " op%d += steps[%d]\n" % (j, j + len(ufunc_inputs))
  294. body += " sf_error.check_fpe(func_name)\n"
  295. return name, body
  296. def generate_fused_type(codes):
  297. """
  298. Generate name of and cython code for a fused type.
  299. Parameters
  300. ----------
  301. typecodes : str
  302. Valid inputs to CY_TYPES (i.e. f, d, g, ...).
  303. """
  304. cytypes = map(lambda x: CY_TYPES[x], codes)
  305. name = codes + "_number_t"
  306. declaration = ["ctypedef fused " + name + ":"]
  307. for cytype in cytypes:
  308. declaration.append(" " + cytype)
  309. declaration = "\n".join(declaration)
  310. return name, declaration
  311. def generate_bench(name, codes):
  312. tab = " "*4
  313. top, middle, end = [], [], []
  314. tmp = codes.split("*")
  315. if len(tmp) > 1:
  316. incodes = tmp[0]
  317. outcodes = tmp[1]
  318. else:
  319. incodes = tmp[0]
  320. outcodes = ""
  321. inargs, inargs_and_types = [], []
  322. for n, code in enumerate(incodes):
  323. arg = "x{}".format(n)
  324. inargs.append(arg)
  325. inargs_and_types.append("{} {}".format(CY_TYPES[code], arg))
  326. line = "def {{}}(int N, {}):".format(", ".join(inargs_and_types))
  327. top.append(line)
  328. top.append(tab + "cdef int n")
  329. outargs = []
  330. for n, code in enumerate(outcodes):
  331. arg = "y{}".format(n)
  332. outargs.append("&{}".format(arg))
  333. line = "cdef {} {}".format(CY_TYPES[code], arg)
  334. middle.append(tab + line)
  335. end.append(tab + "for n in range(N):")
  336. end.append(2*tab + "{}({})")
  337. pyfunc = "_bench_{}_{}_{}".format(name, incodes, "py")
  338. cyfunc = "_bench_{}_{}_{}".format(name, incodes, "cy")
  339. pytemplate = "\n".join(top + end)
  340. cytemplate = "\n".join(top + middle + end)
  341. pybench = pytemplate.format(pyfunc, "_ufuncs." + name, ", ".join(inargs))
  342. cybench = cytemplate.format(cyfunc, name, ", ".join(inargs + outargs))
  343. return pybench, cybench
  344. def generate_doc(name, specs):
  345. tab = " "*4
  346. doc = ["- :py:func:`~scipy.special.{}`::\n".format(name)]
  347. for spec in specs:
  348. incodes, outcodes = spec.split("->")
  349. incodes = incodes.split("*")
  350. intypes = list(map(lambda x: CY_TYPES[x], incodes[0]))
  351. if len(incodes) > 1:
  352. types = map(lambda x: "{} *".format(CY_TYPES[x]), incodes[1])
  353. intypes.extend(types)
  354. outtype = CY_TYPES[outcodes]
  355. line = "{} {}({})".format(outtype, name, ", ".join(intypes))
  356. doc.append(2*tab + line)
  357. doc[-1] = "{}\n".format(doc[-1])
  358. doc = "\n".join(doc)
  359. return doc
  360. def npy_cdouble_from_double_complex(var):
  361. """Cast a cython double complex to a numpy cdouble."""
  362. res = "_complexstuff.npy_cdouble_from_double_complex({})".format(var)
  363. return res
  364. def double_complex_from_npy_cdouble(var):
  365. """Cast a numpy cdouble to a cython double complex."""
  366. res = "_complexstuff.double_complex_from_npy_cdouble({})".format(var)
  367. return res
  368. def iter_variants(inputs, outputs):
  369. """
  370. Generate variants of UFunc signatures, by changing variable types,
  371. within the limitation that the corresponding C types casts still
  372. work out.
  373. This does not generate all possibilities, just the ones required
  374. for the ufunc to work properly with the most common data types.
  375. Parameters
  376. ----------
  377. inputs, outputs : str
  378. UFunc input and output signature strings
  379. Yields
  380. ------
  381. new_input, new_output : str
  382. Modified input and output strings.
  383. Also the original input/output pair is yielded.
  384. """
  385. maps = [
  386. # always use long instead of int (more common type on 64-bit)
  387. ('i', 'l'),
  388. ]
  389. # float32-preserving signatures
  390. if not ('i' in inputs or 'l' in inputs):
  391. # Don't add float32 versions of ufuncs with integer arguments, as this
  392. # can lead to incorrect dtype selection if the integer arguments are
  393. # arrays, but float arguments are scalars.
  394. # For instance sph_harm(0,[0],0,0).dtype == complex64
  395. # This may be a Numpy bug, but we need to work around it.
  396. # cf. gh-4895, https://github.com/numpy/numpy/issues/5895
  397. maps = maps + [(a + 'dD', b + 'fF') for a, b in maps]
  398. # do the replacements
  399. for src, dst in maps:
  400. new_inputs = inputs
  401. new_outputs = outputs
  402. for a, b in zip(src, dst):
  403. new_inputs = new_inputs.replace(a, b)
  404. new_outputs = new_outputs.replace(a, b)
  405. yield new_inputs, new_outputs
  406. class Func(object):
  407. """
  408. Base class for Ufunc and FusedFunc.
  409. """
  410. def __init__(self, name, signatures):
  411. self.name = name
  412. self.signatures = []
  413. self.function_name_overrides = {}
  414. for header in signatures.keys():
  415. for name, sig in signatures[header].items():
  416. inarg, outarg, ret = self._parse_signature(sig)
  417. self.signatures.append((name, inarg, outarg, ret, header))
  418. def _parse_signature(self, sig):
  419. m = re.match(r"\s*([fdgFDGil]*)\s*\*\s*([fdgFDGil]*)\s*->\s*([*fdgFDGil]*)\s*$", sig)
  420. if m:
  421. inarg, outarg, ret = [x.strip() for x in m.groups()]
  422. if ret.count('*') > 1:
  423. raise ValueError("{}: Invalid signature: {}".format(self.name, sig))
  424. return inarg, outarg, ret
  425. m = re.match(r"\s*([fdgFDGil]*)\s*->\s*([fdgFDGil]?)\s*$", sig)
  426. if m:
  427. inarg, ret = [x.strip() for x in m.groups()]
  428. return inarg, "", ret
  429. raise ValueError("{}: Invalid signature: {}".format(self.name, sig))
  430. def get_prototypes(self, nptypes_for_h=False):
  431. prototypes = []
  432. for func_name, inarg, outarg, ret, header in self.signatures:
  433. ret = ret.replace('*', '')
  434. c_args = ([C_TYPES[x] for x in inarg]
  435. + [C_TYPES[x] + ' *' for x in outarg])
  436. cy_args = ([CY_TYPES[x] for x in inarg]
  437. + [CY_TYPES[x] + ' *' for x in outarg])
  438. c_proto = "%s (*)(%s)" % (C_TYPES[ret], ", ".join(c_args))
  439. if header.endswith("h") and nptypes_for_h:
  440. cy_proto = c_proto + "nogil"
  441. else:
  442. cy_proto = "%s (*)(%s) nogil" % (CY_TYPES[ret], ", ".join(cy_args))
  443. prototypes.append((func_name, c_proto, cy_proto, header))
  444. return prototypes
  445. def cython_func_name(self, c_name, specialized=False, prefix="_func_",
  446. override=True):
  447. # act on function name overrides
  448. if override and c_name in self.function_name_overrides:
  449. c_name = self.function_name_overrides[c_name]
  450. prefix = ""
  451. # support fused types
  452. m = re.match(r'^(.*?)(\[.*\])$', c_name)
  453. if m:
  454. c_base_name, fused_part = m.groups()
  455. else:
  456. c_base_name, fused_part = c_name, ""
  457. if specialized:
  458. return "%s%s%s" % (prefix, c_base_name, fused_part.replace(' ', '_'))
  459. else:
  460. return "%s%s" % (prefix, c_base_name,)
  461. class Ufunc(Func):
  462. """
  463. Ufunc signature, restricted format suitable for special functions.
  464. Parameters
  465. ----------
  466. name
  467. Name of the ufunc to create
  468. signature
  469. String of form 'func: fff*ff->f, func2: ddd->*i' describing
  470. the C-level functions and types of their input arguments
  471. and return values.
  472. The syntax is 'function_name: inputparams*outputparams->output_retval*ignored_retval'
  473. Attributes
  474. ----------
  475. name : str
  476. Python name for the Ufunc
  477. signatures : list of (func_name, inarg_spec, outarg_spec, ret_spec, header_name)
  478. List of parsed signatures
  479. doc : str
  480. Docstring, obtained from add_newdocs
  481. function_name_overrides : dict of str->str
  482. Overrides for the function names in signatures
  483. """
  484. def __init__(self, name, signatures):
  485. super(Ufunc, self).__init__(name, signatures)
  486. self.doc = add_newdocs.get("scipy.special." + name)
  487. if self.doc is None:
  488. raise ValueError("No docstring for ufunc %r" % name)
  489. self.doc = textwrap.dedent(self.doc).strip()
  490. def _get_signatures_and_loops(self, all_loops):
  491. inarg_num = None
  492. outarg_num = None
  493. seen = set()
  494. variants = []
  495. def add_variant(func_name, inarg, outarg, ret, inp, outp):
  496. if inp in seen:
  497. return
  498. seen.add(inp)
  499. sig = (func_name, inp, outp)
  500. if "v" in outp:
  501. raise ValueError("%s: void signature %r" % (self.name, sig))
  502. if len(inp) != inarg_num or len(outp) != outarg_num:
  503. raise ValueError("%s: signature %r does not have %d/%d input/output args" % (
  504. self.name, sig,
  505. inarg_num, outarg_num))
  506. loop_name, loop = generate_loop(inarg, outarg, ret, inp, outp)
  507. all_loops[loop_name] = loop
  508. variants.append((func_name, loop_name, inp, outp))
  509. # First add base variants
  510. for func_name, inarg, outarg, ret, header in self.signatures:
  511. outp = re.sub(r'\*.*', '', ret) + outarg
  512. ret = ret.replace('*', '')
  513. if inarg_num is None:
  514. inarg_num = len(inarg)
  515. outarg_num = len(outp)
  516. inp, outp = list(iter_variants(inarg, outp))[0]
  517. add_variant(func_name, inarg, outarg, ret, inp, outp)
  518. # Then the supplementary ones
  519. for func_name, inarg, outarg, ret, header in self.signatures:
  520. outp = re.sub(r'\*.*', '', ret) + outarg
  521. ret = ret.replace('*', '')
  522. for inp, outp in iter_variants(inarg, outp):
  523. add_variant(func_name, inarg, outarg, ret, inp, outp)
  524. # Then sort variants to input argument cast order
  525. # -- the sort is stable, so functions earlier in the signature list
  526. # are still preferred
  527. variants.sort(key=lambda v: cast_order(v[2]))
  528. return variants, inarg_num, outarg_num
  529. def generate(self, all_loops):
  530. toplevel = ""
  531. variants, inarg_num, outarg_num = self._get_signatures_and_loops(all_loops)
  532. loops = []
  533. funcs = []
  534. types = []
  535. for func_name, loop_name, inputs, outputs in variants:
  536. for x in inputs:
  537. types.append(TYPE_NAMES[x])
  538. for x in outputs:
  539. types.append(TYPE_NAMES[x])
  540. loops.append(loop_name)
  541. funcs.append(func_name)
  542. toplevel += "cdef np.PyUFuncGenericFunction ufunc_%s_loops[%d]\n" % (self.name, len(loops))
  543. toplevel += "cdef void *ufunc_%s_ptr[%d]\n" % (self.name, 2*len(funcs))
  544. toplevel += "cdef void *ufunc_%s_data[%d]\n" % (self.name, len(funcs))
  545. toplevel += "cdef char ufunc_%s_types[%d]\n" % (self.name, len(types))
  546. toplevel += 'cdef char *ufunc_%s_doc = (\n "%s")\n' % (
  547. self.name,
  548. self.doc.replace("\\", "\\\\").replace('"', '\\"').replace('\n', '\\n\"\n "')
  549. )
  550. for j, function in enumerate(loops):
  551. toplevel += "ufunc_%s_loops[%d] = <np.PyUFuncGenericFunction>%s\n" % (self.name, j, function)
  552. for j, type in enumerate(types):
  553. toplevel += "ufunc_%s_types[%d] = <char>%s\n" % (self.name, j, type)
  554. for j, func in enumerate(funcs):
  555. toplevel += "ufunc_%s_ptr[2*%d] = <void*>%s\n" % (self.name, j,
  556. self.cython_func_name(func, specialized=True))
  557. toplevel += "ufunc_%s_ptr[2*%d+1] = <void*>(<char*>\"%s\")\n" % (self.name, j,
  558. self.name)
  559. for j, func in enumerate(funcs):
  560. toplevel += "ufunc_%s_data[%d] = &ufunc_%s_ptr[2*%d]\n" % (
  561. self.name, j, self.name, j)
  562. toplevel += ('@ = np.PyUFunc_FromFuncAndData(ufunc_@_loops, '
  563. 'ufunc_@_data, ufunc_@_types, %d, %d, %d, 0, '
  564. '"@", ufunc_@_doc, 0)\n' % (len(types)/(inarg_num+outarg_num),
  565. inarg_num, outarg_num)
  566. ).replace('@', self.name)
  567. return toplevel
  568. class FusedFunc(Func):
  569. """
  570. Generate code for a fused-type special function that can be
  571. cimported in cython.
  572. """
  573. def __init__(self, name, signatures):
  574. super(FusedFunc, self).__init__(name, signatures)
  575. self.doc = "See the documentation for scipy.special." + self.name
  576. # "codes" are the keys for CY_TYPES
  577. self.incodes, self.outcodes = self._get_codes()
  578. self.fused_types = set()
  579. self.intypes, infused_types = self._get_types(self.incodes)
  580. self.fused_types.update(infused_types)
  581. self.outtypes, outfused_types = self._get_types(self.outcodes)
  582. self.fused_types.update(outfused_types)
  583. self.invars, self.outvars = self._get_vars()
  584. def _get_codes(self):
  585. inarg_num, outarg_num = None, None
  586. all_inp, all_outp = [], []
  587. for _, inarg, outarg, ret, _ in self.signatures:
  588. outp = re.sub(r'\*.*', '', ret) + outarg
  589. if inarg_num is None:
  590. inarg_num = len(inarg)
  591. outarg_num = len(outp)
  592. inp, outp = list(iter_variants(inarg, outp))[0]
  593. all_inp.append(inp)
  594. all_outp.append(outp)
  595. incodes = []
  596. for n in range(inarg_num):
  597. codes = unique(map(lambda x: x[n], all_inp))
  598. codes.sort()
  599. incodes.append(''.join(codes))
  600. outcodes = []
  601. for n in range(outarg_num):
  602. codes = unique(map(lambda x: x[n], all_outp))
  603. codes.sort()
  604. outcodes.append(''.join(codes))
  605. return tuple(incodes), tuple(outcodes)
  606. def _get_types(self, codes):
  607. all_types = []
  608. fused_types = set()
  609. for code in codes:
  610. if len(code) == 1:
  611. # It's not a fused type
  612. all_types.append((CY_TYPES[code], code))
  613. else:
  614. # It's a fused type
  615. fused_type, dec = generate_fused_type(code)
  616. fused_types.add(dec)
  617. all_types.append((fused_type, code))
  618. return all_types, fused_types
  619. def _get_vars(self):
  620. invars = []
  621. for n in range(len(self.intypes)):
  622. invars.append("x{}".format(n))
  623. outvars = []
  624. for n in range(len(self.outtypes)):
  625. outvars.append("y{}".format(n))
  626. return invars, outvars
  627. def _get_conditional(self, types, codes, adverb):
  628. """Generate an if/elif/else clause that selects a specialization of
  629. fused types.
  630. """
  631. clauses = []
  632. seen = set()
  633. for (typ, typcode), code in zip(types, codes):
  634. if len(typcode) == 1:
  635. continue
  636. if typ not in seen:
  637. clauses.append("{} is {}".format(typ, underscore(CY_TYPES[code])))
  638. seen.add(typ)
  639. if clauses and adverb != "else":
  640. line = "{} {}:".format(adverb, " and ".join(clauses))
  641. elif clauses and adverb == "else":
  642. line = "else:"
  643. else:
  644. line = None
  645. return line
  646. def _get_incallvars(self, intypes, c):
  647. """Generate pure input variables to a specialization,
  648. i.e. variables that aren't used to return a value.
  649. """
  650. incallvars = []
  651. for n, intype in enumerate(intypes):
  652. var = self.invars[n]
  653. if c and intype == "double complex":
  654. var = npy_cdouble_from_double_complex(var)
  655. incallvars.append(var)
  656. return incallvars
  657. def _get_outcallvars(self, outtypes, c):
  658. """Generate output variables to a specialization,
  659. i.e. pointers that are used to return values.
  660. """
  661. outcallvars, tmpvars, casts = [], [], []
  662. # If there are more out variables than out types, we want the
  663. # tail of the out variables
  664. start = len(self.outvars) - len(outtypes)
  665. outvars = self.outvars[start:]
  666. for n, (var, outtype) in enumerate(zip(outvars, outtypes)):
  667. if c and outtype == "double complex":
  668. tmp = "tmp{}".format(n)
  669. tmpvars.append(tmp)
  670. outcallvars.append("&{}".format(tmp))
  671. tmpcast = double_complex_from_npy_cdouble(tmp)
  672. casts.append("{}[0] = {}".format(var, tmpcast))
  673. else:
  674. outcallvars.append("{}".format(var))
  675. return outcallvars, tmpvars, casts
  676. def _get_nan_decs(self):
  677. """Set all variables to nan for specializations of fused types for
  678. which don't have signatures.
  679. """
  680. # Set non fused-type variables to nan
  681. tab = " "*4
  682. fused_types, lines = [], [tab + "else:"]
  683. seen = set()
  684. for outvar, outtype, code in zip(self.outvars, self.outtypes, self.outcodes):
  685. if len(code) == 1:
  686. line = "{}[0] = {}".format(outvar, NAN_VALUE[code])
  687. lines.append(2*tab + line)
  688. else:
  689. fused_type = outtype
  690. name, _ = fused_type
  691. if name not in seen:
  692. fused_types.append(fused_type)
  693. seen.add(name)
  694. if not fused_types:
  695. return lines
  696. # Set fused-type variables to nan
  697. all_codes = []
  698. for fused_type in fused_types:
  699. _, codes = fused_type
  700. all_codes.append(codes)
  701. all_codes = tuple(all_codes)
  702. codelens = list(map(lambda x: len(x), all_codes))
  703. last = numpy.product(codelens) - 1
  704. for m, codes in enumerate(itertools.product(*all_codes)):
  705. fused_codes, decs = [], []
  706. for n, fused_type in enumerate(fused_types):
  707. code = codes[n]
  708. fused_codes.append(underscore(CY_TYPES[code]))
  709. for nn, outvar in enumerate(self.outvars):
  710. if self.outtypes[nn] == fused_type:
  711. line = "{}[0] = {}".format(outvar, NAN_VALUE[code])
  712. decs.append(line)
  713. if m == 0:
  714. adverb = "if"
  715. elif m == last:
  716. adverb = "else"
  717. else:
  718. adverb = "elif"
  719. cond = self._get_conditional(fused_types, codes, adverb)
  720. lines.append(2*tab + cond)
  721. lines.extend(map(lambda x: 3*tab + x, decs))
  722. return lines
  723. def _get_tmp_decs(self, all_tmpvars):
  724. """Generate the declarations of any necessary temporary
  725. variables.
  726. """
  727. tab = " "*4
  728. tmpvars = list(all_tmpvars)
  729. tmpvars.sort()
  730. tmpdecs = []
  731. for tmpvar in tmpvars:
  732. line = "cdef npy_cdouble {}".format(tmpvar)
  733. tmpdecs.append(tab + line)
  734. return tmpdecs
  735. def _get_python_wrap(self):
  736. """Generate a python wrapper for functions which pass their
  737. arguments as pointers.
  738. """
  739. tab = " "*4
  740. body, callvars = [], []
  741. for (intype, _), invar in zip(self.intypes, self.invars):
  742. callvars.append("{} {}".format(intype, invar))
  743. line = "def _{}_pywrap({}):".format(self.name, ", ".join(callvars))
  744. body.append(line)
  745. for (outtype, _), outvar in zip(self.outtypes, self.outvars):
  746. line = "cdef {} {}".format(outtype, outvar)
  747. body.append(tab + line)
  748. addr_outvars = map(lambda x: "&{}".format(x), self.outvars)
  749. line = "{}({}, {})".format(self.name, ", ".join(self.invars),
  750. ", ".join(addr_outvars))
  751. body.append(tab + line)
  752. line = "return {}".format(", ".join(self.outvars))
  753. body.append(tab + line)
  754. body = "\n".join(body)
  755. return body
  756. def _get_common(self, signum, sig):
  757. """Generate code common to all the _generate_* methods."""
  758. tab = " "*4
  759. func_name, incodes, outcodes, retcode, header = sig
  760. # Convert ints to longs; cf. iter_variants()
  761. incodes = incodes.replace('i', 'l')
  762. outcodes = outcodes.replace('i', 'l')
  763. retcode = retcode.replace('i', 'l')
  764. if header.endswith("h"):
  765. c = True
  766. else:
  767. c = False
  768. if header.endswith("++"):
  769. cpp = True
  770. else:
  771. cpp = False
  772. intypes = list(map(lambda x: CY_TYPES[x], incodes))
  773. outtypes = list(map(lambda x: CY_TYPES[x], outcodes))
  774. retcode = re.sub(r'\*.*', '', retcode)
  775. if not retcode:
  776. retcode = 'v'
  777. rettype = CY_TYPES[retcode]
  778. if cpp:
  779. # Functions from _ufuncs_cxx are exported as a void*
  780. # pointers; cast them to the correct types
  781. func_name = "scipy.special._ufuncs_cxx._export_{}".format(func_name)
  782. func_name = "(<{}(*)({}) nogil>{})"\
  783. .format(rettype, ", ".join(intypes + outtypes), func_name)
  784. else:
  785. func_name = self.cython_func_name(func_name, specialized=True)
  786. if signum == 0:
  787. adverb = "if"
  788. else:
  789. adverb = "elif"
  790. cond = self._get_conditional(self.intypes, incodes, adverb)
  791. if cond:
  792. lines = [tab + cond]
  793. sp = 2*tab
  794. else:
  795. lines = []
  796. sp = tab
  797. return func_name, incodes, outcodes, retcode, \
  798. intypes, outtypes, rettype, c, lines, sp
  799. def _generate_from_return_and_no_outargs(self):
  800. tab = " "*4
  801. specs, body = [], []
  802. for signum, sig in enumerate(self.signatures):
  803. func_name, incodes, outcodes, retcode, intypes, outtypes, \
  804. rettype, c, lines, sp = self._get_common(signum, sig)
  805. body.extend(lines)
  806. # Generate the call to the specialized function
  807. callvars = self._get_incallvars(intypes, c)
  808. call = "{}({})".format(func_name, ", ".join(callvars))
  809. if c and rettype == "double complex":
  810. call = double_complex_from_npy_cdouble(call)
  811. line = sp + "return {}".format(call)
  812. body.append(line)
  813. sig = "{}->{}".format(incodes, retcode)
  814. specs.append(sig)
  815. if len(specs) > 1:
  816. # Return nan for signatures without a specialization
  817. body.append(tab + "else:")
  818. outtype, outcodes = self.outtypes[0]
  819. last = len(outcodes) - 1
  820. if len(outcodes) == 1:
  821. line = "return {}".format(NAN_VALUE[outcodes])
  822. body.append(2*tab + line)
  823. else:
  824. for n, code in enumerate(outcodes):
  825. if n == 0:
  826. adverb = "if"
  827. elif n == last:
  828. adverb = "else"
  829. else:
  830. adverb = "elif"
  831. cond = self._get_conditional(self.outtypes, code, adverb)
  832. body.append(2*tab + cond)
  833. line = "return {}".format(NAN_VALUE[code])
  834. body.append(3*tab + line)
  835. # Generate the head of the function
  836. callvars, head = [], []
  837. for n, (intype, _) in enumerate(self.intypes):
  838. callvars.append("{} {}".format(intype, self.invars[n]))
  839. (outtype, _) = self.outtypes[0]
  840. dec = "cpdef {} {}({}) nogil".format(outtype, self.name, ", ".join(callvars))
  841. head.append(dec + ":")
  842. head.append(tab + '"""{}"""'.format(self.doc))
  843. src = "\n".join(head + body)
  844. return dec, src, specs
  845. def _generate_from_outargs_and_no_return(self):
  846. tab = " "*4
  847. all_tmpvars = set()
  848. specs, body = [], []
  849. for signum, sig in enumerate(self.signatures):
  850. func_name, incodes, outcodes, retcode, intypes, outtypes, \
  851. rettype, c, lines, sp = self._get_common(signum, sig)
  852. body.extend(lines)
  853. # Generate the call to the specialized function
  854. callvars = self._get_incallvars(intypes, c)
  855. outcallvars, tmpvars, casts = self._get_outcallvars(outtypes, c)
  856. callvars.extend(outcallvars)
  857. all_tmpvars.update(tmpvars)
  858. call = "{}({})".format(func_name, ", ".join(callvars))
  859. body.append(sp + call)
  860. body.extend(map(lambda x: sp + x, casts))
  861. if len(outcodes) == 1:
  862. sig = "{}->{}".format(incodes, outcodes)
  863. specs.append(sig)
  864. else:
  865. sig = "{}*{}->v".format(incodes, outcodes)
  866. specs.append(sig)
  867. if len(specs) > 1:
  868. lines = self._get_nan_decs()
  869. body.extend(lines)
  870. if len(self.outvars) == 1:
  871. line = "return {}[0]".format(self.outvars[0])
  872. body.append(tab + line)
  873. # Generate the head of the function
  874. callvars, head = [], []
  875. for invar, (intype, _) in zip(self.invars, self.intypes):
  876. callvars.append("{} {}".format(intype, invar))
  877. if len(self.outvars) > 1:
  878. for outvar, (outtype, _) in zip(self.outvars, self.outtypes):
  879. callvars.append("{} *{}".format(outtype, outvar))
  880. if len(self.outvars) == 1:
  881. outtype, _ = self.outtypes[0]
  882. dec = "cpdef {} {}({}) nogil".format(outtype, self.name, ", ".join(callvars))
  883. else:
  884. dec = "cdef void {}({}) nogil".format(self.name, ", ".join(callvars))
  885. head.append(dec + ":")
  886. head.append(tab + '"""{}"""'.format(self.doc))
  887. if len(self.outvars) == 1:
  888. outvar = self.outvars[0]
  889. outtype, _ = self.outtypes[0]
  890. line = "cdef {} {}".format(outtype, outvar)
  891. head.append(tab + line)
  892. head.extend(self._get_tmp_decs(all_tmpvars))
  893. src = "\n".join(head + body)
  894. return dec, src, specs
  895. def _generate_from_outargs_and_return(self):
  896. tab = " "*4
  897. all_tmpvars = set()
  898. specs, body = [], []
  899. for signum, sig in enumerate(self.signatures):
  900. func_name, incodes, outcodes, retcode, intypes, outtypes, \
  901. rettype, c, lines, sp = self._get_common(signum, sig)
  902. body.extend(lines)
  903. # Generate the call to the specialized function
  904. callvars = self._get_incallvars(intypes, c)
  905. outcallvars, tmpvars, casts = self._get_outcallvars(outtypes, c)
  906. callvars.extend(outcallvars)
  907. all_tmpvars.update(tmpvars)
  908. call = "{}({})".format(func_name, ", ".join(callvars))
  909. if c and rettype == "double complex":
  910. call = double_complex_from_npy_cdouble(call)
  911. call = "{}[0] = {}".format(self.outvars[0], call)
  912. body.append(sp + call)
  913. body.extend(map(lambda x: sp + x, casts))
  914. sig = "{}*{}->v".format(incodes, outcodes + retcode)
  915. specs.append(sig)
  916. if len(specs) > 1:
  917. lines = self._get_nan_decs()
  918. body.extend(lines)
  919. # Generate the head of the function
  920. callvars, head = [], []
  921. for invar, (intype, _) in zip(self.invars, self.intypes):
  922. callvars.append("{} {}".format(intype, invar))
  923. for outvar, (outtype, _) in zip(self.outvars, self.outtypes):
  924. callvars.append("{} *{}".format(outtype, outvar))
  925. dec = "cdef void {}({}) nogil".format(self.name, ", ".join(callvars))
  926. head.append(dec + ":")
  927. head.append(tab + '"""{}"""'.format(self.doc))
  928. head.extend(self._get_tmp_decs(all_tmpvars))
  929. src = "\n".join(head + body)
  930. return dec, src, specs
  931. def generate(self):
  932. _, _, outcodes, retcode, _ = self.signatures[0]
  933. retcode = re.sub(r'\*.*', '', retcode)
  934. if not retcode:
  935. retcode = 'v'
  936. if len(outcodes) == 0 and retcode != 'v':
  937. dec, src, specs = self._generate_from_return_and_no_outargs()
  938. elif len(outcodes) > 0 and retcode == 'v':
  939. dec, src, specs = self._generate_from_outargs_and_no_return()
  940. elif len(outcodes) > 0 and retcode != 'v':
  941. dec, src, specs = self._generate_from_outargs_and_return()
  942. else:
  943. raise ValueError("Invalid signature")
  944. if len(self.outvars) > 1:
  945. wrap = self._get_python_wrap()
  946. else:
  947. wrap = None
  948. return dec, src, specs, self.fused_types, wrap
  949. def get_declaration(ufunc, c_name, c_proto, cy_proto, header, proto_h_filename):
  950. """
  951. Construct a Cython declaration of a function coming either from a
  952. pxd or a header file. Do sufficient tricks to enable compile-time
  953. type checking against the signature expected by the ufunc.
  954. """
  955. defs = []
  956. defs_h = []
  957. var_name = c_name.replace('[', '_').replace(']', '_').replace(' ', '_')
  958. if header.endswith('.pxd'):
  959. defs.append("from .%s cimport %s as %s" % (
  960. header[:-4], ufunc.cython_func_name(c_name, prefix=""),
  961. ufunc.cython_func_name(c_name)))
  962. # check function signature at compile time
  963. proto_name = '_proto_%s_t' % var_name
  964. defs.append("ctypedef %s" % (cy_proto.replace('(*)', proto_name)))
  965. defs.append("cdef %s *%s_var = &%s" % (
  966. proto_name, proto_name, ufunc.cython_func_name(c_name, specialized=True)))
  967. else:
  968. # redeclare the function, so that the assumed
  969. # signature is checked at compile time
  970. new_name = "%s \"%s\"" % (ufunc.cython_func_name(c_name), c_name)
  971. defs.append("cdef extern from \"%s\":" % proto_h_filename)
  972. defs.append(" cdef %s" % (cy_proto.replace('(*)', new_name)))
  973. defs_h.append("#include \"%s\"" % header)
  974. defs_h.append("%s;" % (c_proto.replace('(*)', c_name)))
  975. return defs, defs_h, var_name
  976. def generate_ufuncs(fn_prefix, cxx_fn_prefix, ufuncs):
  977. filename = fn_prefix + ".pyx"
  978. proto_h_filename = fn_prefix + '_defs.h'
  979. cxx_proto_h_filename = cxx_fn_prefix + '_defs.h'
  980. cxx_pyx_filename = cxx_fn_prefix + ".pyx"
  981. cxx_pxd_filename = cxx_fn_prefix + ".pxd"
  982. toplevel = ""
  983. # for _ufuncs*
  984. defs = []
  985. defs_h = []
  986. all_loops = {}
  987. # for _ufuncs_cxx*
  988. cxx_defs = []
  989. cxx_pxd_defs = [
  990. "from . cimport sf_error",
  991. "cdef void _set_action(sf_error.sf_error_t, sf_error.sf_action_t) nogil"
  992. ]
  993. cxx_defs_h = []
  994. ufuncs.sort(key=lambda u: u.name)
  995. for ufunc in ufuncs:
  996. # generate function declaration and type checking snippets
  997. cfuncs = ufunc.get_prototypes()
  998. for c_name, c_proto, cy_proto, header in cfuncs:
  999. if header.endswith('++'):
  1000. header = header[:-2]
  1001. # for the CXX module
  1002. item_defs, item_defs_h, var_name = get_declaration(ufunc, c_name, c_proto, cy_proto,
  1003. header, cxx_proto_h_filename)
  1004. cxx_defs.extend(item_defs)
  1005. cxx_defs_h.extend(item_defs_h)
  1006. cxx_defs.append("cdef void *_export_%s = <void*>%s" % (
  1007. var_name, ufunc.cython_func_name(c_name, specialized=True, override=False)))
  1008. cxx_pxd_defs.append("cdef void *_export_%s" % (var_name,))
  1009. # let cython grab the function pointer from the c++ shared library
  1010. ufunc.function_name_overrides[c_name] = "scipy.special._ufuncs_cxx._export_" + var_name
  1011. else:
  1012. # usual case
  1013. item_defs, item_defs_h, _ = get_declaration(ufunc, c_name, c_proto, cy_proto, header,
  1014. proto_h_filename)
  1015. defs.extend(item_defs)
  1016. defs_h.extend(item_defs_h)
  1017. # ufunc creation code snippet
  1018. t = ufunc.generate(all_loops)
  1019. toplevel += t + "\n"
  1020. # Produce output
  1021. toplevel = "\n".join(sorted(all_loops.values()) + defs + [toplevel])
  1022. with open(filename, 'w') as f:
  1023. f.write(UFUNCS_EXTRA_CODE_COMMON)
  1024. f.write(UFUNCS_EXTRA_CODE)
  1025. f.write("\n")
  1026. f.write(toplevel)
  1027. f.write(UFUNCS_EXTRA_CODE_BOTTOM)
  1028. defs_h = unique(defs_h)
  1029. with open(proto_h_filename, 'w') as f:
  1030. f.write("#ifndef UFUNCS_PROTO_H\n#define UFUNCS_PROTO_H 1\n")
  1031. f.write("\n".join(defs_h))
  1032. f.write("\n#endif\n")
  1033. cxx_defs_h = unique(cxx_defs_h)
  1034. with open(cxx_proto_h_filename, 'w') as f:
  1035. f.write("#ifndef UFUNCS_PROTO_H\n#define UFUNCS_PROTO_H 1\n")
  1036. f.write("\n".join(cxx_defs_h))
  1037. f.write("\n#endif\n")
  1038. with open(cxx_pyx_filename, 'w') as f:
  1039. f.write(UFUNCS_EXTRA_CODE_COMMON)
  1040. f.write("\n")
  1041. f.write("\n".join(cxx_defs))
  1042. f.write("\n# distutils: language = c++\n")
  1043. with open(cxx_pxd_filename, 'w') as f:
  1044. f.write("\n".join(cxx_pxd_defs))
  1045. def generate_fused_funcs(modname, ufunc_fn_prefix, fused_funcs):
  1046. pxdfile = modname + ".pxd"
  1047. pyxfile = modname + ".pyx"
  1048. proto_h_filename = ufunc_fn_prefix + '_defs.h'
  1049. sources = []
  1050. declarations = []
  1051. # Code for benchmarks
  1052. bench_aux = []
  1053. fused_types = set()
  1054. # Parameters for the tests
  1055. doc = []
  1056. defs = []
  1057. for func in fused_funcs:
  1058. if func.name.startswith("_"):
  1059. # Don't try to deal with functions that have extra layers
  1060. # of wrappers.
  1061. continue
  1062. # Get the function declaration for the .pxd and the source
  1063. # code for the .pyx
  1064. dec, src, specs, func_fused_types, wrap = func.generate()
  1065. declarations.append(dec)
  1066. sources.append(src)
  1067. if wrap:
  1068. sources.append(wrap)
  1069. fused_types.update(func_fused_types)
  1070. # Declare the specializations
  1071. cfuncs = func.get_prototypes(nptypes_for_h=True)
  1072. for c_name, c_proto, cy_proto, header in cfuncs:
  1073. if header.endswith('++'):
  1074. # We grab the c++ functions from the c++ module
  1075. continue
  1076. item_defs, _, _ = get_declaration(func, c_name, c_proto,
  1077. cy_proto, header,
  1078. proto_h_filename)
  1079. defs.extend(item_defs)
  1080. # Add a line to the documentation
  1081. doc.append(generate_doc(func.name, specs))
  1082. # Generate code for benchmarks
  1083. if func.name in CYTHON_SPECIAL_BENCHFUNCS:
  1084. for codes in CYTHON_SPECIAL_BENCHFUNCS[func.name]:
  1085. pybench, cybench = generate_bench(func.name, codes)
  1086. bench_aux.extend([pybench, cybench])
  1087. fused_types = list(fused_types)
  1088. fused_types.sort()
  1089. with open(pxdfile, 'w') as f:
  1090. f.write(CYTHON_SPECIAL_PXD)
  1091. f.write("\n")
  1092. f.write("\n\n".join(fused_types))
  1093. f.write("\n\n")
  1094. f.write("\n".join(declarations))
  1095. with open(pyxfile, 'w') as f:
  1096. header = CYTHON_SPECIAL_PYX
  1097. header = header.replace("FUNCLIST", "\n".join(doc))
  1098. f.write(header)
  1099. f.write("\n")
  1100. f.write("\n".join(defs))
  1101. f.write("\n\n")
  1102. f.write("\n\n".join(sources))
  1103. f.write("\n\n")
  1104. f.write("\n\n".join(bench_aux))
  1105. def unique(lst):
  1106. """
  1107. Return a list without repeated entries (first occurrence is kept),
  1108. preserving order.
  1109. """
  1110. seen = set()
  1111. new_lst = []
  1112. for item in lst:
  1113. if item in seen:
  1114. continue
  1115. seen.add(item)
  1116. new_lst.append(item)
  1117. return new_lst
  1118. def all_newer(src_files, dst_files):
  1119. from distutils.dep_util import newer
  1120. return all(os.path.exists(dst) and newer(dst, src)
  1121. for dst in dst_files for src in src_files)
  1122. def main():
  1123. p = optparse.OptionParser(usage=(__doc__ or '').strip())
  1124. options, args = p.parse_args()
  1125. if len(args) != 0:
  1126. p.error('invalid number of arguments')
  1127. pwd = os.path.dirname(__file__)
  1128. src_files = (os.path.abspath(__file__),
  1129. os.path.abspath(os.path.join(pwd, 'functions.json')),
  1130. os.path.abspath(os.path.join(pwd, 'add_newdocs.py')))
  1131. dst_files = ('_ufuncs.pyx',
  1132. '_ufuncs_defs.h',
  1133. '_ufuncs_cxx.pyx',
  1134. '_ufuncs_cxx.pxd',
  1135. '_ufuncs_cxx_defs.h',
  1136. 'cython_special.pyx',
  1137. 'cython_special.pxd')
  1138. os.chdir(BASE_DIR)
  1139. if all_newer(src_files, dst_files):
  1140. print("scipy/special/_generate_pyx.py: all files up-to-date")
  1141. return
  1142. ufuncs, fused_funcs = [], []
  1143. with open('functions.json') as data:
  1144. functions = json.load(data)
  1145. for f, sig in functions.items():
  1146. ufuncs.append(Ufunc(f, sig))
  1147. fused_funcs.append(FusedFunc(f, sig))
  1148. generate_ufuncs("_ufuncs", "_ufuncs_cxx", ufuncs)
  1149. generate_fused_funcs("cython_special", "_ufuncs", fused_funcs)
  1150. if __name__ == "__main__":
  1151. main()