test_fit.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from __future__ import division, print_function, absolute_import
  2. import os
  3. import numpy as np
  4. from numpy.testing import assert_allclose
  5. from scipy._lib._numpy_compat import suppress_warnings
  6. import pytest
  7. from scipy import stats
  8. from .test_continuous_basic import distcont
  9. # this is not a proper statistical test for convergence, but only
  10. # verifies that the estimate and true values don't differ by too much
  11. fit_sizes = [1000, 5000] # sample sizes to try
  12. thresh_percent = 0.25 # percent of true parameters for fail cut-off
  13. thresh_min = 0.75 # minimum difference estimate - true to fail test
  14. failing_fits = [
  15. 'burr',
  16. 'chi2',
  17. 'gausshyper',
  18. 'genexpon',
  19. 'gengamma',
  20. 'kappa4',
  21. 'ksone',
  22. 'mielke',
  23. 'ncf',
  24. 'ncx2',
  25. 'pearson3',
  26. 'powerlognorm',
  27. 'truncexpon',
  28. 'tukeylambda',
  29. 'vonmises',
  30. 'wrapcauchy',
  31. 'levy_stable',
  32. 'trapz'
  33. ]
  34. # Don't run the fit test on these:
  35. skip_fit = [
  36. 'erlang', # Subclass of gamma, generates a warning.
  37. ]
  38. def cases_test_cont_fit():
  39. # this tests the closeness of the estimated parameters to the true
  40. # parameters with fit method of continuous distributions
  41. # Note: is slow, some distributions don't converge with sample size <= 10000
  42. for distname, arg in distcont:
  43. if distname not in skip_fit:
  44. yield distname, arg
  45. @pytest.mark.slow
  46. @pytest.mark.parametrize('distname,arg', cases_test_cont_fit())
  47. def test_cont_fit(distname, arg):
  48. if distname in failing_fits:
  49. # Skip failing fits unless overridden
  50. try:
  51. xfail = not int(os.environ['SCIPY_XFAIL'])
  52. except Exception:
  53. xfail = True
  54. if xfail:
  55. msg = "Fitting %s doesn't work reliably yet" % distname
  56. msg += " [Set environment variable SCIPY_XFAIL=1 to run this test nevertheless.]"
  57. pytest.xfail(msg)
  58. distfn = getattr(stats, distname)
  59. truearg = np.hstack([arg, [0.0, 1.0]])
  60. diffthreshold = np.max(np.vstack([truearg*thresh_percent,
  61. np.ones(distfn.numargs+2)*thresh_min]),
  62. 0)
  63. for fit_size in fit_sizes:
  64. # Note that if a fit succeeds, the other fit_sizes are skipped
  65. np.random.seed(1234)
  66. with np.errstate(all='ignore'), suppress_warnings() as sup:
  67. sup.filter(category=DeprecationWarning, message=".*frechet_")
  68. rvs = distfn.rvs(size=fit_size, *arg)
  69. est = distfn.fit(rvs) # start with default values
  70. diff = est - truearg
  71. # threshold for location
  72. diffthreshold[-2] = np.max([np.abs(rvs.mean())*thresh_percent,thresh_min])
  73. if np.any(np.isnan(est)):
  74. raise AssertionError('nan returned in fit')
  75. else:
  76. if np.all(np.abs(diff) <= diffthreshold):
  77. break
  78. else:
  79. txt = 'parameter: %s\n' % str(truearg)
  80. txt += 'estimated: %s\n' % str(est)
  81. txt += 'diff : %s\n' % str(diff)
  82. raise AssertionError('fit not very good in %s\n' % distfn.name + txt)
  83. def _check_loc_scale_mle_fit(name, data, desired, atol=None):
  84. d = getattr(stats, name)
  85. actual = d.fit(data)[-2:]
  86. assert_allclose(actual, desired, atol=atol,
  87. err_msg='poor mle fit of (loc, scale) in %s' % name)
  88. def test_non_default_loc_scale_mle_fit():
  89. data = np.array([1.01, 1.78, 1.78, 1.78, 1.88, 1.88, 1.88, 2.00])
  90. _check_loc_scale_mle_fit('uniform', data, [1.01, 0.99], 1e-3)
  91. _check_loc_scale_mle_fit('expon', data, [1.01, 0.73875], 1e-3)
  92. def test_expon_fit():
  93. """gh-6167"""
  94. data = [0, 0, 0, 0, 2, 2, 2, 2]
  95. phat = stats.expon.fit(data, floc=0)
  96. assert_allclose(phat, [0, 1.0], atol=1e-3)