test_indexing_slow.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # -*- coding: utf-8 -*-
  2. import warnings
  3. import numpy as np
  4. import pytest
  5. import pandas as pd
  6. from pandas import DataFrame, MultiIndex, Series
  7. import pandas.util.testing as tm
  8. @pytest.mark.slow
  9. @pytest.mark.filterwarnings("ignore::pandas.errors.PerformanceWarning")
  10. def test_multiindex_get_loc(): # GH7724, GH2646
  11. with warnings.catch_warnings(record=True):
  12. # test indexing into a multi-index before & past the lexsort depth
  13. from numpy.random import randint, choice, randn
  14. cols = ['jim', 'joe', 'jolie', 'joline', 'jolia']
  15. def validate(mi, df, key):
  16. mask = np.ones(len(df)).astype('bool')
  17. # test for all partials of this key
  18. for i, k in enumerate(key):
  19. mask &= df.iloc[:, i] == k
  20. if not mask.any():
  21. assert key[:i + 1] not in mi.index
  22. continue
  23. assert key[:i + 1] in mi.index
  24. right = df[mask].copy()
  25. if i + 1 != len(key): # partial key
  26. right.drop(cols[:i + 1], axis=1, inplace=True)
  27. right.set_index(cols[i + 1:-1], inplace=True)
  28. tm.assert_frame_equal(mi.loc[key[:i + 1]], right)
  29. else: # full key
  30. right.set_index(cols[:-1], inplace=True)
  31. if len(right) == 1: # single hit
  32. right = Series(right['jolia'].values,
  33. name=right.index[0],
  34. index=['jolia'])
  35. tm.assert_series_equal(mi.loc[key[:i + 1]], right)
  36. else: # multi hit
  37. tm.assert_frame_equal(mi.loc[key[:i + 1]], right)
  38. def loop(mi, df, keys):
  39. for key in keys:
  40. validate(mi, df, key)
  41. n, m = 1000, 50
  42. vals = [randint(0, 10, n), choice(
  43. list('abcdefghij'), n), choice(
  44. pd.date_range('20141009', periods=10).tolist(), n), choice(
  45. list('ZYXWVUTSRQ'), n), randn(n)]
  46. vals = list(map(tuple, zip(*vals)))
  47. # bunch of keys for testing
  48. keys = [randint(0, 11, m), choice(
  49. list('abcdefghijk'), m), choice(
  50. pd.date_range('20141009', periods=11).tolist(), m), choice(
  51. list('ZYXWVUTSRQP'), m)]
  52. keys = list(map(tuple, zip(*keys)))
  53. keys += list(map(lambda t: t[:-1], vals[::n // m]))
  54. # covers both unique index and non-unique index
  55. df = DataFrame(vals, columns=cols)
  56. a, b = pd.concat([df, df]), df.drop_duplicates(subset=cols[:-1])
  57. for frame in a, b:
  58. for i in range(5): # lexsort depth
  59. df = frame.copy() if i == 0 else frame.sort_values(
  60. by=cols[:i])
  61. mi = df.set_index(cols[:-1])
  62. assert not mi.index.lexsort_depth < i
  63. loop(mi, df, keys)
  64. @pytest.mark.slow
  65. def test_large_mi_dataframe_indexing():
  66. # GH10645
  67. result = MultiIndex.from_arrays([range(10 ** 6), range(10 ** 6)])
  68. assert (not (10 ** 6, 0) in result)