Pipeline.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368
  1. from __future__ import absolute_import
  2. import itertools
  3. from time import time
  4. from . import Errors
  5. from . import DebugFlags
  6. from . import Options
  7. from .Errors import CompileError, InternalError, AbortError
  8. from . import Naming
  9. #
  10. # Really small pipeline stages
  11. #
  12. def dumptree(t):
  13. # For quick debugging in pipelines
  14. print(t.dump())
  15. return t
  16. def abort_on_errors(node):
  17. # Stop the pipeline if there are any errors.
  18. if Errors.num_errors != 0:
  19. raise AbortError("pipeline break")
  20. return node
  21. def parse_stage_factory(context):
  22. def parse(compsrc):
  23. source_desc = compsrc.source_desc
  24. full_module_name = compsrc.full_module_name
  25. initial_pos = (source_desc, 1, 0)
  26. saved_cimport_from_pyx, Options.cimport_from_pyx = Options.cimport_from_pyx, False
  27. scope = context.find_module(full_module_name, pos = initial_pos, need_pxd = 0)
  28. Options.cimport_from_pyx = saved_cimport_from_pyx
  29. tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name)
  30. tree.compilation_source = compsrc
  31. tree.scope = scope
  32. tree.is_pxd = False
  33. return tree
  34. return parse
  35. def parse_pxd_stage_factory(context, scope, module_name):
  36. def parse(source_desc):
  37. tree = context.parse(source_desc, scope, pxd=True,
  38. full_module_name=module_name)
  39. tree.scope = scope
  40. tree.is_pxd = True
  41. return tree
  42. return parse
  43. def generate_pyx_code_stage_factory(options, result):
  44. def generate_pyx_code_stage(module_node):
  45. module_node.process_implementation(options, result)
  46. result.compilation_source = module_node.compilation_source
  47. return result
  48. return generate_pyx_code_stage
  49. def inject_pxd_code_stage_factory(context):
  50. def inject_pxd_code_stage(module_node):
  51. for name, (statlistnode, scope) in context.pxds.items():
  52. module_node.merge_in(statlistnode, scope)
  53. return module_node
  54. return inject_pxd_code_stage
  55. def use_utility_code_definitions(scope, target, seen=None):
  56. if seen is None:
  57. seen = set()
  58. for entry in scope.entries.values():
  59. if entry in seen:
  60. continue
  61. seen.add(entry)
  62. if entry.used and entry.utility_code_definition:
  63. target.use_utility_code(entry.utility_code_definition)
  64. for required_utility in entry.utility_code_definition.requires:
  65. target.use_utility_code(required_utility)
  66. elif entry.as_module:
  67. use_utility_code_definitions(entry.as_module, target, seen)
  68. def sort_utility_codes(utilcodes):
  69. ranks = {}
  70. def get_rank(utilcode):
  71. if utilcode not in ranks:
  72. ranks[utilcode] = 0 # prevent infinite recursion on circular dependencies
  73. original_order = len(ranks)
  74. ranks[utilcode] = 1 + min([get_rank(dep) for dep in utilcode.requires or ()] or [-1]) + original_order * 1e-8
  75. return ranks[utilcode]
  76. for utilcode in utilcodes:
  77. get_rank(utilcode)
  78. return [utilcode for utilcode, _ in sorted(ranks.items(), key=lambda kv: kv[1])]
  79. def normalize_deps(utilcodes):
  80. deps = {}
  81. for utilcode in utilcodes:
  82. deps[utilcode] = utilcode
  83. def unify_dep(dep):
  84. if dep in deps:
  85. return deps[dep]
  86. else:
  87. deps[dep] = dep
  88. return dep
  89. for utilcode in utilcodes:
  90. utilcode.requires = [unify_dep(dep) for dep in utilcode.requires or ()]
  91. def inject_utility_code_stage_factory(context):
  92. def inject_utility_code_stage(module_node):
  93. module_node.prepare_utility_code()
  94. use_utility_code_definitions(context.cython_scope, module_node.scope)
  95. module_node.scope.utility_code_list = sort_utility_codes(module_node.scope.utility_code_list)
  96. normalize_deps(module_node.scope.utility_code_list)
  97. added = []
  98. # Note: the list might be extended inside the loop (if some utility code
  99. # pulls in other utility code, explicitly or implicitly)
  100. for utilcode in module_node.scope.utility_code_list:
  101. if utilcode in added:
  102. continue
  103. added.append(utilcode)
  104. if utilcode.requires:
  105. for dep in utilcode.requires:
  106. if dep not in added and dep not in module_node.scope.utility_code_list:
  107. module_node.scope.utility_code_list.append(dep)
  108. tree = utilcode.get_tree(cython_scope=context.cython_scope)
  109. if tree:
  110. module_node.merge_in(tree.body, tree.scope, merge_scope=True)
  111. return module_node
  112. return inject_utility_code_stage
  113. #
  114. # Pipeline factories
  115. #
  116. def create_pipeline(context, mode, exclude_classes=()):
  117. assert mode in ('pyx', 'py', 'pxd')
  118. from .Visitor import PrintTree
  119. from .ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
  120. from .ParseTreeTransforms import ForwardDeclareTypes, AnalyseDeclarationsTransform
  121. from .ParseTreeTransforms import AnalyseExpressionsTransform, FindInvalidUseOfFusedTypes
  122. from .ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
  123. from .ParseTreeTransforms import TrackNumpyAttributes, InterpretCompilerDirectives, TransformBuiltinMethods
  124. from .ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
  125. from .ParseTreeTransforms import CalculateQualifiedNamesTransform
  126. from .TypeInference import MarkParallelAssignments, MarkOverflowingArithmetic
  127. from .ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions
  128. from .ParseTreeTransforms import RemoveUnreachableCode, GilCheck
  129. from .FlowControl import ControlFlowAnalysis
  130. from .AnalysedTreeTransforms import AutoTestDictTransform
  131. from .AutoDocTransforms import EmbedSignature
  132. from .Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
  133. from .Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls
  134. from .Optimize import InlineDefNodeCalls
  135. from .Optimize import ConstantFolding, FinalOptimizePhase
  136. from .Optimize import DropRefcountingTransform
  137. from .Optimize import ConsolidateOverflowCheck
  138. from .Buffer import IntroduceBufferAuxiliaryVars
  139. from .ModuleNode import check_c_declarations, check_c_declarations_pxd
  140. if mode == 'pxd':
  141. _check_c_declarations = check_c_declarations_pxd
  142. _specific_post_parse = PxdPostParse(context)
  143. else:
  144. _check_c_declarations = check_c_declarations
  145. _specific_post_parse = None
  146. if mode == 'py':
  147. _align_function_definitions = AlignFunctionDefinitions(context)
  148. else:
  149. _align_function_definitions = None
  150. # NOTE: This is the "common" parts of the pipeline, which is also
  151. # code in pxd files. So it will be run multiple times in a
  152. # compilation stage.
  153. stages = [
  154. NormalizeTree(context),
  155. PostParse(context),
  156. _specific_post_parse,
  157. TrackNumpyAttributes(),
  158. InterpretCompilerDirectives(context, context.compiler_directives),
  159. ParallelRangeTransform(context),
  160. AdjustDefByDirectives(context),
  161. WithTransform(context),
  162. MarkClosureVisitor(context),
  163. _align_function_definitions,
  164. RemoveUnreachableCode(context),
  165. ConstantFolding(),
  166. FlattenInListTransform(),
  167. DecoratorTransform(context),
  168. ForwardDeclareTypes(context),
  169. AnalyseDeclarationsTransform(context),
  170. AutoTestDictTransform(context),
  171. EmbedSignature(context),
  172. EarlyReplaceBuiltinCalls(context), ## Necessary?
  173. TransformBuiltinMethods(context),
  174. MarkParallelAssignments(context),
  175. ControlFlowAnalysis(context),
  176. RemoveUnreachableCode(context),
  177. # MarkParallelAssignments(context),
  178. MarkOverflowingArithmetic(context),
  179. IntroduceBufferAuxiliaryVars(context),
  180. _check_c_declarations,
  181. InlineDefNodeCalls(context),
  182. AnalyseExpressionsTransform(context),
  183. FindInvalidUseOfFusedTypes(context),
  184. ExpandInplaceOperators(context),
  185. IterationTransform(context),
  186. SwitchTransform(context),
  187. OptimizeBuiltinCalls(context), ## Necessary?
  188. CreateClosureClasses(context), ## After all lookups and type inference
  189. CalculateQualifiedNamesTransform(context),
  190. ConsolidateOverflowCheck(context),
  191. DropRefcountingTransform(),
  192. FinalOptimizePhase(context),
  193. GilCheck(),
  194. ]
  195. filtered_stages = []
  196. for s in stages:
  197. if s.__class__ not in exclude_classes:
  198. filtered_stages.append(s)
  199. return filtered_stages
  200. def create_pyx_pipeline(context, options, result, py=False, exclude_classes=()):
  201. if py:
  202. mode = 'py'
  203. else:
  204. mode = 'pyx'
  205. test_support = []
  206. if options.evaluate_tree_assertions:
  207. from ..TestUtils import TreeAssertVisitor
  208. test_support.append(TreeAssertVisitor())
  209. if options.gdb_debug:
  210. from ..Debugger import DebugWriter # requires Py2.5+
  211. from .ParseTreeTransforms import DebugTransform
  212. context.gdb_debug_outputwriter = DebugWriter.CythonDebugWriter(
  213. options.output_dir)
  214. debug_transform = [DebugTransform(context, options, result)]
  215. else:
  216. debug_transform = []
  217. return list(itertools.chain(
  218. [parse_stage_factory(context)],
  219. create_pipeline(context, mode, exclude_classes=exclude_classes),
  220. test_support,
  221. [inject_pxd_code_stage_factory(context),
  222. inject_utility_code_stage_factory(context),
  223. abort_on_errors],
  224. debug_transform,
  225. [generate_pyx_code_stage_factory(options, result)]))
  226. def create_pxd_pipeline(context, scope, module_name):
  227. from .CodeGeneration import ExtractPxdCode
  228. # The pxd pipeline ends up with a CCodeWriter containing the
  229. # code of the pxd, as well as a pxd scope.
  230. return [
  231. parse_pxd_stage_factory(context, scope, module_name)
  232. ] + create_pipeline(context, 'pxd') + [
  233. ExtractPxdCode()
  234. ]
  235. def create_py_pipeline(context, options, result):
  236. return create_pyx_pipeline(context, options, result, py=True)
  237. def create_pyx_as_pxd_pipeline(context, result):
  238. from .ParseTreeTransforms import AlignFunctionDefinitions, \
  239. MarkClosureVisitor, WithTransform, AnalyseDeclarationsTransform
  240. from .Optimize import ConstantFolding, FlattenInListTransform
  241. from .Nodes import StatListNode
  242. pipeline = []
  243. pyx_pipeline = create_pyx_pipeline(context, context.options, result,
  244. exclude_classes=[
  245. AlignFunctionDefinitions,
  246. MarkClosureVisitor,
  247. ConstantFolding,
  248. FlattenInListTransform,
  249. WithTransform
  250. ])
  251. for stage in pyx_pipeline:
  252. pipeline.append(stage)
  253. if isinstance(stage, AnalyseDeclarationsTransform):
  254. # This is the last stage we need.
  255. break
  256. def fake_pxd(root):
  257. for entry in root.scope.entries.values():
  258. if not entry.in_cinclude:
  259. entry.defined_in_pxd = 1
  260. if entry.name == entry.cname and entry.visibility != 'extern':
  261. # Always mangle non-extern cimported entries.
  262. entry.cname = entry.scope.mangle(Naming.func_prefix, entry.name)
  263. return StatListNode(root.pos, stats=[]), root.scope
  264. pipeline.append(fake_pxd)
  265. return pipeline
  266. def insert_into_pipeline(pipeline, transform, before=None, after=None):
  267. """
  268. Insert a new transform into the pipeline after or before an instance of
  269. the given class. e.g.
  270. pipeline = insert_into_pipeline(pipeline, transform,
  271. after=AnalyseDeclarationsTransform)
  272. """
  273. assert before or after
  274. cls = before or after
  275. for i, t in enumerate(pipeline):
  276. if isinstance(t, cls):
  277. break
  278. if after:
  279. i += 1
  280. return pipeline[:i] + [transform] + pipeline[i:]
  281. #
  282. # Running a pipeline
  283. #
  284. _pipeline_entry_points = {}
  285. def run_pipeline(pipeline, source, printtree=True):
  286. from .Visitor import PrintTree
  287. exec_ns = globals().copy() if DebugFlags.debug_verbose_pipeline else None
  288. def run(phase, data):
  289. return phase(data)
  290. error = None
  291. data = source
  292. try:
  293. try:
  294. for phase in pipeline:
  295. if phase is not None:
  296. if not printtree and isinstance(phase, PrintTree):
  297. continue
  298. if DebugFlags.debug_verbose_pipeline:
  299. t = time()
  300. print("Entering pipeline phase %r" % phase)
  301. # create a new wrapper for each step to show the name in profiles
  302. phase_name = getattr(phase, '__name__', type(phase).__name__)
  303. try:
  304. run = _pipeline_entry_points[phase_name]
  305. except KeyError:
  306. exec("def %s(phase, data): return phase(data)" % phase_name, exec_ns)
  307. run = _pipeline_entry_points[phase_name] = exec_ns[phase_name]
  308. data = run(phase, data)
  309. if DebugFlags.debug_verbose_pipeline:
  310. print(" %.3f seconds" % (time() - t))
  311. except CompileError as err:
  312. # err is set
  313. Errors.report_error(err, use_stack=False)
  314. error = err
  315. except InternalError as err:
  316. # Only raise if there was not an earlier error
  317. if Errors.num_errors == 0:
  318. raise
  319. error = err
  320. except AbortError as err:
  321. error = err
  322. return (error, data)