source.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. from __future__ import absolute_import, division, print_function
  2. import ast
  3. from ast import PyCF_ONLY_AST as _AST_FLAG
  4. from bisect import bisect_right
  5. import linecache
  6. import sys
  7. import six
  8. import inspect
  9. import tokenize
  10. import py
  11. cpy_compile = compile
  12. class Source(object):
  13. """ an immutable object holding a source code fragment,
  14. possibly deindenting it.
  15. """
  16. _compilecounter = 0
  17. def __init__(self, *parts, **kwargs):
  18. self.lines = lines = []
  19. de = kwargs.get("deindent", True)
  20. rstrip = kwargs.get("rstrip", True)
  21. for part in parts:
  22. if not part:
  23. partlines = []
  24. elif isinstance(part, Source):
  25. partlines = part.lines
  26. elif isinstance(part, (tuple, list)):
  27. partlines = [x.rstrip("\n") for x in part]
  28. elif isinstance(part, six.string_types):
  29. partlines = part.split("\n")
  30. if rstrip:
  31. while partlines:
  32. if partlines[-1].strip():
  33. break
  34. partlines.pop()
  35. else:
  36. partlines = getsource(part, deindent=de).lines
  37. if de:
  38. partlines = deindent(partlines)
  39. lines.extend(partlines)
  40. def __eq__(self, other):
  41. try:
  42. return self.lines == other.lines
  43. except AttributeError:
  44. if isinstance(other, str):
  45. return str(self) == other
  46. return False
  47. __hash__ = None
  48. def __getitem__(self, key):
  49. if isinstance(key, int):
  50. return self.lines[key]
  51. else:
  52. if key.step not in (None, 1):
  53. raise IndexError("cannot slice a Source with a step")
  54. newsource = Source()
  55. newsource.lines = self.lines[key.start : key.stop]
  56. return newsource
  57. def __len__(self):
  58. return len(self.lines)
  59. def strip(self):
  60. """ return new source object with trailing
  61. and leading blank lines removed.
  62. """
  63. start, end = 0, len(self)
  64. while start < end and not self.lines[start].strip():
  65. start += 1
  66. while end > start and not self.lines[end - 1].strip():
  67. end -= 1
  68. source = Source()
  69. source.lines[:] = self.lines[start:end]
  70. return source
  71. def putaround(self, before="", after="", indent=" " * 4):
  72. """ return a copy of the source object with
  73. 'before' and 'after' wrapped around it.
  74. """
  75. before = Source(before)
  76. after = Source(after)
  77. newsource = Source()
  78. lines = [(indent + line) for line in self.lines]
  79. newsource.lines = before.lines + lines + after.lines
  80. return newsource
  81. def indent(self, indent=" " * 4):
  82. """ return a copy of the source object with
  83. all lines indented by the given indent-string.
  84. """
  85. newsource = Source()
  86. newsource.lines = [(indent + line) for line in self.lines]
  87. return newsource
  88. def getstatement(self, lineno):
  89. """ return Source statement which contains the
  90. given linenumber (counted from 0).
  91. """
  92. start, end = self.getstatementrange(lineno)
  93. return self[start:end]
  94. def getstatementrange(self, lineno):
  95. """ return (start, end) tuple which spans the minimal
  96. statement region which containing the given lineno.
  97. """
  98. if not (0 <= lineno < len(self)):
  99. raise IndexError("lineno out of range")
  100. ast, start, end = getstatementrange_ast(lineno, self)
  101. return start, end
  102. def deindent(self, offset=None):
  103. """ return a new source object deindented by offset.
  104. If offset is None then guess an indentation offset from
  105. the first non-blank line. Subsequent lines which have a
  106. lower indentation offset will be copied verbatim as
  107. they are assumed to be part of multilines.
  108. """
  109. # XXX maybe use the tokenizer to properly handle multiline
  110. # strings etc.pp?
  111. newsource = Source()
  112. newsource.lines[:] = deindent(self.lines, offset)
  113. return newsource
  114. def isparseable(self, deindent=True):
  115. """ return True if source is parseable, heuristically
  116. deindenting it by default.
  117. """
  118. from parser import suite as syntax_checker
  119. if deindent:
  120. source = str(self.deindent())
  121. else:
  122. source = str(self)
  123. try:
  124. # compile(source+'\n', "x", "exec")
  125. syntax_checker(source + "\n")
  126. except KeyboardInterrupt:
  127. raise
  128. except Exception:
  129. return False
  130. else:
  131. return True
  132. def __str__(self):
  133. return "\n".join(self.lines)
  134. def compile(
  135. self, filename=None, mode="exec", flag=0, dont_inherit=0, _genframe=None
  136. ):
  137. """ return compiled code object. if filename is None
  138. invent an artificial filename which displays
  139. the source/line position of the caller frame.
  140. """
  141. if not filename or py.path.local(filename).check(file=0):
  142. if _genframe is None:
  143. _genframe = sys._getframe(1) # the caller
  144. fn, lineno = _genframe.f_code.co_filename, _genframe.f_lineno
  145. base = "<%d-codegen " % self._compilecounter
  146. self.__class__._compilecounter += 1
  147. if not filename:
  148. filename = base + "%s:%d>" % (fn, lineno)
  149. else:
  150. filename = base + "%r %s:%d>" % (filename, fn, lineno)
  151. source = "\n".join(self.lines) + "\n"
  152. try:
  153. co = cpy_compile(source, filename, mode, flag)
  154. except SyntaxError:
  155. ex = sys.exc_info()[1]
  156. # re-represent syntax errors from parsing python strings
  157. msglines = self.lines[: ex.lineno]
  158. if ex.offset:
  159. msglines.append(" " * ex.offset + "^")
  160. msglines.append("(code was compiled probably from here: %s)" % filename)
  161. newex = SyntaxError("\n".join(msglines))
  162. newex.offset = ex.offset
  163. newex.lineno = ex.lineno
  164. newex.text = ex.text
  165. raise newex
  166. else:
  167. if flag & _AST_FLAG:
  168. return co
  169. lines = [(x + "\n") for x in self.lines]
  170. linecache.cache[filename] = (1, None, lines, filename)
  171. return co
  172. #
  173. # public API shortcut functions
  174. #
  175. def compile_(source, filename=None, mode="exec", flags=0, dont_inherit=0):
  176. """ compile the given source to a raw code object,
  177. and maintain an internal cache which allows later
  178. retrieval of the source code for the code object
  179. and any recursively created code objects.
  180. """
  181. if isinstance(source, ast.AST):
  182. # XXX should Source support having AST?
  183. return cpy_compile(source, filename, mode, flags, dont_inherit)
  184. _genframe = sys._getframe(1) # the caller
  185. s = Source(source)
  186. co = s.compile(filename, mode, flags, _genframe=_genframe)
  187. return co
  188. def getfslineno(obj):
  189. """ Return source location (path, lineno) for the given object.
  190. If the source cannot be determined return ("", -1)
  191. """
  192. from .code import Code
  193. try:
  194. code = Code(obj)
  195. except TypeError:
  196. try:
  197. fn = inspect.getsourcefile(obj) or inspect.getfile(obj)
  198. except TypeError:
  199. return "", -1
  200. fspath = fn and py.path.local(fn) or None
  201. lineno = -1
  202. if fspath:
  203. try:
  204. _, lineno = findsource(obj)
  205. except IOError:
  206. pass
  207. else:
  208. fspath = code.path
  209. lineno = code.firstlineno
  210. assert isinstance(lineno, int)
  211. return fspath, lineno
  212. #
  213. # helper functions
  214. #
  215. def findsource(obj):
  216. try:
  217. sourcelines, lineno = inspect.findsource(obj)
  218. except py.builtin._sysex:
  219. raise
  220. except: # noqa
  221. return None, -1
  222. source = Source()
  223. source.lines = [line.rstrip() for line in sourcelines]
  224. return source, lineno
  225. def getsource(obj, **kwargs):
  226. from .code import getrawcode
  227. obj = getrawcode(obj)
  228. try:
  229. strsrc = inspect.getsource(obj)
  230. except IndentationError:
  231. strsrc = '"Buggy python version consider upgrading, cannot get source"'
  232. assert isinstance(strsrc, str)
  233. return Source(strsrc, **kwargs)
  234. def deindent(lines, offset=None):
  235. if offset is None:
  236. for line in lines:
  237. line = line.expandtabs()
  238. s = line.lstrip()
  239. if s:
  240. offset = len(line) - len(s)
  241. break
  242. else:
  243. offset = 0
  244. if offset == 0:
  245. return list(lines)
  246. newlines = []
  247. def readline_generator(lines):
  248. for line in lines:
  249. yield line + "\n"
  250. it = readline_generator(lines)
  251. try:
  252. for _, _, (sline, _), (eline, _), _ in tokenize.generate_tokens(
  253. lambda: next(it)
  254. ):
  255. if sline > len(lines):
  256. break # End of input reached
  257. if sline > len(newlines):
  258. line = lines[sline - 1].expandtabs()
  259. if line.lstrip() and line[:offset].isspace():
  260. line = line[offset:] # Deindent
  261. newlines.append(line)
  262. for i in range(sline, eline):
  263. # Don't deindent continuing lines of
  264. # multiline tokens (i.e. multiline strings)
  265. newlines.append(lines[i])
  266. except (IndentationError, tokenize.TokenError):
  267. pass
  268. # Add any lines we didn't see. E.g. if an exception was raised.
  269. newlines.extend(lines[len(newlines) :])
  270. return newlines
  271. def get_statement_startend2(lineno, node):
  272. import ast
  273. # flatten all statements and except handlers into one lineno-list
  274. # AST's line numbers start indexing at 1
  275. values = []
  276. for x in ast.walk(node):
  277. if isinstance(x, (ast.stmt, ast.ExceptHandler)):
  278. values.append(x.lineno - 1)
  279. for name in ("finalbody", "orelse"):
  280. val = getattr(x, name, None)
  281. if val:
  282. # treat the finally/orelse part as its own statement
  283. values.append(val[0].lineno - 1 - 1)
  284. values.sort()
  285. insert_index = bisect_right(values, lineno)
  286. start = values[insert_index - 1]
  287. if insert_index >= len(values):
  288. end = None
  289. else:
  290. end = values[insert_index]
  291. return start, end
  292. def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
  293. if astnode is None:
  294. content = str(source)
  295. astnode = compile(content, "source", "exec", 1024) # 1024 for AST
  296. start, end = get_statement_startend2(lineno, astnode)
  297. # we need to correct the end:
  298. # - ast-parsing strips comments
  299. # - there might be empty lines
  300. # - we might have lesser indented code blocks at the end
  301. if end is None:
  302. end = len(source.lines)
  303. if end > start + 1:
  304. # make sure we don't span differently indented code blocks
  305. # by using the BlockFinder helper used which inspect.getsource() uses itself
  306. block_finder = inspect.BlockFinder()
  307. # if we start with an indented line, put blockfinder to "started" mode
  308. block_finder.started = source.lines[start][0].isspace()
  309. it = ((x + "\n") for x in source.lines[start:end])
  310. try:
  311. for tok in tokenize.generate_tokens(lambda: next(it)):
  312. block_finder.tokeneater(*tok)
  313. except (inspect.EndOfBlock, IndentationError):
  314. end = block_finder.last + start
  315. except Exception:
  316. pass
  317. # the end might still point to a comment or empty line, correct it
  318. while end:
  319. line = source.lines[end - 1].lstrip()
  320. if line.startswith("#") or not line:
  321. end -= 1
  322. else:
  323. break
  324. return astnode, start, end