test_set_ops.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import pytest
  4. import pandas as pd
  5. from pandas import MultiIndex, Series
  6. import pandas.util.testing as tm
  7. @pytest.mark.parametrize("case", [0.5, "xxx"])
  8. @pytest.mark.parametrize("sort", [None, False])
  9. @pytest.mark.parametrize("method", ["intersection", "union",
  10. "difference", "symmetric_difference"])
  11. def test_set_ops_error_cases(idx, case, sort, method):
  12. # non-iterable input
  13. msg = "Input must be Index or array-like"
  14. with pytest.raises(TypeError, match=msg):
  15. getattr(idx, method)(case, sort=sort)
  16. @pytest.mark.parametrize("sort", [None, False])
  17. def test_intersection_base(idx, sort):
  18. first = idx[:5]
  19. second = idx[:3]
  20. intersect = first.intersection(second, sort=sort)
  21. if sort is None:
  22. tm.assert_index_equal(intersect, second.sort_values())
  23. assert tm.equalContents(intersect, second)
  24. # GH 10149
  25. cases = [klass(second.values)
  26. for klass in [np.array, Series, list]]
  27. for case in cases:
  28. result = first.intersection(case, sort=sort)
  29. if sort is None:
  30. tm.assert_index_equal(result, second.sort_values())
  31. assert tm.equalContents(result, second)
  32. msg = "other must be a MultiIndex or a list of tuples"
  33. with pytest.raises(TypeError, match=msg):
  34. first.intersection([1, 2, 3], sort=sort)
  35. @pytest.mark.parametrize("sort", [None, False])
  36. def test_union_base(idx, sort):
  37. first = idx[3:]
  38. second = idx[:5]
  39. everything = idx
  40. union = first.union(second, sort=sort)
  41. if sort is None:
  42. tm.assert_index_equal(union, everything.sort_values())
  43. assert tm.equalContents(union, everything)
  44. # GH 10149
  45. cases = [klass(second.values)
  46. for klass in [np.array, Series, list]]
  47. for case in cases:
  48. result = first.union(case, sort=sort)
  49. if sort is None:
  50. tm.assert_index_equal(result, everything.sort_values())
  51. assert tm.equalContents(result, everything)
  52. msg = "other must be a MultiIndex or a list of tuples"
  53. with pytest.raises(TypeError, match=msg):
  54. first.union([1, 2, 3], sort=sort)
  55. @pytest.mark.parametrize("sort", [None, False])
  56. def test_difference_base(idx, sort):
  57. second = idx[4:]
  58. answer = idx[:4]
  59. result = idx.difference(second, sort=sort)
  60. if sort is None:
  61. answer = answer.sort_values()
  62. assert result.equals(answer)
  63. tm.assert_index_equal(result, answer)
  64. # GH 10149
  65. cases = [klass(second.values)
  66. for klass in [np.array, Series, list]]
  67. for case in cases:
  68. result = idx.difference(case, sort=sort)
  69. tm.assert_index_equal(result, answer)
  70. msg = "other must be a MultiIndex or a list of tuples"
  71. with pytest.raises(TypeError, match=msg):
  72. idx.difference([1, 2, 3], sort=sort)
  73. @pytest.mark.parametrize("sort", [None, False])
  74. def test_symmetric_difference(idx, sort):
  75. first = idx[1:]
  76. second = idx[:-1]
  77. answer = idx[[-1, 0]]
  78. result = first.symmetric_difference(second, sort=sort)
  79. if sort is None:
  80. answer = answer.sort_values()
  81. tm.assert_index_equal(result, answer)
  82. # GH 10149
  83. cases = [klass(second.values)
  84. for klass in [np.array, Series, list]]
  85. for case in cases:
  86. result = first.symmetric_difference(case, sort=sort)
  87. tm.assert_index_equal(result, answer)
  88. msg = "other must be a MultiIndex or a list of tuples"
  89. with pytest.raises(TypeError, match=msg):
  90. first.symmetric_difference([1, 2, 3], sort=sort)
  91. def test_empty(idx):
  92. # GH 15270
  93. assert not idx.empty
  94. assert idx[:0].empty
  95. @pytest.mark.parametrize("sort", [None, False])
  96. def test_difference(idx, sort):
  97. first = idx
  98. result = first.difference(idx[-3:], sort=sort)
  99. vals = idx[:-3].values
  100. if sort is None:
  101. vals = sorted(vals)
  102. expected = MultiIndex.from_tuples(vals,
  103. sortorder=0,
  104. names=idx.names)
  105. assert isinstance(result, MultiIndex)
  106. assert result.equals(expected)
  107. assert result.names == idx.names
  108. tm.assert_index_equal(result, expected)
  109. # empty difference: reflexive
  110. result = idx.difference(idx, sort=sort)
  111. expected = idx[:0]
  112. assert result.equals(expected)
  113. assert result.names == idx.names
  114. # empty difference: superset
  115. result = idx[-3:].difference(idx, sort=sort)
  116. expected = idx[:0]
  117. assert result.equals(expected)
  118. assert result.names == idx.names
  119. # empty difference: degenerate
  120. result = idx[:0].difference(idx, sort=sort)
  121. expected = idx[:0]
  122. assert result.equals(expected)
  123. assert result.names == idx.names
  124. # names not the same
  125. chunklet = idx[-3:]
  126. chunklet.names = ['foo', 'baz']
  127. result = first.difference(chunklet, sort=sort)
  128. assert result.names == (None, None)
  129. # empty, but non-equal
  130. result = idx.difference(idx.sortlevel(1)[0], sort=sort)
  131. assert len(result) == 0
  132. # raise Exception called with non-MultiIndex
  133. result = first.difference(first.values, sort=sort)
  134. assert result.equals(first[:0])
  135. # name from empty array
  136. result = first.difference([], sort=sort)
  137. assert first.equals(result)
  138. assert first.names == result.names
  139. # name from non-empty array
  140. result = first.difference([('foo', 'one')], sort=sort)
  141. expected = pd.MultiIndex.from_tuples([('bar', 'one'), ('baz', 'two'), (
  142. 'foo', 'two'), ('qux', 'one'), ('qux', 'two')])
  143. expected.names = first.names
  144. assert first.names == result.names
  145. msg = "other must be a MultiIndex or a list of tuples"
  146. with pytest.raises(TypeError, match=msg):
  147. first.difference([1, 2, 3, 4, 5], sort=sort)
  148. def test_difference_sort_special():
  149. # GH-24959
  150. idx = pd.MultiIndex.from_product([[1, 0], ['a', 'b']])
  151. # sort=None, the default
  152. result = idx.difference([])
  153. tm.assert_index_equal(result, idx)
  154. @pytest.mark.xfail(reason="Not implemented.")
  155. def test_difference_sort_special_true():
  156. # TODO decide on True behaviour
  157. idx = pd.MultiIndex.from_product([[1, 0], ['a', 'b']])
  158. result = idx.difference([], sort=True)
  159. expected = pd.MultiIndex.from_product([[0, 1], ['a', 'b']])
  160. tm.assert_index_equal(result, expected)
  161. def test_difference_sort_incomparable():
  162. # GH-24959
  163. idx = pd.MultiIndex.from_product([[1, pd.Timestamp('2000'), 2],
  164. ['a', 'b']])
  165. other = pd.MultiIndex.from_product([[3, pd.Timestamp('2000'), 4],
  166. ['c', 'd']])
  167. # sort=None, the default
  168. # MultiIndex.difference deviates here from other difference
  169. # implementations in not catching the TypeError
  170. with pytest.raises(TypeError):
  171. result = idx.difference(other)
  172. # sort=False
  173. result = idx.difference(other, sort=False)
  174. tm.assert_index_equal(result, idx)
  175. @pytest.mark.xfail(reason="Not implemented.")
  176. def test_difference_sort_incomparable_true():
  177. # TODO decide on True behaviour
  178. # # sort=True, raises
  179. idx = pd.MultiIndex.from_product([[1, pd.Timestamp('2000'), 2],
  180. ['a', 'b']])
  181. other = pd.MultiIndex.from_product([[3, pd.Timestamp('2000'), 4],
  182. ['c', 'd']])
  183. with pytest.raises(TypeError):
  184. idx.difference(other, sort=True)
  185. @pytest.mark.parametrize("sort", [None, False])
  186. def test_union(idx, sort):
  187. piece1 = idx[:5][::-1]
  188. piece2 = idx[3:]
  189. the_union = piece1.union(piece2, sort=sort)
  190. if sort is None:
  191. tm.assert_index_equal(the_union, idx.sort_values())
  192. assert tm.equalContents(the_union, idx)
  193. # corner case, pass self or empty thing:
  194. the_union = idx.union(idx, sort=sort)
  195. assert the_union is idx
  196. the_union = idx.union(idx[:0], sort=sort)
  197. assert the_union is idx
  198. # won't work in python 3
  199. # tuples = _index.values
  200. # result = _index[:4] | tuples[4:]
  201. # assert result.equals(tuples)
  202. # not valid for python 3
  203. # def test_union_with_regular_index(self):
  204. # other = Index(['A', 'B', 'C'])
  205. # result = other.union(idx)
  206. # assert ('foo', 'one') in result
  207. # assert 'B' in result
  208. # result2 = _index.union(other)
  209. # assert result.equals(result2)
  210. @pytest.mark.parametrize("sort", [None, False])
  211. def test_intersection(idx, sort):
  212. piece1 = idx[:5][::-1]
  213. piece2 = idx[3:]
  214. the_int = piece1.intersection(piece2, sort=sort)
  215. if sort is None:
  216. tm.assert_index_equal(the_int, idx[3:5])
  217. assert tm.equalContents(the_int, idx[3:5])
  218. # corner case, pass self
  219. the_int = idx.intersection(idx, sort=sort)
  220. assert the_int is idx
  221. # empty intersection: disjoint
  222. empty = idx[:2].intersection(idx[2:], sort=sort)
  223. expected = idx[:0]
  224. assert empty.equals(expected)
  225. # can't do in python 3
  226. # tuples = _index.values
  227. # result = _index & tuples
  228. # assert result.equals(tuples)
  229. def test_intersect_equal_sort():
  230. # GH-24959
  231. idx = pd.MultiIndex.from_product([[1, 0], ['a', 'b']])
  232. tm.assert_index_equal(idx.intersection(idx, sort=False), idx)
  233. tm.assert_index_equal(idx.intersection(idx, sort=None), idx)
  234. @pytest.mark.xfail(reason="Not implemented.")
  235. def test_intersect_equal_sort_true():
  236. # TODO decide on True behaviour
  237. idx = pd.MultiIndex.from_product([[1, 0], ['a', 'b']])
  238. sorted_ = pd.MultiIndex.from_product([[0, 1], ['a', 'b']])
  239. tm.assert_index_equal(idx.intersection(idx, sort=True), sorted_)
  240. @pytest.mark.parametrize('slice_', [slice(None), slice(0)])
  241. def test_union_sort_other_empty(slice_):
  242. # https://github.com/pandas-dev/pandas/issues/24959
  243. idx = pd.MultiIndex.from_product([[1, 0], ['a', 'b']])
  244. # default, sort=None
  245. other = idx[slice_]
  246. tm.assert_index_equal(idx.union(other), idx)
  247. # MultiIndex does not special case empty.union(idx)
  248. # tm.assert_index_equal(other.union(idx), idx)
  249. # sort=False
  250. tm.assert_index_equal(idx.union(other, sort=False), idx)
  251. @pytest.mark.xfail(reason="Not implemented.")
  252. def test_union_sort_other_empty_sort(slice_):
  253. # TODO decide on True behaviour
  254. # # sort=True
  255. idx = pd.MultiIndex.from_product([[1, 0], ['a', 'b']])
  256. other = idx[:0]
  257. result = idx.union(other, sort=True)
  258. expected = pd.MultiIndex.from_product([[0, 1], ['a', 'b']])
  259. tm.assert_index_equal(result, expected)
  260. def test_union_sort_other_incomparable():
  261. # https://github.com/pandas-dev/pandas/issues/24959
  262. idx = pd.MultiIndex.from_product([[1, pd.Timestamp('2000')], ['a', 'b']])
  263. # default, sort=None
  264. result = idx.union(idx[:1])
  265. tm.assert_index_equal(result, idx)
  266. # sort=False
  267. result = idx.union(idx[:1], sort=False)
  268. tm.assert_index_equal(result, idx)
  269. @pytest.mark.xfail(reason="Not implemented.")
  270. def test_union_sort_other_incomparable_sort():
  271. # TODO decide on True behaviour
  272. # # sort=True
  273. idx = pd.MultiIndex.from_product([[1, pd.Timestamp('2000')], ['a', 'b']])
  274. with pytest.raises(TypeError, match='Cannot compare'):
  275. idx.union(idx[:1], sort=True)
  276. @pytest.mark.parametrize("method", ['union', 'intersection', 'difference',
  277. 'symmetric_difference'])
  278. def test_setops_disallow_true(method):
  279. idx1 = pd.MultiIndex.from_product([['a', 'b'], [1, 2]])
  280. idx2 = pd.MultiIndex.from_product([['b', 'c'], [1, 2]])
  281. with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
  282. getattr(idx1, method)(idx2, sort=True)