ops.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898
  1. """
  2. Provide classes to perform the groupby aggregate operations.
  3. These are not exposed to the user and provide implementations of the grouping
  4. operations, primarily in cython. These classes (BaseGrouper and BinGrouper)
  5. are contained *in* the SeriesGroupBy and DataFrameGroupBy objects.
  6. """
  7. import collections
  8. import numpy as np
  9. from pandas._libs import NaT, groupby as libgroupby, iNaT, lib, reduction
  10. from pandas.compat import lzip, range, zip
  11. from pandas.errors import AbstractMethodError
  12. from pandas.util._decorators import cache_readonly
  13. from pandas.core.dtypes.common import (
  14. ensure_float64, ensure_int64, ensure_int64_or_float64, ensure_object,
  15. ensure_platform_int, is_bool_dtype, is_categorical_dtype, is_complex_dtype,
  16. is_datetime64_any_dtype, is_integer_dtype, is_numeric_dtype,
  17. is_timedelta64_dtype, needs_i8_conversion)
  18. from pandas.core.dtypes.missing import _maybe_fill, isna
  19. import pandas.core.algorithms as algorithms
  20. from pandas.core.base import SelectionMixin
  21. import pandas.core.common as com
  22. from pandas.core.frame import DataFrame
  23. from pandas.core.generic import NDFrame
  24. from pandas.core.groupby import base
  25. from pandas.core.index import Index, MultiIndex, ensure_index
  26. from pandas.core.series import Series
  27. from pandas.core.sorting import (
  28. compress_group_index, decons_obs_group_ids, get_flattened_iterator,
  29. get_group_index, get_group_index_sorter, get_indexer_dict)
  30. def generate_bins_generic(values, binner, closed):
  31. """
  32. Generate bin edge offsets and bin labels for one array using another array
  33. which has bin edge values. Both arrays must be sorted.
  34. Parameters
  35. ----------
  36. values : array of values
  37. binner : a comparable array of values representing bins into which to bin
  38. the first array. Note, 'values' end-points must fall within 'binner'
  39. end-points.
  40. closed : which end of bin is closed; left (default), right
  41. Returns
  42. -------
  43. bins : array of offsets (into 'values' argument) of bins.
  44. Zero and last edge are excluded in result, so for instance the first
  45. bin is values[0:bin[0]] and the last is values[bin[-1]:]
  46. """
  47. lenidx = len(values)
  48. lenbin = len(binner)
  49. if lenidx <= 0 or lenbin <= 0:
  50. raise ValueError("Invalid length for values or for binner")
  51. # check binner fits data
  52. if values[0] < binner[0]:
  53. raise ValueError("Values falls before first bin")
  54. if values[lenidx - 1] > binner[lenbin - 1]:
  55. raise ValueError("Values falls after last bin")
  56. bins = np.empty(lenbin - 1, dtype=np.int64)
  57. j = 0 # index into values
  58. bc = 0 # bin count
  59. # linear scan, presume nothing about values/binner except that it fits ok
  60. for i in range(0, lenbin - 1):
  61. r_bin = binner[i + 1]
  62. # count values in current bin, advance to next bin
  63. while j < lenidx and (values[j] < r_bin or
  64. (closed == 'right' and values[j] == r_bin)):
  65. j += 1
  66. bins[bc] = j
  67. bc += 1
  68. return bins
  69. class BaseGrouper(object):
  70. """
  71. This is an internal Grouper class, which actually holds
  72. the generated groups
  73. Parameters
  74. ----------
  75. axis : int
  76. the axis to group
  77. groupings : array of grouping
  78. all the grouping instances to handle in this grouper
  79. for example for grouper list to groupby, need to pass the list
  80. sort : boolean, default True
  81. whether this grouper will give sorted result or not
  82. group_keys : boolean, default True
  83. mutated : boolean, default False
  84. indexer : intp array, optional
  85. the indexer created by Grouper
  86. some groupers (TimeGrouper) will sort its axis and its
  87. group_info is also sorted, so need the indexer to reorder
  88. """
  89. def __init__(self, axis, groupings, sort=True, group_keys=True,
  90. mutated=False, indexer=None):
  91. self._filter_empty_groups = self.compressed = len(groupings) != 1
  92. self.axis = axis
  93. self.groupings = groupings
  94. self.sort = sort
  95. self.group_keys = group_keys
  96. self.mutated = mutated
  97. self.indexer = indexer
  98. @property
  99. def shape(self):
  100. return tuple(ping.ngroups for ping in self.groupings)
  101. def __iter__(self):
  102. return iter(self.indices)
  103. @property
  104. def nkeys(self):
  105. return len(self.groupings)
  106. def get_iterator(self, data, axis=0):
  107. """
  108. Groupby iterator
  109. Returns
  110. -------
  111. Generator yielding sequence of (name, subsetted object)
  112. for each group
  113. """
  114. splitter = self._get_splitter(data, axis=axis)
  115. keys = self._get_group_keys()
  116. for key, (i, group) in zip(keys, splitter):
  117. yield key, group
  118. def _get_splitter(self, data, axis=0):
  119. comp_ids, _, ngroups = self.group_info
  120. return get_splitter(data, comp_ids, ngroups, axis=axis)
  121. def _get_group_keys(self):
  122. if len(self.groupings) == 1:
  123. return self.levels[0]
  124. else:
  125. comp_ids, _, ngroups = self.group_info
  126. # provide "flattened" iterator for multi-group setting
  127. return get_flattened_iterator(comp_ids,
  128. ngroups,
  129. self.levels,
  130. self.labels)
  131. def apply(self, f, data, axis=0):
  132. mutated = self.mutated
  133. splitter = self._get_splitter(data, axis=axis)
  134. group_keys = self._get_group_keys()
  135. # oh boy
  136. f_name = com.get_callable_name(f)
  137. if (f_name not in base.plotting_methods and
  138. hasattr(splitter, 'fast_apply') and axis == 0):
  139. try:
  140. values, mutated = splitter.fast_apply(f, group_keys)
  141. return group_keys, values, mutated
  142. except reduction.InvalidApply:
  143. # we detect a mutation of some kind
  144. # so take slow path
  145. pass
  146. except Exception:
  147. # raise this error to the caller
  148. pass
  149. result_values = []
  150. for key, (i, group) in zip(group_keys, splitter):
  151. object.__setattr__(group, 'name', key)
  152. # group might be modified
  153. group_axes = _get_axes(group)
  154. res = f(group)
  155. if not _is_indexed_like(res, group_axes):
  156. mutated = True
  157. result_values.append(res)
  158. return group_keys, result_values, mutated
  159. @cache_readonly
  160. def indices(self):
  161. """ dict {group name -> group indices} """
  162. if len(self.groupings) == 1:
  163. return self.groupings[0].indices
  164. else:
  165. label_list = [ping.labels for ping in self.groupings]
  166. keys = [com.values_from_object(ping.group_index)
  167. for ping in self.groupings]
  168. return get_indexer_dict(label_list, keys)
  169. @property
  170. def labels(self):
  171. return [ping.labels for ping in self.groupings]
  172. @property
  173. def levels(self):
  174. return [ping.group_index for ping in self.groupings]
  175. @property
  176. def names(self):
  177. return [ping.name for ping in self.groupings]
  178. def size(self):
  179. """
  180. Compute group sizes
  181. """
  182. ids, _, ngroup = self.group_info
  183. ids = ensure_platform_int(ids)
  184. if ngroup:
  185. out = np.bincount(ids[ids != -1], minlength=ngroup)
  186. else:
  187. out = ids
  188. return Series(out,
  189. index=self.result_index,
  190. dtype='int64')
  191. @cache_readonly
  192. def groups(self):
  193. """ dict {group name -> group labels} """
  194. if len(self.groupings) == 1:
  195. return self.groupings[0].groups
  196. else:
  197. to_groupby = lzip(*(ping.grouper for ping in self.groupings))
  198. to_groupby = Index(to_groupby)
  199. return self.axis.groupby(to_groupby)
  200. @cache_readonly
  201. def is_monotonic(self):
  202. # return if my group orderings are monotonic
  203. return Index(self.group_info[0]).is_monotonic
  204. @cache_readonly
  205. def group_info(self):
  206. comp_ids, obs_group_ids = self._get_compressed_labels()
  207. ngroups = len(obs_group_ids)
  208. comp_ids = ensure_int64(comp_ids)
  209. return comp_ids, obs_group_ids, ngroups
  210. @cache_readonly
  211. def label_info(self):
  212. # return the labels of items in original grouped axis
  213. labels, _, _ = self.group_info
  214. if self.indexer is not None:
  215. sorter = np.lexsort((labels, self.indexer))
  216. labels = labels[sorter]
  217. return labels
  218. def _get_compressed_labels(self):
  219. all_labels = [ping.labels for ping in self.groupings]
  220. if len(all_labels) > 1:
  221. group_index = get_group_index(all_labels, self.shape,
  222. sort=True, xnull=True)
  223. return compress_group_index(group_index, sort=self.sort)
  224. ping = self.groupings[0]
  225. return ping.labels, np.arange(len(ping.group_index))
  226. @cache_readonly
  227. def ngroups(self):
  228. return len(self.result_index)
  229. @property
  230. def recons_labels(self):
  231. comp_ids, obs_ids, _ = self.group_info
  232. labels = (ping.labels for ping in self.groupings)
  233. return decons_obs_group_ids(
  234. comp_ids, obs_ids, self.shape, labels, xnull=True)
  235. @cache_readonly
  236. def result_index(self):
  237. if not self.compressed and len(self.groupings) == 1:
  238. return self.groupings[0].result_index.rename(self.names[0])
  239. codes = self.recons_labels
  240. levels = [ping.result_index for ping in self.groupings]
  241. result = MultiIndex(levels=levels,
  242. codes=codes,
  243. verify_integrity=False,
  244. names=self.names)
  245. return result
  246. def get_group_levels(self):
  247. if not self.compressed and len(self.groupings) == 1:
  248. return [self.groupings[0].result_index]
  249. name_list = []
  250. for ping, labels in zip(self.groupings, self.recons_labels):
  251. labels = ensure_platform_int(labels)
  252. levels = ping.result_index.take(labels)
  253. name_list.append(levels)
  254. return name_list
  255. # ------------------------------------------------------------
  256. # Aggregation functions
  257. _cython_functions = {
  258. 'aggregate': {
  259. 'add': 'group_add',
  260. 'prod': 'group_prod',
  261. 'min': 'group_min',
  262. 'max': 'group_max',
  263. 'mean': 'group_mean',
  264. 'median': {
  265. 'name': 'group_median'
  266. },
  267. 'var': 'group_var',
  268. 'first': {
  269. 'name': 'group_nth',
  270. 'f': lambda func, a, b, c, d, e: func(a, b, c, d, 1, -1)
  271. },
  272. 'last': 'group_last',
  273. 'ohlc': 'group_ohlc',
  274. },
  275. 'transform': {
  276. 'cumprod': 'group_cumprod',
  277. 'cumsum': 'group_cumsum',
  278. 'cummin': 'group_cummin',
  279. 'cummax': 'group_cummax',
  280. 'rank': {
  281. 'name': 'group_rank',
  282. 'f': lambda func, a, b, c, d, **kwargs: func(
  283. a, b, c, d,
  284. kwargs.get('ties_method', 'average'),
  285. kwargs.get('ascending', True),
  286. kwargs.get('pct', False),
  287. kwargs.get('na_option', 'keep')
  288. )
  289. }
  290. }
  291. }
  292. _cython_arity = {
  293. 'ohlc': 4, # OHLC
  294. }
  295. _name_functions = {
  296. 'ohlc': lambda *args: ['open', 'high', 'low', 'close']
  297. }
  298. def _is_builtin_func(self, arg):
  299. """
  300. if we define an builtin function for this argument, return it,
  301. otherwise return the arg
  302. """
  303. return SelectionMixin._builtin_table.get(arg, arg)
  304. def _get_cython_function(self, kind, how, values, is_numeric):
  305. dtype_str = values.dtype.name
  306. def get_func(fname):
  307. # see if there is a fused-type version of function
  308. # only valid for numeric
  309. f = getattr(libgroupby, fname, None)
  310. if f is not None and is_numeric:
  311. return f
  312. # otherwise find dtype-specific version, falling back to object
  313. for dt in [dtype_str, 'object']:
  314. f = getattr(libgroupby, "{fname}_{dtype_str}".format(
  315. fname=fname, dtype_str=dtype_str), None)
  316. if f is not None:
  317. return f
  318. ftype = self._cython_functions[kind][how]
  319. if isinstance(ftype, dict):
  320. func = afunc = get_func(ftype['name'])
  321. # a sub-function
  322. f = ftype.get('f')
  323. if f is not None:
  324. def wrapper(*args, **kwargs):
  325. return f(afunc, *args, **kwargs)
  326. # need to curry our sub-function
  327. func = wrapper
  328. else:
  329. func = get_func(ftype)
  330. if func is None:
  331. raise NotImplementedError(
  332. "function is not implemented for this dtype: "
  333. "[how->{how},dtype->{dtype_str}]".format(how=how,
  334. dtype_str=dtype_str))
  335. return func
  336. def _cython_operation(self, kind, values, how, axis, min_count=-1,
  337. **kwargs):
  338. assert kind in ['transform', 'aggregate']
  339. # can we do this operation with our cython functions
  340. # if not raise NotImplementedError
  341. # we raise NotImplemented if this is an invalid operation
  342. # entirely, e.g. adding datetimes
  343. # categoricals are only 1d, so we
  344. # are not setup for dim transforming
  345. if is_categorical_dtype(values):
  346. raise NotImplementedError(
  347. "categoricals are not support in cython ops ATM")
  348. elif is_datetime64_any_dtype(values):
  349. if how in ['add', 'prod', 'cumsum', 'cumprod']:
  350. raise NotImplementedError(
  351. "datetime64 type does not support {} "
  352. "operations".format(how))
  353. elif is_timedelta64_dtype(values):
  354. if how in ['prod', 'cumprod']:
  355. raise NotImplementedError(
  356. "timedelta64 type does not support {} "
  357. "operations".format(how))
  358. arity = self._cython_arity.get(how, 1)
  359. vdim = values.ndim
  360. swapped = False
  361. if vdim == 1:
  362. values = values[:, None]
  363. out_shape = (self.ngroups, arity)
  364. else:
  365. if axis > 0:
  366. swapped = True
  367. values = values.swapaxes(0, axis)
  368. if arity > 1:
  369. raise NotImplementedError("arity of more than 1 is not "
  370. "supported for the 'how' argument")
  371. out_shape = (self.ngroups,) + values.shape[1:]
  372. is_datetimelike = needs_i8_conversion(values.dtype)
  373. is_numeric = is_numeric_dtype(values.dtype)
  374. if is_datetimelike:
  375. values = values.view('int64')
  376. is_numeric = True
  377. elif is_bool_dtype(values.dtype):
  378. values = ensure_float64(values)
  379. elif is_integer_dtype(values):
  380. # we use iNaT for the missing value on ints
  381. # so pre-convert to guard this condition
  382. if (values == iNaT).any():
  383. values = ensure_float64(values)
  384. else:
  385. values = ensure_int64_or_float64(values)
  386. elif is_numeric and not is_complex_dtype(values):
  387. values = ensure_float64(values)
  388. else:
  389. values = values.astype(object)
  390. try:
  391. func = self._get_cython_function(
  392. kind, how, values, is_numeric)
  393. except NotImplementedError:
  394. if is_numeric:
  395. values = ensure_float64(values)
  396. func = self._get_cython_function(
  397. kind, how, values, is_numeric)
  398. else:
  399. raise
  400. if how == 'rank':
  401. out_dtype = 'float'
  402. else:
  403. if is_numeric:
  404. out_dtype = '{kind}{itemsize}'.format(
  405. kind=values.dtype.kind, itemsize=values.dtype.itemsize)
  406. else:
  407. out_dtype = 'object'
  408. labels, _, _ = self.group_info
  409. if kind == 'aggregate':
  410. result = _maybe_fill(np.empty(out_shape, dtype=out_dtype),
  411. fill_value=np.nan)
  412. counts = np.zeros(self.ngroups, dtype=np.int64)
  413. result = self._aggregate(
  414. result, counts, values, labels, func, is_numeric,
  415. is_datetimelike, min_count)
  416. elif kind == 'transform':
  417. result = _maybe_fill(np.empty_like(values, dtype=out_dtype),
  418. fill_value=np.nan)
  419. # TODO: min_count
  420. result = self._transform(
  421. result, values, labels, func, is_numeric, is_datetimelike,
  422. **kwargs)
  423. if is_integer_dtype(result) and not is_datetimelike:
  424. mask = result == iNaT
  425. if mask.any():
  426. result = result.astype('float64')
  427. result[mask] = np.nan
  428. if (kind == 'aggregate' and
  429. self._filter_empty_groups and not counts.all()):
  430. if result.ndim == 2:
  431. try:
  432. result = lib.row_bool_subset(
  433. result, (counts > 0).view(np.uint8))
  434. except ValueError:
  435. result = lib.row_bool_subset_object(
  436. ensure_object(result),
  437. (counts > 0).view(np.uint8))
  438. else:
  439. result = result[counts > 0]
  440. if vdim == 1 and arity == 1:
  441. result = result[:, 0]
  442. if how in self._name_functions:
  443. # TODO
  444. names = self._name_functions[how]()
  445. else:
  446. names = None
  447. if swapped:
  448. result = result.swapaxes(0, axis)
  449. return result, names
  450. def aggregate(self, values, how, axis=0, min_count=-1):
  451. return self._cython_operation('aggregate', values, how, axis,
  452. min_count=min_count)
  453. def transform(self, values, how, axis=0, **kwargs):
  454. return self._cython_operation('transform', values, how, axis, **kwargs)
  455. def _aggregate(self, result, counts, values, comp_ids, agg_func,
  456. is_numeric, is_datetimelike, min_count=-1):
  457. if values.ndim > 3:
  458. # punting for now
  459. raise NotImplementedError("number of dimensions is currently "
  460. "limited to 3")
  461. elif values.ndim > 2:
  462. for i, chunk in enumerate(values.transpose(2, 0, 1)):
  463. chunk = chunk.squeeze()
  464. agg_func(result[:, :, i], counts, chunk, comp_ids,
  465. min_count)
  466. else:
  467. agg_func(result, counts, values, comp_ids, min_count)
  468. return result
  469. def _transform(self, result, values, comp_ids, transform_func,
  470. is_numeric, is_datetimelike, **kwargs):
  471. comp_ids, _, ngroups = self.group_info
  472. if values.ndim > 3:
  473. # punting for now
  474. raise NotImplementedError("number of dimensions is currently "
  475. "limited to 3")
  476. elif values.ndim > 2:
  477. for i, chunk in enumerate(values.transpose(2, 0, 1)):
  478. transform_func(result[:, :, i], values,
  479. comp_ids, is_datetimelike, **kwargs)
  480. else:
  481. transform_func(result, values, comp_ids, is_datetimelike, **kwargs)
  482. return result
  483. def agg_series(self, obj, func):
  484. try:
  485. return self._aggregate_series_fast(obj, func)
  486. except Exception:
  487. return self._aggregate_series_pure_python(obj, func)
  488. def _aggregate_series_fast(self, obj, func):
  489. func = self._is_builtin_func(func)
  490. if obj.index._has_complex_internals:
  491. raise TypeError('Incompatible index for Cython grouper')
  492. group_index, _, ngroups = self.group_info
  493. # avoids object / Series creation overhead
  494. dummy = obj._get_values(slice(None, 0)).to_dense()
  495. indexer = get_group_index_sorter(group_index, ngroups)
  496. obj = obj._take(indexer).to_dense()
  497. group_index = algorithms.take_nd(
  498. group_index, indexer, allow_fill=False)
  499. grouper = reduction.SeriesGrouper(obj, func, group_index, ngroups,
  500. dummy)
  501. result, counts = grouper.get_result()
  502. return result, counts
  503. def _aggregate_series_pure_python(self, obj, func):
  504. group_index, _, ngroups = self.group_info
  505. counts = np.zeros(ngroups, dtype=int)
  506. result = None
  507. splitter = get_splitter(obj, group_index, ngroups, axis=self.axis)
  508. for label, group in splitter:
  509. res = func(group)
  510. if result is None:
  511. if (isinstance(res, (Series, Index, np.ndarray))):
  512. raise ValueError('Function does not reduce')
  513. result = np.empty(ngroups, dtype='O')
  514. counts[label] = group.shape[0]
  515. result[label] = res
  516. result = lib.maybe_convert_objects(result, try_float=0)
  517. return result, counts
  518. class BinGrouper(BaseGrouper):
  519. """
  520. This is an internal Grouper class
  521. Parameters
  522. ----------
  523. bins : the split index of binlabels to group the item of axis
  524. binlabels : the label list
  525. filter_empty : boolean, default False
  526. mutated : boolean, default False
  527. indexer : a intp array
  528. Examples
  529. --------
  530. bins: [2, 4, 6, 8, 10]
  531. binlabels: DatetimeIndex(['2005-01-01', '2005-01-03',
  532. '2005-01-05', '2005-01-07', '2005-01-09'],
  533. dtype='datetime64[ns]', freq='2D')
  534. the group_info, which contains the label of each item in grouped
  535. axis, the index of label in label list, group number, is
  536. (array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4]), array([0, 1, 2, 3, 4]), 5)
  537. means that, the grouped axis has 10 items, can be grouped into 5
  538. labels, the first and second items belong to the first label, the
  539. third and forth items belong to the second label, and so on
  540. """
  541. def __init__(self, bins, binlabels, filter_empty=False, mutated=False,
  542. indexer=None):
  543. self.bins = ensure_int64(bins)
  544. self.binlabels = ensure_index(binlabels)
  545. self._filter_empty_groups = filter_empty
  546. self.mutated = mutated
  547. self.indexer = indexer
  548. @cache_readonly
  549. def groups(self):
  550. """ dict {group name -> group labels} """
  551. # this is mainly for compat
  552. # GH 3881
  553. result = {key: value for key, value in zip(self.binlabels, self.bins)
  554. if key is not NaT}
  555. return result
  556. @property
  557. def nkeys(self):
  558. return 1
  559. def get_iterator(self, data, axis=0):
  560. """
  561. Groupby iterator
  562. Returns
  563. -------
  564. Generator yielding sequence of (name, subsetted object)
  565. for each group
  566. """
  567. if isinstance(data, NDFrame):
  568. slicer = lambda start, edge: data._slice(
  569. slice(start, edge), axis=axis)
  570. length = len(data.axes[axis])
  571. else:
  572. slicer = lambda start, edge: data[slice(start, edge)]
  573. length = len(data)
  574. start = 0
  575. for edge, label in zip(self.bins, self.binlabels):
  576. if label is not NaT:
  577. yield label, slicer(start, edge)
  578. start = edge
  579. if start < length:
  580. yield self.binlabels[-1], slicer(start, None)
  581. @cache_readonly
  582. def indices(self):
  583. indices = collections.defaultdict(list)
  584. i = 0
  585. for label, bin in zip(self.binlabels, self.bins):
  586. if i < bin:
  587. if label is not NaT:
  588. indices[label] = list(range(i, bin))
  589. i = bin
  590. return indices
  591. @cache_readonly
  592. def group_info(self):
  593. ngroups = self.ngroups
  594. obs_group_ids = np.arange(ngroups)
  595. rep = np.diff(np.r_[0, self.bins])
  596. rep = ensure_platform_int(rep)
  597. if ngroups == len(self.bins):
  598. comp_ids = np.repeat(np.arange(ngroups), rep)
  599. else:
  600. comp_ids = np.repeat(np.r_[-1, np.arange(ngroups)], rep)
  601. return (comp_ids.astype('int64', copy=False),
  602. obs_group_ids.astype('int64', copy=False),
  603. ngroups)
  604. @cache_readonly
  605. def result_index(self):
  606. if len(self.binlabels) != 0 and isna(self.binlabels[0]):
  607. return self.binlabels[1:]
  608. return self.binlabels
  609. @property
  610. def levels(self):
  611. return [self.binlabels]
  612. @property
  613. def names(self):
  614. return [self.binlabels.name]
  615. @property
  616. def groupings(self):
  617. from pandas.core.groupby.grouper import Grouping
  618. return [Grouping(lvl, lvl, in_axis=False, level=None, name=name)
  619. for lvl, name in zip(self.levels, self.names)]
  620. def agg_series(self, obj, func):
  621. dummy = obj[:0]
  622. grouper = reduction.SeriesBinGrouper(obj, func, self.bins, dummy)
  623. return grouper.get_result()
  624. def _get_axes(group):
  625. if isinstance(group, Series):
  626. return [group.index]
  627. else:
  628. return group.axes
  629. def _is_indexed_like(obj, axes):
  630. if isinstance(obj, Series):
  631. if len(axes) > 1:
  632. return False
  633. return obj.index.equals(axes[0])
  634. elif isinstance(obj, DataFrame):
  635. return obj.index.equals(axes[0])
  636. return False
  637. # ----------------------------------------------------------------------
  638. # Splitting / application
  639. class DataSplitter(object):
  640. def __init__(self, data, labels, ngroups, axis=0):
  641. self.data = data
  642. self.labels = ensure_int64(labels)
  643. self.ngroups = ngroups
  644. self.axis = axis
  645. @cache_readonly
  646. def slabels(self):
  647. # Sorted labels
  648. return algorithms.take_nd(self.labels, self.sort_idx, allow_fill=False)
  649. @cache_readonly
  650. def sort_idx(self):
  651. # Counting sort indexer
  652. return get_group_index_sorter(self.labels, self.ngroups)
  653. def __iter__(self):
  654. sdata = self._get_sorted_data()
  655. if self.ngroups == 0:
  656. # we are inside a generator, rather than raise StopIteration
  657. # we merely return signal the end
  658. return
  659. starts, ends = lib.generate_slices(self.slabels, self.ngroups)
  660. for i, (start, end) in enumerate(zip(starts, ends)):
  661. # Since I'm now compressing the group ids, it's now not "possible"
  662. # to produce empty slices because such groups would not be observed
  663. # in the data
  664. # if start >= end:
  665. # raise AssertionError('Start %s must be less than end %s'
  666. # % (str(start), str(end)))
  667. yield i, self._chop(sdata, slice(start, end))
  668. def _get_sorted_data(self):
  669. return self.data._take(self.sort_idx, axis=self.axis)
  670. def _chop(self, sdata, slice_obj):
  671. return sdata.iloc[slice_obj]
  672. def apply(self, f):
  673. raise AbstractMethodError(self)
  674. class SeriesSplitter(DataSplitter):
  675. def _chop(self, sdata, slice_obj):
  676. return sdata._get_values(slice_obj).to_dense()
  677. class FrameSplitter(DataSplitter):
  678. def fast_apply(self, f, names):
  679. # must return keys::list, values::list, mutated::bool
  680. try:
  681. starts, ends = lib.generate_slices(self.slabels, self.ngroups)
  682. except Exception:
  683. # fails when all -1
  684. return [], True
  685. sdata = self._get_sorted_data()
  686. results, mutated = reduction.apply_frame_axis0(sdata, f, names,
  687. starts, ends)
  688. return results, mutated
  689. def _chop(self, sdata, slice_obj):
  690. if self.axis == 0:
  691. return sdata.iloc[slice_obj]
  692. else:
  693. return sdata._slice(slice_obj, axis=1) # .loc[:, slice_obj]
  694. class NDFrameSplitter(DataSplitter):
  695. def __init__(self, data, labels, ngroups, axis=0):
  696. super(NDFrameSplitter, self).__init__(data, labels, ngroups, axis=axis)
  697. self.factory = data._constructor
  698. def _get_sorted_data(self):
  699. # this is the BlockManager
  700. data = self.data._data
  701. # this is sort of wasteful but...
  702. sorted_axis = data.axes[self.axis].take(self.sort_idx)
  703. sorted_data = data.reindex_axis(sorted_axis, axis=self.axis)
  704. return sorted_data
  705. def _chop(self, sdata, slice_obj):
  706. return self.factory(sdata.get_slice(slice_obj, axis=self.axis))
  707. def get_splitter(data, *args, **kwargs):
  708. if isinstance(data, Series):
  709. klass = SeriesSplitter
  710. elif isinstance(data, DataFrame):
  711. klass = FrameSplitter
  712. else:
  713. klass = NDFrameSplitter
  714. return klass(data, *args, **kwargs)