Optimize.py 203 KB


  1. from __future__ import absolute_import
  2. import re
  3. import sys
  4. import copy
  5. import codecs
  6. import itertools
  7. from . import TypeSlots
  8. from .ExprNodes import not_a_constant
  9. import cython
  10. cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object, encoded_string=object,
  11. Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
  12. UtilNodes=object, _py_int_types=object)
  13. if sys.version_info[0] >= 3:
  14. _py_int_types = int
  15. _py_string_types = (bytes, str)
  16. else:
  17. _py_int_types = (int, long)
  18. _py_string_types = (bytes, unicode)
  19. from . import Nodes
  20. from . import ExprNodes
  21. from . import PyrexTypes
  22. from . import Visitor
  23. from . import Builtin
  24. from . import UtilNodes
  25. from . import Options
  26. from .Code import UtilityCode, TempitaUtilityCode
  27. from .StringEncoding import EncodedString, bytes_literal, encoded_string
  28. from .Errors import error, warning
  29. from .ParseTreeTransforms import SkipDeclarations
  30. try:
  31. from __builtin__ import reduce
  32. except ImportError:
  33. from functools import reduce
  34. try:
  35. from __builtin__ import basestring
  36. except ImportError:
  37. basestring = str # Python 3
  38. def load_c_utility(name):
  39. return UtilityCode.load_cached(name, "Optimize.c")
  40. def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
  41. if isinstance(node, coercion_nodes):
  42. return node.arg
  43. return node
  44. def unwrap_node(node):
  45. while isinstance(node, UtilNodes.ResultRefNode):
  46. node = node.expression
  47. return node
  48. def is_common_value(a, b):
  49. a = unwrap_node(a)
  50. b = unwrap_node(b)
  51. if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
  52. return a.name == b.name
  53. if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
  54. return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
  55. return False
  56. def filter_none_node(node):
  57. if node is not None and node.constant_result is None:
  58. return None
  59. return node
  60. class _YieldNodeCollector(Visitor.TreeVisitor):
  61. """
  62. YieldExprNode finder for generator expressions.
  63. """
  64. def __init__(self):
  65. Visitor.TreeVisitor.__init__(self)
  66. self.yield_stat_nodes = {}
  67. self.yield_nodes = []
  68. visit_Node = Visitor.TreeVisitor.visitchildren
  69. def visit_YieldExprNode(self, node):
  70. self.yield_nodes.append(node)
  71. self.visitchildren(node)
  72. def visit_ExprStatNode(self, node):
  73. self.visitchildren(node)
  74. if node.expr in self.yield_nodes:
  75. self.yield_stat_nodes[node.expr] = node
  76. # everything below these nodes is out of scope:
  77. def visit_GeneratorExpressionNode(self, node):
  78. pass
  79. def visit_LambdaNode(self, node):
  80. pass
  81. def visit_FuncDefNode(self, node):
  82. pass
  83. def _find_single_yield_expression(node):
  84. yield_statements = _find_yield_statements(node)
  85. if len(yield_statements) != 1:
  86. return None, None
  87. return yield_statements[0]
  88. def _find_yield_statements(node):
  89. collector = _YieldNodeCollector()
  90. collector.visitchildren(node)
  91. try:
  92. yield_statements = [
  93. (yield_node.arg, collector.yield_stat_nodes[yield_node])
  94. for yield_node in collector.yield_nodes
  95. ]
  96. except KeyError:
  97. # found YieldExprNode without ExprStatNode (i.e. a non-statement usage of 'yield')
  98. yield_statements = []
  99. return yield_statements
  100. class IterationTransform(Visitor.EnvTransform):
  101. """Transform some common for-in loop patterns into efficient C loops:
  102. - for-in-dict loop becomes a while loop calling PyDict_Next()
  103. - for-in-enumerate is replaced by an external counter variable
  104. - for-in-range loop becomes a plain C for loop
  105. """
  106. def visit_PrimaryCmpNode(self, node):
  107. if node.is_ptr_contains():
  108. # for t in operand2:
  109. # if operand1 == t:
  110. # res = True
  111. # break
  112. # else:
  113. # res = False
  114. pos = node.pos
  115. result_ref = UtilNodes.ResultRefNode(node)
  116. if node.operand2.is_subscript:
  117. base_type = node.operand2.base.type.base_type
  118. else:
  119. base_type = node.operand2.type.base_type
  120. target_handle = UtilNodes.TempHandle(base_type)
  121. target = target_handle.ref(pos)
  122. cmp_node = ExprNodes.PrimaryCmpNode(
  123. pos, operator=u'==', operand1=node.operand1, operand2=target)
  124. if_body = Nodes.StatListNode(
  125. pos,
  126. stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
  127. Nodes.BreakStatNode(pos)])
  128. if_node = Nodes.IfStatNode(
  129. pos,
  130. if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
  131. else_clause=None)
  132. for_loop = UtilNodes.TempsBlockNode(
  133. pos,
  134. temps = [target_handle],
  135. body = Nodes.ForInStatNode(
  136. pos,
  137. target=target,
  138. iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
  139. body=if_node,
  140. else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
  141. for_loop = for_loop.analyse_expressions(self.current_env())
  142. for_loop = self.visit(for_loop)
  143. new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
  144. if node.operator == 'not_in':
  145. new_node = ExprNodes.NotNode(pos, operand=new_node)
  146. return new_node
  147. else:
  148. self.visitchildren(node)
  149. return node
  150. def visit_ForInStatNode(self, node):
  151. self.visitchildren(node)
  152. return self._optimise_for_loop(node, node.iterator.sequence)
  153. def _optimise_for_loop(self, node, iterable, reversed=False):
  154. annotation_type = None
  155. if (iterable.is_name or iterable.is_attribute) and iterable.entry and iterable.entry.annotation:
  156. annotation = iterable.entry.annotation
  157. if annotation.is_subscript:
  158. annotation = annotation.base # container base type
  159. # FIXME: generalise annotation evaluation => maybe provide a "qualified name" also for imported names?
  160. if annotation.is_name:
  161. if annotation.entry and annotation.entry.qualified_name == 'typing.Dict':
  162. annotation_type = Builtin.dict_type
  163. elif annotation.name == 'Dict':
  164. annotation_type = Builtin.dict_type
  165. if annotation.entry and annotation.entry.qualified_name in ('typing.Set', 'typing.FrozenSet'):
  166. annotation_type = Builtin.set_type
  167. elif annotation.name in ('Set', 'FrozenSet'):
  168. annotation_type = Builtin.set_type
  169. if Builtin.dict_type in (iterable.type, annotation_type):
  170. # like iterating over dict.keys()
  171. if reversed:
  172. # CPython raises an error here: not a sequence
  173. return node
  174. return self._transform_dict_iteration(
  175. node, dict_obj=iterable, method=None, keys=True, values=False)
  176. if (Builtin.set_type in (iterable.type, annotation_type) or
  177. Builtin.frozenset_type in (iterable.type, annotation_type)):
  178. if reversed:
  179. # CPython raises an error here: not a sequence
  180. return node
  181. return self._transform_set_iteration(node, iterable)
  182. # C array (slice) iteration?
  183. if iterable.type.is_ptr or iterable.type.is_array:
  184. return self._transform_carray_iteration(node, iterable, reversed=reversed)
  185. if iterable.type is Builtin.bytes_type:
  186. return self._transform_bytes_iteration(node, iterable, reversed=reversed)
  187. if iterable.type is Builtin.unicode_type:
  188. return self._transform_unicode_iteration(node, iterable, reversed=reversed)
  189. # the rest is based on function calls
  190. if not isinstance(iterable, ExprNodes.SimpleCallNode):
  191. return node
  192. if iterable.args is None:
  193. arg_count = iterable.arg_tuple and len(iterable.arg_tuple.args) or 0
  194. else:
  195. arg_count = len(iterable.args)
  196. if arg_count and iterable.self is not None:
  197. arg_count -= 1
  198. function = iterable.function
  199. # dict iteration?
  200. if function.is_attribute and not reversed and not arg_count:
  201. base_obj = iterable.self or function.obj
  202. method = function.attribute
  203. # in Py3, items() is equivalent to Py2's iteritems()
  204. is_safe_iter = self.global_scope().context.language_level >= 3
  205. if not is_safe_iter and method in ('keys', 'values', 'items'):
  206. # try to reduce this to the corresponding .iter*() methods
  207. if isinstance(base_obj, ExprNodes.CallNode):
  208. inner_function = base_obj.function
  209. if (inner_function.is_name and inner_function.name == 'dict'
  210. and inner_function.entry
  211. and inner_function.entry.is_builtin):
  212. # e.g. dict(something).items() => safe to use .iter*()
  213. is_safe_iter = True
  214. keys = values = False
  215. if method == 'iterkeys' or (is_safe_iter and method == 'keys'):
  216. keys = True
  217. elif method == 'itervalues' or (is_safe_iter and method == 'values'):
  218. values = True
  219. elif method == 'iteritems' or (is_safe_iter and method == 'items'):
  220. keys = values = True
  221. if keys or values:
  222. return self._transform_dict_iteration(
  223. node, base_obj, method, keys, values)
  224. # enumerate/reversed ?
  225. if iterable.self is None and function.is_name and \
  226. function.entry and function.entry.is_builtin:
  227. if function.name == 'enumerate':
  228. if reversed:
  229. # CPython raises an error here: not a sequence
  230. return node
  231. return self._transform_enumerate_iteration(node, iterable)
  232. elif function.name == 'reversed':
  233. if reversed:
  234. # CPython raises an error here: not a sequence
  235. return node
  236. return self._transform_reversed_iteration(node, iterable)
  237. # range() iteration?
  238. if Options.convert_range and arg_count >= 1 and (
  239. iterable.self is None and
  240. function.is_name and function.name in ('range', 'xrange') and
  241. function.entry and function.entry.is_builtin):
  242. if node.target.type.is_int or node.target.type.is_enum:
  243. return self._transform_range_iteration(node, iterable, reversed=reversed)
  244. if node.target.type.is_pyobject:
  245. # Assume that small integer ranges (C long >= 32bit) are best handled in C as well.
  246. for arg in (iterable.arg_tuple.args if iterable.args is None else iterable.args):
  247. if isinstance(arg, ExprNodes.IntNode):
  248. if arg.has_constant_result() and -2**30 <= arg.constant_result < 2**30:
  249. continue
  250. break
  251. else:
  252. return self._transform_range_iteration(node, iterable, reversed=reversed)
  253. return node
  254. def _transform_reversed_iteration(self, node, reversed_function):
  255. args = reversed_function.arg_tuple.args
  256. if len(args) == 0:
  257. error(reversed_function.pos,
  258. "reversed() requires an iterable argument")
  259. return node
  260. elif len(args) > 1:
  261. error(reversed_function.pos,
  262. "reversed() takes exactly 1 argument")
  263. return node
  264. arg = args[0]
  265. # reversed(list/tuple) ?
  266. if arg.type in (Builtin.tuple_type, Builtin.list_type):
  267. node.iterator.sequence = arg.as_none_safe_node("'NoneType' object is not iterable")
  268. node.iterator.reversed = True
  269. return node
  270. return self._optimise_for_loop(node, arg, reversed=True)
  271. PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
  272. PyrexTypes.c_char_ptr_type, [
  273. PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
  274. ])
  275. PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
  276. PyrexTypes.c_py_ssize_t_type, [
  277. PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
  278. ])
  279. def _transform_bytes_iteration(self, node, slice_node, reversed=False):
  280. target_type = node.target.type
  281. if not target_type.is_int and target_type is not Builtin.bytes_type:
  282. # bytes iteration returns bytes objects in Py2, but
  283. # integers in Py3
  284. return node
  285. unpack_temp_node = UtilNodes.LetRefNode(
  286. slice_node.as_none_safe_node("'NoneType' is not iterable"))
  287. slice_base_node = ExprNodes.PythonCapiCallNode(
  288. slice_node.pos, "PyBytes_AS_STRING",
  289. self.PyBytes_AS_STRING_func_type,
  290. args = [unpack_temp_node],
  291. is_temp = 0,
  292. )
  293. len_node = ExprNodes.PythonCapiCallNode(
  294. slice_node.pos, "PyBytes_GET_SIZE",
  295. self.PyBytes_GET_SIZE_func_type,
  296. args = [unpack_temp_node],
  297. is_temp = 0,
  298. )
  299. return UtilNodes.LetNode(
  300. unpack_temp_node,
  301. self._transform_carray_iteration(
  302. node,
  303. ExprNodes.SliceIndexNode(
  304. slice_node.pos,
  305. base = slice_base_node,
  306. start = None,
  307. step = None,
  308. stop = len_node,
  309. type = slice_base_node.type,
  310. is_temp = 1,
  311. ),
  312. reversed = reversed))
  313. PyUnicode_READ_func_type = PyrexTypes.CFuncType(
  314. PyrexTypes.c_py_ucs4_type, [
  315. PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_type, None),
  316. PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_type, None),
  317. PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None)
  318. ])
  319. init_unicode_iteration_func_type = PyrexTypes.CFuncType(
  320. PyrexTypes.c_int_type, [
  321. PyrexTypes.CFuncTypeArg("s", PyrexTypes.py_object_type, None),
  322. PyrexTypes.CFuncTypeArg("length", PyrexTypes.c_py_ssize_t_ptr_type, None),
  323. PyrexTypes.CFuncTypeArg("data", PyrexTypes.c_void_ptr_ptr_type, None),
  324. PyrexTypes.CFuncTypeArg("kind", PyrexTypes.c_int_ptr_type, None)
  325. ],
  326. exception_value = '-1')
  327. def _transform_unicode_iteration(self, node, slice_node, reversed=False):
  328. if slice_node.is_literal:
  329. # try to reduce to byte iteration for plain Latin-1 strings
  330. try:
  331. bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1')
  332. except UnicodeEncodeError:
  333. pass
  334. else:
  335. bytes_slice = ExprNodes.SliceIndexNode(
  336. slice_node.pos,
  337. base=ExprNodes.BytesNode(
  338. slice_node.pos, value=bytes_value,
  339. constant_result=bytes_value,
  340. type=PyrexTypes.c_const_char_ptr_type).coerce_to(
  341. PyrexTypes.c_const_uchar_ptr_type, self.current_env()),
  342. start=None,
  343. stop=ExprNodes.IntNode(
  344. slice_node.pos, value=str(len(bytes_value)),
  345. constant_result=len(bytes_value),
  346. type=PyrexTypes.c_py_ssize_t_type),
  347. type=Builtin.unicode_type, # hint for Python conversion
  348. )
  349. return self._transform_carray_iteration(node, bytes_slice, reversed)
  350. unpack_temp_node = UtilNodes.LetRefNode(
  351. slice_node.as_none_safe_node("'NoneType' is not iterable"))
  352. start_node = ExprNodes.IntNode(
  353. node.pos, value='0', constant_result=0, type=PyrexTypes.c_py_ssize_t_type)
  354. length_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  355. end_node = length_temp.ref(node.pos)
  356. if reversed:
  357. relation1, relation2 = '>', '>='
  358. start_node, end_node = end_node, start_node
  359. else:
  360. relation1, relation2 = '<=', '<'
  361. kind_temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
  362. data_temp = UtilNodes.TempHandle(PyrexTypes.c_void_ptr_type)
  363. counter_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  364. target_value = ExprNodes.PythonCapiCallNode(
  365. slice_node.pos, "__Pyx_PyUnicode_READ",
  366. self.PyUnicode_READ_func_type,
  367. args = [kind_temp.ref(slice_node.pos),
  368. data_temp.ref(slice_node.pos),
  369. counter_temp.ref(node.target.pos)],
  370. is_temp = False,
  371. )
  372. if target_value.type != node.target.type:
  373. target_value = target_value.coerce_to(node.target.type,
  374. self.current_env())
  375. target_assign = Nodes.SingleAssignmentNode(
  376. pos = node.target.pos,
  377. lhs = node.target,
  378. rhs = target_value)
  379. body = Nodes.StatListNode(
  380. node.pos,
  381. stats = [target_assign, node.body])
  382. loop_node = Nodes.ForFromStatNode(
  383. node.pos,
  384. bound1=start_node, relation1=relation1,
  385. target=counter_temp.ref(node.target.pos),
  386. relation2=relation2, bound2=end_node,
  387. step=None, body=body,
  388. else_clause=node.else_clause,
  389. from_range=True)
  390. setup_node = Nodes.ExprStatNode(
  391. node.pos,
  392. expr = ExprNodes.PythonCapiCallNode(
  393. slice_node.pos, "__Pyx_init_unicode_iteration",
  394. self.init_unicode_iteration_func_type,
  395. args = [unpack_temp_node,
  396. ExprNodes.AmpersandNode(slice_node.pos, operand=length_temp.ref(slice_node.pos),
  397. type=PyrexTypes.c_py_ssize_t_ptr_type),
  398. ExprNodes.AmpersandNode(slice_node.pos, operand=data_temp.ref(slice_node.pos),
  399. type=PyrexTypes.c_void_ptr_ptr_type),
  400. ExprNodes.AmpersandNode(slice_node.pos, operand=kind_temp.ref(slice_node.pos),
  401. type=PyrexTypes.c_int_ptr_type),
  402. ],
  403. is_temp = True,
  404. result_is_used = False,
  405. utility_code=UtilityCode.load_cached("unicode_iter", "Optimize.c"),
  406. ))
  407. return UtilNodes.LetNode(
  408. unpack_temp_node,
  409. UtilNodes.TempsBlockNode(
  410. node.pos, temps=[counter_temp, length_temp, data_temp, kind_temp],
  411. body=Nodes.StatListNode(node.pos, stats=[setup_node, loop_node])))
  412. def _transform_carray_iteration(self, node, slice_node, reversed=False):
  413. neg_step = False
  414. if isinstance(slice_node, ExprNodes.SliceIndexNode):
  415. slice_base = slice_node.base
  416. start = filter_none_node(slice_node.start)
  417. stop = filter_none_node(slice_node.stop)
  418. step = None
  419. if not stop:
  420. if not slice_base.type.is_pyobject:
  421. error(slice_node.pos, "C array iteration requires known end index")
  422. return node
  423. elif slice_node.is_subscript:
  424. assert isinstance(slice_node.index, ExprNodes.SliceNode)
  425. slice_base = slice_node.base
  426. index = slice_node.index
  427. start = filter_none_node(index.start)
  428. stop = filter_none_node(index.stop)
  429. step = filter_none_node(index.step)
  430. if step:
  431. if not isinstance(step.constant_result, _py_int_types) \
  432. or step.constant_result == 0 \
  433. or step.constant_result > 0 and not stop \
  434. or step.constant_result < 0 and not start:
  435. if not slice_base.type.is_pyobject:
  436. error(step.pos, "C array iteration requires known step size and end index")
  437. return node
  438. else:
  439. # step sign is handled internally by ForFromStatNode
  440. step_value = step.constant_result
  441. if reversed:
  442. step_value = -step_value
  443. neg_step = step_value < 0
  444. step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
  445. value=str(abs(step_value)),
  446. constant_result=abs(step_value))
  447. elif slice_node.type.is_array:
  448. if slice_node.type.size is None:
  449. error(slice_node.pos, "C array iteration requires known end index")
  450. return node
  451. slice_base = slice_node
  452. start = None
  453. stop = ExprNodes.IntNode(
  454. slice_node.pos, value=str(slice_node.type.size),
  455. type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
  456. step = None
  457. else:
  458. if not slice_node.type.is_pyobject:
  459. error(slice_node.pos, "C array iteration requires known end index")
  460. return node
  461. if start:
  462. start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  463. if stop:
  464. stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  465. if stop is None:
  466. if neg_step:
  467. stop = ExprNodes.IntNode(
  468. slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
  469. else:
  470. error(slice_node.pos, "C array iteration requires known step size and end index")
  471. return node
  472. if reversed:
  473. if not start:
  474. start = ExprNodes.IntNode(slice_node.pos, value="0", constant_result=0,
  475. type=PyrexTypes.c_py_ssize_t_type)
  476. # if step was provided, it was already negated above
  477. start, stop = stop, start
  478. ptr_type = slice_base.type
  479. if ptr_type.is_array:
  480. ptr_type = ptr_type.element_ptr_type()
  481. carray_ptr = slice_base.coerce_to_simple(self.current_env())
  482. if start and start.constant_result != 0:
  483. start_ptr_node = ExprNodes.AddNode(
  484. start.pos,
  485. operand1=carray_ptr,
  486. operator='+',
  487. operand2=start,
  488. type=ptr_type)
  489. else:
  490. start_ptr_node = carray_ptr
  491. if stop and stop.constant_result != 0:
  492. stop_ptr_node = ExprNodes.AddNode(
  493. stop.pos,
  494. operand1=ExprNodes.CloneNode(carray_ptr),
  495. operator='+',
  496. operand2=stop,
  497. type=ptr_type
  498. ).coerce_to_simple(self.current_env())
  499. else:
  500. stop_ptr_node = ExprNodes.CloneNode(carray_ptr)
  501. counter = UtilNodes.TempHandle(ptr_type)
  502. counter_temp = counter.ref(node.target.pos)
  503. if slice_base.type.is_string and node.target.type.is_pyobject:
  504. # special case: char* -> bytes/unicode
  505. if slice_node.type is Builtin.unicode_type:
  506. target_value = ExprNodes.CastNode(
  507. ExprNodes.DereferenceNode(
  508. node.target.pos, operand=counter_temp,
  509. type=ptr_type.base_type),
  510. PyrexTypes.c_py_ucs4_type).coerce_to(
  511. node.target.type, self.current_env())
  512. else:
  513. # char* -> bytes coercion requires slicing, not indexing
  514. target_value = ExprNodes.SliceIndexNode(
  515. node.target.pos,
  516. start=ExprNodes.IntNode(node.target.pos, value='0',
  517. constant_result=0,
  518. type=PyrexTypes.c_int_type),
  519. stop=ExprNodes.IntNode(node.target.pos, value='1',
  520. constant_result=1,
  521. type=PyrexTypes.c_int_type),
  522. base=counter_temp,
  523. type=Builtin.bytes_type,
  524. is_temp=1)
  525. elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
  526. # Allow iteration with pointer target to avoid copy.
  527. target_value = counter_temp
  528. else:
  529. # TODO: can this safely be replaced with DereferenceNode() as above?
  530. target_value = ExprNodes.IndexNode(
  531. node.target.pos,
  532. index=ExprNodes.IntNode(node.target.pos, value='0',
  533. constant_result=0,
  534. type=PyrexTypes.c_int_type),
  535. base=counter_temp,
  536. type=ptr_type.base_type)
  537. if target_value.type != node.target.type:
  538. target_value = target_value.coerce_to(node.target.type,
  539. self.current_env())
  540. target_assign = Nodes.SingleAssignmentNode(
  541. pos = node.target.pos,
  542. lhs = node.target,
  543. rhs = target_value)
  544. body = Nodes.StatListNode(
  545. node.pos,
  546. stats = [target_assign, node.body])
  547. relation1, relation2 = self._find_for_from_node_relations(neg_step, reversed)
  548. for_node = Nodes.ForFromStatNode(
  549. node.pos,
  550. bound1=start_ptr_node, relation1=relation1,
  551. target=counter_temp,
  552. relation2=relation2, bound2=stop_ptr_node,
  553. step=step, body=body,
  554. else_clause=node.else_clause,
  555. from_range=True)
  556. return UtilNodes.TempsBlockNode(
  557. node.pos, temps=[counter],
  558. body=for_node)
  559. def _transform_enumerate_iteration(self, node, enumerate_function):
  560. args = enumerate_function.arg_tuple.args
  561. if len(args) == 0:
  562. error(enumerate_function.pos,
  563. "enumerate() requires an iterable argument")
  564. return node
  565. elif len(args) > 2:
  566. error(enumerate_function.pos,
  567. "enumerate() takes at most 2 arguments")
  568. return node
  569. if not node.target.is_sequence_constructor:
  570. # leave this untouched for now
  571. return node
  572. targets = node.target.args
  573. if len(targets) != 2:
  574. # leave this untouched for now
  575. return node
  576. enumerate_target, iterable_target = targets
  577. counter_type = enumerate_target.type
  578. if not counter_type.is_pyobject and not counter_type.is_int:
  579. # nothing we can do here, I guess
  580. return node
  581. if len(args) == 2:
  582. start = unwrap_coerced_node(args[1]).coerce_to(counter_type, self.current_env())
  583. else:
  584. start = ExprNodes.IntNode(enumerate_function.pos,
  585. value='0',
  586. type=counter_type,
  587. constant_result=0)
  588. temp = UtilNodes.LetRefNode(start)
  589. inc_expression = ExprNodes.AddNode(
  590. enumerate_function.pos,
  591. operand1 = temp,
  592. operand2 = ExprNodes.IntNode(node.pos, value='1',
  593. type=counter_type,
  594. constant_result=1),
  595. operator = '+',
  596. type = counter_type,
  597. #inplace = True, # not worth using in-place operation for Py ints
  598. is_temp = counter_type.is_pyobject
  599. )
  600. loop_body = [
  601. Nodes.SingleAssignmentNode(
  602. pos = enumerate_target.pos,
  603. lhs = enumerate_target,
  604. rhs = temp),
  605. Nodes.SingleAssignmentNode(
  606. pos = enumerate_target.pos,
  607. lhs = temp,
  608. rhs = inc_expression)
  609. ]
  610. if isinstance(node.body, Nodes.StatListNode):
  611. node.body.stats = loop_body + node.body.stats
  612. else:
  613. loop_body.append(node.body)
  614. node.body = Nodes.StatListNode(
  615. node.body.pos,
  616. stats = loop_body)
  617. node.target = iterable_target
  618. node.item = node.item.coerce_to(iterable_target.type, self.current_env())
  619. node.iterator.sequence = args[0]
  620. # recurse into loop to check for further optimisations
  621. return UtilNodes.LetNode(temp, self._optimise_for_loop(node, node.iterator.sequence))
  622. def _find_for_from_node_relations(self, neg_step_value, reversed):
  623. if reversed:
  624. if neg_step_value:
  625. return '<', '<='
  626. else:
  627. return '>', '>='
  628. else:
  629. if neg_step_value:
  630. return '>=', '>'
  631. else:
  632. return '<=', '<'
  633. def _transform_range_iteration(self, node, range_function, reversed=False):
  634. args = range_function.arg_tuple.args
  635. if len(args) < 3:
  636. step_pos = range_function.pos
  637. step_value = 1
  638. step = ExprNodes.IntNode(step_pos, value='1', constant_result=1)
  639. else:
  640. step = args[2]
  641. step_pos = step.pos
  642. if not isinstance(step.constant_result, _py_int_types):
  643. # cannot determine step direction
  644. return node
  645. step_value = step.constant_result
  646. if step_value == 0:
  647. # will lead to an error elsewhere
  648. return node
  649. step = ExprNodes.IntNode(step_pos, value=str(step_value),
  650. constant_result=step_value)
  651. if len(args) == 1:
  652. bound1 = ExprNodes.IntNode(range_function.pos, value='0',
  653. constant_result=0)
  654. bound2 = args[0].coerce_to_integer(self.current_env())
  655. else:
  656. bound1 = args[0].coerce_to_integer(self.current_env())
  657. bound2 = args[1].coerce_to_integer(self.current_env())
  658. relation1, relation2 = self._find_for_from_node_relations(step_value < 0, reversed)
  659. bound2_ref_node = None
  660. if reversed:
  661. bound1, bound2 = bound2, bound1
  662. abs_step = abs(step_value)
  663. if abs_step != 1:
  664. if (isinstance(bound1.constant_result, _py_int_types) and
  665. isinstance(bound2.constant_result, _py_int_types)):
  666. # calculate final bounds now
  667. if step_value < 0:
  668. begin_value = bound2.constant_result
  669. end_value = bound1.constant_result
  670. bound1_value = begin_value - abs_step * ((begin_value - end_value - 1) // abs_step) - 1
  671. else:
  672. begin_value = bound1.constant_result
  673. end_value = bound2.constant_result
  674. bound1_value = end_value + abs_step * ((begin_value - end_value - 1) // abs_step) + 1
  675. bound1 = ExprNodes.IntNode(
  676. bound1.pos, value=str(bound1_value), constant_result=bound1_value,
  677. type=PyrexTypes.spanning_type(bound1.type, bound2.type))
  678. else:
  679. # evaluate the same expression as above at runtime
  680. bound2_ref_node = UtilNodes.LetRefNode(bound2)
  681. bound1 = self._build_range_step_calculation(
  682. bound1, bound2_ref_node, step, step_value)
  683. if step_value < 0:
  684. step_value = -step_value
  685. step.value = str(step_value)
  686. step.constant_result = step_value
  687. step = step.coerce_to_integer(self.current_env())
  688. if not bound2.is_literal:
  689. # stop bound must be immutable => keep it in a temp var
  690. bound2_is_temp = True
  691. bound2 = bound2_ref_node or UtilNodes.LetRefNode(bound2)
  692. else:
  693. bound2_is_temp = False
  694. for_node = Nodes.ForFromStatNode(
  695. node.pos,
  696. target=node.target,
  697. bound1=bound1, relation1=relation1,
  698. relation2=relation2, bound2=bound2,
  699. step=step, body=node.body,
  700. else_clause=node.else_clause,
  701. from_range=True)
  702. for_node.set_up_loop(self.current_env())
  703. if bound2_is_temp:
  704. for_node = UtilNodes.LetNode(bound2, for_node)
  705. return for_node
  706. def _build_range_step_calculation(self, bound1, bound2_ref_node, step, step_value):
  707. abs_step = abs(step_value)
  708. spanning_type = PyrexTypes.spanning_type(bound1.type, bound2_ref_node.type)
  709. if step.type.is_int and abs_step < 0x7FFF:
  710. # Avoid loss of integer precision warnings.
  711. spanning_step_type = PyrexTypes.spanning_type(spanning_type, PyrexTypes.c_int_type)
  712. else:
  713. spanning_step_type = PyrexTypes.spanning_type(spanning_type, step.type)
  714. if step_value < 0:
  715. begin_value = bound2_ref_node
  716. end_value = bound1
  717. final_op = '-'
  718. else:
  719. begin_value = bound1
  720. end_value = bound2_ref_node
  721. final_op = '+'
  722. step_calculation_node = ExprNodes.binop_node(
  723. bound1.pos,
  724. operand1=ExprNodes.binop_node(
  725. bound1.pos,
  726. operand1=bound2_ref_node,
  727. operator=final_op, # +/-
  728. operand2=ExprNodes.MulNode(
  729. bound1.pos,
  730. operand1=ExprNodes.IntNode(
  731. bound1.pos,
  732. value=str(abs_step),
  733. constant_result=abs_step,
  734. type=spanning_step_type),
  735. operator='*',
  736. operand2=ExprNodes.DivNode(
  737. bound1.pos,
  738. operand1=ExprNodes.SubNode(
  739. bound1.pos,
  740. operand1=ExprNodes.SubNode(
  741. bound1.pos,
  742. operand1=begin_value,
  743. operator='-',
  744. operand2=end_value,
  745. type=spanning_type),
  746. operator='-',
  747. operand2=ExprNodes.IntNode(
  748. bound1.pos,
  749. value='1',
  750. constant_result=1),
  751. type=spanning_step_type),
  752. operator='//',
  753. operand2=ExprNodes.IntNode(
  754. bound1.pos,
  755. value=str(abs_step),
  756. constant_result=abs_step,
  757. type=spanning_step_type),
  758. type=spanning_step_type),
  759. type=spanning_step_type),
  760. type=spanning_step_type),
  761. operator=final_op, # +/-
  762. operand2=ExprNodes.IntNode(
  763. bound1.pos,
  764. value='1',
  765. constant_result=1),
  766. type=spanning_type)
  767. return step_calculation_node
  768. def _transform_dict_iteration(self, node, dict_obj, method, keys, values):
  769. temps = []
  770. temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
  771. temps.append(temp)
  772. dict_temp = temp.ref(dict_obj.pos)
  773. temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  774. temps.append(temp)
  775. pos_temp = temp.ref(node.pos)
  776. key_target = value_target = tuple_target = None
  777. if keys and values:
  778. if node.target.is_sequence_constructor:
  779. if len(node.target.args) == 2:
  780. key_target, value_target = node.target.args
  781. else:
  782. # unusual case that may or may not lead to an error
  783. return node
  784. else:
  785. tuple_target = node.target
  786. elif keys:
  787. key_target = node.target
  788. else:
  789. value_target = node.target
  790. if isinstance(node.body, Nodes.StatListNode):
  791. body = node.body
  792. else:
  793. body = Nodes.StatListNode(pos = node.body.pos,
  794. stats = [node.body])
  795. # keep original length to guard against dict modification
  796. dict_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  797. temps.append(dict_len_temp)
  798. dict_len_temp_addr = ExprNodes.AmpersandNode(
  799. node.pos, operand=dict_len_temp.ref(dict_obj.pos),
  800. type=PyrexTypes.c_ptr_type(dict_len_temp.type))
  801. temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
  802. temps.append(temp)
  803. is_dict_temp = temp.ref(node.pos)
  804. is_dict_temp_addr = ExprNodes.AmpersandNode(
  805. node.pos, operand=is_dict_temp,
  806. type=PyrexTypes.c_ptr_type(temp.type))
  807. iter_next_node = Nodes.DictIterationNextNode(
  808. dict_temp, dict_len_temp.ref(dict_obj.pos), pos_temp,
  809. key_target, value_target, tuple_target,
  810. is_dict_temp)
  811. iter_next_node = iter_next_node.analyse_expressions(self.current_env())
  812. body.stats[0:0] = [iter_next_node]
  813. if method:
  814. method_node = ExprNodes.StringNode(
  815. dict_obj.pos, is_identifier=True, value=method)
  816. dict_obj = dict_obj.as_none_safe_node(
  817. "'NoneType' object has no attribute '%{0}s'".format('.30' if len(method) <= 30 else ''),
  818. error = "PyExc_AttributeError",
  819. format_args = [method])
  820. else:
  821. method_node = ExprNodes.NullNode(dict_obj.pos)
  822. dict_obj = dict_obj.as_none_safe_node("'NoneType' object is not iterable")
  823. def flag_node(value):
  824. value = value and 1 or 0
  825. return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
  826. result_code = [
  827. Nodes.SingleAssignmentNode(
  828. node.pos,
  829. lhs = pos_temp,
  830. rhs = ExprNodes.IntNode(node.pos, value='0',
  831. constant_result=0)),
  832. Nodes.SingleAssignmentNode(
  833. dict_obj.pos,
  834. lhs = dict_temp,
  835. rhs = ExprNodes.PythonCapiCallNode(
  836. dict_obj.pos,
  837. "__Pyx_dict_iterator",
  838. self.PyDict_Iterator_func_type,
  839. utility_code = UtilityCode.load_cached("dict_iter", "Optimize.c"),
  840. args = [dict_obj, flag_node(dict_obj.type is Builtin.dict_type),
  841. method_node, dict_len_temp_addr, is_dict_temp_addr,
  842. ],
  843. is_temp=True,
  844. )),
  845. Nodes.WhileStatNode(
  846. node.pos,
  847. condition = None,
  848. body = body,
  849. else_clause = node.else_clause
  850. )
  851. ]
  852. return UtilNodes.TempsBlockNode(
  853. node.pos, temps=temps,
  854. body=Nodes.StatListNode(
  855. node.pos,
  856. stats = result_code
  857. ))
  858. PyDict_Iterator_func_type = PyrexTypes.CFuncType(
  859. PyrexTypes.py_object_type, [
  860. PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
  861. PyrexTypes.CFuncTypeArg("is_dict", PyrexTypes.c_int_type, None),
  862. PyrexTypes.CFuncTypeArg("method_name", PyrexTypes.py_object_type, None),
  863. PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None),
  864. PyrexTypes.CFuncTypeArg("p_is_dict", PyrexTypes.c_int_ptr_type, None),
  865. ])
  866. PySet_Iterator_func_type = PyrexTypes.CFuncType(
  867. PyrexTypes.py_object_type, [
  868. PyrexTypes.CFuncTypeArg("set", PyrexTypes.py_object_type, None),
  869. PyrexTypes.CFuncTypeArg("is_set", PyrexTypes.c_int_type, None),
  870. PyrexTypes.CFuncTypeArg("p_orig_length", PyrexTypes.c_py_ssize_t_ptr_type, None),
  871. PyrexTypes.CFuncTypeArg("p_is_set", PyrexTypes.c_int_ptr_type, None),
  872. ])
  873. def _transform_set_iteration(self, node, set_obj):
  874. temps = []
  875. temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
  876. temps.append(temp)
  877. set_temp = temp.ref(set_obj.pos)
  878. temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  879. temps.append(temp)
  880. pos_temp = temp.ref(node.pos)
  881. if isinstance(node.body, Nodes.StatListNode):
  882. body = node.body
  883. else:
  884. body = Nodes.StatListNode(pos = node.body.pos,
  885. stats = [node.body])
  886. # keep original length to guard against set modification
  887. set_len_temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
  888. temps.append(set_len_temp)
  889. set_len_temp_addr = ExprNodes.AmpersandNode(
  890. node.pos, operand=set_len_temp.ref(set_obj.pos),
  891. type=PyrexTypes.c_ptr_type(set_len_temp.type))
  892. temp = UtilNodes.TempHandle(PyrexTypes.c_int_type)
  893. temps.append(temp)
  894. is_set_temp = temp.ref(node.pos)
  895. is_set_temp_addr = ExprNodes.AmpersandNode(
  896. node.pos, operand=is_set_temp,
  897. type=PyrexTypes.c_ptr_type(temp.type))
  898. value_target = node.target
  899. iter_next_node = Nodes.SetIterationNextNode(
  900. set_temp, set_len_temp.ref(set_obj.pos), pos_temp, value_target, is_set_temp)
  901. iter_next_node = iter_next_node.analyse_expressions(self.current_env())
  902. body.stats[0:0] = [iter_next_node]
  903. def flag_node(value):
  904. value = value and 1 or 0
  905. return ExprNodes.IntNode(node.pos, value=str(value), constant_result=value)
  906. result_code = [
  907. Nodes.SingleAssignmentNode(
  908. node.pos,
  909. lhs=pos_temp,
  910. rhs=ExprNodes.IntNode(node.pos, value='0', constant_result=0)),
  911. Nodes.SingleAssignmentNode(
  912. set_obj.pos,
  913. lhs=set_temp,
  914. rhs=ExprNodes.PythonCapiCallNode(
  915. set_obj.pos,
  916. "__Pyx_set_iterator",
  917. self.PySet_Iterator_func_type,
  918. utility_code=UtilityCode.load_cached("set_iter", "Optimize.c"),
  919. args=[set_obj, flag_node(set_obj.type is Builtin.set_type),
  920. set_len_temp_addr, is_set_temp_addr,
  921. ],
  922. is_temp=True,
  923. )),
  924. Nodes.WhileStatNode(
  925. node.pos,
  926. condition=None,
  927. body=body,
  928. else_clause=node.else_clause,
  929. )
  930. ]
  931. return UtilNodes.TempsBlockNode(
  932. node.pos, temps=temps,
  933. body=Nodes.StatListNode(
  934. node.pos,
  935. stats = result_code
  936. ))
  937. class SwitchTransform(Visitor.EnvTransform):
  938. """
  939. This transformation tries to turn long if statements into C switch statements.
  940. The requirement is that every clause be an (or of) var == value, where the var
  941. is common among all clauses and both var and value are ints.
  942. """
  943. NO_MATCH = (None, None, None)
  944. def extract_conditions(self, cond, allow_not_in):
  945. while True:
  946. if isinstance(cond, (ExprNodes.CoerceToTempNode,
  947. ExprNodes.CoerceToBooleanNode)):
  948. cond = cond.arg
  949. elif isinstance(cond, ExprNodes.BoolBinopResultNode):
  950. cond = cond.arg.arg
  951. elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
  952. # this is what we get from the FlattenInListTransform
  953. cond = cond.subexpression
  954. elif isinstance(cond, ExprNodes.TypecastNode):
  955. cond = cond.operand
  956. else:
  957. break
  958. if isinstance(cond, ExprNodes.PrimaryCmpNode):
  959. if cond.cascade is not None:
  960. return self.NO_MATCH
  961. elif cond.is_c_string_contains() and \
  962. isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
  963. not_in = cond.operator == 'not_in'
  964. if not_in and not allow_not_in:
  965. return self.NO_MATCH
  966. if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
  967. cond.operand2.contains_surrogates():
  968. # dealing with surrogates leads to different
  969. # behaviour on wide and narrow Unicode
  970. # platforms => refuse to optimise this case
  971. return self.NO_MATCH
  972. return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
  973. elif not cond.is_python_comparison():
  974. if cond.operator == '==':
  975. not_in = False
  976. elif allow_not_in and cond.operator == '!=':
  977. not_in = True
  978. else:
  979. return self.NO_MATCH
  980. # this looks somewhat silly, but it does the right
  981. # checks for NameNode and AttributeNode
  982. if is_common_value(cond.operand1, cond.operand1):
  983. if cond.operand2.is_literal:
  984. return not_in, cond.operand1, [cond.operand2]
  985. elif getattr(cond.operand2, 'entry', None) \
  986. and cond.operand2.entry.is_const:
  987. return not_in, cond.operand1, [cond.operand2]
  988. if is_common_value(cond.operand2, cond.operand2):
  989. if cond.operand1.is_literal:
  990. return not_in, cond.operand2, [cond.operand1]
  991. elif getattr(cond.operand1, 'entry', None) \
  992. and cond.operand1.entry.is_const:
  993. return not_in, cond.operand2, [cond.operand1]
  994. elif isinstance(cond, ExprNodes.BoolBinopNode):
  995. if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
  996. allow_not_in = (cond.operator == 'and')
  997. not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
  998. not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
  999. if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
  1000. if (not not_in_1) or allow_not_in:
  1001. return not_in_1, t1, c1+c2
  1002. return self.NO_MATCH
  1003. def extract_in_string_conditions(self, string_literal):
  1004. if isinstance(string_literal, ExprNodes.UnicodeNode):
  1005. charvals = list(map(ord, set(string_literal.value)))
  1006. charvals.sort()
  1007. return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
  1008. constant_result=charval)
  1009. for charval in charvals ]
  1010. else:
  1011. # this is a bit tricky as Py3's bytes type returns
  1012. # integers on iteration, whereas Py2 returns 1-char byte
  1013. # strings
  1014. characters = string_literal.value
  1015. characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
  1016. characters.sort()
  1017. return [ ExprNodes.CharNode(string_literal.pos, value=charval,
  1018. constant_result=charval)
  1019. for charval in characters ]
  1020. def extract_common_conditions(self, common_var, condition, allow_not_in):
  1021. not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
  1022. if var is None:
  1023. return self.NO_MATCH
  1024. elif common_var is not None and not is_common_value(var, common_var):
  1025. return self.NO_MATCH
  1026. elif not (var.type.is_int or var.type.is_enum) or sum([not (cond.type.is_int or cond.type.is_enum) for cond in conditions]):
  1027. return self.NO_MATCH
  1028. return not_in, var, conditions
  1029. def has_duplicate_values(self, condition_values):
  1030. # duplicated values don't work in a switch statement
  1031. seen = set()
  1032. for value in condition_values:
  1033. if value.has_constant_result():
  1034. if value.constant_result in seen:
  1035. return True
  1036. seen.add(value.constant_result)
  1037. else:
  1038. # this isn't completely safe as we don't know the
  1039. # final C value, but this is about the best we can do
  1040. try:
  1041. if value.entry.cname in seen:
  1042. return True
  1043. except AttributeError:
  1044. return True # play safe
  1045. seen.add(value.entry.cname)
  1046. return False
  1047. def visit_IfStatNode(self, node):
  1048. if not self.current_directives.get('optimize.use_switch'):
  1049. self.visitchildren(node)
  1050. return node
  1051. common_var = None
  1052. cases = []
  1053. for if_clause in node.if_clauses:
  1054. _, common_var, conditions = self.extract_common_conditions(
  1055. common_var, if_clause.condition, False)
  1056. if common_var is None:
  1057. self.visitchildren(node)
  1058. return node
  1059. cases.append(Nodes.SwitchCaseNode(pos=if_clause.pos,
  1060. conditions=conditions,
  1061. body=if_clause.body))
  1062. condition_values = [
  1063. cond for case in cases for cond in case.conditions]
  1064. if len(condition_values) < 2:
  1065. self.visitchildren(node)
  1066. return node
  1067. if self.has_duplicate_values(condition_values):
  1068. self.visitchildren(node)
  1069. return node
  1070. # Recurse into body subtrees that we left untouched so far.
  1071. self.visitchildren(node, 'else_clause')
  1072. for case in cases:
  1073. self.visitchildren(case, 'body')
  1074. common_var = unwrap_node(common_var)
  1075. switch_node = Nodes.SwitchStatNode(pos=node.pos,
  1076. test=common_var,
  1077. cases=cases,
  1078. else_clause=node.else_clause)
  1079. return switch_node
  1080. def visit_CondExprNode(self, node):
  1081. if not self.current_directives.get('optimize.use_switch'):
  1082. self.visitchildren(node)
  1083. return node
  1084. not_in, common_var, conditions = self.extract_common_conditions(
  1085. None, node.test, True)
  1086. if common_var is None \
  1087. or len(conditions) < 2 \
  1088. or self.has_duplicate_values(conditions):
  1089. self.visitchildren(node)
  1090. return node
  1091. return self.build_simple_switch_statement(
  1092. node, common_var, conditions, not_in,
  1093. node.true_val, node.false_val)
  1094. def visit_BoolBinopNode(self, node):
  1095. if not self.current_directives.get('optimize.use_switch'):
  1096. self.visitchildren(node)
  1097. return node
  1098. not_in, common_var, conditions = self.extract_common_conditions(
  1099. None, node, True)
  1100. if common_var is None \
  1101. or len(conditions) < 2 \
  1102. or self.has_duplicate_values(conditions):
  1103. self.visitchildren(node)
  1104. node.wrap_operands(self.current_env()) # in case we changed the operands
  1105. return node
  1106. return self.build_simple_switch_statement(
  1107. node, common_var, conditions, not_in,
  1108. ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
  1109. ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
  1110. def visit_PrimaryCmpNode(self, node):
  1111. if not self.current_directives.get('optimize.use_switch'):
  1112. self.visitchildren(node)
  1113. return node
  1114. not_in, common_var, conditions = self.extract_common_conditions(
  1115. None, node, True)
  1116. if common_var is None \
  1117. or len(conditions) < 2 \
  1118. or self.has_duplicate_values(conditions):
  1119. self.visitchildren(node)
  1120. return node
  1121. return self.build_simple_switch_statement(
  1122. node, common_var, conditions, not_in,
  1123. ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
  1124. ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
  1125. def build_simple_switch_statement(self, node, common_var, conditions,
  1126. not_in, true_val, false_val):
  1127. result_ref = UtilNodes.ResultRefNode(node)
  1128. true_body = Nodes.SingleAssignmentNode(
  1129. node.pos,
  1130. lhs=result_ref,
  1131. rhs=true_val.coerce_to(node.type, self.current_env()),
  1132. first=True)
  1133. false_body = Nodes.SingleAssignmentNode(
  1134. node.pos,
  1135. lhs=result_ref,
  1136. rhs=false_val.coerce_to(node.type, self.current_env()),
  1137. first=True)
  1138. if not_in:
  1139. true_body, false_body = false_body, true_body
  1140. cases = [Nodes.SwitchCaseNode(pos = node.pos,
  1141. conditions = conditions,
  1142. body = true_body)]
  1143. common_var = unwrap_node(common_var)
  1144. switch_node = Nodes.SwitchStatNode(pos = node.pos,
  1145. test = common_var,
  1146. cases = cases,
  1147. else_clause = false_body)
  1148. replacement = UtilNodes.TempResultFromStatNode(result_ref, switch_node)
  1149. return replacement
  1150. def visit_EvalWithTempExprNode(self, node):
  1151. if not self.current_directives.get('optimize.use_switch'):
  1152. self.visitchildren(node)
  1153. return node
  1154. # drop unused expression temp from FlattenInListTransform
  1155. orig_expr = node.subexpression
  1156. temp_ref = node.lazy_temp
  1157. self.visitchildren(node)
  1158. if node.subexpression is not orig_expr:
  1159. # node was restructured => check if temp is still used
  1160. if not Visitor.tree_contains(node.subexpression, temp_ref):
  1161. return node.subexpression
  1162. return node
  1163. visit_Node = Visitor.VisitorTransform.recurse_to_children
  1164. class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
  1165. """
  1166. This transformation flattens "x in [val1, ..., valn]" into a sequential list
  1167. of comparisons.
  1168. """
  1169. def visit_PrimaryCmpNode(self, node):
  1170. self.visitchildren(node)
  1171. if node.cascade is not None:
  1172. return node
  1173. elif node.operator == 'in':
  1174. conjunction = 'or'
  1175. eq_or_neq = '=='
  1176. elif node.operator == 'not_in':
  1177. conjunction = 'and'
  1178. eq_or_neq = '!='
  1179. else:
  1180. return node
  1181. if not isinstance(node.operand2, (ExprNodes.TupleNode,
  1182. ExprNodes.ListNode,
  1183. ExprNodes.SetNode)):
  1184. return node
  1185. args = node.operand2.args
  1186. if len(args) == 0:
  1187. # note: lhs may have side effects
  1188. return node
  1189. lhs = UtilNodes.ResultRefNode(node.operand1)
  1190. conds = []
  1191. temps = []
  1192. for arg in args:
  1193. try:
  1194. # Trial optimisation to avoid redundant temp
  1195. # assignments. However, since is_simple() is meant to
  1196. # be called after type analysis, we ignore any errors
  1197. # and just play safe in that case.
  1198. is_simple_arg = arg.is_simple()
  1199. except Exception:
  1200. is_simple_arg = False
  1201. if not is_simple_arg:
  1202. # must evaluate all non-simple RHS before doing the comparisons
  1203. arg = UtilNodes.LetRefNode(arg)
  1204. temps.append(arg)
  1205. cond = ExprNodes.PrimaryCmpNode(
  1206. pos = node.pos,
  1207. operand1 = lhs,
  1208. operator = eq_or_neq,
  1209. operand2 = arg,
  1210. cascade = None)
  1211. conds.append(ExprNodes.TypecastNode(
  1212. pos = node.pos,
  1213. operand = cond,
  1214. type = PyrexTypes.c_bint_type))
  1215. def concat(left, right):
  1216. return ExprNodes.BoolBinopNode(
  1217. pos = node.pos,
  1218. operator = conjunction,
  1219. operand1 = left,
  1220. operand2 = right)
  1221. condition = reduce(concat, conds)
  1222. new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
  1223. for temp in temps[::-1]:
  1224. new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
  1225. return new_node
  1226. visit_Node = Visitor.VisitorTransform.recurse_to_children
  1227. class DropRefcountingTransform(Visitor.VisitorTransform):
  1228. """Drop ref-counting in safe places.
  1229. """
  1230. visit_Node = Visitor.VisitorTransform.recurse_to_children
  1231. def visit_ParallelAssignmentNode(self, node):
  1232. """
  1233. Parallel swap assignments like 'a,b = b,a' are safe.
  1234. """
  1235. left_names, right_names = [], []
  1236. left_indices, right_indices = [], []
  1237. temps = []
  1238. for stat in node.stats:
  1239. if isinstance(stat, Nodes.SingleAssignmentNode):
  1240. if not self._extract_operand(stat.lhs, left_names,
  1241. left_indices, temps):
  1242. return node
  1243. if not self._extract_operand(stat.rhs, right_names,
  1244. right_indices, temps):
  1245. return node
  1246. elif isinstance(stat, Nodes.CascadedAssignmentNode):
  1247. # FIXME
  1248. return node
  1249. else:
  1250. return node
  1251. if left_names or right_names:
  1252. # lhs/rhs names must be a non-redundant permutation
  1253. lnames = [ path for path, n in left_names ]
  1254. rnames = [ path for path, n in right_names ]
  1255. if set(lnames) != set(rnames):
  1256. return node
  1257. if len(set(lnames)) != len(right_names):
  1258. return node
  1259. if left_indices or right_indices:
  1260. # base name and index of index nodes must be a
  1261. # non-redundant permutation
  1262. lindices = []
  1263. for lhs_node in left_indices:
  1264. index_id = self._extract_index_id(lhs_node)
  1265. if not index_id:
  1266. return node
  1267. lindices.append(index_id)
  1268. rindices = []
  1269. for rhs_node in right_indices:
  1270. index_id = self._extract_index_id(rhs_node)
  1271. if not index_id:
  1272. return node
  1273. rindices.append(index_id)
  1274. if set(lindices) != set(rindices):
  1275. return node
  1276. if len(set(lindices)) != len(right_indices):
  1277. return node
  1278. # really supporting IndexNode requires support in
  1279. # __Pyx_GetItemInt(), so let's stop short for now
  1280. return node
  1281. temp_args = [t.arg for t in temps]
  1282. for temp in temps:
  1283. temp.use_managed_ref = False
  1284. for _, name_node in left_names + right_names:
  1285. if name_node not in temp_args:
  1286. name_node.use_managed_ref = False
  1287. for index_node in left_indices + right_indices:
  1288. index_node.use_managed_ref = False
  1289. return node
  1290. def _extract_operand(self, node, names, indices, temps):
  1291. node = unwrap_node(node)
  1292. if not node.type.is_pyobject:
  1293. return False
  1294. if isinstance(node, ExprNodes.CoerceToTempNode):
  1295. temps.append(node)
  1296. node = node.arg
  1297. name_path = []
  1298. obj_node = node
  1299. while obj_node.is_attribute:
  1300. if obj_node.is_py_attr:
  1301. return False
  1302. name_path.append(obj_node.member)
  1303. obj_node = obj_node.obj
  1304. if obj_node.is_name:
  1305. name_path.append(obj_node.name)
  1306. names.append( ('.'.join(name_path[::-1]), node) )
  1307. elif node.is_subscript:
  1308. if node.base.type != Builtin.list_type:
  1309. return False
  1310. if not node.index.type.is_int:
  1311. return False
  1312. if not node.base.is_name:
  1313. return False
  1314. indices.append(node)
  1315. else:
  1316. return False
  1317. return True
  1318. def _extract_index_id(self, index_node):
  1319. base = index_node.base
  1320. index = index_node.index
  1321. if isinstance(index, ExprNodes.NameNode):
  1322. index_val = index.name
  1323. elif isinstance(index, ExprNodes.ConstNode):
  1324. # FIXME:
  1325. return None
  1326. else:
  1327. return None
  1328. return (base.name, index_val)
  1329. class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
  1330. """Optimize some common calls to builtin types *before* the type
  1331. analysis phase and *after* the declarations analysis phase.
  1332. This transform cannot make use of any argument types, but it can
  1333. restructure the tree in a way that the type analysis phase can
  1334. respond to.
  1335. Introducing C function calls here may not be a good idea. Move
  1336. them to the OptimizeBuiltinCalls transform instead, which runs
  1337. after type analysis.
  1338. """
  1339. # only intercept on call nodes
  1340. visit_Node = Visitor.VisitorTransform.recurse_to_children
  1341. def visit_SimpleCallNode(self, node):
  1342. self.visitchildren(node)
  1343. function = node.function
  1344. if not self._function_is_builtin_name(function):
  1345. return node
  1346. return self._dispatch_to_handler(node, function, node.args)
  1347. def visit_GeneralCallNode(self, node):
  1348. self.visitchildren(node)
  1349. function = node.function
  1350. if not self._function_is_builtin_name(function):
  1351. return node
  1352. arg_tuple = node.positional_args
  1353. if not isinstance(arg_tuple, ExprNodes.TupleNode):
  1354. return node
  1355. args = arg_tuple.args
  1356. return self._dispatch_to_handler(
  1357. node, function, args, node.keyword_args)
  1358. def _function_is_builtin_name(self, function):
  1359. if not function.is_name:
  1360. return False
  1361. env = self.current_env()
  1362. entry = env.lookup(function.name)
  1363. if entry is not env.builtin_scope().lookup_here(function.name):
  1364. return False
  1365. # if entry is None, it's at least an undeclared name, so likely builtin
  1366. return True
  1367. def _dispatch_to_handler(self, node, function, args, kwargs=None):
  1368. if kwargs is None:
  1369. handler_name = '_handle_simple_function_%s' % function.name
  1370. else:
  1371. handler_name = '_handle_general_function_%s' % function.name
  1372. handle_call = getattr(self, handler_name, None)
  1373. if handle_call is not None:
  1374. if kwargs is None:
  1375. return handle_call(node, args)
  1376. else:
  1377. return handle_call(node, args, kwargs)
  1378. return node
  1379. def _inject_capi_function(self, node, cname, func_type, utility_code=None):
  1380. node.function = ExprNodes.PythonCapiFunctionNode(
  1381. node.function.pos, node.function.name, cname, func_type,
  1382. utility_code = utility_code)
  1383. def _error_wrong_arg_count(self, function_name, node, args, expected=None):
  1384. if not expected: # None or 0
  1385. arg_str = ''
  1386. elif isinstance(expected, basestring) or expected > 1:
  1387. arg_str = '...'
  1388. elif expected == 1:
  1389. arg_str = 'x'
  1390. else:
  1391. arg_str = ''
  1392. if expected is not None:
  1393. expected_str = 'expected %s, ' % expected
  1394. else:
  1395. expected_str = ''
  1396. error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
  1397. function_name, arg_str, expected_str, len(args)))
  1398. # specific handlers for simple call nodes
  1399. def _handle_simple_function_float(self, node, pos_args):
  1400. if not pos_args:
  1401. return ExprNodes.FloatNode(node.pos, value='0.0')
  1402. if len(pos_args) > 1:
  1403. self._error_wrong_arg_count('float', node, pos_args, 1)
  1404. arg_type = getattr(pos_args[0], 'type', None)
  1405. if arg_type in (PyrexTypes.c_double_type, Builtin.float_type):
  1406. return pos_args[0]
  1407. return node
  1408. def _handle_simple_function_slice(self, node, pos_args):
  1409. arg_count = len(pos_args)
  1410. start = step = None
  1411. if arg_count == 1:
  1412. stop, = pos_args
  1413. elif arg_count == 2:
  1414. start, stop = pos_args
  1415. elif arg_count == 3:
  1416. start, stop, step = pos_args
  1417. else:
  1418. self._error_wrong_arg_count('slice', node, pos_args)
  1419. return node
  1420. return ExprNodes.SliceNode(
  1421. node.pos,
  1422. start=start or ExprNodes.NoneNode(node.pos),
  1423. stop=stop,
  1424. step=step or ExprNodes.NoneNode(node.pos))
  1425. def _handle_simple_function_ord(self, node, pos_args):
  1426. """Unpack ord('X').
  1427. """
  1428. if len(pos_args) != 1:
  1429. return node
  1430. arg = pos_args[0]
  1431. if isinstance(arg, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
  1432. if len(arg.value) == 1:
  1433. return ExprNodes.IntNode(
  1434. arg.pos, type=PyrexTypes.c_long_type,
  1435. value=str(ord(arg.value)),
  1436. constant_result=ord(arg.value)
  1437. )
  1438. elif isinstance(arg, ExprNodes.StringNode):
  1439. if arg.unicode_value and len(arg.unicode_value) == 1 \
  1440. and ord(arg.unicode_value) <= 255: # Py2/3 portability
  1441. return ExprNodes.IntNode(
  1442. arg.pos, type=PyrexTypes.c_int_type,
  1443. value=str(ord(arg.unicode_value)),
  1444. constant_result=ord(arg.unicode_value)
  1445. )
  1446. return node
  1447. # sequence processing
  1448. def _handle_simple_function_all(self, node, pos_args):
  1449. """Transform
  1450. _result = all(p(x) for L in LL for x in L)
  1451. into
  1452. for L in LL:
  1453. for x in L:
  1454. if not p(x):
  1455. return False
  1456. else:
  1457. return True
  1458. """
  1459. return self._transform_any_all(node, pos_args, False)
  1460. def _handle_simple_function_any(self, node, pos_args):
  1461. """Transform
  1462. _result = any(p(x) for L in LL for x in L)
  1463. into
  1464. for L in LL:
  1465. for x in L:
  1466. if p(x):
  1467. return True
  1468. else:
  1469. return False
  1470. """
  1471. return self._transform_any_all(node, pos_args, True)
  1472. def _transform_any_all(self, node, pos_args, is_any):
  1473. if len(pos_args) != 1:
  1474. return node
  1475. if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
  1476. return node
  1477. gen_expr_node = pos_args[0]
  1478. generator_body = gen_expr_node.def_node.gbody
  1479. loop_node = generator_body.body
  1480. yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
  1481. if yield_expression is None:
  1482. return node
  1483. if is_any:
  1484. condition = yield_expression
  1485. else:
  1486. condition = ExprNodes.NotNode(yield_expression.pos, operand=yield_expression)
  1487. test_node = Nodes.IfStatNode(
  1488. yield_expression.pos, else_clause=None, if_clauses=[
  1489. Nodes.IfClauseNode(
  1490. yield_expression.pos,
  1491. condition=condition,
  1492. body=Nodes.ReturnStatNode(
  1493. node.pos,
  1494. value=ExprNodes.BoolNode(yield_expression.pos, value=is_any, constant_result=is_any))
  1495. )]
  1496. )
  1497. loop_node.else_clause = Nodes.ReturnStatNode(
  1498. node.pos,
  1499. value=ExprNodes.BoolNode(yield_expression.pos, value=not is_any, constant_result=not is_any))
  1500. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, test_node)
  1501. return ExprNodes.InlinedGeneratorExpressionNode(
  1502. gen_expr_node.pos, gen=gen_expr_node, orig_func='any' if is_any else 'all')
  1503. PySequence_List_func_type = PyrexTypes.CFuncType(
  1504. Builtin.list_type,
  1505. [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])
  1506. def _handle_simple_function_sorted(self, node, pos_args):
  1507. """Transform sorted(genexpr) and sorted([listcomp]) into
  1508. [listcomp].sort(). CPython just reads the iterable into a
  1509. list and calls .sort() on it. Expanding the iterable in a
  1510. listcomp is still faster and the result can be sorted in
  1511. place.
  1512. """
  1513. if len(pos_args) != 1:
  1514. return node
  1515. arg = pos_args[0]
  1516. if isinstance(arg, ExprNodes.ComprehensionNode) and arg.type is Builtin.list_type:
  1517. list_node = pos_args[0]
  1518. loop_node = list_node.loop
  1519. elif isinstance(arg, ExprNodes.GeneratorExpressionNode):
  1520. gen_expr_node = arg
  1521. loop_node = gen_expr_node.loop
  1522. yield_statements = _find_yield_statements(loop_node)
  1523. if not yield_statements:
  1524. return node
  1525. list_node = ExprNodes.InlinedGeneratorExpressionNode(
  1526. node.pos, gen_expr_node, orig_func='sorted',
  1527. comprehension_type=Builtin.list_type)
  1528. for yield_expression, yield_stat_node in yield_statements:
  1529. append_node = ExprNodes.ComprehensionAppendNode(
  1530. yield_expression.pos,
  1531. expr=yield_expression,
  1532. target=list_node.target)
  1533. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
  1534. elif arg.is_sequence_constructor:
  1535. # sorted([a, b, c]) or sorted((a, b, c)). The result is always a list,
  1536. # so starting off with a fresh one is more efficient.
  1537. list_node = loop_node = arg.as_list()
  1538. else:
  1539. # Interestingly, PySequence_List works on a lot of non-sequence
  1540. # things as well.
  1541. list_node = loop_node = ExprNodes.PythonCapiCallNode(
  1542. node.pos, "PySequence_List", self.PySequence_List_func_type,
  1543. args=pos_args, is_temp=True)
  1544. result_node = UtilNodes.ResultRefNode(
  1545. pos=loop_node.pos, type=Builtin.list_type, may_hold_none=False)
  1546. list_assign_node = Nodes.SingleAssignmentNode(
  1547. node.pos, lhs=result_node, rhs=list_node, first=True)
  1548. sort_method = ExprNodes.AttributeNode(
  1549. node.pos, obj=result_node, attribute=EncodedString('sort'),
  1550. # entry ? type ?
  1551. needs_none_check=False)
  1552. sort_node = Nodes.ExprStatNode(
  1553. node.pos, expr=ExprNodes.SimpleCallNode(
  1554. node.pos, function=sort_method, args=[]))
  1555. sort_node.analyse_declarations(self.current_env())
  1556. return UtilNodes.TempResultFromStatNode(
  1557. result_node,
  1558. Nodes.StatListNode(node.pos, stats=[list_assign_node, sort_node]))
  1559. def __handle_simple_function_sum(self, node, pos_args):
  1560. """Transform sum(genexpr) into an equivalent inlined aggregation loop.
  1561. """
  1562. if len(pos_args) not in (1,2):
  1563. return node
  1564. if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
  1565. ExprNodes.ComprehensionNode)):
  1566. return node
  1567. gen_expr_node = pos_args[0]
  1568. loop_node = gen_expr_node.loop
  1569. if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
  1570. yield_expression, yield_stat_node = _find_single_yield_expression(loop_node)
  1571. # FIXME: currently nonfunctional
  1572. yield_expression = None
  1573. if yield_expression is None:
  1574. return node
  1575. else: # ComprehensionNode
  1576. yield_stat_node = gen_expr_node.append
  1577. yield_expression = yield_stat_node.expr
  1578. try:
  1579. if not yield_expression.is_literal or not yield_expression.type.is_int:
  1580. return node
  1581. except AttributeError:
  1582. return node # in case we don't have a type yet
  1583. # special case: old Py2 backwards compatible "sum([int_const for ...])"
  1584. # can safely be unpacked into a genexpr
  1585. if len(pos_args) == 1:
  1586. start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
  1587. else:
  1588. start = pos_args[1]
  1589. result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
  1590. add_node = Nodes.SingleAssignmentNode(
  1591. yield_expression.pos,
  1592. lhs = result_ref,
  1593. rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
  1594. )
  1595. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, add_node)
  1596. exec_code = Nodes.StatListNode(
  1597. node.pos,
  1598. stats = [
  1599. Nodes.SingleAssignmentNode(
  1600. start.pos,
  1601. lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
  1602. rhs = start,
  1603. first = True),
  1604. loop_node
  1605. ])
  1606. return ExprNodes.InlinedGeneratorExpressionNode(
  1607. gen_expr_node.pos, loop = exec_code, result_node = result_ref,
  1608. expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
  1609. has_local_scope = gen_expr_node.has_local_scope)
  1610. def _handle_simple_function_min(self, node, pos_args):
  1611. return self._optimise_min_max(node, pos_args, '<')
  1612. def _handle_simple_function_max(self, node, pos_args):
  1613. return self._optimise_min_max(node, pos_args, '>')
  1614. def _optimise_min_max(self, node, args, operator):
  1615. """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
  1616. """
  1617. if len(args) <= 1:
  1618. if len(args) == 1 and args[0].is_sequence_constructor:
  1619. args = args[0].args
  1620. if len(args) <= 1:
  1621. # leave this to Python
  1622. return node
  1623. cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
  1624. last_result = args[0]
  1625. for arg_node in cascaded_nodes:
  1626. result_ref = UtilNodes.ResultRefNode(last_result)
  1627. last_result = ExprNodes.CondExprNode(
  1628. arg_node.pos,
  1629. true_val = arg_node,
  1630. false_val = result_ref,
  1631. test = ExprNodes.PrimaryCmpNode(
  1632. arg_node.pos,
  1633. operand1 = arg_node,
  1634. operator = operator,
  1635. operand2 = result_ref,
  1636. )
  1637. )
  1638. last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
  1639. for ref_node in cascaded_nodes[::-1]:
  1640. last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
  1641. return last_result
  1642. # builtin type creation
  1643. def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
  1644. if not pos_args:
  1645. return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
  1646. # This is a bit special - for iterables (including genexps),
  1647. # Python actually overallocates and resizes a newly created
  1648. # tuple incrementally while reading items, which we can't
  1649. # easily do without explicit node support. Instead, we read
  1650. # the items into a list and then copy them into a tuple of the
  1651. # final size. This takes up to twice as much memory, but will
  1652. # have to do until we have real support for genexps.
  1653. result = self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
  1654. if result is not node:
  1655. return ExprNodes.AsTupleNode(node.pos, arg=result)
  1656. return node
  1657. def _handle_simple_function_frozenset(self, node, pos_args):
  1658. """Replace frozenset([...]) by frozenset((...)) as tuples are more efficient.
  1659. """
  1660. if len(pos_args) != 1:
  1661. return node
  1662. if pos_args[0].is_sequence_constructor and not pos_args[0].args:
  1663. del pos_args[0]
  1664. elif isinstance(pos_args[0], ExprNodes.ListNode):
  1665. pos_args[0] = pos_args[0].as_tuple()
  1666. return node
  1667. def _handle_simple_function_list(self, node, pos_args):
  1668. if not pos_args:
  1669. return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
  1670. return self._transform_list_set_genexpr(node, pos_args, Builtin.list_type)
  1671. def _handle_simple_function_set(self, node, pos_args):
  1672. if not pos_args:
  1673. return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
  1674. return self._transform_list_set_genexpr(node, pos_args, Builtin.set_type)
  1675. def _transform_list_set_genexpr(self, node, pos_args, target_type):
  1676. """Replace set(genexpr) and list(genexpr) by an inlined comprehension.
  1677. """
  1678. if len(pos_args) > 1:
  1679. return node
  1680. if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
  1681. return node
  1682. gen_expr_node = pos_args[0]
  1683. loop_node = gen_expr_node.loop
  1684. yield_statements = _find_yield_statements(loop_node)
  1685. if not yield_statements:
  1686. return node
  1687. result_node = ExprNodes.InlinedGeneratorExpressionNode(
  1688. node.pos, gen_expr_node,
  1689. orig_func='set' if target_type is Builtin.set_type else 'list',
  1690. comprehension_type=target_type)
  1691. for yield_expression, yield_stat_node in yield_statements:
  1692. append_node = ExprNodes.ComprehensionAppendNode(
  1693. yield_expression.pos,
  1694. expr=yield_expression,
  1695. target=result_node.target)
  1696. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
  1697. return result_node
  1698. def _handle_simple_function_dict(self, node, pos_args):
  1699. """Replace dict( (a,b) for ... ) by an inlined { a:b for ... }
  1700. """
  1701. if len(pos_args) == 0:
  1702. return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
  1703. if len(pos_args) > 1:
  1704. return node
  1705. if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
  1706. return node
  1707. gen_expr_node = pos_args[0]
  1708. loop_node = gen_expr_node.loop
  1709. yield_statements = _find_yield_statements(loop_node)
  1710. if not yield_statements:
  1711. return node
  1712. for yield_expression, _ in yield_statements:
  1713. if not isinstance(yield_expression, ExprNodes.TupleNode):
  1714. return node
  1715. if len(yield_expression.args) != 2:
  1716. return node
  1717. result_node = ExprNodes.InlinedGeneratorExpressionNode(
  1718. node.pos, gen_expr_node, orig_func='dict',
  1719. comprehension_type=Builtin.dict_type)
  1720. for yield_expression, yield_stat_node in yield_statements:
  1721. append_node = ExprNodes.DictComprehensionAppendNode(
  1722. yield_expression.pos,
  1723. key_expr=yield_expression.args[0],
  1724. value_expr=yield_expression.args[1],
  1725. target=result_node.target)
  1726. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
  1727. return result_node
  1728. # specific handlers for general call nodes
  1729. def _handle_general_function_dict(self, node, pos_args, kwargs):
  1730. """Replace dict(a=b,c=d,...) by the underlying keyword dict
  1731. construction which is done anyway.
  1732. """
  1733. if len(pos_args) > 0:
  1734. return node
  1735. if not isinstance(kwargs, ExprNodes.DictNode):
  1736. return node
  1737. return kwargs
  1738. class InlineDefNodeCalls(Visitor.NodeRefCleanupMixin, Visitor.EnvTransform):
  1739. visit_Node = Visitor.VisitorTransform.recurse_to_children
  1740. def get_constant_value_node(self, name_node):
  1741. if name_node.cf_state is None:
  1742. return None
  1743. if name_node.cf_state.cf_is_null:
  1744. return None
  1745. entry = self.current_env().lookup(name_node.name)
  1746. if not entry or (not entry.cf_assignments
  1747. or len(entry.cf_assignments) != 1):
  1748. # not just a single assignment in all closures
  1749. return None
  1750. return entry.cf_assignments[0].rhs
  1751. def visit_SimpleCallNode(self, node):
  1752. self.visitchildren(node)
  1753. if not self.current_directives.get('optimize.inline_defnode_calls'):
  1754. return node
  1755. function_name = node.function
  1756. if not function_name.is_name:
  1757. return node
  1758. function = self.get_constant_value_node(function_name)
  1759. if not isinstance(function, ExprNodes.PyCFunctionNode):
  1760. return node
  1761. inlined = ExprNodes.InlinedDefNodeCallNode(
  1762. node.pos, function_name=function_name,
  1763. function=function, args=node.args)
  1764. if inlined.can_be_inlined():
  1765. return self.replace(node, inlined)
  1766. return node
  1767. class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
  1768. Visitor.MethodDispatcherTransform):
  1769. """Optimize some common methods calls and instantiation patterns
  1770. for builtin types *after* the type analysis phase.
  1771. Running after type analysis, this transform can only perform
  1772. function replacements that do not alter the function return type
  1773. in a way that was not anticipated by the type analysis.
  1774. """
  1775. ### cleanup to avoid redundant coercions to/from Python types
  1776. def visit_PyTypeTestNode(self, node):
  1777. """Flatten redundant type checks after tree changes.
  1778. """
  1779. self.visitchildren(node)
  1780. return node.reanalyse()
  1781. def _visit_TypecastNode(self, node):
  1782. # disabled - the user may have had a reason to put a type
  1783. # cast, even if it looks redundant to Cython
  1784. """
  1785. Drop redundant type casts.
  1786. """
  1787. self.visitchildren(node)
  1788. if node.type == node.operand.type:
  1789. return node.operand
  1790. return node
  1791. def visit_ExprStatNode(self, node):
  1792. """
  1793. Drop dead code and useless coercions.
  1794. """
  1795. self.visitchildren(node)
  1796. if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
  1797. node.expr = node.expr.arg
  1798. expr = node.expr
  1799. if expr is None or expr.is_none or expr.is_literal:
  1800. # Expression was removed or is dead code => remove ExprStatNode as well.
  1801. return None
  1802. if expr.is_name and expr.entry and (expr.entry.is_local or expr.entry.is_arg):
  1803. # Ignore dead references to local variables etc.
  1804. return None
  1805. return node
  1806. def visit_CoerceToBooleanNode(self, node):
  1807. """Drop redundant conversion nodes after tree changes.
  1808. """
  1809. self.visitchildren(node)
  1810. arg = node.arg
  1811. if isinstance(arg, ExprNodes.PyTypeTestNode):
  1812. arg = arg.arg
  1813. if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
  1814. if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
  1815. return arg.arg.coerce_to_boolean(self.current_env())
  1816. return node
  1817. PyNumber_Float_func_type = PyrexTypes.CFuncType(
  1818. PyrexTypes.py_object_type, [
  1819. PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
  1820. ])
  1821. def visit_CoerceToPyTypeNode(self, node):
  1822. """Drop redundant conversion nodes after tree changes."""
  1823. self.visitchildren(node)
  1824. arg = node.arg
  1825. if isinstance(arg, ExprNodes.CoerceFromPyTypeNode):
  1826. arg = arg.arg
  1827. if isinstance(arg, ExprNodes.PythonCapiCallNode):
  1828. if arg.function.name == 'float' and len(arg.args) == 1:
  1829. # undo redundant Py->C->Py coercion
  1830. func_arg = arg.args[0]
  1831. if func_arg.type is Builtin.float_type:
  1832. return func_arg.as_none_safe_node("float() argument must be a string or a number, not 'NoneType'")
  1833. elif func_arg.type.is_pyobject:
  1834. return ExprNodes.PythonCapiCallNode(
  1835. node.pos, '__Pyx_PyNumber_Float', self.PyNumber_Float_func_type,
  1836. args=[func_arg],
  1837. py_name='float',
  1838. is_temp=node.is_temp,
  1839. result_is_used=node.result_is_used,
  1840. ).coerce_to(node.type, self.current_env())
  1841. return node
  1842. def visit_CoerceFromPyTypeNode(self, node):
  1843. """Drop redundant conversion nodes after tree changes.
  1844. Also, optimise away calls to Python's builtin int() and
  1845. float() if the result is going to be coerced back into a C
  1846. type anyway.
  1847. """
  1848. self.visitchildren(node)
  1849. arg = node.arg
  1850. if not arg.type.is_pyobject:
  1851. # no Python conversion left at all, just do a C coercion instead
  1852. if node.type != arg.type:
  1853. arg = arg.coerce_to(node.type, self.current_env())
  1854. return arg
  1855. if isinstance(arg, ExprNodes.PyTypeTestNode):
  1856. arg = arg.arg
  1857. if arg.is_literal:
  1858. if (node.type.is_int and isinstance(arg, ExprNodes.IntNode) or
  1859. node.type.is_float and isinstance(arg, ExprNodes.FloatNode) or
  1860. node.type.is_int and isinstance(arg, ExprNodes.BoolNode)):
  1861. return arg.coerce_to(node.type, self.current_env())
  1862. elif isinstance(arg, ExprNodes.CoerceToPyTypeNode):
  1863. if arg.type is PyrexTypes.py_object_type:
  1864. if node.type.assignable_from(arg.arg.type):
  1865. # completely redundant C->Py->C coercion
  1866. return arg.arg.coerce_to(node.type, self.current_env())
  1867. elif arg.type is Builtin.unicode_type:
  1868. if arg.arg.type.is_unicode_char and node.type.is_unicode_char:
  1869. return arg.arg.coerce_to(node.type, self.current_env())
  1870. elif isinstance(arg, ExprNodes.SimpleCallNode):
  1871. if node.type.is_int or node.type.is_float:
  1872. return self._optimise_numeric_cast_call(node, arg)
  1873. elif arg.is_subscript:
  1874. index_node = arg.index
  1875. if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
  1876. index_node = index_node.arg
  1877. if index_node.type.is_int:
  1878. return self._optimise_int_indexing(node, arg, index_node)
  1879. return node
  1880. PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
  1881. PyrexTypes.c_char_type, [
  1882. PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
  1883. PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
  1884. PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
  1885. ],
  1886. exception_value = "((char)-1)",
  1887. exception_check = True)
  1888. def _optimise_int_indexing(self, coerce_node, arg, index_node):
  1889. env = self.current_env()
  1890. bound_check_bool = env.directives['boundscheck'] and 1 or 0
  1891. if arg.base.type is Builtin.bytes_type:
  1892. if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
  1893. # bytes[index] -> char
  1894. bound_check_node = ExprNodes.IntNode(
  1895. coerce_node.pos, value=str(bound_check_bool),
  1896. constant_result=bound_check_bool)
  1897. node = ExprNodes.PythonCapiCallNode(
  1898. coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
  1899. self.PyBytes_GetItemInt_func_type,
  1900. args=[
  1901. arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
  1902. index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
  1903. bound_check_node,
  1904. ],
  1905. is_temp=True,
  1906. utility_code=UtilityCode.load_cached(
  1907. 'bytes_index', 'StringTools.c'))
  1908. if coerce_node.type is not PyrexTypes.c_char_type:
  1909. node = node.coerce_to(coerce_node.type, env)
  1910. return node
  1911. return coerce_node
  1912. float_float_func_types = dict(
  1913. (float_type, PyrexTypes.CFuncType(
  1914. float_type, [
  1915. PyrexTypes.CFuncTypeArg("arg", float_type, None)
  1916. ]))
  1917. for float_type in (PyrexTypes.c_float_type, PyrexTypes.c_double_type, PyrexTypes.c_longdouble_type))
  1918. def _optimise_numeric_cast_call(self, node, arg):
  1919. function = arg.function
  1920. args = None
  1921. if isinstance(arg, ExprNodes.PythonCapiCallNode):
  1922. args = arg.args
  1923. elif isinstance(function, ExprNodes.NameNode):
  1924. if function.type.is_builtin_type and isinstance(arg.arg_tuple, ExprNodes.TupleNode):
  1925. args = arg.arg_tuple.args
  1926. if args is None or len(args) != 1:
  1927. return node
  1928. func_arg = args[0]
  1929. if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
  1930. func_arg = func_arg.arg
  1931. elif func_arg.type.is_pyobject:
  1932. # play it safe: Python conversion might work on all sorts of things
  1933. return node
  1934. if function.name == 'int':
  1935. if func_arg.type.is_int or node.type.is_int:
  1936. if func_arg.type == node.type:
  1937. return func_arg
  1938. elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
  1939. return ExprNodes.TypecastNode(node.pos, operand=func_arg, type=node.type)
  1940. elif func_arg.type.is_float and node.type.is_numeric:
  1941. if func_arg.type.math_h_modifier == 'l':
  1942. # Work around missing Cygwin definition.
  1943. truncl = '__Pyx_truncl'
  1944. else:
  1945. truncl = 'trunc' + func_arg.type.math_h_modifier
  1946. return ExprNodes.PythonCapiCallNode(
  1947. node.pos, truncl,
  1948. func_type=self.float_float_func_types[func_arg.type],
  1949. args=[func_arg],
  1950. py_name='int',
  1951. is_temp=node.is_temp,
  1952. result_is_used=node.result_is_used,
  1953. ).coerce_to(node.type, self.current_env())
  1954. elif function.name == 'float':
  1955. if func_arg.type.is_float or node.type.is_float:
  1956. if func_arg.type == node.type:
  1957. return func_arg
  1958. elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
  1959. return ExprNodes.TypecastNode(
  1960. node.pos, operand=func_arg, type=node.type)
  1961. return node
  1962. def _error_wrong_arg_count(self, function_name, node, args, expected=None):
  1963. if not expected: # None or 0
  1964. arg_str = ''
  1965. elif isinstance(expected, basestring) or expected > 1:
  1966. arg_str = '...'
  1967. elif expected == 1:
  1968. arg_str = 'x'
  1969. else:
  1970. arg_str = ''
  1971. if expected is not None:
  1972. expected_str = 'expected %s, ' % expected
  1973. else:
  1974. expected_str = ''
  1975. error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
  1976. function_name, arg_str, expected_str, len(args)))
  1977. ### generic fallbacks
  1978. def _handle_function(self, node, function_name, function, arg_list, kwargs):
  1979. return node
  1980. def _handle_method(self, node, type_name, attr_name, function,
  1981. arg_list, is_unbound_method, kwargs):
  1982. """
  1983. Try to inject C-API calls for unbound method calls to builtin types.
  1984. While the method declarations in Builtin.py already handle this, we
  1985. can additionally resolve bound and unbound methods here that were
  1986. assigned to variables ahead of time.
  1987. """
  1988. if kwargs:
  1989. return node
  1990. if not function or not function.is_attribute or not function.obj.is_name:
  1991. # cannot track unbound method calls over more than one indirection as
  1992. # the names might have been reassigned in the meantime
  1993. return node
  1994. type_entry = self.current_env().lookup(type_name)
  1995. if not type_entry:
  1996. return node
  1997. method = ExprNodes.AttributeNode(
  1998. node.function.pos,
  1999. obj=ExprNodes.NameNode(
  2000. function.pos,
  2001. name=type_name,
  2002. entry=type_entry,
  2003. type=type_entry.type),
  2004. attribute=attr_name,
  2005. is_called=True).analyse_as_type_attribute(self.current_env())
  2006. if method is None:
  2007. return self._optimise_generic_builtin_method_call(
  2008. node, attr_name, function, arg_list, is_unbound_method)
  2009. args = node.args
  2010. if args is None and node.arg_tuple:
  2011. args = node.arg_tuple.args
  2012. call_node = ExprNodes.SimpleCallNode(
  2013. node.pos,
  2014. function=method,
  2015. args=args)
  2016. if not is_unbound_method:
  2017. call_node.self = function.obj
  2018. call_node.analyse_c_function_call(self.current_env())
  2019. call_node.analysed = True
  2020. return call_node.coerce_to(node.type, self.current_env())
  2021. ### builtin types
  2022. def _optimise_generic_builtin_method_call(self, node, attr_name, function, arg_list, is_unbound_method):
  2023. """
  2024. Try to inject an unbound method call for a call to a method of a known builtin type.
  2025. This enables caching the underlying C function of the method at runtime.
  2026. """
  2027. arg_count = len(arg_list)
  2028. if is_unbound_method or arg_count >= 3 or not (function.is_attribute and function.is_py_attr):
  2029. return node
  2030. if not function.obj.type.is_builtin_type:
  2031. return node
  2032. if function.obj.type.name in ('basestring', 'type'):
  2033. # these allow different actual types => unsafe
  2034. return node
  2035. return ExprNodes.CachedBuiltinMethodCallNode(
  2036. node, function.obj, attr_name, arg_list)
  2037. PyObject_Unicode_func_type = PyrexTypes.CFuncType(
  2038. Builtin.unicode_type, [
  2039. PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
  2040. ])
  2041. def _handle_simple_function_unicode(self, node, function, pos_args):
  2042. """Optimise single argument calls to unicode().
  2043. """
  2044. if len(pos_args) != 1:
  2045. if len(pos_args) == 0:
  2046. return ExprNodes.UnicodeNode(node.pos, value=EncodedString(), constant_result=u'')
  2047. return node
  2048. arg = pos_args[0]
  2049. if arg.type is Builtin.unicode_type:
  2050. if not arg.may_be_none():
  2051. return arg
  2052. cname = "__Pyx_PyUnicode_Unicode"
  2053. utility_code = UtilityCode.load_cached('PyUnicode_Unicode', 'StringTools.c')
  2054. else:
  2055. cname = "__Pyx_PyObject_Unicode"
  2056. utility_code = UtilityCode.load_cached('PyObject_Unicode', 'StringTools.c')
  2057. return ExprNodes.PythonCapiCallNode(
  2058. node.pos, cname, self.PyObject_Unicode_func_type,
  2059. args=pos_args,
  2060. is_temp=node.is_temp,
  2061. utility_code=utility_code,
  2062. py_name="unicode")
  2063. def visit_FormattedValueNode(self, node):
  2064. """Simplify or avoid plain string formatting of a unicode value.
  2065. This seems misplaced here, but plain unicode formatting is essentially
  2066. a call to the unicode() builtin, which is optimised right above.
  2067. """
  2068. self.visitchildren(node)
  2069. if node.value.type is Builtin.unicode_type and not node.c_format_spec and not node.format_spec:
  2070. if not node.conversion_char or node.conversion_char == 's':
  2071. # value is definitely a unicode string and we don't format it any special
  2072. return self._handle_simple_function_unicode(node, None, [node.value])
  2073. return node
  2074. PyDict_Copy_func_type = PyrexTypes.CFuncType(
  2075. Builtin.dict_type, [
  2076. PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
  2077. ])
  2078. def _handle_simple_function_dict(self, node, function, pos_args):
  2079. """Replace dict(some_dict) by PyDict_Copy(some_dict).
  2080. """
  2081. if len(pos_args) != 1:
  2082. return node
  2083. arg = pos_args[0]
  2084. if arg.type is Builtin.dict_type:
  2085. arg = arg.as_none_safe_node("'NoneType' is not iterable")
  2086. return ExprNodes.PythonCapiCallNode(
  2087. node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
  2088. args = [arg],
  2089. is_temp = node.is_temp
  2090. )
  2091. return node
  2092. PySequence_List_func_type = PyrexTypes.CFuncType(
  2093. Builtin.list_type,
  2094. [PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)])
  2095. def _handle_simple_function_list(self, node, function, pos_args):
  2096. """Turn list(ob) into PySequence_List(ob).
  2097. """
  2098. if len(pos_args) != 1:
  2099. return node
  2100. arg = pos_args[0]
  2101. return ExprNodes.PythonCapiCallNode(
  2102. node.pos, "PySequence_List", self.PySequence_List_func_type,
  2103. args=pos_args, is_temp=node.is_temp)
  2104. PyList_AsTuple_func_type = PyrexTypes.CFuncType(
  2105. Builtin.tuple_type, [
  2106. PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
  2107. ])
  2108. def _handle_simple_function_tuple(self, node, function, pos_args):
  2109. """Replace tuple([...]) by PyList_AsTuple or PySequence_Tuple.
  2110. """
  2111. if len(pos_args) != 1 or not node.is_temp:
  2112. return node
  2113. arg = pos_args[0]
  2114. if arg.type is Builtin.tuple_type and not arg.may_be_none():
  2115. return arg
  2116. if arg.type is Builtin.list_type:
  2117. pos_args[0] = arg.as_none_safe_node(
  2118. "'NoneType' object is not iterable")
  2119. return ExprNodes.PythonCapiCallNode(
  2120. node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
  2121. args=pos_args, is_temp=node.is_temp)
  2122. else:
  2123. return ExprNodes.AsTupleNode(node.pos, arg=arg, type=Builtin.tuple_type)
  2124. PySet_New_func_type = PyrexTypes.CFuncType(
  2125. Builtin.set_type, [
  2126. PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
  2127. ])
  2128. def _handle_simple_function_set(self, node, function, pos_args):
  2129. if len(pos_args) != 1:
  2130. return node
  2131. if pos_args[0].is_sequence_constructor:
  2132. # We can optimise set([x,y,z]) safely into a set literal,
  2133. # but only if we create all items before adding them -
  2134. # adding an item may raise an exception if it is not
  2135. # hashable, but creating the later items may have
  2136. # side-effects.
  2137. args = []
  2138. temps = []
  2139. for arg in pos_args[0].args:
  2140. if not arg.is_simple():
  2141. arg = UtilNodes.LetRefNode(arg)
  2142. temps.append(arg)
  2143. args.append(arg)
  2144. result = ExprNodes.SetNode(node.pos, is_temp=1, args=args)
  2145. self.replace(node, result)
  2146. for temp in temps[::-1]:
  2147. result = UtilNodes.EvalWithTempExprNode(temp, result)
  2148. return result
  2149. else:
  2150. # PySet_New(it) is better than a generic Python call to set(it)
  2151. return self.replace(node, ExprNodes.PythonCapiCallNode(
  2152. node.pos, "PySet_New",
  2153. self.PySet_New_func_type,
  2154. args=pos_args,
  2155. is_temp=node.is_temp,
  2156. py_name="set"))
  2157. PyFrozenSet_New_func_type = PyrexTypes.CFuncType(
  2158. Builtin.frozenset_type, [
  2159. PyrexTypes.CFuncTypeArg("it", PyrexTypes.py_object_type, None)
  2160. ])
  2161. def _handle_simple_function_frozenset(self, node, function, pos_args):
  2162. if not pos_args:
  2163. pos_args = [ExprNodes.NullNode(node.pos)]
  2164. elif len(pos_args) > 1:
  2165. return node
  2166. elif pos_args[0].type is Builtin.frozenset_type and not pos_args[0].may_be_none():
  2167. return pos_args[0]
  2168. # PyFrozenSet_New(it) is better than a generic Python call to frozenset(it)
  2169. return ExprNodes.PythonCapiCallNode(
  2170. node.pos, "__Pyx_PyFrozenSet_New",
  2171. self.PyFrozenSet_New_func_type,
  2172. args=pos_args,
  2173. is_temp=node.is_temp,
  2174. utility_code=UtilityCode.load_cached('pyfrozenset_new', 'Builtins.c'),
  2175. py_name="frozenset")
  2176. PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
  2177. PyrexTypes.c_double_type, [
  2178. PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
  2179. ],
  2180. exception_value = "((double)-1)",
  2181. exception_check = True)
  2182. def _handle_simple_function_float(self, node, function, pos_args):
  2183. """Transform float() into either a C type cast or a faster C
  2184. function call.
  2185. """
  2186. # Note: this requires the float() function to be typed as
  2187. # returning a C 'double'
  2188. if len(pos_args) == 0:
  2189. return ExprNodes.FloatNode(
  2190. node, value="0.0", constant_result=0.0
  2191. ).coerce_to(Builtin.float_type, self.current_env())
  2192. elif len(pos_args) != 1:
  2193. self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
  2194. return node
  2195. func_arg = pos_args[0]
  2196. if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
  2197. func_arg = func_arg.arg
  2198. if func_arg.type is PyrexTypes.c_double_type:
  2199. return func_arg
  2200. elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
  2201. return ExprNodes.TypecastNode(
  2202. node.pos, operand=func_arg, type=node.type)
  2203. return ExprNodes.PythonCapiCallNode(
  2204. node.pos, "__Pyx_PyObject_AsDouble",
  2205. self.PyObject_AsDouble_func_type,
  2206. args = pos_args,
  2207. is_temp = node.is_temp,
  2208. utility_code = load_c_utility('pyobject_as_double'),
  2209. py_name = "float")
  2210. PyNumber_Int_func_type = PyrexTypes.CFuncType(
  2211. PyrexTypes.py_object_type, [
  2212. PyrexTypes.CFuncTypeArg("o", PyrexTypes.py_object_type, None)
  2213. ])
  2214. PyInt_FromDouble_func_type = PyrexTypes.CFuncType(
  2215. PyrexTypes.py_object_type, [
  2216. PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_double_type, None)
  2217. ])
  2218. def _handle_simple_function_int(self, node, function, pos_args):
  2219. """Transform int() into a faster C function call.
  2220. """
  2221. if len(pos_args) == 0:
  2222. return ExprNodes.IntNode(node.pos, value="0", constant_result=0,
  2223. type=PyrexTypes.py_object_type)
  2224. elif len(pos_args) != 1:
  2225. return node # int(x, base)
  2226. func_arg = pos_args[0]
  2227. if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
  2228. if func_arg.arg.type.is_float:
  2229. return ExprNodes.PythonCapiCallNode(
  2230. node.pos, "__Pyx_PyInt_FromDouble", self.PyInt_FromDouble_func_type,
  2231. args=[func_arg.arg], is_temp=True, py_name='int',
  2232. utility_code=UtilityCode.load_cached("PyIntFromDouble", "TypeConversion.c"))
  2233. else:
  2234. return node # handled in visit_CoerceFromPyTypeNode()
  2235. if func_arg.type.is_pyobject and node.type.is_pyobject:
  2236. return ExprNodes.PythonCapiCallNode(
  2237. node.pos, "__Pyx_PyNumber_Int", self.PyNumber_Int_func_type,
  2238. args=pos_args, is_temp=True, py_name='int')
  2239. return node
  2240. def _handle_simple_function_bool(self, node, function, pos_args):
  2241. """Transform bool(x) into a type coercion to a boolean.
  2242. """
  2243. if len(pos_args) == 0:
  2244. return ExprNodes.BoolNode(
  2245. node.pos, value=False, constant_result=False
  2246. ).coerce_to(Builtin.bool_type, self.current_env())
  2247. elif len(pos_args) != 1:
  2248. self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
  2249. return node
  2250. else:
  2251. # => !!<bint>(x) to make sure it's exactly 0 or 1
  2252. operand = pos_args[0].coerce_to_boolean(self.current_env())
  2253. operand = ExprNodes.NotNode(node.pos, operand = operand)
  2254. operand = ExprNodes.NotNode(node.pos, operand = operand)
  2255. # coerce back to Python object as that's the result we are expecting
  2256. return operand.coerce_to_pyobject(self.current_env())
  2257. ### builtin functions
  2258. Pyx_strlen_func_type = PyrexTypes.CFuncType(
  2259. PyrexTypes.c_size_t_type, [
  2260. PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_const_char_ptr_type, None)
  2261. ])
  2262. Pyx_Py_UNICODE_strlen_func_type = PyrexTypes.CFuncType(
  2263. PyrexTypes.c_size_t_type, [
  2264. PyrexTypes.CFuncTypeArg("unicode", PyrexTypes.c_const_py_unicode_ptr_type, None)
  2265. ])
  2266. PyObject_Size_func_type = PyrexTypes.CFuncType(
  2267. PyrexTypes.c_py_ssize_t_type, [
  2268. PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
  2269. ],
  2270. exception_value="-1")
  2271. _map_to_capi_len_function = {
  2272. Builtin.unicode_type: "__Pyx_PyUnicode_GET_LENGTH",
  2273. Builtin.bytes_type: "PyBytes_GET_SIZE",
  2274. Builtin.bytearray_type: 'PyByteArray_GET_SIZE',
  2275. Builtin.list_type: "PyList_GET_SIZE",
  2276. Builtin.tuple_type: "PyTuple_GET_SIZE",
  2277. Builtin.set_type: "PySet_GET_SIZE",
  2278. Builtin.frozenset_type: "PySet_GET_SIZE",
  2279. Builtin.dict_type: "PyDict_Size",
  2280. }.get
  2281. _ext_types_with_pysize = set(["cpython.array.array"])
  2282. def _handle_simple_function_len(self, node, function, pos_args):
  2283. """Replace len(char*) by the equivalent call to strlen(),
  2284. len(Py_UNICODE) by the equivalent Py_UNICODE_strlen() and
  2285. len(known_builtin_type) by an equivalent C-API call.
  2286. """
  2287. if len(pos_args) != 1:
  2288. self._error_wrong_arg_count('len', node, pos_args, 1)
  2289. return node
  2290. arg = pos_args[0]
  2291. if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
  2292. arg = arg.arg
  2293. if arg.type.is_string:
  2294. new_node = ExprNodes.PythonCapiCallNode(
  2295. node.pos, "strlen", self.Pyx_strlen_func_type,
  2296. args = [arg],
  2297. is_temp = node.is_temp,
  2298. utility_code = UtilityCode.load_cached("IncludeStringH", "StringTools.c"))
  2299. elif arg.type.is_pyunicode_ptr:
  2300. new_node = ExprNodes.PythonCapiCallNode(
  2301. node.pos, "__Pyx_Py_UNICODE_strlen", self.Pyx_Py_UNICODE_strlen_func_type,
  2302. args = [arg],
  2303. is_temp = node.is_temp)
  2304. elif arg.type.is_memoryviewslice:
  2305. func_type = PyrexTypes.CFuncType(
  2306. PyrexTypes.c_size_t_type, [
  2307. PyrexTypes.CFuncTypeArg("memoryviewslice", arg.type, None)
  2308. ], nogil=True)
  2309. new_node = ExprNodes.PythonCapiCallNode(
  2310. node.pos, "__Pyx_MemoryView_Len", func_type,
  2311. args=[arg], is_temp=node.is_temp)
  2312. elif arg.type.is_pyobject:
  2313. cfunc_name = self._map_to_capi_len_function(arg.type)
  2314. if cfunc_name is None:
  2315. arg_type = arg.type
  2316. if ((arg_type.is_extension_type or arg_type.is_builtin_type)
  2317. and arg_type.entry.qualified_name in self._ext_types_with_pysize):
  2318. cfunc_name = 'Py_SIZE'
  2319. else:
  2320. return node
  2321. arg = arg.as_none_safe_node(
  2322. "object of type 'NoneType' has no len()")
  2323. new_node = ExprNodes.PythonCapiCallNode(
  2324. node.pos, cfunc_name, self.PyObject_Size_func_type,
  2325. args=[arg], is_temp=node.is_temp)
  2326. elif arg.type.is_unicode_char:
  2327. return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
  2328. type=node.type)
  2329. else:
  2330. return node
  2331. if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
  2332. new_node = new_node.coerce_to(node.type, self.current_env())
  2333. return new_node
  2334. Pyx_Type_func_type = PyrexTypes.CFuncType(
  2335. Builtin.type_type, [
  2336. PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
  2337. ])
  2338. def _handle_simple_function_type(self, node, function, pos_args):
  2339. """Replace type(o) by a macro call to Py_TYPE(o).
  2340. """
  2341. if len(pos_args) != 1:
  2342. return node
  2343. node = ExprNodes.PythonCapiCallNode(
  2344. node.pos, "Py_TYPE", self.Pyx_Type_func_type,
  2345. args = pos_args,
  2346. is_temp = False)
  2347. return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
  2348. Py_type_check_func_type = PyrexTypes.CFuncType(
  2349. PyrexTypes.c_bint_type, [
  2350. PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
  2351. ])
  2352. def _handle_simple_function_isinstance(self, node, function, pos_args):
  2353. """Replace isinstance() checks against builtin types by the
  2354. corresponding C-API call.
  2355. """
  2356. if len(pos_args) != 2:
  2357. return node
  2358. arg, types = pos_args
  2359. temps = []
  2360. if isinstance(types, ExprNodes.TupleNode):
  2361. types = types.args
  2362. if len(types) == 1 and not types[0].type is Builtin.type_type:
  2363. return node # nothing to improve here
  2364. if arg.is_attribute or not arg.is_simple():
  2365. arg = UtilNodes.ResultRefNode(arg)
  2366. temps.append(arg)
  2367. elif types.type is Builtin.type_type:
  2368. types = [types]
  2369. else:
  2370. return node
  2371. tests = []
  2372. test_nodes = []
  2373. env = self.current_env()
  2374. for test_type_node in types:
  2375. builtin_type = None
  2376. if test_type_node.is_name:
  2377. if test_type_node.entry:
  2378. entry = env.lookup(test_type_node.entry.name)
  2379. if entry and entry.type and entry.type.is_builtin_type:
  2380. builtin_type = entry.type
  2381. if builtin_type is Builtin.type_type:
  2382. # all types have type "type", but there's only one 'type'
  2383. if entry.name != 'type' or not (
  2384. entry.scope and entry.scope.is_builtin_scope):
  2385. builtin_type = None
  2386. if builtin_type is not None:
  2387. type_check_function = entry.type.type_check_function(exact=False)
  2388. if type_check_function in tests:
  2389. continue
  2390. tests.append(type_check_function)
  2391. type_check_args = [arg]
  2392. elif test_type_node.type is Builtin.type_type:
  2393. type_check_function = '__Pyx_TypeCheck'
  2394. type_check_args = [arg, test_type_node]
  2395. else:
  2396. if not test_type_node.is_literal:
  2397. test_type_node = UtilNodes.ResultRefNode(test_type_node)
  2398. temps.append(test_type_node)
  2399. type_check_function = 'PyObject_IsInstance'
  2400. type_check_args = [arg, test_type_node]
  2401. test_nodes.append(
  2402. ExprNodes.PythonCapiCallNode(
  2403. test_type_node.pos, type_check_function, self.Py_type_check_func_type,
  2404. args=type_check_args,
  2405. is_temp=True,
  2406. ))
  2407. def join_with_or(a, b, make_binop_node=ExprNodes.binop_node):
  2408. or_node = make_binop_node(node.pos, 'or', a, b)
  2409. or_node.type = PyrexTypes.c_bint_type
  2410. or_node.wrap_operands(env)
  2411. return or_node
  2412. test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
  2413. for temp in temps[::-1]:
  2414. test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
  2415. return test_node
  2416. def _handle_simple_function_ord(self, node, function, pos_args):
  2417. """Unpack ord(Py_UNICODE) and ord('X').
  2418. """
  2419. if len(pos_args) != 1:
  2420. return node
  2421. arg = pos_args[0]
  2422. if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
  2423. if arg.arg.type.is_unicode_char:
  2424. return ExprNodes.TypecastNode(
  2425. arg.pos, operand=arg.arg, type=PyrexTypes.c_long_type
  2426. ).coerce_to(node.type, self.current_env())
  2427. elif isinstance(arg, ExprNodes.UnicodeNode):
  2428. if len(arg.value) == 1:
  2429. return ExprNodes.IntNode(
  2430. arg.pos, type=PyrexTypes.c_int_type,
  2431. value=str(ord(arg.value)),
  2432. constant_result=ord(arg.value)
  2433. ).coerce_to(node.type, self.current_env())
  2434. elif isinstance(arg, ExprNodes.StringNode):
  2435. if arg.unicode_value and len(arg.unicode_value) == 1 \
  2436. and ord(arg.unicode_value) <= 255: # Py2/3 portability
  2437. return ExprNodes.IntNode(
  2438. arg.pos, type=PyrexTypes.c_int_type,
  2439. value=str(ord(arg.unicode_value)),
  2440. constant_result=ord(arg.unicode_value)
  2441. ).coerce_to(node.type, self.current_env())
  2442. return node
  2443. ### special methods
  2444. Pyx_tp_new_func_type = PyrexTypes.CFuncType(
  2445. PyrexTypes.py_object_type, [
  2446. PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None),
  2447. PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None),
  2448. ])
  2449. Pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
  2450. PyrexTypes.py_object_type, [
  2451. PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None),
  2452. PyrexTypes.CFuncTypeArg("args", Builtin.tuple_type, None),
  2453. PyrexTypes.CFuncTypeArg("kwargs", Builtin.dict_type, None),
  2454. ])
  2455. def _handle_any_slot__new__(self, node, function, args,
  2456. is_unbound_method, kwargs=None):
  2457. """Replace 'exttype.__new__(exttype, ...)' by a call to exttype->tp_new()
  2458. """
  2459. obj = function.obj
  2460. if not is_unbound_method or len(args) < 1:
  2461. return node
  2462. type_arg = args[0]
  2463. if not obj.is_name or not type_arg.is_name:
  2464. # play safe
  2465. return node
  2466. if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
  2467. # not a known type, play safe
  2468. return node
  2469. if not type_arg.type_entry or not obj.type_entry:
  2470. if obj.name != type_arg.name:
  2471. return node
  2472. # otherwise, we know it's a type and we know it's the same
  2473. # type for both - that should do
  2474. elif type_arg.type_entry != obj.type_entry:
  2475. # different types - may or may not lead to an error at runtime
  2476. return node
  2477. args_tuple = ExprNodes.TupleNode(node.pos, args=args[1:])
  2478. args_tuple = args_tuple.analyse_types(
  2479. self.current_env(), skip_children=True)
  2480. if type_arg.type_entry:
  2481. ext_type = type_arg.type_entry.type
  2482. if (ext_type.is_extension_type and ext_type.typeobj_cname and
  2483. ext_type.scope.global_scope() == self.current_env().global_scope()):
  2484. # known type in current module
  2485. tp_slot = TypeSlots.ConstructorSlot("tp_new", '__new__')
  2486. slot_func_cname = TypeSlots.get_slot_function(ext_type.scope, tp_slot)
  2487. if slot_func_cname:
  2488. cython_scope = self.context.cython_scope
  2489. PyTypeObjectPtr = PyrexTypes.CPtrType(
  2490. cython_scope.lookup('PyTypeObject').type)
  2491. pyx_tp_new_kwargs_func_type = PyrexTypes.CFuncType(
  2492. ext_type, [
  2493. PyrexTypes.CFuncTypeArg("type", PyTypeObjectPtr, None),
  2494. PyrexTypes.CFuncTypeArg("args", PyrexTypes.py_object_type, None),
  2495. PyrexTypes.CFuncTypeArg("kwargs", PyrexTypes.py_object_type, None),
  2496. ])
  2497. type_arg = ExprNodes.CastNode(type_arg, PyTypeObjectPtr)
  2498. if not kwargs:
  2499. kwargs = ExprNodes.NullNode(node.pos, type=PyrexTypes.py_object_type) # hack?
  2500. return ExprNodes.PythonCapiCallNode(
  2501. node.pos, slot_func_cname,
  2502. pyx_tp_new_kwargs_func_type,
  2503. args=[type_arg, args_tuple, kwargs],
  2504. may_return_none=False,
  2505. is_temp=True)
  2506. else:
  2507. # arbitrary variable, needs a None check for safety
  2508. type_arg = type_arg.as_none_safe_node(
  2509. "object.__new__(X): X is not a type object (NoneType)")
  2510. utility_code = UtilityCode.load_cached('tp_new', 'ObjectHandling.c')
  2511. if kwargs:
  2512. return ExprNodes.PythonCapiCallNode(
  2513. node.pos, "__Pyx_tp_new_kwargs", self.Pyx_tp_new_kwargs_func_type,
  2514. args=[type_arg, args_tuple, kwargs],
  2515. utility_code=utility_code,
  2516. is_temp=node.is_temp
  2517. )
  2518. else:
  2519. return ExprNodes.PythonCapiCallNode(
  2520. node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
  2521. args=[type_arg, args_tuple],
  2522. utility_code=utility_code,
  2523. is_temp=node.is_temp
  2524. )
  2525. ### methods of builtin types
  2526. PyObject_Append_func_type = PyrexTypes.CFuncType(
  2527. PyrexTypes.c_returncode_type, [
  2528. PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
  2529. PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
  2530. ],
  2531. exception_value="-1")
  2532. def _handle_simple_method_object_append(self, node, function, args, is_unbound_method):
  2533. """Optimistic optimisation as X.append() is almost always
  2534. referring to a list.
  2535. """
  2536. if len(args) != 2 or node.result_is_used:
  2537. return node
  2538. return ExprNodes.PythonCapiCallNode(
  2539. node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
  2540. args=args,
  2541. may_return_none=False,
  2542. is_temp=node.is_temp,
  2543. result_is_used=False,
  2544. utility_code=load_c_utility('append')
  2545. )
  2546. def _handle_simple_method_list_extend(self, node, function, args, is_unbound_method):
  2547. """Replace list.extend([...]) for short sequence literals values by sequential appends
  2548. to avoid creating an intermediate sequence argument.
  2549. """
  2550. if len(args) != 2:
  2551. return node
  2552. obj, value = args
  2553. if not value.is_sequence_constructor:
  2554. return node
  2555. items = list(value.args)
  2556. if value.mult_factor is not None or len(items) > 8:
  2557. # Appending wins for short sequences but slows down when multiple resize operations are needed.
  2558. # This seems to be a good enough limit that avoids repeated resizing.
  2559. if False and isinstance(value, ExprNodes.ListNode):
  2560. # One would expect that tuples are more efficient here, but benchmarking with
  2561. # Py3.5 and Py3.7 suggests that they are not. Probably worth revisiting at some point.
  2562. # Might be related to the usage of PySequence_FAST() in CPython's list.extend(),
  2563. # which is probably tuned more towards lists than tuples (and rightly so).
  2564. tuple_node = args[1].as_tuple().analyse_types(self.current_env(), skip_children=True)
  2565. Visitor.recursively_replace_node(node, args[1], tuple_node)
  2566. return node
  2567. wrapped_obj = self._wrap_self_arg(obj, function, is_unbound_method, 'extend')
  2568. if not items:
  2569. # Empty sequences are not likely to occur, but why waste a call to list.extend() for them?
  2570. wrapped_obj.result_is_used = node.result_is_used
  2571. return wrapped_obj
  2572. cloned_obj = obj = wrapped_obj
  2573. if len(items) > 1 and not obj.is_simple():
  2574. cloned_obj = UtilNodes.LetRefNode(obj)
  2575. # Use ListComp_Append() for all but the last item and finish with PyList_Append()
  2576. # to shrink the list storage size at the very end if necessary.
  2577. temps = []
  2578. arg = items[-1]
  2579. if not arg.is_simple():
  2580. arg = UtilNodes.LetRefNode(arg)
  2581. temps.append(arg)
  2582. new_node = ExprNodes.PythonCapiCallNode(
  2583. node.pos, "__Pyx_PyList_Append", self.PyObject_Append_func_type,
  2584. args=[cloned_obj, arg],
  2585. is_temp=True,
  2586. utility_code=load_c_utility("ListAppend"))
  2587. for arg in items[-2::-1]:
  2588. if not arg.is_simple():
  2589. arg = UtilNodes.LetRefNode(arg)
  2590. temps.append(arg)
  2591. new_node = ExprNodes.binop_node(
  2592. node.pos, '|',
  2593. ExprNodes.PythonCapiCallNode(
  2594. node.pos, "__Pyx_ListComp_Append", self.PyObject_Append_func_type,
  2595. args=[cloned_obj, arg], py_name="extend",
  2596. is_temp=True,
  2597. utility_code=load_c_utility("ListCompAppend")),
  2598. new_node,
  2599. type=PyrexTypes.c_returncode_type,
  2600. )
  2601. new_node.result_is_used = node.result_is_used
  2602. if cloned_obj is not obj:
  2603. temps.append(cloned_obj)
  2604. for temp in temps:
  2605. new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
  2606. new_node.result_is_used = node.result_is_used
  2607. return new_node
  2608. PyByteArray_Append_func_type = PyrexTypes.CFuncType(
  2609. PyrexTypes.c_returncode_type, [
  2610. PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
  2611. PyrexTypes.CFuncTypeArg("value", PyrexTypes.c_int_type, None),
  2612. ],
  2613. exception_value="-1")
  2614. PyByteArray_AppendObject_func_type = PyrexTypes.CFuncType(
  2615. PyrexTypes.c_returncode_type, [
  2616. PyrexTypes.CFuncTypeArg("bytearray", PyrexTypes.py_object_type, None),
  2617. PyrexTypes.CFuncTypeArg("value", PyrexTypes.py_object_type, None),
  2618. ],
  2619. exception_value="-1")
  2620. def _handle_simple_method_bytearray_append(self, node, function, args, is_unbound_method):
  2621. if len(args) != 2:
  2622. return node
  2623. func_name = "__Pyx_PyByteArray_Append"
  2624. func_type = self.PyByteArray_Append_func_type
  2625. value = unwrap_coerced_node(args[1])
  2626. if value.type.is_int or isinstance(value, ExprNodes.IntNode):
  2627. value = value.coerce_to(PyrexTypes.c_int_type, self.current_env())
  2628. utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
  2629. elif value.is_string_literal:
  2630. if not value.can_coerce_to_char_literal():
  2631. return node
  2632. value = value.coerce_to(PyrexTypes.c_char_type, self.current_env())
  2633. utility_code = UtilityCode.load_cached("ByteArrayAppend", "StringTools.c")
  2634. elif value.type.is_pyobject:
  2635. func_name = "__Pyx_PyByteArray_AppendObject"
  2636. func_type = self.PyByteArray_AppendObject_func_type
  2637. utility_code = UtilityCode.load_cached("ByteArrayAppendObject", "StringTools.c")
  2638. else:
  2639. return node
  2640. new_node = ExprNodes.PythonCapiCallNode(
  2641. node.pos, func_name, func_type,
  2642. args=[args[0], value],
  2643. may_return_none=False,
  2644. is_temp=node.is_temp,
  2645. utility_code=utility_code,
  2646. )
  2647. if node.result_is_used:
  2648. new_node = new_node.coerce_to(node.type, self.current_env())
  2649. return new_node
  2650. PyObject_Pop_func_type = PyrexTypes.CFuncType(
  2651. PyrexTypes.py_object_type, [
  2652. PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
  2653. ])
  2654. PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
  2655. PyrexTypes.py_object_type, [
  2656. PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
  2657. PyrexTypes.CFuncTypeArg("py_index", PyrexTypes.py_object_type, None),
  2658. PyrexTypes.CFuncTypeArg("c_index", PyrexTypes.c_py_ssize_t_type, None),
  2659. PyrexTypes.CFuncTypeArg("is_signed", PyrexTypes.c_int_type, None),
  2660. ],
  2661. has_varargs=True) # to fake the additional macro args that lack a proper C type
  2662. def _handle_simple_method_list_pop(self, node, function, args, is_unbound_method):
  2663. return self._handle_simple_method_object_pop(
  2664. node, function, args, is_unbound_method, is_list=True)
  2665. def _handle_simple_method_object_pop(self, node, function, args, is_unbound_method, is_list=False):
  2666. """Optimistic optimisation as X.pop([n]) is almost always
  2667. referring to a list.
  2668. """
  2669. if not args:
  2670. return node
  2671. obj = args[0]
  2672. if is_list:
  2673. type_name = 'List'
  2674. obj = obj.as_none_safe_node(
  2675. "'NoneType' object has no attribute '%.30s'",
  2676. error="PyExc_AttributeError",
  2677. format_args=['pop'])
  2678. else:
  2679. type_name = 'Object'
  2680. if len(args) == 1:
  2681. return ExprNodes.PythonCapiCallNode(
  2682. node.pos, "__Pyx_Py%s_Pop" % type_name,
  2683. self.PyObject_Pop_func_type,
  2684. args=[obj],
  2685. may_return_none=True,
  2686. is_temp=node.is_temp,
  2687. utility_code=load_c_utility('pop'),
  2688. )
  2689. elif len(args) == 2:
  2690. index = unwrap_coerced_node(args[1])
  2691. py_index = ExprNodes.NoneNode(index.pos)
  2692. orig_index_type = index.type
  2693. if not index.type.is_int:
  2694. if isinstance(index, ExprNodes.IntNode):
  2695. py_index = index.coerce_to_pyobject(self.current_env())
  2696. index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  2697. elif is_list:
  2698. if index.type.is_pyobject:
  2699. py_index = index.coerce_to_simple(self.current_env())
  2700. index = ExprNodes.CloneNode(py_index)
  2701. index = index.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  2702. else:
  2703. return node
  2704. elif not PyrexTypes.numeric_type_fits(index.type, PyrexTypes.c_py_ssize_t_type):
  2705. return node
  2706. elif isinstance(index, ExprNodes.IntNode):
  2707. py_index = index.coerce_to_pyobject(self.current_env())
  2708. # real type might still be larger at runtime
  2709. if not orig_index_type.is_int:
  2710. orig_index_type = index.type
  2711. if not orig_index_type.create_to_py_utility_code(self.current_env()):
  2712. return node
  2713. convert_func = orig_index_type.to_py_function
  2714. conversion_type = PyrexTypes.CFuncType(
  2715. PyrexTypes.py_object_type, [PyrexTypes.CFuncTypeArg("intval", orig_index_type, None)])
  2716. return ExprNodes.PythonCapiCallNode(
  2717. node.pos, "__Pyx_Py%s_PopIndex" % type_name,
  2718. self.PyObject_PopIndex_func_type,
  2719. args=[obj, py_index, index,
  2720. ExprNodes.IntNode(index.pos, value=str(orig_index_type.signed and 1 or 0),
  2721. constant_result=orig_index_type.signed and 1 or 0,
  2722. type=PyrexTypes.c_int_type),
  2723. ExprNodes.RawCNameExprNode(index.pos, PyrexTypes.c_void_type,
  2724. orig_index_type.empty_declaration_code()),
  2725. ExprNodes.RawCNameExprNode(index.pos, conversion_type, convert_func)],
  2726. may_return_none=True,
  2727. is_temp=node.is_temp,
  2728. utility_code=load_c_utility("pop_index"),
  2729. )
  2730. return node
  2731. single_param_func_type = PyrexTypes.CFuncType(
  2732. PyrexTypes.c_returncode_type, [
  2733. PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
  2734. ],
  2735. exception_value = "-1")
  2736. def _handle_simple_method_list_sort(self, node, function, args, is_unbound_method):
  2737. """Call PyList_Sort() instead of the 0-argument l.sort().
  2738. """
  2739. if len(args) != 1:
  2740. return node
  2741. return self._substitute_method_call(
  2742. node, function, "PyList_Sort", self.single_param_func_type,
  2743. 'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
  2744. Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
  2745. PyrexTypes.py_object_type, [
  2746. PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
  2747. PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
  2748. PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
  2749. ])
  2750. def _handle_simple_method_dict_get(self, node, function, args, is_unbound_method):
  2751. """Replace dict.get() by a call to PyDict_GetItem().
  2752. """
  2753. if len(args) == 2:
  2754. args.append(ExprNodes.NoneNode(node.pos))
  2755. elif len(args) != 3:
  2756. self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
  2757. return node
  2758. return self._substitute_method_call(
  2759. node, function,
  2760. "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
  2761. 'get', is_unbound_method, args,
  2762. may_return_none = True,
  2763. utility_code = load_c_utility("dict_getitem_default"))
  2764. Pyx_PyDict_SetDefault_func_type = PyrexTypes.CFuncType(
  2765. PyrexTypes.py_object_type, [
  2766. PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
  2767. PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
  2768. PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
  2769. PyrexTypes.CFuncTypeArg("is_safe_type", PyrexTypes.c_int_type, None),
  2770. ])
  2771. def _handle_simple_method_dict_setdefault(self, node, function, args, is_unbound_method):
  2772. """Replace dict.setdefault() by calls to PyDict_GetItem() and PyDict_SetItem().
  2773. """
  2774. if len(args) == 2:
  2775. args.append(ExprNodes.NoneNode(node.pos))
  2776. elif len(args) != 3:
  2777. self._error_wrong_arg_count('dict.setdefault', node, args, "2 or 3")
  2778. return node
  2779. key_type = args[1].type
  2780. if key_type.is_builtin_type:
  2781. is_safe_type = int(key_type.name in
  2782. 'str bytes unicode float int long bool')
  2783. elif key_type is PyrexTypes.py_object_type:
  2784. is_safe_type = -1 # don't know
  2785. else:
  2786. is_safe_type = 0 # definitely not
  2787. args.append(ExprNodes.IntNode(
  2788. node.pos, value=str(is_safe_type), constant_result=is_safe_type))
  2789. return self._substitute_method_call(
  2790. node, function,
  2791. "__Pyx_PyDict_SetDefault", self.Pyx_PyDict_SetDefault_func_type,
  2792. 'setdefault', is_unbound_method, args,
  2793. may_return_none=True,
  2794. utility_code=load_c_utility('dict_setdefault'))
  2795. PyDict_Pop_func_type = PyrexTypes.CFuncType(
  2796. PyrexTypes.py_object_type, [
  2797. PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
  2798. PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
  2799. PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
  2800. ])
  2801. def _handle_simple_method_dict_pop(self, node, function, args, is_unbound_method):
  2802. """Replace dict.pop() by a call to _PyDict_Pop().
  2803. """
  2804. if len(args) == 2:
  2805. args.append(ExprNodes.NullNode(node.pos))
  2806. elif len(args) != 3:
  2807. self._error_wrong_arg_count('dict.pop', node, args, "2 or 3")
  2808. return node
  2809. return self._substitute_method_call(
  2810. node, function,
  2811. "__Pyx_PyDict_Pop", self.PyDict_Pop_func_type,
  2812. 'pop', is_unbound_method, args,
  2813. may_return_none=True,
  2814. utility_code=load_c_utility('py_dict_pop'))
  2815. Pyx_PyInt_BinopInt_func_type = PyrexTypes.CFuncType(
  2816. PyrexTypes.py_object_type, [
  2817. PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None),
  2818. PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None),
  2819. PyrexTypes.CFuncTypeArg("intval", PyrexTypes.c_long_type, None),
  2820. PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None),
  2821. ])
  2822. Pyx_PyFloat_BinopInt_func_type = PyrexTypes.CFuncType(
  2823. PyrexTypes.py_object_type, [
  2824. PyrexTypes.CFuncTypeArg("op1", PyrexTypes.py_object_type, None),
  2825. PyrexTypes.CFuncTypeArg("op2", PyrexTypes.py_object_type, None),
  2826. PyrexTypes.CFuncTypeArg("fval", PyrexTypes.c_double_type, None),
  2827. PyrexTypes.CFuncTypeArg("inplace", PyrexTypes.c_bint_type, None),
  2828. ])
  2829. def _handle_simple_method_object___add__(self, node, function, args, is_unbound_method):
  2830. return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
  2831. def _handle_simple_method_object___sub__(self, node, function, args, is_unbound_method):
  2832. return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
  2833. def _handle_simple_method_object___eq__(self, node, function, args, is_unbound_method):
  2834. return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
  2835. def _handle_simple_method_object___neq__(self, node, function, args, is_unbound_method):
  2836. return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)
  2837. def _handle_simple_method_object___and__(self, node, function, args, is_unbound_method):
  2838. return self._optimise_num_binop('And', node, function, args, is_unbound_method)
  2839. def _handle_simple_method_object___or__(self, node, function, args, is_unbound_method):
  2840. return self._optimise_num_binop('Or', node, function, args, is_unbound_method)
  2841. def _handle_simple_method_object___xor__(self, node, function, args, is_unbound_method):
  2842. return self._optimise_num_binop('Xor', node, function, args, is_unbound_method)
  2843. def _handle_simple_method_object___rshift__(self, node, function, args, is_unbound_method):
  2844. if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
  2845. return node
  2846. if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
  2847. return node
  2848. return self._optimise_num_binop('Rshift', node, function, args, is_unbound_method)
  2849. def _handle_simple_method_object___lshift__(self, node, function, args, is_unbound_method):
  2850. if len(args) != 2 or not isinstance(args[1], ExprNodes.IntNode):
  2851. return node
  2852. if not args[1].has_constant_result() or not (1 <= args[1].constant_result <= 63):
  2853. return node
  2854. return self._optimise_num_binop('Lshift', node, function, args, is_unbound_method)
  2855. def _handle_simple_method_object___mod__(self, node, function, args, is_unbound_method):
  2856. return self._optimise_num_div('Remainder', node, function, args, is_unbound_method)
  2857. def _handle_simple_method_object___floordiv__(self, node, function, args, is_unbound_method):
  2858. return self._optimise_num_div('FloorDivide', node, function, args, is_unbound_method)
  2859. def _handle_simple_method_object___truediv__(self, node, function, args, is_unbound_method):
  2860. return self._optimise_num_div('TrueDivide', node, function, args, is_unbound_method)
  2861. def _handle_simple_method_object___div__(self, node, function, args, is_unbound_method):
  2862. return self._optimise_num_div('Divide', node, function, args, is_unbound_method)
  2863. def _optimise_num_div(self, operator, node, function, args, is_unbound_method):
  2864. if len(args) != 2 or not args[1].has_constant_result() or args[1].constant_result == 0:
  2865. return node
  2866. if isinstance(args[1], ExprNodes.IntNode):
  2867. if not (-2**30 <= args[1].constant_result <= 2**30):
  2868. return node
  2869. elif isinstance(args[1], ExprNodes.FloatNode):
  2870. if not (-2**53 <= args[1].constant_result <= 2**53):
  2871. return node
  2872. else:
  2873. return node
  2874. return self._optimise_num_binop(operator, node, function, args, is_unbound_method)
  2875. def _handle_simple_method_float___add__(self, node, function, args, is_unbound_method):
  2876. return self._optimise_num_binop('Add', node, function, args, is_unbound_method)
  2877. def _handle_simple_method_float___sub__(self, node, function, args, is_unbound_method):
  2878. return self._optimise_num_binop('Subtract', node, function, args, is_unbound_method)
  2879. def _handle_simple_method_float___truediv__(self, node, function, args, is_unbound_method):
  2880. return self._optimise_num_binop('TrueDivide', node, function, args, is_unbound_method)
  2881. def _handle_simple_method_float___div__(self, node, function, args, is_unbound_method):
  2882. return self._optimise_num_binop('Divide', node, function, args, is_unbound_method)
  2883. def _handle_simple_method_float___mod__(self, node, function, args, is_unbound_method):
  2884. return self._optimise_num_binop('Remainder', node, function, args, is_unbound_method)
  2885. def _handle_simple_method_float___eq__(self, node, function, args, is_unbound_method):
  2886. return self._optimise_num_binop('Eq', node, function, args, is_unbound_method)
  2887. def _handle_simple_method_float___neq__(self, node, function, args, is_unbound_method):
  2888. return self._optimise_num_binop('Ne', node, function, args, is_unbound_method)
  2889. def _optimise_num_binop(self, operator, node, function, args, is_unbound_method):
  2890. """
  2891. Optimise math operators for (likely) float or small integer operations.
  2892. """
  2893. if len(args) != 2:
  2894. return node
  2895. if not node.type.is_pyobject:
  2896. return node
  2897. # When adding IntNode/FloatNode to something else, assume other operand is also numeric.
  2898. # Prefer constants on RHS as they allows better size control for some operators.
  2899. num_nodes = (ExprNodes.IntNode, ExprNodes.FloatNode)
  2900. if isinstance(args[1], num_nodes):
  2901. if args[0].type is not PyrexTypes.py_object_type:
  2902. return node
  2903. numval = args[1]
  2904. arg_order = 'ObjC'
  2905. elif isinstance(args[0], num_nodes):
  2906. if args[1].type is not PyrexTypes.py_object_type:
  2907. return node
  2908. numval = args[0]
  2909. arg_order = 'CObj'
  2910. else:
  2911. return node
  2912. if not numval.has_constant_result():
  2913. return node
  2914. is_float = isinstance(numval, ExprNodes.FloatNode)
  2915. if is_float:
  2916. if operator not in ('Add', 'Subtract', 'Remainder', 'TrueDivide', 'Divide', 'Eq', 'Ne'):
  2917. return node
  2918. elif operator == 'Divide':
  2919. # mixed old-/new-style division is not currently optimised for integers
  2920. return node
  2921. elif abs(numval.constant_result) > 2**30:
  2922. return node
  2923. args = list(args)
  2924. args.append((ExprNodes.FloatNode if is_float else ExprNodes.IntNode)(
  2925. numval.pos, value=numval.value, constant_result=numval.constant_result,
  2926. type=PyrexTypes.c_double_type if is_float else PyrexTypes.c_long_type))
  2927. inplace = node.inplace if isinstance(node, ExprNodes.NumBinopNode) else False
  2928. args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace))
  2929. utility_code = TempitaUtilityCode.load_cached(
  2930. "PyFloatBinop" if is_float else "PyIntBinop", "Optimize.c",
  2931. context=dict(op=operator, order=arg_order))
  2932. return self._substitute_method_call(
  2933. node, function, "__Pyx_Py%s_%s%s" % ('Float' if is_float else 'Int', operator, arg_order),
  2934. self.Pyx_PyFloat_BinopInt_func_type if is_float else self.Pyx_PyInt_BinopInt_func_type,
  2935. '__%s__' % operator[:3].lower(), is_unbound_method, args,
  2936. may_return_none=True,
  2937. with_none_check=False,
  2938. utility_code=utility_code)
  2939. ### unicode type methods
  2940. PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
  2941. PyrexTypes.c_bint_type, [
  2942. PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
  2943. ])
  2944. def _inject_unicode_predicate(self, node, function, args, is_unbound_method):
  2945. if is_unbound_method or len(args) != 1:
  2946. return node
  2947. ustring = args[0]
  2948. if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
  2949. not ustring.arg.type.is_unicode_char:
  2950. return node
  2951. uchar = ustring.arg
  2952. method_name = function.attribute
  2953. if method_name == 'istitle':
  2954. # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
  2955. utility_code = UtilityCode.load_cached(
  2956. "py_unicode_istitle", "StringTools.c")
  2957. function_name = '__Pyx_Py_UNICODE_ISTITLE'
  2958. else:
  2959. utility_code = None
  2960. function_name = 'Py_UNICODE_%s' % method_name.upper()
  2961. func_call = self._substitute_method_call(
  2962. node, function,
  2963. function_name, self.PyUnicode_uchar_predicate_func_type,
  2964. method_name, is_unbound_method, [uchar],
  2965. utility_code = utility_code)
  2966. if node.type.is_pyobject:
  2967. func_call = func_call.coerce_to_pyobject(self.current_env)
  2968. return func_call
  2969. _handle_simple_method_unicode_isalnum = _inject_unicode_predicate
  2970. _handle_simple_method_unicode_isalpha = _inject_unicode_predicate
  2971. _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
  2972. _handle_simple_method_unicode_isdigit = _inject_unicode_predicate
  2973. _handle_simple_method_unicode_islower = _inject_unicode_predicate
  2974. _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
  2975. _handle_simple_method_unicode_isspace = _inject_unicode_predicate
  2976. _handle_simple_method_unicode_istitle = _inject_unicode_predicate
  2977. _handle_simple_method_unicode_isupper = _inject_unicode_predicate
  2978. PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
  2979. PyrexTypes.c_py_ucs4_type, [
  2980. PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
  2981. ])
  2982. def _inject_unicode_character_conversion(self, node, function, args, is_unbound_method):
  2983. if is_unbound_method or len(args) != 1:
  2984. return node
  2985. ustring = args[0]
  2986. if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
  2987. not ustring.arg.type.is_unicode_char:
  2988. return node
  2989. uchar = ustring.arg
  2990. method_name = function.attribute
  2991. function_name = 'Py_UNICODE_TO%s' % method_name.upper()
  2992. func_call = self._substitute_method_call(
  2993. node, function,
  2994. function_name, self.PyUnicode_uchar_conversion_func_type,
  2995. method_name, is_unbound_method, [uchar])
  2996. if node.type.is_pyobject:
  2997. func_call = func_call.coerce_to_pyobject(self.current_env)
  2998. return func_call
  2999. _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
  3000. _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
  3001. _handle_simple_method_unicode_title = _inject_unicode_character_conversion
  3002. PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
  3003. Builtin.list_type, [
  3004. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3005. PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
  3006. ])
  3007. def _handle_simple_method_unicode_splitlines(self, node, function, args, is_unbound_method):
  3008. """Replace unicode.splitlines(...) by a direct call to the
  3009. corresponding C-API function.
  3010. """
  3011. if len(args) not in (1,2):
  3012. self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
  3013. return node
  3014. self._inject_bint_default_argument(node, args, 1, False)
  3015. return self._substitute_method_call(
  3016. node, function,
  3017. "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
  3018. 'splitlines', is_unbound_method, args)
  3019. PyUnicode_Split_func_type = PyrexTypes.CFuncType(
  3020. Builtin.list_type, [
  3021. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3022. PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
  3023. PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
  3024. ]
  3025. )
  3026. def _handle_simple_method_unicode_split(self, node, function, args, is_unbound_method):
  3027. """Replace unicode.split(...) by a direct call to the
  3028. corresponding C-API function.
  3029. """
  3030. if len(args) not in (1,2,3):
  3031. self._error_wrong_arg_count('unicode.split', node, args, "1-3")
  3032. return node
  3033. if len(args) < 2:
  3034. args.append(ExprNodes.NullNode(node.pos))
  3035. self._inject_int_default_argument(
  3036. node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
  3037. return self._substitute_method_call(
  3038. node, function,
  3039. "PyUnicode_Split", self.PyUnicode_Split_func_type,
  3040. 'split', is_unbound_method, args)
  3041. PyUnicode_Join_func_type = PyrexTypes.CFuncType(
  3042. Builtin.unicode_type, [
  3043. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3044. PyrexTypes.CFuncTypeArg("seq", PyrexTypes.py_object_type, None),
  3045. ])
  3046. def _handle_simple_method_unicode_join(self, node, function, args, is_unbound_method):
  3047. """
  3048. unicode.join() builds a list first => see if we can do this more efficiently
  3049. """
  3050. if len(args) != 2:
  3051. self._error_wrong_arg_count('unicode.join', node, args, "2")
  3052. return node
  3053. if isinstance(args[1], ExprNodes.GeneratorExpressionNode):
  3054. gen_expr_node = args[1]
  3055. loop_node = gen_expr_node.loop
  3056. yield_statements = _find_yield_statements(loop_node)
  3057. if yield_statements:
  3058. inlined_genexpr = ExprNodes.InlinedGeneratorExpressionNode(
  3059. node.pos, gen_expr_node, orig_func='list',
  3060. comprehension_type=Builtin.list_type)
  3061. for yield_expression, yield_stat_node in yield_statements:
  3062. append_node = ExprNodes.ComprehensionAppendNode(
  3063. yield_expression.pos,
  3064. expr=yield_expression,
  3065. target=inlined_genexpr.target)
  3066. Visitor.recursively_replace_node(gen_expr_node, yield_stat_node, append_node)
  3067. args[1] = inlined_genexpr
  3068. return self._substitute_method_call(
  3069. node, function,
  3070. "PyUnicode_Join", self.PyUnicode_Join_func_type,
  3071. 'join', is_unbound_method, args)
  3072. PyString_Tailmatch_func_type = PyrexTypes.CFuncType(
  3073. PyrexTypes.c_bint_type, [
  3074. PyrexTypes.CFuncTypeArg("str", PyrexTypes.py_object_type, None), # bytes/str/unicode
  3075. PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
  3076. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3077. PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
  3078. PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
  3079. ],
  3080. exception_value = '-1')
  3081. def _handle_simple_method_unicode_endswith(self, node, function, args, is_unbound_method):
  3082. return self._inject_tailmatch(
  3083. node, function, args, is_unbound_method, 'unicode', 'endswith',
  3084. unicode_tailmatch_utility_code, +1)
  3085. def _handle_simple_method_unicode_startswith(self, node, function, args, is_unbound_method):
  3086. return self._inject_tailmatch(
  3087. node, function, args, is_unbound_method, 'unicode', 'startswith',
  3088. unicode_tailmatch_utility_code, -1)
  3089. def _inject_tailmatch(self, node, function, args, is_unbound_method, type_name,
  3090. method_name, utility_code, direction):
  3091. """Replace unicode.startswith(...) and unicode.endswith(...)
  3092. by a direct call to the corresponding C-API function.
  3093. """
  3094. if len(args) not in (2,3,4):
  3095. self._error_wrong_arg_count('%s.%s' % (type_name, method_name), node, args, "2-4")
  3096. return node
  3097. self._inject_int_default_argument(
  3098. node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
  3099. self._inject_int_default_argument(
  3100. node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
  3101. args.append(ExprNodes.IntNode(
  3102. node.pos, value=str(direction), type=PyrexTypes.c_int_type))
  3103. method_call = self._substitute_method_call(
  3104. node, function,
  3105. "__Pyx_Py%s_Tailmatch" % type_name.capitalize(),
  3106. self.PyString_Tailmatch_func_type,
  3107. method_name, is_unbound_method, args,
  3108. utility_code = utility_code)
  3109. return method_call.coerce_to(Builtin.bool_type, self.current_env())
  3110. PyUnicode_Find_func_type = PyrexTypes.CFuncType(
  3111. PyrexTypes.c_py_ssize_t_type, [
  3112. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3113. PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
  3114. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3115. PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
  3116. PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
  3117. ],
  3118. exception_value = '-2')
  3119. def _handle_simple_method_unicode_find(self, node, function, args, is_unbound_method):
  3120. return self._inject_unicode_find(
  3121. node, function, args, is_unbound_method, 'find', +1)
  3122. def _handle_simple_method_unicode_rfind(self, node, function, args, is_unbound_method):
  3123. return self._inject_unicode_find(
  3124. node, function, args, is_unbound_method, 'rfind', -1)
  3125. def _inject_unicode_find(self, node, function, args, is_unbound_method,
  3126. method_name, direction):
  3127. """Replace unicode.find(...) and unicode.rfind(...) by a
  3128. direct call to the corresponding C-API function.
  3129. """
  3130. if len(args) not in (2,3,4):
  3131. self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
  3132. return node
  3133. self._inject_int_default_argument(
  3134. node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
  3135. self._inject_int_default_argument(
  3136. node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
  3137. args.append(ExprNodes.IntNode(
  3138. node.pos, value=str(direction), type=PyrexTypes.c_int_type))
  3139. method_call = self._substitute_method_call(
  3140. node, function, "PyUnicode_Find", self.PyUnicode_Find_func_type,
  3141. method_name, is_unbound_method, args)
  3142. return method_call.coerce_to_pyobject(self.current_env())
  3143. PyUnicode_Count_func_type = PyrexTypes.CFuncType(
  3144. PyrexTypes.c_py_ssize_t_type, [
  3145. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3146. PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
  3147. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3148. PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
  3149. ],
  3150. exception_value = '-1')
  3151. def _handle_simple_method_unicode_count(self, node, function, args, is_unbound_method):
  3152. """Replace unicode.count(...) by a direct call to the
  3153. corresponding C-API function.
  3154. """
  3155. if len(args) not in (2,3,4):
  3156. self._error_wrong_arg_count('unicode.count', node, args, "2-4")
  3157. return node
  3158. self._inject_int_default_argument(
  3159. node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
  3160. self._inject_int_default_argument(
  3161. node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
  3162. method_call = self._substitute_method_call(
  3163. node, function, "PyUnicode_Count", self.PyUnicode_Count_func_type,
  3164. 'count', is_unbound_method, args)
  3165. return method_call.coerce_to_pyobject(self.current_env())
  3166. PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
  3167. Builtin.unicode_type, [
  3168. PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
  3169. PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
  3170. PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
  3171. PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
  3172. ])
  3173. def _handle_simple_method_unicode_replace(self, node, function, args, is_unbound_method):
  3174. """Replace unicode.replace(...) by a direct call to the
  3175. corresponding C-API function.
  3176. """
  3177. if len(args) not in (3,4):
  3178. self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
  3179. return node
  3180. self._inject_int_default_argument(
  3181. node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
  3182. return self._substitute_method_call(
  3183. node, function, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
  3184. 'replace', is_unbound_method, args)
  3185. PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
  3186. Builtin.bytes_type, [
  3187. PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
  3188. PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
  3189. PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
  3190. ])
  3191. PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
  3192. Builtin.bytes_type, [
  3193. PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
  3194. ])
  3195. _special_encodings = ['UTF8', 'UTF16', 'UTF-16LE', 'UTF-16BE', 'Latin1', 'ASCII',
  3196. 'unicode_escape', 'raw_unicode_escape']
  3197. _special_codecs = [ (name, codecs.getencoder(name))
  3198. for name in _special_encodings ]
  3199. def _handle_simple_method_unicode_encode(self, node, function, args, is_unbound_method):
  3200. """Replace unicode.encode(...) by a direct C-API call to the
  3201. corresponding codec.
  3202. """
  3203. if len(args) < 1 or len(args) > 3:
  3204. self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
  3205. return node
  3206. string_node = args[0]
  3207. if len(args) == 1:
  3208. null_node = ExprNodes.NullNode(node.pos)
  3209. return self._substitute_method_call(
  3210. node, function, "PyUnicode_AsEncodedString",
  3211. self.PyUnicode_AsEncodedString_func_type,
  3212. 'encode', is_unbound_method, [string_node, null_node, null_node])
  3213. parameters = self._unpack_encoding_and_error_mode(node.pos, args)
  3214. if parameters is None:
  3215. return node
  3216. encoding, encoding_node, error_handling, error_handling_node = parameters
  3217. if encoding and isinstance(string_node, ExprNodes.UnicodeNode):
  3218. # constant, so try to do the encoding at compile time
  3219. try:
  3220. value = string_node.value.encode(encoding, error_handling)
  3221. except:
  3222. # well, looks like we can't
  3223. pass
  3224. else:
  3225. value = bytes_literal(value, encoding)
  3226. return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type)
  3227. if encoding and error_handling == 'strict':
  3228. # try to find a specific encoder function
  3229. codec_name = self._find_special_codec_name(encoding)
  3230. if codec_name is not None and '-' not in codec_name:
  3231. encode_function = "PyUnicode_As%sString" % codec_name
  3232. return self._substitute_method_call(
  3233. node, function, encode_function,
  3234. self.PyUnicode_AsXyzString_func_type,
  3235. 'encode', is_unbound_method, [string_node])
  3236. return self._substitute_method_call(
  3237. node, function, "PyUnicode_AsEncodedString",
  3238. self.PyUnicode_AsEncodedString_func_type,
  3239. 'encode', is_unbound_method,
  3240. [string_node, encoding_node, error_handling_node])
  3241. PyUnicode_DecodeXyz_func_ptr_type = PyrexTypes.CPtrType(PyrexTypes.CFuncType(
  3242. Builtin.unicode_type, [
  3243. PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
  3244. PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
  3245. PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
  3246. ]))
  3247. _decode_c_string_func_type = PyrexTypes.CFuncType(
  3248. Builtin.unicode_type, [
  3249. PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_const_char_ptr_type, None),
  3250. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3251. PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
  3252. PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
  3253. PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
  3254. PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
  3255. ])
  3256. _decode_bytes_func_type = PyrexTypes.CFuncType(
  3257. Builtin.unicode_type, [
  3258. PyrexTypes.CFuncTypeArg("string", PyrexTypes.py_object_type, None),
  3259. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3260. PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
  3261. PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
  3262. PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
  3263. PyrexTypes.CFuncTypeArg("decode_func", PyUnicode_DecodeXyz_func_ptr_type, None),
  3264. ])
  3265. _decode_cpp_string_func_type = None # lazy init
  3266. def _handle_simple_method_bytes_decode(self, node, function, args, is_unbound_method):
  3267. """Replace char*.decode() by a direct C-API call to the
  3268. corresponding codec, possibly resolving a slice on the char*.
  3269. """
  3270. if not (1 <= len(args) <= 3):
  3271. self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
  3272. return node
  3273. # normalise input nodes
  3274. string_node = args[0]
  3275. start = stop = None
  3276. if isinstance(string_node, ExprNodes.SliceIndexNode):
  3277. index_node = string_node
  3278. string_node = index_node.base
  3279. start, stop = index_node.start, index_node.stop
  3280. if not start or start.constant_result == 0:
  3281. start = None
  3282. if isinstance(string_node, ExprNodes.CoerceToPyTypeNode):
  3283. string_node = string_node.arg
  3284. string_type = string_node.type
  3285. if string_type in (Builtin.bytes_type, Builtin.bytearray_type):
  3286. if is_unbound_method:
  3287. string_node = string_node.as_none_safe_node(
  3288. "descriptor '%s' requires a '%s' object but received a 'NoneType'",
  3289. format_args=['decode', string_type.name])
  3290. else:
  3291. string_node = string_node.as_none_safe_node(
  3292. "'NoneType' object has no attribute '%.30s'",
  3293. error="PyExc_AttributeError",
  3294. format_args=['decode'])
  3295. elif not string_type.is_string and not string_type.is_cpp_string:
  3296. # nothing to optimise here
  3297. return node
  3298. parameters = self._unpack_encoding_and_error_mode(node.pos, args)
  3299. if parameters is None:
  3300. return node
  3301. encoding, encoding_node, error_handling, error_handling_node = parameters
  3302. if not start:
  3303. start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
  3304. elif not start.type.is_int:
  3305. start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  3306. if stop and not stop.type.is_int:
  3307. stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  3308. # try to find a specific encoder function
  3309. codec_name = None
  3310. if encoding is not None:
  3311. codec_name = self._find_special_codec_name(encoding)
  3312. if codec_name is not None:
  3313. if codec_name in ('UTF16', 'UTF-16LE', 'UTF-16BE'):
  3314. codec_cname = "__Pyx_PyUnicode_Decode%s" % codec_name.replace('-', '')
  3315. else:
  3316. codec_cname = "PyUnicode_Decode%s" % codec_name
  3317. decode_function = ExprNodes.RawCNameExprNode(
  3318. node.pos, type=self.PyUnicode_DecodeXyz_func_ptr_type, cname=codec_cname)
  3319. encoding_node = ExprNodes.NullNode(node.pos)
  3320. else:
  3321. decode_function = ExprNodes.NullNode(node.pos)
  3322. # build the helper function call
  3323. temps = []
  3324. if string_type.is_string:
  3325. # C string
  3326. if not stop:
  3327. # use strlen() to find the string length, just as CPython would
  3328. if not string_node.is_name:
  3329. string_node = UtilNodes.LetRefNode(string_node) # used twice
  3330. temps.append(string_node)
  3331. stop = ExprNodes.PythonCapiCallNode(
  3332. string_node.pos, "strlen", self.Pyx_strlen_func_type,
  3333. args=[string_node],
  3334. is_temp=False,
  3335. utility_code=UtilityCode.load_cached("IncludeStringH", "StringTools.c"),
  3336. ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
  3337. helper_func_type = self._decode_c_string_func_type
  3338. utility_code_name = 'decode_c_string'
  3339. elif string_type.is_cpp_string:
  3340. # C++ std::string
  3341. if not stop:
  3342. stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
  3343. constant_result=ExprNodes.not_a_constant)
  3344. if self._decode_cpp_string_func_type is None:
  3345. # lazy init to reuse the C++ string type
  3346. self._decode_cpp_string_func_type = PyrexTypes.CFuncType(
  3347. Builtin.unicode_type, [
  3348. PyrexTypes.CFuncTypeArg("string", string_type, None),
  3349. PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
  3350. PyrexTypes.CFuncTypeArg("stop", PyrexTypes.c_py_ssize_t_type, None),
  3351. PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_const_char_ptr_type, None),
  3352. PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_const_char_ptr_type, None),
  3353. PyrexTypes.CFuncTypeArg("decode_func", self.PyUnicode_DecodeXyz_func_ptr_type, None),
  3354. ])
  3355. helper_func_type = self._decode_cpp_string_func_type
  3356. utility_code_name = 'decode_cpp_string'
  3357. else:
  3358. # Python bytes/bytearray object
  3359. if not stop:
  3360. stop = ExprNodes.IntNode(node.pos, value='PY_SSIZE_T_MAX',
  3361. constant_result=ExprNodes.not_a_constant)
  3362. helper_func_type = self._decode_bytes_func_type
  3363. if string_type is Builtin.bytes_type:
  3364. utility_code_name = 'decode_bytes'
  3365. else:
  3366. utility_code_name = 'decode_bytearray'
  3367. node = ExprNodes.PythonCapiCallNode(
  3368. node.pos, '__Pyx_%s' % utility_code_name, helper_func_type,
  3369. args=[string_node, start, stop, encoding_node, error_handling_node, decode_function],
  3370. is_temp=node.is_temp,
  3371. utility_code=UtilityCode.load_cached(utility_code_name, 'StringTools.c'),
  3372. )
  3373. for temp in temps[::-1]:
  3374. node = UtilNodes.EvalWithTempExprNode(temp, node)
  3375. return node
  3376. _handle_simple_method_bytearray_decode = _handle_simple_method_bytes_decode
  3377. def _find_special_codec_name(self, encoding):
  3378. try:
  3379. requested_codec = codecs.getencoder(encoding)
  3380. except LookupError:
  3381. return None
  3382. for name, codec in self._special_codecs:
  3383. if codec == requested_codec:
  3384. if '_' in name:
  3385. name = ''.join([s.capitalize()
  3386. for s in name.split('_')])
  3387. return name
  3388. return None
  3389. def _unpack_encoding_and_error_mode(self, pos, args):
  3390. null_node = ExprNodes.NullNode(pos)
  3391. if len(args) >= 2:
  3392. encoding, encoding_node = self._unpack_string_and_cstring_node(args[1])
  3393. if encoding_node is None:
  3394. return None
  3395. else:
  3396. encoding = None
  3397. encoding_node = null_node
  3398. if len(args) == 3:
  3399. error_handling, error_handling_node = self._unpack_string_and_cstring_node(args[2])
  3400. if error_handling_node is None:
  3401. return None
  3402. if error_handling == 'strict':
  3403. error_handling_node = null_node
  3404. else:
  3405. error_handling = 'strict'
  3406. error_handling_node = null_node
  3407. return (encoding, encoding_node, error_handling, error_handling_node)
  3408. def _unpack_string_and_cstring_node(self, node):
  3409. if isinstance(node, ExprNodes.CoerceToPyTypeNode):
  3410. node = node.arg
  3411. if isinstance(node, ExprNodes.UnicodeNode):
  3412. encoding = node.value
  3413. node = ExprNodes.BytesNode(
  3414. node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_const_char_ptr_type)
  3415. elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)):
  3416. encoding = node.value.decode('ISO-8859-1')
  3417. node = ExprNodes.BytesNode(
  3418. node.pos, value=node.value, type=PyrexTypes.c_const_char_ptr_type)
  3419. elif node.type is Builtin.bytes_type:
  3420. encoding = None
  3421. node = node.coerce_to(PyrexTypes.c_const_char_ptr_type, self.current_env())
  3422. elif node.type.is_string:
  3423. encoding = None
  3424. else:
  3425. encoding = node = None
  3426. return encoding, node
  3427. def _handle_simple_method_str_endswith(self, node, function, args, is_unbound_method):
  3428. return self._inject_tailmatch(
  3429. node, function, args, is_unbound_method, 'str', 'endswith',
  3430. str_tailmatch_utility_code, +1)
  3431. def _handle_simple_method_str_startswith(self, node, function, args, is_unbound_method):
  3432. return self._inject_tailmatch(
  3433. node, function, args, is_unbound_method, 'str', 'startswith',
  3434. str_tailmatch_utility_code, -1)
  3435. def _handle_simple_method_bytes_endswith(self, node, function, args, is_unbound_method):
  3436. return self._inject_tailmatch(
  3437. node, function, args, is_unbound_method, 'bytes', 'endswith',
  3438. bytes_tailmatch_utility_code, +1)
  3439. def _handle_simple_method_bytes_startswith(self, node, function, args, is_unbound_method):
  3440. return self._inject_tailmatch(
  3441. node, function, args, is_unbound_method, 'bytes', 'startswith',
  3442. bytes_tailmatch_utility_code, -1)
  3443. ''' # disabled for now, enable when we consider it worth it (see StringTools.c)
  3444. def _handle_simple_method_bytearray_endswith(self, node, function, args, is_unbound_method):
  3445. return self._inject_tailmatch(
  3446. node, function, args, is_unbound_method, 'bytearray', 'endswith',
  3447. bytes_tailmatch_utility_code, +1)
  3448. def _handle_simple_method_bytearray_startswith(self, node, function, args, is_unbound_method):
  3449. return self._inject_tailmatch(
  3450. node, function, args, is_unbound_method, 'bytearray', 'startswith',
  3451. bytes_tailmatch_utility_code, -1)
  3452. '''
  3453. ### helpers
  3454. def _substitute_method_call(self, node, function, name, func_type,
  3455. attr_name, is_unbound_method, args=(),
  3456. utility_code=None, is_temp=None,
  3457. may_return_none=ExprNodes.PythonCapiCallNode.may_return_none,
  3458. with_none_check=True):
  3459. args = list(args)
  3460. if with_none_check and args:
  3461. args[0] = self._wrap_self_arg(args[0], function, is_unbound_method, attr_name)
  3462. if is_temp is None:
  3463. is_temp = node.is_temp
  3464. return ExprNodes.PythonCapiCallNode(
  3465. node.pos, name, func_type,
  3466. args = args,
  3467. is_temp = is_temp,
  3468. utility_code = utility_code,
  3469. may_return_none = may_return_none,
  3470. result_is_used = node.result_is_used,
  3471. )
  3472. def _wrap_self_arg(self, self_arg, function, is_unbound_method, attr_name):
  3473. if self_arg.is_literal:
  3474. return self_arg
  3475. if is_unbound_method:
  3476. self_arg = self_arg.as_none_safe_node(
  3477. "descriptor '%s' requires a '%s' object but received a 'NoneType'",
  3478. format_args=[attr_name, self_arg.type.name])
  3479. else:
  3480. self_arg = self_arg.as_none_safe_node(
  3481. "'NoneType' object has no attribute '%{0}s'".format('.30' if len(attr_name) <= 30 else ''),
  3482. error="PyExc_AttributeError",
  3483. format_args=[attr_name])
  3484. return self_arg
  3485. def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
  3486. assert len(args) >= arg_index
  3487. if len(args) == arg_index:
  3488. args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
  3489. type=type, constant_result=default_value))
  3490. else:
  3491. args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
  3492. def _inject_bint_default_argument(self, node, args, arg_index, default_value):
  3493. assert len(args) >= arg_index
  3494. if len(args) == arg_index:
  3495. default_value = bool(default_value)
  3496. args.append(ExprNodes.BoolNode(node.pos, value=default_value,
  3497. constant_result=default_value))
  3498. else:
  3499. args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
  3500. unicode_tailmatch_utility_code = UtilityCode.load_cached('unicode_tailmatch', 'StringTools.c')
  3501. bytes_tailmatch_utility_code = UtilityCode.load_cached('bytes_tailmatch', 'StringTools.c')
  3502. str_tailmatch_utility_code = UtilityCode.load_cached('str_tailmatch', 'StringTools.c')
  3503. class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
  3504. """Calculate the result of constant expressions to store it in
  3505. ``expr_node.constant_result``, and replace trivial cases by their
  3506. constant result.
  3507. General rules:
  3508. - We calculate float constants to make them available to the
  3509. compiler, but we do not aggregate them into a single literal
  3510. node to prevent any loss of precision.
  3511. - We recursively calculate constants from non-literal nodes to
  3512. make them available to the compiler, but we only aggregate
  3513. literal nodes at each step. Non-literal nodes are never merged
  3514. into a single node.
  3515. """
  3516. def __init__(self, reevaluate=False):
  3517. """
  3518. The reevaluate argument specifies whether constant values that were
  3519. previously computed should be recomputed.
  3520. """
  3521. super(ConstantFolding, self).__init__()
  3522. self.reevaluate = reevaluate
  3523. def _calculate_const(self, node):
  3524. if (not self.reevaluate and
  3525. node.constant_result is not ExprNodes.constant_value_not_set):
  3526. return
  3527. # make sure we always set the value
  3528. not_a_constant = ExprNodes.not_a_constant
  3529. node.constant_result = not_a_constant
  3530. # check if all children are constant
  3531. children = self.visitchildren(node)
  3532. for child_result in children.values():
  3533. if type(child_result) is list:
  3534. for child in child_result:
  3535. if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
  3536. return
  3537. elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
  3538. return
  3539. # now try to calculate the real constant value
  3540. try:
  3541. node.calculate_constant_result()
  3542. # if node.constant_result is not ExprNodes.not_a_constant:
  3543. # print node.__class__.__name__, node.constant_result
  3544. except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
  3545. # ignore all 'normal' errors here => no constant result
  3546. pass
  3547. except Exception:
  3548. # this looks like a real error
  3549. import traceback, sys
  3550. traceback.print_exc(file=sys.stdout)
  3551. NODE_TYPE_ORDER = [ExprNodes.BoolNode, ExprNodes.CharNode,
  3552. ExprNodes.IntNode, ExprNodes.FloatNode]
  3553. def _widest_node_class(self, *nodes):
  3554. try:
  3555. return self.NODE_TYPE_ORDER[
  3556. max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
  3557. except ValueError:
  3558. return None
  3559. def _bool_node(self, node, value):
  3560. value = bool(value)
  3561. return ExprNodes.BoolNode(node.pos, value=value, constant_result=value)
  3562. def visit_ExprNode(self, node):
  3563. self._calculate_const(node)
  3564. return node
  3565. def visit_UnopNode(self, node):
  3566. self._calculate_const(node)
  3567. if not node.has_constant_result():
  3568. if node.operator == '!':
  3569. return self._handle_NotNode(node)
  3570. return node
  3571. if not node.operand.is_literal:
  3572. return node
  3573. if node.operator == '!':
  3574. return self._bool_node(node, node.constant_result)
  3575. elif isinstance(node.operand, ExprNodes.BoolNode):
  3576. return ExprNodes.IntNode(node.pos, value=str(int(node.constant_result)),
  3577. type=PyrexTypes.c_int_type,
  3578. constant_result=int(node.constant_result))
  3579. elif node.operator == '+':
  3580. return self._handle_UnaryPlusNode(node)
  3581. elif node.operator == '-':
  3582. return self._handle_UnaryMinusNode(node)
  3583. return node
  3584. _negate_operator = {
  3585. 'in': 'not_in',
  3586. 'not_in': 'in',
  3587. 'is': 'is_not',
  3588. 'is_not': 'is'
  3589. }.get
  3590. def _handle_NotNode(self, node):
  3591. operand = node.operand
  3592. if isinstance(operand, ExprNodes.PrimaryCmpNode):
  3593. operator = self._negate_operator(operand.operator)
  3594. if operator:
  3595. node = copy.copy(operand)
  3596. node.operator = operator
  3597. node = self.visit_PrimaryCmpNode(node)
  3598. return node
  3599. def _handle_UnaryMinusNode(self, node):
  3600. def _negate(value):
  3601. if value.startswith('-'):
  3602. value = value[1:]
  3603. else:
  3604. value = '-' + value
  3605. return value
  3606. node_type = node.operand.type
  3607. if isinstance(node.operand, ExprNodes.FloatNode):
  3608. # this is a safe operation
  3609. return ExprNodes.FloatNode(node.pos, value=_negate(node.operand.value),
  3610. type=node_type,
  3611. constant_result=node.constant_result)
  3612. if node_type.is_int and node_type.signed or \
  3613. isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
  3614. return ExprNodes.IntNode(node.pos, value=_negate(node.operand.value),
  3615. type=node_type,
  3616. longness=node.operand.longness,
  3617. constant_result=node.constant_result)
  3618. return node
  3619. def _handle_UnaryPlusNode(self, node):
  3620. if (node.operand.has_constant_result() and
  3621. node.constant_result == node.operand.constant_result):
  3622. return node.operand
  3623. return node
  3624. def visit_BoolBinopNode(self, node):
  3625. self._calculate_const(node)
  3626. if not node.operand1.has_constant_result():
  3627. return node
  3628. if node.operand1.constant_result:
  3629. if node.operator == 'and':
  3630. return node.operand2
  3631. else:
  3632. return node.operand1
  3633. else:
  3634. if node.operator == 'and':
  3635. return node.operand1
  3636. else:
  3637. return node.operand2
  3638. def visit_BinopNode(self, node):
  3639. self._calculate_const(node)
  3640. if node.constant_result is ExprNodes.not_a_constant:
  3641. return node
  3642. if isinstance(node.constant_result, float):
  3643. return node
  3644. operand1, operand2 = node.operand1, node.operand2
  3645. if not operand1.is_literal or not operand2.is_literal:
  3646. return node
  3647. # now inject a new constant node with the calculated value
  3648. try:
  3649. type1, type2 = operand1.type, operand2.type
  3650. if type1 is None or type2 is None:
  3651. return node
  3652. except AttributeError:
  3653. return node
  3654. if type1.is_numeric and type2.is_numeric:
  3655. widest_type = PyrexTypes.widest_numeric_type(type1, type2)
  3656. else:
  3657. widest_type = PyrexTypes.py_object_type
  3658. target_class = self._widest_node_class(operand1, operand2)
  3659. if target_class is None:
  3660. return node
  3661. elif target_class is ExprNodes.BoolNode and node.operator in '+-//<<%**>>':
  3662. # C arithmetic results in at least an int type
  3663. target_class = ExprNodes.IntNode
  3664. elif target_class is ExprNodes.CharNode and node.operator in '+-//<<%**>>&|^':
  3665. # C arithmetic results in at least an int type
  3666. target_class = ExprNodes.IntNode
  3667. if target_class is ExprNodes.IntNode:
  3668. unsigned = getattr(operand1, 'unsigned', '') and \
  3669. getattr(operand2, 'unsigned', '')
  3670. longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
  3671. len(getattr(operand2, 'longness', '')))]
  3672. new_node = ExprNodes.IntNode(pos=node.pos,
  3673. unsigned=unsigned, longness=longness,
  3674. value=str(int(node.constant_result)),
  3675. constant_result=int(node.constant_result))
  3676. # IntNode is smart about the type it chooses, so we just
  3677. # make sure we were not smarter this time
  3678. if widest_type.is_pyobject or new_node.type.is_pyobject:
  3679. new_node.type = PyrexTypes.py_object_type
  3680. else:
  3681. new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
  3682. else:
  3683. if target_class is ExprNodes.BoolNode:
  3684. node_value = node.constant_result
  3685. else:
  3686. node_value = str(node.constant_result)
  3687. new_node = target_class(pos=node.pos, type = widest_type,
  3688. value = node_value,
  3689. constant_result = node.constant_result)
  3690. return new_node
  3691. def visit_AddNode(self, node):
  3692. self._calculate_const(node)
  3693. if node.constant_result is ExprNodes.not_a_constant:
  3694. return node
  3695. if node.operand1.is_string_literal and node.operand2.is_string_literal:
  3696. # some people combine string literals with a '+'
  3697. str1, str2 = node.operand1, node.operand2
  3698. if isinstance(str1, ExprNodes.UnicodeNode) and isinstance(str2, ExprNodes.UnicodeNode):
  3699. bytes_value = None
  3700. if str1.bytes_value is not None and str2.bytes_value is not None:
  3701. if str1.bytes_value.encoding == str2.bytes_value.encoding:
  3702. bytes_value = bytes_literal(
  3703. str1.bytes_value + str2.bytes_value,
  3704. str1.bytes_value.encoding)
  3705. string_value = EncodedString(node.constant_result)
  3706. return ExprNodes.UnicodeNode(
  3707. str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value)
  3708. elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode):
  3709. if str1.value.encoding == str2.value.encoding:
  3710. bytes_value = bytes_literal(node.constant_result, str1.value.encoding)
  3711. return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result)
  3712. # all other combinations are rather complicated
  3713. # to get right in Py2/3: encodings, unicode escapes, ...
  3714. return self.visit_BinopNode(node)
  3715. def visit_MulNode(self, node):
  3716. self._calculate_const(node)
  3717. if node.operand1.is_sequence_constructor:
  3718. return self._calculate_constant_seq(node, node.operand1, node.operand2)
  3719. if isinstance(node.operand1, ExprNodes.IntNode) and \
  3720. node.operand2.is_sequence_constructor:
  3721. return self._calculate_constant_seq(node, node.operand2, node.operand1)
  3722. if node.operand1.is_string_literal:
  3723. return self._multiply_string(node, node.operand1, node.operand2)
  3724. elif node.operand2.is_string_literal:
  3725. return self._multiply_string(node, node.operand2, node.operand1)
  3726. return self.visit_BinopNode(node)
  3727. def _multiply_string(self, node, string_node, multiplier_node):
  3728. multiplier = multiplier_node.constant_result
  3729. if not isinstance(multiplier, _py_int_types):
  3730. return node
  3731. if not (node.has_constant_result() and isinstance(node.constant_result, _py_string_types)):
  3732. return node
  3733. if len(node.constant_result) > 256:
  3734. # Too long for static creation, leave it to runtime. (-> arbitrary limit)
  3735. return node
  3736. build_string = encoded_string
  3737. if isinstance(string_node, ExprNodes.BytesNode):
  3738. build_string = bytes_literal
  3739. elif isinstance(string_node, ExprNodes.StringNode):
  3740. if string_node.unicode_value is not None:
  3741. string_node.unicode_value = encoded_string(
  3742. string_node.unicode_value * multiplier,
  3743. string_node.unicode_value.encoding)
  3744. elif isinstance(string_node, ExprNodes.UnicodeNode):
  3745. if string_node.bytes_value is not None:
  3746. string_node.bytes_value = bytes_literal(
  3747. string_node.bytes_value * multiplier,
  3748. string_node.bytes_value.encoding)
  3749. else:
  3750. assert False, "unknown string node type: %s" % type(string_node)
  3751. string_node.value = build_string(
  3752. string_node.value * multiplier,
  3753. string_node.value.encoding)
  3754. return string_node
  3755. def _calculate_constant_seq(self, node, sequence_node, factor):
  3756. if factor.constant_result != 1 and sequence_node.args:
  3757. if isinstance(factor.constant_result, _py_int_types) and factor.constant_result <= 0:
  3758. del sequence_node.args[:]
  3759. sequence_node.mult_factor = None
  3760. elif sequence_node.mult_factor is not None:
  3761. if (isinstance(factor.constant_result, _py_int_types) and
  3762. isinstance(sequence_node.mult_factor.constant_result, _py_int_types)):
  3763. value = sequence_node.mult_factor.constant_result * factor.constant_result
  3764. sequence_node.mult_factor = ExprNodes.IntNode(
  3765. sequence_node.mult_factor.pos,
  3766. value=str(value), constant_result=value)
  3767. else:
  3768. # don't know if we can combine the factors, so don't
  3769. return self.visit_BinopNode(node)
  3770. else:
  3771. sequence_node.mult_factor = factor
  3772. return sequence_node
  3773. def visit_ModNode(self, node):
  3774. self.visitchildren(node)
  3775. if isinstance(node.operand1, ExprNodes.UnicodeNode) and isinstance(node.operand2, ExprNodes.TupleNode):
  3776. if not node.operand2.mult_factor:
  3777. fstring = self._build_fstring(node.operand1.pos, node.operand1.value, node.operand2.args)
  3778. if fstring is not None:
  3779. return fstring
  3780. return self.visit_BinopNode(node)
  3781. _parse_string_format_regex = (
  3782. u'(%(?:' # %...
  3783. u'(?:[0-9]+|[ ])?' # width (optional) or space prefix fill character (optional)
  3784. u'(?:[.][0-9]+)?' # precision (optional)
  3785. u')?.)' # format type (or something different for unsupported formats)
  3786. )
  3787. def _build_fstring(self, pos, ustring, format_args):
  3788. # Issues formatting warnings instead of errors since we really only catch a few errors by accident.
  3789. args = iter(format_args)
  3790. substrings = []
  3791. can_be_optimised = True
  3792. for s in re.split(self._parse_string_format_regex, ustring):
  3793. if not s:
  3794. continue
  3795. if s == u'%%':
  3796. substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(u'%'), constant_result=u'%'))
  3797. continue
  3798. if s[0] != u'%':
  3799. if s[-1] == u'%':
  3800. warning(pos, "Incomplete format: '...%s'" % s[-3:], level=1)
  3801. can_be_optimised = False
  3802. substrings.append(ExprNodes.UnicodeNode(pos, value=EncodedString(s), constant_result=s))
  3803. continue
  3804. format_type = s[-1]
  3805. try:
  3806. arg = next(args)
  3807. except StopIteration:
  3808. warning(pos, "Too few arguments for format placeholders", level=1)
  3809. can_be_optimised = False
  3810. break
  3811. if format_type in u'srfdoxX':
  3812. format_spec = s[1:]
  3813. if format_type in u'doxX' and u'.' in format_spec:
  3814. # Precision is not allowed for integers in format(), but ok in %-formatting.
  3815. can_be_optimised = False
  3816. elif format_type in u'rs':
  3817. format_spec = format_spec[:-1]
  3818. substrings.append(ExprNodes.FormattedValueNode(
  3819. arg.pos, value=arg,
  3820. conversion_char=format_type if format_type in u'rs' else None,
  3821. format_spec=ExprNodes.UnicodeNode(
  3822. pos, value=EncodedString(format_spec), constant_result=format_spec)
  3823. if format_spec else None,
  3824. ))
  3825. else:
  3826. # keep it simple for now ...
  3827. can_be_optimised = False
  3828. if not can_be_optimised:
  3829. # Print all warnings we can find before finally giving up here.
  3830. return None
  3831. try:
  3832. next(args)
  3833. except StopIteration: pass
  3834. else:
  3835. warning(pos, "Too many arguments for format placeholders", level=1)
  3836. return None
  3837. node = ExprNodes.JoinedStrNode(pos, values=substrings)
  3838. return self.visit_JoinedStrNode(node)
  3839. def visit_FormattedValueNode(self, node):
  3840. self.visitchildren(node)
  3841. conversion_char = node.conversion_char or 's'
  3842. if isinstance(node.format_spec, ExprNodes.UnicodeNode) and not node.format_spec.value:
  3843. node.format_spec = None
  3844. if node.format_spec is None and isinstance(node.value, ExprNodes.IntNode):
  3845. value = EncodedString(node.value.value)
  3846. if value.isdigit():
  3847. return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value)
  3848. if node.format_spec is None and conversion_char == 's':
  3849. value = None
  3850. if isinstance(node.value, ExprNodes.UnicodeNode):
  3851. value = node.value.value
  3852. elif isinstance(node.value, ExprNodes.StringNode):
  3853. value = node.value.unicode_value
  3854. if value is not None:
  3855. return ExprNodes.UnicodeNode(node.value.pos, value=value, constant_result=value)
  3856. return node
  3857. def visit_JoinedStrNode(self, node):
  3858. """
  3859. Clean up after the parser by discarding empty Unicode strings and merging
  3860. substring sequences. Empty or single-value join lists are not uncommon
  3861. because f-string format specs are always parsed into JoinedStrNodes.
  3862. """
  3863. self.visitchildren(node)
  3864. unicode_node = ExprNodes.UnicodeNode
  3865. values = []
  3866. for is_unode_group, substrings in itertools.groupby(node.values, lambda v: isinstance(v, unicode_node)):
  3867. if is_unode_group:
  3868. substrings = list(substrings)
  3869. unode = substrings[0]
  3870. if len(substrings) > 1:
  3871. value = EncodedString(u''.join(value.value for value in substrings))
  3872. unode = ExprNodes.UnicodeNode(unode.pos, value=value, constant_result=value)
  3873. # ignore empty Unicode strings
  3874. if unode.value:
  3875. values.append(unode)
  3876. else:
  3877. values.extend(substrings)
  3878. if not values:
  3879. value = EncodedString('')
  3880. node = ExprNodes.UnicodeNode(node.pos, value=value, constant_result=value)
  3881. elif len(values) == 1:
  3882. node = values[0]
  3883. elif len(values) == 2:
  3884. # reduce to string concatenation
  3885. node = ExprNodes.binop_node(node.pos, '+', *values)
  3886. else:
  3887. node.values = values
  3888. return node
  3889. def visit_MergedDictNode(self, node):
  3890. """Unpack **args in place if we can."""
  3891. self.visitchildren(node)
  3892. args = []
  3893. items = []
  3894. def add(arg):
  3895. if arg.is_dict_literal:
  3896. if items:
  3897. items[0].key_value_pairs.extend(arg.key_value_pairs)
  3898. else:
  3899. items.append(arg)
  3900. elif isinstance(arg, ExprNodes.MergedDictNode):
  3901. for child_arg in arg.keyword_args:
  3902. add(child_arg)
  3903. else:
  3904. if items:
  3905. args.append(items[0])
  3906. del items[:]
  3907. args.append(arg)
  3908. for arg in node.keyword_args:
  3909. add(arg)
  3910. if items:
  3911. args.append(items[0])
  3912. if len(args) == 1:
  3913. arg = args[0]
  3914. if arg.is_dict_literal or isinstance(arg, ExprNodes.MergedDictNode):
  3915. return arg
  3916. node.keyword_args[:] = args
  3917. self._calculate_const(node)
  3918. return node
  3919. def visit_MergedSequenceNode(self, node):
  3920. """Unpack *args in place if we can."""
  3921. self.visitchildren(node)
  3922. is_set = node.type is Builtin.set_type
  3923. args = []
  3924. values = []
  3925. def add(arg):
  3926. if (is_set and arg.is_set_literal) or (arg.is_sequence_constructor and not arg.mult_factor):
  3927. if values:
  3928. values[0].args.extend(arg.args)
  3929. else:
  3930. values.append(arg)
  3931. elif isinstance(arg, ExprNodes.MergedSequenceNode):
  3932. for child_arg in arg.args:
  3933. add(child_arg)
  3934. else:
  3935. if values:
  3936. args.append(values[0])
  3937. del values[:]
  3938. args.append(arg)
  3939. for arg in node.args:
  3940. add(arg)
  3941. if values:
  3942. args.append(values[0])
  3943. if len(args) == 1:
  3944. arg = args[0]
  3945. if ((is_set and arg.is_set_literal) or
  3946. (arg.is_sequence_constructor and arg.type is node.type) or
  3947. isinstance(arg, ExprNodes.MergedSequenceNode)):
  3948. return arg
  3949. node.args[:] = args
  3950. self._calculate_const(node)
  3951. return node
  3952. def visit_SequenceNode(self, node):
  3953. """Unpack *args in place if we can."""
  3954. self.visitchildren(node)
  3955. args = []
  3956. for arg in node.args:
  3957. if not arg.is_starred:
  3958. args.append(arg)
  3959. elif arg.target.is_sequence_constructor and not arg.target.mult_factor:
  3960. args.extend(arg.target.args)
  3961. else:
  3962. args.append(arg)
  3963. node.args[:] = args
  3964. self._calculate_const(node)
  3965. return node
  3966. def visit_PrimaryCmpNode(self, node):
  3967. # calculate constant partial results in the comparison cascade
  3968. self.visitchildren(node, ['operand1'])
  3969. left_node = node.operand1
  3970. cmp_node = node
  3971. while cmp_node is not None:
  3972. self.visitchildren(cmp_node, ['operand2'])
  3973. right_node = cmp_node.operand2
  3974. cmp_node.constant_result = not_a_constant
  3975. if left_node.has_constant_result() and right_node.has_constant_result():
  3976. try:
  3977. cmp_node.calculate_cascaded_constant_result(left_node.constant_result)
  3978. except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
  3979. pass # ignore all 'normal' errors here => no constant result
  3980. left_node = right_node
  3981. cmp_node = cmp_node.cascade
  3982. if not node.cascade:
  3983. if node.has_constant_result():
  3984. return self._bool_node(node, node.constant_result)
  3985. return node
  3986. # collect partial cascades: [[value, CmpNode...], [value, CmpNode, ...], ...]
  3987. cascades = [[node.operand1]]
  3988. final_false_result = []
  3989. def split_cascades(cmp_node):
  3990. if cmp_node.has_constant_result():
  3991. if not cmp_node.constant_result:
  3992. # False => short-circuit
  3993. final_false_result.append(self._bool_node(cmp_node, False))
  3994. return
  3995. else:
  3996. # True => discard and start new cascade
  3997. cascades.append([cmp_node.operand2])
  3998. else:
  3999. # not constant => append to current cascade
  4000. cascades[-1].append(cmp_node)
  4001. if cmp_node.cascade:
  4002. split_cascades(cmp_node.cascade)
  4003. split_cascades(node)
  4004. cmp_nodes = []
  4005. for cascade in cascades:
  4006. if len(cascade) < 2:
  4007. continue
  4008. cmp_node = cascade[1]
  4009. pcmp_node = ExprNodes.PrimaryCmpNode(
  4010. cmp_node.pos,
  4011. operand1=cascade[0],
  4012. operator=cmp_node.operator,
  4013. operand2=cmp_node.operand2,
  4014. constant_result=not_a_constant)
  4015. cmp_nodes.append(pcmp_node)
  4016. last_cmp_node = pcmp_node
  4017. for cmp_node in cascade[2:]:
  4018. last_cmp_node.cascade = cmp_node
  4019. last_cmp_node = cmp_node
  4020. last_cmp_node.cascade = None
  4021. if final_false_result:
  4022. # last cascade was constant False
  4023. cmp_nodes.append(final_false_result[0])
  4024. elif not cmp_nodes:
  4025. # only constants, but no False result
  4026. return self._bool_node(node, True)
  4027. node = cmp_nodes[0]
  4028. if len(cmp_nodes) == 1:
  4029. if node.has_constant_result():
  4030. return self._bool_node(node, node.constant_result)
  4031. else:
  4032. for cmp_node in cmp_nodes[1:]:
  4033. node = ExprNodes.BoolBinopNode(
  4034. node.pos,
  4035. operand1=node,
  4036. operator='and',
  4037. operand2=cmp_node,
  4038. constant_result=not_a_constant)
  4039. return node
  4040. def visit_CondExprNode(self, node):
  4041. self._calculate_const(node)
  4042. if not node.test.has_constant_result():
  4043. return node
  4044. if node.test.constant_result:
  4045. return node.true_val
  4046. else:
  4047. return node.false_val
  4048. def visit_IfStatNode(self, node):
  4049. self.visitchildren(node)
  4050. # eliminate dead code based on constant condition results
  4051. if_clauses = []
  4052. for if_clause in node.if_clauses:
  4053. condition = if_clause.condition
  4054. if condition.has_constant_result():
  4055. if condition.constant_result:
  4056. # always true => subsequent clauses can safely be dropped
  4057. node.else_clause = if_clause.body
  4058. break
  4059. # else: false => drop clause
  4060. else:
  4061. # unknown result => normal runtime evaluation
  4062. if_clauses.append(if_clause)
  4063. if if_clauses:
  4064. node.if_clauses = if_clauses
  4065. return node
  4066. elif node.else_clause:
  4067. return node.else_clause
  4068. else:
  4069. return Nodes.StatListNode(node.pos, stats=[])
  4070. def visit_SliceIndexNode(self, node):
  4071. self._calculate_const(node)
  4072. # normalise start/stop values
  4073. if node.start is None or node.start.constant_result is None:
  4074. start = node.start = None
  4075. else:
  4076. start = node.start.constant_result
  4077. if node.stop is None or node.stop.constant_result is None:
  4078. stop = node.stop = None
  4079. else:
  4080. stop = node.stop.constant_result
  4081. # cut down sliced constant sequences
  4082. if node.constant_result is not not_a_constant:
  4083. base = node.base
  4084. if base.is_sequence_constructor and base.mult_factor is None:
  4085. base.args = base.args[start:stop]
  4086. return base
  4087. elif base.is_string_literal:
  4088. base = base.as_sliced_node(start, stop)
  4089. if base is not None:
  4090. return base
  4091. return node
  4092. def visit_ComprehensionNode(self, node):
  4093. self.visitchildren(node)
  4094. if isinstance(node.loop, Nodes.StatListNode) and not node.loop.stats:
  4095. # loop was pruned already => transform into literal
  4096. if node.type is Builtin.list_type:
  4097. return ExprNodes.ListNode(
  4098. node.pos, args=[], constant_result=[])
  4099. elif node.type is Builtin.set_type:
  4100. return ExprNodes.SetNode(
  4101. node.pos, args=[], constant_result=set())
  4102. elif node.type is Builtin.dict_type:
  4103. return ExprNodes.DictNode(
  4104. node.pos, key_value_pairs=[], constant_result={})
  4105. return node
  4106. def visit_ForInStatNode(self, node):
  4107. self.visitchildren(node)
  4108. sequence = node.iterator.sequence
  4109. if isinstance(sequence, ExprNodes.SequenceNode):
  4110. if not sequence.args:
  4111. if node.else_clause:
  4112. return node.else_clause
  4113. else:
  4114. # don't break list comprehensions
  4115. return Nodes.StatListNode(node.pos, stats=[])
  4116. # iterating over a list literal? => tuples are more efficient
  4117. if isinstance(sequence, ExprNodes.ListNode):
  4118. node.iterator.sequence = sequence.as_tuple()
  4119. return node
  4120. def visit_WhileStatNode(self, node):
  4121. self.visitchildren(node)
  4122. if node.condition and node.condition.has_constant_result():
  4123. if node.condition.constant_result:
  4124. node.condition = None
  4125. node.else_clause = None
  4126. else:
  4127. return node.else_clause
  4128. return node
  4129. def visit_ExprStatNode(self, node):
  4130. self.visitchildren(node)
  4131. if not isinstance(node.expr, ExprNodes.ExprNode):
  4132. # ParallelRangeTransform does this ...
  4133. return node
  4134. # drop unused constant expressions
  4135. if node.expr.has_constant_result():
  4136. return None
  4137. return node
  4138. # in the future, other nodes can have their own handler method here
  4139. # that can replace them with a constant result node
  4140. visit_Node = Visitor.VisitorTransform.recurse_to_children
  4141. class FinalOptimizePhase(Visitor.EnvTransform, Visitor.NodeRefCleanupMixin):
  4142. """
  4143. This visitor handles several commuting optimizations, and is run
  4144. just before the C code generation phase.
  4145. The optimizations currently implemented in this class are:
  4146. - eliminate None assignment and refcounting for first assignment.
  4147. - isinstance -> typecheck for cdef types
  4148. - eliminate checks for None and/or types that became redundant after tree changes
  4149. - eliminate useless string formatting steps
  4150. - replace Python function calls that look like method calls by a faster PyMethodCallNode
  4151. """
  4152. in_loop = False
  4153. def visit_SingleAssignmentNode(self, node):
  4154. """Avoid redundant initialisation of local variables before their
  4155. first assignment.
  4156. """
  4157. self.visitchildren(node)
  4158. if node.first:
  4159. lhs = node.lhs
  4160. lhs.lhs_of_first_assignment = True
  4161. return node
  4162. def visit_SimpleCallNode(self, node):
  4163. """
  4164. Replace generic calls to isinstance(x, type) by a more efficient type check.
  4165. Replace likely Python method calls by a specialised PyMethodCallNode.
  4166. """
  4167. self.visitchildren(node)
  4168. function = node.function
  4169. if function.type.is_cfunction and function.is_name:
  4170. if function.name == 'isinstance' and len(node.args) == 2:
  4171. type_arg = node.args[1]
  4172. if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
  4173. cython_scope = self.context.cython_scope
  4174. function.entry = cython_scope.lookup('PyObject_TypeCheck')
  4175. function.type = function.entry.type
  4176. PyTypeObjectPtr = PyrexTypes.CPtrType(cython_scope.lookup('PyTypeObject').type)
  4177. node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
  4178. elif (node.is_temp and function.type.is_pyobject and self.current_directives.get(
  4179. "optimize.unpack_method_calls_in_pyinit"
  4180. if not self.in_loop and self.current_env().is_module_scope
  4181. else "optimize.unpack_method_calls")):
  4182. # optimise simple Python methods calls
  4183. if isinstance(node.arg_tuple, ExprNodes.TupleNode) and not (
  4184. node.arg_tuple.mult_factor or (node.arg_tuple.is_literal and node.arg_tuple.args)):
  4185. # simple call, now exclude calls to objects that are definitely not methods
  4186. may_be_a_method = True
  4187. if function.type is Builtin.type_type:
  4188. may_be_a_method = False
  4189. elif function.is_attribute:
  4190. if function.entry and function.entry.type.is_cfunction:
  4191. # optimised builtin method
  4192. may_be_a_method = False
  4193. elif function.is_name:
  4194. entry = function.entry
  4195. if entry.is_builtin or entry.type.is_cfunction:
  4196. may_be_a_method = False
  4197. elif entry.cf_assignments:
  4198. # local functions/classes are definitely not methods
  4199. non_method_nodes = (ExprNodes.PyCFunctionNode, ExprNodes.ClassNode, ExprNodes.Py3ClassNode)
  4200. may_be_a_method = any(
  4201. assignment.rhs and not isinstance(assignment.rhs, non_method_nodes)
  4202. for assignment in entry.cf_assignments)
  4203. if may_be_a_method:
  4204. if (node.self and function.is_attribute and
  4205. isinstance(function.obj, ExprNodes.CloneNode) and function.obj.arg is node.self):
  4206. # function self object was moved into a CloneNode => undo
  4207. function.obj = function.obj.arg
  4208. node = self.replace(node, ExprNodes.PyMethodCallNode.from_node(
  4209. node, function=function, arg_tuple=node.arg_tuple, type=node.type))
  4210. return node
  4211. def visit_NumPyMethodCallNode(self, node):
  4212. # Exclude from replacement above.
  4213. self.visitchildren(node)
  4214. return node
  4215. def visit_PyTypeTestNode(self, node):
  4216. """Remove tests for alternatively allowed None values from
  4217. type tests when we know that the argument cannot be None
  4218. anyway.
  4219. """
  4220. self.visitchildren(node)
  4221. if not node.notnone:
  4222. if not node.arg.may_be_none():
  4223. node.notnone = True
  4224. return node
  4225. def visit_NoneCheckNode(self, node):
  4226. """Remove None checks from expressions that definitely do not
  4227. carry a None value.
  4228. """
  4229. self.visitchildren(node)
  4230. if not node.arg.may_be_none():
  4231. return node.arg
  4232. return node
  4233. def visit_LoopNode(self, node):
  4234. """Remember when we enter a loop as some expensive optimisations might still be worth it there.
  4235. """
  4236. old_val = self.in_loop
  4237. self.in_loop = True
  4238. self.visitchildren(node)
  4239. self.in_loop = old_val
  4240. return node
  4241. class ConsolidateOverflowCheck(Visitor.CythonTransform):
  4242. """
  4243. This class facilitates the sharing of overflow checking among all nodes
  4244. of a nested arithmetic expression. For example, given the expression
  4245. a*b + c, where a, b, and x are all possibly overflowing ints, the entire
  4246. sequence will be evaluated and the overflow bit checked only at the end.
  4247. """
  4248. overflow_bit_node = None
  4249. def visit_Node(self, node):
  4250. if self.overflow_bit_node is not None:
  4251. saved = self.overflow_bit_node
  4252. self.overflow_bit_node = None
  4253. self.visitchildren(node)
  4254. self.overflow_bit_node = saved
  4255. else:
  4256. self.visitchildren(node)
  4257. return node
  4258. def visit_NumBinopNode(self, node):
  4259. if node.overflow_check and node.overflow_fold:
  4260. top_level_overflow = self.overflow_bit_node is None
  4261. if top_level_overflow:
  4262. self.overflow_bit_node = node
  4263. else:
  4264. node.overflow_bit_node = self.overflow_bit_node
  4265. node.overflow_check = False
  4266. self.visitchildren(node)
  4267. if top_level_overflow:
  4268. self.overflow_bit_node = None
  4269. else:
  4270. self.visitchildren(node)
  4271. return node