test_interval_tree.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. from __future__ import division
  2. from itertools import permutations
  3. import numpy as np
  4. import pytest
  5. from pandas._libs.interval import IntervalTree
  6. from pandas import compat
  7. import pandas.util.testing as tm
  8. def skipif_32bit(param):
  9. """
  10. Skip parameters in a parametrize on 32bit systems. Specifically used
  11. here to skip leaf_size parameters related to GH 23440.
  12. """
  13. marks = pytest.mark.skipif(compat.is_platform_32bit(),
  14. reason='GH 23440: int type mismatch on 32bit')
  15. return pytest.param(param, marks=marks)
  16. @pytest.fixture(
  17. scope='class', params=['int32', 'int64', 'float32', 'float64', 'uint64'])
  18. def dtype(request):
  19. return request.param
  20. @pytest.fixture(params=[skipif_32bit(1), skipif_32bit(2), 10])
  21. def leaf_size(request):
  22. """
  23. Fixture to specify IntervalTree leaf_size parameter; to be used with the
  24. tree fixture.
  25. """
  26. return request.param
  27. @pytest.fixture(params=[
  28. np.arange(5, dtype='int64'),
  29. np.arange(5, dtype='int32'),
  30. np.arange(5, dtype='uint64'),
  31. np.arange(5, dtype='float64'),
  32. np.arange(5, dtype='float32'),
  33. np.array([0, 1, 2, 3, 4, np.nan], dtype='float64'),
  34. np.array([0, 1, 2, 3, 4, np.nan], dtype='float32')])
  35. def tree(request, leaf_size):
  36. left = request.param
  37. return IntervalTree(left, left + 2, leaf_size=leaf_size)
  38. class TestIntervalTree(object):
  39. def test_get_loc(self, tree):
  40. result = tree.get_loc(1)
  41. expected = np.array([0], dtype='intp')
  42. tm.assert_numpy_array_equal(result, expected)
  43. result = np.sort(tree.get_loc(2))
  44. expected = np.array([0, 1], dtype='intp')
  45. tm.assert_numpy_array_equal(result, expected)
  46. with pytest.raises(KeyError):
  47. tree.get_loc(-1)
  48. def test_get_indexer(self, tree):
  49. result = tree.get_indexer(np.array([1.0, 5.5, 6.5]))
  50. expected = np.array([0, 4, -1], dtype='intp')
  51. tm.assert_numpy_array_equal(result, expected)
  52. with pytest.raises(KeyError):
  53. tree.get_indexer(np.array([3.0]))
  54. def test_get_indexer_non_unique(self, tree):
  55. indexer, missing = tree.get_indexer_non_unique(
  56. np.array([1.0, 2.0, 6.5]))
  57. result = indexer[:1]
  58. expected = np.array([0], dtype='intp')
  59. tm.assert_numpy_array_equal(result, expected)
  60. result = np.sort(indexer[1:3])
  61. expected = np.array([0, 1], dtype='intp')
  62. tm.assert_numpy_array_equal(result, expected)
  63. result = np.sort(indexer[3:])
  64. expected = np.array([-1], dtype='intp')
  65. tm.assert_numpy_array_equal(result, expected)
  66. result = missing
  67. expected = np.array([2], dtype='intp')
  68. tm.assert_numpy_array_equal(result, expected)
  69. def test_duplicates(self, dtype):
  70. left = np.array([0, 0, 0], dtype=dtype)
  71. tree = IntervalTree(left, left + 1)
  72. result = np.sort(tree.get_loc(0.5))
  73. expected = np.array([0, 1, 2], dtype='intp')
  74. tm.assert_numpy_array_equal(result, expected)
  75. with pytest.raises(KeyError):
  76. tree.get_indexer(np.array([0.5]))
  77. indexer, missing = tree.get_indexer_non_unique(np.array([0.5]))
  78. result = np.sort(indexer)
  79. expected = np.array([0, 1, 2], dtype='intp')
  80. tm.assert_numpy_array_equal(result, expected)
  81. result = missing
  82. expected = np.array([], dtype='intp')
  83. tm.assert_numpy_array_equal(result, expected)
  84. def test_get_loc_closed(self, closed):
  85. tree = IntervalTree([0], [1], closed=closed)
  86. for p, errors in [(0, tree.open_left),
  87. (1, tree.open_right)]:
  88. if errors:
  89. with pytest.raises(KeyError):
  90. tree.get_loc(p)
  91. else:
  92. result = tree.get_loc(p)
  93. expected = np.array([0], dtype='intp')
  94. tm.assert_numpy_array_equal(result, expected)
  95. @pytest.mark.parametrize('leaf_size', [
  96. skipif_32bit(1), skipif_32bit(10), skipif_32bit(100), 10000])
  97. def test_get_indexer_closed(self, closed, leaf_size):
  98. x = np.arange(1000, dtype='float64')
  99. found = x.astype('intp')
  100. not_found = (-1 * np.ones(1000)).astype('intp')
  101. tree = IntervalTree(x, x + 0.5, closed=closed, leaf_size=leaf_size)
  102. tm.assert_numpy_array_equal(found, tree.get_indexer(x + 0.25))
  103. expected = found if tree.closed_left else not_found
  104. tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.0))
  105. expected = found if tree.closed_right else not_found
  106. tm.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.5))
  107. @pytest.mark.parametrize('left, right, expected', [
  108. (np.array([0, 1, 4]), np.array([2, 3, 5]), True),
  109. (np.array([0, 1, 2]), np.array([5, 4, 3]), True),
  110. (np.array([0, 1, np.nan]), np.array([5, 4, np.nan]), True),
  111. (np.array([0, 2, 4]), np.array([1, 3, 5]), False),
  112. (np.array([0, 2, np.nan]), np.array([1, 3, np.nan]), False)])
  113. @pytest.mark.parametrize('order', map(list, permutations(range(3))))
  114. def test_is_overlapping(self, closed, order, left, right, expected):
  115. # GH 23309
  116. tree = IntervalTree(left[order], right[order], closed=closed)
  117. result = tree.is_overlapping
  118. assert result is expected
  119. @pytest.mark.parametrize('order', map(list, permutations(range(3))))
  120. def test_is_overlapping_endpoints(self, closed, order):
  121. """shared endpoints are marked as overlapping"""
  122. # GH 23309
  123. left, right = np.arange(3), np.arange(1, 4)
  124. tree = IntervalTree(left[order], right[order], closed=closed)
  125. result = tree.is_overlapping
  126. expected = closed is 'both'
  127. assert result is expected
  128. @pytest.mark.parametrize('left, right', [
  129. (np.array([], dtype='int64'), np.array([], dtype='int64')),
  130. (np.array([0], dtype='int64'), np.array([1], dtype='int64')),
  131. (np.array([np.nan]), np.array([np.nan])),
  132. (np.array([np.nan] * 3), np.array([np.nan] * 3))])
  133. def test_is_overlapping_trivial(self, closed, left, right):
  134. # GH 23309
  135. tree = IntervalTree(left, right, closed=closed)
  136. assert tree.is_overlapping is False
  137. @pytest.mark.skipif(compat.is_platform_32bit(), reason='GH 23440')
  138. def test_construction_overflow(self):
  139. # GH 25485
  140. left, right = np.arange(101), [np.iinfo(np.int64).max] * 101
  141. tree = IntervalTree(left, right)
  142. # pivot should be average of left/right medians
  143. result = tree.root.pivot
  144. expected = (50 + np.iinfo(np.int64).max) / 2
  145. assert result == expected