test_sf_error.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from __future__ import division, print_function, absolute_import
  2. import warnings
  3. from numpy.testing import assert_, assert_equal
  4. from scipy._lib._numpy_compat import suppress_warnings
  5. import pytest
  6. from pytest import raises as assert_raises
  7. import scipy.special as sc
  8. from scipy.special._ufuncs import _sf_error_test_function
  9. _sf_error_code_map = {
  10. # skip 'ok'
  11. 'singular': 1,
  12. 'underflow': 2,
  13. 'overflow': 3,
  14. 'slow': 4,
  15. 'loss': 5,
  16. 'no_result': 6,
  17. 'domain': 7,
  18. 'arg': 8,
  19. 'other': 9
  20. }
  21. _sf_error_actions = [
  22. 'ignore',
  23. 'warn',
  24. 'raise'
  25. ]
  26. def _check_action(fun, args, action):
  27. if action == 'warn':
  28. with pytest.warns(sc.SpecialFunctionWarning):
  29. fun(*args)
  30. elif action == 'raise':
  31. with assert_raises(sc.SpecialFunctionError):
  32. fun(*args)
  33. else:
  34. # action == 'ignore', make sure there are no warnings/exceptions
  35. with warnings.catch_warnings():
  36. warnings.simplefilter("error")
  37. fun(*args)
  38. def test_geterr():
  39. err = sc.geterr()
  40. for key, value in err.items():
  41. assert_(key in _sf_error_code_map.keys())
  42. assert_(value in _sf_error_actions)
  43. def test_seterr():
  44. entry_err = sc.geterr()
  45. try:
  46. for category in _sf_error_code_map.keys():
  47. for action in _sf_error_actions:
  48. geterr_olderr = sc.geterr()
  49. seterr_olderr = sc.seterr(**{category: action})
  50. assert_(geterr_olderr == seterr_olderr)
  51. newerr = sc.geterr()
  52. assert_(newerr[category] == action)
  53. geterr_olderr.pop(category)
  54. newerr.pop(category)
  55. assert_(geterr_olderr == newerr)
  56. _check_action(_sf_error_test_function,
  57. (_sf_error_code_map[category],),
  58. action)
  59. finally:
  60. sc.seterr(**entry_err)
  61. def test_errstate_pyx_basic():
  62. olderr = sc.geterr()
  63. with sc.errstate(singular='raise'):
  64. with assert_raises(sc.SpecialFunctionError):
  65. sc.loggamma(0)
  66. assert_equal(olderr, sc.geterr())
  67. def test_errstate_c_basic():
  68. olderr = sc.geterr()
  69. with sc.errstate(domain='raise'):
  70. with assert_raises(sc.SpecialFunctionError):
  71. sc.spence(-1)
  72. assert_equal(olderr, sc.geterr())
  73. def test_errstate_cpp_basic():
  74. olderr = sc.geterr()
  75. with sc.errstate(underflow='raise'):
  76. with assert_raises(sc.SpecialFunctionError):
  77. sc.wrightomega(-1000)
  78. assert_equal(olderr, sc.geterr())
  79. def test_errstate():
  80. for category in _sf_error_code_map.keys():
  81. for action in _sf_error_actions:
  82. olderr = sc.geterr()
  83. with sc.errstate(**{category: action}):
  84. _check_action(_sf_error_test_function,
  85. (_sf_error_code_map[category],),
  86. action)
  87. assert_equal(olderr, sc.geterr())
  88. def test_errstate_all_but_one():
  89. olderr = sc.geterr()
  90. with sc.errstate(all='raise', singular='ignore'):
  91. sc.gammaln(0)
  92. with assert_raises(sc.SpecialFunctionError):
  93. sc.spence(-1.0)
  94. assert_equal(olderr, sc.geterr())