test_discrete_basic.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. from __future__ import division, print_function, absolute_import
  2. import numpy.testing as npt
  3. import numpy as np
  4. from scipy._lib.six import xrange
  5. import pytest
  6. from scipy import stats
  7. from .common_tests import (check_normalization, check_moment, check_mean_expect,
  8. check_var_expect, check_skew_expect,
  9. check_kurt_expect, check_entropy,
  10. check_private_entropy, check_edge_support,
  11. check_named_args, check_random_state_property,
  12. check_pickling, check_rvs_broadcast)
  13. from scipy.stats._distr_params import distdiscrete
  14. vals = ([1, 2, 3, 4], [0.1, 0.2, 0.3, 0.4])
  15. distdiscrete += [[stats.rv_discrete(values=vals), ()]]
  16. def cases_test_discrete_basic():
  17. seen = set()
  18. for distname, arg in distdiscrete:
  19. yield distname, arg, distname not in seen
  20. seen.add(distname)
  21. @pytest.mark.parametrize('distname,arg,first_case', cases_test_discrete_basic())
  22. def test_discrete_basic(distname, arg, first_case):
  23. try:
  24. distfn = getattr(stats, distname)
  25. except TypeError:
  26. distfn = distname
  27. distname = 'sample distribution'
  28. np.random.seed(9765456)
  29. rvs = distfn.rvs(size=2000, *arg)
  30. supp = np.unique(rvs)
  31. m, v = distfn.stats(*arg)
  32. check_cdf_ppf(distfn, arg, supp, distname + ' cdf_ppf')
  33. check_pmf_cdf(distfn, arg, distname)
  34. check_oth(distfn, arg, supp, distname + ' oth')
  35. check_edge_support(distfn, arg)
  36. alpha = 0.01
  37. check_discrete_chisquare(distfn, arg, rvs, alpha,
  38. distname + ' chisquare')
  39. if first_case:
  40. locscale_defaults = (0,)
  41. meths = [distfn.pmf, distfn.logpmf, distfn.cdf, distfn.logcdf,
  42. distfn.logsf]
  43. # make sure arguments are within support
  44. spec_k = {'randint': 11, 'hypergeom': 4, 'bernoulli': 0, }
  45. k = spec_k.get(distname, 1)
  46. check_named_args(distfn, k, arg, locscale_defaults, meths)
  47. if distname != 'sample distribution':
  48. check_scale_docstring(distfn)
  49. check_random_state_property(distfn, arg)
  50. check_pickling(distfn, arg)
  51. # Entropy
  52. check_entropy(distfn, arg, distname)
  53. if distfn.__class__._entropy != stats.rv_discrete._entropy:
  54. check_private_entropy(distfn, arg, stats.rv_discrete)
  55. @pytest.mark.parametrize('distname,arg', distdiscrete)
  56. def test_moments(distname, arg):
  57. try:
  58. distfn = getattr(stats, distname)
  59. except TypeError:
  60. distfn = distname
  61. distname = 'sample distribution'
  62. m, v, s, k = distfn.stats(*arg, moments='mvsk')
  63. check_normalization(distfn, arg, distname)
  64. # compare `stats` and `moment` methods
  65. check_moment(distfn, arg, m, v, distname)
  66. check_mean_expect(distfn, arg, m, distname)
  67. check_var_expect(distfn, arg, m, v, distname)
  68. check_skew_expect(distfn, arg, m, v, s, distname)
  69. if distname not in ['zipf', 'yulesimon']:
  70. check_kurt_expect(distfn, arg, m, v, k, distname)
  71. # frozen distr moments
  72. check_moment_frozen(distfn, arg, m, 1)
  73. check_moment_frozen(distfn, arg, v+m*m, 2)
  74. @pytest.mark.parametrize('dist,shape_args', distdiscrete)
  75. def test_rvs_broadcast(dist, shape_args):
  76. # If shape_only is True, it means the _rvs method of the
  77. # distribution uses more than one random number to generate a random
  78. # variate. That means the result of using rvs with broadcasting or
  79. # with a nontrivial size will not necessarily be the same as using the
  80. # numpy.vectorize'd version of rvs(), so we can only compare the shapes
  81. # of the results, not the values.
  82. # Whether or not a distribution is in the following list is an
  83. # implementation detail of the distribution, not a requirement. If
  84. # the implementation the rvs() method of a distribution changes, this
  85. # test might also have to be changed.
  86. shape_only = dist in ['skellam', 'yulesimon']
  87. try:
  88. distfunc = getattr(stats, dist)
  89. except TypeError:
  90. distfunc = dist
  91. dist = 'rv_discrete(values=(%r, %r))' % (dist.xk, dist.pk)
  92. loc = np.zeros(2)
  93. nargs = distfunc.numargs
  94. allargs = []
  95. bshape = []
  96. # Generate shape parameter arguments...
  97. for k in range(nargs):
  98. shp = (k + 3,) + (1,)*(k + 1)
  99. param_val = shape_args[k]
  100. allargs.append(param_val*np.ones(shp, dtype=np.array(param_val).dtype))
  101. bshape.insert(0, shp[0])
  102. allargs.append(loc)
  103. bshape.append(loc.size)
  104. # bshape holds the expected shape when loc, scale, and the shape
  105. # parameters are all broadcast together.
  106. check_rvs_broadcast(distfunc, dist, allargs, bshape, shape_only, [np.int_])
  107. def check_cdf_ppf(distfn, arg, supp, msg):
  108. # cdf is a step function, and ppf(q) = min{k : cdf(k) >= q, k integer}
  109. npt.assert_array_equal(distfn.ppf(distfn.cdf(supp, *arg), *arg),
  110. supp, msg + '-roundtrip')
  111. npt.assert_array_equal(distfn.ppf(distfn.cdf(supp, *arg) - 1e-8, *arg),
  112. supp, msg + '-roundtrip')
  113. if not hasattr(distfn, 'xk'):
  114. supp1 = supp[supp < distfn.b]
  115. npt.assert_array_equal(distfn.ppf(distfn.cdf(supp1, *arg) + 1e-8, *arg),
  116. supp1 + distfn.inc, msg + ' ppf-cdf-next')
  117. # -1e-8 could cause an error if pmf < 1e-8
  118. def check_pmf_cdf(distfn, arg, distname):
  119. if hasattr(distfn, 'xk'):
  120. index = distfn.xk
  121. else:
  122. startind = int(distfn.ppf(0.01, *arg) - 1)
  123. index = list(range(startind, startind + 10))
  124. cdfs = distfn.cdf(index, *arg)
  125. pmfs_cum = distfn.pmf(index, *arg).cumsum()
  126. atol, rtol = 1e-10, 1e-10
  127. if distname == 'skellam': # ncx2 accuracy
  128. atol, rtol = 1e-5, 1e-5
  129. npt.assert_allclose(cdfs - cdfs[0], pmfs_cum - pmfs_cum[0],
  130. atol=atol, rtol=rtol)
  131. def check_moment_frozen(distfn, arg, m, k):
  132. npt.assert_allclose(distfn(*arg).moment(k), m,
  133. atol=1e-10, rtol=1e-10)
  134. def check_oth(distfn, arg, supp, msg):
  135. # checking other methods of distfn
  136. npt.assert_allclose(distfn.sf(supp, *arg), 1. - distfn.cdf(supp, *arg),
  137. atol=1e-10, rtol=1e-10)
  138. q = np.linspace(0.01, 0.99, 20)
  139. npt.assert_allclose(distfn.isf(q, *arg), distfn.ppf(1. - q, *arg),
  140. atol=1e-10, rtol=1e-10)
  141. median_sf = distfn.isf(0.5, *arg)
  142. npt.assert_(distfn.sf(median_sf - 1, *arg) > 0.5)
  143. npt.assert_(distfn.cdf(median_sf + 1, *arg) > 0.5)
  144. def check_discrete_chisquare(distfn, arg, rvs, alpha, msg):
  145. """Perform chisquare test for random sample of a discrete distribution
  146. Parameters
  147. ----------
  148. distname : string
  149. name of distribution function
  150. arg : sequence
  151. parameters of distribution
  152. alpha : float
  153. significance level, threshold for p-value
  154. Returns
  155. -------
  156. result : bool
  157. 0 if test passes, 1 if test fails
  158. """
  159. wsupp = 0.05
  160. # construct intervals with minimum mass `wsupp`.
  161. # intervals are left-half-open as in a cdf difference
  162. lo = int(max(distfn.a, -1000))
  163. distsupport = xrange(lo, int(min(distfn.b, 1000)) + 1)
  164. last = 0
  165. distsupp = [lo]
  166. distmass = []
  167. for ii in distsupport:
  168. current = distfn.cdf(ii, *arg)
  169. if current - last >= wsupp - 1e-14:
  170. distsupp.append(ii)
  171. distmass.append(current - last)
  172. last = current
  173. if current > (1 - wsupp):
  174. break
  175. if distsupp[-1] < distfn.b:
  176. distsupp.append(distfn.b)
  177. distmass.append(1 - last)
  178. distsupp = np.array(distsupp)
  179. distmass = np.array(distmass)
  180. # convert intervals to right-half-open as required by histogram
  181. histsupp = distsupp + 1e-8
  182. histsupp[0] = distfn.a
  183. # find sample frequencies and perform chisquare test
  184. freq, hsupp = np.histogram(rvs, histsupp)
  185. chis, pval = stats.chisquare(np.array(freq), len(rvs)*distmass)
  186. npt.assert_(pval > alpha,
  187. 'chisquare - test for %s at arg = %s with pval = %s' %
  188. (msg, str(arg), str(pval)))
  189. def check_scale_docstring(distfn):
  190. if distfn.__doc__ is not None:
  191. # Docstrings can be stripped if interpreter is run with -OO
  192. npt.assert_('scale' not in distfn.__doc__)