basic_test.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. # basic_tests.py -- Basic unit tests for Terminado
  2. # Copyright (c) Jupyter Development Team
  3. # Copyright (c) 2014, Ramalingam Saravanan <sarava@sarava.net>
  4. # Distributed under the terms of the Simplified BSD License.
  5. from __future__ import absolute_import, print_function
  6. import unittest
  7. from terminado import *
  8. import tornado
  9. import tornado.httpserver
  10. from tornado.httpclient import HTTPError
  11. from tornado.ioloop import IOLoop
  12. import tornado.testing
  13. import datetime
  14. import logging
  15. import json
  16. import os
  17. import re
  18. #
  19. # The timeout we use to assume no more messages are coming
  20. # from the sever.
  21. #
  22. DONE_TIMEOUT = 1.0
  23. os.environ['ASYNC_TEST_TIMEOUT'] = "20" # Global test case timeout
  24. MAX_TERMS = 3 # Testing thresholds
  25. class TestTermClient(object):
  26. """Test connection to a terminal manager"""
  27. def __init__(self, websocket):
  28. self.ws = websocket
  29. self.pending_read = None
  30. @tornado.gen.coroutine
  31. def read_msg(self):
  32. # Because the Tornado Websocket client has no way to cancel
  33. # a pending read, we have to keep track of them...
  34. if self.pending_read is None:
  35. self.pending_read = self.ws.read_message()
  36. response = yield self.pending_read
  37. self.pending_read = None
  38. if response:
  39. response = json.loads(response)
  40. raise tornado.gen.Return(response)
  41. @tornado.gen.coroutine
  42. def read_all_msg(self, timeout=DONE_TIMEOUT):
  43. """Read messages until read times out"""
  44. msglist = []
  45. delta = datetime.timedelta(seconds=timeout)
  46. while True:
  47. try:
  48. mf = self.read_msg()
  49. msg = yield tornado.gen.with_timeout(delta, mf)
  50. except tornado.gen.TimeoutError:
  51. raise tornado.gen.Return(msglist)
  52. msglist.append(msg)
  53. def write_msg(self, msg):
  54. self.ws.write_message(json.dumps(msg))
  55. @tornado.gen.coroutine
  56. def read_stdout(self, timeout=DONE_TIMEOUT):
  57. """Read standard output until timeout read reached,
  58. return stdout and any non-stdout msgs received."""
  59. msglist = yield self.read_all_msg(timeout)
  60. stdout = "".join([msg[1] for msg in msglist if msg[0] == 'stdout'])
  61. othermsg = [msg for msg in msglist if msg[0] != 'stdout']
  62. raise tornado.gen.Return((stdout, othermsg))
  63. def write_stdin(self, data):
  64. """Write to terminal stdin"""
  65. self.write_msg(['stdin', data])
  66. @tornado.gen.coroutine
  67. def get_pid(self):
  68. """Get process ID of terminal shell process"""
  69. yield self.read_stdout() # Clear out any pending
  70. self.write_stdin("echo $$\r")
  71. (stdout, extra) = yield self.read_stdout()
  72. if os.name == 'nt':
  73. match = re.search(r'echo \$\$\x1b\[0K\r\n(\d+)', stdout)
  74. pid = int(match.groups()[0])
  75. else:
  76. pid = int(stdout.split('\n')[1])
  77. raise tornado.gen.Return(pid)
  78. def close(self):
  79. self.ws.close()
  80. class TermTestCase(tornado.testing.AsyncHTTPTestCase):
  81. # Factory for TestTermClient, because it has to be a Tornado co-routine.
  82. # See: https://github.com/tornadoweb/tornado/issues/1161
  83. @tornado.gen.coroutine
  84. def get_term_client(self, path):
  85. port = self.get_http_port()
  86. url = 'ws://127.0.0.1:%d%s' % (port, path)
  87. request = tornado.httpclient.HTTPRequest(url,
  88. headers={'Origin' : 'http://127.0.0.1:%d' % port})
  89. ws = yield tornado.websocket.websocket_connect(request)
  90. raise tornado.gen.Return(TestTermClient(ws))
  91. @tornado.gen.coroutine
  92. def get_term_clients(self, paths):
  93. tms = yield [self.get_term_client(path) for path in paths]
  94. raise tornado.gen.Return(tms)
  95. @tornado.gen.coroutine
  96. def get_pids(self, tm_list):
  97. pids = []
  98. for tm in tm_list: # Must be sequential, in case terms are shared
  99. pid = yield tm.get_pid()
  100. pids.append(pid)
  101. raise tornado.gen.Return(pids)
  102. def get_app(self):
  103. self.named_tm = NamedTermManager(shell_command=['bash'],
  104. max_terminals=MAX_TERMS,
  105. ioloop=self.io_loop)
  106. self.single_tm = SingleTermManager(shell_command=['bash'],
  107. ioloop=self.io_loop)
  108. self.unique_tm = UniqueTermManager(shell_command=['bash'],
  109. max_terminals=MAX_TERMS,
  110. ioloop=self.io_loop)
  111. named_tm = self.named_tm
  112. class NewTerminalHandler(tornado.web.RequestHandler):
  113. """Create a new named terminal, return redirect"""
  114. def get(self):
  115. name, terminal = named_tm.new_named_terminal()
  116. self.redirect("/named/" + name, permanent=False)
  117. return tornado.web.Application([
  118. (r"/new", NewTerminalHandler),
  119. (r"/named/(\w+)", TermSocket, {'term_manager': self.named_tm}),
  120. (r"/single", TermSocket, {'term_manager': self.single_tm}),
  121. (r"/unique", TermSocket, {'term_manager': self.unique_tm})
  122. ], debug=True)
  123. test_urls = ('/named/term1', '/unique', '/single')
  124. class CommonTests(TermTestCase):
  125. @tornado.testing.gen_test
  126. def test_basic(self):
  127. for url in self.test_urls:
  128. tm = yield self.get_term_client(url)
  129. response = yield tm.read_msg()
  130. self.assertEqual(response, ['setup', {}])
  131. # Check for initial shell prompt
  132. response = yield tm.read_msg()
  133. self.assertEqual(response[0], 'stdout')
  134. self.assertGreater(len(response[1]), 0)
  135. tm.close()
  136. @tornado.testing.gen_test
  137. def test_basic_command(self):
  138. for url in self.test_urls:
  139. tm = yield self.get_term_client(url)
  140. yield tm.read_all_msg()
  141. tm.write_stdin("whoami\n")
  142. (stdout, other) = yield tm.read_stdout()
  143. if os.name == 'nt':
  144. assert 'whoami' in stdout
  145. else:
  146. assert stdout.startswith('who')
  147. assert other == []
  148. tm.close()
  149. class NamedTermTests(TermTestCase):
  150. def test_new(self):
  151. response = self.fetch("/new", follow_redirects=False)
  152. self.assertEqual(response.code, 302)
  153. url = response.headers["Location"]
  154. # Check that the new terminal was created
  155. name = url.split('/')[2]
  156. self.assertIn(name, self.named_tm.terminals)
  157. @tornado.testing.gen_test
  158. def test_namespace(self):
  159. names = ["/named/1"]*2 + ["/named/2"]*2
  160. tms = yield self.get_term_clients(names)
  161. pids = yield self.get_pids(tms)
  162. self.assertEqual(pids[0], pids[1])
  163. self.assertEqual(pids[2], pids[3])
  164. self.assertNotEqual(pids[0], pids[3])
  165. @tornado.testing.gen_test
  166. def test_max_terminals(self):
  167. urls = ["/named/%d" % i for i in range(MAX_TERMS+1)]
  168. tms = yield self.get_term_clients(urls[:MAX_TERMS])
  169. pids = yield self.get_pids(tms)
  170. # MAX_TERMS+1 should fail
  171. tm = yield self.get_term_client(urls[MAX_TERMS])
  172. msg = yield tm.read_msg()
  173. self.assertEqual(msg, None) # Connection closed
  174. class SingleTermTests(TermTestCase):
  175. @tornado.testing.gen_test
  176. def test_single_process(self):
  177. tms = yield self.get_term_clients(["/single", "/single"])
  178. pids = yield self.get_pids(tms)
  179. self.assertEqual(pids[0], pids[1])
  180. class UniqueTermTests(TermTestCase):
  181. @tornado.testing.gen_test
  182. def test_unique_processes(self):
  183. tms = yield self.get_term_clients(["/unique", "/unique"])
  184. pids = yield self.get_pids(tms)
  185. self.assertNotEqual(pids[0], pids[1])
  186. @tornado.testing.gen_test
  187. def test_max_terminals(self):
  188. tms = yield self.get_term_clients(['/unique'] * MAX_TERMS)
  189. pids = yield self.get_pids(tms)
  190. self.assertEqual(len(set(pids)), MAX_TERMS) # All PIDs unique
  191. # MAX_TERMS+1 should fail
  192. tm = yield self.get_term_client("/unique")
  193. msg = yield tm.read_msg()
  194. self.assertEqual(msg, None) # Connection closed
  195. # Close one
  196. tms[0].close()
  197. msg = yield tms[0].read_msg() # Closed
  198. self.assertEquals(msg, None)
  199. # Should be able to open back up to MAX_TERMS
  200. tm = yield self.get_term_client("/unique")
  201. msg = yield tm.read_msg()
  202. self.assertEquals(msg[0], 'setup')
  203. if __name__ == '__main__':
  204. unittest.main()