test_continuous_basic.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. import numpy.testing as npt
  4. import pytest
  5. from pytest import raises as assert_raises
  6. from scipy._lib._numpy_compat import suppress_warnings
  7. from scipy.integrate import IntegrationWarning
  8. from scipy import stats
  9. from scipy.special import betainc
  10. from. common_tests import (check_normalization, check_moment, check_mean_expect,
  11. check_var_expect, check_skew_expect,
  12. check_kurt_expect, check_entropy,
  13. check_private_entropy, check_entropy_vect_scale,
  14. check_edge_support, check_named_args,
  15. check_random_state_property,
  16. check_meth_dtype, check_ppf_dtype, check_cmplx_deriv,
  17. check_pickling, check_rvs_broadcast)
  18. from scipy.stats._distr_params import distcont
  19. """
  20. Test all continuous distributions.
  21. Parameters were chosen for those distributions that pass the
  22. Kolmogorov-Smirnov test. This provides safe parameters for each
  23. distributions so that we can perform further testing of class methods.
  24. These tests currently check only/mostly for serious errors and exceptions,
  25. not for numerically exact results.
  26. """
  27. # Note that you need to add new distributions you want tested
  28. # to _distr_params
  29. DECIMAL = 5 # specify the precision of the tests # increased from 0 to 5
  30. # Last four of these fail all around. Need to be checked
  31. distcont_extra = [
  32. ['betaprime', (100, 86)],
  33. ['fatiguelife', (5,)],
  34. ['mielke', (4.6420495492121487, 0.59707419545516938)],
  35. ['invweibull', (0.58847112119264788,)],
  36. # burr: sample mean test fails still for c<1
  37. ['burr', (0.94839838075366045, 4.3820284068855795)],
  38. # genextreme: sample mean test, sf-logsf test fail
  39. ['genextreme', (3.3184017469423535,)],
  40. ]
  41. distslow = ['kappa4', 'rdist', 'gausshyper',
  42. 'recipinvgauss', 'ksone', 'genexpon',
  43. 'vonmises', 'vonmises_line', 'mielke', 'semicircular',
  44. 'cosine', 'invweibull', 'powerlognorm', 'johnsonsu', 'kstwobign']
  45. # distslow are sorted by speed (very slow to slow)
  46. # These distributions fail the complex derivative test below.
  47. # Here 'fail' mean produce wrong results and/or raise exceptions, depending
  48. # on the implementation details of corresponding special functions.
  49. # cf https://github.com/scipy/scipy/pull/4979 for a discussion.
  50. fails_cmplx = set(['beta', 'betaprime', 'chi', 'chi2', 'dgamma', 'dweibull',
  51. 'erlang', 'f', 'gamma', 'gausshyper', 'gengamma',
  52. 'gennorm', 'genpareto', 'halfgennorm', 'invgamma',
  53. 'ksone', 'kstwobign', 'levy_l', 'loggamma', 'logistic',
  54. 'maxwell', 'nakagami', 'ncf', 'nct', 'ncx2', 'norminvgauss',
  55. 'pearson3', 'rice', 't', 'skewnorm', 'tukeylambda',
  56. 'vonmises', 'vonmises_line', 'rv_histogram_instance'])
  57. _h = np.histogram([1, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 5, 6,
  58. 6, 6, 6, 7, 7, 7, 8, 8, 9], bins=8)
  59. histogram_test_instance = stats.rv_histogram(_h)
  60. def cases_test_cont_basic():
  61. for distname, arg in distcont[:] + [(histogram_test_instance, tuple())]:
  62. if distname == 'levy_stable':
  63. continue
  64. elif distname in distslow:
  65. yield pytest.param(distname, arg, marks=pytest.mark.slow)
  66. else:
  67. yield distname, arg
  68. @pytest.mark.parametrize('distname,arg', cases_test_cont_basic())
  69. def test_cont_basic(distname, arg):
  70. # this test skips slow distributions
  71. if distname == 'truncnorm':
  72. pytest.xfail(reason=distname)
  73. try:
  74. distfn = getattr(stats, distname)
  75. except TypeError:
  76. distfn = distname
  77. distname = 'rv_histogram_instance'
  78. np.random.seed(765456)
  79. sn = 500
  80. with suppress_warnings() as sup:
  81. # frechet_l and frechet_r are deprecated, so all their
  82. # methods generate DeprecationWarnings.
  83. sup.filter(category=DeprecationWarning, message=".*frechet_")
  84. rvs = distfn.rvs(size=sn, *arg)
  85. sm = rvs.mean()
  86. sv = rvs.var()
  87. m, v = distfn.stats(*arg)
  88. check_sample_meanvar_(distfn, arg, m, v, sm, sv, sn, distname + 'sample mean test')
  89. check_cdf_ppf(distfn, arg, distname)
  90. check_sf_isf(distfn, arg, distname)
  91. check_pdf(distfn, arg, distname)
  92. check_pdf_logpdf(distfn, arg, distname)
  93. check_cdf_logcdf(distfn, arg, distname)
  94. check_sf_logsf(distfn, arg, distname)
  95. alpha = 0.01
  96. if distname == 'rv_histogram_instance':
  97. check_distribution_rvs(distfn.cdf, arg, alpha, rvs)
  98. else:
  99. check_distribution_rvs(distname, arg, alpha, rvs)
  100. locscale_defaults = (0, 1)
  101. meths = [distfn.pdf, distfn.logpdf, distfn.cdf, distfn.logcdf,
  102. distfn.logsf]
  103. # make sure arguments are within support
  104. spec_x = {'frechet_l': -0.5, 'weibull_max': -0.5, 'levy_l': -0.5,
  105. 'pareto': 1.5, 'tukeylambda': 0.3,
  106. 'rv_histogram_instance': 5.0}
  107. x = spec_x.get(distname, 0.5)
  108. if distname == 'invweibull':
  109. arg = (1,)
  110. elif distname == 'ksone':
  111. arg = (3,)
  112. check_named_args(distfn, x, arg, locscale_defaults, meths)
  113. check_random_state_property(distfn, arg)
  114. check_pickling(distfn, arg)
  115. # Entropy
  116. if distname not in ['ksone', 'kstwobign', 'ncf', 'crystalball']:
  117. check_entropy(distfn, arg, distname)
  118. if distfn.numargs == 0:
  119. check_vecentropy(distfn, arg)
  120. if (distfn.__class__._entropy != stats.rv_continuous._entropy
  121. and distname != 'vonmises'):
  122. check_private_entropy(distfn, arg, stats.rv_continuous)
  123. with suppress_warnings() as sup:
  124. sup.filter(IntegrationWarning, "The occurrence of roundoff error")
  125. sup.filter(IntegrationWarning, "Extremely bad integrand")
  126. sup.filter(RuntimeWarning, "invalid value")
  127. check_entropy_vect_scale(distfn, arg)
  128. check_edge_support(distfn, arg)
  129. check_meth_dtype(distfn, arg, meths)
  130. check_ppf_dtype(distfn, arg)
  131. if distname not in fails_cmplx:
  132. check_cmplx_deriv(distfn, arg)
  133. if distname != 'truncnorm':
  134. check_ppf_private(distfn, arg, distname)
  135. def test_levy_stable_random_state_property():
  136. # levy_stable only implements rvs(), so it is skipped in the
  137. # main loop in test_cont_basic(). Here we apply just the test
  138. # check_random_state_property to levy_stable.
  139. check_random_state_property(stats.levy_stable, (0.5, 0.1))
  140. def cases_test_moments():
  141. fail_normalization = set(['vonmises', 'ksone'])
  142. fail_higher = set(['vonmises', 'ksone', 'ncf'])
  143. for distname, arg in distcont[:] + [(histogram_test_instance, tuple())]:
  144. if distname == 'levy_stable':
  145. continue
  146. cond1 = distname not in fail_normalization
  147. cond2 = distname not in fail_higher
  148. yield distname, arg, cond1, cond2, False
  149. if not cond1 or not cond2:
  150. # Run the distributions that have issues twice, once skipping the
  151. # not_ok parts, once with the not_ok parts but marked as knownfail
  152. yield pytest.param(distname, arg, True, True, True,
  153. marks=pytest.mark.xfail)
  154. @pytest.mark.slow
  155. @pytest.mark.parametrize('distname,arg,normalization_ok,higher_ok,is_xfailing',
  156. cases_test_moments())
  157. def test_moments(distname, arg, normalization_ok, higher_ok, is_xfailing):
  158. try:
  159. distfn = getattr(stats, distname)
  160. except TypeError:
  161. distfn = distname
  162. distname = 'rv_histogram_instance'
  163. with suppress_warnings() as sup:
  164. sup.filter(IntegrationWarning,
  165. "The integral is probably divergent, or slowly convergent.")
  166. sup.filter(category=DeprecationWarning, message=".*frechet_")
  167. if is_xfailing:
  168. sup.filter(IntegrationWarning)
  169. m, v, s, k = distfn.stats(*arg, moments='mvsk')
  170. if normalization_ok:
  171. check_normalization(distfn, arg, distname)
  172. if higher_ok:
  173. check_mean_expect(distfn, arg, m, distname)
  174. check_skew_expect(distfn, arg, m, v, s, distname)
  175. check_var_expect(distfn, arg, m, v, distname)
  176. check_kurt_expect(distfn, arg, m, v, k, distname)
  177. check_loc_scale(distfn, arg, m, v, distname)
  178. check_moment(distfn, arg, m, v, distname)
  179. @pytest.mark.parametrize('dist,shape_args', distcont)
  180. def test_rvs_broadcast(dist, shape_args):
  181. if dist in ['gausshyper', 'genexpon']:
  182. pytest.skip("too slow")
  183. # If shape_only is True, it means the _rvs method of the
  184. # distribution uses more than one random number to generate a random
  185. # variate. That means the result of using rvs with broadcasting or
  186. # with a nontrivial size will not necessarily be the same as using the
  187. # numpy.vectorize'd version of rvs(), so we can only compare the shapes
  188. # of the results, not the values.
  189. # Whether or not a distribution is in the following list is an
  190. # implementation detail of the distribution, not a requirement. If
  191. # the implementation the rvs() method of a distribution changes, this
  192. # test might also have to be changed.
  193. shape_only = dist in ['betaprime', 'dgamma', 'exponnorm', 'norminvgauss',
  194. 'nct', 'dweibull', 'rice', 'levy_stable', 'skewnorm']
  195. distfunc = getattr(stats, dist)
  196. loc = np.zeros(2)
  197. scale = np.ones((3, 1))
  198. nargs = distfunc.numargs
  199. allargs = []
  200. bshape = [3, 2]
  201. # Generate shape parameter arguments...
  202. for k in range(nargs):
  203. shp = (k + 4,) + (1,)*(k + 2)
  204. allargs.append(shape_args[k]*np.ones(shp))
  205. bshape.insert(0, k + 4)
  206. allargs.extend([loc, scale])
  207. # bshape holds the expected shape when loc, scale, and the shape
  208. # parameters are all broadcast together.
  209. check_rvs_broadcast(distfunc, dist, allargs, bshape, shape_only, 'd')
  210. def test_rvs_gh2069_regression():
  211. # Regression tests for gh-2069. In scipy 0.17 and earlier,
  212. # these tests would fail.
  213. #
  214. # A typical example of the broken behavior:
  215. # >>> norm.rvs(loc=np.zeros(5), scale=np.ones(5))
  216. # array([-2.49613705, -2.49613705, -2.49613705, -2.49613705, -2.49613705])
  217. np.random.seed(123)
  218. vals = stats.norm.rvs(loc=np.zeros(5), scale=1)
  219. d = np.diff(vals)
  220. npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!")
  221. vals = stats.norm.rvs(loc=0, scale=np.ones(5))
  222. d = np.diff(vals)
  223. npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!")
  224. vals = stats.norm.rvs(loc=np.zeros(5), scale=np.ones(5))
  225. d = np.diff(vals)
  226. npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!")
  227. vals = stats.norm.rvs(loc=np.array([[0], [0]]), scale=np.ones(5))
  228. d = np.diff(vals.ravel())
  229. npt.assert_(np.all(d != 0), "All the values are equal, but they shouldn't be!")
  230. assert_raises(ValueError, stats.norm.rvs, [[0, 0], [0, 0]],
  231. [[1, 1], [1, 1]], 1)
  232. assert_raises(ValueError, stats.gamma.rvs, [2, 3, 4, 5], 0, 1, (2, 2))
  233. assert_raises(ValueError, stats.gamma.rvs, [1, 1, 1, 1], [0, 0, 0, 0],
  234. [[1], [2]], (4,))
  235. def check_sample_meanvar_(distfn, arg, m, v, sm, sv, sn, msg):
  236. # this did not work, skipped silently by nose
  237. if np.isfinite(m):
  238. check_sample_mean(sm, sv, sn, m)
  239. if np.isfinite(v):
  240. check_sample_var(sv, sn, v)
  241. def check_sample_mean(sm, v, n, popmean):
  242. # from stats.stats.ttest_1samp(a, popmean):
  243. # Calculates the t-obtained for the independent samples T-test on ONE group
  244. # of scores a, given a population mean.
  245. #
  246. # Returns: t-value, two-tailed prob
  247. df = n-1
  248. svar = ((n-1)*v) / float(df) # looks redundant
  249. t = (sm-popmean) / np.sqrt(svar*(1.0/n))
  250. prob = betainc(0.5*df, 0.5, df/(df + t*t))
  251. # return t,prob
  252. npt.assert_(prob > 0.01, 'mean fail, t,prob = %f, %f, m, sm=%f,%f' %
  253. (t, prob, popmean, sm))
  254. def check_sample_var(sv, n, popvar):
  255. # two-sided chisquare test for sample variance equal to
  256. # hypothesized variance
  257. df = n-1
  258. chi2 = (n-1)*popvar/float(popvar)
  259. pval = stats.distributions.chi2.sf(chi2, df) * 2
  260. npt.assert_(pval > 0.01, 'var fail, t, pval = %f, %f, v, sv=%f, %f' %
  261. (chi2, pval, popvar, sv))
  262. def check_cdf_ppf(distfn, arg, msg):
  263. values = [0.001, 0.5, 0.999]
  264. npt.assert_almost_equal(distfn.cdf(distfn.ppf(values, *arg), *arg),
  265. values, decimal=DECIMAL, err_msg=msg +
  266. ' - cdf-ppf roundtrip')
  267. def check_sf_isf(distfn, arg, msg):
  268. npt.assert_almost_equal(distfn.sf(distfn.isf([0.1, 0.5, 0.9], *arg), *arg),
  269. [0.1, 0.5, 0.9], decimal=DECIMAL, err_msg=msg +
  270. ' - sf-isf roundtrip')
  271. npt.assert_almost_equal(distfn.cdf([0.1, 0.9], *arg),
  272. 1.0 - distfn.sf([0.1, 0.9], *arg),
  273. decimal=DECIMAL, err_msg=msg +
  274. ' - cdf-sf relationship')
  275. def check_pdf(distfn, arg, msg):
  276. # compares pdf at median with numerical derivative of cdf
  277. median = distfn.ppf(0.5, *arg)
  278. eps = 1e-6
  279. pdfv = distfn.pdf(median, *arg)
  280. if (pdfv < 1e-4) or (pdfv > 1e4):
  281. # avoid checking a case where pdf is close to zero or
  282. # huge (singularity)
  283. median = median + 0.1
  284. pdfv = distfn.pdf(median, *arg)
  285. cdfdiff = (distfn.cdf(median + eps, *arg) -
  286. distfn.cdf(median - eps, *arg))/eps/2.0
  287. # replace with better diff and better test (more points),
  288. # actually, this works pretty well
  289. msg += ' - cdf-pdf relationship'
  290. npt.assert_almost_equal(pdfv, cdfdiff, decimal=DECIMAL, err_msg=msg)
  291. def check_pdf_logpdf(distfn, args, msg):
  292. # compares pdf at several points with the log of the pdf
  293. points = np.array([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
  294. vals = distfn.ppf(points, *args)
  295. pdf = distfn.pdf(vals, *args)
  296. logpdf = distfn.logpdf(vals, *args)
  297. pdf = pdf[pdf != 0]
  298. logpdf = logpdf[np.isfinite(logpdf)]
  299. msg += " - logpdf-log(pdf) relationship"
  300. npt.assert_almost_equal(np.log(pdf), logpdf, decimal=7, err_msg=msg)
  301. def check_sf_logsf(distfn, args, msg):
  302. # compares sf at several points with the log of the sf
  303. points = np.array([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
  304. vals = distfn.ppf(points, *args)
  305. sf = distfn.sf(vals, *args)
  306. logsf = distfn.logsf(vals, *args)
  307. sf = sf[sf != 0]
  308. logsf = logsf[np.isfinite(logsf)]
  309. msg += " - logsf-log(sf) relationship"
  310. npt.assert_almost_equal(np.log(sf), logsf, decimal=7, err_msg=msg)
  311. def check_cdf_logcdf(distfn, args, msg):
  312. # compares cdf at several points with the log of the cdf
  313. points = np.array([0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8])
  314. vals = distfn.ppf(points, *args)
  315. cdf = distfn.cdf(vals, *args)
  316. logcdf = distfn.logcdf(vals, *args)
  317. cdf = cdf[cdf != 0]
  318. logcdf = logcdf[np.isfinite(logcdf)]
  319. msg += " - logcdf-log(cdf) relationship"
  320. npt.assert_almost_equal(np.log(cdf), logcdf, decimal=7, err_msg=msg)
  321. def check_distribution_rvs(dist, args, alpha, rvs):
  322. # test from scipy.stats.tests
  323. # this version reuses existing random variables
  324. D, pval = stats.kstest(rvs, dist, args=args, N=1000)
  325. if (pval < alpha):
  326. D, pval = stats.kstest(dist, '', args=args, N=1000)
  327. npt.assert_(pval > alpha, "D = " + str(D) + "; pval = " + str(pval) +
  328. "; alpha = " + str(alpha) + "\nargs = " + str(args))
  329. def check_vecentropy(distfn, args):
  330. npt.assert_equal(distfn.vecentropy(*args), distfn._entropy(*args))
  331. def check_loc_scale(distfn, arg, m, v, msg):
  332. loc, scale = 10.0, 10.0
  333. mt, vt = distfn.stats(loc=loc, scale=scale, *arg)
  334. npt.assert_allclose(m*scale + loc, mt)
  335. npt.assert_allclose(v*scale*scale, vt)
  336. def check_ppf_private(distfn, arg, msg):
  337. # fails by design for truncnorm self.nb not defined
  338. ppfs = distfn._ppf(np.array([0.1, 0.5, 0.9]), *arg)
  339. npt.assert_(not np.any(np.isnan(ppfs)), msg + 'ppf private is nan')