_mptestutils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455
  1. from __future__ import division, print_function, absolute_import
  2. import os
  3. import sys
  4. import time
  5. import numpy as np
  6. from numpy.testing import assert_
  7. import pytest
  8. from scipy._lib.six import reraise
  9. from scipy.special._testutils import assert_func_equal
  10. try:
  11. import mpmath
  12. except ImportError:
  13. pass
  14. # ------------------------------------------------------------------------------
  15. # Machinery for systematic tests with mpmath
  16. # ------------------------------------------------------------------------------
  17. class Arg(object):
  18. """Generate a set of numbers on the real axis, concentrating on
  19. 'interesting' regions and covering all orders of magnitude.
  20. """
  21. def __init__(self, a=-np.inf, b=np.inf, inclusive_a=True, inclusive_b=True):
  22. if a > b:
  23. raise ValueError("a should be less than or equal to b")
  24. if a == -np.inf:
  25. a = -0.5*np.finfo(float).max
  26. if b == np.inf:
  27. b = 0.5*np.finfo(float).max
  28. self.a, self.b = a, b
  29. self.inclusive_a, self.inclusive_b = inclusive_a, inclusive_b
  30. def _positive_values(self, a, b, n):
  31. if a < 0:
  32. raise ValueError("a should be positive")
  33. # Try to put half of the points into a linspace between a and
  34. # 10 the other half in a logspace.
  35. if n % 2 == 0:
  36. nlogpts = n//2
  37. nlinpts = nlogpts
  38. else:
  39. nlogpts = n//2
  40. nlinpts = nlogpts + 1
  41. if a >= 10:
  42. # Outside of linspace range; just return a logspace.
  43. pts = np.logspace(np.log10(a), np.log10(b), n)
  44. elif a > 0 and b < 10:
  45. # Outside of logspace range; just return a linspace
  46. pts = np.linspace(a, b, n)
  47. elif a > 0:
  48. # Linspace between a and 10 and a logspace between 10 and
  49. # b.
  50. linpts = np.linspace(a, 10, nlinpts, endpoint=False)
  51. logpts = np.logspace(1, np.log10(b), nlogpts)
  52. pts = np.hstack((linpts, logpts))
  53. elif a == 0 and b <= 10:
  54. # Linspace between 0 and b and a logspace between 0 and
  55. # the smallest positive point of the linspace
  56. linpts = np.linspace(0, b, nlinpts)
  57. if linpts.size > 1:
  58. right = np.log10(linpts[1])
  59. else:
  60. right = -30
  61. logpts = np.logspace(-30, right, nlogpts, endpoint=False)
  62. pts = np.hstack((logpts, linpts))
  63. else:
  64. # Linspace between 0 and 10, logspace between 0 and the
  65. # smallest positive point of the linspace, and a logspace
  66. # between 10 and b.
  67. if nlogpts % 2 == 0:
  68. nlogpts1 = nlogpts//2
  69. nlogpts2 = nlogpts1
  70. else:
  71. nlogpts1 = nlogpts//2
  72. nlogpts2 = nlogpts1 + 1
  73. linpts = np.linspace(0, 10, nlinpts, endpoint=False)
  74. if linpts.size > 1:
  75. right = np.log10(linpts[1])
  76. else:
  77. right = -30
  78. logpts1 = np.logspace(-30, right, nlogpts1, endpoint=False)
  79. logpts2 = np.logspace(1, np.log10(b), nlogpts2)
  80. pts = np.hstack((logpts1, linpts, logpts2))
  81. return np.sort(pts)
  82. def values(self, n):
  83. """Return an array containing n numbers."""
  84. a, b = self.a, self.b
  85. if a == b:
  86. return np.zeros(n)
  87. if not self.inclusive_a:
  88. n += 1
  89. if not self.inclusive_b:
  90. n += 1
  91. if n % 2 == 0:
  92. n1 = n//2
  93. n2 = n1
  94. else:
  95. n1 = n//2
  96. n2 = n1 + 1
  97. if a >= 0:
  98. pospts = self._positive_values(a, b, n)
  99. negpts = []
  100. elif b <= 0:
  101. pospts = []
  102. negpts = -self._positive_values(-b, -a, n)
  103. else:
  104. pospts = self._positive_values(0, b, n1)
  105. negpts = -self._positive_values(0, -a, n2 + 1)
  106. # Don't want to get zero twice
  107. negpts = negpts[1:]
  108. pts = np.hstack((negpts[::-1], pospts))
  109. if not self.inclusive_a:
  110. pts = pts[1:]
  111. if not self.inclusive_b:
  112. pts = pts[:-1]
  113. return pts
  114. class FixedArg(object):
  115. def __init__(self, values):
  116. self._values = np.asarray(values)
  117. def values(self, n):
  118. return self._values
  119. class ComplexArg(object):
  120. def __init__(self, a=complex(-np.inf, -np.inf), b=complex(np.inf, np.inf)):
  121. self.real = Arg(a.real, b.real)
  122. self.imag = Arg(a.imag, b.imag)
  123. def values(self, n):
  124. m = int(np.floor(np.sqrt(n)))
  125. x = self.real.values(m)
  126. y = self.imag.values(m + 1)
  127. return (x[:,None] + 1j*y[None,:]).ravel()
  128. class IntArg(object):
  129. def __init__(self, a=-1000, b=1000):
  130. self.a = a
  131. self.b = b
  132. def values(self, n):
  133. v1 = Arg(self.a, self.b).values(max(1 + n//2, n-5)).astype(int)
  134. v2 = np.arange(-5, 5)
  135. v = np.unique(np.r_[v1, v2])
  136. v = v[(v >= self.a) & (v < self.b)]
  137. return v
  138. def get_args(argspec, n):
  139. if isinstance(argspec, np.ndarray):
  140. args = argspec.copy()
  141. else:
  142. nargs = len(argspec)
  143. ms = np.asarray([1.5 if isinstance(spec, ComplexArg) else 1.0 for spec in argspec])
  144. ms = (n**(ms/sum(ms))).astype(int) + 1
  145. args = []
  146. for spec, m in zip(argspec, ms):
  147. args.append(spec.values(m))
  148. args = np.array(np.broadcast_arrays(*np.ix_(*args))).reshape(nargs, -1).T
  149. return args
  150. class MpmathData(object):
  151. def __init__(self, scipy_func, mpmath_func, arg_spec, name=None,
  152. dps=None, prec=None, n=None, rtol=1e-7, atol=1e-300,
  153. ignore_inf_sign=False, distinguish_nan_and_inf=True,
  154. nan_ok=True, param_filter=None):
  155. # mpmath tests are really slow (see gh-6989). Use a small number of
  156. # points by default, increase back to 5000 (old default) if XSLOW is
  157. # set
  158. if n is None:
  159. try:
  160. is_xslow = int(os.environ.get('SCIPY_XSLOW', '0'))
  161. except ValueError:
  162. is_xslow = False
  163. n = 5000 if is_xslow else 500
  164. self.scipy_func = scipy_func
  165. self.mpmath_func = mpmath_func
  166. self.arg_spec = arg_spec
  167. self.dps = dps
  168. self.prec = prec
  169. self.n = n
  170. self.rtol = rtol
  171. self.atol = atol
  172. self.ignore_inf_sign = ignore_inf_sign
  173. self.nan_ok = nan_ok
  174. if isinstance(self.arg_spec, np.ndarray):
  175. self.is_complex = np.issubdtype(self.arg_spec.dtype, np.complexfloating)
  176. else:
  177. self.is_complex = any([isinstance(arg, ComplexArg) for arg in self.arg_spec])
  178. self.ignore_inf_sign = ignore_inf_sign
  179. self.distinguish_nan_and_inf = distinguish_nan_and_inf
  180. if not name or name == '<lambda>':
  181. name = getattr(scipy_func, '__name__', None)
  182. if not name or name == '<lambda>':
  183. name = getattr(mpmath_func, '__name__', None)
  184. self.name = name
  185. self.param_filter = param_filter
  186. def check(self):
  187. np.random.seed(1234)
  188. # Generate values for the arguments
  189. argarr = get_args(self.arg_spec, self.n)
  190. # Check
  191. old_dps, old_prec = mpmath.mp.dps, mpmath.mp.prec
  192. try:
  193. if self.dps is not None:
  194. dps_list = [self.dps]
  195. else:
  196. dps_list = [20]
  197. if self.prec is not None:
  198. mpmath.mp.prec = self.prec
  199. # Proper casting of mpmath input and output types. Using
  200. # native mpmath types as inputs gives improved precision
  201. # in some cases.
  202. if np.issubdtype(argarr.dtype, np.complexfloating):
  203. pytype = mpc2complex
  204. def mptype(x):
  205. return mpmath.mpc(complex(x))
  206. else:
  207. def mptype(x):
  208. return mpmath.mpf(float(x))
  209. def pytype(x):
  210. if abs(x.imag) > 1e-16*(1 + abs(x.real)):
  211. return np.nan
  212. else:
  213. return mpf2float(x.real)
  214. # Try out different dps until one (or none) works
  215. for j, dps in enumerate(dps_list):
  216. mpmath.mp.dps = dps
  217. try:
  218. assert_func_equal(self.scipy_func,
  219. lambda *a: pytype(self.mpmath_func(*map(mptype, a))),
  220. argarr,
  221. vectorized=False,
  222. rtol=self.rtol, atol=self.atol,
  223. ignore_inf_sign=self.ignore_inf_sign,
  224. distinguish_nan_and_inf=self.distinguish_nan_and_inf,
  225. nan_ok=self.nan_ok,
  226. param_filter=self.param_filter)
  227. break
  228. except AssertionError:
  229. if j >= len(dps_list)-1:
  230. reraise(*sys.exc_info())
  231. finally:
  232. mpmath.mp.dps, mpmath.mp.prec = old_dps, old_prec
  233. def __repr__(self):
  234. if self.is_complex:
  235. return "<MpmathData: %s (complex)>" % (self.name,)
  236. else:
  237. return "<MpmathData: %s>" % (self.name,)
  238. def assert_mpmath_equal(*a, **kw):
  239. d = MpmathData(*a, **kw)
  240. d.check()
  241. def nonfunctional_tooslow(func):
  242. return pytest.mark.skip(reason=" Test not yet functional (too slow), needs more work.")(func)
  243. # ------------------------------------------------------------------------------
  244. # Tools for dealing with mpmath quirks
  245. # ------------------------------------------------------------------------------
  246. def mpf2float(x):
  247. """
  248. Convert an mpf to the nearest floating point number. Just using
  249. float directly doesn't work because of results like this:
  250. with mp.workdps(50):
  251. float(mpf("0.99999999999999999")) = 0.9999999999999999
  252. """
  253. return float(mpmath.nstr(x, 17, min_fixed=0, max_fixed=0))
  254. def mpc2complex(x):
  255. return complex(mpf2float(x.real), mpf2float(x.imag))
  256. def trace_args(func):
  257. def tofloat(x):
  258. if isinstance(x, mpmath.mpc):
  259. return complex(x)
  260. else:
  261. return float(x)
  262. def wrap(*a, **kw):
  263. sys.stderr.write("%r: " % (tuple(map(tofloat, a)),))
  264. sys.stderr.flush()
  265. try:
  266. r = func(*a, **kw)
  267. sys.stderr.write("-> %r" % r)
  268. finally:
  269. sys.stderr.write("\n")
  270. sys.stderr.flush()
  271. return r
  272. return wrap
  273. try:
  274. import posix
  275. import signal
  276. POSIX = ('setitimer' in dir(signal))
  277. except ImportError:
  278. POSIX = False
  279. class TimeoutError(Exception):
  280. pass
  281. def time_limited(timeout=0.5, return_val=np.nan, use_sigalrm=True):
  282. """
  283. Decorator for setting a timeout for pure-Python functions.
  284. If the function does not return within `timeout` seconds, the
  285. value `return_val` is returned instead.
  286. On POSIX this uses SIGALRM by default. On non-POSIX, settrace is
  287. used. Do not use this with threads: the SIGALRM implementation
  288. does probably not work well. The settrace implementation only
  289. traces the current thread.
  290. The settrace implementation slows down execution speed. Slowdown
  291. by a factor around 10 is probably typical.
  292. """
  293. if POSIX and use_sigalrm:
  294. def sigalrm_handler(signum, frame):
  295. raise TimeoutError()
  296. def deco(func):
  297. def wrap(*a, **kw):
  298. old_handler = signal.signal(signal.SIGALRM, sigalrm_handler)
  299. signal.setitimer(signal.ITIMER_REAL, timeout)
  300. try:
  301. return func(*a, **kw)
  302. except TimeoutError:
  303. return return_val
  304. finally:
  305. signal.setitimer(signal.ITIMER_REAL, 0)
  306. signal.signal(signal.SIGALRM, old_handler)
  307. return wrap
  308. else:
  309. def deco(func):
  310. def wrap(*a, **kw):
  311. start_time = time.time()
  312. def trace(frame, event, arg):
  313. if time.time() - start_time > timeout:
  314. raise TimeoutError()
  315. return trace
  316. sys.settrace(trace)
  317. try:
  318. return func(*a, **kw)
  319. except TimeoutError:
  320. sys.settrace(None)
  321. return return_val
  322. finally:
  323. sys.settrace(None)
  324. return wrap
  325. return deco
  326. def exception_to_nan(func):
  327. """Decorate function to return nan if it raises an exception"""
  328. def wrap(*a, **kw):
  329. try:
  330. return func(*a, **kw)
  331. except Exception:
  332. return np.nan
  333. return wrap
  334. def inf_to_nan(func):
  335. """Decorate function to return nan if it returns inf"""
  336. def wrap(*a, **kw):
  337. v = func(*a, **kw)
  338. if not np.isfinite(v):
  339. return np.nan
  340. return v
  341. return wrap
  342. def mp_assert_allclose(res, std, atol=0, rtol=1e-17):
  343. """
  344. Compare lists of mpmath.mpf's or mpmath.mpc's directly so that it
  345. can be done to higher precision than double.
  346. """
  347. try:
  348. len(res)
  349. except TypeError:
  350. res = list(res)
  351. n = len(std)
  352. if len(res) != n:
  353. raise AssertionError("Lengths of inputs not equal.")
  354. failures = []
  355. for k in range(n):
  356. try:
  357. assert_(mpmath.fabs(res[k] - std[k]) <= atol + rtol*mpmath.fabs(std[k]))
  358. except AssertionError:
  359. failures.append(k)
  360. ndigits = int(abs(np.log10(rtol)))
  361. msg = [""]
  362. msg.append("Bad results ({} out of {}) for the following points:"
  363. .format(len(failures), n))
  364. for k in failures:
  365. resrep = mpmath.nstr(res[k], ndigits, min_fixed=0, max_fixed=0)
  366. stdrep = mpmath.nstr(std[k], ndigits, min_fixed=0, max_fixed=0)
  367. if std[k] == 0:
  368. rdiff = "inf"
  369. else:
  370. rdiff = mpmath.fabs((res[k] - std[k])/std[k])
  371. rdiff = mpmath.nstr(rdiff, 3)
  372. msg.append("{}: {} != {} (rdiff {})".format(k, resrep, stdrep, rdiff))
  373. if failures:
  374. assert_(False, "\n".join(msg))