missing.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748
  1. """
  2. Routines for filling missing data
  3. """
  4. from distutils.version import LooseVersion
  5. import operator
  6. import numpy as np
  7. from pandas._libs import algos, lib
  8. from pandas.compat import range, string_types
  9. from pandas.core.dtypes.cast import infer_dtype_from_array
  10. from pandas.core.dtypes.common import (
  11. ensure_float64, is_datetime64_dtype, is_datetime64tz_dtype, is_float_dtype,
  12. is_integer, is_integer_dtype, is_numeric_v_string_like, is_scalar,
  13. is_timedelta64_dtype, needs_i8_conversion)
  14. from pandas.core.dtypes.missing import isna
  15. def mask_missing(arr, values_to_mask):
  16. """
  17. Return a masking array of same size/shape as arr
  18. with entries equaling any member of values_to_mask set to True
  19. """
  20. dtype, values_to_mask = infer_dtype_from_array(values_to_mask)
  21. try:
  22. values_to_mask = np.array(values_to_mask, dtype=dtype)
  23. except Exception:
  24. values_to_mask = np.array(values_to_mask, dtype=object)
  25. na_mask = isna(values_to_mask)
  26. nonna = values_to_mask[~na_mask]
  27. mask = None
  28. for x in nonna:
  29. if mask is None:
  30. # numpy elementwise comparison warning
  31. if is_numeric_v_string_like(arr, x):
  32. mask = False
  33. else:
  34. mask = arr == x
  35. # if x is a string and arr is not, then we get False and we must
  36. # expand the mask to size arr.shape
  37. if is_scalar(mask):
  38. mask = np.zeros(arr.shape, dtype=bool)
  39. else:
  40. # numpy elementwise comparison warning
  41. if is_numeric_v_string_like(arr, x):
  42. mask |= False
  43. else:
  44. mask |= arr == x
  45. if na_mask.any():
  46. if mask is None:
  47. mask = isna(arr)
  48. else:
  49. mask |= isna(arr)
  50. # GH 21977
  51. if mask is None:
  52. mask = np.zeros(arr.shape, dtype=bool)
  53. return mask
  54. def clean_fill_method(method, allow_nearest=False):
  55. # asfreq is compat for resampling
  56. if method in [None, 'asfreq']:
  57. return None
  58. if isinstance(method, string_types):
  59. method = method.lower()
  60. if method == 'ffill':
  61. method = 'pad'
  62. elif method == 'bfill':
  63. method = 'backfill'
  64. valid_methods = ['pad', 'backfill']
  65. expecting = 'pad (ffill) or backfill (bfill)'
  66. if allow_nearest:
  67. valid_methods.append('nearest')
  68. expecting = 'pad (ffill), backfill (bfill) or nearest'
  69. if method not in valid_methods:
  70. msg = ('Invalid fill method. Expecting {expecting}. Got {method}'
  71. .format(expecting=expecting, method=method))
  72. raise ValueError(msg)
  73. return method
  74. def clean_interp_method(method, **kwargs):
  75. order = kwargs.get('order')
  76. valid = ['linear', 'time', 'index', 'values', 'nearest', 'zero', 'slinear',
  77. 'quadratic', 'cubic', 'barycentric', 'polynomial', 'krogh',
  78. 'piecewise_polynomial', 'pchip', 'akima', 'spline',
  79. 'from_derivatives']
  80. if method in ('spline', 'polynomial') and order is None:
  81. raise ValueError("You must specify the order of the spline or "
  82. "polynomial.")
  83. if method not in valid:
  84. raise ValueError("method must be one of {valid}. Got '{method}' "
  85. "instead.".format(valid=valid, method=method))
  86. return method
  87. def interpolate_1d(xvalues, yvalues, method='linear', limit=None,
  88. limit_direction='forward', limit_area=None, fill_value=None,
  89. bounds_error=False, order=None, **kwargs):
  90. """
  91. Logic for the 1-d interpolation. The result should be 1-d, inputs
  92. xvalues and yvalues will each be 1-d arrays of the same length.
  93. Bounds_error is currently hardcoded to False since non-scipy ones don't
  94. take it as an argumnet.
  95. """
  96. # Treat the original, non-scipy methods first.
  97. invalid = isna(yvalues)
  98. valid = ~invalid
  99. if not valid.any():
  100. # have to call np.asarray(xvalues) since xvalues could be an Index
  101. # which can't be mutated
  102. result = np.empty_like(np.asarray(xvalues), dtype=np.float64)
  103. result.fill(np.nan)
  104. return result
  105. if valid.all():
  106. return yvalues
  107. if method == 'time':
  108. if not getattr(xvalues, 'is_all_dates', None):
  109. # if not issubclass(xvalues.dtype.type, np.datetime64):
  110. raise ValueError('time-weighted interpolation only works '
  111. 'on Series or DataFrames with a '
  112. 'DatetimeIndex')
  113. method = 'values'
  114. valid_limit_directions = ['forward', 'backward', 'both']
  115. limit_direction = limit_direction.lower()
  116. if limit_direction not in valid_limit_directions:
  117. msg = ('Invalid limit_direction: expecting one of {valid!r}, '
  118. 'got {invalid!r}.')
  119. raise ValueError(msg.format(valid=valid_limit_directions,
  120. invalid=limit_direction))
  121. if limit_area is not None:
  122. valid_limit_areas = ['inside', 'outside']
  123. limit_area = limit_area.lower()
  124. if limit_area not in valid_limit_areas:
  125. raise ValueError('Invalid limit_area: expecting one of {}, got '
  126. '{}.'.format(valid_limit_areas, limit_area))
  127. # default limit is unlimited GH #16282
  128. if limit is None:
  129. # limit = len(xvalues)
  130. pass
  131. elif not is_integer(limit):
  132. raise ValueError('Limit must be an integer')
  133. elif limit < 1:
  134. raise ValueError('Limit must be greater than 0')
  135. from pandas import Series
  136. ys = Series(yvalues)
  137. # These are sets of index pointers to invalid values... i.e. {0, 1, etc...
  138. all_nans = set(np.flatnonzero(invalid))
  139. start_nans = set(range(ys.first_valid_index()))
  140. end_nans = set(range(1 + ys.last_valid_index(), len(valid)))
  141. mid_nans = all_nans - start_nans - end_nans
  142. # Like the sets above, preserve_nans contains indices of invalid values,
  143. # but in this case, it is the final set of indices that need to be
  144. # preserved as NaN after the interpolation.
  145. # For example if limit_direction='forward' then preserve_nans will
  146. # contain indices of NaNs at the beginning of the series, and NaNs that
  147. # are more than'limit' away from the prior non-NaN.
  148. # set preserve_nans based on direction using _interp_limit
  149. if limit_direction == 'forward':
  150. preserve_nans = start_nans | set(_interp_limit(invalid, limit, 0))
  151. elif limit_direction == 'backward':
  152. preserve_nans = end_nans | set(_interp_limit(invalid, 0, limit))
  153. else:
  154. # both directions... just use _interp_limit
  155. preserve_nans = set(_interp_limit(invalid, limit, limit))
  156. # if limit_area is set, add either mid or outside indices
  157. # to preserve_nans GH #16284
  158. if limit_area == 'inside':
  159. # preserve NaNs on the outside
  160. preserve_nans |= start_nans | end_nans
  161. elif limit_area == 'outside':
  162. # preserve NaNs on the inside
  163. preserve_nans |= mid_nans
  164. # sort preserve_nans and covert to list
  165. preserve_nans = sorted(preserve_nans)
  166. xvalues = getattr(xvalues, 'values', xvalues)
  167. yvalues = getattr(yvalues, 'values', yvalues)
  168. result = yvalues.copy()
  169. if method in ['linear', 'time', 'index', 'values']:
  170. if method in ('values', 'index'):
  171. inds = np.asarray(xvalues)
  172. # hack for DatetimeIndex, #1646
  173. if needs_i8_conversion(inds.dtype.type):
  174. inds = inds.view(np.int64)
  175. if inds.dtype == np.object_:
  176. inds = lib.maybe_convert_objects(inds)
  177. else:
  178. inds = xvalues
  179. result[invalid] = np.interp(inds[invalid], inds[valid], yvalues[valid])
  180. result[preserve_nans] = np.nan
  181. return result
  182. sp_methods = ['nearest', 'zero', 'slinear', 'quadratic', 'cubic',
  183. 'barycentric', 'krogh', 'spline', 'polynomial',
  184. 'from_derivatives', 'piecewise_polynomial', 'pchip', 'akima']
  185. if method in sp_methods:
  186. inds = np.asarray(xvalues)
  187. # hack for DatetimeIndex, #1646
  188. if issubclass(inds.dtype.type, np.datetime64):
  189. inds = inds.view(np.int64)
  190. result[invalid] = _interpolate_scipy_wrapper(inds[valid],
  191. yvalues[valid],
  192. inds[invalid],
  193. method=method,
  194. fill_value=fill_value,
  195. bounds_error=bounds_error,
  196. order=order, **kwargs)
  197. result[preserve_nans] = np.nan
  198. return result
  199. def _interpolate_scipy_wrapper(x, y, new_x, method, fill_value=None,
  200. bounds_error=False, order=None, **kwargs):
  201. """
  202. passed off to scipy.interpolate.interp1d. method is scipy's kind.
  203. Returns an array interpolated at new_x. Add any new methods to
  204. the list in _clean_interp_method
  205. """
  206. try:
  207. from scipy import interpolate
  208. # TODO: Why is DatetimeIndex being imported here?
  209. from pandas import DatetimeIndex # noqa
  210. except ImportError:
  211. raise ImportError('{method} interpolation requires SciPy'
  212. .format(method=method))
  213. new_x = np.asarray(new_x)
  214. # ignores some kwargs that could be passed along.
  215. alt_methods = {
  216. 'barycentric': interpolate.barycentric_interpolate,
  217. 'krogh': interpolate.krogh_interpolate,
  218. 'from_derivatives': _from_derivatives,
  219. 'piecewise_polynomial': _from_derivatives,
  220. }
  221. if getattr(x, 'is_all_dates', False):
  222. # GH 5975, scipy.interp1d can't hande datetime64s
  223. x, new_x = x._values.astype('i8'), new_x.astype('i8')
  224. if method == 'pchip':
  225. try:
  226. alt_methods['pchip'] = interpolate.pchip_interpolate
  227. except AttributeError:
  228. raise ImportError("Your version of Scipy does not support "
  229. "PCHIP interpolation.")
  230. elif method == 'akima':
  231. try:
  232. from scipy.interpolate import Akima1DInterpolator # noqa
  233. alt_methods['akima'] = _akima_interpolate
  234. except ImportError:
  235. raise ImportError("Your version of Scipy does not support "
  236. "Akima interpolation.")
  237. interp1d_methods = ['nearest', 'zero', 'slinear', 'quadratic', 'cubic',
  238. 'polynomial']
  239. if method in interp1d_methods:
  240. if method == 'polynomial':
  241. method = order
  242. terp = interpolate.interp1d(x, y, kind=method, fill_value=fill_value,
  243. bounds_error=bounds_error)
  244. new_y = terp(new_x)
  245. elif method == 'spline':
  246. # GH #10633
  247. if not order:
  248. raise ValueError("order needs to be specified and greater than 0")
  249. terp = interpolate.UnivariateSpline(x, y, k=order, **kwargs)
  250. new_y = terp(new_x)
  251. else:
  252. # GH 7295: need to be able to write for some reason
  253. # in some circumstances: check all three
  254. if not x.flags.writeable:
  255. x = x.copy()
  256. if not y.flags.writeable:
  257. y = y.copy()
  258. if not new_x.flags.writeable:
  259. new_x = new_x.copy()
  260. method = alt_methods[method]
  261. new_y = method(x, y, new_x, **kwargs)
  262. return new_y
  263. def _from_derivatives(xi, yi, x, order=None, der=0, extrapolate=False):
  264. """
  265. Convenience function for interpolate.BPoly.from_derivatives
  266. Construct a piecewise polynomial in the Bernstein basis, compatible
  267. with the specified values and derivatives at breakpoints.
  268. Parameters
  269. ----------
  270. xi : array_like
  271. sorted 1D array of x-coordinates
  272. yi : array_like or list of array-likes
  273. yi[i][j] is the j-th derivative known at xi[i]
  274. orders : None or int or array_like of ints. Default: None.
  275. Specifies the degree of local polynomials. If not None, some
  276. derivatives are ignored.
  277. der : int or list
  278. How many derivatives to extract; None for all potentially nonzero
  279. derivatives (that is a number equal to the number of points), or a
  280. list of derivatives to extract. This numberincludes the function
  281. value as 0th derivative.
  282. extrapolate : bool, optional
  283. Whether to extrapolate to ouf-of-bounds points based on first and last
  284. intervals, or to return NaNs. Default: True.
  285. See Also
  286. --------
  287. scipy.interpolate.BPoly.from_derivatives
  288. Returns
  289. -------
  290. y : scalar or array_like
  291. The result, of length R or length M or M by R,
  292. """
  293. import scipy
  294. from scipy import interpolate
  295. if LooseVersion(scipy.__version__) < LooseVersion('0.18.0'):
  296. try:
  297. method = interpolate.piecewise_polynomial_interpolate
  298. return method(xi, yi.reshape(-1, 1), x,
  299. orders=order, der=der)
  300. except AttributeError:
  301. pass
  302. # return the method for compat with scipy version & backwards compat
  303. method = interpolate.BPoly.from_derivatives
  304. m = method(xi, yi.reshape(-1, 1),
  305. orders=order, extrapolate=extrapolate)
  306. return m(x)
  307. def _akima_interpolate(xi, yi, x, der=0, axis=0):
  308. """
  309. Convenience function for akima interpolation.
  310. xi and yi are arrays of values used to approximate some function f,
  311. with ``yi = f(xi)``.
  312. See `Akima1DInterpolator` for details.
  313. Parameters
  314. ----------
  315. xi : array_like
  316. A sorted list of x-coordinates, of length N.
  317. yi : array_like
  318. A 1-D array of real values. `yi`'s length along the interpolation
  319. axis must be equal to the length of `xi`. If N-D array, use axis
  320. parameter to select correct axis.
  321. x : scalar or array_like
  322. Of length M.
  323. der : int or list, optional
  324. How many derivatives to extract; None for all potentially
  325. nonzero derivatives (that is a number equal to the number
  326. of points), or a list of derivatives to extract. This number
  327. includes the function value as 0th derivative.
  328. axis : int, optional
  329. Axis in the yi array corresponding to the x-coordinate values.
  330. See Also
  331. --------
  332. scipy.interpolate.Akima1DInterpolator
  333. Returns
  334. -------
  335. y : scalar or array_like
  336. The result, of length R or length M or M by R,
  337. """
  338. from scipy import interpolate
  339. try:
  340. P = interpolate.Akima1DInterpolator(xi, yi, axis=axis)
  341. except TypeError:
  342. # Scipy earlier than 0.17.0 missing axis
  343. P = interpolate.Akima1DInterpolator(xi, yi)
  344. if der == 0:
  345. return P(x)
  346. elif interpolate._isscalar(der):
  347. return P(x, der=der)
  348. else:
  349. return [P(x, nu) for nu in der]
  350. def interpolate_2d(values, method='pad', axis=0, limit=None, fill_value=None,
  351. dtype=None):
  352. """ perform an actual interpolation of values, values will be make 2-d if
  353. needed fills inplace, returns the result
  354. """
  355. transf = (lambda x: x) if axis == 0 else (lambda x: x.T)
  356. # reshape a 1 dim if needed
  357. ndim = values.ndim
  358. if values.ndim == 1:
  359. if axis != 0: # pragma: no cover
  360. raise AssertionError("cannot interpolate on a ndim == 1 with "
  361. "axis != 0")
  362. values = values.reshape(tuple((1,) + values.shape))
  363. if fill_value is None:
  364. mask = None
  365. else: # todo create faster fill func without masking
  366. mask = mask_missing(transf(values), fill_value)
  367. method = clean_fill_method(method)
  368. if method == 'pad':
  369. values = transf(pad_2d(
  370. transf(values), limit=limit, mask=mask, dtype=dtype))
  371. else:
  372. values = transf(backfill_2d(
  373. transf(values), limit=limit, mask=mask, dtype=dtype))
  374. # reshape back
  375. if ndim == 1:
  376. values = values[0]
  377. return values
  378. def _cast_values_for_fillna(values, dtype):
  379. """
  380. Cast values to a dtype that algos.pad and algos.backfill can handle.
  381. """
  382. # TODO: for int-dtypes we make a copy, but for everything else this
  383. # alters the values in-place. Is this intentional?
  384. if (is_datetime64_dtype(dtype) or is_datetime64tz_dtype(dtype) or
  385. is_timedelta64_dtype(dtype)):
  386. values = values.view(np.int64)
  387. elif is_integer_dtype(values):
  388. # NB: this check needs to come after the datetime64 check above
  389. values = ensure_float64(values)
  390. return values
  391. def _fillna_prep(values, mask=None, dtype=None):
  392. # boilerplate for pad_1d, backfill_1d, pad_2d, backfill_2d
  393. if dtype is None:
  394. dtype = values.dtype
  395. if mask is None:
  396. # This needs to occur before datetime/timedeltas are cast to int64
  397. mask = isna(values)
  398. values = _cast_values_for_fillna(values, dtype)
  399. mask = mask.view(np.uint8)
  400. return values, mask
  401. def pad_1d(values, limit=None, mask=None, dtype=None):
  402. values, mask = _fillna_prep(values, mask, dtype)
  403. algos.pad_inplace(values, mask, limit=limit)
  404. return values
  405. def backfill_1d(values, limit=None, mask=None, dtype=None):
  406. values, mask = _fillna_prep(values, mask, dtype)
  407. algos.backfill_inplace(values, mask, limit=limit)
  408. return values
  409. def pad_2d(values, limit=None, mask=None, dtype=None):
  410. values, mask = _fillna_prep(values, mask, dtype)
  411. if np.all(values.shape):
  412. algos.pad_2d_inplace(values, mask, limit=limit)
  413. else:
  414. # for test coverage
  415. pass
  416. return values
  417. def backfill_2d(values, limit=None, mask=None, dtype=None):
  418. values, mask = _fillna_prep(values, mask, dtype)
  419. if np.all(values.shape):
  420. algos.backfill_2d_inplace(values, mask, limit=limit)
  421. else:
  422. # for test coverage
  423. pass
  424. return values
  425. _fill_methods = {'pad': pad_1d, 'backfill': backfill_1d}
  426. def get_fill_func(method):
  427. method = clean_fill_method(method)
  428. return _fill_methods[method]
  429. def clean_reindex_fill_method(method):
  430. return clean_fill_method(method, allow_nearest=True)
  431. def fill_zeros(result, x, y, name, fill):
  432. """
  433. if this is a reversed op, then flip x,y
  434. if we have an integer value (or array in y)
  435. and we have 0's, fill them with the fill,
  436. return the result
  437. mask the nan's from x
  438. """
  439. if fill is None or is_float_dtype(result):
  440. return result
  441. if name.startswith(('r', '__r')):
  442. x, y = y, x
  443. is_variable_type = (hasattr(y, 'dtype') or hasattr(y, 'type'))
  444. is_scalar_type = is_scalar(y)
  445. if not is_variable_type and not is_scalar_type:
  446. return result
  447. if is_scalar_type:
  448. y = np.array(y)
  449. if is_integer_dtype(y):
  450. if (y == 0).any():
  451. # GH 7325, mask and nans must be broadcastable (also: PR 9308)
  452. # Raveling and then reshaping makes np.putmask faster
  453. mask = ((y == 0) & ~np.isnan(result)).ravel()
  454. shape = result.shape
  455. result = result.astype('float64', copy=False).ravel()
  456. np.putmask(result, mask, fill)
  457. # if we have a fill of inf, then sign it correctly
  458. # (GH 6178 and PR 9308)
  459. if np.isinf(fill):
  460. signs = y if name.startswith(('r', '__r')) else x
  461. signs = np.sign(signs.astype('float', copy=False))
  462. negative_inf_mask = (signs.ravel() < 0) & mask
  463. np.putmask(result, negative_inf_mask, -fill)
  464. if "floordiv" in name: # (PR 9308)
  465. nan_mask = ((y == 0) & (x == 0)).ravel()
  466. np.putmask(result, nan_mask, np.nan)
  467. result = result.reshape(shape)
  468. return result
  469. def mask_zero_div_zero(x, y, result, copy=False):
  470. """
  471. Set results of 0 / 0 or 0 // 0 to np.nan, regardless of the dtypes
  472. of the numerator or the denominator.
  473. Parameters
  474. ----------
  475. x : ndarray
  476. y : ndarray
  477. result : ndarray
  478. copy : bool (default False)
  479. Whether to always create a new array or try to fill in the existing
  480. array if possible.
  481. Returns
  482. -------
  483. filled_result : ndarray
  484. Examples
  485. --------
  486. >>> x = np.array([1, 0, -1], dtype=np.int64)
  487. >>> y = 0 # int 0; numpy behavior is different with float
  488. >>> result = x / y
  489. >>> result # raw numpy result does not fill division by zero
  490. array([0, 0, 0])
  491. >>> mask_zero_div_zero(x, y, result)
  492. array([ inf, nan, -inf])
  493. """
  494. if is_scalar(y):
  495. y = np.array(y)
  496. zmask = y == 0
  497. if zmask.any():
  498. shape = result.shape
  499. nan_mask = (zmask & (x == 0)).ravel()
  500. neginf_mask = (zmask & (x < 0)).ravel()
  501. posinf_mask = (zmask & (x > 0)).ravel()
  502. if nan_mask.any() or neginf_mask.any() or posinf_mask.any():
  503. # Fill negative/0 with -inf, positive/0 with +inf, 0/0 with NaN
  504. result = result.astype('float64', copy=copy).ravel()
  505. np.putmask(result, nan_mask, np.nan)
  506. np.putmask(result, posinf_mask, np.inf)
  507. np.putmask(result, neginf_mask, -np.inf)
  508. result = result.reshape(shape)
  509. return result
  510. def dispatch_missing(op, left, right, result):
  511. """
  512. Fill nulls caused by division by zero, casting to a diffferent dtype
  513. if necessary.
  514. Parameters
  515. ----------
  516. op : function (operator.add, operator.div, ...)
  517. left : object (Index for non-reversed ops)
  518. right : object (Index fof reversed ops)
  519. result : ndarray
  520. Returns
  521. -------
  522. result : ndarray
  523. """
  524. opstr = '__{opname}__'.format(opname=op.__name__).replace('____', '__')
  525. if op in [operator.truediv, operator.floordiv,
  526. getattr(operator, 'div', None)]:
  527. result = mask_zero_div_zero(left, right, result)
  528. elif op is operator.mod:
  529. result = fill_zeros(result, left, right, opstr, np.nan)
  530. elif op is divmod:
  531. res0 = mask_zero_div_zero(left, right, result[0])
  532. res1 = fill_zeros(result[1], left, right, opstr, np.nan)
  533. result = (res0, res1)
  534. return result
  535. def _interp_limit(invalid, fw_limit, bw_limit):
  536. """
  537. Get indexers of values that won't be filled
  538. because they exceed the limits.
  539. Parameters
  540. ----------
  541. invalid : boolean ndarray
  542. fw_limit : int or None
  543. forward limit to index
  544. bw_limit : int or None
  545. backward limit to index
  546. Returns
  547. -------
  548. set of indexers
  549. Notes
  550. -----
  551. This is equivalent to the more readable, but slower
  552. .. code-block:: python
  553. def _interp_limit(invalid, fw_limit, bw_limit):
  554. for x in np.where(invalid)[0]:
  555. if invalid[max(0, x - fw_limit):x + bw_limit + 1].all():
  556. yield x
  557. """
  558. # handle forward first; the backward direction is the same except
  559. # 1. operate on the reversed array
  560. # 2. subtract the returned indices from N - 1
  561. N = len(invalid)
  562. f_idx = set()
  563. b_idx = set()
  564. def inner(invalid, limit):
  565. limit = min(limit, N)
  566. windowed = _rolling_window(invalid, limit + 1).all(1)
  567. idx = (set(np.where(windowed)[0] + limit) |
  568. set(np.where((~invalid[:limit + 1]).cumsum() == 0)[0]))
  569. return idx
  570. if fw_limit is not None:
  571. if fw_limit == 0:
  572. f_idx = set(np.where(invalid)[0])
  573. else:
  574. f_idx = inner(invalid, fw_limit)
  575. if bw_limit is not None:
  576. if bw_limit == 0:
  577. # then we don't even need to care about backwards
  578. # just use forwards
  579. return f_idx
  580. else:
  581. b_idx = list(inner(invalid[::-1], bw_limit))
  582. b_idx = set(N - 1 - np.asarray(b_idx))
  583. if fw_limit == 0:
  584. return b_idx
  585. return f_idx & b_idx
  586. def _rolling_window(a, window):
  587. """
  588. [True, True, False, True, False], 2 ->
  589. [
  590. [True, True],
  591. [True, False],
  592. [False, True],
  593. [True, False],
  594. ]
  595. """
  596. # https://stackoverflow.com/a/6811241
  597. shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
  598. strides = a.strides + (a.strides[-1],)
  599. return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)