test_nan_inputs.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. """Test how the ufuncs in special handle nan inputs.
  2. """
  3. from __future__ import division, print_function, absolute_import
  4. import numpy as np
  5. from numpy.testing import assert_array_equal, assert_
  6. import pytest
  7. import scipy.special as sc
  8. from scipy._lib._numpy_compat import suppress_warnings
  9. KNOWNFAILURES = {}
  10. POSTPROCESSING = {}
  11. def _get_ufuncs():
  12. ufuncs = []
  13. ufunc_names = []
  14. for name in sorted(sc.__dict__):
  15. obj = sc.__dict__[name]
  16. if not isinstance(obj, np.ufunc):
  17. continue
  18. msg = KNOWNFAILURES.get(obj)
  19. if msg is None:
  20. ufuncs.append(obj)
  21. ufunc_names.append(name)
  22. else:
  23. fail = pytest.mark.xfail(run=False, reason=msg)
  24. ufuncs.append(pytest.param(obj, marks=fail))
  25. ufunc_names.append(name)
  26. return ufuncs, ufunc_names
  27. UFUNCS, UFUNC_NAMES = _get_ufuncs()
  28. @pytest.mark.parametrize("func", UFUNCS, ids=UFUNC_NAMES)
  29. def test_nan_inputs(func):
  30. args = (np.nan,)*func.nin
  31. with suppress_warnings() as sup:
  32. # Ignore warnings about unsafe casts from legacy wrappers
  33. sup.filter(RuntimeWarning,
  34. "floating point number truncated to an integer")
  35. try:
  36. res = func(*args)
  37. except TypeError:
  38. # One of the arguments doesn't take real inputs
  39. return
  40. if func in POSTPROCESSING:
  41. res = POSTPROCESSING[func](*res)
  42. msg = "got {} instead of nan".format(res)
  43. assert_array_equal(np.isnan(res), True, err_msg=msg)
  44. def test_legacy_cast():
  45. with suppress_warnings() as sup:
  46. sup.filter(RuntimeWarning,
  47. "floating point number truncated to an integer")
  48. res = sc.bdtrc(np.nan, 1, 0.5)
  49. assert_(np.isnan(res))