rewrite.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969
  1. """Rewrite assertion AST to produce nice error messages"""
  2. from __future__ import absolute_import, division, print_function
  3. import ast
  4. import errno
  5. import itertools
  6. import imp
  7. import marshal
  8. import os
  9. import re
  10. import six
  11. import struct
  12. import sys
  13. import types
  14. import atomicwrites
  15. import py
  16. from _pytest.assertion import util
  17. # pytest caches rewritten pycs in __pycache__.
  18. if hasattr(imp, "get_tag"):
  19. PYTEST_TAG = imp.get_tag() + "-PYTEST"
  20. else:
  21. if hasattr(sys, "pypy_version_info"):
  22. impl = "pypy"
  23. elif sys.platform == "java":
  24. impl = "jython"
  25. else:
  26. impl = "cpython"
  27. ver = sys.version_info
  28. PYTEST_TAG = "%s-%s%s-PYTEST" % (impl, ver[0], ver[1])
  29. del ver, impl
  30. PYC_EXT = ".py" + (__debug__ and "c" or "o")
  31. PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
  32. ASCII_IS_DEFAULT_ENCODING = sys.version_info[0] < 3
  33. if sys.version_info >= (3, 5):
  34. ast_Call = ast.Call
  35. else:
  36. def ast_Call(a, b, c):
  37. return ast.Call(a, b, c, None, None)
  38. if sys.version_info >= (3, 4):
  39. from importlib.util import spec_from_file_location
  40. else:
  41. def spec_from_file_location(*_, **__):
  42. return None
  43. class AssertionRewritingHook(object):
  44. """PEP302 Import hook which rewrites asserts."""
  45. def __init__(self, config):
  46. self.config = config
  47. self.fnpats = config.getini("python_files")
  48. self.session = None
  49. self.modules = {}
  50. self._rewritten_names = set()
  51. self._register_with_pkg_resources()
  52. self._must_rewrite = set()
  53. # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
  54. # which might result in infinite recursion (#3506)
  55. self._writing_pyc = False
  56. def set_session(self, session):
  57. self.session = session
  58. def find_module(self, name, path=None):
  59. if self._writing_pyc:
  60. return None
  61. state = self.config._assertstate
  62. state.trace("find_module called for: %s" % name)
  63. names = name.rsplit(".", 1)
  64. lastname = names[-1]
  65. pth = None
  66. if path is not None:
  67. # Starting with Python 3.3, path is a _NamespacePath(), which
  68. # causes problems if not converted to list.
  69. path = list(path)
  70. if len(path) == 1:
  71. pth = path[0]
  72. if pth is None:
  73. try:
  74. fd, fn, desc = imp.find_module(lastname, path)
  75. except ImportError:
  76. return None
  77. if fd is not None:
  78. fd.close()
  79. tp = desc[2]
  80. if tp == imp.PY_COMPILED:
  81. if hasattr(imp, "source_from_cache"):
  82. try:
  83. fn = imp.source_from_cache(fn)
  84. except ValueError:
  85. # Python 3 doesn't like orphaned but still-importable
  86. # .pyc files.
  87. fn = fn[:-1]
  88. else:
  89. fn = fn[:-1]
  90. elif tp != imp.PY_SOURCE:
  91. # Don't know what this is.
  92. return None
  93. else:
  94. fn = os.path.join(pth, name.rpartition(".")[2] + ".py")
  95. fn_pypath = py.path.local(fn)
  96. if not self._should_rewrite(name, fn_pypath, state):
  97. return None
  98. self._rewritten_names.add(name)
  99. # The requested module looks like a test file, so rewrite it. This is
  100. # the most magical part of the process: load the source, rewrite the
  101. # asserts, and load the rewritten source. We also cache the rewritten
  102. # module code in a special pyc. We must be aware of the possibility of
  103. # concurrent pytest processes rewriting and loading pycs. To avoid
  104. # tricky race conditions, we maintain the following invariant: The
  105. # cached pyc is always a complete, valid pyc. Operations on it must be
  106. # atomic. POSIX's atomic rename comes in handy.
  107. write = not sys.dont_write_bytecode
  108. cache_dir = os.path.join(fn_pypath.dirname, "__pycache__")
  109. if write:
  110. try:
  111. os.mkdir(cache_dir)
  112. except OSError:
  113. e = sys.exc_info()[1].errno
  114. if e == errno.EEXIST:
  115. # Either the __pycache__ directory already exists (the
  116. # common case) or it's blocked by a non-dir node. In the
  117. # latter case, we'll ignore it in _write_pyc.
  118. pass
  119. elif e in [errno.ENOENT, errno.ENOTDIR]:
  120. # One of the path components was not a directory, likely
  121. # because we're in a zip file.
  122. write = False
  123. elif e in [errno.EACCES, errno.EROFS, errno.EPERM]:
  124. state.trace("read only directory: %r" % fn_pypath.dirname)
  125. write = False
  126. else:
  127. raise
  128. cache_name = fn_pypath.basename[:-3] + PYC_TAIL
  129. pyc = os.path.join(cache_dir, cache_name)
  130. # Notice that even if we're in a read-only directory, I'm going
  131. # to check for a cached pyc. This may not be optimal...
  132. co = _read_pyc(fn_pypath, pyc, state.trace)
  133. if co is None:
  134. state.trace("rewriting %r" % (fn,))
  135. source_stat, co = _rewrite_test(self.config, fn_pypath)
  136. if co is None:
  137. # Probably a SyntaxError in the test.
  138. return None
  139. if write:
  140. self._writing_pyc = True
  141. try:
  142. _write_pyc(state, co, source_stat, pyc)
  143. finally:
  144. self._writing_pyc = False
  145. else:
  146. state.trace("found cached rewritten pyc for %r" % (fn,))
  147. self.modules[name] = co, pyc
  148. return self
  149. def _should_rewrite(self, name, fn_pypath, state):
  150. # always rewrite conftest files
  151. fn = str(fn_pypath)
  152. if fn_pypath.basename == "conftest.py":
  153. state.trace("rewriting conftest file: %r" % (fn,))
  154. return True
  155. if self.session is not None:
  156. if self.session.isinitpath(fn):
  157. state.trace("matched test file (was specified on cmdline): %r" % (fn,))
  158. return True
  159. # modules not passed explicitly on the command line are only
  160. # rewritten if they match the naming convention for test files
  161. for pat in self.fnpats:
  162. if fn_pypath.fnmatch(pat):
  163. state.trace("matched test file %r" % (fn,))
  164. return True
  165. for marked in self._must_rewrite:
  166. if name == marked or name.startswith(marked + "."):
  167. state.trace("matched marked file %r (from %r)" % (name, marked))
  168. return True
  169. return False
  170. def mark_rewrite(self, *names):
  171. """Mark import names as needing to be rewritten.
  172. The named module or package as well as any nested modules will
  173. be rewritten on import.
  174. """
  175. already_imported = (
  176. set(names).intersection(sys.modules).difference(self._rewritten_names)
  177. )
  178. for name in already_imported:
  179. if not AssertionRewriter.is_rewrite_disabled(
  180. sys.modules[name].__doc__ or ""
  181. ):
  182. self._warn_already_imported(name)
  183. self._must_rewrite.update(names)
  184. def _warn_already_imported(self, name):
  185. self.config.warn(
  186. "P1", "Module already imported so cannot be rewritten: %s" % name
  187. )
  188. def load_module(self, name):
  189. # If there is an existing module object named 'fullname' in
  190. # sys.modules, the loader must use that existing module. (Otherwise,
  191. # the reload() builtin will not work correctly.)
  192. if name in sys.modules:
  193. return sys.modules[name]
  194. co, pyc = self.modules.pop(name)
  195. # I wish I could just call imp.load_compiled here, but __file__ has to
  196. # be set properly. In Python 3.2+, this all would be handled correctly
  197. # by load_compiled.
  198. mod = sys.modules[name] = imp.new_module(name)
  199. try:
  200. mod.__file__ = co.co_filename
  201. # Normally, this attribute is 3.2+.
  202. mod.__cached__ = pyc
  203. mod.__loader__ = self
  204. # Normally, this attribute is 3.4+
  205. mod.__spec__ = spec_from_file_location(name, co.co_filename, loader=self)
  206. six.exec_(co, mod.__dict__)
  207. except: # noqa
  208. if name in sys.modules:
  209. del sys.modules[name]
  210. raise
  211. return sys.modules[name]
  212. def is_package(self, name):
  213. try:
  214. fd, fn, desc = imp.find_module(name)
  215. except ImportError:
  216. return False
  217. if fd is not None:
  218. fd.close()
  219. tp = desc[2]
  220. return tp == imp.PKG_DIRECTORY
  221. @classmethod
  222. def _register_with_pkg_resources(cls):
  223. """
  224. Ensure package resources can be loaded from this loader. May be called
  225. multiple times, as the operation is idempotent.
  226. """
  227. try:
  228. import pkg_resources
  229. # access an attribute in case a deferred importer is present
  230. pkg_resources.__name__
  231. except ImportError:
  232. return
  233. # Since pytest tests are always located in the file system, the
  234. # DefaultProvider is appropriate.
  235. pkg_resources.register_loader_type(cls, pkg_resources.DefaultProvider)
  236. def get_data(self, pathname):
  237. """Optional PEP302 get_data API.
  238. """
  239. with open(pathname, "rb") as f:
  240. return f.read()
  241. def _write_pyc(state, co, source_stat, pyc):
  242. # Technically, we don't have to have the same pyc format as
  243. # (C)Python, since these "pycs" should never be seen by builtin
  244. # import. However, there's little reason deviate, and I hope
  245. # sometime to be able to use imp.load_compiled to load them. (See
  246. # the comment in load_module above.)
  247. try:
  248. with atomicwrites.atomic_write(pyc, mode="wb", overwrite=True) as fp:
  249. fp.write(imp.get_magic())
  250. mtime = int(source_stat.mtime)
  251. size = source_stat.size & 0xFFFFFFFF
  252. fp.write(struct.pack("<ll", mtime, size))
  253. fp.write(marshal.dumps(co))
  254. except EnvironmentError as e:
  255. state.trace("error writing pyc file at %s: errno=%s" % (pyc, e.errno))
  256. # we ignore any failure to write the cache file
  257. # there are many reasons, permission-denied, __pycache__ being a
  258. # file etc.
  259. return False
  260. return True
  261. RN = "\r\n".encode("utf-8")
  262. N = "\n".encode("utf-8")
  263. cookie_re = re.compile(r"^[ \t\f]*#.*coding[:=][ \t]*[-\w.]+")
  264. BOM_UTF8 = "\xef\xbb\xbf"
  265. def _rewrite_test(config, fn):
  266. """Try to read and rewrite *fn* and return the code object."""
  267. state = config._assertstate
  268. try:
  269. stat = fn.stat()
  270. source = fn.read("rb")
  271. except EnvironmentError:
  272. return None, None
  273. if ASCII_IS_DEFAULT_ENCODING:
  274. # ASCII is the default encoding in Python 2. Without a coding
  275. # declaration, Python 2 will complain about any bytes in the file
  276. # outside the ASCII range. Sadly, this behavior does not extend to
  277. # compile() or ast.parse(), which prefer to interpret the bytes as
  278. # latin-1. (At least they properly handle explicit coding cookies.) To
  279. # preserve this error behavior, we could force ast.parse() to use ASCII
  280. # as the encoding by inserting a coding cookie. Unfortunately, that
  281. # messes up line numbers. Thus, we have to check ourselves if anything
  282. # is outside the ASCII range in the case no encoding is explicitly
  283. # declared. For more context, see issue #269. Yay for Python 3 which
  284. # gets this right.
  285. end1 = source.find("\n")
  286. end2 = source.find("\n", end1 + 1)
  287. if (
  288. not source.startswith(BOM_UTF8)
  289. and cookie_re.match(source[0:end1]) is None
  290. and cookie_re.match(source[end1 + 1 : end2]) is None
  291. ):
  292. if hasattr(state, "_indecode"):
  293. # encodings imported us again, so don't rewrite.
  294. return None, None
  295. state._indecode = True
  296. try:
  297. try:
  298. source.decode("ascii")
  299. except UnicodeDecodeError:
  300. # Let it fail in real import.
  301. return None, None
  302. finally:
  303. del state._indecode
  304. try:
  305. tree = ast.parse(source)
  306. except SyntaxError:
  307. # Let this pop up again in the real import.
  308. state.trace("failed to parse: %r" % (fn,))
  309. return None, None
  310. rewrite_asserts(tree, fn, config)
  311. try:
  312. co = compile(tree, fn.strpath, "exec", dont_inherit=True)
  313. except SyntaxError:
  314. # It's possible that this error is from some bug in the
  315. # assertion rewriting, but I don't know of a fast way to tell.
  316. state.trace("failed to compile: %r" % (fn,))
  317. return None, None
  318. return stat, co
  319. def _read_pyc(source, pyc, trace=lambda x: None):
  320. """Possibly read a pytest pyc containing rewritten code.
  321. Return rewritten code if successful or None if not.
  322. """
  323. try:
  324. fp = open(pyc, "rb")
  325. except IOError:
  326. return None
  327. with fp:
  328. try:
  329. mtime = int(source.mtime())
  330. size = source.size()
  331. data = fp.read(12)
  332. except EnvironmentError as e:
  333. trace("_read_pyc(%s): EnvironmentError %s" % (source, e))
  334. return None
  335. # Check for invalid or out of date pyc file.
  336. if (
  337. len(data) != 12
  338. or data[:4] != imp.get_magic()
  339. or struct.unpack("<ll", data[4:]) != (mtime, size)
  340. ):
  341. trace("_read_pyc(%s): invalid or out of date pyc" % source)
  342. return None
  343. try:
  344. co = marshal.load(fp)
  345. except Exception as e:
  346. trace("_read_pyc(%s): marshal.load error %s" % (source, e))
  347. return None
  348. if not isinstance(co, types.CodeType):
  349. trace("_read_pyc(%s): not a code object" % source)
  350. return None
  351. return co
  352. def rewrite_asserts(mod, module_path=None, config=None):
  353. """Rewrite the assert statements in mod."""
  354. AssertionRewriter(module_path, config).run(mod)
  355. def _saferepr(obj):
  356. """Get a safe repr of an object for assertion error messages.
  357. The assertion formatting (util.format_explanation()) requires
  358. newlines to be escaped since they are a special character for it.
  359. Normally assertion.util.format_explanation() does this but for a
  360. custom repr it is possible to contain one of the special escape
  361. sequences, especially '\n{' and '\n}' are likely to be present in
  362. JSON reprs.
  363. """
  364. r = py.io.saferepr(obj)
  365. if isinstance(r, six.text_type):
  366. return r.replace(u"\n", u"\\n")
  367. else:
  368. return r.replace(b"\n", b"\\n")
  369. from _pytest.assertion.util import format_explanation as _format_explanation # noqa
  370. def _format_assertmsg(obj):
  371. """Format the custom assertion message given.
  372. For strings this simply replaces newlines with '\n~' so that
  373. util.format_explanation() will preserve them instead of escaping
  374. newlines. For other objects py.io.saferepr() is used first.
  375. """
  376. # reprlib appears to have a bug which means that if a string
  377. # contains a newline it gets escaped, however if an object has a
  378. # .__repr__() which contains newlines it does not get escaped.
  379. # However in either case we want to preserve the newline.
  380. replaces = [(u"\n", u"\n~"), (u"%", u"%%")]
  381. if not isinstance(obj, six.string_types):
  382. obj = py.io.saferepr(obj)
  383. replaces.append((u"\\n", u"\n~"))
  384. if isinstance(obj, bytes):
  385. replaces = [(r1.encode(), r2.encode()) for r1, r2 in replaces]
  386. for r1, r2 in replaces:
  387. obj = obj.replace(r1, r2)
  388. return obj
  389. def _should_repr_global_name(obj):
  390. return not hasattr(obj, "__name__") and not callable(obj)
  391. def _format_boolop(explanations, is_or):
  392. explanation = "(" + (is_or and " or " or " and ").join(explanations) + ")"
  393. if isinstance(explanation, six.text_type):
  394. return explanation.replace(u"%", u"%%")
  395. else:
  396. return explanation.replace(b"%", b"%%")
  397. def _call_reprcompare(ops, results, expls, each_obj):
  398. for i, res, expl in zip(range(len(ops)), results, expls):
  399. try:
  400. done = not res
  401. except Exception:
  402. done = True
  403. if done:
  404. break
  405. if util._reprcompare is not None:
  406. custom = util._reprcompare(ops[i], each_obj[i], each_obj[i + 1])
  407. if custom is not None:
  408. return custom
  409. return expl
  410. unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
  411. binop_map = {
  412. ast.BitOr: "|",
  413. ast.BitXor: "^",
  414. ast.BitAnd: "&",
  415. ast.LShift: "<<",
  416. ast.RShift: ">>",
  417. ast.Add: "+",
  418. ast.Sub: "-",
  419. ast.Mult: "*",
  420. ast.Div: "/",
  421. ast.FloorDiv: "//",
  422. ast.Mod: "%%", # escaped for string formatting
  423. ast.Eq: "==",
  424. ast.NotEq: "!=",
  425. ast.Lt: "<",
  426. ast.LtE: "<=",
  427. ast.Gt: ">",
  428. ast.GtE: ">=",
  429. ast.Pow: "**",
  430. ast.Is: "is",
  431. ast.IsNot: "is not",
  432. ast.In: "in",
  433. ast.NotIn: "not in",
  434. }
  435. # Python 3.5+ compatibility
  436. try:
  437. binop_map[ast.MatMult] = "@"
  438. except AttributeError:
  439. pass
  440. # Python 3.4+ compatibility
  441. if hasattr(ast, "NameConstant"):
  442. _NameConstant = ast.NameConstant
  443. else:
  444. def _NameConstant(c):
  445. return ast.Name(str(c), ast.Load())
  446. def set_location(node, lineno, col_offset):
  447. """Set node location information recursively."""
  448. def _fix(node, lineno, col_offset):
  449. if "lineno" in node._attributes:
  450. node.lineno = lineno
  451. if "col_offset" in node._attributes:
  452. node.col_offset = col_offset
  453. for child in ast.iter_child_nodes(node):
  454. _fix(child, lineno, col_offset)
  455. _fix(node, lineno, col_offset)
  456. return node
  457. class AssertionRewriter(ast.NodeVisitor):
  458. """Assertion rewriting implementation.
  459. The main entrypoint is to call .run() with an ast.Module instance,
  460. this will then find all the assert statements and rewrite them to
  461. provide intermediate values and a detailed assertion error. See
  462. http://pybites.blogspot.be/2011/07/behind-scenes-of-pytests-new-assertion.html
  463. for an overview of how this works.
  464. The entry point here is .run() which will iterate over all the
  465. statements in an ast.Module and for each ast.Assert statement it
  466. finds call .visit() with it. Then .visit_Assert() takes over and
  467. is responsible for creating new ast statements to replace the
  468. original assert statement: it rewrites the test of an assertion
  469. to provide intermediate values and replace it with an if statement
  470. which raises an assertion error with a detailed explanation in
  471. case the expression is false.
  472. For this .visit_Assert() uses the visitor pattern to visit all the
  473. AST nodes of the ast.Assert.test field, each visit call returning
  474. an AST node and the corresponding explanation string. During this
  475. state is kept in several instance attributes:
  476. :statements: All the AST statements which will replace the assert
  477. statement.
  478. :variables: This is populated by .variable() with each variable
  479. used by the statements so that they can all be set to None at
  480. the end of the statements.
  481. :variable_counter: Counter to create new unique variables needed
  482. by statements. Variables are created using .variable() and
  483. have the form of "@py_assert0".
  484. :on_failure: The AST statements which will be executed if the
  485. assertion test fails. This is the code which will construct
  486. the failure message and raises the AssertionError.
  487. :explanation_specifiers: A dict filled by .explanation_param()
  488. with %-formatting placeholders and their corresponding
  489. expressions to use in the building of an assertion message.
  490. This is used by .pop_format_context() to build a message.
  491. :stack: A stack of the explanation_specifiers dicts maintained by
  492. .push_format_context() and .pop_format_context() which allows
  493. to build another %-formatted string while already building one.
  494. This state is reset on every new assert statement visited and used
  495. by the other visitors.
  496. """
  497. def __init__(self, module_path, config):
  498. super(AssertionRewriter, self).__init__()
  499. self.module_path = module_path
  500. self.config = config
  501. def run(self, mod):
  502. """Find all assert statements in *mod* and rewrite them."""
  503. if not mod.body:
  504. # Nothing to do.
  505. return
  506. # Insert some special imports at the top of the module but after any
  507. # docstrings and __future__ imports.
  508. aliases = [
  509. ast.alias(py.builtin.builtins.__name__, "@py_builtins"),
  510. ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
  511. ]
  512. doc = getattr(mod, "docstring", None)
  513. expect_docstring = doc is None
  514. if doc is not None and self.is_rewrite_disabled(doc):
  515. return
  516. pos = 0
  517. lineno = 1
  518. for item in mod.body:
  519. if (
  520. expect_docstring
  521. and isinstance(item, ast.Expr)
  522. and isinstance(item.value, ast.Str)
  523. ):
  524. doc = item.value.s
  525. if self.is_rewrite_disabled(doc):
  526. return
  527. expect_docstring = False
  528. elif (
  529. not isinstance(item, ast.ImportFrom)
  530. or item.level > 0
  531. or item.module != "__future__"
  532. ):
  533. lineno = item.lineno
  534. break
  535. pos += 1
  536. else:
  537. lineno = item.lineno
  538. imports = [
  539. ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
  540. ]
  541. mod.body[pos:pos] = imports
  542. # Collect asserts.
  543. nodes = [mod]
  544. while nodes:
  545. node = nodes.pop()
  546. for name, field in ast.iter_fields(node):
  547. if isinstance(field, list):
  548. new = []
  549. for i, child in enumerate(field):
  550. if isinstance(child, ast.Assert):
  551. # Transform assert.
  552. new.extend(self.visit(child))
  553. else:
  554. new.append(child)
  555. if isinstance(child, ast.AST):
  556. nodes.append(child)
  557. setattr(node, name, new)
  558. elif (
  559. isinstance(field, ast.AST)
  560. and
  561. # Don't recurse into expressions as they can't contain
  562. # asserts.
  563. not isinstance(field, ast.expr)
  564. ):
  565. nodes.append(field)
  566. @staticmethod
  567. def is_rewrite_disabled(docstring):
  568. return "PYTEST_DONT_REWRITE" in docstring
  569. def variable(self):
  570. """Get a new variable."""
  571. # Use a character invalid in python identifiers to avoid clashing.
  572. name = "@py_assert" + str(next(self.variable_counter))
  573. self.variables.append(name)
  574. return name
  575. def assign(self, expr):
  576. """Give *expr* a name."""
  577. name = self.variable()
  578. self.statements.append(ast.Assign([ast.Name(name, ast.Store())], expr))
  579. return ast.Name(name, ast.Load())
  580. def display(self, expr):
  581. """Call py.io.saferepr on the expression."""
  582. return self.helper("saferepr", expr)
  583. def helper(self, name, *args):
  584. """Call a helper in this module."""
  585. py_name = ast.Name("@pytest_ar", ast.Load())
  586. attr = ast.Attribute(py_name, "_" + name, ast.Load())
  587. return ast_Call(attr, list(args), [])
  588. def builtin(self, name):
  589. """Return the builtin called *name*."""
  590. builtin_name = ast.Name("@py_builtins", ast.Load())
  591. return ast.Attribute(builtin_name, name, ast.Load())
  592. def explanation_param(self, expr):
  593. """Return a new named %-formatting placeholder for expr.
  594. This creates a %-formatting placeholder for expr in the
  595. current formatting context, e.g. ``%(py0)s``. The placeholder
  596. and expr are placed in the current format context so that it
  597. can be used on the next call to .pop_format_context().
  598. """
  599. specifier = "py" + str(next(self.variable_counter))
  600. self.explanation_specifiers[specifier] = expr
  601. return "%(" + specifier + ")s"
  602. def push_format_context(self):
  603. """Create a new formatting context.
  604. The format context is used for when an explanation wants to
  605. have a variable value formatted in the assertion message. In
  606. this case the value required can be added using
  607. .explanation_param(). Finally .pop_format_context() is used
  608. to format a string of %-formatted values as added by
  609. .explanation_param().
  610. """
  611. self.explanation_specifiers = {}
  612. self.stack.append(self.explanation_specifiers)
  613. def pop_format_context(self, expl_expr):
  614. """Format the %-formatted string with current format context.
  615. The expl_expr should be an ast.Str instance constructed from
  616. the %-placeholders created by .explanation_param(). This will
  617. add the required code to format said string to .on_failure and
  618. return the ast.Name instance of the formatted string.
  619. """
  620. current = self.stack.pop()
  621. if self.stack:
  622. self.explanation_specifiers = self.stack[-1]
  623. keys = [ast.Str(key) for key in current.keys()]
  624. format_dict = ast.Dict(keys, list(current.values()))
  625. form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
  626. name = "@py_format" + str(next(self.variable_counter))
  627. self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
  628. return ast.Name(name, ast.Load())
  629. def generic_visit(self, node):
  630. """Handle expressions we don't have custom code for."""
  631. assert isinstance(node, ast.expr)
  632. res = self.assign(node)
  633. return res, self.explanation_param(self.display(res))
  634. def visit_Assert(self, assert_):
  635. """Return the AST statements to replace the ast.Assert instance.
  636. This rewrites the test of an assertion to provide
  637. intermediate values and replace it with an if statement which
  638. raises an assertion error with a detailed explanation in case
  639. the expression is false.
  640. """
  641. if isinstance(assert_.test, ast.Tuple) and self.config is not None:
  642. fslocation = (self.module_path, assert_.lineno)
  643. self.config.warn(
  644. "R1",
  645. "assertion is always true, perhaps " "remove parentheses?",
  646. fslocation=fslocation,
  647. )
  648. self.statements = []
  649. self.variables = []
  650. self.variable_counter = itertools.count()
  651. self.stack = []
  652. self.on_failure = []
  653. self.push_format_context()
  654. # Rewrite assert into a bunch of statements.
  655. top_condition, explanation = self.visit(assert_.test)
  656. # Create failure message.
  657. body = self.on_failure
  658. negation = ast.UnaryOp(ast.Not(), top_condition)
  659. self.statements.append(ast.If(negation, body, []))
  660. if assert_.msg:
  661. assertmsg = self.helper("format_assertmsg", assert_.msg)
  662. explanation = "\n>assert " + explanation
  663. else:
  664. assertmsg = ast.Str("")
  665. explanation = "assert " + explanation
  666. template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
  667. msg = self.pop_format_context(template)
  668. fmt = self.helper("format_explanation", msg)
  669. err_name = ast.Name("AssertionError", ast.Load())
  670. exc = ast_Call(err_name, [fmt], [])
  671. if sys.version_info[0] >= 3:
  672. raise_ = ast.Raise(exc, None)
  673. else:
  674. raise_ = ast.Raise(exc, None, None)
  675. body.append(raise_)
  676. # Clear temporary variables by setting them to None.
  677. if self.variables:
  678. variables = [ast.Name(name, ast.Store()) for name in self.variables]
  679. clear = ast.Assign(variables, _NameConstant(None))
  680. self.statements.append(clear)
  681. # Fix line numbers.
  682. for stmt in self.statements:
  683. set_location(stmt, assert_.lineno, assert_.col_offset)
  684. return self.statements
  685. def visit_Name(self, name):
  686. # Display the repr of the name if it's a local variable or
  687. # _should_repr_global_name() thinks it's acceptable.
  688. locs = ast_Call(self.builtin("locals"), [], [])
  689. inlocs = ast.Compare(ast.Str(name.id), [ast.In()], [locs])
  690. dorepr = self.helper("should_repr_global_name", name)
  691. test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
  692. expr = ast.IfExp(test, self.display(name), ast.Str(name.id))
  693. return name, self.explanation_param(expr)
  694. def visit_BoolOp(self, boolop):
  695. res_var = self.variable()
  696. expl_list = self.assign(ast.List([], ast.Load()))
  697. app = ast.Attribute(expl_list, "append", ast.Load())
  698. is_or = int(isinstance(boolop.op, ast.Or))
  699. body = save = self.statements
  700. fail_save = self.on_failure
  701. levels = len(boolop.values) - 1
  702. self.push_format_context()
  703. # Process each operand, short-circuting if needed.
  704. for i, v in enumerate(boolop.values):
  705. if i:
  706. fail_inner = []
  707. # cond is set in a prior loop iteration below
  708. self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
  709. self.on_failure = fail_inner
  710. self.push_format_context()
  711. res, expl = self.visit(v)
  712. body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
  713. expl_format = self.pop_format_context(ast.Str(expl))
  714. call = ast_Call(app, [expl_format], [])
  715. self.on_failure.append(ast.Expr(call))
  716. if i < levels:
  717. cond = res
  718. if is_or:
  719. cond = ast.UnaryOp(ast.Not(), cond)
  720. inner = []
  721. self.statements.append(ast.If(cond, inner, []))
  722. self.statements = body = inner
  723. self.statements = save
  724. self.on_failure = fail_save
  725. expl_template = self.helper("format_boolop", expl_list, ast.Num(is_or))
  726. expl = self.pop_format_context(expl_template)
  727. return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
  728. def visit_UnaryOp(self, unary):
  729. pattern = unary_map[unary.op.__class__]
  730. operand_res, operand_expl = self.visit(unary.operand)
  731. res = self.assign(ast.UnaryOp(unary.op, operand_res))
  732. return res, pattern % (operand_expl,)
  733. def visit_BinOp(self, binop):
  734. symbol = binop_map[binop.op.__class__]
  735. left_expr, left_expl = self.visit(binop.left)
  736. right_expr, right_expl = self.visit(binop.right)
  737. explanation = "(%s %s %s)" % (left_expl, symbol, right_expl)
  738. res = self.assign(ast.BinOp(left_expr, binop.op, right_expr))
  739. return res, explanation
  740. def visit_Call_35(self, call):
  741. """
  742. visit `ast.Call` nodes on Python3.5 and after
  743. """
  744. new_func, func_expl = self.visit(call.func)
  745. arg_expls = []
  746. new_args = []
  747. new_kwargs = []
  748. for arg in call.args:
  749. res, expl = self.visit(arg)
  750. arg_expls.append(expl)
  751. new_args.append(res)
  752. for keyword in call.keywords:
  753. res, expl = self.visit(keyword.value)
  754. new_kwargs.append(ast.keyword(keyword.arg, res))
  755. if keyword.arg:
  756. arg_expls.append(keyword.arg + "=" + expl)
  757. else: # **args have `arg` keywords with an .arg of None
  758. arg_expls.append("**" + expl)
  759. expl = "%s(%s)" % (func_expl, ", ".join(arg_expls))
  760. new_call = ast.Call(new_func, new_args, new_kwargs)
  761. res = self.assign(new_call)
  762. res_expl = self.explanation_param(self.display(res))
  763. outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
  764. return res, outer_expl
  765. def visit_Starred(self, starred):
  766. # From Python 3.5, a Starred node can appear in a function call
  767. res, expl = self.visit(starred.value)
  768. return starred, "*" + expl
  769. def visit_Call_legacy(self, call):
  770. """
  771. visit `ast.Call nodes on 3.4 and below`
  772. """
  773. new_func, func_expl = self.visit(call.func)
  774. arg_expls = []
  775. new_args = []
  776. new_kwargs = []
  777. new_star = new_kwarg = None
  778. for arg in call.args:
  779. res, expl = self.visit(arg)
  780. new_args.append(res)
  781. arg_expls.append(expl)
  782. for keyword in call.keywords:
  783. res, expl = self.visit(keyword.value)
  784. new_kwargs.append(ast.keyword(keyword.arg, res))
  785. arg_expls.append(keyword.arg + "=" + expl)
  786. if call.starargs:
  787. new_star, expl = self.visit(call.starargs)
  788. arg_expls.append("*" + expl)
  789. if call.kwargs:
  790. new_kwarg, expl = self.visit(call.kwargs)
  791. arg_expls.append("**" + expl)
  792. expl = "%s(%s)" % (func_expl, ", ".join(arg_expls))
  793. new_call = ast.Call(new_func, new_args, new_kwargs, new_star, new_kwarg)
  794. res = self.assign(new_call)
  795. res_expl = self.explanation_param(self.display(res))
  796. outer_expl = "%s\n{%s = %s\n}" % (res_expl, res_expl, expl)
  797. return res, outer_expl
  798. # ast.Call signature changed on 3.5,
  799. # conditionally change which methods is named
  800. # visit_Call depending on Python version
  801. if sys.version_info >= (3, 5):
  802. visit_Call = visit_Call_35
  803. else:
  804. visit_Call = visit_Call_legacy
  805. def visit_Attribute(self, attr):
  806. if not isinstance(attr.ctx, ast.Load):
  807. return self.generic_visit(attr)
  808. value, value_expl = self.visit(attr.value)
  809. res = self.assign(ast.Attribute(value, attr.attr, ast.Load()))
  810. res_expl = self.explanation_param(self.display(res))
  811. pat = "%s\n{%s = %s.%s\n}"
  812. expl = pat % (res_expl, res_expl, value_expl, attr.attr)
  813. return res, expl
  814. def visit_Compare(self, comp):
  815. self.push_format_context()
  816. left_res, left_expl = self.visit(comp.left)
  817. if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
  818. left_expl = "({})".format(left_expl)
  819. res_variables = [self.variable() for i in range(len(comp.ops))]
  820. load_names = [ast.Name(v, ast.Load()) for v in res_variables]
  821. store_names = [ast.Name(v, ast.Store()) for v in res_variables]
  822. it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
  823. expls = []
  824. syms = []
  825. results = [left_res]
  826. for i, op, next_operand in it:
  827. next_res, next_expl = self.visit(next_operand)
  828. if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
  829. next_expl = "({})".format(next_expl)
  830. results.append(next_res)
  831. sym = binop_map[op.__class__]
  832. syms.append(ast.Str(sym))
  833. expl = "%s %s %s" % (left_expl, sym, next_expl)
  834. expls.append(ast.Str(expl))
  835. res_expr = ast.Compare(left_res, [op], [next_res])
  836. self.statements.append(ast.Assign([store_names[i]], res_expr))
  837. left_res, left_expl = next_res, next_expl
  838. # Use pytest.assertion.util._reprcompare if that's available.
  839. expl_call = self.helper(
  840. "call_reprcompare",
  841. ast.Tuple(syms, ast.Load()),
  842. ast.Tuple(load_names, ast.Load()),
  843. ast.Tuple(expls, ast.Load()),
  844. ast.Tuple(results, ast.Load()),
  845. )
  846. if len(comp.ops) > 1:
  847. res = ast.BoolOp(ast.And(), load_names)
  848. else:
  849. res = load_names[0]
  850. return res, self.explanation_param(self.pop_format_context(expl_call))