test_mixins.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. from __future__ import division, absolute_import, print_function
  2. import numbers
  3. import operator
  4. import sys
  5. import numpy as np
  6. from numpy.testing import assert_, assert_equal, assert_raises
  7. PY2 = sys.version_info.major < 3
  8. # NOTE: This class should be kept as an exact copy of the example from the
  9. # docstring for NDArrayOperatorsMixin.
  10. class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
  11. def __init__(self, value):
  12. self.value = np.asarray(value)
  13. # One might also consider adding the built-in list type to this
  14. # list, to support operations like np.add(array_like, list)
  15. _HANDLED_TYPES = (np.ndarray, numbers.Number)
  16. def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
  17. out = kwargs.get('out', ())
  18. for x in inputs + out:
  19. # Only support operations with instances of _HANDLED_TYPES.
  20. # Use ArrayLike instead of type(self) for isinstance to
  21. # allow subclasses that don't override __array_ufunc__ to
  22. # handle ArrayLike objects.
  23. if not isinstance(x, self._HANDLED_TYPES + (ArrayLike,)):
  24. return NotImplemented
  25. # Defer to the implementation of the ufunc on unwrapped values.
  26. inputs = tuple(x.value if isinstance(x, ArrayLike) else x
  27. for x in inputs)
  28. if out:
  29. kwargs['out'] = tuple(
  30. x.value if isinstance(x, ArrayLike) else x
  31. for x in out)
  32. result = getattr(ufunc, method)(*inputs, **kwargs)
  33. if type(result) is tuple:
  34. # multiple return values
  35. return tuple(type(self)(x) for x in result)
  36. elif method == 'at':
  37. # no return value
  38. return None
  39. else:
  40. # one return value
  41. return type(self)(result)
  42. def __repr__(self):
  43. return '%s(%r)' % (type(self).__name__, self.value)
  44. def wrap_array_like(result):
  45. if type(result) is tuple:
  46. return tuple(ArrayLike(r) for r in result)
  47. else:
  48. return ArrayLike(result)
  49. def _assert_equal_type_and_value(result, expected, err_msg=None):
  50. assert_equal(type(result), type(expected), err_msg=err_msg)
  51. if isinstance(result, tuple):
  52. assert_equal(len(result), len(expected), err_msg=err_msg)
  53. for result_item, expected_item in zip(result, expected):
  54. _assert_equal_type_and_value(result_item, expected_item, err_msg)
  55. else:
  56. assert_equal(result.value, expected.value, err_msg=err_msg)
  57. assert_equal(getattr(result.value, 'dtype', None),
  58. getattr(expected.value, 'dtype', None), err_msg=err_msg)
  59. _ALL_BINARY_OPERATORS = [
  60. operator.lt,
  61. operator.le,
  62. operator.eq,
  63. operator.ne,
  64. operator.gt,
  65. operator.ge,
  66. operator.add,
  67. operator.sub,
  68. operator.mul,
  69. operator.truediv,
  70. operator.floordiv,
  71. # TODO: test div on Python 2, only
  72. operator.mod,
  73. divmod,
  74. pow,
  75. operator.lshift,
  76. operator.rshift,
  77. operator.and_,
  78. operator.xor,
  79. operator.or_,
  80. ]
  81. class TestNDArrayOperatorsMixin(object):
  82. def test_array_like_add(self):
  83. def check(result):
  84. _assert_equal_type_and_value(result, ArrayLike(0))
  85. check(ArrayLike(0) + 0)
  86. check(0 + ArrayLike(0))
  87. check(ArrayLike(0) + np.array(0))
  88. check(np.array(0) + ArrayLike(0))
  89. check(ArrayLike(np.array(0)) + 0)
  90. check(0 + ArrayLike(np.array(0)))
  91. check(ArrayLike(np.array(0)) + np.array(0))
  92. check(np.array(0) + ArrayLike(np.array(0)))
  93. def test_inplace(self):
  94. array_like = ArrayLike(np.array([0]))
  95. array_like += 1
  96. _assert_equal_type_and_value(array_like, ArrayLike(np.array([1])))
  97. array = np.array([0])
  98. array += ArrayLike(1)
  99. _assert_equal_type_and_value(array, ArrayLike(np.array([1])))
  100. def test_opt_out(self):
  101. class OptOut(object):
  102. """Object that opts out of __array_ufunc__."""
  103. __array_ufunc__ = None
  104. def __add__(self, other):
  105. return self
  106. def __radd__(self, other):
  107. return self
  108. array_like = ArrayLike(1)
  109. opt_out = OptOut()
  110. # supported operations
  111. assert_(array_like + opt_out is opt_out)
  112. assert_(opt_out + array_like is opt_out)
  113. # not supported
  114. with assert_raises(TypeError):
  115. # don't use the Python default, array_like = array_like + opt_out
  116. array_like += opt_out
  117. with assert_raises(TypeError):
  118. array_like - opt_out
  119. with assert_raises(TypeError):
  120. opt_out - array_like
  121. def test_subclass(self):
  122. class SubArrayLike(ArrayLike):
  123. """Should take precedence over ArrayLike."""
  124. x = ArrayLike(0)
  125. y = SubArrayLike(1)
  126. _assert_equal_type_and_value(x + y, y)
  127. _assert_equal_type_and_value(y + x, y)
  128. def test_object(self):
  129. x = ArrayLike(0)
  130. obj = object()
  131. with assert_raises(TypeError):
  132. x + obj
  133. with assert_raises(TypeError):
  134. obj + x
  135. with assert_raises(TypeError):
  136. x += obj
  137. def test_unary_methods(self):
  138. array = np.array([-1, 0, 1, 2])
  139. array_like = ArrayLike(array)
  140. for op in [operator.neg,
  141. operator.pos,
  142. abs,
  143. operator.invert]:
  144. _assert_equal_type_and_value(op(array_like), ArrayLike(op(array)))
  145. def test_forward_binary_methods(self):
  146. array = np.array([-1, 0, 1, 2])
  147. array_like = ArrayLike(array)
  148. for op in _ALL_BINARY_OPERATORS:
  149. expected = wrap_array_like(op(array, 1))
  150. actual = op(array_like, 1)
  151. err_msg = 'failed for operator {}'.format(op)
  152. _assert_equal_type_and_value(expected, actual, err_msg=err_msg)
  153. def test_reflected_binary_methods(self):
  154. for op in _ALL_BINARY_OPERATORS:
  155. expected = wrap_array_like(op(2, 1))
  156. actual = op(2, ArrayLike(1))
  157. err_msg = 'failed for operator {}'.format(op)
  158. _assert_equal_type_and_value(expected, actual, err_msg=err_msg)
  159. def test_matmul(self):
  160. array = np.array([1, 2], dtype=np.float64)
  161. array_like = ArrayLike(array)
  162. expected = ArrayLike(np.float64(5))
  163. _assert_equal_type_and_value(expected, np.matmul(array_like, array))
  164. if not PY2:
  165. _assert_equal_type_and_value(
  166. expected, operator.matmul(array_like, array))
  167. _assert_equal_type_and_value(
  168. expected, operator.matmul(array, array_like))
  169. def test_ufunc_at(self):
  170. array = ArrayLike(np.array([1, 2, 3, 4]))
  171. assert_(np.negative.at(array, np.array([0, 1])) is None)
  172. _assert_equal_type_and_value(array, ArrayLike([-1, -2, 3, 4]))
  173. def test_ufunc_two_outputs(self):
  174. mantissa, exponent = np.frexp(2 ** -3)
  175. expected = (ArrayLike(mantissa), ArrayLike(exponent))
  176. _assert_equal_type_and_value(
  177. np.frexp(ArrayLike(2 ** -3)), expected)
  178. _assert_equal_type_and_value(
  179. np.frexp(ArrayLike(np.array(2 ** -3))), expected)