test_logit.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.testing import (assert_equal, assert_almost_equal,
  4. assert_allclose)
  5. from scipy.special import logit, expit
  6. class TestLogit(object):
  7. def check_logit_out(self, dtype, expected):
  8. a = np.linspace(0,1,10)
  9. a = np.array(a, dtype=dtype)
  10. olderr = np.seterr(divide='ignore')
  11. try:
  12. actual = logit(a)
  13. finally:
  14. np.seterr(**olderr)
  15. assert_almost_equal(actual, expected)
  16. assert_equal(actual.dtype, np.dtype(dtype))
  17. def test_float32(self):
  18. expected = np.array([-np.inf, -2.07944155,
  19. -1.25276291, -0.69314718,
  20. -0.22314353, 0.22314365,
  21. 0.6931473, 1.25276303,
  22. 2.07944155, np.inf], dtype=np.float32)
  23. self.check_logit_out('f4', expected)
  24. def test_float64(self):
  25. expected = np.array([-np.inf, -2.07944154,
  26. -1.25276297, -0.69314718,
  27. -0.22314355, 0.22314355,
  28. 0.69314718, 1.25276297,
  29. 2.07944154, np.inf])
  30. self.check_logit_out('f8', expected)
  31. def test_nan(self):
  32. expected = np.array([np.nan]*4)
  33. olderr = np.seterr(invalid='ignore')
  34. try:
  35. actual = logit(np.array([-3., -2., 2., 3.]))
  36. finally:
  37. np.seterr(**olderr)
  38. assert_equal(expected, actual)
  39. class TestExpit(object):
  40. def check_expit_out(self, dtype, expected):
  41. a = np.linspace(-4,4,10)
  42. a = np.array(a, dtype=dtype)
  43. actual = expit(a)
  44. assert_almost_equal(actual, expected)
  45. assert_equal(actual.dtype, np.dtype(dtype))
  46. def test_float32(self):
  47. expected = np.array([0.01798621, 0.04265125,
  48. 0.09777259, 0.20860852,
  49. 0.39068246, 0.60931754,
  50. 0.79139149, 0.9022274,
  51. 0.95734876, 0.98201376], dtype=np.float32)
  52. self.check_expit_out('f4',expected)
  53. def test_float64(self):
  54. expected = np.array([0.01798621, 0.04265125,
  55. 0.0977726, 0.20860853,
  56. 0.39068246, 0.60931754,
  57. 0.79139147, 0.9022274,
  58. 0.95734875, 0.98201379])
  59. self.check_expit_out('f8', expected)
  60. def test_large(self):
  61. for dtype in (np.float32, np.float64, np.longdouble):
  62. for n in (88, 89, 709, 710, 11356, 11357):
  63. n = np.array(n, dtype=dtype)
  64. assert_allclose(expit(n), 1.0, atol=1e-20)
  65. assert_allclose(expit(-n), 0.0, atol=1e-20)
  66. assert_equal(expit(n).dtype, dtype)
  67. assert_equal(expit(-n).dtype, dtype)