ops.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import operator
  2. import pytest
  3. import pandas as pd
  4. from pandas.core import ops
  5. from .base import BaseExtensionTests
  6. class BaseOpsUtil(BaseExtensionTests):
  7. def get_op_from_name(self, op_name):
  8. short_opname = op_name.strip('_')
  9. try:
  10. op = getattr(operator, short_opname)
  11. except AttributeError:
  12. # Assume it is the reverse operator
  13. rop = getattr(operator, short_opname[1:])
  14. op = lambda x, y: rop(y, x)
  15. return op
  16. def check_opname(self, s, op_name, other, exc=Exception):
  17. op = self.get_op_from_name(op_name)
  18. self._check_op(s, op, other, op_name, exc)
  19. def _check_op(self, s, op, other, op_name, exc=NotImplementedError):
  20. if exc is None:
  21. result = op(s, other)
  22. expected = s.combine(other, op)
  23. self.assert_series_equal(result, expected)
  24. else:
  25. with pytest.raises(exc):
  26. op(s, other)
  27. def _check_divmod_op(self, s, op, other, exc=Exception):
  28. # divmod has multiple return values, so check separatly
  29. if exc is None:
  30. result_div, result_mod = op(s, other)
  31. if op is divmod:
  32. expected_div, expected_mod = s // other, s % other
  33. else:
  34. expected_div, expected_mod = other // s, other % s
  35. self.assert_series_equal(result_div, expected_div)
  36. self.assert_series_equal(result_mod, expected_mod)
  37. else:
  38. with pytest.raises(exc):
  39. divmod(s, other)
  40. class BaseArithmeticOpsTests(BaseOpsUtil):
  41. """Various Series and DataFrame arithmetic ops methods.
  42. Subclasses supporting various ops should set the class variables
  43. to indicate that they support ops of that kind
  44. * series_scalar_exc = TypeError
  45. * frame_scalar_exc = TypeError
  46. * series_array_exc = TypeError
  47. * divmod_exc = TypeError
  48. """
  49. series_scalar_exc = TypeError
  50. frame_scalar_exc = TypeError
  51. series_array_exc = TypeError
  52. divmod_exc = TypeError
  53. def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
  54. # series & scalar
  55. op_name = all_arithmetic_operators
  56. s = pd.Series(data)
  57. self.check_opname(s, op_name, s.iloc[0], exc=self.series_scalar_exc)
  58. @pytest.mark.xfail(run=False, reason="_reduce needs implementation")
  59. def test_arith_frame_with_scalar(self, data, all_arithmetic_operators):
  60. # frame & scalar
  61. op_name = all_arithmetic_operators
  62. df = pd.DataFrame({'A': data})
  63. self.check_opname(df, op_name, data[0], exc=self.frame_scalar_exc)
  64. def test_arith_series_with_array(self, data, all_arithmetic_operators):
  65. # ndarray & other series
  66. op_name = all_arithmetic_operators
  67. s = pd.Series(data)
  68. self.check_opname(s, op_name, pd.Series([s.iloc[0]] * len(s)),
  69. exc=self.series_array_exc)
  70. def test_divmod(self, data):
  71. s = pd.Series(data)
  72. self._check_divmod_op(s, divmod, 1, exc=self.divmod_exc)
  73. self._check_divmod_op(1, ops.rdivmod, s, exc=self.divmod_exc)
  74. def test_divmod_series_array(self, data):
  75. s = pd.Series(data)
  76. self._check_divmod_op(s, divmod, data)
  77. def test_add_series_with_extension_array(self, data):
  78. s = pd.Series(data)
  79. result = s + data
  80. expected = pd.Series(data + data)
  81. self.assert_series_equal(result, expected)
  82. def test_error(self, data, all_arithmetic_operators):
  83. # invalid ops
  84. op_name = all_arithmetic_operators
  85. with pytest.raises(AttributeError):
  86. getattr(data, op_name)
  87. def test_direct_arith_with_series_returns_not_implemented(self, data):
  88. # EAs should return NotImplemented for ops with Series.
  89. # Pandas takes care of unboxing the series and calling the EA's op.
  90. other = pd.Series(data)
  91. if hasattr(data, '__add__'):
  92. result = data.__add__(other)
  93. assert result is NotImplemented
  94. else:
  95. raise pytest.skip(
  96. "{} does not implement add".format(data.__class__.__name__)
  97. )
  98. class BaseComparisonOpsTests(BaseOpsUtil):
  99. """Various Series and DataFrame comparison ops methods."""
  100. def _compare_other(self, s, data, op_name, other):
  101. op = self.get_op_from_name(op_name)
  102. if op_name == '__eq__':
  103. assert getattr(data, op_name)(other) is NotImplemented
  104. assert not op(s, other).all()
  105. elif op_name == '__ne__':
  106. assert getattr(data, op_name)(other) is NotImplemented
  107. assert op(s, other).all()
  108. else:
  109. # array
  110. assert getattr(data, op_name)(other) is NotImplemented
  111. # series
  112. s = pd.Series(data)
  113. with pytest.raises(TypeError):
  114. op(s, other)
  115. def test_compare_scalar(self, data, all_compare_operators):
  116. op_name = all_compare_operators
  117. s = pd.Series(data)
  118. self._compare_other(s, data, op_name, 0)
  119. def test_compare_array(self, data, all_compare_operators):
  120. op_name = all_compare_operators
  121. s = pd.Series(data)
  122. other = pd.Series([data[0]] * len(data))
  123. self._compare_other(s, data, op_name, other)
  124. def test_direct_arith_with_series_returns_not_implemented(self, data):
  125. # EAs should return NotImplemented for ops with Series.
  126. # Pandas takes care of unboxing the series and calling the EA's op.
  127. other = pd.Series(data)
  128. if hasattr(data, '__eq__'):
  129. result = data.__eq__(other)
  130. assert result is NotImplemented
  131. else:
  132. raise pytest.skip(
  133. "{} does not implement __eq__".format(data.__class__.__name__)
  134. )