test_interactivshell.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # -*- coding: utf-8 -*-
  2. """Tests for the TerminalInteractiveShell and related pieces."""
  3. #-----------------------------------------------------------------------------
  4. # Copyright (C) 2011 The IPython Development Team
  5. #
  6. # Distributed under the terms of the BSD License. The full license is in
  7. # the file COPYING, distributed as part of this software.
  8. #-----------------------------------------------------------------------------
  9. import sys
  10. import unittest
  11. from IPython.core.inputtransformer import InputTransformer
  12. from IPython.testing import tools as tt
  13. # Decorator for interaction loop tests -----------------------------------------
  14. class mock_input_helper(object):
  15. """Machinery for tests of the main interact loop.
  16. Used by the mock_input decorator.
  17. """
  18. def __init__(self, testgen):
  19. self.testgen = testgen
  20. self.exception = None
  21. self.ip = get_ipython()
  22. def __enter__(self):
  23. self.orig_prompt_for_code = self.ip.prompt_for_code
  24. self.ip.prompt_for_code = self.fake_input
  25. return self
  26. def __exit__(self, etype, value, tb):
  27. self.ip.prompt_for_code = self.orig_prompt_for_code
  28. def fake_input(self):
  29. try:
  30. return next(self.testgen)
  31. except StopIteration:
  32. self.ip.keep_running = False
  33. return u''
  34. except:
  35. self.exception = sys.exc_info()
  36. self.ip.keep_running = False
  37. return u''
  38. def mock_input(testfunc):
  39. """Decorator for tests of the main interact loop.
  40. Write the test as a generator, yield-ing the input strings, which IPython
  41. will see as if they were typed in at the prompt.
  42. """
  43. def test_method(self):
  44. testgen = testfunc(self)
  45. with mock_input_helper(testgen) as mih:
  46. mih.ip.interact()
  47. if mih.exception is not None:
  48. # Re-raise captured exception
  49. etype, value, tb = mih.exception
  50. import traceback
  51. traceback.print_tb(tb, file=sys.stdout)
  52. del tb # Avoid reference loop
  53. raise value
  54. return test_method
  55. # Test classes -----------------------------------------------------------------
  56. class InteractiveShellTestCase(unittest.TestCase):
  57. def rl_hist_entries(self, rl, n):
  58. """Get last n readline history entries as a list"""
  59. return [rl.get_history_item(rl.get_current_history_length() - x)
  60. for x in range(n - 1, -1, -1)]
  61. @mock_input
  62. def test_inputtransformer_syntaxerror(self):
  63. ip = get_ipython()
  64. transformer = SyntaxErrorTransformer()
  65. ip.input_splitter.python_line_transforms.append(transformer)
  66. ip.input_transformer_manager.python_line_transforms.append(transformer)
  67. try:
  68. #raise Exception
  69. with tt.AssertPrints('4', suppress=False):
  70. yield u'print(2*2)'
  71. with tt.AssertPrints('SyntaxError: input contains', suppress=False):
  72. yield u'print(2345) # syntaxerror'
  73. with tt.AssertPrints('16', suppress=False):
  74. yield u'print(4*4)'
  75. finally:
  76. ip.input_splitter.python_line_transforms.remove(transformer)
  77. ip.input_transformer_manager.python_line_transforms.remove(transformer)
  78. def test_plain_text_only(self):
  79. ip = get_ipython()
  80. formatter = ip.display_formatter
  81. assert formatter.active_types == ['text/plain']
  82. class SyntaxErrorTransformer(InputTransformer):
  83. def push(self, line):
  84. pos = line.find('syntaxerror')
  85. if pos >= 0:
  86. e = SyntaxError('input contains "syntaxerror"')
  87. e.text = line
  88. e.offset = pos + 1
  89. raise e
  90. return line
  91. def reset(self):
  92. pass
  93. class TerminalMagicsTestCase(unittest.TestCase):
  94. def test_paste_magics_blankline(self):
  95. """Test that code with a blank line doesn't get split (gh-3246)."""
  96. ip = get_ipython()
  97. s = ('def pasted_func(a):\n'
  98. ' b = a+1\n'
  99. '\n'
  100. ' return b')
  101. tm = ip.magics_manager.registry['TerminalMagics']
  102. tm.store_or_execute(s, name=None)
  103. self.assertEqual(ip.user_ns['pasted_func'](54), 55)