test_protocol.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. Tests for L{twisted.internet.protocol}.
  5. """
  6. from __future__ import division, absolute_import
  7. from zope.interface.verify import verifyObject
  8. from zope.interface import implementer
  9. from twisted.python.failure import Failure
  10. from twisted.internet.interfaces import (
  11. IProtocol, ILoggingContext, IProtocolFactory, IConsumer)
  12. from twisted.internet.defer import CancelledError
  13. from twisted.internet.protocol import (
  14. Protocol, ClientCreator, Factory, ProtocolToConsumerAdapter,
  15. ConsumerToProtocolAdapter)
  16. from twisted.trial.unittest import TestCase
  17. from twisted.test.proto_helpers import MemoryReactorClock, StringTransport
  18. from twisted.logger import LogLevel, globalLogPublisher
  19. class ClientCreatorTests(TestCase):
  20. """
  21. Tests for L{twisted.internet.protocol.ClientCreator}.
  22. """
  23. def _basicConnectTest(self, check):
  24. """
  25. Helper for implementing a test to verify that one of the I{connect}
  26. methods of L{ClientCreator} passes the right arguments to the right
  27. reactor method.
  28. @param check: A function which will be invoked with a reactor and a
  29. L{ClientCreator} instance and which should call one of the
  30. L{ClientCreator}'s I{connect} methods and assert that all of its
  31. arguments except for the factory are passed on as expected to the
  32. reactor. The factory should be returned.
  33. """
  34. class SomeProtocol(Protocol):
  35. pass
  36. reactor = MemoryReactorClock()
  37. cc = ClientCreator(reactor, SomeProtocol)
  38. factory = check(reactor, cc)
  39. protocol = factory.buildProtocol(None)
  40. self.assertIsInstance(protocol, SomeProtocol)
  41. def test_connectTCP(self):
  42. """
  43. L{ClientCreator.connectTCP} calls C{reactor.connectTCP} with the host
  44. and port information passed to it, and with a factory which will
  45. construct the protocol passed to L{ClientCreator.__init__}.
  46. """
  47. def check(reactor, cc):
  48. cc.connectTCP('example.com', 1234, 4321, ('1.2.3.4', 9876))
  49. host, port, factory, timeout, bindAddress = reactor.tcpClients.pop()
  50. self.assertEqual(host, 'example.com')
  51. self.assertEqual(port, 1234)
  52. self.assertEqual(timeout, 4321)
  53. self.assertEqual(bindAddress, ('1.2.3.4', 9876))
  54. return factory
  55. self._basicConnectTest(check)
  56. def test_connectUNIX(self):
  57. """
  58. L{ClientCreator.connectUNIX} calls C{reactor.connectUNIX} with the
  59. filename passed to it, and with a factory which will construct the
  60. protocol passed to L{ClientCreator.__init__}.
  61. """
  62. def check(reactor, cc):
  63. cc.connectUNIX('/foo/bar', 123, True)
  64. address, factory, timeout, checkPID = reactor.unixClients.pop()
  65. self.assertEqual(address, '/foo/bar')
  66. self.assertEqual(timeout, 123)
  67. self.assertTrue(checkPID)
  68. return factory
  69. self._basicConnectTest(check)
  70. def test_connectSSL(self):
  71. """
  72. L{ClientCreator.connectSSL} calls C{reactor.connectSSL} with the host,
  73. port, and context factory passed to it, and with a factory which will
  74. construct the protocol passed to L{ClientCreator.__init__}.
  75. """
  76. def check(reactor, cc):
  77. expectedContextFactory = object()
  78. cc.connectSSL('example.com', 1234, expectedContextFactory, 4321, ('4.3.2.1', 5678))
  79. host, port, factory, contextFactory, timeout, bindAddress = reactor.sslClients.pop()
  80. self.assertEqual(host, 'example.com')
  81. self.assertEqual(port, 1234)
  82. self.assertIs(contextFactory, expectedContextFactory)
  83. self.assertEqual(timeout, 4321)
  84. self.assertEqual(bindAddress, ('4.3.2.1', 5678))
  85. return factory
  86. self._basicConnectTest(check)
  87. def _cancelConnectTest(self, connect):
  88. """
  89. Helper for implementing a test to verify that cancellation of the
  90. L{Deferred} returned by one of L{ClientCreator}'s I{connect} methods is
  91. implemented to cancel the underlying connector.
  92. @param connect: A function which will be invoked with a L{ClientCreator}
  93. instance as an argument and which should call one its I{connect}
  94. methods and return the result.
  95. @return: A L{Deferred} which fires when the test is complete or fails if
  96. there is a problem.
  97. """
  98. reactor = MemoryReactorClock()
  99. cc = ClientCreator(reactor, Protocol)
  100. d = connect(cc)
  101. connector = reactor.connectors.pop()
  102. self.assertFalse(connector._disconnected)
  103. d.cancel()
  104. self.assertTrue(connector._disconnected)
  105. return self.assertFailure(d, CancelledError)
  106. def test_cancelConnectTCP(self):
  107. """
  108. The L{Deferred} returned by L{ClientCreator.connectTCP} can be cancelled
  109. to abort the connection attempt before it completes.
  110. """
  111. def connect(cc):
  112. return cc.connectTCP('example.com', 1234)
  113. return self._cancelConnectTest(connect)
  114. def test_cancelConnectUNIX(self):
  115. """
  116. The L{Deferred} returned by L{ClientCreator.connectTCP} can be cancelled
  117. to abort the connection attempt before it completes.
  118. """
  119. def connect(cc):
  120. return cc.connectUNIX('/foo/bar')
  121. return self._cancelConnectTest(connect)
  122. def test_cancelConnectSSL(self):
  123. """
  124. The L{Deferred} returned by L{ClientCreator.connectTCP} can be cancelled
  125. to abort the connection attempt before it completes.
  126. """
  127. def connect(cc):
  128. return cc.connectSSL('example.com', 1234, object())
  129. return self._cancelConnectTest(connect)
  130. def _cancelConnectTimeoutTest(self, connect):
  131. """
  132. Like L{_cancelConnectTest}, but for the case where the L{Deferred} is
  133. cancelled after the connection is set up but before it is fired with the
  134. resulting protocol instance.
  135. """
  136. reactor = MemoryReactorClock()
  137. cc = ClientCreator(reactor, Protocol)
  138. d = connect(reactor, cc)
  139. connector = reactor.connectors.pop()
  140. # Sanity check - there is an outstanding delayed call to fire the
  141. # Deferred.
  142. self.assertEqual(len(reactor.getDelayedCalls()), 1)
  143. # Cancel the Deferred, disconnecting the transport just set up and
  144. # cancelling the delayed call.
  145. d.cancel()
  146. self.assertEqual(reactor.getDelayedCalls(), [])
  147. # A real connector implementation is responsible for disconnecting the
  148. # transport as well. For our purposes, just check that someone told the
  149. # connector to disconnect.
  150. self.assertTrue(connector._disconnected)
  151. return self.assertFailure(d, CancelledError)
  152. def test_cancelConnectTCPTimeout(self):
  153. """
  154. L{ClientCreator.connectTCP} inserts a very short delayed call between
  155. the time the connection is established and the time the L{Deferred}
  156. returned from one of its connect methods actually fires. If the
  157. L{Deferred} is cancelled in this interval, the established connection is
  158. closed, the timeout is cancelled, and the L{Deferred} fails with
  159. L{CancelledError}.
  160. """
  161. def connect(reactor, cc):
  162. d = cc.connectTCP('example.com', 1234)
  163. host, port, factory, timeout, bindAddress = reactor.tcpClients.pop()
  164. protocol = factory.buildProtocol(None)
  165. transport = StringTransport()
  166. protocol.makeConnection(transport)
  167. return d
  168. return self._cancelConnectTimeoutTest(connect)
  169. def test_cancelConnectUNIXTimeout(self):
  170. """
  171. L{ClientCreator.connectUNIX} inserts a very short delayed call between
  172. the time the connection is established and the time the L{Deferred}
  173. returned from one of its connect methods actually fires. If the
  174. L{Deferred} is cancelled in this interval, the established connection is
  175. closed, the timeout is cancelled, and the L{Deferred} fails with
  176. L{CancelledError}.
  177. """
  178. def connect(reactor, cc):
  179. d = cc.connectUNIX('/foo/bar')
  180. address, factory, timeout, bindAddress = reactor.unixClients.pop()
  181. protocol = factory.buildProtocol(None)
  182. transport = StringTransport()
  183. protocol.makeConnection(transport)
  184. return d
  185. return self._cancelConnectTimeoutTest(connect)
  186. def test_cancelConnectSSLTimeout(self):
  187. """
  188. L{ClientCreator.connectSSL} inserts a very short delayed call between
  189. the time the connection is established and the time the L{Deferred}
  190. returned from one of its connect methods actually fires. If the
  191. L{Deferred} is cancelled in this interval, the established connection is
  192. closed, the timeout is cancelled, and the L{Deferred} fails with
  193. L{CancelledError}.
  194. """
  195. def connect(reactor, cc):
  196. d = cc.connectSSL('example.com', 1234, object())
  197. host, port, factory, contextFactory, timeout, bindADdress = reactor.sslClients.pop()
  198. protocol = factory.buildProtocol(None)
  199. transport = StringTransport()
  200. protocol.makeConnection(transport)
  201. return d
  202. return self._cancelConnectTimeoutTest(connect)
  203. def _cancelConnectFailedTimeoutTest(self, connect):
  204. """
  205. Like L{_cancelConnectTest}, but for the case where the L{Deferred} is
  206. cancelled after the connection attempt has failed but before it is fired
  207. with the resulting failure.
  208. """
  209. reactor = MemoryReactorClock()
  210. cc = ClientCreator(reactor, Protocol)
  211. d, factory = connect(reactor, cc)
  212. connector = reactor.connectors.pop()
  213. factory.clientConnectionFailed(
  214. connector, Failure(Exception("Simulated failure")))
  215. # Sanity check - there is an outstanding delayed call to fire the
  216. # Deferred.
  217. self.assertEqual(len(reactor.getDelayedCalls()), 1)
  218. # Cancel the Deferred, cancelling the delayed call.
  219. d.cancel()
  220. self.assertEqual(reactor.getDelayedCalls(), [])
  221. return self.assertFailure(d, CancelledError)
  222. def test_cancelConnectTCPFailedTimeout(self):
  223. """
  224. Similar to L{test_cancelConnectTCPTimeout}, but for the case where the
  225. connection attempt fails.
  226. """
  227. def connect(reactor, cc):
  228. d = cc.connectTCP('example.com', 1234)
  229. host, port, factory, timeout, bindAddress = reactor.tcpClients.pop()
  230. return d, factory
  231. return self._cancelConnectFailedTimeoutTest(connect)
  232. def test_cancelConnectUNIXFailedTimeout(self):
  233. """
  234. Similar to L{test_cancelConnectUNIXTimeout}, but for the case where the
  235. connection attempt fails.
  236. """
  237. def connect(reactor, cc):
  238. d = cc.connectUNIX('/foo/bar')
  239. address, factory, timeout, bindAddress = reactor.unixClients.pop()
  240. return d, factory
  241. return self._cancelConnectFailedTimeoutTest(connect)
  242. def test_cancelConnectSSLFailedTimeout(self):
  243. """
  244. Similar to L{test_cancelConnectSSLTimeout}, but for the case where the
  245. connection attempt fails.
  246. """
  247. def connect(reactor, cc):
  248. d = cc.connectSSL('example.com', 1234, object())
  249. host, port, factory, contextFactory, timeout, bindADdress = reactor.sslClients.pop()
  250. return d, factory
  251. return self._cancelConnectFailedTimeoutTest(connect)
  252. class ProtocolTests(TestCase):
  253. """
  254. Tests for L{twisted.internet.protocol.Protocol}.
  255. """
  256. def test_interfaces(self):
  257. """
  258. L{Protocol} instances provide L{IProtocol} and L{ILoggingContext}.
  259. """
  260. proto = Protocol()
  261. self.assertTrue(verifyObject(IProtocol, proto))
  262. self.assertTrue(verifyObject(ILoggingContext, proto))
  263. def test_logPrefix(self):
  264. """
  265. L{Protocol.logPrefix} returns the protocol class's name.
  266. """
  267. class SomeThing(Protocol):
  268. pass
  269. self.assertEqual("SomeThing", SomeThing().logPrefix())
  270. def test_makeConnection(self):
  271. """
  272. L{Protocol.makeConnection} sets the given transport on itself, and
  273. then calls C{connectionMade}.
  274. """
  275. result = []
  276. class SomeProtocol(Protocol):
  277. def connectionMade(self):
  278. result.append(self.transport)
  279. transport = object()
  280. protocol = SomeProtocol()
  281. protocol.makeConnection(transport)
  282. self.assertEqual(result, [transport])
  283. class FactoryTests(TestCase):
  284. """
  285. Tests for L{protocol.Factory}.
  286. """
  287. def test_interfaces(self):
  288. """
  289. L{Factory} instances provide both L{IProtocolFactory} and
  290. L{ILoggingContext}.
  291. """
  292. factory = Factory()
  293. self.assertTrue(verifyObject(IProtocolFactory, factory))
  294. self.assertTrue(verifyObject(ILoggingContext, factory))
  295. def test_logPrefix(self):
  296. """
  297. L{Factory.logPrefix} returns the name of the factory class.
  298. """
  299. class SomeKindOfFactory(Factory):
  300. pass
  301. self.assertEqual("SomeKindOfFactory", SomeKindOfFactory().logPrefix())
  302. def test_defaultBuildProtocol(self):
  303. """
  304. L{Factory.buildProtocol} by default constructs a protocol by calling
  305. its C{protocol} attribute, and attaches the factory to the result.
  306. """
  307. class SomeProtocol(Protocol):
  308. pass
  309. f = Factory()
  310. f.protocol = SomeProtocol
  311. protocol = f.buildProtocol(None)
  312. self.assertIsInstance(protocol, SomeProtocol)
  313. self.assertIs(protocol.factory, f)
  314. def test_forProtocol(self):
  315. """
  316. L{Factory.forProtocol} constructs a Factory, passing along any
  317. additional arguments, and sets its C{protocol} attribute to the given
  318. Protocol subclass.
  319. """
  320. class ArgTakingFactory(Factory):
  321. def __init__(self, *args, **kwargs):
  322. self.args, self.kwargs = args, kwargs
  323. factory = ArgTakingFactory.forProtocol(Protocol, 1, 2, foo=12)
  324. self.assertEqual(factory.protocol, Protocol)
  325. self.assertEqual(factory.args, (1, 2))
  326. self.assertEqual(factory.kwargs, {"foo": 12})
  327. def test_doStartLoggingStatement(self):
  328. """
  329. L{Factory.doStart} logs that it is starting a factory, followed by
  330. the L{repr} of the L{Factory} instance that is being started.
  331. """
  332. events = []
  333. globalLogPublisher.addObserver(events.append)
  334. self.addCleanup(
  335. lambda: globalLogPublisher.removeObserver(events.append))
  336. f = Factory()
  337. f.doStart()
  338. self.assertIs(events[0]['factory'], f)
  339. self.assertEqual(events[0]['log_level'], LogLevel.info)
  340. self.assertEqual(events[0]['log_format'],
  341. 'Starting factory {factory!r}')
  342. def test_doStopLoggingStatement(self):
  343. """
  344. L{Factory.doStop} logs that it is stopping a factory, followed by
  345. the L{repr} of the L{Factory} instance that is being stopped.
  346. """
  347. events = []
  348. globalLogPublisher.addObserver(events.append)
  349. self.addCleanup(
  350. lambda: globalLogPublisher.removeObserver(events.append))
  351. class MyFactory(Factory):
  352. numPorts = 1
  353. f = MyFactory()
  354. f.doStop()
  355. self.assertIs(events[0]['factory'], f)
  356. self.assertEqual(events[0]['log_level'], LogLevel.info)
  357. self.assertEqual(events[0]['log_format'],
  358. 'Stopping factory {factory!r}')
  359. class AdapterTests(TestCase):
  360. """
  361. Tests for L{ProtocolToConsumerAdapter} and L{ConsumerToProtocolAdapter}.
  362. """
  363. def test_protocolToConsumer(self):
  364. """
  365. L{IProtocol} providers can be adapted to L{IConsumer} providers using
  366. L{ProtocolToConsumerAdapter}.
  367. """
  368. result = []
  369. p = Protocol()
  370. p.dataReceived = result.append
  371. consumer = IConsumer(p)
  372. consumer.write(b"hello")
  373. self.assertEqual(result, [b"hello"])
  374. self.assertIsInstance(consumer, ProtocolToConsumerAdapter)
  375. def test_consumerToProtocol(self):
  376. """
  377. L{IConsumer} providers can be adapted to L{IProtocol} providers using
  378. L{ProtocolToConsumerAdapter}.
  379. """
  380. result = []
  381. @implementer(IConsumer)
  382. class Consumer(object):
  383. def write(self, d):
  384. result.append(d)
  385. c = Consumer()
  386. protocol = IProtocol(c)
  387. protocol.dataReceived(b"hello")
  388. self.assertEqual(result, [b"hello"])
  389. self.assertIsInstance(protocol, ConsumerToProtocolAdapter)