test_expressions.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463
  1. # -*- coding: utf-8 -*-
  2. from __future__ import print_function
  3. import operator
  4. import re
  5. from warnings import catch_warnings, simplefilter
  6. import numpy as np
  7. from numpy.random import randn
  8. import pytest
  9. from pandas import _np_version_under1p13, compat
  10. from pandas.core.api import DataFrame, Panel
  11. from pandas.core.computation import expressions as expr
  12. import pandas.util.testing as tm
  13. from pandas.util.testing import (
  14. assert_almost_equal, assert_frame_equal, assert_panel_equal,
  15. assert_series_equal)
  16. from pandas.io.formats.printing import pprint_thing
  17. # pylint: disable-msg=W0612,E1101
  18. _frame = DataFrame(randn(10000, 4), columns=list('ABCD'), dtype='float64')
  19. _frame2 = DataFrame(randn(100, 4), columns=list('ABCD'), dtype='float64')
  20. _mixed = DataFrame({'A': _frame['A'].copy(),
  21. 'B': _frame['B'].astype('float32'),
  22. 'C': _frame['C'].astype('int64'),
  23. 'D': _frame['D'].astype('int32')})
  24. _mixed2 = DataFrame({'A': _frame2['A'].copy(),
  25. 'B': _frame2['B'].astype('float32'),
  26. 'C': _frame2['C'].astype('int64'),
  27. 'D': _frame2['D'].astype('int32')})
  28. _integer = DataFrame(
  29. np.random.randint(1, 100,
  30. size=(10001, 4)),
  31. columns=list('ABCD'), dtype='int64')
  32. _integer2 = DataFrame(np.random.randint(1, 100, size=(101, 4)),
  33. columns=list('ABCD'), dtype='int64')
  34. with catch_warnings(record=True):
  35. simplefilter("ignore", FutureWarning)
  36. _frame_panel = Panel(dict(ItemA=_frame.copy(),
  37. ItemB=(_frame.copy() + 3),
  38. ItemC=_frame.copy(),
  39. ItemD=_frame.copy()))
  40. _frame2_panel = Panel(dict(ItemA=_frame2.copy(),
  41. ItemB=(_frame2.copy() + 3),
  42. ItemC=_frame2.copy(),
  43. ItemD=_frame2.copy()))
  44. _integer_panel = Panel(dict(ItemA=_integer,
  45. ItemB=(_integer + 34).astype('int64')))
  46. _integer2_panel = Panel(dict(ItemA=_integer2,
  47. ItemB=(_integer2 + 34).astype('int64')))
  48. _mixed_panel = Panel(dict(ItemA=_mixed, ItemB=(_mixed + 3)))
  49. _mixed2_panel = Panel(dict(ItemA=_mixed2, ItemB=(_mixed2 + 3)))
  50. @pytest.mark.skipif(not expr._USE_NUMEXPR, reason='not using numexpr')
  51. class TestExpressions(object):
  52. def setup_method(self, method):
  53. self.frame = _frame.copy()
  54. self.frame2 = _frame2.copy()
  55. self.mixed = _mixed.copy()
  56. self.mixed2 = _mixed2.copy()
  57. self.integer = _integer.copy()
  58. self._MIN_ELEMENTS = expr._MIN_ELEMENTS
  59. def teardown_method(self, method):
  60. expr._MIN_ELEMENTS = self._MIN_ELEMENTS
  61. def run_arithmetic(self, df, other, assert_func, check_dtype=False,
  62. test_flex=True):
  63. expr._MIN_ELEMENTS = 0
  64. operations = ['add', 'sub', 'mul', 'mod', 'truediv', 'floordiv']
  65. if not compat.PY3:
  66. operations.append('div')
  67. for arith in operations:
  68. operator_name = arith
  69. if arith == 'div':
  70. operator_name = 'truediv'
  71. if test_flex:
  72. op = lambda x, y: getattr(df, arith)(y)
  73. op.__name__ = arith
  74. else:
  75. op = getattr(operator, operator_name)
  76. expr.set_use_numexpr(False)
  77. expected = op(df, other)
  78. expr.set_use_numexpr(True)
  79. result = op(df, other)
  80. try:
  81. if check_dtype:
  82. if arith == 'truediv':
  83. assert expected.dtype.kind == 'f'
  84. assert_func(expected, result)
  85. except Exception:
  86. pprint_thing("Failed test with operator %r" % op.__name__)
  87. raise
  88. def test_integer_arithmetic(self):
  89. self.run_arithmetic(self.integer, self.integer,
  90. assert_frame_equal)
  91. self.run_arithmetic(self.integer.iloc[:, 0],
  92. self.integer.iloc[:, 0], assert_series_equal,
  93. check_dtype=True)
  94. def run_binary(self, df, other, assert_func, test_flex=False,
  95. numexpr_ops={'gt', 'lt', 'ge', 'le', 'eq', 'ne'}):
  96. """
  97. tests solely that the result is the same whether or not numexpr is
  98. enabled. Need to test whether the function does the correct thing
  99. elsewhere.
  100. """
  101. expr._MIN_ELEMENTS = 0
  102. expr.set_test_mode(True)
  103. operations = ['gt', 'lt', 'ge', 'le', 'eq', 'ne']
  104. for arith in operations:
  105. if test_flex:
  106. op = lambda x, y: getattr(df, arith)(y)
  107. op.__name__ = arith
  108. else:
  109. op = getattr(operator, arith)
  110. expr.set_use_numexpr(False)
  111. expected = op(df, other)
  112. expr.set_use_numexpr(True)
  113. expr.get_test_result()
  114. result = op(df, other)
  115. used_numexpr = expr.get_test_result()
  116. try:
  117. if arith in numexpr_ops:
  118. assert used_numexpr, "Did not use numexpr as expected."
  119. else:
  120. assert not used_numexpr, "Used numexpr unexpectedly."
  121. assert_func(expected, result)
  122. except Exception:
  123. pprint_thing("Failed test with operation %r" % arith)
  124. pprint_thing("test_flex was %r" % test_flex)
  125. raise
  126. def run_frame(self, df, other, binary_comp=None, run_binary=True,
  127. **kwargs):
  128. self.run_arithmetic(df, other, assert_frame_equal,
  129. test_flex=False, **kwargs)
  130. self.run_arithmetic(df, other, assert_frame_equal, test_flex=True,
  131. **kwargs)
  132. if run_binary:
  133. if binary_comp is None:
  134. expr.set_use_numexpr(False)
  135. binary_comp = other + 1
  136. expr.set_use_numexpr(True)
  137. self.run_binary(df, binary_comp, assert_frame_equal,
  138. test_flex=False, **kwargs)
  139. self.run_binary(df, binary_comp, assert_frame_equal,
  140. test_flex=True, **kwargs)
  141. def run_series(self, ser, other, binary_comp=None, **kwargs):
  142. self.run_arithmetic(ser, other, assert_series_equal,
  143. test_flex=False, **kwargs)
  144. self.run_arithmetic(ser, other, assert_almost_equal,
  145. test_flex=True, **kwargs)
  146. # series doesn't uses vec_compare instead of numexpr...
  147. # if binary_comp is None:
  148. # binary_comp = other + 1
  149. # self.run_binary(ser, binary_comp, assert_frame_equal,
  150. # test_flex=False, **kwargs)
  151. # self.run_binary(ser, binary_comp, assert_frame_equal,
  152. # test_flex=True, **kwargs)
  153. def run_panel(self, panel, other, binary_comp=None, run_binary=True,
  154. assert_func=assert_panel_equal, **kwargs):
  155. self.run_arithmetic(panel, other, assert_func, test_flex=False,
  156. **kwargs)
  157. self.run_arithmetic(panel, other, assert_func, test_flex=True,
  158. **kwargs)
  159. if run_binary:
  160. if binary_comp is None:
  161. binary_comp = other + 1
  162. self.run_binary(panel, binary_comp, assert_func,
  163. test_flex=False, **kwargs)
  164. self.run_binary(panel, binary_comp, assert_func,
  165. test_flex=True, **kwargs)
  166. def test_integer_arithmetic_frame(self):
  167. self.run_frame(self.integer, self.integer)
  168. def test_integer_arithmetic_series(self):
  169. self.run_series(self.integer.iloc[:, 0], self.integer.iloc[:, 0])
  170. @pytest.mark.slow
  171. @pytest.mark.filterwarnings("ignore:\\nPanel:FutureWarning")
  172. def test_integer_panel(self):
  173. self.run_panel(_integer2_panel, np.random.randint(1, 100))
  174. def test_float_arithemtic_frame(self):
  175. self.run_frame(self.frame2, self.frame2)
  176. def test_float_arithmetic_series(self):
  177. self.run_series(self.frame2.iloc[:, 0], self.frame2.iloc[:, 0])
  178. @pytest.mark.slow
  179. @pytest.mark.filterwarnings("ignore:\\nPanel:FutureWarning")
  180. def test_float_panel(self):
  181. self.run_panel(_frame2_panel, np.random.randn() + 0.1, binary_comp=0.8)
  182. def test_mixed_arithmetic_frame(self):
  183. # TODO: FIGURE OUT HOW TO GET IT TO WORK...
  184. # can't do arithmetic because comparison methods try to do *entire*
  185. # frame instead of by-column
  186. self.run_frame(self.mixed2, self.mixed2, run_binary=False)
  187. def test_mixed_arithmetic_series(self):
  188. for col in self.mixed2.columns:
  189. self.run_series(self.mixed2[col], self.mixed2[col], binary_comp=4)
  190. @pytest.mark.slow
  191. @pytest.mark.filterwarnings("ignore:\\nPanel:FutureWarning")
  192. def test_mixed_panel(self):
  193. self.run_panel(_mixed2_panel, np.random.randint(1, 100),
  194. binary_comp=-2)
  195. def test_float_arithemtic(self):
  196. self.run_arithmetic(self.frame, self.frame, assert_frame_equal)
  197. self.run_arithmetic(self.frame.iloc[:, 0], self.frame.iloc[:, 0],
  198. assert_series_equal, check_dtype=True)
  199. def test_mixed_arithmetic(self):
  200. self.run_arithmetic(self.mixed, self.mixed, assert_frame_equal)
  201. for col in self.mixed.columns:
  202. self.run_arithmetic(self.mixed[col], self.mixed[col],
  203. assert_series_equal)
  204. def test_integer_with_zeros(self):
  205. self.integer *= np.random.randint(0, 2, size=np.shape(self.integer))
  206. self.run_arithmetic(self.integer, self.integer,
  207. assert_frame_equal)
  208. self.run_arithmetic(self.integer.iloc[:, 0],
  209. self.integer.iloc[:, 0], assert_series_equal)
  210. def test_invalid(self):
  211. # no op
  212. result = expr._can_use_numexpr(operator.add, None, self.frame,
  213. self.frame, 'evaluate')
  214. assert not result
  215. # mixed
  216. result = expr._can_use_numexpr(operator.add, '+', self.mixed,
  217. self.frame, 'evaluate')
  218. assert not result
  219. # min elements
  220. result = expr._can_use_numexpr(operator.add, '+', self.frame2,
  221. self.frame2, 'evaluate')
  222. assert not result
  223. # ok, we only check on first part of expression
  224. result = expr._can_use_numexpr(operator.add, '+', self.frame,
  225. self.frame2, 'evaluate')
  226. assert result
  227. def test_binary_ops(self):
  228. def testit():
  229. for f, f2 in [(self.frame, self.frame2),
  230. (self.mixed, self.mixed2)]:
  231. for op, op_str in [('add', '+'), ('sub', '-'), ('mul', '*'),
  232. ('div', '/'), ('pow', '**')]:
  233. if op == 'pow':
  234. continue
  235. if op == 'div':
  236. op = getattr(operator, 'truediv', None)
  237. else:
  238. op = getattr(operator, op, None)
  239. if op is not None:
  240. result = expr._can_use_numexpr(op, op_str, f, f,
  241. 'evaluate')
  242. assert result != f._is_mixed_type
  243. result = expr.evaluate(op, op_str, f, f,
  244. use_numexpr=True)
  245. expected = expr.evaluate(op, op_str, f, f,
  246. use_numexpr=False)
  247. if isinstance(result, DataFrame):
  248. tm.assert_frame_equal(result, expected)
  249. else:
  250. tm.assert_numpy_array_equal(result,
  251. expected.values)
  252. result = expr._can_use_numexpr(op, op_str, f2, f2,
  253. 'evaluate')
  254. assert not result
  255. expr.set_use_numexpr(False)
  256. testit()
  257. expr.set_use_numexpr(True)
  258. expr.set_numexpr_threads(1)
  259. testit()
  260. expr.set_numexpr_threads()
  261. testit()
  262. def test_boolean_ops(self):
  263. def testit():
  264. for f, f2 in [(self.frame, self.frame2),
  265. (self.mixed, self.mixed2)]:
  266. f11 = f
  267. f12 = f + 1
  268. f21 = f2
  269. f22 = f2 + 1
  270. for op, op_str in [('gt', '>'), ('lt', '<'), ('ge', '>='),
  271. ('le', '<='), ('eq', '=='), ('ne', '!=')]:
  272. op = getattr(operator, op)
  273. result = expr._can_use_numexpr(op, op_str, f11, f12,
  274. 'evaluate')
  275. assert result != f11._is_mixed_type
  276. result = expr.evaluate(op, op_str, f11, f12,
  277. use_numexpr=True)
  278. expected = expr.evaluate(op, op_str, f11, f12,
  279. use_numexpr=False)
  280. if isinstance(result, DataFrame):
  281. tm.assert_frame_equal(result, expected)
  282. else:
  283. tm.assert_numpy_array_equal(result, expected.values)
  284. result = expr._can_use_numexpr(op, op_str, f21, f22,
  285. 'evaluate')
  286. assert not result
  287. expr.set_use_numexpr(False)
  288. testit()
  289. expr.set_use_numexpr(True)
  290. expr.set_numexpr_threads(1)
  291. testit()
  292. expr.set_numexpr_threads()
  293. testit()
  294. def test_where(self):
  295. def testit():
  296. for f in [self.frame, self.frame2, self.mixed, self.mixed2]:
  297. for cond in [True, False]:
  298. c = np.empty(f.shape, dtype=np.bool_)
  299. c.fill(cond)
  300. result = expr.where(c, f.values, f.values + 1)
  301. expected = np.where(c, f.values, f.values + 1)
  302. tm.assert_numpy_array_equal(result, expected)
  303. expr.set_use_numexpr(False)
  304. testit()
  305. expr.set_use_numexpr(True)
  306. expr.set_numexpr_threads(1)
  307. testit()
  308. expr.set_numexpr_threads()
  309. testit()
  310. def test_bool_ops_raise_on_arithmetic(self):
  311. df = DataFrame({'a': np.random.rand(10) > 0.5,
  312. 'b': np.random.rand(10) > 0.5})
  313. names = 'div', 'truediv', 'floordiv', 'pow'
  314. ops = '/', '/', '//', '**'
  315. msg = 'operator %r not implemented for bool dtypes'
  316. for op, name in zip(ops, names):
  317. if not compat.PY3 or name != 'div':
  318. f = getattr(operator, name)
  319. err_msg = re.escape(msg % op)
  320. with pytest.raises(NotImplementedError, match=err_msg):
  321. f(df, df)
  322. with pytest.raises(NotImplementedError, match=err_msg):
  323. f(df.a, df.b)
  324. with pytest.raises(NotImplementedError, match=err_msg):
  325. f(df.a, True)
  326. with pytest.raises(NotImplementedError, match=err_msg):
  327. f(False, df.a)
  328. with pytest.raises(NotImplementedError, match=err_msg):
  329. f(False, df)
  330. with pytest.raises(NotImplementedError, match=err_msg):
  331. f(df, True)
  332. def test_bool_ops_warn_on_arithmetic(self):
  333. n = 10
  334. df = DataFrame({'a': np.random.rand(n) > 0.5,
  335. 'b': np.random.rand(n) > 0.5})
  336. names = 'add', 'mul', 'sub'
  337. ops = '+', '*', '-'
  338. subs = {'+': '|', '*': '&', '-': '^'}
  339. sub_funcs = {'|': 'or_', '&': 'and_', '^': 'xor'}
  340. for op, name in zip(ops, names):
  341. f = getattr(operator, name)
  342. fe = getattr(operator, sub_funcs[subs[op]])
  343. # >= 1.13.0 these are now TypeErrors
  344. if op == '-' and not _np_version_under1p13:
  345. continue
  346. with tm.use_numexpr(True, min_elements=5):
  347. with tm.assert_produces_warning(check_stacklevel=False):
  348. r = f(df, df)
  349. e = fe(df, df)
  350. tm.assert_frame_equal(r, e)
  351. with tm.assert_produces_warning(check_stacklevel=False):
  352. r = f(df.a, df.b)
  353. e = fe(df.a, df.b)
  354. tm.assert_series_equal(r, e)
  355. with tm.assert_produces_warning(check_stacklevel=False):
  356. r = f(df.a, True)
  357. e = fe(df.a, True)
  358. tm.assert_series_equal(r, e)
  359. with tm.assert_produces_warning(check_stacklevel=False):
  360. r = f(False, df.a)
  361. e = fe(False, df.a)
  362. tm.assert_series_equal(r, e)
  363. with tm.assert_produces_warning(check_stacklevel=False):
  364. r = f(False, df)
  365. e = fe(False, df)
  366. tm.assert_frame_equal(r, e)
  367. with tm.assert_produces_warning(check_stacklevel=False):
  368. r = f(df, True)
  369. e = fe(df, True)
  370. tm.assert_frame_equal(r, e)
  371. @pytest.mark.parametrize("test_input,expected", [
  372. (DataFrame([[0, 1, 2, 'aa'], [0, 1, 2, 'aa']],
  373. columns=['a', 'b', 'c', 'dtype']),
  374. DataFrame([[False, False], [False, False]],
  375. columns=['a', 'dtype'])),
  376. (DataFrame([[0, 3, 2, 'aa'], [0, 4, 2, 'aa'], [0, 1, 1, 'bb']],
  377. columns=['a', 'b', 'c', 'dtype']),
  378. DataFrame([[False, False], [False, False],
  379. [False, False]], columns=['a', 'dtype'])),
  380. ])
  381. def test_bool_ops_column_name_dtype(self, test_input, expected):
  382. # GH 22383 - .ne fails if columns containing column name 'dtype'
  383. result = test_input.loc[:, ['a', 'dtype']].ne(
  384. test_input.loc[:, ['a', 'dtype']])
  385. assert_frame_equal(result, expected)