concurrent_test.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. #
  2. # Copyright 2012 Facebook
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License"); you may
  5. # not use this file except in compliance with the License. You may obtain
  6. # a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  12. # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
  13. # License for the specific language governing permissions and limitations
  14. # under the License.
  15. from __future__ import absolute_import, division, print_function
  16. import gc
  17. import logging
  18. import re
  19. import socket
  20. import sys
  21. import traceback
  22. import warnings
  23. from tornado.concurrent import (Future, return_future, ReturnValueIgnoredError,
  24. run_on_executor, future_set_result_unless_cancelled)
  25. from tornado.escape import utf8, to_unicode
  26. from tornado import gen
  27. from tornado.ioloop import IOLoop
  28. from tornado.iostream import IOStream
  29. from tornado.log import app_log
  30. from tornado import stack_context
  31. from tornado.tcpserver import TCPServer
  32. from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
  33. from tornado.test.util import unittest, skipBefore35, exec_test, ignore_deprecation
  34. try:
  35. from concurrent import futures
  36. except ImportError:
  37. futures = None
  38. class MiscFutureTest(AsyncTestCase):
  39. def test_future_set_result_unless_cancelled(self):
  40. fut = Future()
  41. future_set_result_unless_cancelled(fut, 42)
  42. self.assertEqual(fut.result(), 42)
  43. self.assertFalse(fut.cancelled())
  44. fut = Future()
  45. fut.cancel()
  46. is_cancelled = fut.cancelled()
  47. future_set_result_unless_cancelled(fut, 42)
  48. self.assertEqual(fut.cancelled(), is_cancelled)
  49. if not is_cancelled:
  50. self.assertEqual(fut.result(), 42)
  51. class ReturnFutureTest(AsyncTestCase):
  52. with ignore_deprecation():
  53. @return_future
  54. def sync_future(self, callback):
  55. callback(42)
  56. @return_future
  57. def async_future(self, callback):
  58. self.io_loop.add_callback(callback, 42)
  59. @return_future
  60. def immediate_failure(self, callback):
  61. 1 / 0
  62. @return_future
  63. def delayed_failure(self, callback):
  64. self.io_loop.add_callback(lambda: 1 / 0)
  65. @return_future
  66. def return_value(self, callback):
  67. # Note that the result of both running the callback and returning
  68. # a value (or raising an exception) is unspecified; with current
  69. # implementations the last event prior to callback resolution wins.
  70. return 42
  71. @return_future
  72. def no_result_future(self, callback):
  73. callback()
  74. def test_immediate_failure(self):
  75. with self.assertRaises(ZeroDivisionError):
  76. # The caller sees the error just like a normal function.
  77. self.immediate_failure(callback=self.stop)
  78. # The callback is not run because the function failed synchronously.
  79. self.io_loop.add_timeout(self.io_loop.time() + 0.05, self.stop)
  80. result = self.wait()
  81. self.assertIs(result, None)
  82. def test_return_value(self):
  83. with self.assertRaises(ReturnValueIgnoredError):
  84. self.return_value(callback=self.stop)
  85. def test_callback_kw(self):
  86. with ignore_deprecation():
  87. future = self.sync_future(callback=self.stop)
  88. result = self.wait()
  89. self.assertEqual(result, 42)
  90. self.assertEqual(future.result(), 42)
  91. def test_callback_positional(self):
  92. # When the callback is passed in positionally, future_wrap shouldn't
  93. # add another callback in the kwargs.
  94. with ignore_deprecation():
  95. future = self.sync_future(self.stop)
  96. result = self.wait()
  97. self.assertEqual(result, 42)
  98. self.assertEqual(future.result(), 42)
  99. def test_no_callback(self):
  100. future = self.sync_future()
  101. self.assertEqual(future.result(), 42)
  102. def test_none_callback_kw(self):
  103. # explicitly pass None as callback
  104. future = self.sync_future(callback=None)
  105. self.assertEqual(future.result(), 42)
  106. def test_none_callback_pos(self):
  107. future = self.sync_future(None)
  108. self.assertEqual(future.result(), 42)
  109. def test_async_future(self):
  110. future = self.async_future()
  111. self.assertFalse(future.done())
  112. self.io_loop.add_future(future, self.stop)
  113. future2 = self.wait()
  114. self.assertIs(future, future2)
  115. self.assertEqual(future.result(), 42)
  116. @gen_test
  117. def test_async_future_gen(self):
  118. result = yield self.async_future()
  119. self.assertEqual(result, 42)
  120. def test_delayed_failure(self):
  121. future = self.delayed_failure()
  122. with ignore_deprecation():
  123. self.io_loop.add_future(future, self.stop)
  124. future2 = self.wait()
  125. self.assertIs(future, future2)
  126. with self.assertRaises(ZeroDivisionError):
  127. future.result()
  128. def test_kw_only_callback(self):
  129. with ignore_deprecation():
  130. @return_future
  131. def f(**kwargs):
  132. kwargs['callback'](42)
  133. future = f()
  134. self.assertEqual(future.result(), 42)
  135. def test_error_in_callback(self):
  136. with ignore_deprecation():
  137. self.sync_future(callback=lambda future: 1 / 0)
  138. # The exception gets caught by our StackContext and will be re-raised
  139. # when we wait.
  140. self.assertRaises(ZeroDivisionError, self.wait)
  141. def test_no_result_future(self):
  142. with ignore_deprecation():
  143. future = self.no_result_future(self.stop)
  144. result = self.wait()
  145. self.assertIs(result, None)
  146. # result of this future is undefined, but not an error
  147. future.result()
  148. def test_no_result_future_callback(self):
  149. with ignore_deprecation():
  150. future = self.no_result_future(callback=lambda: self.stop())
  151. result = self.wait()
  152. self.assertIs(result, None)
  153. future.result()
  154. @gen_test
  155. def test_future_traceback_legacy(self):
  156. with ignore_deprecation():
  157. @return_future
  158. @gen.engine
  159. def f(callback):
  160. yield gen.Task(self.io_loop.add_callback)
  161. try:
  162. 1 / 0
  163. except ZeroDivisionError:
  164. self.expected_frame = traceback.extract_tb(
  165. sys.exc_info()[2], limit=1)[0]
  166. raise
  167. try:
  168. yield f()
  169. self.fail("didn't get expected exception")
  170. except ZeroDivisionError:
  171. tb = traceback.extract_tb(sys.exc_info()[2])
  172. self.assertIn(self.expected_frame, tb)
  173. @gen_test
  174. def test_future_traceback(self):
  175. @gen.coroutine
  176. def f():
  177. yield gen.moment
  178. try:
  179. 1 / 0
  180. except ZeroDivisionError:
  181. self.expected_frame = traceback.extract_tb(
  182. sys.exc_info()[2], limit=1)[0]
  183. raise
  184. try:
  185. yield f()
  186. self.fail("didn't get expected exception")
  187. except ZeroDivisionError:
  188. tb = traceback.extract_tb(sys.exc_info()[2])
  189. self.assertIn(self.expected_frame, tb)
  190. @gen_test
  191. def test_uncaught_exception_log(self):
  192. if IOLoop.configured_class().__name__.endswith('AsyncIOLoop'):
  193. # Install an exception handler that mirrors our
  194. # non-asyncio logging behavior.
  195. def exc_handler(loop, context):
  196. app_log.error('%s: %s', context['message'],
  197. type(context.get('exception')))
  198. self.io_loop.asyncio_loop.set_exception_handler(exc_handler)
  199. @gen.coroutine
  200. def f():
  201. yield gen.moment
  202. 1 / 0
  203. g = f()
  204. with ExpectLog(app_log,
  205. "(?s)Future.* exception was never retrieved:"
  206. ".*ZeroDivisionError"):
  207. yield gen.moment
  208. yield gen.moment
  209. # For some reason, TwistedIOLoop and pypy3 need a third iteration
  210. # in order to drain references to the future
  211. yield gen.moment
  212. del g
  213. gc.collect() # for PyPy
  214. # The following series of classes demonstrate and test various styles
  215. # of use, with and without generators and futures.
  216. class CapServer(TCPServer):
  217. @gen.coroutine
  218. def handle_stream(self, stream, address):
  219. data = yield stream.read_until(b"\n")
  220. data = to_unicode(data)
  221. if data == data.upper():
  222. stream.write(b"error\talready capitalized\n")
  223. else:
  224. # data already has \n
  225. stream.write(utf8("ok\t%s" % data.upper()))
  226. stream.close()
  227. class CapError(Exception):
  228. pass
  229. class BaseCapClient(object):
  230. def __init__(self, port):
  231. self.port = port
  232. def process_response(self, data):
  233. status, message = re.match('(.*)\t(.*)\n', to_unicode(data)).groups()
  234. if status == 'ok':
  235. return message
  236. else:
  237. raise CapError(message)
  238. class ManualCapClient(BaseCapClient):
  239. def capitalize(self, request_data, callback=None):
  240. logging.debug("capitalize")
  241. self.request_data = request_data
  242. self.stream = IOStream(socket.socket())
  243. self.stream.connect(('127.0.0.1', self.port),
  244. callback=self.handle_connect)
  245. self.future = Future()
  246. if callback is not None:
  247. self.future.add_done_callback(
  248. stack_context.wrap(lambda future: callback(future.result())))
  249. return self.future
  250. def handle_connect(self):
  251. logging.debug("handle_connect")
  252. self.stream.write(utf8(self.request_data + "\n"))
  253. self.stream.read_until(b'\n', callback=self.handle_read)
  254. def handle_read(self, data):
  255. logging.debug("handle_read")
  256. self.stream.close()
  257. try:
  258. self.future.set_result(self.process_response(data))
  259. except CapError as e:
  260. self.future.set_exception(e)
  261. class DecoratorCapClient(BaseCapClient):
  262. with ignore_deprecation():
  263. @return_future
  264. def capitalize(self, request_data, callback):
  265. logging.debug("capitalize")
  266. self.request_data = request_data
  267. self.stream = IOStream(socket.socket())
  268. self.stream.connect(('127.0.0.1', self.port),
  269. callback=self.handle_connect)
  270. self.callback = callback
  271. def handle_connect(self):
  272. logging.debug("handle_connect")
  273. self.stream.write(utf8(self.request_data + "\n"))
  274. self.stream.read_until(b'\n', callback=self.handle_read)
  275. def handle_read(self, data):
  276. logging.debug("handle_read")
  277. self.stream.close()
  278. self.callback(self.process_response(data))
  279. class GeneratorCapClient(BaseCapClient):
  280. @gen.coroutine
  281. def capitalize(self, request_data):
  282. logging.debug('capitalize')
  283. stream = IOStream(socket.socket())
  284. logging.debug('connecting')
  285. yield stream.connect(('127.0.0.1', self.port))
  286. stream.write(utf8(request_data + '\n'))
  287. logging.debug('reading')
  288. data = yield stream.read_until(b'\n')
  289. logging.debug('returning')
  290. stream.close()
  291. raise gen.Return(self.process_response(data))
  292. class ClientTestMixin(object):
  293. def setUp(self):
  294. super(ClientTestMixin, self).setUp() # type: ignore
  295. self.server = CapServer()
  296. sock, port = bind_unused_port()
  297. self.server.add_sockets([sock])
  298. self.client = self.client_class(port=port)
  299. def tearDown(self):
  300. self.server.stop()
  301. super(ClientTestMixin, self).tearDown() # type: ignore
  302. def test_callback(self):
  303. with ignore_deprecation():
  304. self.client.capitalize("hello", callback=self.stop)
  305. result = self.wait()
  306. self.assertEqual(result, "HELLO")
  307. def test_callback_error(self):
  308. with ignore_deprecation():
  309. self.client.capitalize("HELLO", callback=self.stop)
  310. self.assertRaisesRegexp(CapError, "already capitalized", self.wait)
  311. def test_future(self):
  312. future = self.client.capitalize("hello")
  313. self.io_loop.add_future(future, self.stop)
  314. self.wait()
  315. self.assertEqual(future.result(), "HELLO")
  316. def test_future_error(self):
  317. future = self.client.capitalize("HELLO")
  318. self.io_loop.add_future(future, self.stop)
  319. self.wait()
  320. self.assertRaisesRegexp(CapError, "already capitalized", future.result)
  321. def test_generator(self):
  322. @gen.coroutine
  323. def f():
  324. result = yield self.client.capitalize("hello")
  325. self.assertEqual(result, "HELLO")
  326. self.io_loop.run_sync(f)
  327. def test_generator_error(self):
  328. @gen.coroutine
  329. def f():
  330. with self.assertRaisesRegexp(CapError, "already capitalized"):
  331. yield self.client.capitalize("HELLO")
  332. self.io_loop.run_sync(f)
  333. class ManualClientTest(ClientTestMixin, AsyncTestCase):
  334. client_class = ManualCapClient
  335. def setUp(self):
  336. self.warning_catcher = warnings.catch_warnings()
  337. self.warning_catcher.__enter__()
  338. warnings.simplefilter('ignore', DeprecationWarning)
  339. super(ManualClientTest, self).setUp()
  340. def tearDown(self):
  341. super(ManualClientTest, self).tearDown()
  342. self.warning_catcher.__exit__(None, None, None)
  343. class DecoratorClientTest(ClientTestMixin, AsyncTestCase):
  344. client_class = DecoratorCapClient
  345. def setUp(self):
  346. self.warning_catcher = warnings.catch_warnings()
  347. self.warning_catcher.__enter__()
  348. warnings.simplefilter('ignore', DeprecationWarning)
  349. super(DecoratorClientTest, self).setUp()
  350. def tearDown(self):
  351. super(DecoratorClientTest, self).tearDown()
  352. self.warning_catcher.__exit__(None, None, None)
  353. class GeneratorClientTest(ClientTestMixin, AsyncTestCase):
  354. client_class = GeneratorCapClient
  355. @unittest.skipIf(futures is None, "concurrent.futures module not present")
  356. class RunOnExecutorTest(AsyncTestCase):
  357. @gen_test
  358. def test_no_calling(self):
  359. class Object(object):
  360. def __init__(self):
  361. self.executor = futures.thread.ThreadPoolExecutor(1)
  362. @run_on_executor
  363. def f(self):
  364. return 42
  365. o = Object()
  366. answer = yield o.f()
  367. self.assertEqual(answer, 42)
  368. @gen_test
  369. def test_call_with_no_args(self):
  370. class Object(object):
  371. def __init__(self):
  372. self.executor = futures.thread.ThreadPoolExecutor(1)
  373. @run_on_executor()
  374. def f(self):
  375. return 42
  376. o = Object()
  377. answer = yield o.f()
  378. self.assertEqual(answer, 42)
  379. @gen_test
  380. def test_call_with_executor(self):
  381. class Object(object):
  382. def __init__(self):
  383. self.__executor = futures.thread.ThreadPoolExecutor(1)
  384. @run_on_executor(executor='_Object__executor')
  385. def f(self):
  386. return 42
  387. o = Object()
  388. answer = yield o.f()
  389. self.assertEqual(answer, 42)
  390. @skipBefore35
  391. @gen_test
  392. def test_async_await(self):
  393. class Object(object):
  394. def __init__(self):
  395. self.executor = futures.thread.ThreadPoolExecutor(1)
  396. @run_on_executor()
  397. def f(self):
  398. return 42
  399. o = Object()
  400. namespace = exec_test(globals(), locals(), """
  401. async def f():
  402. answer = await o.f()
  403. return answer
  404. """)
  405. result = yield namespace['f']()
  406. self.assertEqual(result, 42)
  407. if __name__ == '__main__':
  408. unittest.main()