test_period.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import numpy as np
  2. import pytest
  3. from pandas._libs.tslib import iNaT
  4. from pandas.core.dtypes.dtypes import PeriodDtype
  5. import pandas as pd
  6. from pandas.core.arrays import PeriodArray
  7. from pandas.tests.extension import base
  8. @pytest.fixture
  9. def dtype():
  10. return PeriodDtype(freq='D')
  11. @pytest.fixture
  12. def data(dtype):
  13. return PeriodArray(np.arange(1970, 2070), freq=dtype.freq)
  14. @pytest.fixture
  15. def data_for_sorting(dtype):
  16. return PeriodArray([2018, 2019, 2017], freq=dtype.freq)
  17. @pytest.fixture
  18. def data_missing(dtype):
  19. return PeriodArray([iNaT, 2017], freq=dtype.freq)
  20. @pytest.fixture
  21. def data_missing_for_sorting(dtype):
  22. return PeriodArray([2018, iNaT, 2017], freq=dtype.freq)
  23. @pytest.fixture
  24. def data_for_grouping(dtype):
  25. B = 2018
  26. NA = iNaT
  27. A = 2017
  28. C = 2019
  29. return PeriodArray([B, B, NA, NA, A, A, B, C], freq=dtype.freq)
  30. @pytest.fixture
  31. def na_value():
  32. return pd.NaT
  33. class BasePeriodTests(object):
  34. pass
  35. class TestPeriodDtype(BasePeriodTests, base.BaseDtypeTests):
  36. pass
  37. class TestConstructors(BasePeriodTests, base.BaseConstructorsTests):
  38. pass
  39. class TestGetitem(BasePeriodTests, base.BaseGetitemTests):
  40. pass
  41. class TestMethods(BasePeriodTests, base.BaseMethodsTests):
  42. def test_combine_add(self, data_repeated):
  43. # Period + Period is not defined.
  44. pass
  45. class TestInterface(BasePeriodTests, base.BaseInterfaceTests):
  46. pass
  47. class TestArithmeticOps(BasePeriodTests, base.BaseArithmeticOpsTests):
  48. implements = {'__sub__', '__rsub__'}
  49. def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
  50. # we implement substitution...
  51. if all_arithmetic_operators in self.implements:
  52. s = pd.Series(data)
  53. self.check_opname(s, all_arithmetic_operators, s.iloc[0],
  54. exc=None)
  55. else:
  56. # ... but not the rest.
  57. super(TestArithmeticOps, self).test_arith_series_with_scalar(
  58. data, all_arithmetic_operators
  59. )
  60. def test_arith_series_with_array(self, data, all_arithmetic_operators):
  61. if all_arithmetic_operators in self.implements:
  62. s = pd.Series(data)
  63. self.check_opname(s, all_arithmetic_operators, s.iloc[0],
  64. exc=None)
  65. else:
  66. # ... but not the rest.
  67. super(TestArithmeticOps, self).test_arith_series_with_scalar(
  68. data, all_arithmetic_operators
  69. )
  70. def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
  71. super(TestArithmeticOps, self)._check_divmod_op(
  72. s, op, other, exc=TypeError
  73. )
  74. def test_add_series_with_extension_array(self, data):
  75. # we don't implement + for Period
  76. s = pd.Series(data)
  77. msg = (r"unsupported operand type\(s\) for \+: "
  78. r"\'PeriodArray\' and \'PeriodArray\'")
  79. with pytest.raises(TypeError, match=msg):
  80. s + data
  81. def test_error(self):
  82. pass
  83. def test_direct_arith_with_series_returns_not_implemented(self, data):
  84. # Override to use __sub__ instead of __add__
  85. other = pd.Series(data)
  86. result = data.__sub__(other)
  87. assert result is NotImplemented
  88. class TestCasting(BasePeriodTests, base.BaseCastingTests):
  89. pass
  90. class TestComparisonOps(BasePeriodTests, base.BaseComparisonOpsTests):
  91. def _compare_other(self, s, data, op_name, other):
  92. # the base test is not appropriate for us. We raise on comparison
  93. # with (some) integers, depending on the value.
  94. pass
  95. class TestMissing(BasePeriodTests, base.BaseMissingTests):
  96. pass
  97. class TestReshaping(BasePeriodTests, base.BaseReshapingTests):
  98. pass
  99. class TestSetitem(BasePeriodTests, base.BaseSetitemTests):
  100. pass
  101. class TestGroupby(BasePeriodTests, base.BaseGroupbyTests):
  102. pass
  103. class TestPrinting(BasePeriodTests, base.BasePrintingTests):
  104. pass
  105. class TestParsing(BasePeriodTests, base.BaseParsingTests):
  106. @pytest.mark.parametrize('engine', ['c', 'python'])
  107. def test_EA_types(self, engine, data):
  108. expected_msg = r'.*must implement _from_sequence_of_strings.*'
  109. with pytest.raises(NotImplementedError, match=expected_msg):
  110. super(TestParsing, self).test_EA_types(engine, data)