stack_context_test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. from __future__ import absolute_import, division, print_function
  2. from tornado import gen
  3. from tornado.ioloop import IOLoop
  4. from tornado.log import app_log
  5. from tornado.stack_context import (StackContext, wrap, NullContext, StackContextInconsistentError,
  6. ExceptionStackContext, run_with_stack_context, _state)
  7. from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, ExpectLog, gen_test
  8. from tornado.test.util import unittest, ignore_deprecation
  9. from tornado.web import asynchronous, Application, RequestHandler
  10. import contextlib
  11. import functools
  12. import logging
  13. import warnings
  14. class TestRequestHandler(RequestHandler):
  15. def __init__(self, app, request):
  16. super(TestRequestHandler, self).__init__(app, request)
  17. with ignore_deprecation():
  18. @asynchronous
  19. def get(self):
  20. logging.debug('in get()')
  21. # call self.part2 without a self.async_callback wrapper. Its
  22. # exception should still get thrown
  23. IOLoop.current().add_callback(self.part2)
  24. def part2(self):
  25. logging.debug('in part2()')
  26. # Go through a third layer to make sure that contexts once restored
  27. # are again passed on to future callbacks
  28. IOLoop.current().add_callback(self.part3)
  29. def part3(self):
  30. logging.debug('in part3()')
  31. raise Exception('test exception')
  32. def write_error(self, status_code, **kwargs):
  33. if 'exc_info' in kwargs and str(kwargs['exc_info'][1]) == 'test exception':
  34. self.write('got expected exception')
  35. else:
  36. self.write('unexpected failure')
  37. class HTTPStackContextTest(AsyncHTTPTestCase):
  38. def get_app(self):
  39. return Application([('/', TestRequestHandler)])
  40. def test_stack_context(self):
  41. with ExpectLog(app_log, "Uncaught exception GET /"):
  42. with ignore_deprecation():
  43. self.http_client.fetch(self.get_url('/'), self.handle_response)
  44. self.wait()
  45. self.assertEqual(self.response.code, 500)
  46. self.assertTrue(b'got expected exception' in self.response.body)
  47. def handle_response(self, response):
  48. self.response = response
  49. self.stop()
  50. class StackContextTest(AsyncTestCase):
  51. def setUp(self):
  52. super(StackContextTest, self).setUp()
  53. self.active_contexts = []
  54. self.warning_catcher = warnings.catch_warnings()
  55. self.warning_catcher.__enter__()
  56. warnings.simplefilter('ignore', DeprecationWarning)
  57. def tearDown(self):
  58. self.warning_catcher.__exit__(None, None, None)
  59. super(StackContextTest, self).tearDown()
  60. @contextlib.contextmanager
  61. def context(self, name):
  62. self.active_contexts.append(name)
  63. yield
  64. self.assertEqual(self.active_contexts.pop(), name)
  65. # Simulates the effect of an asynchronous library that uses its own
  66. # StackContext internally and then returns control to the application.
  67. def test_exit_library_context(self):
  68. def library_function(callback):
  69. # capture the caller's context before introducing our own
  70. callback = wrap(callback)
  71. with StackContext(functools.partial(self.context, 'library')):
  72. self.io_loop.add_callback(
  73. functools.partial(library_inner_callback, callback))
  74. def library_inner_callback(callback):
  75. self.assertEqual(self.active_contexts[-2:],
  76. ['application', 'library'])
  77. callback()
  78. def final_callback():
  79. # implementation detail: the full context stack at this point
  80. # is ['application', 'library', 'application']. The 'library'
  81. # context was not removed, but is no longer innermost so
  82. # the application context takes precedence.
  83. self.assertEqual(self.active_contexts[-1], 'application')
  84. self.stop()
  85. with StackContext(functools.partial(self.context, 'application')):
  86. library_function(final_callback)
  87. self.wait()
  88. def test_deactivate(self):
  89. deactivate_callbacks = []
  90. def f1():
  91. with StackContext(functools.partial(self.context, 'c1')) as c1:
  92. deactivate_callbacks.append(c1)
  93. self.io_loop.add_callback(f2)
  94. def f2():
  95. with StackContext(functools.partial(self.context, 'c2')) as c2:
  96. deactivate_callbacks.append(c2)
  97. self.io_loop.add_callback(f3)
  98. def f3():
  99. with StackContext(functools.partial(self.context, 'c3')) as c3:
  100. deactivate_callbacks.append(c3)
  101. self.io_loop.add_callback(f4)
  102. def f4():
  103. self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
  104. deactivate_callbacks[1]()
  105. # deactivating a context doesn't remove it immediately,
  106. # but it will be missing from the next iteration
  107. self.assertEqual(self.active_contexts, ['c1', 'c2', 'c3'])
  108. self.io_loop.add_callback(f5)
  109. def f5():
  110. self.assertEqual(self.active_contexts, ['c1', 'c3'])
  111. self.stop()
  112. self.io_loop.add_callback(f1)
  113. self.wait()
  114. def test_deactivate_order(self):
  115. # Stack context deactivation has separate logic for deactivation at
  116. # the head and tail of the stack, so make sure it works in any order.
  117. def check_contexts():
  118. # Make sure that the full-context array and the exception-context
  119. # linked lists are consistent with each other.
  120. full_contexts, chain = _state.contexts
  121. exception_contexts = []
  122. while chain is not None:
  123. exception_contexts.append(chain)
  124. chain = chain.old_contexts[1]
  125. self.assertEqual(list(reversed(full_contexts)), exception_contexts)
  126. return list(self.active_contexts)
  127. def make_wrapped_function():
  128. """Wraps a function in three stack contexts, and returns
  129. the function along with the deactivation functions.
  130. """
  131. # Remove the test's stack context to make sure we can cover
  132. # the case where the last context is deactivated.
  133. with NullContext():
  134. partial = functools.partial
  135. with StackContext(partial(self.context, 'c0')) as c0:
  136. with StackContext(partial(self.context, 'c1')) as c1:
  137. with StackContext(partial(self.context, 'c2')) as c2:
  138. return (wrap(check_contexts), [c0, c1, c2])
  139. # First make sure the test mechanism works without any deactivations
  140. func, deactivate_callbacks = make_wrapped_function()
  141. self.assertEqual(func(), ['c0', 'c1', 'c2'])
  142. # Deactivate the tail
  143. func, deactivate_callbacks = make_wrapped_function()
  144. deactivate_callbacks[0]()
  145. self.assertEqual(func(), ['c1', 'c2'])
  146. # Deactivate the middle
  147. func, deactivate_callbacks = make_wrapped_function()
  148. deactivate_callbacks[1]()
  149. self.assertEqual(func(), ['c0', 'c2'])
  150. # Deactivate the head
  151. func, deactivate_callbacks = make_wrapped_function()
  152. deactivate_callbacks[2]()
  153. self.assertEqual(func(), ['c0', 'c1'])
  154. def test_isolation_nonempty(self):
  155. # f2 and f3 are a chain of operations started in context c1.
  156. # f2 is incidentally run under context c2, but that context should
  157. # not be passed along to f3.
  158. def f1():
  159. with StackContext(functools.partial(self.context, 'c1')):
  160. wrapped = wrap(f2)
  161. with StackContext(functools.partial(self.context, 'c2')):
  162. wrapped()
  163. def f2():
  164. self.assertIn('c1', self.active_contexts)
  165. self.io_loop.add_callback(f3)
  166. def f3():
  167. self.assertIn('c1', self.active_contexts)
  168. self.assertNotIn('c2', self.active_contexts)
  169. self.stop()
  170. self.io_loop.add_callback(f1)
  171. self.wait()
  172. def test_isolation_empty(self):
  173. # Similar to test_isolation_nonempty, but here the f2/f3 chain
  174. # is started without any context. Behavior should be equivalent
  175. # to the nonempty case (although historically it was not)
  176. def f1():
  177. with NullContext():
  178. wrapped = wrap(f2)
  179. with StackContext(functools.partial(self.context, 'c2')):
  180. wrapped()
  181. def f2():
  182. self.io_loop.add_callback(f3)
  183. def f3():
  184. self.assertNotIn('c2', self.active_contexts)
  185. self.stop()
  186. self.io_loop.add_callback(f1)
  187. self.wait()
  188. def test_yield_in_with(self):
  189. @gen.engine
  190. def f():
  191. self.callback = yield gen.Callback('a')
  192. with StackContext(functools.partial(self.context, 'c1')):
  193. # This yield is a problem: the generator will be suspended
  194. # and the StackContext's __exit__ is not called yet, so
  195. # the context will be left on _state.contexts for anything
  196. # that runs before the yield resolves.
  197. yield gen.Wait('a')
  198. with self.assertRaises(StackContextInconsistentError):
  199. f()
  200. self.wait()
  201. # Cleanup: to avoid GC warnings (which for some reason only seem
  202. # to show up on py33-asyncio), invoke the callback (which will do
  203. # nothing since the gen.Runner is already finished) and delete it.
  204. self.callback()
  205. del self.callback
  206. @gen_test
  207. def test_yield_outside_with(self):
  208. # This pattern avoids the problem in the previous test.
  209. cb = yield gen.Callback('k1')
  210. with StackContext(functools.partial(self.context, 'c1')):
  211. self.io_loop.add_callback(cb)
  212. yield gen.Wait('k1')
  213. def test_yield_in_with_exception_stack_context(self):
  214. # As above, but with ExceptionStackContext instead of StackContext.
  215. @gen.engine
  216. def f():
  217. with ExceptionStackContext(lambda t, v, tb: False):
  218. yield gen.Task(self.io_loop.add_callback)
  219. with self.assertRaises(StackContextInconsistentError):
  220. f()
  221. self.wait()
  222. @gen_test
  223. def test_yield_outside_with_exception_stack_context(self):
  224. cb = yield gen.Callback('k1')
  225. with ExceptionStackContext(lambda t, v, tb: False):
  226. self.io_loop.add_callback(cb)
  227. yield gen.Wait('k1')
  228. @gen_test
  229. def test_run_with_stack_context(self):
  230. @gen.coroutine
  231. def f1():
  232. self.assertEqual(self.active_contexts, ['c1'])
  233. yield run_with_stack_context(
  234. StackContext(functools.partial(self.context, 'c2')),
  235. f2)
  236. self.assertEqual(self.active_contexts, ['c1'])
  237. @gen.coroutine
  238. def f2():
  239. self.assertEqual(self.active_contexts, ['c1', 'c2'])
  240. yield gen.Task(self.io_loop.add_callback)
  241. self.assertEqual(self.active_contexts, ['c1', 'c2'])
  242. self.assertEqual(self.active_contexts, [])
  243. yield run_with_stack_context(
  244. StackContext(functools.partial(self.context, 'c1')),
  245. f1)
  246. self.assertEqual(self.active_contexts, [])
  247. if __name__ == '__main__':
  248. unittest.main()