test_whitelist.py 8.3 KB


  1. """
  2. test methods relating to generic function evaluation
  3. the so-called white/black lists
  4. """
  5. from string import ascii_lowercase
  6. import numpy as np
  7. import pytest
  8. from pandas import DataFrame, Index, MultiIndex, Series, compat, date_range
  9. from pandas.util import testing as tm
  10. AGG_FUNCTIONS = ['sum', 'prod', 'min', 'max', 'median', 'mean', 'skew',
  11. 'mad', 'std', 'var', 'sem']
  12. AGG_FUNCTIONS_WITH_SKIPNA = ['skew', 'mad']
  13. df_whitelist = [
  14. 'quantile',
  15. 'fillna',
  16. 'mad',
  17. 'take',
  18. 'idxmax',
  19. 'idxmin',
  20. 'tshift',
  21. 'skew',
  22. 'plot',
  23. 'hist',
  24. 'dtypes',
  25. 'corrwith',
  26. 'corr',
  27. 'cov',
  28. 'diff',
  29. ]
  30. @pytest.fixture(params=df_whitelist)
  31. def df_whitelist_fixture(request):
  32. return request.param
  33. s_whitelist = [
  34. 'quantile',
  35. 'fillna',
  36. 'mad',
  37. 'take',
  38. 'idxmax',
  39. 'idxmin',
  40. 'tshift',
  41. 'skew',
  42. 'plot',
  43. 'hist',
  44. 'dtype',
  45. 'corr',
  46. 'cov',
  47. 'diff',
  48. 'unique',
  49. 'nlargest',
  50. 'nsmallest',
  51. 'is_monotonic_increasing',
  52. 'is_monotonic_decreasing',
  53. ]
  54. @pytest.fixture(params=s_whitelist)
  55. def s_whitelist_fixture(request):
  56. return request.param
  57. @pytest.fixture
  58. def mframe():
  59. index = MultiIndex(levels=[['foo', 'bar', 'baz', 'qux'], ['one', 'two',
  60. 'three']],
  61. codes=[[0, 0, 0, 1, 1, 2, 2, 3, 3, 3],
  62. [0, 1, 2, 0, 1, 1, 2, 0, 1, 2]],
  63. names=['first', 'second'])
  64. return DataFrame(np.random.randn(10, 3), index=index,
  65. columns=['A', 'B', 'C'])
  66. @pytest.fixture
  67. def df():
  68. return DataFrame(
  69. {'A': ['foo', 'bar', 'foo', 'bar', 'foo', 'bar', 'foo', 'foo'],
  70. 'B': ['one', 'one', 'two', 'three', 'two', 'two', 'one', 'three'],
  71. 'C': np.random.randn(8),
  72. 'D': np.random.randn(8)})
  73. @pytest.fixture
  74. def df_letters():
  75. letters = np.array(list(ascii_lowercase))
  76. N = 10
  77. random_letters = letters.take(np.random.randint(0, 26, N))
  78. df = DataFrame({'floats': N / 10 * Series(np.random.random(N)),
  79. 'letters': Series(random_letters)})
  80. return df
  81. @pytest.mark.parametrize("whitelist", [df_whitelist, s_whitelist])
  82. def test_groupby_whitelist(df_letters, whitelist):
  83. df = df_letters
  84. if whitelist == df_whitelist:
  85. # dataframe
  86. obj = df_letters
  87. else:
  88. obj = df_letters['floats']
  89. gb = obj.groupby(df.letters)
  90. assert set(whitelist) == set(gb._apply_whitelist)
  91. def check_whitelist(obj, df, m):
  92. # check the obj for a particular whitelist m
  93. gb = obj.groupby(df.letters)
  94. f = getattr(type(gb), m)
  95. # name
  96. try:
  97. n = f.__name__
  98. except AttributeError:
  99. return
  100. assert n == m
  101. # qualname
  102. if compat.PY3:
  103. try:
  104. n = f.__qualname__
  105. except AttributeError:
  106. return
  107. assert n.endswith(m)
  108. def test_groupby_series_whitelist(df_letters, s_whitelist_fixture):
  109. m = s_whitelist_fixture
  110. df = df_letters
  111. check_whitelist(df.letters, df, m)
  112. def test_groupby_frame_whitelist(df_letters, df_whitelist_fixture):
  113. m = df_whitelist_fixture
  114. df = df_letters
  115. check_whitelist(df, df, m)
  116. @pytest.fixture
  117. def raw_frame():
  118. index = MultiIndex(levels=[['foo', 'bar', 'baz', 'qux'], ['one', 'two',
  119. 'three']],
  120. codes=[[0, 0, 0, 1, 1, 2, 2, 3, 3, 3],
  121. [0, 1, 2, 0, 1, 1, 2, 0, 1, 2]],
  122. names=['first', 'second'])
  123. raw_frame = DataFrame(np.random.randn(10, 3), index=index,
  124. columns=Index(['A', 'B', 'C'], name='exp'))
  125. raw_frame.iloc[1, [1, 2]] = np.nan
  126. raw_frame.iloc[7, [0, 1]] = np.nan
  127. return raw_frame
  128. @pytest.mark.parametrize('op', AGG_FUNCTIONS)
  129. @pytest.mark.parametrize('level', [0, 1])
  130. @pytest.mark.parametrize('axis', [0, 1])
  131. @pytest.mark.parametrize('skipna', [True, False])
  132. @pytest.mark.parametrize('sort', [True, False])
  133. def test_regression_whitelist_methods(
  134. raw_frame, op, level,
  135. axis, skipna, sort):
  136. # GH6944
  137. # GH 17537
  138. # explicitly test the whitelist methods
  139. if axis == 0:
  140. frame = raw_frame
  141. else:
  142. frame = raw_frame.T
  143. if op in AGG_FUNCTIONS_WITH_SKIPNA:
  144. grouped = frame.groupby(level=level, axis=axis, sort=sort)
  145. result = getattr(grouped, op)(skipna=skipna)
  146. expected = getattr(frame, op)(level=level, axis=axis,
  147. skipna=skipna)
  148. if sort:
  149. expected = expected.sort_index(axis=axis, level=level)
  150. tm.assert_frame_equal(result, expected)
  151. else:
  152. grouped = frame.groupby(level=level, axis=axis, sort=sort)
  153. result = getattr(grouped, op)()
  154. expected = getattr(frame, op)(level=level, axis=axis)
  155. if sort:
  156. expected = expected.sort_index(axis=axis, level=level)
  157. tm.assert_frame_equal(result, expected)
  158. def test_groupby_blacklist(df_letters):
  159. df = df_letters
  160. s = df_letters.floats
  161. blacklist = [
  162. 'eval', 'query', 'abs', 'where',
  163. 'mask', 'align', 'groupby', 'clip', 'astype',
  164. 'at', 'combine', 'consolidate', 'convert_objects',
  165. ]
  166. to_methods = [method for method in dir(df) if method.startswith('to_')]
  167. blacklist.extend(to_methods)
  168. # e.g., to_csv
  169. defined_but_not_allowed = ("(?:^Cannot.+{0!r}.+{1!r}.+try using the "
  170. "'apply' method$)")
  171. # e.g., query, eval
  172. not_defined = "(?:^{1!r} object has no attribute {0!r}$)"
  173. fmt = defined_but_not_allowed + '|' + not_defined
  174. for bl in blacklist:
  175. for obj in (df, s):
  176. gb = obj.groupby(df.letters)
  177. msg = fmt.format(bl, type(gb).__name__)
  178. with pytest.raises(AttributeError, match=msg):
  179. getattr(gb, bl)
  180. def test_tab_completion(mframe):
  181. grp = mframe.groupby(level='second')
  182. results = {v for v in dir(grp) if not v.startswith('_')}
  183. expected = {
  184. 'A', 'B', 'C', 'agg', 'aggregate', 'apply', 'boxplot', 'filter',
  185. 'first', 'get_group', 'groups', 'hist', 'indices', 'last', 'max',
  186. 'mean', 'median', 'min', 'ngroups', 'nth', 'ohlc', 'plot',
  187. 'prod', 'size', 'std', 'sum', 'transform', 'var', 'sem', 'count',
  188. 'nunique', 'head', 'describe', 'cummax', 'quantile',
  189. 'rank', 'cumprod', 'tail', 'resample', 'cummin', 'fillna',
  190. 'cumsum', 'cumcount', 'ngroup', 'all', 'shift', 'skew',
  191. 'take', 'tshift', 'pct_change', 'any', 'mad', 'corr', 'corrwith',
  192. 'cov', 'dtypes', 'ndim', 'diff', 'idxmax', 'idxmin',
  193. 'ffill', 'bfill', 'pad', 'backfill', 'rolling', 'expanding', 'pipe',
  194. }
  195. assert results == expected
  196. def test_groupby_function_rename(mframe):
  197. grp = mframe.groupby(level='second')
  198. for name in ['sum', 'prod', 'min', 'max', 'first', 'last']:
  199. f = getattr(grp, name)
  200. assert f.__name__ == name
  201. def test_groupby_selection_with_methods(df):
  202. # some methods which require DatetimeIndex
  203. rng = date_range('2014', periods=len(df))
  204. df.index = rng
  205. g = df.groupby(['A'])[['C']]
  206. g_exp = df[['C']].groupby(df['A'])
  207. # TODO check groupby with > 1 col ?
  208. # methods which are called as .foo()
  209. methods = ['count',
  210. 'corr',
  211. 'cummax',
  212. 'cummin',
  213. 'cumprod',
  214. 'describe',
  215. 'rank',
  216. 'quantile',
  217. 'diff',
  218. 'shift',
  219. 'all',
  220. 'any',
  221. 'idxmin',
  222. 'idxmax',
  223. 'ffill',
  224. 'bfill',
  225. 'pct_change',
  226. 'tshift']
  227. for m in methods:
  228. res = getattr(g, m)()
  229. exp = getattr(g_exp, m)()
  230. # should always be frames!
  231. tm.assert_frame_equal(res, exp)
  232. # methods which aren't just .foo()
  233. tm.assert_frame_equal(g.fillna(0), g_exp.fillna(0))
  234. tm.assert_frame_equal(g.dtypes, g_exp.dtypes)
  235. tm.assert_frame_equal(g.apply(lambda x: x.sum()),
  236. g_exp.apply(lambda x: x.sum()))
  237. tm.assert_frame_equal(g.resample('D').mean(), g_exp.resample('D').mean())
  238. tm.assert_frame_equal(g.resample('D').ohlc(),
  239. g_exp.resample('D').ohlc())
  240. tm.assert_frame_equal(g.filter(lambda x: len(x) == 3),
  241. g_exp.filter(lambda x: len(x) == 3))