test_sputils.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. """unit tests for sparse utility functions"""
  2. from __future__ import division, print_function, absolute_import
  3. import numpy as np
  4. from numpy.testing import assert_equal, assert_raises
  5. from pytest import raises as assert_raises
  6. from scipy.sparse import sputils
  7. from scipy._lib._numpy_compat import suppress_warnings
  8. class TestSparseUtils(object):
  9. def test_upcast(self):
  10. assert_equal(sputils.upcast('intc'), np.intc)
  11. assert_equal(sputils.upcast('int32', 'float32'), np.float64)
  12. assert_equal(sputils.upcast('bool', complex, float), np.complex128)
  13. assert_equal(sputils.upcast('i', 'd'), np.float64)
  14. def test_getdtype(self):
  15. A = np.array([1], dtype='int8')
  16. assert_equal(sputils.getdtype(None, default=float), float)
  17. assert_equal(sputils.getdtype(None, a=A), np.int8)
  18. def test_isscalarlike(self):
  19. assert_equal(sputils.isscalarlike(3.0), True)
  20. assert_equal(sputils.isscalarlike(-4), True)
  21. assert_equal(sputils.isscalarlike(2.5), True)
  22. assert_equal(sputils.isscalarlike(1 + 3j), True)
  23. assert_equal(sputils.isscalarlike(np.array(3)), True)
  24. assert_equal(sputils.isscalarlike("16"), True)
  25. assert_equal(sputils.isscalarlike(np.array([3])), False)
  26. assert_equal(sputils.isscalarlike([[3]]), False)
  27. assert_equal(sputils.isscalarlike((1,)), False)
  28. assert_equal(sputils.isscalarlike((1, 2)), False)
  29. def test_isintlike(self):
  30. assert_equal(sputils.isintlike(-4), True)
  31. assert_equal(sputils.isintlike(np.array(3)), True)
  32. assert_equal(sputils.isintlike(np.array([3])), False)
  33. with suppress_warnings() as sup:
  34. sup.filter(DeprecationWarning,
  35. "Inexact indices into sparse matrices are deprecated")
  36. assert_equal(sputils.isintlike(3.0), True)
  37. assert_equal(sputils.isintlike(2.5), False)
  38. assert_equal(sputils.isintlike(1 + 3j), False)
  39. assert_equal(sputils.isintlike((1,)), False)
  40. assert_equal(sputils.isintlike((1, 2)), False)
  41. def test_isshape(self):
  42. assert_equal(sputils.isshape((1, 2)), True)
  43. assert_equal(sputils.isshape((5, 2)), True)
  44. assert_equal(sputils.isshape((1.5, 2)), False)
  45. assert_equal(sputils.isshape((2, 2, 2)), False)
  46. assert_equal(sputils.isshape(([2], 2)), False)
  47. assert_equal(sputils.isshape((-1, 2), nonneg=False),True)
  48. assert_equal(sputils.isshape((2, -1), nonneg=False),True)
  49. assert_equal(sputils.isshape((-1, 2), nonneg=True),False)
  50. assert_equal(sputils.isshape((2, -1), nonneg=True),False)
  51. def test_issequence(self):
  52. assert_equal(sputils.issequence((1,)), True)
  53. assert_equal(sputils.issequence((1, 2, 3)), True)
  54. assert_equal(sputils.issequence([1]), True)
  55. assert_equal(sputils.issequence([1, 2, 3]), True)
  56. assert_equal(sputils.issequence(np.array([1, 2, 3])), True)
  57. assert_equal(sputils.issequence(np.array([[1], [2], [3]])), False)
  58. assert_equal(sputils.issequence(3), False)
  59. def test_ismatrix(self):
  60. assert_equal(sputils.ismatrix(((),)), True)
  61. assert_equal(sputils.ismatrix([[1], [2]]), True)
  62. assert_equal(sputils.ismatrix(np.arange(3)[None]), True)
  63. assert_equal(sputils.ismatrix([1, 2]), False)
  64. assert_equal(sputils.ismatrix(np.arange(3)), False)
  65. assert_equal(sputils.ismatrix([[[1]]]), False)
  66. assert_equal(sputils.ismatrix(3), False)
  67. def test_isdense(self):
  68. assert_equal(sputils.isdense(np.array([1])), True)
  69. assert_equal(sputils.isdense(np.matrix([1])), True)
  70. def test_validateaxis(self):
  71. assert_raises(TypeError, sputils.validateaxis, (0, 1))
  72. assert_raises(TypeError, sputils.validateaxis, 1.5)
  73. assert_raises(ValueError, sputils.validateaxis, 3)
  74. # These function calls should not raise errors
  75. for axis in (-2, -1, 0, 1, None):
  76. sputils.validateaxis(axis)
  77. def test_get_index_dtype(self):
  78. imax = np.iinfo(np.int32).max
  79. too_big = imax + 1
  80. # Check that uint32's with no values too large doesn't return
  81. # int64
  82. a1 = np.ones(90, dtype='uint32')
  83. a2 = np.ones(90, dtype='uint32')
  84. assert_equal(
  85. np.dtype(sputils.get_index_dtype((a1, a2), check_contents=True)),
  86. np.dtype('int32')
  87. )
  88. # Check that if we can not convert but all values are less than or
  89. # equal to max that we can just convert to int32
  90. a1[-1] = imax
  91. assert_equal(
  92. np.dtype(sputils.get_index_dtype((a1, a2), check_contents=True)),
  93. np.dtype('int32')
  94. )
  95. # Check that if it can not convert directly and the contents are
  96. # too large that we return int64
  97. a1[-1] = too_big
  98. assert_equal(
  99. np.dtype(sputils.get_index_dtype((a1, a2), check_contents=True)),
  100. np.dtype('int64')
  101. )
  102. # test that if can not convert and didn't specify to check_contents
  103. # we return int64
  104. a1 = np.ones(89, dtype='uint32')
  105. a2 = np.ones(89, dtype='uint32')
  106. assert_equal(
  107. np.dtype(sputils.get_index_dtype((a1, a2))),
  108. np.dtype('int64')
  109. )
  110. # Check that even if we have arrays that can be converted directly
  111. # that if we specify a maxval directly it takes precedence
  112. a1 = np.ones(12, dtype='uint32')
  113. a2 = np.ones(12, dtype='uint32')
  114. assert_equal(
  115. np.dtype(sputils.get_index_dtype(
  116. (a1, a2), maxval=too_big, check_contents=True
  117. )),
  118. np.dtype('int64')
  119. )
  120. # Check that an array with a too max size and maxval set
  121. # still returns int64
  122. a1[-1] = too_big
  123. assert_equal(
  124. np.dtype(sputils.get_index_dtype((a1, a2), maxval=too_big)),
  125. np.dtype('int64')
  126. )