test_arithmetic.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. # -*- coding: utf-8 -*-
  2. import operator
  3. import numpy as np
  4. import pytest
  5. import pandas as pd
  6. from pandas import Series, compat
  7. from pandas.core.indexes.period import IncompatibleFrequency
  8. import pandas.util.testing as tm
  9. def _permute(obj):
  10. return obj.take(np.random.permutation(len(obj)))
  11. class TestSeriesFlexArithmetic(object):
  12. @pytest.mark.parametrize(
  13. 'ts',
  14. [
  15. (lambda x: x, lambda x: x * 2, False),
  16. (lambda x: x, lambda x: x[::2], False),
  17. (lambda x: x, lambda x: 5, True),
  18. (lambda x: tm.makeFloatSeries(),
  19. lambda x: tm.makeFloatSeries(),
  20. True)
  21. ])
  22. @pytest.mark.parametrize('opname', ['add', 'sub', 'mul', 'floordiv',
  23. 'truediv', 'div', 'pow'])
  24. def test_flex_method_equivalence(self, opname, ts):
  25. # check that Series.{opname} behaves like Series.__{opname}__,
  26. tser = tm.makeTimeSeries().rename('ts')
  27. series = ts[0](tser)
  28. other = ts[1](tser)
  29. check_reverse = ts[2]
  30. if opname == 'div' and compat.PY3:
  31. pytest.skip('div test only for Py3')
  32. op = getattr(Series, opname)
  33. if op == 'div':
  34. alt = operator.truediv
  35. else:
  36. alt = getattr(operator, opname)
  37. result = op(series, other)
  38. expected = alt(series, other)
  39. tm.assert_almost_equal(result, expected)
  40. if check_reverse:
  41. rop = getattr(Series, "r" + opname)
  42. result = rop(series, other)
  43. expected = alt(other, series)
  44. tm.assert_almost_equal(result, expected)
  45. class TestSeriesArithmetic(object):
  46. # Some of these may end up in tests/arithmetic, but are not yet sorted
  47. def test_add_series_with_period_index(self):
  48. rng = pd.period_range('1/1/2000', '1/1/2010', freq='A')
  49. ts = Series(np.random.randn(len(rng)), index=rng)
  50. result = ts + ts[::2]
  51. expected = ts + ts
  52. expected[1::2] = np.nan
  53. tm.assert_series_equal(result, expected)
  54. result = ts + _permute(ts[::2])
  55. tm.assert_series_equal(result, expected)
  56. msg = "Input has different freq=D from PeriodIndex\\(freq=A-DEC\\)"
  57. with pytest.raises(IncompatibleFrequency, match=msg):
  58. ts + ts.asfreq('D', how="end")
  59. # ------------------------------------------------------------------
  60. # Comparisons
  61. class TestSeriesFlexComparison(object):
  62. def test_comparison_flex_basic(self):
  63. left = pd.Series(np.random.randn(10))
  64. right = pd.Series(np.random.randn(10))
  65. tm.assert_series_equal(left.eq(right), left == right)
  66. tm.assert_series_equal(left.ne(right), left != right)
  67. tm.assert_series_equal(left.le(right), left < right)
  68. tm.assert_series_equal(left.lt(right), left <= right)
  69. tm.assert_series_equal(left.gt(right), left > right)
  70. tm.assert_series_equal(left.ge(right), left >= right)
  71. # axis
  72. for axis in [0, None, 'index']:
  73. tm.assert_series_equal(left.eq(right, axis=axis), left == right)
  74. tm.assert_series_equal(left.ne(right, axis=axis), left != right)
  75. tm.assert_series_equal(left.le(right, axis=axis), left < right)
  76. tm.assert_series_equal(left.lt(right, axis=axis), left <= right)
  77. tm.assert_series_equal(left.gt(right, axis=axis), left > right)
  78. tm.assert_series_equal(left.ge(right, axis=axis), left >= right)
  79. #
  80. msg = 'No axis named 1 for object type'
  81. for op in ['eq', 'ne', 'le', 'le', 'gt', 'ge']:
  82. with pytest.raises(ValueError, match=msg):
  83. getattr(left, op)(right, axis=1)
  84. class TestSeriesComparison(object):
  85. def test_comparison_different_length(self):
  86. a = Series(['a', 'b', 'c'])
  87. b = Series(['b', 'a'])
  88. with pytest.raises(ValueError):
  89. a < b
  90. a = Series([1, 2])
  91. b = Series([2, 3, 4])
  92. with pytest.raises(ValueError):
  93. a == b
  94. @pytest.mark.parametrize('opname', ['eq', 'ne', 'gt', 'lt', 'ge', 'le'])
  95. def test_ser_flex_cmp_return_dtypes(self, opname):
  96. # GH#15115
  97. ser = Series([1, 3, 2], index=range(3))
  98. const = 2
  99. result = getattr(ser, opname)(const).get_dtype_counts()
  100. tm.assert_series_equal(result, Series([1], ['bool']))
  101. @pytest.mark.parametrize('opname', ['eq', 'ne', 'gt', 'lt', 'ge', 'le'])
  102. def test_ser_flex_cmp_return_dtypes_empty(self, opname):
  103. # GH#15115 empty Series case
  104. ser = Series([1, 3, 2], index=range(3))
  105. empty = ser.iloc[:0]
  106. const = 2
  107. result = getattr(empty, opname)(const).get_dtype_counts()
  108. tm.assert_series_equal(result, Series([1], ['bool']))
  109. @pytest.mark.parametrize('op', [operator.eq, operator.ne,
  110. operator.le, operator.lt,
  111. operator.ge, operator.gt])
  112. @pytest.mark.parametrize('names', [(None, None, None),
  113. ('foo', 'bar', None),
  114. ('baz', 'baz', 'baz')])
  115. def test_ser_cmp_result_names(self, names, op):
  116. # datetime64 dtype
  117. dti = pd.date_range('1949-06-07 03:00:00',
  118. freq='H', periods=5, name=names[0])
  119. ser = Series(dti).rename(names[1])
  120. result = op(ser, dti)
  121. assert result.name == names[2]
  122. # datetime64tz dtype
  123. dti = dti.tz_localize('US/Central')
  124. ser = Series(dti).rename(names[1])
  125. result = op(ser, dti)
  126. assert result.name == names[2]
  127. # timedelta64 dtype
  128. tdi = dti - dti.shift(1)
  129. ser = Series(tdi).rename(names[1])
  130. result = op(ser, tdi)
  131. assert result.name == names[2]
  132. # categorical
  133. if op in [operator.eq, operator.ne]:
  134. # categorical dtype comparisons raise for inequalities
  135. cidx = tdi.astype('category')
  136. ser = Series(cidx).rename(names[1])
  137. result = op(ser, cidx)
  138. assert result.name == names[2]