macros.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. """The main source of all things MacroPy"""
  2. import sys
  3. import imp
  4. import ast
  5. import itertools
  6. from ast import *
  7. from util import *
  8. from walkers import *
  9. # Monkey Patching pickle to pickle module objects properly
  10. import pickle
  11. pickle.Pickler.dispatch[type(pickle)] = pickle.Pickler.save_global
  12. class WrappedFunction(object):
  13. """Wraps a function which is meant to be handled (and removed) by macro
  14. expansion, and never called directly with square brackets."""
  15. def __init__(self, func, msg):
  16. self.func = func
  17. self.msg = msg
  18. import functools
  19. functools.update_wrapper(self, func)
  20. def __call__(self, *args, **kwargs):
  21. return self.func(*args, **kwargs)
  22. def __getitem__(self, i):
  23. raise TypeError(self.msg.replace("%s", self.func.__name__))
  24. def macro_function(func):
  25. """Wraps a function, to provide nicer error-messages in the common
  26. case where the macro is imported but macro-expansion isn't triggered"""
  27. return WrappedFunction(
  28. func,
  29. "Macro `%s` illegally invoked at runtime; did you import it "
  30. "properly using `from ... import macros, %s`?"
  31. )
  32. def macro_stub(func):
  33. """Wraps a function that is a stub meant to be used by macros but never
  34. called directly."""
  35. return WrappedFunction(
  36. func,
  37. "Stub `%s` illegally invoked at runtime; is it used "
  38. "properly within a macro?"
  39. )
  40. class Macros(object):
  41. """A registry of macros belonging to a module; used via
  42. ```python
  43. macros = Macros()
  44. @macros.expr
  45. def my_macro(tree):
  46. ...
  47. ```
  48. Where the decorators are used to register functions as macros belonging
  49. to that module.
  50. """
  51. class Registry(object):
  52. def __init__(self, wrap = lambda x: x):
  53. self.registry = {}
  54. self.wrap = wrap
  55. def __call__(self, f, name=None):
  56. if name is not None:
  57. self.registry[name] = self.wrap(f)
  58. if hasattr(f, "func_name"):
  59. self.registry[f.func_name] = self.wrap(f)
  60. if hasattr(f, "__name__"):
  61. self.registry[f.__name__] = self.wrap(f)
  62. return self.wrap(f)
  63. def __init__(self):
  64. # Different kinds of macros
  65. self.expr = Macros.Registry(macro_function)
  66. self.block = Macros.Registry(macro_function)
  67. self.decorator = Macros.Registry(macro_function)
  68. self.expose_unhygienic = Macros.Registry()
  69. # For other modules to hook into MacroPy's workflow while
  70. # keeping this module itself unaware of their presence.
  71. injected_vars = [] # functions to inject values throughout each files macros
  72. filters = [] # functions to call on every macro-expanded snippet
  73. post_processing = [] # functions to call on every macro-expanded file
  74. def expand_entire_ast(tree, src, bindings):
  75. def expand_macros(tree):
  76. """Go through an AST, hunting for macro invocations and expanding any that
  77. are found"""
  78. def expand_if_in_registry(macro_tree, body_tree, args, registry, **kwargs):
  79. """check if `tree` is a macro in `registry`, and if so use it to expand `args`"""
  80. if isinstance(macro_tree, Name) and macro_tree.id in registry:
  81. (the_macro, the_module) = registry[macro_tree.id]
  82. try:
  83. new_tree = the_macro(
  84. tree=body_tree,
  85. args=args,
  86. src=src,
  87. expand_macros=expand_macros,
  88. **dict(kwargs.items() + file_vars.items())
  89. )
  90. except Exception as e:
  91. new_tree = e
  92. for filter in reversed(filters):
  93. new_tree = filter(
  94. tree=new_tree,
  95. args=args,
  96. src=src,
  97. expand_macros=expand_macros,
  98. lineno=macro_tree.lineno,
  99. col_offset=macro_tree.col_offset,
  100. **dict(kwargs.items() + file_vars.items())
  101. )
  102. return new_tree
  103. elif isinstance(macro_tree, Call):
  104. args.extend(macro_tree.args)
  105. return expand_if_in_registry(macro_tree.func, body_tree, args, registry)
  106. def preserve_line_numbers(func):
  107. """Decorates a tree-transformer function to stick the original line
  108. numbers onto the transformed tree"""
  109. def run(tree):
  110. pos = (tree.lineno, tree.col_offset) if hasattr(tree, "lineno") and hasattr(tree, "col_offset") else None
  111. new_tree = func(tree)
  112. if pos:
  113. t = new_tree
  114. while type(t) is list:
  115. t = t[0]
  116. (t.lineno, t.col_offset) = pos
  117. return new_tree
  118. return run
  119. @preserve_line_numbers
  120. def macro_expand(tree):
  121. """Tail Recursively expands all macros in a single AST node"""
  122. if isinstance(tree, With):
  123. assert isinstance(tree.body, list), real_repr(tree.body)
  124. new_tree = expand_if_in_registry(tree.context_expr, tree.body, [], block_registry, target=tree.optional_vars)
  125. if new_tree:
  126. if isinstance(new_tree, expr):
  127. new_tree = [Expr(new_tree)]
  128. if isinstance(new_tree, Exception): raise new_tree
  129. assert isinstance(new_tree, list), type(new_tree)
  130. return macro_expand(new_tree)
  131. if isinstance(tree, Subscript) and type(tree.slice) is Index:
  132. new_tree = expand_if_in_registry(tree.value, tree.slice.value, [], expr_registry)
  133. if new_tree:
  134. assert isinstance(new_tree, expr), type(new_tree)
  135. return macro_expand(new_tree)
  136. if isinstance(tree, ClassDef) or isinstance(tree, FunctionDef):
  137. seen_decs = []
  138. additions = []
  139. while tree.decorator_list != []:
  140. dec = tree.decorator_list[0]
  141. tree.decorator_list = tree.decorator_list[1:]
  142. new_tree = expand_if_in_registry(dec, tree, [], decorator_registry)
  143. if new_tree is None:
  144. seen_decs.append(dec)
  145. else:
  146. tree = new_tree
  147. tree = macro_expand(tree)
  148. if type(tree) is list:
  149. additions = tree[1:]
  150. tree = tree[0]
  151. elif isinstance(tree, expr):
  152. tree = [Expr(tree)]
  153. break
  154. if type(tree) is ClassDef or type(tree) is FunctionDef:
  155. tree.decorator_list = seen_decs
  156. if len(additions) == 0:
  157. return tree
  158. else:
  159. return [tree] + additions
  160. return tree
  161. @Walker
  162. def macro_searcher(tree, **kw):
  163. x = macro_expand(tree)
  164. return x
  165. tree = macro_searcher.recurse(tree)
  166. return tree
  167. file_vars = {}
  168. for v in injected_vars:
  169. file_vars[v.func_name] = v(tree=tree, src=src, expand_macros=expand_macros, **file_vars)
  170. allnames = [
  171. (m, name, asname)
  172. for m, names in bindings
  173. for name, asname in names
  174. ]
  175. def extract_macros(pick_registry):
  176. return {
  177. asname: (registry[name], ma)
  178. for ma, name, asname in allnames
  179. for registry in [pick_registry(ma.macros).registry]
  180. if name in registry.keys()
  181. }
  182. block_registry = extract_macros(lambda x: x.block)
  183. expr_registry = extract_macros(lambda x: x.expr)
  184. decorator_registry = extract_macros(lambda x: x.decorator)
  185. tree = expand_macros(tree)
  186. for post in post_processing:
  187. tree = post(
  188. tree=tree,
  189. src=src,
  190. expand_macros=expand_macros,
  191. **file_vars
  192. )
  193. return tree
  194. def detect_macros(tree):
  195. """Look for macros imports within an AST, transforming them and extracting
  196. the list of macro modules."""
  197. bindings = []
  198. for stmt in tree.body:
  199. if isinstance(stmt, ImportFrom) \
  200. and stmt.names[0].name == 'macros' \
  201. and stmt.names[0].asname is None:
  202. __import__(stmt.module)
  203. mod = sys.modules[stmt.module]
  204. bindings.append((
  205. stmt.module,
  206. [(t.name, t.asname or t.name) for t in stmt.names[1:]]
  207. ))
  208. stmt.names = [
  209. name for name in stmt.names
  210. if name.name not in mod.macros.block.registry
  211. if name.name not in mod.macros.expr.registry
  212. if name.name not in mod.macros.decorator.registry
  213. ]
  214. stmt.names.extend([
  215. alias(x, x) for x in
  216. mod.macros.expose_unhygienic.registry.keys()
  217. ])
  218. return bindings
  219. def check_annotated(tree):
  220. """Shorthand for checking if an AST is of the form something[...]"""
  221. if isinstance(tree, Subscript) and \
  222. type(tree.slice) is Index and \
  223. type(tree.value) is Name:
  224. return tree.value.id, tree.slice.value
  225. # import other modules in order to register their hooks