expressions.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. """
  2. Expressions
  3. -----------
  4. Offer fast expression evaluation through numexpr
  5. """
  6. import warnings
  7. import numpy as np
  8. from pandas.core.dtypes.generic import ABCDataFrame
  9. import pandas.core.common as com
  10. from pandas.core.computation.check import _NUMEXPR_INSTALLED
  11. from pandas.core.config import get_option
  12. if _NUMEXPR_INSTALLED:
  13. import numexpr as ne
  14. _TEST_MODE = None
  15. _TEST_RESULT = None
  16. _USE_NUMEXPR = _NUMEXPR_INSTALLED
  17. _evaluate = None
  18. _where = None
  19. # the set of dtypes that we will allow pass to numexpr
  20. _ALLOWED_DTYPES = {
  21. 'evaluate': {'int64', 'int32', 'float64', 'float32', 'bool'},
  22. 'where': {'int64', 'float64', 'bool'}
  23. }
  24. # the minimum prod shape that we will use numexpr
  25. _MIN_ELEMENTS = 10000
  26. def set_use_numexpr(v=True):
  27. # set/unset to use numexpr
  28. global _USE_NUMEXPR
  29. if _NUMEXPR_INSTALLED:
  30. _USE_NUMEXPR = v
  31. # choose what we are going to do
  32. global _evaluate, _where
  33. if not _USE_NUMEXPR:
  34. _evaluate = _evaluate_standard
  35. _where = _where_standard
  36. else:
  37. _evaluate = _evaluate_numexpr
  38. _where = _where_numexpr
  39. def set_numexpr_threads(n=None):
  40. # if we are using numexpr, set the threads to n
  41. # otherwise reset
  42. if _NUMEXPR_INSTALLED and _USE_NUMEXPR:
  43. if n is None:
  44. n = ne.detect_number_of_cores()
  45. ne.set_num_threads(n)
  46. def _evaluate_standard(op, op_str, a, b, **eval_kwargs):
  47. """ standard evaluation """
  48. if _TEST_MODE:
  49. _store_test_result(False)
  50. with np.errstate(all='ignore'):
  51. return op(a, b)
  52. def _can_use_numexpr(op, op_str, a, b, dtype_check):
  53. """ return a boolean if we WILL be using numexpr """
  54. if op_str is not None:
  55. # required min elements (otherwise we are adding overhead)
  56. if np.prod(a.shape) > _MIN_ELEMENTS:
  57. # check for dtype compatibility
  58. dtypes = set()
  59. for o in [a, b]:
  60. if hasattr(o, 'get_dtype_counts'):
  61. s = o.get_dtype_counts()
  62. if len(s) > 1:
  63. return False
  64. dtypes |= set(s.index)
  65. elif isinstance(o, np.ndarray):
  66. dtypes |= {o.dtype.name}
  67. # allowed are a superset
  68. if not len(dtypes) or _ALLOWED_DTYPES[dtype_check] >= dtypes:
  69. return True
  70. return False
  71. def _evaluate_numexpr(op, op_str, a, b, truediv=True,
  72. reversed=False, **eval_kwargs):
  73. result = None
  74. if _can_use_numexpr(op, op_str, a, b, 'evaluate'):
  75. try:
  76. # we were originally called by a reversed op
  77. # method
  78. if reversed:
  79. a, b = b, a
  80. a_value = getattr(a, "values", a)
  81. b_value = getattr(b, "values", b)
  82. result = ne.evaluate('a_value {op} b_value'.format(op=op_str),
  83. local_dict={'a_value': a_value,
  84. 'b_value': b_value},
  85. casting='safe', truediv=truediv,
  86. **eval_kwargs)
  87. except ValueError as detail:
  88. if 'unknown type object' in str(detail):
  89. pass
  90. if _TEST_MODE:
  91. _store_test_result(result is not None)
  92. if result is None:
  93. result = _evaluate_standard(op, op_str, a, b)
  94. return result
  95. def _where_standard(cond, a, b):
  96. return np.where(com.values_from_object(cond), com.values_from_object(a),
  97. com.values_from_object(b))
  98. def _where_numexpr(cond, a, b):
  99. result = None
  100. if _can_use_numexpr(None, 'where', a, b, 'where'):
  101. try:
  102. cond_value = getattr(cond, 'values', cond)
  103. a_value = getattr(a, 'values', a)
  104. b_value = getattr(b, 'values', b)
  105. result = ne.evaluate('where(cond_value, a_value, b_value)',
  106. local_dict={'cond_value': cond_value,
  107. 'a_value': a_value,
  108. 'b_value': b_value},
  109. casting='safe')
  110. except ValueError as detail:
  111. if 'unknown type object' in str(detail):
  112. pass
  113. except Exception as detail:
  114. raise TypeError(str(detail))
  115. if result is None:
  116. result = _where_standard(cond, a, b)
  117. return result
  118. # turn myself on
  119. set_use_numexpr(get_option('compute.use_numexpr'))
  120. def _has_bool_dtype(x):
  121. try:
  122. if isinstance(x, ABCDataFrame):
  123. return 'bool' in x.dtypes
  124. else:
  125. return x.dtype == bool
  126. except AttributeError:
  127. return isinstance(x, (bool, np.bool_))
  128. def _bool_arith_check(op_str, a, b, not_allowed=frozenset(('/', '//', '**')),
  129. unsupported=None):
  130. if unsupported is None:
  131. unsupported = {'+': '|', '*': '&', '-': '^'}
  132. if _has_bool_dtype(a) and _has_bool_dtype(b):
  133. if op_str in unsupported:
  134. warnings.warn("evaluating in Python space because the {op!r} "
  135. "operator is not supported by numexpr for "
  136. "the bool dtype, use {alt_op!r} instead"
  137. .format(op=op_str, alt_op=unsupported[op_str]))
  138. return False
  139. if op_str in not_allowed:
  140. raise NotImplementedError("operator {op!r} not implemented for "
  141. "bool dtypes".format(op=op_str))
  142. return True
  143. def evaluate(op, op_str, a, b, use_numexpr=True,
  144. **eval_kwargs):
  145. """ evaluate and return the expression of the op on a and b
  146. Parameters
  147. ----------
  148. op : the actual operand
  149. op_str: the string version of the op
  150. a : left operand
  151. b : right operand
  152. use_numexpr : whether to try to use numexpr (default True)
  153. """
  154. use_numexpr = use_numexpr and _bool_arith_check(op_str, a, b)
  155. if use_numexpr:
  156. return _evaluate(op, op_str, a, b, **eval_kwargs)
  157. return _evaluate_standard(op, op_str, a, b)
  158. def where(cond, a, b, use_numexpr=True):
  159. """ evaluate the where condition cond on a and b
  160. Parameters
  161. ----------
  162. cond : a boolean array
  163. a : return if cond is True
  164. b : return if cond is False
  165. use_numexpr : whether to try to use numexpr (default True)
  166. """
  167. if use_numexpr:
  168. return _where(cond, a, b)
  169. return _where_standard(cond, a, b)
  170. def set_test_mode(v=True):
  171. """
  172. Keeps track of whether numexpr was used. Stores an additional ``True``
  173. for every successful use of evaluate with numexpr since the last
  174. ``get_test_result``
  175. """
  176. global _TEST_MODE, _TEST_RESULT
  177. _TEST_MODE = v
  178. _TEST_RESULT = []
  179. def _store_test_result(used_numexpr):
  180. global _TEST_RESULT
  181. if used_numexpr:
  182. _TEST_RESULT.append(used_numexpr)
  183. def get_test_result():
  184. """get test result and reset test_results"""
  185. global _TEST_RESULT
  186. res = _TEST_RESULT
  187. _TEST_RESULT = []
  188. return res