pivot.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618
  1. # pylint: disable=E1103
  2. import numpy as np
  3. from pandas.compat import lrange, range, zip
  4. from pandas.util._decorators import Appender, Substitution
  5. from pandas.core.dtypes.cast import maybe_downcast_to_dtype
  6. from pandas.core.dtypes.common import is_integer_dtype, is_list_like, is_scalar
  7. from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
  8. from pandas import compat
  9. import pandas.core.common as com
  10. from pandas.core.frame import _shared_docs
  11. from pandas.core.groupby import Grouper
  12. from pandas.core.index import Index, MultiIndex, _get_objs_combined_axis
  13. from pandas.core.reshape.concat import concat
  14. from pandas.core.reshape.util import cartesian_product
  15. from pandas.core.series import Series
  16. # Note: We need to make sure `frame` is imported before `pivot`, otherwise
  17. # _shared_docs['pivot_table'] will not yet exist. TODO: Fix this dependency
  18. @Substitution('\ndata : DataFrame')
  19. @Appender(_shared_docs['pivot_table'], indents=1)
  20. def pivot_table(data, values=None, index=None, columns=None, aggfunc='mean',
  21. fill_value=None, margins=False, dropna=True,
  22. margins_name='All'):
  23. index = _convert_by(index)
  24. columns = _convert_by(columns)
  25. if isinstance(aggfunc, list):
  26. pieces = []
  27. keys = []
  28. for func in aggfunc:
  29. table = pivot_table(data, values=values, index=index,
  30. columns=columns,
  31. fill_value=fill_value, aggfunc=func,
  32. margins=margins, margins_name=margins_name)
  33. pieces.append(table)
  34. keys.append(getattr(func, '__name__', func))
  35. return concat(pieces, keys=keys, axis=1)
  36. keys = index + columns
  37. values_passed = values is not None
  38. if values_passed:
  39. if is_list_like(values):
  40. values_multi = True
  41. values = list(values)
  42. else:
  43. values_multi = False
  44. values = [values]
  45. # GH14938 Make sure value labels are in data
  46. for i in values:
  47. if i not in data:
  48. raise KeyError(i)
  49. to_filter = []
  50. for x in keys + values:
  51. if isinstance(x, Grouper):
  52. x = x.key
  53. try:
  54. if x in data:
  55. to_filter.append(x)
  56. except TypeError:
  57. pass
  58. if len(to_filter) < len(data.columns):
  59. data = data[to_filter]
  60. else:
  61. values = data.columns
  62. for key in keys:
  63. try:
  64. values = values.drop(key)
  65. except (TypeError, ValueError, KeyError):
  66. pass
  67. values = list(values)
  68. grouped = data.groupby(keys, observed=False)
  69. agged = grouped.agg(aggfunc)
  70. if dropna and isinstance(agged, ABCDataFrame) and len(agged.columns):
  71. agged = agged.dropna(how='all')
  72. # gh-21133
  73. # we want to down cast if
  74. # the original values are ints
  75. # as we grouped with a NaN value
  76. # and then dropped, coercing to floats
  77. for v in [v for v in values if v in data and v in agged]:
  78. if (is_integer_dtype(data[v]) and
  79. not is_integer_dtype(agged[v])):
  80. agged[v] = maybe_downcast_to_dtype(agged[v], data[v].dtype)
  81. table = agged
  82. if table.index.nlevels > 1:
  83. # Related GH #17123
  84. # If index_names are integers, determine whether the integers refer
  85. # to the level position or name.
  86. index_names = agged.index.names[:len(index)]
  87. to_unstack = []
  88. for i in range(len(index), len(keys)):
  89. name = agged.index.names[i]
  90. if name is None or name in index_names:
  91. to_unstack.append(i)
  92. else:
  93. to_unstack.append(name)
  94. table = agged.unstack(to_unstack)
  95. if not dropna:
  96. from pandas import MultiIndex
  97. if table.index.nlevels > 1:
  98. m = MultiIndex.from_arrays(cartesian_product(table.index.levels),
  99. names=table.index.names)
  100. table = table.reindex(m, axis=0)
  101. if table.columns.nlevels > 1:
  102. m = MultiIndex.from_arrays(cartesian_product(table.columns.levels),
  103. names=table.columns.names)
  104. table = table.reindex(m, axis=1)
  105. if isinstance(table, ABCDataFrame):
  106. table = table.sort_index(axis=1)
  107. if fill_value is not None:
  108. table = table.fillna(value=fill_value, downcast='infer')
  109. if margins:
  110. if dropna:
  111. data = data[data.notna().all(axis=1)]
  112. table = _add_margins(table, data, values, rows=index,
  113. cols=columns, aggfunc=aggfunc,
  114. observed=dropna,
  115. margins_name=margins_name, fill_value=fill_value)
  116. # discard the top level
  117. if (values_passed and not values_multi and not table.empty and
  118. (table.columns.nlevels > 1)):
  119. table = table[values[0]]
  120. if len(index) == 0 and len(columns) > 0:
  121. table = table.T
  122. # GH 15193 Make sure empty columns are removed if dropna=True
  123. if isinstance(table, ABCDataFrame) and dropna:
  124. table = table.dropna(how='all', axis=1)
  125. return table
  126. def _add_margins(table, data, values, rows, cols, aggfunc,
  127. observed=None, margins_name='All', fill_value=None):
  128. if not isinstance(margins_name, compat.string_types):
  129. raise ValueError('margins_name argument must be a string')
  130. msg = u'Conflicting name "{name}" in margins'.format(name=margins_name)
  131. for level in table.index.names:
  132. if margins_name in table.index.get_level_values(level):
  133. raise ValueError(msg)
  134. grand_margin = _compute_grand_margin(data, values, aggfunc, margins_name)
  135. # could be passed a Series object with no 'columns'
  136. if hasattr(table, 'columns'):
  137. for level in table.columns.names[1:]:
  138. if margins_name in table.columns.get_level_values(level):
  139. raise ValueError(msg)
  140. if len(rows) > 1:
  141. key = (margins_name,) + ('',) * (len(rows) - 1)
  142. else:
  143. key = margins_name
  144. if not values and isinstance(table, ABCSeries):
  145. # If there are no values and the table is a series, then there is only
  146. # one column in the data. Compute grand margin and return it.
  147. return table.append(Series({key: grand_margin[margins_name]}))
  148. if values:
  149. marginal_result_set = _generate_marginal_results(table, data, values,
  150. rows, cols, aggfunc,
  151. observed,
  152. grand_margin,
  153. margins_name)
  154. if not isinstance(marginal_result_set, tuple):
  155. return marginal_result_set
  156. result, margin_keys, row_margin = marginal_result_set
  157. else:
  158. marginal_result_set = _generate_marginal_results_without_values(
  159. table, data, rows, cols, aggfunc, observed, margins_name)
  160. if not isinstance(marginal_result_set, tuple):
  161. return marginal_result_set
  162. result, margin_keys, row_margin = marginal_result_set
  163. row_margin = row_margin.reindex(result.columns, fill_value=fill_value)
  164. # populate grand margin
  165. for k in margin_keys:
  166. if isinstance(k, compat.string_types):
  167. row_margin[k] = grand_margin[k]
  168. else:
  169. row_margin[k] = grand_margin[k[0]]
  170. from pandas import DataFrame
  171. margin_dummy = DataFrame(row_margin, columns=[key]).T
  172. row_names = result.index.names
  173. try:
  174. for dtype in set(result.dtypes):
  175. cols = result.select_dtypes([dtype]).columns
  176. margin_dummy[cols] = margin_dummy[cols].astype(dtype)
  177. result = result.append(margin_dummy)
  178. except TypeError:
  179. # we cannot reshape, so coerce the axis
  180. result.index = result.index._to_safe_for_reshape()
  181. result = result.append(margin_dummy)
  182. result.index.names = row_names
  183. return result
  184. def _compute_grand_margin(data, values, aggfunc,
  185. margins_name='All'):
  186. if values:
  187. grand_margin = {}
  188. for k, v in data[values].iteritems():
  189. try:
  190. if isinstance(aggfunc, compat.string_types):
  191. grand_margin[k] = getattr(v, aggfunc)()
  192. elif isinstance(aggfunc, dict):
  193. if isinstance(aggfunc[k], compat.string_types):
  194. grand_margin[k] = getattr(v, aggfunc[k])()
  195. else:
  196. grand_margin[k] = aggfunc[k](v)
  197. else:
  198. grand_margin[k] = aggfunc(v)
  199. except TypeError:
  200. pass
  201. return grand_margin
  202. else:
  203. return {margins_name: aggfunc(data.index)}
  204. def _generate_marginal_results(table, data, values, rows, cols, aggfunc,
  205. observed,
  206. grand_margin,
  207. margins_name='All'):
  208. if len(cols) > 0:
  209. # need to "interleave" the margins
  210. table_pieces = []
  211. margin_keys = []
  212. def _all_key(key):
  213. return (key, margins_name) + ('',) * (len(cols) - 1)
  214. if len(rows) > 0:
  215. margin = data[rows + values].groupby(
  216. rows, observed=observed).agg(aggfunc)
  217. cat_axis = 1
  218. for key, piece in table.groupby(level=0,
  219. axis=cat_axis,
  220. observed=observed):
  221. all_key = _all_key(key)
  222. # we are going to mutate this, so need to copy!
  223. piece = piece.copy()
  224. try:
  225. piece[all_key] = margin[key]
  226. except TypeError:
  227. # we cannot reshape, so coerce the axis
  228. piece.set_axis(piece._get_axis(
  229. cat_axis)._to_safe_for_reshape(),
  230. axis=cat_axis, inplace=True)
  231. piece[all_key] = margin[key]
  232. table_pieces.append(piece)
  233. margin_keys.append(all_key)
  234. else:
  235. margin = grand_margin
  236. cat_axis = 0
  237. for key, piece in table.groupby(level=0,
  238. axis=cat_axis,
  239. observed=observed):
  240. all_key = _all_key(key)
  241. table_pieces.append(piece)
  242. table_pieces.append(Series(margin[key], index=[all_key]))
  243. margin_keys.append(all_key)
  244. result = concat(table_pieces, axis=cat_axis)
  245. if len(rows) == 0:
  246. return result
  247. else:
  248. result = table
  249. margin_keys = table.columns
  250. if len(cols) > 0:
  251. row_margin = data[cols + values].groupby(
  252. cols, observed=observed).agg(aggfunc)
  253. row_margin = row_margin.stack()
  254. # slight hack
  255. new_order = [len(cols)] + lrange(len(cols))
  256. row_margin.index = row_margin.index.reorder_levels(new_order)
  257. else:
  258. row_margin = Series(np.nan, index=result.columns)
  259. return result, margin_keys, row_margin
  260. def _generate_marginal_results_without_values(
  261. table, data, rows, cols, aggfunc,
  262. observed, margins_name='All'):
  263. if len(cols) > 0:
  264. # need to "interleave" the margins
  265. margin_keys = []
  266. def _all_key():
  267. if len(cols) == 1:
  268. return margins_name
  269. return (margins_name, ) + ('', ) * (len(cols) - 1)
  270. if len(rows) > 0:
  271. margin = data[rows].groupby(rows,
  272. observed=observed).apply(aggfunc)
  273. all_key = _all_key()
  274. table[all_key] = margin
  275. result = table
  276. margin_keys.append(all_key)
  277. else:
  278. margin = data.groupby(level=0,
  279. axis=0,
  280. observed=observed).apply(aggfunc)
  281. all_key = _all_key()
  282. table[all_key] = margin
  283. result = table
  284. margin_keys.append(all_key)
  285. return result
  286. else:
  287. result = table
  288. margin_keys = table.columns
  289. if len(cols):
  290. row_margin = data[cols].groupby(cols, observed=observed).apply(aggfunc)
  291. else:
  292. row_margin = Series(np.nan, index=result.columns)
  293. return result, margin_keys, row_margin
  294. def _convert_by(by):
  295. if by is None:
  296. by = []
  297. elif (is_scalar(by) or
  298. isinstance(by, (np.ndarray, Index, ABCSeries, Grouper)) or
  299. hasattr(by, '__call__')):
  300. by = [by]
  301. else:
  302. by = list(by)
  303. return by
  304. @Substitution('\ndata : DataFrame')
  305. @Appender(_shared_docs['pivot'], indents=1)
  306. def pivot(data, index=None, columns=None, values=None):
  307. if values is None:
  308. cols = [columns] if index is None else [index, columns]
  309. append = index is None
  310. indexed = data.set_index(cols, append=append)
  311. else:
  312. if index is None:
  313. index = data.index
  314. else:
  315. index = data[index]
  316. index = MultiIndex.from_arrays([index, data[columns]])
  317. if is_list_like(values) and not isinstance(values, tuple):
  318. # Exclude tuple because it is seen as a single column name
  319. indexed = data._constructor(data[values].values, index=index,
  320. columns=values)
  321. else:
  322. indexed = data._constructor_sliced(data[values].values,
  323. index=index)
  324. return indexed.unstack(columns)
  325. def crosstab(index, columns, values=None, rownames=None, colnames=None,
  326. aggfunc=None, margins=False, margins_name='All', dropna=True,
  327. normalize=False):
  328. """
  329. Compute a simple cross-tabulation of two (or more) factors. By default
  330. computes a frequency table of the factors unless an array of values and an
  331. aggregation function are passed
  332. Parameters
  333. ----------
  334. index : array-like, Series, or list of arrays/Series
  335. Values to group by in the rows
  336. columns : array-like, Series, or list of arrays/Series
  337. Values to group by in the columns
  338. values : array-like, optional
  339. Array of values to aggregate according to the factors.
  340. Requires `aggfunc` be specified.
  341. rownames : sequence, default None
  342. If passed, must match number of row arrays passed
  343. colnames : sequence, default None
  344. If passed, must match number of column arrays passed
  345. aggfunc : function, optional
  346. If specified, requires `values` be specified as well
  347. margins : boolean, default False
  348. Add row/column margins (subtotals)
  349. margins_name : string, default 'All'
  350. Name of the row / column that will contain the totals
  351. when margins is True.
  352. .. versionadded:: 0.21.0
  353. dropna : boolean, default True
  354. Do not include columns whose entries are all NaN
  355. normalize : boolean, {'all', 'index', 'columns'}, or {0,1}, default False
  356. Normalize by dividing all values by the sum of values.
  357. - If passed 'all' or `True`, will normalize over all values.
  358. - If passed 'index' will normalize over each row.
  359. - If passed 'columns' will normalize over each column.
  360. - If margins is `True`, will also normalize margin values.
  361. .. versionadded:: 0.18.1
  362. Returns
  363. -------
  364. crosstab : DataFrame
  365. Notes
  366. -----
  367. Any Series passed will have their name attributes used unless row or column
  368. names for the cross-tabulation are specified.
  369. Any input passed containing Categorical data will have **all** of its
  370. categories included in the cross-tabulation, even if the actual data does
  371. not contain any instances of a particular category.
  372. In the event that there aren't overlapping indexes an empty DataFrame will
  373. be returned.
  374. Examples
  375. --------
  376. >>> a = np.array(["foo", "foo", "foo", "foo", "bar", "bar",
  377. ... "bar", "bar", "foo", "foo", "foo"], dtype=object)
  378. >>> b = np.array(["one", "one", "one", "two", "one", "one",
  379. ... "one", "two", "two", "two", "one"], dtype=object)
  380. >>> c = np.array(["dull", "dull", "shiny", "dull", "dull", "shiny",
  381. ... "shiny", "dull", "shiny", "shiny", "shiny"],
  382. ... dtype=object)
  383. >>> pd.crosstab(a, [b, c], rownames=['a'], colnames=['b', 'c'])
  384. ... # doctest: +NORMALIZE_WHITESPACE
  385. b one two
  386. c dull shiny dull shiny
  387. a
  388. bar 1 2 1 0
  389. foo 2 2 1 2
  390. >>> foo = pd.Categorical(['a', 'b'], categories=['a', 'b', 'c'])
  391. >>> bar = pd.Categorical(['d', 'e'], categories=['d', 'e', 'f'])
  392. >>> crosstab(foo, bar) # 'c' and 'f' are not represented in the data,
  393. # and will not be shown in the output because
  394. # dropna is True by default. Set 'dropna=False'
  395. # to preserve categories with no data
  396. ... # doctest: +SKIP
  397. col_0 d e
  398. row_0
  399. a 1 0
  400. b 0 1
  401. >>> crosstab(foo, bar, dropna=False) # 'c' and 'f' are not represented
  402. # in the data, but they still will be counted
  403. # and shown in the output
  404. ... # doctest: +SKIP
  405. col_0 d e f
  406. row_0
  407. a 1 0 0
  408. b 0 1 0
  409. c 0 0 0
  410. """
  411. index = com.maybe_make_list(index)
  412. columns = com.maybe_make_list(columns)
  413. rownames = _get_names(index, rownames, prefix='row')
  414. colnames = _get_names(columns, colnames, prefix='col')
  415. common_idx = _get_objs_combined_axis(index + columns, intersect=True,
  416. sort=False)
  417. data = {}
  418. data.update(zip(rownames, index))
  419. data.update(zip(colnames, columns))
  420. if values is None and aggfunc is not None:
  421. raise ValueError("aggfunc cannot be used without values.")
  422. if values is not None and aggfunc is None:
  423. raise ValueError("values cannot be used without an aggfunc.")
  424. from pandas import DataFrame
  425. df = DataFrame(data, index=common_idx)
  426. if values is None:
  427. df['__dummy__'] = 0
  428. kwargs = {'aggfunc': len, 'fill_value': 0}
  429. else:
  430. df['__dummy__'] = values
  431. kwargs = {'aggfunc': aggfunc}
  432. table = df.pivot_table('__dummy__', index=rownames, columns=colnames,
  433. margins=margins, margins_name=margins_name,
  434. dropna=dropna, **kwargs)
  435. # Post-process
  436. if normalize is not False:
  437. table = _normalize(table, normalize=normalize, margins=margins,
  438. margins_name=margins_name)
  439. return table
  440. def _normalize(table, normalize, margins, margins_name='All'):
  441. if not isinstance(normalize, bool) and not isinstance(normalize,
  442. compat.string_types):
  443. axis_subs = {0: 'index', 1: 'columns'}
  444. try:
  445. normalize = axis_subs[normalize]
  446. except KeyError:
  447. raise ValueError("Not a valid normalize argument")
  448. if margins is False:
  449. # Actual Normalizations
  450. normalizers = {
  451. 'all': lambda x: x / x.sum(axis=1).sum(axis=0),
  452. 'columns': lambda x: x / x.sum(),
  453. 'index': lambda x: x.div(x.sum(axis=1), axis=0)
  454. }
  455. normalizers[True] = normalizers['all']
  456. try:
  457. f = normalizers[normalize]
  458. except KeyError:
  459. raise ValueError("Not a valid normalize argument")
  460. table = f(table)
  461. table = table.fillna(0)
  462. elif margins is True:
  463. column_margin = table.loc[:, margins_name].drop(margins_name)
  464. index_margin = table.loc[margins_name, :].drop(margins_name)
  465. table = table.drop(margins_name, axis=1).drop(margins_name)
  466. # to keep index and columns names
  467. table_index_names = table.index.names
  468. table_columns_names = table.columns.names
  469. # Normalize core
  470. table = _normalize(table, normalize=normalize, margins=False)
  471. # Fix Margins
  472. if normalize == 'columns':
  473. column_margin = column_margin / column_margin.sum()
  474. table = concat([table, column_margin], axis=1)
  475. table = table.fillna(0)
  476. elif normalize == 'index':
  477. index_margin = index_margin / index_margin.sum()
  478. table = table.append(index_margin)
  479. table = table.fillna(0)
  480. elif normalize == "all" or normalize is True:
  481. column_margin = column_margin / column_margin.sum()
  482. index_margin = index_margin / index_margin.sum()
  483. index_margin.loc[margins_name] = 1
  484. table = concat([table, column_margin], axis=1)
  485. table = table.append(index_margin)
  486. table = table.fillna(0)
  487. else:
  488. raise ValueError("Not a valid normalize argument")
  489. table.index.names = table_index_names
  490. table.columns.names = table_columns_names
  491. else:
  492. raise ValueError("Not a valid margins argument")
  493. return table
  494. def _get_names(arrs, names, prefix='row'):
  495. if names is None:
  496. names = []
  497. for i, arr in enumerate(arrs):
  498. if isinstance(arr, ABCSeries) and arr.name is not None:
  499. names.append(arr.name)
  500. else:
  501. names.append('{prefix}_{i}'.format(prefix=prefix, i=i))
  502. else:
  503. if len(names) != len(arrs):
  504. raise AssertionError('arrays and names must have the same length')
  505. if not isinstance(names, list):
  506. names = list(names)
  507. return names