123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127 |
- """Implementation of Walkers, a nice way of transforming and traversing ASTs.
- @Walker decorates a function of the form:
- @Walker
- def transform(tree, **kw):
- ...
- return new_tree
- Which is used via:
- new_tree = transform.recurse(old_tree, initial_ctx)
- new_tree = transform.recurse(old_tree)
- new_tree, collected = transform.recurse_collect(old_tree, initial_ctx)
- new_tree, collected = transform.recurse_collect(old_tree)
- collected = transform.collect(old_tree, initial_ctx)
- collected = transform.collect(old_tree)
- The `transform` function takes the tree to be transformed, in addition to
- a set of `**kw` which provides additional functionality:
- - `ctx`: this is the value that is (optionally) passed in to the `recurse`
- and `recurse_collect` methods.
- - `set_ctx`: this is a function, used via `set_ctx(new_ctx)` anywhere in
- `transform`, which will cause any children of `tree` to receive `new_ctx`
- as their `ctx` variable.
- - `collect`: this is a function used via `collect(thing)`, which adds
- `thing` to the `collected` list returned by `recurse_collect`.
- - `stop`: when called via `stop()`, this prevents recursion on children
- of the current tree.
- These additional arguments can be declared in the signature, e.g.:
- @Walker
- def transform(tree, ctx, set_ctx, **kw):
- ... do stuff with ctx ...
- set_ctx(...)
- return new_tree
- for ease of use.
- """
- from macropy.core import *
- from ast import *
- class Walker(object):
- def __init__(self, func):
- self.func = func
- def walk_children(self, tree, ctx=None):
- if isinstance(tree, AST):
- aggregates = []
- for field, old_value in iter_fields(tree):
- old_value = getattr(tree, field, None)
- new_value, new_aggregate = self.recurse_collect(old_value, ctx)
- aggregates.extend(new_aggregate)
- setattr(tree, field, new_value)
- return aggregates
- elif isinstance(tree, list) and len(tree) > 0:
- aggregates = []
- new_tree = []
- for t in tree:
- new_t, new_a = self.recurse_collect(t, ctx)
- if type(new_t) is list:
- new_tree.extend(new_t)
- else:
- new_tree.append(new_t)
- aggregates.extend(new_a)
- tree[:] = new_tree
- return aggregates
- else:
- return []
- def recurse(self, tree, ctx=None):
- """Traverse the given AST and return the transformed tree."""
- return self.recurse_collect(tree, ctx)[0]
- def collect(self, tree, ctx=None):
- """Traverse the given AST and return the transformed tree."""
- return self.recurse_collect(tree, ctx)[1]
- def recurse_collect(self, tree, ctx=None):
- """Traverse the given AST and return the transformed tree together
- with any values which were collected along with way."""
- if isinstance(tree, AST) or type(tree) is Literal or type(tree) is Captured:
- aggregates = []
- stop_now = [False]
- def stop():
- stop_now[0] = True
- new_ctx = [ctx]
- def set_ctx(new):
- new_ctx[0] = new
- # Provide the function with a bunch of controls, in addition to
- # the tree itself.
- new_tree = self.func(
- tree=tree,
- ctx=ctx,
- collect=aggregates.append,
- set_ctx=set_ctx,
- stop=stop
- )
- if new_tree is not None:
- tree = new_tree
- if not stop_now[0]:
- aggregates.extend(self.walk_children(tree, new_ctx[0]))
- else:
- aggregates = self.walk_children(tree, ctx)
- return tree, aggregates
|