test_channel.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # Copyright Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. Test ssh/channel.py.
  5. """
  6. from __future__ import division, absolute_import
  7. from zope.interface.verify import verifyObject
  8. try:
  9. from twisted.conch.ssh import channel
  10. from twisted.conch.ssh.address import SSHTransportAddress
  11. from twisted.conch.ssh.transport import SSHServerTransport
  12. from twisted.conch.ssh.service import SSHService
  13. from twisted.internet import interfaces
  14. from twisted.internet.address import IPv4Address
  15. from twisted.test.proto_helpers import StringTransport
  16. skipTest = None
  17. except ImportError:
  18. skipTest = 'Conch SSH not supported.'
  19. SSHService = object
  20. from twisted.trial import unittest
  21. from twisted.python.compat import intToBytes
  22. class MockConnection(SSHService):
  23. """
  24. A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
  25. that channels send, and when they try to close the connection.
  26. @ivar data: a L{dict} mapping channel id #s to lists of data sent by that
  27. channel.
  28. @ivar extData: a L{dict} mapping channel id #s to lists of 2-tuples
  29. (extended data type, data) sent by that channel.
  30. @ivar closes: a L{dict} mapping channel id #s to True if that channel sent
  31. a close message.
  32. """
  33. def __init__(self):
  34. self.data = {}
  35. self.extData = {}
  36. self.closes = {}
  37. def logPrefix(self):
  38. """
  39. Return our logging prefix.
  40. """
  41. return "MockConnection"
  42. def sendData(self, channel, data):
  43. """
  44. Record the sent data.
  45. """
  46. self.data.setdefault(channel, []).append(data)
  47. def sendExtendedData(self, channel, type, data):
  48. """
  49. Record the sent extended data.
  50. """
  51. self.extData.setdefault(channel, []).append((type, data))
  52. def sendClose(self, channel):
  53. """
  54. Record that the channel sent a close message.
  55. """
  56. self.closes[channel] = True
  57. def connectSSHTransport(service, hostAddress=None, peerAddress=None):
  58. """
  59. Connect a SSHTransport which is already connected to a remote peer to
  60. the channel under test.
  61. @param service: Service used over the connected transport.
  62. @type service: L{SSHService}
  63. @param hostAddress: Local address of the connected transport.
  64. @type hostAddress: L{interfaces.IAddress}
  65. @param peerAddress: Remote address of the connected transport.
  66. @type peerAddress: L{interfaces.IAddress}
  67. """
  68. transport = SSHServerTransport()
  69. transport.makeConnection(StringTransport(
  70. hostAddress=hostAddress, peerAddress=peerAddress))
  71. transport.setService(service)
  72. class ChannelTests(unittest.TestCase):
  73. """
  74. Tests for L{SSHChannel}.
  75. """
  76. skip = skipTest
  77. def setUp(self):
  78. """
  79. Initialize the channel. remoteMaxPacket is 10 so that data is able
  80. to be sent (the default of 0 means no data is sent because no packets
  81. are made).
  82. """
  83. self.conn = MockConnection()
  84. self.channel = channel.SSHChannel(conn=self.conn,
  85. remoteMaxPacket=10)
  86. self.channel.name = b'channel'
  87. def test_interface(self):
  88. """
  89. L{SSHChannel} instances provide L{interfaces.ITransport}.
  90. """
  91. self.assertTrue(verifyObject(interfaces.ITransport, self.channel))
  92. def test_init(self):
  93. """
  94. Test that SSHChannel initializes correctly. localWindowSize defaults
  95. to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
  96. defaults (what OpenSSH uses for those variables).
  97. The values in the second set of assertions are meaningless; they serve
  98. only to verify that the instance variables are assigned in the correct
  99. order.
  100. """
  101. c = channel.SSHChannel(conn=self.conn)
  102. self.assertEqual(c.localWindowSize, 131072)
  103. self.assertEqual(c.localWindowLeft, 131072)
  104. self.assertEqual(c.localMaxPacket, 32768)
  105. self.assertEqual(c.remoteWindowLeft, 0)
  106. self.assertEqual(c.remoteMaxPacket, 0)
  107. self.assertEqual(c.conn, self.conn)
  108. self.assertIsNone(c.data)
  109. self.assertIsNone(c.avatar)
  110. c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
  111. self.assertEqual(c2.localWindowSize, 1)
  112. self.assertEqual(c2.localWindowLeft, 1)
  113. self.assertEqual(c2.localMaxPacket, 2)
  114. self.assertEqual(c2.remoteWindowLeft, 3)
  115. self.assertEqual(c2.remoteMaxPacket, 4)
  116. self.assertEqual(c2.conn, 5)
  117. self.assertEqual(c2.data, 6)
  118. self.assertEqual(c2.avatar, 7)
  119. def test_str(self):
  120. """
  121. Test that str(SSHChannel) works gives the channel name and local and
  122. remote windows at a glance..
  123. """
  124. self.assertEqual(
  125. str(self.channel), '<SSHChannel channel (lw 131072 rw 0)>')
  126. self.assertEqual(
  127. str(channel.SSHChannel(localWindow=1)),
  128. '<SSHChannel None (lw 1 rw 0)>')
  129. def test_bytes(self):
  130. """
  131. Test that bytes(SSHChannel) works, gives the channel name and
  132. local and remote windows at a glance..
  133. """
  134. self.assertEqual(
  135. self.channel.__bytes__(),
  136. b'<SSHChannel channel (lw 131072 rw 0)>')
  137. self.assertEqual(
  138. channel.SSHChannel(localWindow=1).__bytes__(),
  139. b'<SSHChannel None (lw 1 rw 0)>')
  140. def test_logPrefix(self):
  141. """
  142. Test that SSHChannel.logPrefix gives the name of the channel, the
  143. local channel ID and the underlying connection.
  144. """
  145. self.assertEqual(self.channel.logPrefix(), 'SSHChannel channel '
  146. '(unknown) on MockConnection')
  147. def test_addWindowBytes(self):
  148. """
  149. Test that addWindowBytes adds bytes to the window and resumes writing
  150. if it was paused.
  151. """
  152. cb = [False]
  153. def stubStartWriting():
  154. cb[0] = True
  155. self.channel.startWriting = stubStartWriting
  156. self.channel.write(b'test')
  157. self.channel.writeExtended(1, b'test')
  158. self.channel.addWindowBytes(50)
  159. self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
  160. self.assertTrue(self.channel.areWriting)
  161. self.assertTrue(cb[0])
  162. self.assertEqual(self.channel.buf, b'')
  163. self.assertEqual(self.conn.data[self.channel], [b'test'])
  164. self.assertEqual(self.channel.extBuf, [])
  165. self.assertEqual(self.conn.extData[self.channel], [(1, b'test')])
  166. cb[0] = False
  167. self.channel.addWindowBytes(20)
  168. self.assertFalse(cb[0])
  169. self.channel.write(b'a'*80)
  170. self.channel.loseConnection()
  171. self.channel.addWindowBytes(20)
  172. self.assertFalse(cb[0])
  173. def test_requestReceived(self):
  174. """
  175. Test that requestReceived handles requests by dispatching them to
  176. request_* methods.
  177. """
  178. self.channel.request_test_method = lambda data: data == b''
  179. self.assertTrue(self.channel.requestReceived(b'test-method', b''))
  180. self.assertFalse(self.channel.requestReceived(b'test-method', b'a'))
  181. self.assertFalse(self.channel.requestReceived(b'bad-method', b''))
  182. def test_closeReceieved(self):
  183. """
  184. Test that the default closeReceieved closes the connection.
  185. """
  186. self.assertFalse(self.channel.closing)
  187. self.channel.closeReceived()
  188. self.assertTrue(self.channel.closing)
  189. def test_write(self):
  190. """
  191. Test that write handles data correctly. Send data up to the size
  192. of the remote window, splitting the data into packets of length
  193. remoteMaxPacket.
  194. """
  195. cb = [False]
  196. def stubStopWriting():
  197. cb[0] = True
  198. # no window to start with
  199. self.channel.stopWriting = stubStopWriting
  200. self.channel.write(b'd')
  201. self.channel.write(b'a')
  202. self.assertFalse(self.channel.areWriting)
  203. self.assertTrue(cb[0])
  204. # regular write
  205. self.channel.addWindowBytes(20)
  206. self.channel.write(b'ta')
  207. data = self.conn.data[self.channel]
  208. self.assertEqual(data, [b'da', b'ta'])
  209. self.assertEqual(self.channel.remoteWindowLeft, 16)
  210. # larger than max packet
  211. self.channel.write(b'12345678901')
  212. self.assertEqual(data, [b'da', b'ta', b'1234567890', b'1'])
  213. self.assertEqual(self.channel.remoteWindowLeft, 5)
  214. # running out of window
  215. cb[0] = False
  216. self.channel.write(b'123456')
  217. self.assertFalse(self.channel.areWriting)
  218. self.assertTrue(cb[0])
  219. self.assertEqual(data, [b'da', b'ta', b'1234567890', b'1', b'12345'])
  220. self.assertEqual(self.channel.buf, b'6')
  221. self.assertEqual(self.channel.remoteWindowLeft, 0)
  222. def test_writeExtended(self):
  223. """
  224. Test that writeExtended handles data correctly. Send extended data
  225. up to the size of the window, splitting the extended data into packets
  226. of length remoteMaxPacket.
  227. """
  228. cb = [False]
  229. def stubStopWriting():
  230. cb[0] = True
  231. # no window to start with
  232. self.channel.stopWriting = stubStopWriting
  233. self.channel.writeExtended(1, b'd')
  234. self.channel.writeExtended(1, b'a')
  235. self.channel.writeExtended(2, b't')
  236. self.assertFalse(self.channel.areWriting)
  237. self.assertTrue(cb[0])
  238. # regular write
  239. self.channel.addWindowBytes(20)
  240. self.channel.writeExtended(2, b'a')
  241. data = self.conn.extData[self.channel]
  242. self.assertEqual(data, [(1, b'da'), (2, b't'), (2, b'a')])
  243. self.assertEqual(self.channel.remoteWindowLeft, 16)
  244. # larger than max packet
  245. self.channel.writeExtended(3, b'12345678901')
  246. self.assertEqual(data, [(1, b'da'), (2, b't'), (2, b'a'),
  247. (3, b'1234567890'), (3, b'1')])
  248. self.assertEqual(self.channel.remoteWindowLeft, 5)
  249. # running out of window
  250. cb[0] = False
  251. self.channel.writeExtended(4, b'123456')
  252. self.assertFalse(self.channel.areWriting)
  253. self.assertTrue(cb[0])
  254. self.assertEqual(data, [(1, b'da'), (2, b't'), (2, b'a'),
  255. (3, b'1234567890'), (3, b'1'), (4, b'12345')])
  256. self.assertEqual(self.channel.extBuf, [[4, b'6']])
  257. self.assertEqual(self.channel.remoteWindowLeft, 0)
  258. def test_writeSequence(self):
  259. """
  260. Test that writeSequence is equivalent to write(''.join(sequece)).
  261. """
  262. self.channel.addWindowBytes(20)
  263. self.channel.writeSequence(map(intToBytes, range(10)))
  264. self.assertEqual(self.conn.data[self.channel], [b'0123456789'])
  265. def test_loseConnection(self):
  266. """
  267. Tesyt that loseConnection() doesn't close the channel until all
  268. the data is sent.
  269. """
  270. self.channel.write(b'data')
  271. self.channel.writeExtended(1, b'datadata')
  272. self.channel.loseConnection()
  273. self.assertIsNone(self.conn.closes.get(self.channel))
  274. self.channel.addWindowBytes(4) # send regular data
  275. self.assertIsNone(self.conn.closes.get(self.channel))
  276. self.channel.addWindowBytes(8) # send extended data
  277. self.assertTrue(self.conn.closes.get(self.channel))
  278. def test_getPeer(self):
  279. """
  280. L{SSHChannel.getPeer} returns the same object as the underlying
  281. transport's C{getPeer} method returns.
  282. """
  283. peer = IPv4Address('TCP', '192.168.0.1', 54321)
  284. connectSSHTransport(service=self.channel.conn, peerAddress=peer)
  285. self.assertEqual(SSHTransportAddress(peer), self.channel.getPeer())
  286. def test_getHost(self):
  287. """
  288. L{SSHChannel.getHost} returns the same object as the underlying
  289. transport's C{getHost} method returns.
  290. """
  291. host = IPv4Address('TCP', '127.0.0.1', 12345)
  292. connectSSHTransport(service=self.channel.conn, hostAddress=host)
  293. self.assertEqual(SSHTransportAddress(host), self.channel.getHost())