pattern.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. import inspect
  2. from abc import ABCMeta, abstractmethod
  3. from ast import *
  4. from macropy.core import util
  5. from macropy.core.macros import *
  6. from macropy.core.quotes import macros, q
  7. from macropy.core.hquotes import macros, hq
  8. macros = Macros()
  9. class PatternMatchException(Exception):
  10. """Thrown when a nonrefutable pattern match fails"""
  11. pass
  12. class PatternVarConflict(Exception):
  13. """Thrown when a pattern attempts to match a variable more than once."""
  14. pass
  15. def _vars_are_disjoint(var_names):
  16. return len(var_names)== len(set(var_names))
  17. class Matcher(object):
  18. __metaclass__ = ABCMeta
  19. @abstractmethod
  20. def var_names(self):
  21. """
  22. Returns a container of the variable names which may be modified upon a
  23. successful match.
  24. """
  25. pass
  26. @abstractmethod
  27. def match(self, matchee):
  28. """
  29. Returns ([(varname, value)...]) if there is a match. Otherwise,
  30. raise PatternMatchException(). This should be stateless.
  31. """
  32. pass
  33. def _match_value(self, matchee):
  34. """
  35. Match against matchee and produce an internal dictionary of the values
  36. for each variable.
  37. """
  38. self.var_dict = {}
  39. for (varname, value) in self.match(matchee):
  40. self.var_dict[varname] = value
  41. def get_var(self, var_name):
  42. return self.var_dict[var_name]
  43. class LiteralMatcher(Matcher):
  44. def __init__(self, val):
  45. self.val = val
  46. def var_names(self):
  47. return []
  48. def match(self, matchee):
  49. if self.val != matchee:
  50. raise PatternMatchException("Literal match failed")
  51. return []
  52. class TupleMatcher(Matcher):
  53. def __init__(self, *matchers):
  54. self.matchers = matchers
  55. if not _vars_are_disjoint(util.flatten([m.var_names() for m in
  56. matchers])):
  57. raise PatternVarConflict()
  58. def var_names(self):
  59. return util.flatten([matcher.var_names() for matcher in self.matchers])
  60. def match(self, matchee):
  61. updates = []
  62. if (not isinstance(matchee, tuple) or
  63. len(matchee) != len(self.matchers)):
  64. raise PatternMatchException("Expected tuple of %d elements" %
  65. (len(self.matchers),))
  66. for (matcher, sub_matchee) in zip(self.matchers, matchee):
  67. match = matcher.match(sub_matchee)
  68. updates.extend(match)
  69. return updates
  70. class ParallelMatcher(Matcher):
  71. def __init__(self, matcher1, matcher2):
  72. self.matcher1 = matcher1
  73. self.matcher2 = matcher2
  74. if not _vars_are_disjoint(util.flatten([matcher1.var_names(),
  75. matcher2.var_names()])):
  76. raise PatternVarConflict()
  77. def var_names(self):
  78. return util.flatten([self.matcher1.var_names(),
  79. self.matcher2.var_names()])
  80. def match(self, matchee):
  81. updates = []
  82. for matcher in [self.matcher1, self.matcher2]:
  83. match = matcher.match(matchee)
  84. updates.extend(match)
  85. return updates
  86. class ListMatcher(Matcher):
  87. def __init__(self, *matchers):
  88. self.matchers = matchers
  89. if not _vars_are_disjoint(util.flatten([m.var_names() for m in
  90. matchers])):
  91. raise PatternVarConflict()
  92. def var_names(self):
  93. return util.flatten([matcher.var_names() for matcher in self.matchers])
  94. def match(self, matchee):
  95. updates = []
  96. if (not isinstance(matchee, list) or len(matchee) != len(self.matchers)):
  97. raise PatternMatchException("Expected list of length %d" %
  98. (len(self.matchers),))
  99. for (matcher, sub_matchee) in zip(self.matchers, matchee):
  100. match = matcher.match(sub_matchee)
  101. updates.extend(match)
  102. return updates
  103. class NameMatcher(Matcher):
  104. def __init__(self, name):
  105. self.name = name
  106. def var_names(self):
  107. return [self.name]
  108. def match(self, matchee):
  109. return [(self.name, matchee)]
  110. class WildcardMatcher(Matcher):
  111. def __init__(self):
  112. pass
  113. def var_names(self):
  114. return ['_']
  115. def match(self, matchee):
  116. return [('_', 3)]
  117. class ClassMatcher(Matcher):
  118. def __init__(self, clazz, positionalMatchers, **kwMatchers):
  119. self.clazz = clazz
  120. self.positionalMatchers = positionalMatchers
  121. self.kwMatchers = kwMatchers
  122. # This stores which fields of the object we will need to look at.
  123. if not _vars_are_disjoint(util.flatten([m.var_names() for m in
  124. positionalMatchers + kwMatchers.values()])):
  125. raise PatternVarConflict()
  126. def var_names(self):
  127. return (util.flatten([matcher.var_names()
  128. for matcher in self.positionalMatchers + self.kwMatchers.values()]))
  129. def default_unapply(self, matchee, kw_keys):
  130. if not isinstance(matchee, self.clazz):
  131. raise PatternMatchException("Matchee should be of type %r" %
  132. (self.clazz,))
  133. pos_values = []
  134. kw_dict = {}
  135. # We don't get the argspec unless there are actually positional matchers
  136. def genPosValues():
  137. arg_spec = inspect.getargspec(self.clazz.__init__)
  138. for arg in arg_spec.args:
  139. if arg != 'self':
  140. yield(getattr(matchee, arg, None))
  141. pos_values = genPosValues()
  142. for kw_key in kw_keys:
  143. if not hasattr(matchee, kw_key):
  144. raise PatternMatchException("Keyword argument match failed: no"
  145. + " attribute %r" % (kw_key,))
  146. kw_dict[kw_key] = getattr(matchee, kw_key)
  147. return pos_values, kw_dict
  148. def match(self, matchee):
  149. updates = []
  150. if hasattr(self.clazz, '__unapply__'):
  151. pos_vals, kw_dict = self.clazz.__unapply__(matchee,
  152. self.kwMatchers.keys())
  153. else:
  154. pos_vals, kw_dict = self.default_unapply(matchee,
  155. self.kwMatchers.keys())
  156. for (matcher, sub_matchee) in zip(self.positionalMatchers,
  157. pos_vals):
  158. updates.extend(matcher.match(sub_matchee))
  159. for key, val in kw_dict.items():
  160. updates.extend(self.kwMatchers[key].match(val))
  161. return updates
  162. def build_matcher(tree, modified):
  163. if isinstance(tree, Num):
  164. return hq[LiteralMatcher(u[tree.n])]
  165. if isinstance(tree, Str):
  166. return hq[LiteralMatcher(u[tree.s])]
  167. if isinstance(tree, Name):
  168. if tree.id in ['True', 'False', 'None']:
  169. return hq[LiteralMatcher(ast[tree])]
  170. elif tree.id in ['_']:
  171. return hq[WildcardMatcher()]
  172. modified.add(tree.id)
  173. return hq[NameMatcher(u[tree.id])]
  174. if isinstance(tree, List):
  175. sub_matchers = []
  176. for child in tree.elts:
  177. sub_matchers.append(build_matcher(child, modified))
  178. return Call(Name('ListMatcher', Load()), sub_matchers, [], None, None)
  179. if isinstance(tree, Tuple):
  180. sub_matchers = []
  181. for child in tree.elts:
  182. sub_matchers.append(build_matcher(child, modified))
  183. return Call(Name('TupleMatcher', Load()), sub_matchers, [], None, None)
  184. if isinstance(tree, Call):
  185. sub_matchers = []
  186. for child in tree.args:
  187. sub_matchers.append(build_matcher(child, modified))
  188. positional_matchers = List(sub_matchers, Load())
  189. kw_matchers = []
  190. for kw in tree.keywords:
  191. kw_matchers.append(
  192. keyword(kw.arg, build_matcher(kw.value, modified)))
  193. return Call(Name('ClassMatcher', Load()), [tree.func,
  194. positional_matchers], kw_matchers, None, None)
  195. if (isinstance(tree, BinOp) and isinstance(tree.op, BitAnd)):
  196. sub1 = build_matcher(tree.left, modified)
  197. sub2 = build_matcher(tree.right, modified)
  198. return Call(Name('ParallelMatcher', Load()), [sub1, sub2], [], None,
  199. None)
  200. raise Exception("Unrecognized tree " + repr(tree))
  201. def _is_pattern_match_stmt(tree):
  202. return (isinstance(tree, Expr) and
  203. _is_pattern_match_expr(tree.value))
  204. def _is_pattern_match_expr(tree):
  205. return (isinstance(tree, BinOp) and
  206. isinstance(tree.op, LShift))
  207. @macros.block
  208. def _matching(tree, gen_sym, **kw):
  209. """
  210. This macro will enable non-refutable pattern matching. If a pattern match
  211. fails, an exception will be thrown.
  212. """
  213. @Walker
  214. def func(tree, **kw):
  215. if _is_pattern_match_stmt(tree):
  216. modified = set()
  217. matcher = build_matcher(tree.value.left, modified)
  218. temp = gen_sym()
  219. # lol random names for hax
  220. with hq as assignment:
  221. name[temp] = ast[matcher]
  222. statements = [assignment, Expr(hq[name[temp]._match_value(ast[tree.value.right])])]
  223. for var_name in modified:
  224. statements.append(Assign([Name(var_name, Store())], hq[name[temp].get_var(u[var_name])]))
  225. return statements
  226. else:
  227. return tree
  228. func.recurse(tree)
  229. return [tree]
  230. def _rewrite_if(tree, var_name=None, **kw_args):
  231. # TODO refactor into a _rewrite_switch and a _rewrite_if
  232. """
  233. Rewrite if statements to treat pattern matches as boolean expressions.
  234. Recall that normally a pattern match is a statement which will throw a
  235. PatternMatchException if the match fails. We can therefore use try-blocks
  236. to produce the desired branching behavior.
  237. var_name is an optional parameter used for rewriting switch statements. If
  238. present, it will transform predicates which are expressions into pattern
  239. matches.
  240. """
  241. # with q as rewritten:
  242. # try:
  243. # with matching:
  244. # u%(matchPattern)
  245. # u%(successBody)
  246. # except PatternMatchException:
  247. # u%(_maybe_rewrite_if(failBody))
  248. # return rewritten
  249. if not isinstance(tree, If):
  250. return tree
  251. if var_name:
  252. tree.test = BinOp(tree.test, LShift(), Name(var_name, Load()))
  253. elif not (isinstance(tree.test, BinOp) and isinstance(tree.test.op, LShift)):
  254. return tree
  255. handler = ExceptHandler(hq[PatternMatchException], None, tree.orelse)
  256. try_stmt = TryExcept(tree.body, [handler], [])
  257. macroed_match = With(Name('_matching', Load()), None, [Expr(tree.test)])
  258. try_stmt.body = [macroed_match] + try_stmt.body
  259. if len(handler.body) == 1: # (== tree.orelse)
  260. # Might be an elif
  261. handler.body = [_rewrite_if(handler.body[0], var_name)]
  262. elif not handler.body:
  263. handler.body = [Pass()]
  264. return try_stmt
  265. @macros.block
  266. def switch(tree, args, gen_sym, **kw):
  267. """
  268. If supplied one argument x, switch will treat the predicates of any
  269. top-level if statements as patten matches against x.
  270. Pattern matches elsewhere are ignored. The advantage of this is the
  271. limited reach ensures less interference with existing code.
  272. """
  273. new_id = gen_sym()
  274. for i in xrange(len(tree)):
  275. tree[i] = _rewrite_if(tree[i], new_id)
  276. tree = [Assign([Name(new_id, Store())], args[0])] + tree
  277. return tree
  278. @macros.block
  279. def patterns(tree, **kw):
  280. """
  281. This enables patterns everywhere! NB if you use this macro, you will not be
  282. able to use real left shifts anywhere.
  283. """
  284. with q as new:
  285. with _matching:
  286. None
  287. new[0].body = Walker(lambda tree, **kw: _rewrite_if(tree)).recurse(tree)
  288. return new