test_boolean.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634
  1. # coding=utf-8
  2. # pylint: disable-msg=E1101,W0612
  3. import numpy as np
  4. import pytest
  5. from pandas.compat import lrange, range
  6. from pandas.core.dtypes.common import is_integer
  7. import pandas as pd
  8. from pandas import Index, Series, Timestamp, date_range, isna
  9. from pandas.core.indexing import IndexingError
  10. import pandas.util.testing as tm
  11. from pandas.util.testing import assert_series_equal
  12. from pandas.tseries.offsets import BDay
  13. def test_getitem_boolean(test_data):
  14. s = test_data.series
  15. mask = s > s.median()
  16. # passing list is OK
  17. result = s[list(mask)]
  18. expected = s[mask]
  19. assert_series_equal(result, expected)
  20. tm.assert_index_equal(result.index, s.index[mask])
  21. def test_getitem_boolean_empty():
  22. s = Series([], dtype=np.int64)
  23. s.index.name = 'index_name'
  24. s = s[s.isna()]
  25. assert s.index.name == 'index_name'
  26. assert s.dtype == np.int64
  27. # GH5877
  28. # indexing with empty series
  29. s = Series(['A', 'B'])
  30. expected = Series(np.nan, index=['C'], dtype=object)
  31. result = s[Series(['C'], dtype=object)]
  32. assert_series_equal(result, expected)
  33. s = Series(['A', 'B'])
  34. expected = Series(dtype=object, index=Index([], dtype='int64'))
  35. result = s[Series([], dtype=object)]
  36. assert_series_equal(result, expected)
  37. # invalid because of the boolean indexer
  38. # that's empty or not-aligned
  39. msg = (r"Unalignable boolean Series provided as indexer \(index of"
  40. r" the boolean Series and of the indexed object do not match")
  41. with pytest.raises(IndexingError, match=msg):
  42. s[Series([], dtype=bool)]
  43. with pytest.raises(IndexingError, match=msg):
  44. s[Series([True], dtype=bool)]
  45. def test_getitem_boolean_object(test_data):
  46. # using column from DataFrame
  47. s = test_data.series
  48. mask = s > s.median()
  49. omask = mask.astype(object)
  50. # getitem
  51. result = s[omask]
  52. expected = s[mask]
  53. assert_series_equal(result, expected)
  54. # setitem
  55. s2 = s.copy()
  56. cop = s.copy()
  57. cop[omask] = 5
  58. s2[mask] = 5
  59. assert_series_equal(cop, s2)
  60. # nans raise exception
  61. omask[5:10] = np.nan
  62. msg = "cannot index with vector containing NA / NaN values"
  63. with pytest.raises(ValueError, match=msg):
  64. s[omask]
  65. with pytest.raises(ValueError, match=msg):
  66. s[omask] = 5
  67. def test_getitem_setitem_boolean_corner(test_data):
  68. ts = test_data.ts
  69. mask_shifted = ts.shift(1, freq=BDay()) > ts.median()
  70. # these used to raise...??
  71. msg = (r"Unalignable boolean Series provided as indexer \(index of"
  72. r" the boolean Series and of the indexed object do not match")
  73. with pytest.raises(IndexingError, match=msg):
  74. ts[mask_shifted]
  75. with pytest.raises(IndexingError, match=msg):
  76. ts[mask_shifted] = 1
  77. with pytest.raises(IndexingError, match=msg):
  78. ts.loc[mask_shifted]
  79. with pytest.raises(IndexingError, match=msg):
  80. ts.loc[mask_shifted] = 1
  81. def test_setitem_boolean(test_data):
  82. mask = test_data.series > test_data.series.median()
  83. # similar indexed series
  84. result = test_data.series.copy()
  85. result[mask] = test_data.series * 2
  86. expected = test_data.series * 2
  87. assert_series_equal(result[mask], expected[mask])
  88. # needs alignment
  89. result = test_data.series.copy()
  90. result[mask] = (test_data.series * 2)[0:5]
  91. expected = (test_data.series * 2)[0:5].reindex_like(test_data.series)
  92. expected[-mask] = test_data.series[mask]
  93. assert_series_equal(result[mask], expected[mask])
  94. def test_get_set_boolean_different_order(test_data):
  95. ordered = test_data.series.sort_values()
  96. # setting
  97. copy = test_data.series.copy()
  98. copy[ordered > 0] = 0
  99. expected = test_data.series.copy()
  100. expected[expected > 0] = 0
  101. assert_series_equal(copy, expected)
  102. # getting
  103. sel = test_data.series[ordered > 0]
  104. exp = test_data.series[test_data.series > 0]
  105. assert_series_equal(sel, exp)
  106. def test_where_unsafe_int(sint_dtype):
  107. s = Series(np.arange(10), dtype=sint_dtype)
  108. mask = s < 5
  109. s[mask] = lrange(2, 7)
  110. expected = Series(lrange(2, 7) + lrange(5, 10), dtype=sint_dtype)
  111. assert_series_equal(s, expected)
  112. def test_where_unsafe_float(float_dtype):
  113. s = Series(np.arange(10), dtype=float_dtype)
  114. mask = s < 5
  115. s[mask] = lrange(2, 7)
  116. expected = Series(lrange(2, 7) + lrange(5, 10), dtype=float_dtype)
  117. assert_series_equal(s, expected)
  118. @pytest.mark.parametrize("dtype,expected_dtype", [
  119. (np.int8, np.float64),
  120. (np.int16, np.float64),
  121. (np.int32, np.float64),
  122. (np.int64, np.float64),
  123. (np.float32, np.float32),
  124. (np.float64, np.float64)
  125. ])
  126. def test_where_unsafe_upcast(dtype, expected_dtype):
  127. # see gh-9743
  128. s = Series(np.arange(10), dtype=dtype)
  129. values = [2.5, 3.5, 4.5, 5.5, 6.5]
  130. mask = s < 5
  131. expected = Series(values + lrange(5, 10), dtype=expected_dtype)
  132. s[mask] = values
  133. assert_series_equal(s, expected)
  134. def test_where_unsafe():
  135. # see gh-9731
  136. s = Series(np.arange(10), dtype="int64")
  137. values = [2.5, 3.5, 4.5, 5.5]
  138. mask = s > 5
  139. expected = Series(lrange(6) + values, dtype="float64")
  140. s[mask] = values
  141. assert_series_equal(s, expected)
  142. # see gh-3235
  143. s = Series(np.arange(10), dtype='int64')
  144. mask = s < 5
  145. s[mask] = lrange(2, 7)
  146. expected = Series(lrange(2, 7) + lrange(5, 10), dtype='int64')
  147. assert_series_equal(s, expected)
  148. assert s.dtype == expected.dtype
  149. s = Series(np.arange(10), dtype='int64')
  150. mask = s > 5
  151. s[mask] = [0] * 4
  152. expected = Series([0, 1, 2, 3, 4, 5] + [0] * 4, dtype='int64')
  153. assert_series_equal(s, expected)
  154. s = Series(np.arange(10))
  155. mask = s > 5
  156. msg = "cannot assign mismatch length to masked array"
  157. with pytest.raises(ValueError, match=msg):
  158. s[mask] = [5, 4, 3, 2, 1]
  159. with pytest.raises(ValueError, match=msg):
  160. s[mask] = [0] * 5
  161. # dtype changes
  162. s = Series([1, 2, 3, 4])
  163. result = s.where(s > 2, np.nan)
  164. expected = Series([np.nan, np.nan, 3, 4])
  165. assert_series_equal(result, expected)
  166. # GH 4667
  167. # setting with None changes dtype
  168. s = Series(range(10)).astype(float)
  169. s[8] = None
  170. result = s[8]
  171. assert isna(result)
  172. s = Series(range(10)).astype(float)
  173. s[s > 8] = None
  174. result = s[isna(s)]
  175. expected = Series(np.nan, index=[9])
  176. assert_series_equal(result, expected)
  177. def test_where_raise_on_error_deprecation():
  178. # gh-14968
  179. # deprecation of raise_on_error
  180. s = Series(np.random.randn(5))
  181. cond = s > 0
  182. with tm.assert_produces_warning(FutureWarning):
  183. s.where(cond, raise_on_error=True)
  184. with tm.assert_produces_warning(FutureWarning):
  185. s.mask(cond, raise_on_error=True)
  186. def test_where():
  187. s = Series(np.random.randn(5))
  188. cond = s > 0
  189. rs = s.where(cond).dropna()
  190. rs2 = s[cond]
  191. assert_series_equal(rs, rs2)
  192. rs = s.where(cond, -s)
  193. assert_series_equal(rs, s.abs())
  194. rs = s.where(cond)
  195. assert (s.shape == rs.shape)
  196. assert (rs is not s)
  197. # test alignment
  198. cond = Series([True, False, False, True, False], index=s.index)
  199. s2 = -(s.abs())
  200. expected = s2[cond].reindex(s2.index[:3]).reindex(s2.index)
  201. rs = s2.where(cond[:3])
  202. assert_series_equal(rs, expected)
  203. expected = s2.abs()
  204. expected.iloc[0] = s2[0]
  205. rs = s2.where(cond[:3], -s2)
  206. assert_series_equal(rs, expected)
  207. def test_where_error():
  208. s = Series(np.random.randn(5))
  209. cond = s > 0
  210. msg = "Array conditional must be same shape as self"
  211. with pytest.raises(ValueError, match=msg):
  212. s.where(1)
  213. with pytest.raises(ValueError, match=msg):
  214. s.where(cond[:3].values, -s)
  215. # GH 2745
  216. s = Series([1, 2])
  217. s[[True, False]] = [0, 1]
  218. expected = Series([0, 2])
  219. assert_series_equal(s, expected)
  220. # failures
  221. msg = "cannot assign mismatch length to masked array"
  222. with pytest.raises(ValueError, match=msg):
  223. s[[True, False]] = [0, 2, 3]
  224. msg = ("NumPy boolean array indexing assignment cannot assign 0 input"
  225. " values to the 1 output values where the mask is true")
  226. with pytest.raises(ValueError, match=msg):
  227. s[[True, False]] = []
  228. @pytest.mark.parametrize('klass', [list, tuple, np.array, Series])
  229. def test_where_array_like(klass):
  230. # see gh-15414
  231. s = Series([1, 2, 3])
  232. cond = [False, True, True]
  233. expected = Series([np.nan, 2, 3])
  234. result = s.where(klass(cond))
  235. assert_series_equal(result, expected)
  236. @pytest.mark.parametrize('cond', [
  237. [1, 0, 1],
  238. Series([2, 5, 7]),
  239. ["True", "False", "True"],
  240. [Timestamp("2017-01-01"), pd.NaT, Timestamp("2017-01-02")]
  241. ])
  242. def test_where_invalid_input(cond):
  243. # see gh-15414: only boolean arrays accepted
  244. s = Series([1, 2, 3])
  245. msg = "Boolean array expected for the condition"
  246. with pytest.raises(ValueError, match=msg):
  247. s.where(cond)
  248. msg = "Array conditional must be same shape as self"
  249. with pytest.raises(ValueError, match=msg):
  250. s.where([True])
  251. def test_where_ndframe_align():
  252. msg = "Array conditional must be same shape as self"
  253. s = Series([1, 2, 3])
  254. cond = [True]
  255. with pytest.raises(ValueError, match=msg):
  256. s.where(cond)
  257. expected = Series([1, np.nan, np.nan])
  258. out = s.where(Series(cond))
  259. tm.assert_series_equal(out, expected)
  260. cond = np.array([False, True, False, True])
  261. with pytest.raises(ValueError, match=msg):
  262. s.where(cond)
  263. expected = Series([np.nan, 2, np.nan])
  264. out = s.where(Series(cond))
  265. tm.assert_series_equal(out, expected)
  266. def test_where_setitem_invalid():
  267. # GH 2702
  268. # make sure correct exceptions are raised on invalid list assignment
  269. msg = ("cannot set using a {} indexer with a different length than"
  270. " the value")
  271. # slice
  272. s = Series(list('abc'))
  273. with pytest.raises(ValueError, match=msg.format('slice')):
  274. s[0:3] = list(range(27))
  275. s[0:3] = list(range(3))
  276. expected = Series([0, 1, 2])
  277. assert_series_equal(s.astype(np.int64), expected, )
  278. # slice with step
  279. s = Series(list('abcdef'))
  280. with pytest.raises(ValueError, match=msg.format('slice')):
  281. s[0:4:2] = list(range(27))
  282. s = Series(list('abcdef'))
  283. s[0:4:2] = list(range(2))
  284. expected = Series([0, 'b', 1, 'd', 'e', 'f'])
  285. assert_series_equal(s, expected)
  286. # neg slices
  287. s = Series(list('abcdef'))
  288. with pytest.raises(ValueError, match=msg.format('slice')):
  289. s[:-1] = list(range(27))
  290. s[-3:-1] = list(range(2))
  291. expected = Series(['a', 'b', 'c', 0, 1, 'f'])
  292. assert_series_equal(s, expected)
  293. # list
  294. s = Series(list('abc'))
  295. with pytest.raises(ValueError, match=msg.format('list-like')):
  296. s[[0, 1, 2]] = list(range(27))
  297. s = Series(list('abc'))
  298. with pytest.raises(ValueError, match=msg.format('list-like')):
  299. s[[0, 1, 2]] = list(range(2))
  300. # scalar
  301. s = Series(list('abc'))
  302. s[0] = list(range(10))
  303. expected = Series([list(range(10)), 'b', 'c'])
  304. assert_series_equal(s, expected)
  305. @pytest.mark.parametrize('size', range(2, 6))
  306. @pytest.mark.parametrize('mask', [
  307. [True, False, False, False, False],
  308. [True, False],
  309. [False]
  310. ])
  311. @pytest.mark.parametrize('item', [
  312. 2.0, np.nan, np.finfo(np.float).max, np.finfo(np.float).min
  313. ])
  314. # Test numpy arrays, lists and tuples as the input to be
  315. # broadcast
  316. @pytest.mark.parametrize('box', [
  317. lambda x: np.array([x]),
  318. lambda x: [x],
  319. lambda x: (x,)
  320. ])
  321. def test_broadcast(size, mask, item, box):
  322. selection = np.resize(mask, size)
  323. data = np.arange(size, dtype=float)
  324. # Construct the expected series by taking the source
  325. # data or item based on the selection
  326. expected = Series([item if use_item else data[
  327. i] for i, use_item in enumerate(selection)])
  328. s = Series(data)
  329. s[selection] = box(item)
  330. assert_series_equal(s, expected)
  331. s = Series(data)
  332. result = s.where(~selection, box(item))
  333. assert_series_equal(result, expected)
  334. s = Series(data)
  335. result = s.mask(selection, box(item))
  336. assert_series_equal(result, expected)
  337. def test_where_inplace():
  338. s = Series(np.random.randn(5))
  339. cond = s > 0
  340. rs = s.copy()
  341. rs.where(cond, inplace=True)
  342. assert_series_equal(rs.dropna(), s[cond])
  343. assert_series_equal(rs, s.where(cond))
  344. rs = s.copy()
  345. rs.where(cond, -s, inplace=True)
  346. assert_series_equal(rs, s.where(cond, -s))
  347. def test_where_dups():
  348. # GH 4550
  349. # where crashes with dups in index
  350. s1 = Series(list(range(3)))
  351. s2 = Series(list(range(3)))
  352. comb = pd.concat([s1, s2])
  353. result = comb.where(comb < 2)
  354. expected = Series([0, 1, np.nan, 0, 1, np.nan],
  355. index=[0, 1, 2, 0, 1, 2])
  356. assert_series_equal(result, expected)
  357. # GH 4548
  358. # inplace updating not working with dups
  359. comb[comb < 1] = 5
  360. expected = Series([5, 1, 2, 5, 1, 2], index=[0, 1, 2, 0, 1, 2])
  361. assert_series_equal(comb, expected)
  362. comb[comb < 2] += 10
  363. expected = Series([5, 11, 2, 5, 11, 2], index=[0, 1, 2, 0, 1, 2])
  364. assert_series_equal(comb, expected)
  365. def test_where_numeric_with_string():
  366. # GH 9280
  367. s = pd.Series([1, 2, 3])
  368. w = s.where(s > 1, 'X')
  369. assert not is_integer(w[0])
  370. assert is_integer(w[1])
  371. assert is_integer(w[2])
  372. assert isinstance(w[0], str)
  373. assert w.dtype == 'object'
  374. w = s.where(s > 1, ['X', 'Y', 'Z'])
  375. assert not is_integer(w[0])
  376. assert is_integer(w[1])
  377. assert is_integer(w[2])
  378. assert isinstance(w[0], str)
  379. assert w.dtype == 'object'
  380. w = s.where(s > 1, np.array(['X', 'Y', 'Z']))
  381. assert not is_integer(w[0])
  382. assert is_integer(w[1])
  383. assert is_integer(w[2])
  384. assert isinstance(w[0], str)
  385. assert w.dtype == 'object'
  386. def test_where_timedelta_coerce():
  387. s = Series([1, 2], dtype='timedelta64[ns]')
  388. expected = Series([10, 10])
  389. mask = np.array([False, False])
  390. rs = s.where(mask, [10, 10])
  391. assert_series_equal(rs, expected)
  392. rs = s.where(mask, 10)
  393. assert_series_equal(rs, expected)
  394. rs = s.where(mask, 10.0)
  395. assert_series_equal(rs, expected)
  396. rs = s.where(mask, [10.0, 10.0])
  397. assert_series_equal(rs, expected)
  398. rs = s.where(mask, [10.0, np.nan])
  399. expected = Series([10, None], dtype='object')
  400. assert_series_equal(rs, expected)
  401. def test_where_datetime_conversion():
  402. s = Series(date_range('20130102', periods=2))
  403. expected = Series([10, 10])
  404. mask = np.array([False, False])
  405. rs = s.where(mask, [10, 10])
  406. assert_series_equal(rs, expected)
  407. rs = s.where(mask, 10)
  408. assert_series_equal(rs, expected)
  409. rs = s.where(mask, 10.0)
  410. assert_series_equal(rs, expected)
  411. rs = s.where(mask, [10.0, 10.0])
  412. assert_series_equal(rs, expected)
  413. rs = s.where(mask, [10.0, np.nan])
  414. expected = Series([10, None], dtype='object')
  415. assert_series_equal(rs, expected)
  416. # GH 15701
  417. timestamps = ['2016-12-31 12:00:04+00:00',
  418. '2016-12-31 12:00:04.010000+00:00']
  419. s = Series([pd.Timestamp(t) for t in timestamps])
  420. rs = s.where(Series([False, True]))
  421. expected = Series([pd.NaT, s[1]])
  422. assert_series_equal(rs, expected)
  423. def test_where_dt_tz_values(tz_naive_fixture):
  424. ser1 = pd.Series(pd.DatetimeIndex(['20150101', '20150102', '20150103'],
  425. tz=tz_naive_fixture))
  426. ser2 = pd.Series(pd.DatetimeIndex(['20160514', '20160515', '20160516'],
  427. tz=tz_naive_fixture))
  428. mask = pd.Series([True, True, False])
  429. result = ser1.where(mask, ser2)
  430. exp = pd.Series(pd.DatetimeIndex(['20150101', '20150102', '20160516'],
  431. tz=tz_naive_fixture))
  432. assert_series_equal(exp, result)
  433. def test_mask():
  434. # compare with tested results in test_where
  435. s = Series(np.random.randn(5))
  436. cond = s > 0
  437. rs = s.where(~cond, np.nan)
  438. assert_series_equal(rs, s.mask(cond))
  439. rs = s.where(~cond)
  440. rs2 = s.mask(cond)
  441. assert_series_equal(rs, rs2)
  442. rs = s.where(~cond, -s)
  443. rs2 = s.mask(cond, -s)
  444. assert_series_equal(rs, rs2)
  445. cond = Series([True, False, False, True, False], index=s.index)
  446. s2 = -(s.abs())
  447. rs = s2.where(~cond[:3])
  448. rs2 = s2.mask(cond[:3])
  449. assert_series_equal(rs, rs2)
  450. rs = s2.where(~cond[:3], -s2)
  451. rs2 = s2.mask(cond[:3], -s2)
  452. assert_series_equal(rs, rs2)
  453. msg = "Array conditional must be same shape as self"
  454. with pytest.raises(ValueError, match=msg):
  455. s.mask(1)
  456. with pytest.raises(ValueError, match=msg):
  457. s.mask(cond[:3].values, -s)
  458. # dtype changes
  459. s = Series([1, 2, 3, 4])
  460. result = s.mask(s > 2, np.nan)
  461. expected = Series([1, 2, np.nan, np.nan])
  462. assert_series_equal(result, expected)
  463. # see gh-21891
  464. s = Series([1, 2])
  465. res = s.mask([True, False])
  466. exp = Series([np.nan, 2])
  467. tm.assert_series_equal(res, exp)
  468. def test_mask_inplace():
  469. s = Series(np.random.randn(5))
  470. cond = s > 0
  471. rs = s.copy()
  472. rs.mask(cond, inplace=True)
  473. assert_series_equal(rs.dropna(), s[~cond])
  474. assert_series_equal(rs, s.mask(cond))
  475. rs = s.copy()
  476. rs.mask(cond, -s, inplace=True)
  477. assert_series_equal(rs, s.mask(cond, -s))