test_spfun_stats.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. from __future__ import division, print_function, absolute_import
  2. import numpy as np
  3. from numpy.testing import (assert_array_equal,
  4. assert_array_almost_equal_nulp, assert_almost_equal)
  5. from pytest import raises as assert_raises
  6. from scipy.special import gammaln, multigammaln
  7. class TestMultiGammaLn(object):
  8. def test1(self):
  9. # A test of the identity
  10. # Gamma_1(a) = Gamma(a)
  11. np.random.seed(1234)
  12. a = np.abs(np.random.randn())
  13. assert_array_equal(multigammaln(a, 1), gammaln(a))
  14. def test2(self):
  15. # A test of the identity
  16. # Gamma_2(a) = sqrt(pi) * Gamma(a) * Gamma(a - 0.5)
  17. a = np.array([2.5, 10.0])
  18. result = multigammaln(a, 2)
  19. expected = np.log(np.sqrt(np.pi)) + gammaln(a) + gammaln(a - 0.5)
  20. assert_almost_equal(result, expected)
  21. def test_bararg(self):
  22. assert_raises(ValueError, multigammaln, 0.5, 1.2)
  23. def _check_multigammaln_array_result(a, d):
  24. # Test that the shape of the array returned by multigammaln
  25. # matches the input shape, and that all the values match
  26. # the value computed when multigammaln is called with a scalar.
  27. result = multigammaln(a, d)
  28. assert_array_equal(a.shape, result.shape)
  29. a1 = a.ravel()
  30. result1 = result.ravel()
  31. for i in range(a.size):
  32. assert_array_almost_equal_nulp(result1[i], multigammaln(a1[i], d))
  33. def test_multigammaln_array_arg():
  34. # Check that the array returned by multigammaln has the correct
  35. # shape and contains the correct values. The cases have arrays
  36. # with several differnent shapes.
  37. # The cases include a regression test for ticket #1849
  38. # (a = np.array([2.0]), an array with a single element).
  39. np.random.seed(1234)
  40. cases = [
  41. # a, d
  42. (np.abs(np.random.randn(3, 2)) + 5, 5),
  43. (np.abs(np.random.randn(1, 2)) + 5, 5),
  44. (np.arange(10.0, 18.0).reshape(2, 2, 2), 3),
  45. (np.array([2.0]), 3),
  46. (np.float64(2.0), 3),
  47. ]
  48. for a, d in cases:
  49. _check_multigammaln_array_result(a, d)