123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- # basic_tests.py -- Basic unit tests for Terminado
- # Copyright (c) Jupyter Development Team
- # Copyright (c) 2014, Ramalingam Saravanan <sarava@sarava.net>
- # Distributed under the terms of the Simplified BSD License.
- from __future__ import absolute_import, print_function
- import unittest
- from terminado import *
- import tornado
- import tornado.httpserver
- from tornado.httpclient import HTTPError
- from tornado.ioloop import IOLoop
- import tornado.testing
- import datetime
- import logging
- import json
- import os
- import re
- #
- # The timeout we use to assume no more messages are coming
- # from the sever.
- #
- DONE_TIMEOUT = 1.0
- os.environ['ASYNC_TEST_TIMEOUT'] = "20" # Global test case timeout
- MAX_TERMS = 3 # Testing thresholds
- class TestTermClient(object):
- """Test connection to a terminal manager"""
- def __init__(self, websocket):
- self.ws = websocket
- self.pending_read = None
- @tornado.gen.coroutine
- def read_msg(self):
- # Because the Tornado Websocket client has no way to cancel
- # a pending read, we have to keep track of them...
- if self.pending_read is None:
- self.pending_read = self.ws.read_message()
- response = yield self.pending_read
- self.pending_read = None
- if response:
- response = json.loads(response)
- raise tornado.gen.Return(response)
- @tornado.gen.coroutine
- def read_all_msg(self, timeout=DONE_TIMEOUT):
- """Read messages until read times out"""
- msglist = []
- delta = datetime.timedelta(seconds=timeout)
- while True:
- try:
- mf = self.read_msg()
- msg = yield tornado.gen.with_timeout(delta, mf)
- except tornado.gen.TimeoutError:
- raise tornado.gen.Return(msglist)
- msglist.append(msg)
- def write_msg(self, msg):
- self.ws.write_message(json.dumps(msg))
- @tornado.gen.coroutine
- def read_stdout(self, timeout=DONE_TIMEOUT):
- """Read standard output until timeout read reached,
- return stdout and any non-stdout msgs received."""
- msglist = yield self.read_all_msg(timeout)
- stdout = "".join([msg[1] for msg in msglist if msg[0] == 'stdout'])
- othermsg = [msg for msg in msglist if msg[0] != 'stdout']
- raise tornado.gen.Return((stdout, othermsg))
- def write_stdin(self, data):
- """Write to terminal stdin"""
- self.write_msg(['stdin', data])
- @tornado.gen.coroutine
- def get_pid(self):
- """Get process ID of terminal shell process"""
- yield self.read_stdout() # Clear out any pending
- self.write_stdin("echo $$\r")
- (stdout, extra) = yield self.read_stdout()
- if os.name == 'nt':
- match = re.search(r'echo \$\$\x1b\[0K\r\n(\d+)', stdout)
- pid = int(match.groups()[0])
- else:
- pid = int(stdout.split('\n')[1])
- raise tornado.gen.Return(pid)
- def close(self):
- self.ws.close()
- class TermTestCase(tornado.testing.AsyncHTTPTestCase):
- # Factory for TestTermClient, because it has to be a Tornado co-routine.
- # See: https://github.com/tornadoweb/tornado/issues/1161
- @tornado.gen.coroutine
- def get_term_client(self, path):
- port = self.get_http_port()
- url = 'ws://127.0.0.1:%d%s' % (port, path)
- request = tornado.httpclient.HTTPRequest(url,
- headers={'Origin' : 'http://127.0.0.1:%d' % port})
- ws = yield tornado.websocket.websocket_connect(request)
- raise tornado.gen.Return(TestTermClient(ws))
- @tornado.gen.coroutine
- def get_term_clients(self, paths):
- tms = yield [self.get_term_client(path) for path in paths]
- raise tornado.gen.Return(tms)
- @tornado.gen.coroutine
- def get_pids(self, tm_list):
- pids = []
- for tm in tm_list: # Must be sequential, in case terms are shared
- pid = yield tm.get_pid()
- pids.append(pid)
- raise tornado.gen.Return(pids)
- def get_app(self):
- self.named_tm = NamedTermManager(shell_command=['bash'],
- max_terminals=MAX_TERMS,
- ioloop=self.io_loop)
- self.single_tm = SingleTermManager(shell_command=['bash'],
- ioloop=self.io_loop)
- self.unique_tm = UniqueTermManager(shell_command=['bash'],
- max_terminals=MAX_TERMS,
- ioloop=self.io_loop)
-
- named_tm = self.named_tm
- class NewTerminalHandler(tornado.web.RequestHandler):
- """Create a new named terminal, return redirect"""
- def get(self):
- name, terminal = named_tm.new_named_terminal()
- self.redirect("/named/" + name, permanent=False)
- return tornado.web.Application([
- (r"/new", NewTerminalHandler),
- (r"/named/(\w+)", TermSocket, {'term_manager': self.named_tm}),
- (r"/single", TermSocket, {'term_manager': self.single_tm}),
- (r"/unique", TermSocket, {'term_manager': self.unique_tm})
- ], debug=True)
- test_urls = ('/named/term1', '/unique', '/single')
- class CommonTests(TermTestCase):
- @tornado.testing.gen_test
- def test_basic(self):
- for url in self.test_urls:
- tm = yield self.get_term_client(url)
- response = yield tm.read_msg()
- self.assertEqual(response, ['setup', {}])
- # Check for initial shell prompt
- response = yield tm.read_msg()
- self.assertEqual(response[0], 'stdout')
- self.assertGreater(len(response[1]), 0)
- tm.close()
- @tornado.testing.gen_test
- def test_basic_command(self):
- for url in self.test_urls:
- tm = yield self.get_term_client(url)
- yield tm.read_all_msg()
- tm.write_stdin("whoami\n")
- (stdout, other) = yield tm.read_stdout()
- if os.name == 'nt':
- assert 'whoami' in stdout
- else:
- assert stdout.startswith('who')
- assert other == []
- tm.close()
- class NamedTermTests(TermTestCase):
- def test_new(self):
- response = self.fetch("/new", follow_redirects=False)
- self.assertEqual(response.code, 302)
- url = response.headers["Location"]
- # Check that the new terminal was created
- name = url.split('/')[2]
- self.assertIn(name, self.named_tm.terminals)
- @tornado.testing.gen_test
- def test_namespace(self):
- names = ["/named/1"]*2 + ["/named/2"]*2
- tms = yield self.get_term_clients(names)
- pids = yield self.get_pids(tms)
- self.assertEqual(pids[0], pids[1])
- self.assertEqual(pids[2], pids[3])
- self.assertNotEqual(pids[0], pids[3])
- @tornado.testing.gen_test
- def test_max_terminals(self):
- urls = ["/named/%d" % i for i in range(MAX_TERMS+1)]
- tms = yield self.get_term_clients(urls[:MAX_TERMS])
- pids = yield self.get_pids(tms)
- # MAX_TERMS+1 should fail
- tm = yield self.get_term_client(urls[MAX_TERMS])
- msg = yield tm.read_msg()
- self.assertEqual(msg, None) # Connection closed
- class SingleTermTests(TermTestCase):
- @tornado.testing.gen_test
- def test_single_process(self):
- tms = yield self.get_term_clients(["/single", "/single"])
- pids = yield self.get_pids(tms)
- self.assertEqual(pids[0], pids[1])
- class UniqueTermTests(TermTestCase):
- @tornado.testing.gen_test
- def test_unique_processes(self):
- tms = yield self.get_term_clients(["/unique", "/unique"])
- pids = yield self.get_pids(tms)
- self.assertNotEqual(pids[0], pids[1])
- @tornado.testing.gen_test
- def test_max_terminals(self):
- tms = yield self.get_term_clients(['/unique'] * MAX_TERMS)
- pids = yield self.get_pids(tms)
- self.assertEqual(len(set(pids)), MAX_TERMS) # All PIDs unique
- # MAX_TERMS+1 should fail
- tm = yield self.get_term_client("/unique")
- msg = yield tm.read_msg()
- self.assertEqual(msg, None) # Connection closed
- # Close one
- tms[0].close()
- msg = yield tms[0].read_msg() # Closed
- self.assertEquals(msg, None)
- # Should be able to open back up to MAX_TERMS
- tm = yield self.get_term_client("/unique")
- msg = yield tm.read_msg()
- self.assertEquals(msg[0], 'setup')
- if __name__ == '__main__':
- unittest.main()
|