case_classes.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. """Macro providing an extremely concise way of declaring classes"""
  2. from macropy.core.macros import *
  3. from macropy.core.hquotes import macros, hq, name, unhygienic, u
  4. macros = Macros()
  5. def apply(f):
  6. return f()
  7. class CaseClass(object):
  8. __slots__ = []
  9. def copy(self, **kwargs):
  10. old = map(lambda a: (a, getattr(self, a)), self._fields)
  11. new = kwargs.items()
  12. return self.__class__(**dict(old + new))
  13. def __str__(self):
  14. return self.__class__.__name__ + "(" + ", ".join(str(getattr(self, x)) for x in self.__class__._fields) + ")"
  15. def __repr__(self):
  16. return self.__str__()
  17. def __eq__(self, other):
  18. try:
  19. return self.__class__ == other.__class__ \
  20. and all(getattr(self, x) == getattr(other, x) for x in self.__class__._fields)
  21. except AttributeError:
  22. return False
  23. def __ne__(self, other):
  24. return not self.__eq__(other)
  25. def __iter__(self):
  26. for x in self.__class__._fields:
  27. yield getattr(self, x)
  28. class Enum(object):
  29. def __new__(cls, *args, **kw):
  30. if not hasattr(cls, "all"):
  31. cls.all = []
  32. thing = object.__new__(cls, *args, **kw)
  33. cls.all.append(thing)
  34. return thing
  35. @property
  36. def next(self):
  37. return self.__class__.all[(self.id + 1) % len(self.__class__.all)]
  38. @property
  39. def prev(self):
  40. return self.__class__.all[(self.id - 1) % len(self.__class__.all)]
  41. def __str__(self):
  42. return self.__class__.__name__ + "." + self.name
  43. def __repr__(self):
  44. return self.__str__()
  45. def __iter__(self):
  46. for x in self.__class__._fields:
  47. yield getattr(self, x)
  48. def enum_new(cls, **kw):
  49. if len(kw) != 1:
  50. raise TypeError("Enum selection can only take exactly 1 named argument: " + len(kw) + " found.")
  51. [(k, v)] = kw.items()
  52. for value in cls.all:
  53. if getattr(value, k) == v:
  54. return value
  55. raise ValueError("No Enum found for %s=%s" % (k, v))
  56. def noop_init(*args, **kw):
  57. pass
  58. def extract_args(bases):
  59. args = []
  60. vararg = None
  61. kwarg = None
  62. defaults = []
  63. for base in bases:
  64. if type(base) is Name:
  65. args.append(base.id)
  66. elif type(base) is List:
  67. vararg = base.elts[0].id
  68. elif type(base) is Set:
  69. kwarg = base.elts[0].id
  70. elif type(base) is BinOp and type(base.op) is BitOr:
  71. args.append(base.left.id)
  72. defaults.append(base.right)
  73. else:
  74. assert False, "Illegal expression in case class signature: " + unparse(base)
  75. all_args = args[:]
  76. if vararg:
  77. all_args.append(vararg)
  78. if kwarg:
  79. all_args.append(kwarg)
  80. return args, vararg, kwarg, defaults, all_args
  81. @Walker
  82. def find_member_assignments(tree, collect, stop, **kw):
  83. if type(tree) in [GeneratorExp, Lambda, ListComp, DictComp, SetComp, FunctionDef, ClassDef]:
  84. stop()
  85. if type(tree) is Assign:
  86. self_assigns = [
  87. t.attr for t in tree.targets
  88. if type(t) is Attribute
  89. and type(t.value) is Name
  90. and t.value.id == "self"
  91. ]
  92. map(collect, self_assigns)
  93. def split_body(tree, gen_sym):
  94. new_body = []
  95. outer = []
  96. init_body = []
  97. for statement in tree.body:
  98. if type(statement) is ClassDef:
  99. outer.append(case_transform(statement, gen_sym, [Name(id=tree.name)]))
  100. with hq as a:
  101. name[tree.name].b = name[statement.name]
  102. a_old = a[0]
  103. a_old.targets[0].attr = statement.name
  104. a_new = parse_stmt(unparse(a[0]))[0]
  105. outer.append(a_new)
  106. elif type(statement) is FunctionDef:
  107. new_body.append(statement)
  108. else:
  109. init_body.append(statement)
  110. return new_body, outer, init_body
  111. def prep_initialization(init_fun, args, vararg, kwarg, defaults, all_args):
  112. init_fun.args = arguments(
  113. args = [Name(id="self")] + [Name(id = id) for id in args],
  114. vararg = vararg,
  115. kwarg = kwarg,
  116. defaults = defaults
  117. )
  118. for x in all_args:
  119. with hq as a:
  120. unhygienic[self.x] = name[x]
  121. a[0].targets[0].attr = x
  122. init_fun.body.append(a[0])
  123. def shared_transform(tree, gen_sym, additional_args=[]):
  124. with hq as methods:
  125. def __init__(self, *args, **kwargs):
  126. pass
  127. _fields = []
  128. _varargs = None
  129. _kwargs = None
  130. __slots__ = []
  131. init_fun, set_fields, set_varargs, set_kwargs, set_slots, = methods
  132. args, vararg, kwarg, defaults, all_args = extract_args(
  133. [Name(id=x) for x in additional_args] + tree.bases
  134. )
  135. if vararg:
  136. set_varargs.value = Str(vararg)
  137. if kwarg:
  138. set_kwargs.value = Str(kwarg)
  139. additional_members = find_member_assignments.collect(tree.body)
  140. prep_initialization(init_fun, args, vararg, kwarg, defaults, all_args)
  141. set_fields.value.elts = map(Str, args)
  142. set_slots.value.elts = map(Str, all_args + additional_members)
  143. new_body, outer, init_body = split_body(tree, gen_sym)
  144. init_fun.body.extend(init_body)
  145. tree.body = new_body
  146. tree.body = methods + tree.body
  147. return outer
  148. def case_transform(tree, gen_sym, parents):
  149. outer = shared_transform(tree, gen_sym)
  150. tree.bases = parents
  151. assign = FunctionDef(
  152. gen_sym("prepare_"+tree.name),
  153. arguments([], None, None, []),
  154. outer,
  155. [hq[apply]]
  156. )
  157. return [tree] + ([assign] if len(outer) > 0 else [])
  158. @macros.decorator
  159. def case(tree, gen_sym, **kw):
  160. """Macro providing an extremely concise way of declaring classes"""
  161. x = case_transform(tree, gen_sym, [hq[CaseClass]])
  162. return x
  163. @macros.decorator
  164. def enum(tree, gen_sym, exact_src, **kw):
  165. count = [0]
  166. new_assigns = []
  167. new_body = []
  168. def handle(expr):
  169. assert type(expr) in (Name, Call), stmt.value
  170. if type(expr) is Name:
  171. expr.ctx = Store()
  172. self_ref = Attribute(value=Name(id=tree.name), attr=expr.id)
  173. with hq as code:
  174. ast[self_ref] = name[tree.name](u[count[0]], u[expr.id])
  175. new_assigns.extend(code)
  176. count[0] += 1
  177. elif type(expr) is Call:
  178. assert type(expr.func) is Name
  179. self_ref = Attribute(value=Name(id=tree.name), attr=expr.func.id)
  180. id = expr.func.id
  181. expr.func = Name(id=tree.name)
  182. expr.args = [Num(count[0]), Str(id)] + expr.args
  183. new_assigns.append(Assign([self_ref], expr))
  184. count[0] += 1
  185. for stmt in tree.body:
  186. try:
  187. if type(stmt) is Expr:
  188. assert type(stmt.value) in (Tuple, Name, Call)
  189. if type(stmt.value) is Tuple:
  190. map(handle, stmt.value.elts)
  191. else:
  192. handle(stmt.value)
  193. elif type(stmt) is FunctionDef:
  194. new_body.append(stmt)
  195. else:
  196. assert False
  197. except AssertionError as e:
  198. assert False, "Can't have `%s` in body of enum" % unparse(stmt).strip("\n")
  199. tree.body = new_body + [Pass()]
  200. shared_transform(tree, gen_sym, additional_args=["id", "name"])
  201. with hq as code:
  202. name[tree.name].__new__ = staticmethod(enum_new)
  203. name[tree.name].__init__ = noop_init
  204. tree.bases = [hq[Enum]]
  205. return [tree] + new_assigns + code