test_comparisons.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. # -*- coding: utf-8 -*-
  2. from datetime import datetime
  3. import operator
  4. import numpy as np
  5. import pytest
  6. from pandas.compat import PY2, long
  7. from pandas import Timestamp
  8. class TestTimestampComparison(object):
  9. def test_comparison_object_array(self):
  10. # GH#15183
  11. ts = Timestamp('2011-01-03 00:00:00-0500', tz='US/Eastern')
  12. other = Timestamp('2011-01-01 00:00:00-0500', tz='US/Eastern')
  13. naive = Timestamp('2011-01-01 00:00:00')
  14. arr = np.array([other, ts], dtype=object)
  15. res = arr == ts
  16. expected = np.array([False, True], dtype=bool)
  17. assert (res == expected).all()
  18. # 2D case
  19. arr = np.array([[other, ts],
  20. [ts, other]],
  21. dtype=object)
  22. res = arr != ts
  23. expected = np.array([[True, False], [False, True]], dtype=bool)
  24. assert res.shape == expected.shape
  25. assert (res == expected).all()
  26. # tzaware mismatch
  27. arr = np.array([naive], dtype=object)
  28. with pytest.raises(TypeError):
  29. arr < ts
  30. def test_comparison(self):
  31. # 5-18-2012 00:00:00.000
  32. stamp = long(1337299200000000000)
  33. val = Timestamp(stamp)
  34. assert val == val
  35. assert not val != val
  36. assert not val < val
  37. assert val <= val
  38. assert not val > val
  39. assert val >= val
  40. other = datetime(2012, 5, 18)
  41. assert val == other
  42. assert not val != other
  43. assert not val < other
  44. assert val <= other
  45. assert not val > other
  46. assert val >= other
  47. other = Timestamp(stamp + 100)
  48. assert val != other
  49. assert val != other
  50. assert val < other
  51. assert val <= other
  52. assert other > val
  53. assert other >= val
  54. def test_compare_invalid(self):
  55. # GH#8058
  56. val = Timestamp('20130101 12:01:02')
  57. assert not val == 'foo'
  58. assert not val == 10.0
  59. assert not val == 1
  60. assert not val == long(1)
  61. assert not val == []
  62. assert not val == {'foo': 1}
  63. assert not val == np.float64(1)
  64. assert not val == np.int64(1)
  65. assert val != 'foo'
  66. assert val != 10.0
  67. assert val != 1
  68. assert val != long(1)
  69. assert val != []
  70. assert val != {'foo': 1}
  71. assert val != np.float64(1)
  72. assert val != np.int64(1)
  73. def test_cant_compare_tz_naive_w_aware(self, utc_fixture):
  74. # see GH#1404
  75. a = Timestamp('3/12/2012')
  76. b = Timestamp('3/12/2012', tz=utc_fixture)
  77. with pytest.raises(TypeError):
  78. a == b
  79. with pytest.raises(TypeError):
  80. a != b
  81. with pytest.raises(TypeError):
  82. a < b
  83. with pytest.raises(TypeError):
  84. a <= b
  85. with pytest.raises(TypeError):
  86. a > b
  87. with pytest.raises(TypeError):
  88. a >= b
  89. with pytest.raises(TypeError):
  90. b == a
  91. with pytest.raises(TypeError):
  92. b != a
  93. with pytest.raises(TypeError):
  94. b < a
  95. with pytest.raises(TypeError):
  96. b <= a
  97. with pytest.raises(TypeError):
  98. b > a
  99. with pytest.raises(TypeError):
  100. b >= a
  101. if PY2:
  102. with pytest.raises(TypeError):
  103. a == b.to_pydatetime()
  104. with pytest.raises(TypeError):
  105. a.to_pydatetime() == b
  106. else:
  107. assert not a == b.to_pydatetime()
  108. assert not a.to_pydatetime() == b
  109. def test_timestamp_compare_scalars(self):
  110. # case where ndim == 0
  111. lhs = np.datetime64(datetime(2013, 12, 6))
  112. rhs = Timestamp('now')
  113. nat = Timestamp('nat')
  114. ops = {'gt': 'lt',
  115. 'lt': 'gt',
  116. 'ge': 'le',
  117. 'le': 'ge',
  118. 'eq': 'eq',
  119. 'ne': 'ne'}
  120. for left, right in ops.items():
  121. left_f = getattr(operator, left)
  122. right_f = getattr(operator, right)
  123. expected = left_f(lhs, rhs)
  124. result = right_f(rhs, lhs)
  125. assert result == expected
  126. expected = left_f(rhs, nat)
  127. result = right_f(nat, rhs)
  128. assert result == expected
  129. def test_timestamp_compare_with_early_datetime(self):
  130. # e.g. datetime.min
  131. stamp = Timestamp('2012-01-01')
  132. assert not stamp == datetime.min
  133. assert not stamp == datetime(1600, 1, 1)
  134. assert not stamp == datetime(2700, 1, 1)
  135. assert stamp != datetime.min
  136. assert stamp != datetime(1600, 1, 1)
  137. assert stamp != datetime(2700, 1, 1)
  138. assert stamp > datetime(1600, 1, 1)
  139. assert stamp >= datetime(1600, 1, 1)
  140. assert stamp < datetime(2700, 1, 1)
  141. assert stamp <= datetime(2700, 1, 1)