test_cdflib.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. """
  2. Test cdflib functions versus mpmath, if available.
  3. The following functions still need tests:
  4. - ncfdtr
  5. - ncfdtri
  6. - ncfdtridfn
  7. - ncfdtridfd
  8. - ncfdtrinc
  9. - nbdtrik
  10. - nbdtrin
  11. - nrdtrimn
  12. - nrdtrisd
  13. - pdtrik
  14. - nctdtr
  15. - nctdtrit
  16. - nctdtridf
  17. - nctdtrinc
  18. """
  19. from __future__ import division, print_function, absolute_import
  20. import itertools
  21. import numpy as np
  22. from numpy.testing import assert_equal
  23. import pytest
  24. import scipy.special as sp
  25. from scipy._lib.six import with_metaclass
  26. from scipy.special._testutils import (
  27. MissingModule, check_version, FuncData)
  28. from scipy.special._mptestutils import (
  29. Arg, IntArg, get_args, mpf2float, assert_mpmath_equal)
  30. try:
  31. import mpmath
  32. except ImportError:
  33. mpmath = MissingModule('mpmath')
  34. class ProbArg(object):
  35. """Generate a set of probabilities on [0, 1]."""
  36. def __init__(self):
  37. # Include the endpoints for compatibility with Arg et. al.
  38. self.a = 0
  39. self.b = 1
  40. def values(self, n):
  41. """Return an array containing approximatively n numbers."""
  42. m = max(1, n//3)
  43. v1 = np.logspace(-30, np.log10(0.3), m)
  44. v2 = np.linspace(0.3, 0.7, m + 1, endpoint=False)[1:]
  45. v3 = 1 - np.logspace(np.log10(0.3), -15, m)
  46. v = np.r_[v1, v2, v3]
  47. return np.unique(v)
  48. class EndpointFilter(object):
  49. def __init__(self, a, b, rtol, atol):
  50. self.a = a
  51. self.b = b
  52. self.rtol = rtol
  53. self.atol = atol
  54. def __call__(self, x):
  55. mask1 = np.abs(x - self.a) < self.rtol*np.abs(self.a) + self.atol
  56. mask2 = np.abs(x - self.b) < self.rtol*np.abs(self.b) + self.atol
  57. return np.where(mask1 | mask2, False, True)
  58. class _CDFData(object):
  59. def __init__(self, spfunc, mpfunc, index, argspec, spfunc_first=True,
  60. dps=20, n=5000, rtol=None, atol=None,
  61. endpt_rtol=None, endpt_atol=None):
  62. self.spfunc = spfunc
  63. self.mpfunc = mpfunc
  64. self.index = index
  65. self.argspec = argspec
  66. self.spfunc_first = spfunc_first
  67. self.dps = dps
  68. self.n = n
  69. self.rtol = rtol
  70. self.atol = atol
  71. if not isinstance(argspec, list):
  72. self.endpt_rtol = None
  73. self.endpt_atol = None
  74. elif endpt_rtol is not None or endpt_atol is not None:
  75. if isinstance(endpt_rtol, list):
  76. self.endpt_rtol = endpt_rtol
  77. else:
  78. self.endpt_rtol = [endpt_rtol]*len(self.argspec)
  79. if isinstance(endpt_atol, list):
  80. self.endpt_atol = endpt_atol
  81. else:
  82. self.endpt_atol = [endpt_atol]*len(self.argspec)
  83. else:
  84. self.endpt_rtol = None
  85. self.endpt_atol = None
  86. def idmap(self, *args):
  87. if self.spfunc_first:
  88. res = self.spfunc(*args)
  89. if np.isnan(res):
  90. return np.nan
  91. args = list(args)
  92. args[self.index] = res
  93. with mpmath.workdps(self.dps):
  94. res = self.mpfunc(*tuple(args))
  95. # Imaginary parts are spurious
  96. res = mpf2float(res.real)
  97. else:
  98. with mpmath.workdps(self.dps):
  99. res = self.mpfunc(*args)
  100. res = mpf2float(res.real)
  101. args = list(args)
  102. args[self.index] = res
  103. res = self.spfunc(*tuple(args))
  104. return res
  105. def get_param_filter(self):
  106. if self.endpt_rtol is None and self.endpt_atol is None:
  107. return None
  108. filters = []
  109. for rtol, atol, spec in zip(self.endpt_rtol, self.endpt_atol, self.argspec):
  110. if rtol is None and atol is None:
  111. filters.append(None)
  112. continue
  113. elif rtol is None:
  114. rtol = 0.0
  115. elif atol is None:
  116. atol = 0.0
  117. filters.append(EndpointFilter(spec.a, spec.b, rtol, atol))
  118. return filters
  119. def check(self):
  120. # Generate values for the arguments
  121. args = get_args(self.argspec, self.n)
  122. param_filter = self.get_param_filter()
  123. param_columns = tuple(range(args.shape[1]))
  124. result_columns = args.shape[1]
  125. args = np.hstack((args, args[:,self.index].reshape(args.shape[0], 1)))
  126. FuncData(self.idmap, args,
  127. param_columns=param_columns, result_columns=result_columns,
  128. rtol=self.rtol, atol=self.atol, vectorized=False,
  129. param_filter=param_filter).check()
  130. def _assert_inverts(*a, **kw):
  131. d = _CDFData(*a, **kw)
  132. d.check()
  133. def _binomial_cdf(k, n, p):
  134. k, n, p = mpmath.mpf(k), mpmath.mpf(n), mpmath.mpf(p)
  135. if k <= 0:
  136. return mpmath.mpf(0)
  137. elif k >= n:
  138. return mpmath.mpf(1)
  139. onemp = mpmath.fsub(1, p, exact=True)
  140. return mpmath.betainc(n - k, k + 1, x2=onemp, regularized=True)
  141. def _f_cdf(dfn, dfd, x):
  142. if x < 0:
  143. return mpmath.mpf(0)
  144. dfn, dfd, x = mpmath.mpf(dfn), mpmath.mpf(dfd), mpmath.mpf(x)
  145. ub = dfn*x/(dfn*x + dfd)
  146. res = mpmath.betainc(dfn/2, dfd/2, x2=ub, regularized=True)
  147. return res
  148. def _student_t_cdf(df, t, dps=None):
  149. if dps is None:
  150. dps = mpmath.mp.dps
  151. with mpmath.workdps(dps):
  152. df, t = mpmath.mpf(df), mpmath.mpf(t)
  153. fac = mpmath.hyp2f1(0.5, 0.5*(df + 1), 1.5, -t**2/df)
  154. fac *= t*mpmath.gamma(0.5*(df + 1))
  155. fac /= mpmath.sqrt(mpmath.pi*df)*mpmath.gamma(0.5*df)
  156. return 0.5 + fac
  157. def _noncentral_chi_pdf(t, df, nc):
  158. res = mpmath.besseli(df/2 - 1, mpmath.sqrt(nc*t))
  159. res *= mpmath.exp(-(t + nc)/2)*(t/nc)**(df/4 - 1/2)/2
  160. return res
  161. def _noncentral_chi_cdf(x, df, nc, dps=None):
  162. if dps is None:
  163. dps = mpmath.mp.dps
  164. x, df, nc = mpmath.mpf(x), mpmath.mpf(df), mpmath.mpf(nc)
  165. with mpmath.workdps(dps):
  166. res = mpmath.quad(lambda t: _noncentral_chi_pdf(t, df, nc), [0, x])
  167. return res
  168. def _tukey_lmbda_quantile(p, lmbda):
  169. # For lmbda != 0
  170. return (p**lmbda - (1 - p)**lmbda)/lmbda
  171. @pytest.mark.slow
  172. @check_version(mpmath, '0.19')
  173. class TestCDFlib(object):
  174. @pytest.mark.xfail(run=False)
  175. def test_bdtrik(self):
  176. _assert_inverts(
  177. sp.bdtrik,
  178. _binomial_cdf,
  179. 0, [ProbArg(), IntArg(1, 1000), ProbArg()],
  180. rtol=1e-4)
  181. def test_bdtrin(self):
  182. _assert_inverts(
  183. sp.bdtrin,
  184. _binomial_cdf,
  185. 1, [IntArg(1, 1000), ProbArg(), ProbArg()],
  186. rtol=1e-4, endpt_atol=[None, None, 1e-6])
  187. def test_btdtria(self):
  188. _assert_inverts(
  189. sp.btdtria,
  190. lambda a, b, x: mpmath.betainc(a, b, x2=x, regularized=True),
  191. 0, [ProbArg(), Arg(0, 1e2, inclusive_a=False),
  192. Arg(0, 1, inclusive_a=False, inclusive_b=False)],
  193. rtol=1e-6)
  194. def test_btdtrib(self):
  195. # Use small values of a or mpmath doesn't converge
  196. _assert_inverts(
  197. sp.btdtrib,
  198. lambda a, b, x: mpmath.betainc(a, b, x2=x, regularized=True),
  199. 1, [Arg(0, 1e2, inclusive_a=False), ProbArg(),
  200. Arg(0, 1, inclusive_a=False, inclusive_b=False)],
  201. rtol=1e-7, endpt_atol=[None, 1e-18, 1e-15])
  202. @pytest.mark.xfail(run=False)
  203. def test_fdtridfd(self):
  204. _assert_inverts(
  205. sp.fdtridfd,
  206. _f_cdf,
  207. 1, [IntArg(1, 100), ProbArg(), Arg(0, 100, inclusive_a=False)],
  208. rtol=1e-7)
  209. def test_gdtria(self):
  210. _assert_inverts(
  211. sp.gdtria,
  212. lambda a, b, x: mpmath.gammainc(b, b=a*x, regularized=True),
  213. 0, [ProbArg(), Arg(0, 1e3, inclusive_a=False),
  214. Arg(0, 1e4, inclusive_a=False)], rtol=1e-7,
  215. endpt_atol=[None, 1e-7, 1e-10])
  216. def test_gdtrib(self):
  217. # Use small values of a and x or mpmath doesn't converge
  218. _assert_inverts(
  219. sp.gdtrib,
  220. lambda a, b, x: mpmath.gammainc(b, b=a*x, regularized=True),
  221. 1, [Arg(0, 1e2, inclusive_a=False), ProbArg(),
  222. Arg(0, 1e3, inclusive_a=False)], rtol=1e-5)
  223. def test_gdtrix(self):
  224. _assert_inverts(
  225. sp.gdtrix,
  226. lambda a, b, x: mpmath.gammainc(b, b=a*x, regularized=True),
  227. 2, [Arg(0, 1e3, inclusive_a=False), Arg(0, 1e3, inclusive_a=False),
  228. ProbArg()], rtol=1e-7,
  229. endpt_atol=[None, 1e-7, 1e-10])
  230. def test_stdtr(self):
  231. # Ideally the left endpoint for Arg() should be 0.
  232. assert_mpmath_equal(
  233. sp.stdtr,
  234. _student_t_cdf,
  235. [IntArg(1, 100), Arg(1e-10, np.inf)], rtol=1e-7)
  236. @pytest.mark.xfail(run=False)
  237. def test_stdtridf(self):
  238. _assert_inverts(
  239. sp.stdtridf,
  240. _student_t_cdf,
  241. 0, [ProbArg(), Arg()], rtol=1e-7)
  242. def test_stdtrit(self):
  243. _assert_inverts(
  244. sp.stdtrit,
  245. _student_t_cdf,
  246. 1, [IntArg(1, 100), ProbArg()], rtol=1e-7,
  247. endpt_atol=[None, 1e-10])
  248. def test_chdtriv(self):
  249. _assert_inverts(
  250. sp.chdtriv,
  251. lambda v, x: mpmath.gammainc(v/2, b=x/2, regularized=True),
  252. 0, [ProbArg(), IntArg(1, 100)], rtol=1e-4)
  253. @pytest.mark.xfail(run=False)
  254. def test_chndtridf(self):
  255. # Use a larger atol since mpmath is doing numerical integration
  256. _assert_inverts(
  257. sp.chndtridf,
  258. _noncentral_chi_cdf,
  259. 1, [Arg(0, 100, inclusive_a=False), ProbArg(),
  260. Arg(0, 100, inclusive_a=False)],
  261. n=1000, rtol=1e-4, atol=1e-15)
  262. @pytest.mark.xfail(run=False)
  263. def test_chndtrinc(self):
  264. # Use a larger atol since mpmath is doing numerical integration
  265. _assert_inverts(
  266. sp.chndtrinc,
  267. _noncentral_chi_cdf,
  268. 2, [Arg(0, 100, inclusive_a=False), IntArg(1, 100), ProbArg()],
  269. n=1000, rtol=1e-4, atol=1e-15)
  270. def test_chndtrix(self):
  271. # Use a larger atol since mpmath is doing numerical integration
  272. _assert_inverts(
  273. sp.chndtrix,
  274. _noncentral_chi_cdf,
  275. 0, [ProbArg(), IntArg(1, 100), Arg(0, 100, inclusive_a=False)],
  276. n=1000, rtol=1e-4, atol=1e-15,
  277. endpt_atol=[1e-6, None, None])
  278. def test_tklmbda_zero_shape(self):
  279. # When lmbda = 0 the CDF has a simple closed form
  280. one = mpmath.mpf(1)
  281. assert_mpmath_equal(
  282. lambda x: sp.tklmbda(x, 0),
  283. lambda x: one/(mpmath.exp(-x) + one),
  284. [Arg()], rtol=1e-7)
  285. def test_tklmbda_neg_shape(self):
  286. _assert_inverts(
  287. sp.tklmbda,
  288. _tukey_lmbda_quantile,
  289. 0, [ProbArg(), Arg(-25, 0, inclusive_b=False)],
  290. spfunc_first=False, rtol=1e-5,
  291. endpt_atol=[1e-9, 1e-5])
  292. @pytest.mark.xfail(run=False)
  293. def test_tklmbda_pos_shape(self):
  294. _assert_inverts(
  295. sp.tklmbda,
  296. _tukey_lmbda_quantile,
  297. 0, [ProbArg(), Arg(0, 100, inclusive_a=False)],
  298. spfunc_first=False, rtol=1e-5)
  299. def test_nonfinite():
  300. funcs = [
  301. ("btdtria", 3),
  302. ("btdtrib", 3),
  303. ("bdtrik", 3),
  304. ("bdtrin", 3),
  305. ("chdtriv", 2),
  306. ("chndtr", 3),
  307. ("chndtrix", 3),
  308. ("chndtridf", 3),
  309. ("chndtrinc", 3),
  310. ("fdtridfd", 3),
  311. ("ncfdtr", 4),
  312. ("ncfdtri", 4),
  313. ("ncfdtridfn", 4),
  314. ("ncfdtridfd", 4),
  315. ("ncfdtrinc", 4),
  316. ("gdtrix", 3),
  317. ("gdtrib", 3),
  318. ("gdtria", 3),
  319. ("nbdtrik", 3),
  320. ("nbdtrin", 3),
  321. ("nrdtrimn", 3),
  322. ("nrdtrisd", 3),
  323. ("pdtrik", 2),
  324. ("stdtr", 2),
  325. ("stdtrit", 2),
  326. ("stdtridf", 2),
  327. ("nctdtr", 3),
  328. ("nctdtrit", 3),
  329. ("nctdtridf", 3),
  330. ("nctdtrinc", 3),
  331. ("tklmbda", 2),
  332. ]
  333. np.random.seed(1)
  334. for func, numargs in funcs:
  335. func = getattr(sp, func)
  336. args_choices = [(float(x), np.nan, np.inf, -np.inf) for x in
  337. np.random.rand(numargs)]
  338. for args in itertools.product(*args_choices):
  339. res = func(*args)
  340. if any(np.isnan(x) for x in args):
  341. # Nan inputs should result to nan output
  342. assert_equal(res, np.nan)
  343. else:
  344. # All other inputs should return something (but not
  345. # raise exceptions or cause hangs)
  346. pass