test_interval_new.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. from __future__ import division
  2. import numpy as np
  3. import pytest
  4. from pandas import Int64Index, Interval, IntervalIndex
  5. import pandas.util.testing as tm
  6. pytestmark = pytest.mark.skip(reason="new indexing tests for issue 16316")
  7. class TestIntervalIndex(object):
  8. @pytest.mark.parametrize("side", ['right', 'left', 'both', 'neither'])
  9. def test_get_loc_interval(self, closed, side):
  10. idx = IntervalIndex.from_tuples([(0, 1), (2, 3)], closed=closed)
  11. for bound in [[0, 1], [1, 2], [2, 3], [3, 4],
  12. [0, 2], [2.5, 3], [-1, 4]]:
  13. # if get_loc is supplied an interval, it should only search
  14. # for exact matches, not overlaps or covers, else KeyError.
  15. if closed == side:
  16. if bound == [0, 1]:
  17. assert idx.get_loc(Interval(0, 1, closed=side)) == 0
  18. elif bound == [2, 3]:
  19. assert idx.get_loc(Interval(2, 3, closed=side)) == 1
  20. else:
  21. with pytest.raises(KeyError):
  22. idx.get_loc(Interval(*bound, closed=side))
  23. else:
  24. with pytest.raises(KeyError):
  25. idx.get_loc(Interval(*bound, closed=side))
  26. @pytest.mark.parametrize("scalar", [-0.5, 0, 0.5, 1, 1.5, 2, 2.5, 3, 3.5])
  27. def test_get_loc_scalar(self, closed, scalar):
  28. # correct = {side: {query: answer}}.
  29. # If query is not in the dict, that query should raise a KeyError
  30. correct = {'right': {0.5: 0, 1: 0, 2.5: 1, 3: 1},
  31. 'left': {0: 0, 0.5: 0, 2: 1, 2.5: 1},
  32. 'both': {0: 0, 0.5: 0, 1: 0, 2: 1, 2.5: 1, 3: 1},
  33. 'neither': {0.5: 0, 2.5: 1}}
  34. idx = IntervalIndex.from_tuples([(0, 1), (2, 3)], closed=closed)
  35. # if get_loc is supplied a scalar, it should return the index of
  36. # the interval which contains the scalar, or KeyError.
  37. if scalar in correct[closed].keys():
  38. assert idx.get_loc(scalar) == correct[closed][scalar]
  39. else:
  40. pytest.raises(KeyError, idx.get_loc, scalar)
  41. def test_slice_locs_with_interval(self):
  42. # increasing monotonically
  43. index = IntervalIndex.from_tuples([(0, 2), (1, 3), (2, 4)])
  44. assert index.slice_locs(
  45. start=Interval(0, 2), end=Interval(2, 4)) == (0, 3)
  46. assert index.slice_locs(start=Interval(0, 2)) == (0, 3)
  47. assert index.slice_locs(end=Interval(2, 4)) == (0, 3)
  48. assert index.slice_locs(end=Interval(0, 2)) == (0, 1)
  49. assert index.slice_locs(
  50. start=Interval(2, 4), end=Interval(0, 2)) == (2, 1)
  51. # decreasing monotonically
  52. index = IntervalIndex.from_tuples([(2, 4), (1, 3), (0, 2)])
  53. assert index.slice_locs(
  54. start=Interval(0, 2), end=Interval(2, 4)) == (2, 1)
  55. assert index.slice_locs(start=Interval(0, 2)) == (2, 3)
  56. assert index.slice_locs(end=Interval(2, 4)) == (0, 1)
  57. assert index.slice_locs(end=Interval(0, 2)) == (0, 3)
  58. assert index.slice_locs(
  59. start=Interval(2, 4), end=Interval(0, 2)) == (0, 3)
  60. # sorted duplicates
  61. index = IntervalIndex.from_tuples([(0, 2), (0, 2), (2, 4)])
  62. assert index.slice_locs(
  63. start=Interval(0, 2), end=Interval(2, 4)) == (0, 3)
  64. assert index.slice_locs(start=Interval(0, 2)) == (0, 3)
  65. assert index.slice_locs(end=Interval(2, 4)) == (0, 3)
  66. assert index.slice_locs(end=Interval(0, 2)) == (0, 2)
  67. assert index.slice_locs(
  68. start=Interval(2, 4), end=Interval(0, 2)) == (2, 2)
  69. # unsorted duplicates
  70. index = IntervalIndex.from_tuples([(0, 2), (2, 4), (0, 2)])
  71. pytest.raises(KeyError, index.slice_locs(
  72. start=Interval(0, 2), end=Interval(2, 4)))
  73. pytest.raises(KeyError, index.slice_locs(start=Interval(0, 2)))
  74. assert index.slice_locs(end=Interval(2, 4)) == (0, 2)
  75. pytest.raises(KeyError, index.slice_locs(end=Interval(0, 2)))
  76. pytest.raises(KeyError, index.slice_locs(
  77. start=Interval(2, 4), end=Interval(0, 2)))
  78. # another unsorted duplicates
  79. index = IntervalIndex.from_tuples([(0, 2), (0, 2), (2, 4), (1, 3)])
  80. assert index.slice_locs(
  81. start=Interval(0, 2), end=Interval(2, 4)) == (0, 3)
  82. assert index.slice_locs(start=Interval(0, 2)) == (0, 4)
  83. assert index.slice_locs(end=Interval(2, 4)) == (0, 3)
  84. assert index.slice_locs(end=Interval(0, 2)) == (0, 2)
  85. assert index.slice_locs(
  86. start=Interval(2, 4), end=Interval(0, 2)) == (2, 2)
  87. def test_slice_locs_with_ints_and_floats_succeeds(self):
  88. # increasing non-overlapping
  89. index = IntervalIndex.from_tuples([(0, 1), (1, 2), (3, 4)])
  90. assert index.slice_locs(0, 1) == (0, 1)
  91. assert index.slice_locs(0, 2) == (0, 2)
  92. assert index.slice_locs(0, 3) == (0, 2)
  93. assert index.slice_locs(3, 1) == (2, 1)
  94. assert index.slice_locs(3, 4) == (2, 3)
  95. assert index.slice_locs(0, 4) == (0, 3)
  96. # decreasing non-overlapping
  97. index = IntervalIndex.from_tuples([(3, 4), (1, 2), (0, 1)])
  98. assert index.slice_locs(0, 1) == (3, 2)
  99. assert index.slice_locs(0, 2) == (3, 1)
  100. assert index.slice_locs(0, 3) == (3, 1)
  101. assert index.slice_locs(3, 1) == (1, 2)
  102. assert index.slice_locs(3, 4) == (1, 0)
  103. assert index.slice_locs(0, 4) == (3, 0)
  104. @pytest.mark.parametrize("query", [
  105. [0, 1], [0, 2], [0, 3], [3, 1], [3, 4], [0, 4]])
  106. @pytest.mark.parametrize("tuples", [
  107. [(0, 2), (1, 3), (2, 4)], [(2, 4), (1, 3), (0, 2)],
  108. [(0, 2), (0, 2), (2, 4)], [(0, 2), (2, 4), (0, 2)],
  109. [(0, 2), (0, 2), (2, 4), (1, 3)]])
  110. def test_slice_locs_with_ints_and_floats_errors(self, tuples, query):
  111. index = IntervalIndex.from_tuples(tuples)
  112. with pytest.raises(KeyError):
  113. index.slice_locs(query)
  114. @pytest.mark.parametrize('query, expected', [
  115. ([Interval(1, 3, closed='right')], [1]),
  116. ([Interval(1, 3, closed='left')], [-1]),
  117. ([Interval(1, 3, closed='both')], [-1]),
  118. ([Interval(1, 3, closed='neither')], [-1]),
  119. ([Interval(1, 4, closed='right')], [-1]),
  120. ([Interval(0, 4, closed='right')], [-1]),
  121. ([Interval(1, 2, closed='right')], [-1]),
  122. ([Interval(2, 4, closed='right'), Interval(1, 3, closed='right')],
  123. [2, 1]),
  124. ([Interval(1, 3, closed='right'), Interval(0, 2, closed='right')],
  125. [1, -1]),
  126. ([Interval(1, 3, closed='right'), Interval(1, 3, closed='left')],
  127. [1, -1])])
  128. def test_get_indexer_with_interval(self, query, expected):
  129. tuples = [(0, 2.5), (1, 3), (2, 4)]
  130. index = IntervalIndex.from_tuples(tuples, closed='right')
  131. result = index.get_indexer(query)
  132. expected = np.array(expected, dtype='intp')
  133. tm.assert_numpy_array_equal(result, expected)
  134. @pytest.mark.parametrize('query, expected', [
  135. ([-0.5], [-1]),
  136. ([0], [-1]),
  137. ([0.5], [0]),
  138. ([1], [0]),
  139. ([1.5], [1]),
  140. ([2], [1]),
  141. ([2.5], [-1]),
  142. ([3], [-1]),
  143. ([3.5], [2]),
  144. ([4], [2]),
  145. ([4.5], [-1]),
  146. ([1, 2], [0, 1]),
  147. ([1, 2, 3], [0, 1, -1]),
  148. ([1, 2, 3, 4], [0, 1, -1, 2]),
  149. ([1, 2, 3, 4, 2], [0, 1, -1, 2, 1])])
  150. def test_get_indexer_with_int_and_float(self, query, expected):
  151. tuples = [(0, 1), (1, 2), (3, 4)]
  152. index = IntervalIndex.from_tuples(tuples, closed='right')
  153. result = index.get_indexer(query)
  154. expected = np.array(expected, dtype='intp')
  155. tm.assert_numpy_array_equal(result, expected)
  156. @pytest.mark.parametrize('tuples, closed', [
  157. ([(0, 2), (1, 3), (3, 4)], 'neither'),
  158. ([(0, 5), (1, 4), (6, 7)], 'left'),
  159. ([(0, 1), (0, 1), (1, 2)], 'right'),
  160. ([(0, 1), (2, 3), (3, 4)], 'both')])
  161. def test_get_indexer_errors(self, tuples, closed):
  162. # IntervalIndex needs non-overlapping for uniqueness when querying
  163. index = IntervalIndex.from_tuples(tuples, closed=closed)
  164. msg = ('cannot handle overlapping indices; use '
  165. 'IntervalIndex.get_indexer_non_unique')
  166. with pytest.raises(ValueError, match=msg):
  167. index.get_indexer([0, 2])
  168. @pytest.mark.parametrize('query, expected', [
  169. ([-0.5], ([-1], [0])),
  170. ([0], ([0], [])),
  171. ([0.5], ([0], [])),
  172. ([1], ([0, 1], [])),
  173. ([1.5], ([0, 1], [])),
  174. ([2], ([0, 1, 2], [])),
  175. ([2.5], ([1, 2], [])),
  176. ([3], ([2], [])),
  177. ([3.5], ([2], [])),
  178. ([4], ([-1], [0])),
  179. ([4.5], ([-1], [0])),
  180. ([1, 2], ([0, 1, 0, 1, 2], [])),
  181. ([1, 2, 3], ([0, 1, 0, 1, 2, 2], [])),
  182. ([1, 2, 3, 4], ([0, 1, 0, 1, 2, 2, -1], [3])),
  183. ([1, 2, 3, 4, 2], ([0, 1, 0, 1, 2, 2, -1, 0, 1, 2], [3]))])
  184. def test_get_indexer_non_unique_with_int_and_float(self, query, expected):
  185. tuples = [(0, 2.5), (1, 3), (2, 4)]
  186. index = IntervalIndex.from_tuples(tuples, closed='left')
  187. result_indexer, result_missing = index.get_indexer_non_unique(query)
  188. expected_indexer = Int64Index(expected[0])
  189. expected_missing = np.array(expected[1], dtype='intp')
  190. tm.assert_index_equal(result_indexer, expected_indexer)
  191. tm.assert_numpy_array_equal(result_missing, expected_missing)
  192. # TODO we may also want to test get_indexer for the case when
  193. # the intervals are duplicated, decreasing, non-monotonic, etc..
  194. def test_contains(self):
  195. index = IntervalIndex.from_arrays([0, 1], [1, 2], closed='right')
  196. # __contains__ requires perfect matches to intervals.
  197. assert 0 not in index
  198. assert 1 not in index
  199. assert 2 not in index
  200. assert Interval(0, 1, closed='right') in index
  201. assert Interval(0, 2, closed='right') not in index
  202. assert Interval(0, 0.5, closed='right') not in index
  203. assert Interval(3, 5, closed='right') not in index
  204. assert Interval(-1, 0, closed='left') not in index
  205. assert Interval(0, 1, closed='left') not in index
  206. assert Interval(0, 1, closed='both') not in index
  207. def test_contains_method(self):
  208. index = IntervalIndex.from_arrays([0, 1], [1, 2], closed='right')
  209. assert not index.contains(0)
  210. assert index.contains(0.1)
  211. assert index.contains(0.5)
  212. assert index.contains(1)
  213. assert index.contains(Interval(0, 1, closed='right'))
  214. assert not index.contains(Interval(0, 1, closed='left'))
  215. assert not index.contains(Interval(0, 1, closed='both'))
  216. assert not index.contains(Interval(0, 2, closed='right'))
  217. assert not index.contains(Interval(0, 3, closed='right'))
  218. assert not index.contains(Interval(1, 3, closed='right'))
  219. assert not index.contains(20)
  220. assert not index.contains(-20)