common_tests.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. from __future__ import division, print_function, absolute_import
  2. import pickle
  3. import numpy as np
  4. import numpy.testing as npt
  5. from numpy.testing import assert_allclose, assert_equal
  6. from scipy._lib._numpy_compat import suppress_warnings
  7. from pytest import raises as assert_raises
  8. import numpy.ma.testutils as ma_npt
  9. from scipy._lib._util import getargspec_no_self as _getargspec
  10. from scipy import stats
  11. def check_named_results(res, attributes, ma=False):
  12. for i, attr in enumerate(attributes):
  13. if ma:
  14. ma_npt.assert_equal(res[i], getattr(res, attr))
  15. else:
  16. npt.assert_equal(res[i], getattr(res, attr))
  17. def check_normalization(distfn, args, distname):
  18. norm_moment = distfn.moment(0, *args)
  19. npt.assert_allclose(norm_moment, 1.0)
  20. # this is a temporary plug: either ncf or expect is problematic;
  21. # best be marked as a knownfail, but I've no clue how to do it.
  22. if distname == "ncf":
  23. atol, rtol = 1e-5, 0
  24. else:
  25. atol, rtol = 1e-7, 1e-7
  26. normalization_expect = distfn.expect(lambda x: 1, args=args)
  27. npt.assert_allclose(normalization_expect, 1.0, atol=atol, rtol=rtol,
  28. err_msg=distname, verbose=True)
  29. normalization_cdf = distfn.cdf(distfn.b, *args)
  30. npt.assert_allclose(normalization_cdf, 1.0)
  31. def check_moment(distfn, arg, m, v, msg):
  32. m1 = distfn.moment(1, *arg)
  33. m2 = distfn.moment(2, *arg)
  34. if not np.isinf(m):
  35. npt.assert_almost_equal(m1, m, decimal=10, err_msg=msg +
  36. ' - 1st moment')
  37. else: # or np.isnan(m1),
  38. npt.assert_(np.isinf(m1),
  39. msg + ' - 1st moment -infinite, m1=%s' % str(m1))
  40. if not np.isinf(v):
  41. npt.assert_almost_equal(m2 - m1 * m1, v, decimal=10, err_msg=msg +
  42. ' - 2ndt moment')
  43. else: # or np.isnan(m2),
  44. npt.assert_(np.isinf(m2),
  45. msg + ' - 2nd moment -infinite, m2=%s' % str(m2))
  46. def check_mean_expect(distfn, arg, m, msg):
  47. if np.isfinite(m):
  48. m1 = distfn.expect(lambda x: x, arg)
  49. npt.assert_almost_equal(m1, m, decimal=5, err_msg=msg +
  50. ' - 1st moment (expect)')
  51. def check_var_expect(distfn, arg, m, v, msg):
  52. if np.isfinite(v):
  53. m2 = distfn.expect(lambda x: x*x, arg)
  54. npt.assert_almost_equal(m2, v + m*m, decimal=5, err_msg=msg +
  55. ' - 2st moment (expect)')
  56. def check_skew_expect(distfn, arg, m, v, s, msg):
  57. if np.isfinite(s):
  58. m3e = distfn.expect(lambda x: np.power(x-m, 3), arg)
  59. npt.assert_almost_equal(m3e, s * np.power(v, 1.5),
  60. decimal=5, err_msg=msg + ' - skew')
  61. else:
  62. npt.assert_(np.isnan(s))
  63. def check_kurt_expect(distfn, arg, m, v, k, msg):
  64. if np.isfinite(k):
  65. m4e = distfn.expect(lambda x: np.power(x-m, 4), arg)
  66. npt.assert_allclose(m4e, (k + 3.) * np.power(v, 2), atol=1e-5, rtol=1e-5,
  67. err_msg=msg + ' - kurtosis')
  68. elif not np.isposinf(k):
  69. npt.assert_(np.isnan(k))
  70. def check_entropy(distfn, arg, msg):
  71. ent = distfn.entropy(*arg)
  72. npt.assert_(not np.isnan(ent), msg + 'test Entropy is nan')
  73. def check_private_entropy(distfn, args, superclass):
  74. # compare a generic _entropy with the distribution-specific implementation
  75. npt.assert_allclose(distfn._entropy(*args),
  76. superclass._entropy(distfn, *args))
  77. def check_entropy_vect_scale(distfn, arg):
  78. # check 2-d
  79. sc = np.asarray([[1, 2], [3, 4]])
  80. v_ent = distfn.entropy(*arg, scale=sc)
  81. s_ent = [distfn.entropy(*arg, scale=s) for s in sc.ravel()]
  82. s_ent = np.asarray(s_ent).reshape(v_ent.shape)
  83. assert_allclose(v_ent, s_ent, atol=1e-14)
  84. # check invalid value, check cast
  85. sc = [1, 2, -3]
  86. v_ent = distfn.entropy(*arg, scale=sc)
  87. s_ent = [distfn.entropy(*arg, scale=s) for s in sc]
  88. s_ent = np.asarray(s_ent).reshape(v_ent.shape)
  89. assert_allclose(v_ent, s_ent, atol=1e-14)
  90. def check_edge_support(distfn, args):
  91. # Make sure that x=self.a and self.b are handled correctly.
  92. x = [distfn.a, distfn.b]
  93. if isinstance(distfn, stats.rv_discrete):
  94. x = [distfn.a - 1, distfn.b]
  95. npt.assert_equal(distfn.cdf(x, *args), [0.0, 1.0])
  96. npt.assert_equal(distfn.sf(x, *args), [1.0, 0.0])
  97. if distfn.name not in ('skellam', 'dlaplace'):
  98. # with a = -inf, log(0) generates warnings
  99. npt.assert_equal(distfn.logcdf(x, *args), [-np.inf, 0.0])
  100. npt.assert_equal(distfn.logsf(x, *args), [0.0, -np.inf])
  101. npt.assert_equal(distfn.ppf([0.0, 1.0], *args), x)
  102. npt.assert_equal(distfn.isf([0.0, 1.0], *args), x[::-1])
  103. # out-of-bounds for isf & ppf
  104. npt.assert_(np.isnan(distfn.isf([-1, 2], *args)).all())
  105. npt.assert_(np.isnan(distfn.ppf([-1, 2], *args)).all())
  106. def check_named_args(distfn, x, shape_args, defaults, meths):
  107. ## Check calling w/ named arguments.
  108. # check consistency of shapes, numargs and _parse signature
  109. signature = _getargspec(distfn._parse_args)
  110. npt.assert_(signature.varargs is None)
  111. npt.assert_(signature.keywords is None)
  112. npt.assert_(list(signature.defaults) == list(defaults))
  113. shape_argnames = signature.args[:-len(defaults)] # a, b, loc=0, scale=1
  114. if distfn.shapes:
  115. shapes_ = distfn.shapes.replace(',', ' ').split()
  116. else:
  117. shapes_ = ''
  118. npt.assert_(len(shapes_) == distfn.numargs)
  119. npt.assert_(len(shapes_) == len(shape_argnames))
  120. # check calling w/ named arguments
  121. shape_args = list(shape_args)
  122. vals = [meth(x, *shape_args) for meth in meths]
  123. npt.assert_(np.all(np.isfinite(vals)))
  124. names, a, k = shape_argnames[:], shape_args[:], {}
  125. while names:
  126. k.update({names.pop(): a.pop()})
  127. v = [meth(x, *a, **k) for meth in meths]
  128. npt.assert_array_equal(vals, v)
  129. if 'n' not in k.keys():
  130. # `n` is first parameter of moment(), so can't be used as named arg
  131. npt.assert_equal(distfn.moment(1, *a, **k),
  132. distfn.moment(1, *shape_args))
  133. # unknown arguments should not go through:
  134. k.update({'kaboom': 42})
  135. assert_raises(TypeError, distfn.cdf, x, **k)
  136. def check_random_state_property(distfn, args):
  137. # check the random_state attribute of a distribution *instance*
  138. # This test fiddles with distfn.random_state. This breaks other tests,
  139. # hence need to save it and then restore.
  140. rndm = distfn.random_state
  141. # baseline: this relies on the global state
  142. np.random.seed(1234)
  143. distfn.random_state = None
  144. r0 = distfn.rvs(*args, size=8)
  145. # use an explicit instance-level random_state
  146. distfn.random_state = 1234
  147. r1 = distfn.rvs(*args, size=8)
  148. npt.assert_equal(r0, r1)
  149. distfn.random_state = np.random.RandomState(1234)
  150. r2 = distfn.rvs(*args, size=8)
  151. npt.assert_equal(r0, r2)
  152. # can override the instance-level random_state for an individual .rvs call
  153. distfn.random_state = 2
  154. orig_state = distfn.random_state.get_state()
  155. r3 = distfn.rvs(*args, size=8, random_state=np.random.RandomState(1234))
  156. npt.assert_equal(r0, r3)
  157. # ... and that does not alter the instance-level random_state!
  158. npt.assert_equal(distfn.random_state.get_state(), orig_state)
  159. # finally, restore the random_state
  160. distfn.random_state = rndm
  161. def check_meth_dtype(distfn, arg, meths):
  162. q0 = [0.25, 0.5, 0.75]
  163. x0 = distfn.ppf(q0, *arg)
  164. x_cast = [x0.astype(tp) for tp in
  165. (np.int_, np.float16, np.float32, np.float64)]
  166. for x in x_cast:
  167. # casting may have clipped the values, exclude those
  168. distfn._argcheck(*arg)
  169. x = x[(distfn.a < x) & (x < distfn.b)]
  170. for meth in meths:
  171. val = meth(x, *arg)
  172. npt.assert_(val.dtype == np.float_)
  173. def check_ppf_dtype(distfn, arg):
  174. q0 = np.asarray([0.25, 0.5, 0.75])
  175. q_cast = [q0.astype(tp) for tp in (np.float16, np.float32, np.float64)]
  176. for q in q_cast:
  177. for meth in [distfn.ppf, distfn.isf]:
  178. val = meth(q, *arg)
  179. npt.assert_(val.dtype == np.float_)
  180. def check_cmplx_deriv(distfn, arg):
  181. # Distributions allow complex arguments.
  182. def deriv(f, x, *arg):
  183. x = np.asarray(x)
  184. h = 1e-10
  185. return (f(x + h*1j, *arg)/h).imag
  186. x0 = distfn.ppf([0.25, 0.51, 0.75], *arg)
  187. x_cast = [x0.astype(tp) for tp in
  188. (np.int_, np.float16, np.float32, np.float64)]
  189. for x in x_cast:
  190. # casting may have clipped the values, exclude those
  191. distfn._argcheck(*arg)
  192. x = x[(distfn.a < x) & (x < distfn.b)]
  193. pdf, cdf, sf = distfn.pdf(x, *arg), distfn.cdf(x, *arg), distfn.sf(x, *arg)
  194. assert_allclose(deriv(distfn.cdf, x, *arg), pdf, rtol=1e-5)
  195. assert_allclose(deriv(distfn.logcdf, x, *arg), pdf/cdf, rtol=1e-5)
  196. assert_allclose(deriv(distfn.sf, x, *arg), -pdf, rtol=1e-5)
  197. assert_allclose(deriv(distfn.logsf, x, *arg), -pdf/sf, rtol=1e-5)
  198. assert_allclose(deriv(distfn.logpdf, x, *arg),
  199. deriv(distfn.pdf, x, *arg) / distfn.pdf(x, *arg),
  200. rtol=1e-5)
  201. def check_pickling(distfn, args):
  202. # check that a distribution instance pickles and unpickles
  203. # pay special attention to the random_state property
  204. # save the random_state (restore later)
  205. rndm = distfn.random_state
  206. distfn.random_state = 1234
  207. distfn.rvs(*args, size=8)
  208. s = pickle.dumps(distfn)
  209. r0 = distfn.rvs(*args, size=8)
  210. unpickled = pickle.loads(s)
  211. r1 = unpickled.rvs(*args, size=8)
  212. npt.assert_equal(r0, r1)
  213. # also smoke test some methods
  214. medians = [distfn.ppf(0.5, *args), unpickled.ppf(0.5, *args)]
  215. npt.assert_equal(medians[0], medians[1])
  216. npt.assert_equal(distfn.cdf(medians[0], *args),
  217. unpickled.cdf(medians[1], *args))
  218. # restore the random_state
  219. distfn.random_state = rndm
  220. def check_rvs_broadcast(distfunc, distname, allargs, shape, shape_only, otype):
  221. np.random.seed(123)
  222. with suppress_warnings() as sup:
  223. # frechet_l and frechet_r are deprecated, so all their
  224. # methods generate DeprecationWarnings.
  225. sup.filter(category=DeprecationWarning, message=".*frechet_")
  226. sample = distfunc.rvs(*allargs)
  227. assert_equal(sample.shape, shape, "%s: rvs failed to broadcast" % distname)
  228. if not shape_only:
  229. rvs = np.vectorize(lambda *allargs: distfunc.rvs(*allargs), otypes=otype)
  230. np.random.seed(123)
  231. expected = rvs(*allargs)
  232. assert_allclose(sample, expected, rtol=1e-15)