launchnotebook.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. """Base class for notebook tests."""
  2. from __future__ import print_function
  3. from binascii import hexlify
  4. from contextlib import contextmanager
  5. import errno
  6. import os
  7. import sys
  8. from threading import Thread, Event
  9. import time
  10. from unittest import TestCase
  11. pjoin = os.path.join
  12. try:
  13. from unittest.mock import patch
  14. except ImportError:
  15. from mock import patch #py2
  16. import requests
  17. from tornado.ioloop import IOLoop
  18. import zmq
  19. import jupyter_core.paths
  20. from traitlets.config import Config
  21. from ..notebookapp import NotebookApp
  22. from ..utils import url_path_join
  23. from ipython_genutils.tempdir import TemporaryDirectory
  24. MAX_WAITTIME = 30 # seconds to wait for notebook server to start
  25. POLL_INTERVAL = 0.1 # time between attempts
  26. # TimeoutError is a builtin on Python 3. This can be removed when we stop
  27. # supporting Python 2.
  28. class TimeoutError(Exception):
  29. pass
  30. class NotebookTestBase(TestCase):
  31. """A base class for tests that need a running notebook.
  32. This create some empty config and runtime directories
  33. and then starts the notebook server with them.
  34. """
  35. port = 12341
  36. config = None
  37. # run with a base URL that would be escaped,
  38. # to test that we don't double-escape URLs
  39. url_prefix = '/a%40b/'
  40. @classmethod
  41. def wait_until_alive(cls):
  42. """Wait for the server to be alive"""
  43. url = cls.base_url() + 'api/contents'
  44. for _ in range(int(MAX_WAITTIME/POLL_INTERVAL)):
  45. try:
  46. requests.get(url)
  47. except Exception as e:
  48. if not cls.notebook_thread.is_alive():
  49. raise RuntimeError("The notebook server failed to start")
  50. time.sleep(POLL_INTERVAL)
  51. else:
  52. return
  53. raise TimeoutError("The notebook server didn't start up correctly.")
  54. @classmethod
  55. def wait_until_dead(cls):
  56. """Wait for the server process to terminate after shutdown"""
  57. cls.notebook_thread.join(timeout=MAX_WAITTIME)
  58. if cls.notebook_thread.is_alive():
  59. raise TimeoutError("Undead notebook server")
  60. @classmethod
  61. def auth_headers(cls):
  62. headers = {}
  63. if cls.token:
  64. headers['Authorization'] = 'token %s' % cls.token
  65. return headers
  66. @classmethod
  67. def request(cls, verb, path, **kwargs):
  68. """Send a request to my server
  69. with authentication and everything.
  70. """
  71. headers = kwargs.setdefault('headers', {})
  72. headers.update(cls.auth_headers())
  73. response = requests.request(verb,
  74. url_path_join(cls.base_url(), path),
  75. **kwargs)
  76. return response
  77. @classmethod
  78. def setup_class(cls):
  79. cls.tmp_dir = TemporaryDirectory()
  80. def tmp(*parts):
  81. path = os.path.join(cls.tmp_dir.name, *parts)
  82. try:
  83. os.makedirs(path)
  84. except OSError as e:
  85. if e.errno != errno.EEXIST:
  86. raise
  87. return path
  88. cls.home_dir = tmp('home')
  89. data_dir = cls.data_dir = tmp('data')
  90. config_dir = cls.config_dir = tmp('config')
  91. runtime_dir = cls.runtime_dir = tmp('runtime')
  92. cls.notebook_dir = tmp('notebooks')
  93. cls.env_patch = patch.dict('os.environ', {
  94. 'HOME': cls.home_dir,
  95. 'PYTHONPATH': os.pathsep.join(sys.path),
  96. 'IPYTHONDIR': pjoin(cls.home_dir, '.ipython'),
  97. 'JUPYTER_NO_CONFIG': '1', # needed in the future
  98. 'JUPYTER_CONFIG_DIR' : config_dir,
  99. 'JUPYTER_DATA_DIR' : data_dir,
  100. 'JUPYTER_RUNTIME_DIR': runtime_dir,
  101. })
  102. cls.env_patch.start()
  103. cls.path_patch = patch.multiple(
  104. jupyter_core.paths,
  105. SYSTEM_JUPYTER_PATH=[tmp('share', 'jupyter')],
  106. ENV_JUPYTER_PATH=[tmp('env', 'share', 'jupyter')],
  107. SYSTEM_CONFIG_PATH=[tmp('etc', 'jupyter')],
  108. ENV_CONFIG_PATH=[tmp('env', 'etc', 'jupyter')],
  109. )
  110. cls.path_patch.start()
  111. config = cls.config or Config()
  112. config.NotebookNotary.db_file = ':memory:'
  113. cls.token = hexlify(os.urandom(4)).decode('ascii')
  114. started = Event()
  115. def start_thread():
  116. if 'asyncio' in sys.modules:
  117. import asyncio
  118. asyncio.set_event_loop(asyncio.new_event_loop())
  119. app = cls.notebook = NotebookApp(
  120. port=cls.port,
  121. port_retries=0,
  122. open_browser=False,
  123. config_dir=cls.config_dir,
  124. data_dir=cls.data_dir,
  125. runtime_dir=cls.runtime_dir,
  126. notebook_dir=cls.notebook_dir,
  127. base_url=cls.url_prefix,
  128. config=config,
  129. allow_root=True,
  130. token=cls.token,
  131. )
  132. # don't register signal handler during tests
  133. app.init_signal = lambda : None
  134. # clear log handlers and propagate to root for nose to capture it
  135. # needs to be redone after initialize, which reconfigures logging
  136. app.log.propagate = True
  137. app.log.handlers = []
  138. app.initialize(argv=[])
  139. app.log.propagate = True
  140. app.log.handlers = []
  141. loop = IOLoop.current()
  142. loop.add_callback(started.set)
  143. try:
  144. app.start()
  145. finally:
  146. # set the event, so failure to start doesn't cause a hang
  147. started.set()
  148. app.session_manager.close()
  149. cls.notebook_thread = Thread(target=start_thread)
  150. cls.notebook_thread.daemon = True
  151. cls.notebook_thread.start()
  152. started.wait()
  153. cls.wait_until_alive()
  154. @classmethod
  155. def teardown_class(cls):
  156. cls.notebook.stop()
  157. cls.wait_until_dead()
  158. cls.env_patch.stop()
  159. cls.path_patch.stop()
  160. cls.tmp_dir.cleanup()
  161. # cleanup global zmq Context, to ensure we aren't leaving dangling sockets
  162. def cleanup_zmq():
  163. zmq.Context.instance().term()
  164. t = Thread(target=cleanup_zmq)
  165. t.daemon = True
  166. t.start()
  167. t.join(5) # give it a few seconds to clean up (this should be immediate)
  168. # if term never returned, there's zmq stuff still open somewhere, so shout about it.
  169. if t.is_alive():
  170. raise RuntimeError("Failed to teardown zmq Context, open sockets likely left lying around.")
  171. @classmethod
  172. def base_url(cls):
  173. return 'http://localhost:%i%s' % (cls.port, cls.url_prefix)
  174. @contextmanager
  175. def assert_http_error(status, msg=None):
  176. try:
  177. yield
  178. except requests.HTTPError as e:
  179. real_status = e.response.status_code
  180. assert real_status == status, \
  181. "Expected status %d, got %d" % (status, real_status)
  182. if msg:
  183. assert msg in str(e), e
  184. else:
  185. assert False, "Expected HTTP error status"