test_datetime.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. import numpy as np
  2. import pytest
  3. from pandas.core.dtypes.dtypes import DatetimeTZDtype
  4. import pandas as pd
  5. from pandas.core.arrays import DatetimeArray
  6. from pandas.tests.extension import base
  7. @pytest.fixture(params=["US/Central"])
  8. def dtype(request):
  9. return DatetimeTZDtype(unit="ns", tz=request.param)
  10. @pytest.fixture
  11. def data(dtype):
  12. data = DatetimeArray(pd.date_range("2000", periods=100, tz=dtype.tz),
  13. dtype=dtype)
  14. return data
  15. @pytest.fixture
  16. def data_missing(dtype):
  17. return DatetimeArray(
  18. np.array(['NaT', '2000-01-01'], dtype='datetime64[ns]'),
  19. dtype=dtype
  20. )
  21. @pytest.fixture
  22. def data_for_sorting(dtype):
  23. a = pd.Timestamp('2000-01-01')
  24. b = pd.Timestamp('2000-01-02')
  25. c = pd.Timestamp('2000-01-03')
  26. return DatetimeArray(np.array([b, c, a], dtype='datetime64[ns]'),
  27. dtype=dtype)
  28. @pytest.fixture
  29. def data_missing_for_sorting(dtype):
  30. a = pd.Timestamp('2000-01-01')
  31. b = pd.Timestamp('2000-01-02')
  32. return DatetimeArray(np.array([b, 'NaT', a], dtype='datetime64[ns]'),
  33. dtype=dtype)
  34. @pytest.fixture
  35. def data_for_grouping(dtype):
  36. """
  37. Expected to be like [B, B, NA, NA, A, A, B, C]
  38. Where A < B < C and NA is missing
  39. """
  40. a = pd.Timestamp('2000-01-01')
  41. b = pd.Timestamp('2000-01-02')
  42. c = pd.Timestamp('2000-01-03')
  43. na = 'NaT'
  44. return DatetimeArray(np.array([b, b, na, na, a, a, b, c],
  45. dtype='datetime64[ns]'),
  46. dtype=dtype)
  47. @pytest.fixture
  48. def na_cmp():
  49. def cmp(a, b):
  50. return a is pd.NaT and a is b
  51. return cmp
  52. @pytest.fixture
  53. def na_value():
  54. return pd.NaT
  55. # ----------------------------------------------------------------------------
  56. class BaseDatetimeTests(object):
  57. pass
  58. # ----------------------------------------------------------------------------
  59. # Tests
  60. class TestDatetimeDtype(BaseDatetimeTests, base.BaseDtypeTests):
  61. pass
  62. class TestConstructors(BaseDatetimeTests, base.BaseConstructorsTests):
  63. pass
  64. class TestGetitem(BaseDatetimeTests, base.BaseGetitemTests):
  65. pass
  66. class TestMethods(BaseDatetimeTests, base.BaseMethodsTests):
  67. @pytest.mark.skip(reason="Incorrect expected")
  68. def test_value_counts(self, all_data, dropna):
  69. pass
  70. def test_combine_add(self, data_repeated):
  71. # Timestamp.__add__(Timestamp) not defined
  72. pass
  73. class TestInterface(BaseDatetimeTests, base.BaseInterfaceTests):
  74. def test_array_interface(self, data):
  75. if data.tz:
  76. # np.asarray(DTA) is currently always tz-naive.
  77. pytest.skip("GH-23569")
  78. else:
  79. super(TestInterface, self).test_array_interface(data)
  80. class TestArithmeticOps(BaseDatetimeTests, base.BaseArithmeticOpsTests):
  81. implements = {'__sub__', '__rsub__'}
  82. def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
  83. if all_arithmetic_operators in self.implements:
  84. s = pd.Series(data)
  85. self.check_opname(s, all_arithmetic_operators, s.iloc[0],
  86. exc=None)
  87. else:
  88. # ... but not the rest.
  89. super(TestArithmeticOps, self).test_arith_series_with_scalar(
  90. data, all_arithmetic_operators
  91. )
  92. def test_add_series_with_extension_array(self, data):
  93. # Datetime + Datetime not implemented
  94. s = pd.Series(data)
  95. msg = 'cannot add DatetimeArray and DatetimeArray'
  96. with pytest.raises(TypeError, match=msg):
  97. s + data
  98. def test_arith_series_with_array(self, data, all_arithmetic_operators):
  99. if all_arithmetic_operators in self.implements:
  100. s = pd.Series(data)
  101. self.check_opname(s, all_arithmetic_operators, s.iloc[0],
  102. exc=None)
  103. else:
  104. # ... but not the rest.
  105. super(TestArithmeticOps, self).test_arith_series_with_scalar(
  106. data, all_arithmetic_operators
  107. )
  108. def test_error(self, data, all_arithmetic_operators):
  109. pass
  110. @pytest.mark.xfail(reason="different implementation", strict=False)
  111. def test_direct_arith_with_series_returns_not_implemented(self, data):
  112. # Right now, we have trouble with this. Returning NotImplemented
  113. # fails other tests like
  114. # tests/arithmetic/test_datetime64::TestTimestampSeriesArithmetic::
  115. # test_dt64_seris_add_intlike
  116. return super(
  117. TestArithmeticOps,
  118. self
  119. ).test_direct_arith_with_series_returns_not_implemented(data)
  120. class TestCasting(BaseDatetimeTests, base.BaseCastingTests):
  121. pass
  122. class TestComparisonOps(BaseDatetimeTests, base.BaseComparisonOpsTests):
  123. def _compare_other(self, s, data, op_name, other):
  124. # the base test is not appropriate for us. We raise on comparison
  125. # with (some) integers, depending on the value.
  126. pass
  127. @pytest.mark.xfail(reason="different implementation", strict=False)
  128. def test_direct_arith_with_series_returns_not_implemented(self, data):
  129. return super(
  130. TestComparisonOps,
  131. self
  132. ).test_direct_arith_with_series_returns_not_implemented(data)
  133. class TestMissing(BaseDatetimeTests, base.BaseMissingTests):
  134. pass
  135. class TestReshaping(BaseDatetimeTests, base.BaseReshapingTests):
  136. @pytest.mark.skip(reason="We have DatetimeTZBlock")
  137. def test_concat(self, data, in_frame):
  138. pass
  139. def test_concat_mixed_dtypes(self, data):
  140. # concat(Series[datetimetz], Series[category]) uses a
  141. # plain np.array(values) on the DatetimeArray, which
  142. # drops the tz.
  143. super(TestReshaping, self).test_concat_mixed_dtypes(data)
  144. @pytest.mark.parametrize("obj", ["series", "frame"])
  145. def test_unstack(self, obj):
  146. # GH-13287: can't use base test, since building the expected fails.
  147. data = DatetimeArray._from_sequence(['2000', '2001', '2002', '2003'],
  148. tz='US/Central')
  149. index = pd.MultiIndex.from_product(([['A', 'B'], ['a', 'b']]),
  150. names=['a', 'b'])
  151. if obj == "series":
  152. ser = pd.Series(data, index=index)
  153. expected = pd.DataFrame({
  154. "A": data.take([0, 1]),
  155. "B": data.take([2, 3])
  156. }, index=pd.Index(['a', 'b'], name='b'))
  157. expected.columns.name = 'a'
  158. else:
  159. ser = pd.DataFrame({"A": data, "B": data}, index=index)
  160. expected = pd.DataFrame(
  161. {("A", "A"): data.take([0, 1]),
  162. ("A", "B"): data.take([2, 3]),
  163. ("B", "A"): data.take([0, 1]),
  164. ("B", "B"): data.take([2, 3])},
  165. index=pd.Index(['a', 'b'], name='b')
  166. )
  167. expected.columns.names = [None, 'a']
  168. result = ser.unstack(0)
  169. self.assert_equal(result, expected)
  170. class TestSetitem(BaseDatetimeTests, base.BaseSetitemTests):
  171. pass
  172. class TestGroupby(BaseDatetimeTests, base.BaseGroupbyTests):
  173. pass
  174. class TestPrinting(BaseDatetimeTests, base.BasePrintingTests):
  175. pass