apply.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411
  1. import warnings
  2. import numpy as np
  3. from pandas._libs import reduction
  4. import pandas.compat as compat
  5. from pandas.util._decorators import cache_readonly
  6. from pandas.core.dtypes.common import (
  7. is_dict_like, is_extension_type, is_list_like, is_sequence)
  8. from pandas.core.dtypes.generic import ABCSeries
  9. from pandas.io.formats.printing import pprint_thing
  10. def frame_apply(obj, func, axis=0, broadcast=None,
  11. raw=False, reduce=None, result_type=None,
  12. ignore_failures=False,
  13. args=None, kwds=None):
  14. """ construct and return a row or column based frame apply object """
  15. axis = obj._get_axis_number(axis)
  16. if axis == 0:
  17. klass = FrameRowApply
  18. elif axis == 1:
  19. klass = FrameColumnApply
  20. return klass(obj, func, broadcast=broadcast,
  21. raw=raw, reduce=reduce, result_type=result_type,
  22. ignore_failures=ignore_failures,
  23. args=args, kwds=kwds)
  24. class FrameApply(object):
  25. def __init__(self, obj, func, broadcast, raw, reduce, result_type,
  26. ignore_failures, args, kwds):
  27. self.obj = obj
  28. self.raw = raw
  29. self.ignore_failures = ignore_failures
  30. self.args = args or ()
  31. self.kwds = kwds or {}
  32. if result_type not in [None, 'reduce', 'broadcast', 'expand']:
  33. raise ValueError("invalid value for result_type, must be one "
  34. "of {None, 'reduce', 'broadcast', 'expand'}")
  35. if broadcast is not None:
  36. warnings.warn("The broadcast argument is deprecated and will "
  37. "be removed in a future version. You can specify "
  38. "result_type='broadcast' to broadcast the result "
  39. "to the original dimensions",
  40. FutureWarning, stacklevel=4)
  41. if broadcast:
  42. result_type = 'broadcast'
  43. if reduce is not None:
  44. warnings.warn("The reduce argument is deprecated and will "
  45. "be removed in a future version. You can specify "
  46. "result_type='reduce' to try to reduce the result "
  47. "to the original dimensions",
  48. FutureWarning, stacklevel=4)
  49. if reduce:
  50. if result_type is not None:
  51. raise ValueError(
  52. "cannot pass both reduce=True and result_type")
  53. result_type = 'reduce'
  54. self.result_type = result_type
  55. # curry if needed
  56. if ((kwds or args) and
  57. not isinstance(func, (np.ufunc, compat.string_types))):
  58. def f(x):
  59. return func(x, *args, **kwds)
  60. else:
  61. f = func
  62. self.f = f
  63. # results
  64. self.result = None
  65. self.res_index = None
  66. self.res_columns = None
  67. @property
  68. def columns(self):
  69. return self.obj.columns
  70. @property
  71. def index(self):
  72. return self.obj.index
  73. @cache_readonly
  74. def values(self):
  75. return self.obj.values
  76. @cache_readonly
  77. def dtypes(self):
  78. return self.obj.dtypes
  79. @property
  80. def agg_axis(self):
  81. return self.obj._get_agg_axis(self.axis)
  82. def get_result(self):
  83. """ compute the results """
  84. # dispatch to agg
  85. if is_list_like(self.f) or is_dict_like(self.f):
  86. return self.obj.aggregate(self.f, axis=self.axis,
  87. *self.args, **self.kwds)
  88. # all empty
  89. if len(self.columns) == 0 and len(self.index) == 0:
  90. return self.apply_empty_result()
  91. # string dispatch
  92. if isinstance(self.f, compat.string_types):
  93. # Support for `frame.transform('method')`
  94. # Some methods (shift, etc.) require the axis argument, others
  95. # don't, so inspect and insert if necessary.
  96. func = getattr(self.obj, self.f)
  97. sig = compat.signature(func)
  98. if 'axis' in sig.args:
  99. self.kwds['axis'] = self.axis
  100. return func(*self.args, **self.kwds)
  101. # ufunc
  102. elif isinstance(self.f, np.ufunc):
  103. with np.errstate(all='ignore'):
  104. results = self.obj._data.apply('apply', func=self.f)
  105. return self.obj._constructor(data=results, index=self.index,
  106. columns=self.columns, copy=False)
  107. # broadcasting
  108. if self.result_type == 'broadcast':
  109. return self.apply_broadcast()
  110. # one axis empty
  111. elif not all(self.obj.shape):
  112. return self.apply_empty_result()
  113. # raw
  114. elif self.raw and not self.obj._is_mixed_type:
  115. return self.apply_raw()
  116. return self.apply_standard()
  117. def apply_empty_result(self):
  118. """
  119. we have an empty result; at least 1 axis is 0
  120. we will try to apply the function to an empty
  121. series in order to see if this is a reduction function
  122. """
  123. # we are not asked to reduce or infer reduction
  124. # so just return a copy of the existing object
  125. if self.result_type not in ['reduce', None]:
  126. return self.obj.copy()
  127. # we may need to infer
  128. reduce = self.result_type == 'reduce'
  129. from pandas import Series
  130. if not reduce:
  131. EMPTY_SERIES = Series([])
  132. try:
  133. r = self.f(EMPTY_SERIES, *self.args, **self.kwds)
  134. reduce = not isinstance(r, Series)
  135. except Exception:
  136. pass
  137. if reduce:
  138. return self.obj._constructor_sliced(np.nan, index=self.agg_axis)
  139. else:
  140. return self.obj.copy()
  141. def apply_raw(self):
  142. """ apply to the values as a numpy array """
  143. try:
  144. result = reduction.reduce(self.values, self.f, axis=self.axis)
  145. except Exception:
  146. result = np.apply_along_axis(self.f, self.axis, self.values)
  147. # TODO: mixed type case
  148. if result.ndim == 2:
  149. return self.obj._constructor(result,
  150. index=self.index,
  151. columns=self.columns)
  152. else:
  153. return self.obj._constructor_sliced(result,
  154. index=self.agg_axis)
  155. def apply_broadcast(self, target):
  156. result_values = np.empty_like(target.values)
  157. # axis which we want to compare compliance
  158. result_compare = target.shape[0]
  159. for i, col in enumerate(target.columns):
  160. res = self.f(target[col])
  161. ares = np.asarray(res).ndim
  162. # must be a scalar or 1d
  163. if ares > 1:
  164. raise ValueError("too many dims to broadcast")
  165. elif ares == 1:
  166. # must match return dim
  167. if result_compare != len(res):
  168. raise ValueError("cannot broadcast result")
  169. result_values[:, i] = res
  170. # we *always* preserve the original index / columns
  171. result = self.obj._constructor(result_values,
  172. index=target.index,
  173. columns=target.columns)
  174. return result
  175. def apply_standard(self):
  176. # try to reduce first (by default)
  177. # this only matters if the reduction in values is of different dtype
  178. # e.g. if we want to apply to a SparseFrame, then can't directly reduce
  179. # we cannot reduce using non-numpy dtypes,
  180. # as demonstrated in gh-12244
  181. if (self.result_type in ['reduce', None] and
  182. not self.dtypes.apply(is_extension_type).any()):
  183. # Create a dummy Series from an empty array
  184. from pandas import Series
  185. values = self.values
  186. index = self.obj._get_axis(self.axis)
  187. labels = self.agg_axis
  188. empty_arr = np.empty(len(index), dtype=values.dtype)
  189. dummy = Series(empty_arr, index=index, dtype=values.dtype)
  190. try:
  191. result = reduction.reduce(values, self.f,
  192. axis=self.axis,
  193. dummy=dummy,
  194. labels=labels)
  195. return self.obj._constructor_sliced(result, index=labels)
  196. except Exception:
  197. pass
  198. # compute the result using the series generator
  199. self.apply_series_generator()
  200. # wrap results
  201. return self.wrap_results()
  202. def apply_series_generator(self):
  203. series_gen = self.series_generator
  204. res_index = self.result_index
  205. i = None
  206. keys = []
  207. results = {}
  208. if self.ignore_failures:
  209. successes = []
  210. for i, v in enumerate(series_gen):
  211. try:
  212. results[i] = self.f(v)
  213. keys.append(v.name)
  214. successes.append(i)
  215. except Exception:
  216. pass
  217. # so will work with MultiIndex
  218. if len(successes) < len(res_index):
  219. res_index = res_index.take(successes)
  220. else:
  221. try:
  222. for i, v in enumerate(series_gen):
  223. results[i] = self.f(v)
  224. keys.append(v.name)
  225. except Exception as e:
  226. if hasattr(e, 'args'):
  227. # make sure i is defined
  228. if i is not None:
  229. k = res_index[i]
  230. e.args = e.args + ('occurred at index %s' %
  231. pprint_thing(k), )
  232. raise
  233. self.results = results
  234. self.res_index = res_index
  235. self.res_columns = self.result_columns
  236. def wrap_results(self):
  237. results = self.results
  238. # see if we can infer the results
  239. if len(results) > 0 and is_sequence(results[0]):
  240. return self.wrap_results_for_axis()
  241. # dict of scalars
  242. result = self.obj._constructor_sliced(results)
  243. result.index = self.res_index
  244. return result
  245. class FrameRowApply(FrameApply):
  246. axis = 0
  247. def apply_broadcast(self):
  248. return super(FrameRowApply, self).apply_broadcast(self.obj)
  249. @property
  250. def series_generator(self):
  251. return (self.obj._ixs(i, axis=1)
  252. for i in range(len(self.columns)))
  253. @property
  254. def result_index(self):
  255. return self.columns
  256. @property
  257. def result_columns(self):
  258. return self.index
  259. def wrap_results_for_axis(self):
  260. """ return the results for the rows """
  261. results = self.results
  262. result = self.obj._constructor(data=results)
  263. if not isinstance(results[0], ABCSeries):
  264. try:
  265. result.index = self.res_columns
  266. except ValueError:
  267. pass
  268. try:
  269. result.columns = self.res_index
  270. except ValueError:
  271. pass
  272. return result
  273. class FrameColumnApply(FrameApply):
  274. axis = 1
  275. def apply_broadcast(self):
  276. result = super(FrameColumnApply, self).apply_broadcast(self.obj.T)
  277. return result.T
  278. @property
  279. def series_generator(self):
  280. constructor = self.obj._constructor_sliced
  281. return (constructor(arr, index=self.columns, name=name)
  282. for i, (arr, name) in enumerate(zip(self.values,
  283. self.index)))
  284. @property
  285. def result_index(self):
  286. return self.index
  287. @property
  288. def result_columns(self):
  289. return self.columns
  290. def wrap_results_for_axis(self):
  291. """ return the results for the columns """
  292. results = self.results
  293. # we have requested to expand
  294. if self.result_type == 'expand':
  295. result = self.infer_to_same_shape()
  296. # we have a non-series and don't want inference
  297. elif not isinstance(results[0], ABCSeries):
  298. from pandas import Series
  299. result = Series(results)
  300. result.index = self.res_index
  301. # we may want to infer results
  302. else:
  303. result = self.infer_to_same_shape()
  304. return result
  305. def infer_to_same_shape(self):
  306. """ infer the results to the same shape as the input object """
  307. results = self.results
  308. result = self.obj._constructor(data=results)
  309. result = result.T
  310. # set the index
  311. result.index = self.res_index
  312. # infer dtypes
  313. result = result.infer_objects()
  314. return result