walkers.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. """Implementation of Walkers, a nice way of transforming and traversing ASTs.
  2. @Walker decorates a function of the form:
  3. @Walker
  4. def transform(tree, **kw):
  5. ...
  6. return new_tree
  7. Which is used via:
  8. new_tree = transform.recurse(old_tree, initial_ctx)
  9. new_tree = transform.recurse(old_tree)
  10. new_tree, collected = transform.recurse_collect(old_tree, initial_ctx)
  11. new_tree, collected = transform.recurse_collect(old_tree)
  12. collected = transform.collect(old_tree, initial_ctx)
  13. collected = transform.collect(old_tree)
  14. The `transform` function takes the tree to be transformed, in addition to
  15. a set of `**kw` which provides additional functionality:
  16. - `ctx`: this is the value that is (optionally) passed in to the `recurse`
  17. and `recurse_collect` methods.
  18. - `set_ctx`: this is a function, used via `set_ctx(new_ctx)` anywhere in
  19. `transform`, which will cause any children of `tree` to receive `new_ctx`
  20. as their `ctx` variable.
  21. - `collect`: this is a function used via `collect(thing)`, which adds
  22. `thing` to the `collected` list returned by `recurse_collect`.
  23. - `stop`: when called via `stop()`, this prevents recursion on children
  24. of the current tree.
  25. These additional arguments can be declared in the signature, e.g.:
  26. @Walker
  27. def transform(tree, ctx, set_ctx, **kw):
  28. ... do stuff with ctx ...
  29. set_ctx(...)
  30. return new_tree
  31. for ease of use.
  32. """
  33. from macropy.core import *
  34. from ast import *
  35. class Walker(object):
  36. def __init__(self, func):
  37. self.func = func
  38. def walk_children(self, tree, ctx=None):
  39. if isinstance(tree, AST):
  40. aggregates = []
  41. for field, old_value in iter_fields(tree):
  42. old_value = getattr(tree, field, None)
  43. new_value, new_aggregate = self.recurse_collect(old_value, ctx)
  44. aggregates.extend(new_aggregate)
  45. setattr(tree, field, new_value)
  46. return aggregates
  47. elif isinstance(tree, list) and len(tree) > 0:
  48. aggregates = []
  49. new_tree = []
  50. for t in tree:
  51. new_t, new_a = self.recurse_collect(t, ctx)
  52. if type(new_t) is list:
  53. new_tree.extend(new_t)
  54. else:
  55. new_tree.append(new_t)
  56. aggregates.extend(new_a)
  57. tree[:] = new_tree
  58. return aggregates
  59. else:
  60. return []
  61. def recurse(self, tree, ctx=None):
  62. """Traverse the given AST and return the transformed tree."""
  63. return self.recurse_collect(tree, ctx)[0]
  64. def collect(self, tree, ctx=None):
  65. """Traverse the given AST and return the transformed tree."""
  66. return self.recurse_collect(tree, ctx)[1]
  67. def recurse_collect(self, tree, ctx=None):
  68. """Traverse the given AST and return the transformed tree together
  69. with any values which were collected along with way."""
  70. if isinstance(tree, AST) or type(tree) is Literal or type(tree) is Captured:
  71. aggregates = []
  72. stop_now = [False]
  73. def stop():
  74. stop_now[0] = True
  75. new_ctx = [ctx]
  76. def set_ctx(new):
  77. new_ctx[0] = new
  78. # Provide the function with a bunch of controls, in addition to
  79. # the tree itself.
  80. new_tree = self.func(
  81. tree=tree,
  82. ctx=ctx,
  83. collect=aggregates.append,
  84. set_ctx=set_ctx,
  85. stop=stop
  86. )
  87. if new_tree is not None:
  88. tree = new_tree
  89. if not stop_now[0]:
  90. aggregates.extend(self.walk_children(tree, new_ctx[0]))
  91. else:
  92. aggregates = self.walk_children(tree, ctx)
  93. return tree, aggregates