websocket_test.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775
  1. from __future__ import absolute_import, division, print_function
  2. import functools
  3. import sys
  4. import traceback
  5. from tornado.concurrent import Future
  6. from tornado import gen
  7. from tornado.httpclient import HTTPError, HTTPRequest
  8. from tornado.locks import Event
  9. from tornado.log import gen_log, app_log
  10. from tornado.simple_httpclient import SimpleAsyncHTTPClient
  11. from tornado.template import DictLoader
  12. from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
  13. from tornado.test.util import unittest, skipBefore35, exec_test
  14. from tornado.web import Application, RequestHandler
  15. try:
  16. import tornado.websocket # noqa
  17. from tornado.util import _websocket_mask_python
  18. except ImportError:
  19. # The unittest module presents misleading errors on ImportError
  20. # (it acts as if websocket_test could not be found, hiding the underlying
  21. # error). If we get an ImportError here (which could happen due to
  22. # TORNADO_EXTENSION=1), print some extra information before failing.
  23. traceback.print_exc()
  24. raise
  25. from tornado.websocket import (
  26. WebSocketHandler, websocket_connect, WebSocketError, WebSocketClosedError,
  27. )
  28. try:
  29. from tornado import speedups
  30. except ImportError:
  31. speedups = None
  32. class TestWebSocketHandler(WebSocketHandler):
  33. """Base class for testing handlers that exposes the on_close event.
  34. This allows for deterministic cleanup of the associated socket.
  35. """
  36. def initialize(self, close_future, compression_options=None):
  37. self.close_future = close_future
  38. self.compression_options = compression_options
  39. def get_compression_options(self):
  40. return self.compression_options
  41. def on_close(self):
  42. self.close_future.set_result((self.close_code, self.close_reason))
  43. class EchoHandler(TestWebSocketHandler):
  44. @gen.coroutine
  45. def on_message(self, message):
  46. try:
  47. yield self.write_message(message, isinstance(message, bytes))
  48. except WebSocketClosedError:
  49. pass
  50. class ErrorInOnMessageHandler(TestWebSocketHandler):
  51. def on_message(self, message):
  52. 1 / 0
  53. class HeaderHandler(TestWebSocketHandler):
  54. def open(self):
  55. methods_to_test = [
  56. functools.partial(self.write, 'This should not work'),
  57. functools.partial(self.redirect, 'http://localhost/elsewhere'),
  58. functools.partial(self.set_header, 'X-Test', ''),
  59. functools.partial(self.set_cookie, 'Chocolate', 'Chip'),
  60. functools.partial(self.set_status, 503),
  61. self.flush,
  62. self.finish,
  63. ]
  64. for method in methods_to_test:
  65. try:
  66. # In a websocket context, many RequestHandler methods
  67. # raise RuntimeErrors.
  68. method()
  69. raise Exception("did not get expected exception")
  70. except RuntimeError:
  71. pass
  72. self.write_message(self.request.headers.get('X-Test', ''))
  73. class HeaderEchoHandler(TestWebSocketHandler):
  74. def set_default_headers(self):
  75. self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
  76. def prepare(self):
  77. for k, v in self.request.headers.get_all():
  78. if k.lower().startswith('x-test'):
  79. self.set_header(k, v)
  80. class NonWebSocketHandler(RequestHandler):
  81. def get(self):
  82. self.write('ok')
  83. class CloseReasonHandler(TestWebSocketHandler):
  84. def open(self):
  85. self.on_close_called = False
  86. self.close(1001, "goodbye")
  87. class AsyncPrepareHandler(TestWebSocketHandler):
  88. @gen.coroutine
  89. def prepare(self):
  90. yield gen.moment
  91. def on_message(self, message):
  92. self.write_message(message)
  93. class PathArgsHandler(TestWebSocketHandler):
  94. def open(self, arg):
  95. self.write_message(arg)
  96. class CoroutineOnMessageHandler(TestWebSocketHandler):
  97. def initialize(self, close_future, compression_options=None):
  98. super(CoroutineOnMessageHandler, self).initialize(close_future,
  99. compression_options)
  100. self.sleeping = 0
  101. @gen.coroutine
  102. def on_message(self, message):
  103. if self.sleeping > 0:
  104. self.write_message('another coroutine is already sleeping')
  105. self.sleeping += 1
  106. yield gen.sleep(0.01)
  107. self.sleeping -= 1
  108. self.write_message(message)
  109. class RenderMessageHandler(TestWebSocketHandler):
  110. def on_message(self, message):
  111. self.write_message(self.render_string('message.html', message=message))
  112. class SubprotocolHandler(TestWebSocketHandler):
  113. def initialize(self, **kwargs):
  114. super(SubprotocolHandler, self).initialize(**kwargs)
  115. self.select_subprotocol_called = False
  116. def select_subprotocol(self, subprotocols):
  117. if self.select_subprotocol_called:
  118. raise Exception("select_subprotocol called twice")
  119. self.select_subprotocol_called = True
  120. if 'goodproto' in subprotocols:
  121. return 'goodproto'
  122. return None
  123. def open(self):
  124. if not self.select_subprotocol_called:
  125. raise Exception("select_subprotocol not called")
  126. self.write_message("subprotocol=%s" % self.selected_subprotocol)
  127. class OpenCoroutineHandler(TestWebSocketHandler):
  128. def initialize(self, test, **kwargs):
  129. super(OpenCoroutineHandler, self).initialize(**kwargs)
  130. self.test = test
  131. self.open_finished = False
  132. @gen.coroutine
  133. def open(self):
  134. yield self.test.message_sent.wait()
  135. yield gen.sleep(0.010)
  136. self.open_finished = True
  137. def on_message(self, message):
  138. if not self.open_finished:
  139. raise Exception('on_message called before open finished')
  140. self.write_message('ok')
  141. class WebSocketBaseTestCase(AsyncHTTPTestCase):
  142. @gen.coroutine
  143. def ws_connect(self, path, **kwargs):
  144. ws = yield websocket_connect(
  145. 'ws://127.0.0.1:%d%s' % (self.get_http_port(), path),
  146. **kwargs)
  147. raise gen.Return(ws)
  148. @gen.coroutine
  149. def close(self, ws):
  150. """Close a websocket connection and wait for the server side.
  151. If we don't wait here, there are sometimes leak warnings in the
  152. tests.
  153. """
  154. ws.close()
  155. yield self.close_future
  156. class WebSocketTest(WebSocketBaseTestCase):
  157. def get_app(self):
  158. self.close_future = Future()
  159. return Application([
  160. ('/echo', EchoHandler, dict(close_future=self.close_future)),
  161. ('/non_ws', NonWebSocketHandler),
  162. ('/header', HeaderHandler, dict(close_future=self.close_future)),
  163. ('/header_echo', HeaderEchoHandler,
  164. dict(close_future=self.close_future)),
  165. ('/close_reason', CloseReasonHandler,
  166. dict(close_future=self.close_future)),
  167. ('/error_in_on_message', ErrorInOnMessageHandler,
  168. dict(close_future=self.close_future)),
  169. ('/async_prepare', AsyncPrepareHandler,
  170. dict(close_future=self.close_future)),
  171. ('/path_args/(.*)', PathArgsHandler,
  172. dict(close_future=self.close_future)),
  173. ('/coroutine', CoroutineOnMessageHandler,
  174. dict(close_future=self.close_future)),
  175. ('/render', RenderMessageHandler,
  176. dict(close_future=self.close_future)),
  177. ('/subprotocol', SubprotocolHandler,
  178. dict(close_future=self.close_future)),
  179. ('/open_coroutine', OpenCoroutineHandler,
  180. dict(close_future=self.close_future, test=self)),
  181. ], template_loader=DictLoader({
  182. 'message.html': '<b>{{ message }}</b>',
  183. }))
  184. def get_http_client(self):
  185. # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
  186. return SimpleAsyncHTTPClient()
  187. def tearDown(self):
  188. super(WebSocketTest, self).tearDown()
  189. RequestHandler._template_loaders.clear()
  190. def test_http_request(self):
  191. # WS server, HTTP client.
  192. response = self.fetch('/echo')
  193. self.assertEqual(response.code, 400)
  194. def test_missing_websocket_key(self):
  195. response = self.fetch('/echo',
  196. headers={'Connection': 'Upgrade',
  197. 'Upgrade': 'WebSocket',
  198. 'Sec-WebSocket-Version': '13'})
  199. self.assertEqual(response.code, 400)
  200. def test_bad_websocket_version(self):
  201. response = self.fetch('/echo',
  202. headers={'Connection': 'Upgrade',
  203. 'Upgrade': 'WebSocket',
  204. 'Sec-WebSocket-Version': '12'})
  205. self.assertEqual(response.code, 426)
  206. @gen_test
  207. def test_websocket_gen(self):
  208. ws = yield self.ws_connect('/echo')
  209. yield ws.write_message('hello')
  210. response = yield ws.read_message()
  211. self.assertEqual(response, 'hello')
  212. yield self.close(ws)
  213. def test_websocket_callbacks(self):
  214. websocket_connect(
  215. 'ws://127.0.0.1:%d/echo' % self.get_http_port(),
  216. callback=self.stop)
  217. ws = self.wait().result()
  218. ws.write_message('hello')
  219. ws.read_message(self.stop)
  220. response = self.wait().result()
  221. self.assertEqual(response, 'hello')
  222. self.close_future.add_done_callback(lambda f: self.stop())
  223. ws.close()
  224. self.wait()
  225. @gen_test
  226. def test_binary_message(self):
  227. ws = yield self.ws_connect('/echo')
  228. ws.write_message(b'hello \xe9', binary=True)
  229. response = yield ws.read_message()
  230. self.assertEqual(response, b'hello \xe9')
  231. yield self.close(ws)
  232. @gen_test
  233. def test_unicode_message(self):
  234. ws = yield self.ws_connect('/echo')
  235. ws.write_message(u'hello \u00e9')
  236. response = yield ws.read_message()
  237. self.assertEqual(response, u'hello \u00e9')
  238. yield self.close(ws)
  239. @gen_test
  240. def test_render_message(self):
  241. ws = yield self.ws_connect('/render')
  242. ws.write_message('hello')
  243. response = yield ws.read_message()
  244. self.assertEqual(response, '<b>hello</b>')
  245. yield self.close(ws)
  246. @gen_test
  247. def test_error_in_on_message(self):
  248. ws = yield self.ws_connect('/error_in_on_message')
  249. ws.write_message('hello')
  250. with ExpectLog(app_log, "Uncaught exception"):
  251. response = yield ws.read_message()
  252. self.assertIs(response, None)
  253. yield self.close(ws)
  254. @gen_test
  255. def test_websocket_http_fail(self):
  256. with self.assertRaises(HTTPError) as cm:
  257. yield self.ws_connect('/notfound')
  258. self.assertEqual(cm.exception.code, 404)
  259. @gen_test
  260. def test_websocket_http_success(self):
  261. with self.assertRaises(WebSocketError):
  262. yield self.ws_connect('/non_ws')
  263. @gen_test
  264. def test_websocket_network_fail(self):
  265. sock, port = bind_unused_port()
  266. sock.close()
  267. with self.assertRaises(IOError):
  268. with ExpectLog(gen_log, ".*"):
  269. yield websocket_connect(
  270. 'ws://127.0.0.1:%d/' % port,
  271. connect_timeout=3600)
  272. @gen_test
  273. def test_websocket_close_buffered_data(self):
  274. ws = yield websocket_connect(
  275. 'ws://127.0.0.1:%d/echo' % self.get_http_port())
  276. ws.write_message('hello')
  277. ws.write_message('world')
  278. # Close the underlying stream.
  279. ws.stream.close()
  280. yield self.close_future
  281. @gen_test
  282. def test_websocket_headers(self):
  283. # Ensure that arbitrary headers can be passed through websocket_connect.
  284. ws = yield websocket_connect(
  285. HTTPRequest('ws://127.0.0.1:%d/header' % self.get_http_port(),
  286. headers={'X-Test': 'hello'}))
  287. response = yield ws.read_message()
  288. self.assertEqual(response, 'hello')
  289. yield self.close(ws)
  290. @gen_test
  291. def test_websocket_header_echo(self):
  292. # Ensure that headers can be returned in the response.
  293. # Specifically, that arbitrary headers passed through websocket_connect
  294. # can be returned.
  295. ws = yield websocket_connect(
  296. HTTPRequest('ws://127.0.0.1:%d/header_echo' % self.get_http_port(),
  297. headers={'X-Test-Hello': 'hello'}))
  298. self.assertEqual(ws.headers.get('X-Test-Hello'), 'hello')
  299. self.assertEqual(ws.headers.get('X-Extra-Response-Header'), 'Extra-Response-Value')
  300. yield self.close(ws)
  301. @gen_test
  302. def test_server_close_reason(self):
  303. ws = yield self.ws_connect('/close_reason')
  304. msg = yield ws.read_message()
  305. # A message of None means the other side closed the connection.
  306. self.assertIs(msg, None)
  307. self.assertEqual(ws.close_code, 1001)
  308. self.assertEqual(ws.close_reason, "goodbye")
  309. # The on_close callback is called no matter which side closed.
  310. code, reason = yield self.close_future
  311. # The client echoed the close code it received to the server,
  312. # so the server's close code (returned via close_future) is
  313. # the same.
  314. self.assertEqual(code, 1001)
  315. @gen_test
  316. def test_client_close_reason(self):
  317. ws = yield self.ws_connect('/echo')
  318. ws.close(1001, 'goodbye')
  319. code, reason = yield self.close_future
  320. self.assertEqual(code, 1001)
  321. self.assertEqual(reason, 'goodbye')
  322. @gen_test
  323. def test_write_after_close(self):
  324. ws = yield self.ws_connect('/close_reason')
  325. msg = yield ws.read_message()
  326. self.assertIs(msg, None)
  327. with self.assertRaises(WebSocketClosedError):
  328. ws.write_message('hello')
  329. @gen_test
  330. def test_async_prepare(self):
  331. # Previously, an async prepare method triggered a bug that would
  332. # result in a timeout on test shutdown (and a memory leak).
  333. ws = yield self.ws_connect('/async_prepare')
  334. ws.write_message('hello')
  335. res = yield ws.read_message()
  336. self.assertEqual(res, 'hello')
  337. @gen_test
  338. def test_path_args(self):
  339. ws = yield self.ws_connect('/path_args/hello')
  340. res = yield ws.read_message()
  341. self.assertEqual(res, 'hello')
  342. @gen_test
  343. def test_coroutine(self):
  344. ws = yield self.ws_connect('/coroutine')
  345. # Send both messages immediately, coroutine must process one at a time.
  346. yield ws.write_message('hello1')
  347. yield ws.write_message('hello2')
  348. res = yield ws.read_message()
  349. self.assertEqual(res, 'hello1')
  350. res = yield ws.read_message()
  351. self.assertEqual(res, 'hello2')
  352. @gen_test
  353. def test_check_origin_valid_no_path(self):
  354. port = self.get_http_port()
  355. url = 'ws://127.0.0.1:%d/echo' % port
  356. headers = {'Origin': 'http://127.0.0.1:%d' % port}
  357. ws = yield websocket_connect(HTTPRequest(url, headers=headers))
  358. ws.write_message('hello')
  359. response = yield ws.read_message()
  360. self.assertEqual(response, 'hello')
  361. yield self.close(ws)
  362. @gen_test
  363. def test_check_origin_valid_with_path(self):
  364. port = self.get_http_port()
  365. url = 'ws://127.0.0.1:%d/echo' % port
  366. headers = {'Origin': 'http://127.0.0.1:%d/something' % port}
  367. ws = yield websocket_connect(HTTPRequest(url, headers=headers))
  368. ws.write_message('hello')
  369. response = yield ws.read_message()
  370. self.assertEqual(response, 'hello')
  371. yield self.close(ws)
  372. @gen_test
  373. def test_check_origin_invalid_partial_url(self):
  374. port = self.get_http_port()
  375. url = 'ws://127.0.0.1:%d/echo' % port
  376. headers = {'Origin': '127.0.0.1:%d' % port}
  377. with self.assertRaises(HTTPError) as cm:
  378. yield websocket_connect(HTTPRequest(url, headers=headers))
  379. self.assertEqual(cm.exception.code, 403)
  380. @gen_test
  381. def test_check_origin_invalid(self):
  382. port = self.get_http_port()
  383. url = 'ws://127.0.0.1:%d/echo' % port
  384. # Host is 127.0.0.1, which should not be accessible from some other
  385. # domain
  386. headers = {'Origin': 'http://somewhereelse.com'}
  387. with self.assertRaises(HTTPError) as cm:
  388. yield websocket_connect(HTTPRequest(url, headers=headers))
  389. self.assertEqual(cm.exception.code, 403)
  390. @gen_test
  391. def test_check_origin_invalid_subdomains(self):
  392. port = self.get_http_port()
  393. url = 'ws://localhost:%d/echo' % port
  394. # Subdomains should be disallowed by default. If we could pass a
  395. # resolver to websocket_connect we could test sibling domains as well.
  396. headers = {'Origin': 'http://subtenant.localhost'}
  397. with self.assertRaises(HTTPError) as cm:
  398. yield websocket_connect(HTTPRequest(url, headers=headers))
  399. self.assertEqual(cm.exception.code, 403)
  400. @gen_test
  401. def test_subprotocols(self):
  402. ws = yield self.ws_connect('/subprotocol', subprotocols=['badproto', 'goodproto'])
  403. self.assertEqual(ws.selected_subprotocol, 'goodproto')
  404. res = yield ws.read_message()
  405. self.assertEqual(res, 'subprotocol=goodproto')
  406. yield self.close(ws)
  407. @gen_test
  408. def test_subprotocols_not_offered(self):
  409. ws = yield self.ws_connect('/subprotocol')
  410. self.assertIs(ws.selected_subprotocol, None)
  411. res = yield ws.read_message()
  412. self.assertEqual(res, 'subprotocol=None')
  413. yield self.close(ws)
  414. @gen_test
  415. def test_open_coroutine(self):
  416. self.message_sent = Event()
  417. ws = yield self.ws_connect('/open_coroutine')
  418. yield ws.write_message('hello')
  419. self.message_sent.set()
  420. res = yield ws.read_message()
  421. self.assertEqual(res, 'ok')
  422. yield self.close(ws)
  423. if sys.version_info >= (3, 5):
  424. NativeCoroutineOnMessageHandler = exec_test(globals(), locals(), """
  425. class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
  426. def initialize(self, close_future, compression_options=None):
  427. super().initialize(close_future, compression_options)
  428. self.sleeping = 0
  429. async def on_message(self, message):
  430. if self.sleeping > 0:
  431. self.write_message('another coroutine is already sleeping')
  432. self.sleeping += 1
  433. await gen.sleep(0.01)
  434. self.sleeping -= 1
  435. self.write_message(message)""")['NativeCoroutineOnMessageHandler']
  436. class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
  437. def get_app(self):
  438. self.close_future = Future()
  439. return Application([
  440. ('/native', NativeCoroutineOnMessageHandler,
  441. dict(close_future=self.close_future))])
  442. @skipBefore35
  443. @gen_test
  444. def test_native_coroutine(self):
  445. ws = yield self.ws_connect('/native')
  446. # Send both messages immediately, coroutine must process one at a time.
  447. yield ws.write_message('hello1')
  448. yield ws.write_message('hello2')
  449. res = yield ws.read_message()
  450. self.assertEqual(res, 'hello1')
  451. res = yield ws.read_message()
  452. self.assertEqual(res, 'hello2')
  453. class CompressionTestMixin(object):
  454. MESSAGE = 'Hello world. Testing 123 123'
  455. def get_app(self):
  456. self.close_future = Future()
  457. class LimitedHandler(TestWebSocketHandler):
  458. @property
  459. def max_message_size(self):
  460. return 1024
  461. def on_message(self, message):
  462. self.write_message(str(len(message)))
  463. return Application([
  464. ('/echo', EchoHandler, dict(
  465. close_future=self.close_future,
  466. compression_options=self.get_server_compression_options())),
  467. ('/limited', LimitedHandler, dict(
  468. close_future=self.close_future,
  469. compression_options=self.get_server_compression_options())),
  470. ])
  471. def get_server_compression_options(self):
  472. return None
  473. def get_client_compression_options(self):
  474. return None
  475. @gen_test
  476. def test_message_sizes(self):
  477. ws = yield self.ws_connect(
  478. '/echo',
  479. compression_options=self.get_client_compression_options())
  480. # Send the same message three times so we can measure the
  481. # effect of the context_takeover options.
  482. for i in range(3):
  483. ws.write_message(self.MESSAGE)
  484. response = yield ws.read_message()
  485. self.assertEqual(response, self.MESSAGE)
  486. self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
  487. self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
  488. self.verify_wire_bytes(ws.protocol._wire_bytes_in,
  489. ws.protocol._wire_bytes_out)
  490. yield self.close(ws)
  491. @gen_test
  492. def test_size_limit(self):
  493. ws = yield self.ws_connect(
  494. '/limited',
  495. compression_options=self.get_client_compression_options())
  496. # Small messages pass through.
  497. ws.write_message('a' * 128)
  498. response = yield ws.read_message()
  499. self.assertEqual(response, '128')
  500. # This message is too big after decompression, but it compresses
  501. # down to a size that will pass the initial checks.
  502. ws.write_message('a' * 2048)
  503. response = yield ws.read_message()
  504. self.assertIsNone(response)
  505. yield self.close(ws)
  506. class UncompressedTestMixin(CompressionTestMixin):
  507. """Specialization of CompressionTestMixin when we expect no compression."""
  508. def verify_wire_bytes(self, bytes_in, bytes_out):
  509. # Bytes out includes the 4-byte mask key per message.
  510. self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
  511. self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
  512. class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
  513. pass
  514. # If only one side tries to compress, the extension is not negotiated.
  515. class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
  516. def get_server_compression_options(self):
  517. return {}
  518. class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
  519. def get_client_compression_options(self):
  520. return {}
  521. class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
  522. def get_server_compression_options(self):
  523. return {}
  524. def get_client_compression_options(self):
  525. return {}
  526. def verify_wire_bytes(self, bytes_in, bytes_out):
  527. self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
  528. self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
  529. # Bytes out includes the 4 bytes mask key per message.
  530. self.assertEqual(bytes_out, bytes_in + 12)
  531. class MaskFunctionMixin(object):
  532. # Subclasses should define self.mask(mask, data)
  533. def test_mask(self):
  534. self.assertEqual(self.mask(b'abcd', b''), b'')
  535. self.assertEqual(self.mask(b'abcd', b'b'), b'\x03')
  536. self.assertEqual(self.mask(b'abcd', b'54321'), b'TVPVP')
  537. self.assertEqual(self.mask(b'ZXCV', b'98765432'), b'c`t`olpd')
  538. # Include test cases with \x00 bytes (to ensure that the C
  539. # extension isn't depending on null-terminated strings) and
  540. # bytes with the high bit set (to smoke out signedness issues).
  541. self.assertEqual(self.mask(b'\x00\x01\x02\x03',
  542. b'\xff\xfb\xfd\xfc\xfe\xfa'),
  543. b'\xff\xfa\xff\xff\xfe\xfb')
  544. self.assertEqual(self.mask(b'\xff\xfb\xfd\xfc',
  545. b'\x00\x01\x02\x03\x04\x05'),
  546. b'\xff\xfa\xff\xff\xfb\xfe')
  547. class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
  548. def mask(self, mask, data):
  549. return _websocket_mask_python(mask, data)
  550. @unittest.skipIf(speedups is None, "tornado.speedups module not present")
  551. class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
  552. def mask(self, mask, data):
  553. return speedups.websocket_mask(mask, data)
  554. class ServerPeriodicPingTest(WebSocketBaseTestCase):
  555. def get_app(self):
  556. class PingHandler(TestWebSocketHandler):
  557. def on_pong(self, data):
  558. self.write_message("got pong")
  559. self.close_future = Future()
  560. return Application([
  561. ('/', PingHandler, dict(close_future=self.close_future)),
  562. ], websocket_ping_interval=0.01)
  563. @gen_test
  564. def test_server_ping(self):
  565. ws = yield self.ws_connect('/')
  566. for i in range(3):
  567. response = yield ws.read_message()
  568. self.assertEqual(response, "got pong")
  569. yield self.close(ws)
  570. # TODO: test that the connection gets closed if ping responses stop.
  571. class ClientPeriodicPingTest(WebSocketBaseTestCase):
  572. def get_app(self):
  573. class PingHandler(TestWebSocketHandler):
  574. def on_ping(self, data):
  575. self.write_message("got ping")
  576. self.close_future = Future()
  577. return Application([
  578. ('/', PingHandler, dict(close_future=self.close_future)),
  579. ])
  580. @gen_test
  581. def test_client_ping(self):
  582. ws = yield self.ws_connect('/', ping_interval=0.01)
  583. for i in range(3):
  584. response = yield ws.read_message()
  585. self.assertEqual(response, "got ping")
  586. yield self.close(ws)
  587. # TODO: test that the connection gets closed if ping responses stop.
  588. class ManualPingTest(WebSocketBaseTestCase):
  589. def get_app(self):
  590. class PingHandler(TestWebSocketHandler):
  591. def on_ping(self, data):
  592. self.write_message(data, binary=isinstance(data, bytes))
  593. self.close_future = Future()
  594. return Application([
  595. ('/', PingHandler, dict(close_future=self.close_future)),
  596. ])
  597. @gen_test
  598. def test_manual_ping(self):
  599. ws = yield self.ws_connect('/')
  600. self.assertRaises(ValueError, ws.ping, 'a' * 126)
  601. ws.ping('hello')
  602. resp = yield ws.read_message()
  603. # on_ping always sees bytes.
  604. self.assertEqual(resp, b'hello')
  605. ws.ping(b'binary hello')
  606. resp = yield ws.read_message()
  607. self.assertEqual(resp, b'binary hello')
  608. yield self.close(ws)
  609. class MaxMessageSizeTest(WebSocketBaseTestCase):
  610. def get_app(self):
  611. self.close_future = Future()
  612. return Application([
  613. ('/', EchoHandler, dict(close_future=self.close_future)),
  614. ], websocket_max_message_size=1024)
  615. @gen_test
  616. def test_large_message(self):
  617. ws = yield self.ws_connect('/')
  618. # Write a message that is allowed.
  619. msg = 'a' * 1024
  620. ws.write_message(msg)
  621. resp = yield ws.read_message()
  622. self.assertEqual(resp, msg)
  623. # Write a message that is too large.
  624. ws.write_message(msg + 'b')
  625. resp = yield ws.read_message()
  626. # A message of None means the other side closed the connection.
  627. self.assertIs(resp, None)
  628. self.assertEqual(ws.close_code, 1009)
  629. self.assertEqual(ws.close_reason, "message too big")
  630. # TODO: Needs tests of messages split over multiple
  631. # continuation frames.