test_policies.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. Test code for policies.
  5. """
  6. from __future__ import division, absolute_import
  7. from zope.interface import Interface, implementer, implementedBy
  8. from twisted.python.compat import NativeStringIO
  9. from twisted.trial import unittest
  10. from twisted.test.proto_helpers import StringTransport
  11. from twisted.test.proto_helpers import StringTransportWithDisconnection
  12. from twisted.internet import protocol, reactor, address, defer, task
  13. from twisted.protocols import policies
  14. try:
  15. import builtins
  16. except ImportError:
  17. import __builtin__ as builtins
  18. class SimpleProtocol(protocol.Protocol):
  19. connected = disconnected = 0
  20. buffer = b""
  21. def __init__(self):
  22. self.dConnected = defer.Deferred()
  23. self.dDisconnected = defer.Deferred()
  24. def connectionMade(self):
  25. self.connected = 1
  26. self.dConnected.callback('')
  27. def connectionLost(self, reason):
  28. self.disconnected = 1
  29. self.dDisconnected.callback('')
  30. def dataReceived(self, data):
  31. self.buffer += data
  32. class SillyFactory(protocol.ClientFactory):
  33. def __init__(self, p):
  34. self.p = p
  35. def buildProtocol(self, addr):
  36. return self.p
  37. class EchoProtocol(protocol.Protocol):
  38. paused = False
  39. def pauseProducing(self):
  40. self.paused = True
  41. def resumeProducing(self):
  42. self.paused = False
  43. def stopProducing(self):
  44. pass
  45. def dataReceived(self, data):
  46. self.transport.write(data)
  47. class Server(protocol.ServerFactory):
  48. """
  49. A simple server factory using L{EchoProtocol}.
  50. """
  51. protocol = EchoProtocol
  52. class TestableThrottlingFactory(policies.ThrottlingFactory):
  53. """
  54. L{policies.ThrottlingFactory} using a L{task.Clock} for tests.
  55. """
  56. def __init__(self, clock, *args, **kwargs):
  57. """
  58. @param clock: object providing a callLater method that can be used
  59. for tests.
  60. @type clock: C{task.Clock} or alike.
  61. """
  62. policies.ThrottlingFactory.__init__(self, *args, **kwargs)
  63. self.clock = clock
  64. def callLater(self, period, func):
  65. """
  66. Forward to the testable clock.
  67. """
  68. return self.clock.callLater(period, func)
  69. class TestableTimeoutFactory(policies.TimeoutFactory):
  70. """
  71. L{policies.TimeoutFactory} using a L{task.Clock} for tests.
  72. """
  73. def __init__(self, clock, *args, **kwargs):
  74. """
  75. @param clock: object providing a callLater method that can be used
  76. for tests.
  77. @type clock: C{task.Clock} or alike.
  78. """
  79. policies.TimeoutFactory.__init__(self, *args, **kwargs)
  80. self.clock = clock
  81. def callLater(self, period, func):
  82. """
  83. Forward to the testable clock.
  84. """
  85. return self.clock.callLater(period, func)
  86. class WrapperTests(unittest.TestCase):
  87. """
  88. Tests for L{WrappingFactory} and L{ProtocolWrapper}.
  89. """
  90. def test_protocolFactoryAttribute(self):
  91. """
  92. Make sure protocol.factory is the wrapped factory, not the wrapping
  93. factory.
  94. """
  95. f = Server()
  96. wf = policies.WrappingFactory(f)
  97. p = wf.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 35))
  98. self.assertIs(p.wrappedProtocol.factory, f)
  99. def test_transportInterfaces(self):
  100. """
  101. The transport wrapper passed to the wrapped protocol's
  102. C{makeConnection} provides the same interfaces as are provided by the
  103. original transport.
  104. """
  105. class IStubTransport(Interface):
  106. pass
  107. @implementer(IStubTransport)
  108. class StubTransport:
  109. pass
  110. # Looking up what ProtocolWrapper implements also mutates the class.
  111. # It adds __implemented__ and __providedBy__ attributes to it. These
  112. # prevent __getattr__ from causing the IStubTransport.providedBy call
  113. # below from returning True. If, by accident, nothing else causes
  114. # these attributes to be added to ProtocolWrapper, the test will pass,
  115. # but the interface will only be provided until something does trigger
  116. # their addition. So we just trigger it right now to be sure.
  117. implementedBy(policies.ProtocolWrapper)
  118. proto = protocol.Protocol()
  119. wrapper = policies.ProtocolWrapper(policies.WrappingFactory(None), proto)
  120. wrapper.makeConnection(StubTransport())
  121. self.assertTrue(IStubTransport.providedBy(proto.transport))
  122. def test_factoryLogPrefix(self):
  123. """
  124. L{WrappingFactory.logPrefix} is customized to mention both the original
  125. factory and the wrapping factory.
  126. """
  127. server = Server()
  128. factory = policies.WrappingFactory(server)
  129. self.assertEqual("Server (WrappingFactory)", factory.logPrefix())
  130. def test_factoryLogPrefixFallback(self):
  131. """
  132. If the wrapped factory doesn't have a L{logPrefix} method,
  133. L{WrappingFactory.logPrefix} falls back to the factory class name.
  134. """
  135. class NoFactory(object):
  136. pass
  137. server = NoFactory()
  138. factory = policies.WrappingFactory(server)
  139. self.assertEqual("NoFactory (WrappingFactory)", factory.logPrefix())
  140. def test_protocolLogPrefix(self):
  141. """
  142. L{ProtocolWrapper.logPrefix} is customized to mention both the original
  143. protocol and the wrapper.
  144. """
  145. server = Server()
  146. factory = policies.WrappingFactory(server)
  147. protocol = factory.buildProtocol(
  148. address.IPv4Address('TCP', '127.0.0.1', 35))
  149. self.assertEqual("EchoProtocol (ProtocolWrapper)",
  150. protocol.logPrefix())
  151. def test_protocolLogPrefixFallback(self):
  152. """
  153. If the wrapped protocol doesn't have a L{logPrefix} method,
  154. L{ProtocolWrapper.logPrefix} falls back to the protocol class name.
  155. """
  156. class NoProtocol(object):
  157. pass
  158. server = Server()
  159. server.protocol = NoProtocol
  160. factory = policies.WrappingFactory(server)
  161. protocol = factory.buildProtocol(
  162. address.IPv4Address('TCP', '127.0.0.1', 35))
  163. self.assertEqual("NoProtocol (ProtocolWrapper)",
  164. protocol.logPrefix())
  165. def _getWrapper(self):
  166. """
  167. Return L{policies.ProtocolWrapper} that has been connected to a
  168. L{StringTransport}.
  169. """
  170. wrapper = policies.ProtocolWrapper(policies.WrappingFactory(Server()),
  171. protocol.Protocol())
  172. transport = StringTransport()
  173. wrapper.makeConnection(transport)
  174. return wrapper
  175. def test_getHost(self):
  176. """
  177. L{policies.ProtocolWrapper.getHost} calls C{getHost} on the underlying
  178. transport.
  179. """
  180. wrapper = self._getWrapper()
  181. self.assertEqual(wrapper.getHost(), wrapper.transport.getHost())
  182. def test_getPeer(self):
  183. """
  184. L{policies.ProtocolWrapper.getPeer} calls C{getPeer} on the underlying
  185. transport.
  186. """
  187. wrapper = self._getWrapper()
  188. self.assertEqual(wrapper.getPeer(), wrapper.transport.getPeer())
  189. def test_registerProducer(self):
  190. """
  191. L{policies.ProtocolWrapper.registerProducer} calls C{registerProducer}
  192. on the underlying transport.
  193. """
  194. wrapper = self._getWrapper()
  195. producer = object()
  196. wrapper.registerProducer(producer, True)
  197. self.assertIs(wrapper.transport.producer, producer)
  198. self.assertTrue(wrapper.transport.streaming)
  199. def test_unregisterProducer(self):
  200. """
  201. L{policies.ProtocolWrapper.unregisterProducer} calls
  202. C{unregisterProducer} on the underlying transport.
  203. """
  204. wrapper = self._getWrapper()
  205. producer = object()
  206. wrapper.registerProducer(producer, True)
  207. wrapper.unregisterProducer()
  208. self.assertIsNone(wrapper.transport.producer)
  209. self.assertIsNone(wrapper.transport.streaming)
  210. def test_stopConsuming(self):
  211. """
  212. L{policies.ProtocolWrapper.stopConsuming} calls C{stopConsuming} on
  213. the underlying transport.
  214. """
  215. wrapper = self._getWrapper()
  216. result = []
  217. wrapper.transport.stopConsuming = lambda: result.append(True)
  218. wrapper.stopConsuming()
  219. self.assertEqual(result, [True])
  220. def test_startedConnecting(self):
  221. """
  222. L{policies.WrappingFactory.startedConnecting} calls
  223. C{startedConnecting} on the underlying factory.
  224. """
  225. result = []
  226. class Factory(object):
  227. def startedConnecting(self, connector):
  228. result.append(connector)
  229. wrapper = policies.WrappingFactory(Factory())
  230. connector = object()
  231. wrapper.startedConnecting(connector)
  232. self.assertEqual(result, [connector])
  233. def test_clientConnectionLost(self):
  234. """
  235. L{policies.WrappingFactory.clientConnectionLost} calls
  236. C{clientConnectionLost} on the underlying factory.
  237. """
  238. result = []
  239. class Factory(object):
  240. def clientConnectionLost(self, connector, reason):
  241. result.append((connector, reason))
  242. wrapper = policies.WrappingFactory(Factory())
  243. connector = object()
  244. reason = object()
  245. wrapper.clientConnectionLost(connector, reason)
  246. self.assertEqual(result, [(connector, reason)])
  247. def test_clientConnectionFailed(self):
  248. """
  249. L{policies.WrappingFactory.clientConnectionFailed} calls
  250. C{clientConnectionFailed} on the underlying factory.
  251. """
  252. result = []
  253. class Factory(object):
  254. def clientConnectionFailed(self, connector, reason):
  255. result.append((connector, reason))
  256. wrapper = policies.WrappingFactory(Factory())
  257. connector = object()
  258. reason = object()
  259. wrapper.clientConnectionFailed(connector, reason)
  260. self.assertEqual(result, [(connector, reason)])
  261. class WrappingFactory(policies.WrappingFactory):
  262. protocol = lambda s, f, p: p
  263. def startFactory(self):
  264. policies.WrappingFactory.startFactory(self)
  265. self.deferred.callback(None)
  266. class ThrottlingTests(unittest.TestCase):
  267. """
  268. Tests for L{policies.ThrottlingFactory}.
  269. """
  270. def test_limit(self):
  271. """
  272. Full test using a custom server limiting number of connections.
  273. """
  274. server = Server()
  275. c1, c2, c3, c4 = [SimpleProtocol() for i in range(4)]
  276. tServer = policies.ThrottlingFactory(server, 2)
  277. wrapTServer = WrappingFactory(tServer)
  278. wrapTServer.deferred = defer.Deferred()
  279. # Start listening
  280. p = reactor.listenTCP(0, wrapTServer, interface="127.0.0.1")
  281. n = p.getHost().port
  282. def _connect123(results):
  283. reactor.connectTCP("127.0.0.1", n, SillyFactory(c1))
  284. c1.dConnected.addCallback(
  285. lambda r: reactor.connectTCP("127.0.0.1", n, SillyFactory(c2)))
  286. c2.dConnected.addCallback(
  287. lambda r: reactor.connectTCP("127.0.0.1", n, SillyFactory(c3)))
  288. return c3.dDisconnected
  289. def _check123(results):
  290. self.assertEqual([c.connected for c in (c1, c2, c3)], [1, 1, 1])
  291. self.assertEqual([c.disconnected for c in (c1, c2, c3)], [0, 0, 1])
  292. self.assertEqual(len(tServer.protocols.keys()), 2)
  293. return results
  294. def _lose1(results):
  295. # disconnect one protocol and now another should be able to connect
  296. c1.transport.loseConnection()
  297. return c1.dDisconnected
  298. def _connect4(results):
  299. reactor.connectTCP("127.0.0.1", n, SillyFactory(c4))
  300. return c4.dConnected
  301. def _check4(results):
  302. self.assertEqual(c4.connected, 1)
  303. self.assertEqual(c4.disconnected, 0)
  304. return results
  305. def _cleanup(results):
  306. for c in c2, c4:
  307. c.transport.loseConnection()
  308. return defer.DeferredList([
  309. defer.maybeDeferred(p.stopListening),
  310. c2.dDisconnected,
  311. c4.dDisconnected])
  312. wrapTServer.deferred.addCallback(_connect123)
  313. wrapTServer.deferred.addCallback(_check123)
  314. wrapTServer.deferred.addCallback(_lose1)
  315. wrapTServer.deferred.addCallback(_connect4)
  316. wrapTServer.deferred.addCallback(_check4)
  317. wrapTServer.deferred.addCallback(_cleanup)
  318. return wrapTServer.deferred
  319. def test_writeSequence(self):
  320. """
  321. L{ThrottlingProtocol.writeSequence} is called on the underlying factory.
  322. """
  323. server = Server()
  324. tServer = TestableThrottlingFactory(task.Clock(), server)
  325. protocol = tServer.buildProtocol(
  326. address.IPv4Address('TCP', '127.0.0.1', 0))
  327. transport = StringTransportWithDisconnection()
  328. transport.protocol = protocol
  329. protocol.makeConnection(transport)
  330. protocol.writeSequence([b'bytes'] * 4)
  331. self.assertEqual(transport.value(), b"bytesbytesbytesbytes")
  332. self.assertEqual(tServer.writtenThisSecond, 20)
  333. def test_writeLimit(self):
  334. """
  335. Check the writeLimit parameter: write data, and check for the pause
  336. status.
  337. """
  338. server = Server()
  339. tServer = TestableThrottlingFactory(task.Clock(), server, writeLimit=10)
  340. port = tServer.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 0))
  341. tr = StringTransportWithDisconnection()
  342. tr.protocol = port
  343. port.makeConnection(tr)
  344. port.producer = port.wrappedProtocol
  345. port.dataReceived(b"0123456789")
  346. port.dataReceived(b"abcdefghij")
  347. self.assertEqual(tr.value(), b"0123456789abcdefghij")
  348. self.assertEqual(tServer.writtenThisSecond, 20)
  349. self.assertFalse(port.wrappedProtocol.paused)
  350. # at this point server should've written 20 bytes, 10 bytes
  351. # above the limit so writing should be paused around 1 second
  352. # from 'now', and resumed a second after that
  353. tServer.clock.advance(1.05)
  354. self.assertEqual(tServer.writtenThisSecond, 0)
  355. self.assertTrue(port.wrappedProtocol.paused)
  356. tServer.clock.advance(1.05)
  357. self.assertEqual(tServer.writtenThisSecond, 0)
  358. self.assertFalse(port.wrappedProtocol.paused)
  359. def test_readLimit(self):
  360. """
  361. Check the readLimit parameter: read data and check for the pause
  362. status.
  363. """
  364. server = Server()
  365. tServer = TestableThrottlingFactory(task.Clock(), server, readLimit=10)
  366. port = tServer.buildProtocol(address.IPv4Address('TCP', '127.0.0.1', 0))
  367. tr = StringTransportWithDisconnection()
  368. tr.protocol = port
  369. port.makeConnection(tr)
  370. port.dataReceived(b"0123456789")
  371. port.dataReceived(b"abcdefghij")
  372. self.assertEqual(tr.value(), b"0123456789abcdefghij")
  373. self.assertEqual(tServer.readThisSecond, 20)
  374. tServer.clock.advance(1.05)
  375. self.assertEqual(tServer.readThisSecond, 0)
  376. self.assertEqual(tr.producerState, 'paused')
  377. tServer.clock.advance(1.05)
  378. self.assertEqual(tServer.readThisSecond, 0)
  379. self.assertEqual(tr.producerState, 'producing')
  380. tr.clear()
  381. port.dataReceived(b"0123456789")
  382. port.dataReceived(b"abcdefghij")
  383. self.assertEqual(tr.value(), b"0123456789abcdefghij")
  384. self.assertEqual(tServer.readThisSecond, 20)
  385. tServer.clock.advance(1.05)
  386. self.assertEqual(tServer.readThisSecond, 0)
  387. self.assertEqual(tr.producerState, 'paused')
  388. tServer.clock.advance(1.05)
  389. self.assertEqual(tServer.readThisSecond, 0)
  390. self.assertEqual(tr.producerState, 'producing')
  391. class TimeoutFactoryTests(unittest.TestCase):
  392. """
  393. Tests for L{policies.TimeoutFactory}.
  394. """
  395. def setUp(self):
  396. """
  397. Create a testable, deterministic clock, and a set of
  398. server factory/protocol/transport.
  399. """
  400. self.clock = task.Clock()
  401. wrappedFactory = protocol.ServerFactory()
  402. wrappedFactory.protocol = SimpleProtocol
  403. self.factory = TestableTimeoutFactory(self.clock, wrappedFactory, 3)
  404. self.proto = self.factory.buildProtocol(
  405. address.IPv4Address('TCP', '127.0.0.1', 12345))
  406. self.transport = StringTransportWithDisconnection()
  407. self.transport.protocol = self.proto
  408. self.proto.makeConnection(self.transport)
  409. def test_timeout(self):
  410. """
  411. Make sure that when a TimeoutFactory accepts a connection, it will
  412. time out that connection if no data is read or written within the
  413. timeout period.
  414. """
  415. # Let almost 3 time units pass
  416. self.clock.pump([0.0, 0.5, 1.0, 1.0, 0.4])
  417. self.assertFalse(self.proto.wrappedProtocol.disconnected)
  418. # Now let the timer elapse
  419. self.clock.pump([0.0, 0.2])
  420. self.assertTrue(self.proto.wrappedProtocol.disconnected)
  421. def test_sendAvoidsTimeout(self):
  422. """
  423. Make sure that writing data to a transport from a protocol
  424. constructed by a TimeoutFactory resets the timeout countdown.
  425. """
  426. # Let half the countdown period elapse
  427. self.clock.pump([0.0, 0.5, 1.0])
  428. self.assertFalse(self.proto.wrappedProtocol.disconnected)
  429. # Send some data (self.proto is the /real/ proto's transport, so this
  430. # is the write that gets called)
  431. self.proto.write(b'bytes bytes bytes')
  432. # More time passes, putting us past the original timeout
  433. self.clock.pump([0.0, 1.0, 1.0])
  434. self.assertFalse(self.proto.wrappedProtocol.disconnected)
  435. # Make sure writeSequence delays timeout as well
  436. self.proto.writeSequence([b'bytes'] * 3)
  437. # Tick tock
  438. self.clock.pump([0.0, 1.0, 1.0])
  439. self.assertFalse(self.proto.wrappedProtocol.disconnected)
  440. # Don't write anything more, just let the timeout expire
  441. self.clock.pump([0.0, 2.0])
  442. self.assertTrue(self.proto.wrappedProtocol.disconnected)
  443. def test_receiveAvoidsTimeout(self):
  444. """
  445. Make sure that receiving data also resets the timeout countdown.
  446. """
  447. # Let half the countdown period elapse
  448. self.clock.pump([0.0, 1.0, 0.5])
  449. self.assertFalse(self.proto.wrappedProtocol.disconnected)
  450. # Some bytes arrive, they should reset the counter
  451. self.proto.dataReceived(b'bytes bytes bytes')
  452. # We pass the original timeout
  453. self.clock.pump([0.0, 1.0, 1.0])
  454. self.assertFalse(self.proto.wrappedProtocol.disconnected)
  455. # Nothing more arrives though, the new timeout deadline is passed,
  456. # the connection should be dropped.
  457. self.clock.pump([0.0, 1.0, 1.0])
  458. self.assertTrue(self.proto.wrappedProtocol.disconnected)
  459. class TimeoutTester(protocol.Protocol, policies.TimeoutMixin):
  460. """
  461. A testable protocol with timeout facility.
  462. @ivar timedOut: set to C{True} if a timeout has been detected.
  463. @type timedOut: C{bool}
  464. """
  465. timeOut = 3
  466. timedOut = False
  467. def __init__(self, clock):
  468. """
  469. Initialize the protocol with a C{task.Clock} object.
  470. """
  471. self.clock = clock
  472. def connectionMade(self):
  473. """
  474. Upon connection, set the timeout.
  475. """
  476. self.setTimeout(self.timeOut)
  477. def dataReceived(self, data):
  478. """
  479. Reset the timeout on data.
  480. """
  481. self.resetTimeout()
  482. protocol.Protocol.dataReceived(self, data)
  483. def connectionLost(self, reason=None):
  484. """
  485. On connection lost, cancel all timeout operations.
  486. """
  487. self.setTimeout(None)
  488. def timeoutConnection(self):
  489. """
  490. Flags the timedOut variable to indicate the timeout of the connection.
  491. """
  492. self.timedOut = True
  493. def callLater(self, timeout, func, *args, **kwargs):
  494. """
  495. Override callLater to use the deterministic clock.
  496. """
  497. return self.clock.callLater(timeout, func, *args, **kwargs)
  498. class TimeoutMixinTests(unittest.TestCase):
  499. """
  500. Tests for L{policies.TimeoutMixin}.
  501. """
  502. def setUp(self):
  503. """
  504. Create a testable, deterministic clock and a C{TimeoutTester} instance.
  505. """
  506. self.clock = task.Clock()
  507. self.proto = TimeoutTester(self.clock)
  508. def test_overriddenCallLater(self):
  509. """
  510. Test that the callLater of the clock is used instead of
  511. L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
  512. """
  513. self.proto.setTimeout(10)
  514. self.assertEqual(len(self.clock.calls), 1)
  515. def test_timeout(self):
  516. """
  517. Check that the protocol does timeout at the time specified by its
  518. C{timeOut} attribute.
  519. """
  520. self.proto.makeConnection(StringTransport())
  521. # timeOut value is 3
  522. self.clock.pump([0, 0.5, 1.0, 1.0])
  523. self.assertFalse(self.proto.timedOut)
  524. self.clock.pump([0, 1.0])
  525. self.assertTrue(self.proto.timedOut)
  526. def test_noTimeout(self):
  527. """
  528. Check that receiving data is delaying the timeout of the connection.
  529. """
  530. self.proto.makeConnection(StringTransport())
  531. self.clock.pump([0, 0.5, 1.0, 1.0])
  532. self.assertFalse(self.proto.timedOut)
  533. self.proto.dataReceived(b'hello there')
  534. self.clock.pump([0, 1.0, 1.0, 0.5])
  535. self.assertFalse(self.proto.timedOut)
  536. self.clock.pump([0, 1.0])
  537. self.assertTrue(self.proto.timedOut)
  538. def test_resetTimeout(self):
  539. """
  540. Check that setting a new value for timeout cancel the previous value
  541. and install a new timeout.
  542. """
  543. self.proto.timeOut = None
  544. self.proto.makeConnection(StringTransport())
  545. self.proto.setTimeout(1)
  546. self.assertEqual(self.proto.timeOut, 1)
  547. self.clock.pump([0, 0.9])
  548. self.assertFalse(self.proto.timedOut)
  549. self.clock.pump([0, 0.2])
  550. self.assertTrue(self.proto.timedOut)
  551. def test_cancelTimeout(self):
  552. """
  553. Setting the timeout to L{None} cancel any timeout operations.
  554. """
  555. self.proto.timeOut = 5
  556. self.proto.makeConnection(StringTransport())
  557. self.proto.setTimeout(None)
  558. self.assertIsNone(self.proto.timeOut)
  559. self.clock.pump([0, 5, 5, 5])
  560. self.assertFalse(self.proto.timedOut)
  561. def test_return(self):
  562. """
  563. setTimeout should return the value of the previous timeout.
  564. """
  565. self.proto.timeOut = 5
  566. self.assertEqual(self.proto.setTimeout(10), 5)
  567. self.assertEqual(self.proto.setTimeout(None), 10)
  568. self.assertIsNone(self.proto.setTimeout(1))
  569. self.assertEqual(self.proto.timeOut, 1)
  570. # Clean up the DelayedCall
  571. self.proto.setTimeout(None)
  572. class LimitTotalConnectionsFactoryTests(unittest.TestCase):
  573. """Tests for policies.LimitTotalConnectionsFactory"""
  574. def testConnectionCounting(self):
  575. # Make a basic factory
  576. factory = policies.LimitTotalConnectionsFactory()
  577. factory.protocol = protocol.Protocol
  578. # connectionCount starts at zero
  579. self.assertEqual(0, factory.connectionCount)
  580. # connectionCount increments as connections are made
  581. p1 = factory.buildProtocol(None)
  582. self.assertEqual(1, factory.connectionCount)
  583. p2 = factory.buildProtocol(None)
  584. self.assertEqual(2, factory.connectionCount)
  585. # and decrements as they are lost
  586. p1.connectionLost(None)
  587. self.assertEqual(1, factory.connectionCount)
  588. p2.connectionLost(None)
  589. self.assertEqual(0, factory.connectionCount)
  590. def testConnectionLimiting(self):
  591. # Make a basic factory with a connection limit of 1
  592. factory = policies.LimitTotalConnectionsFactory()
  593. factory.protocol = protocol.Protocol
  594. factory.connectionLimit = 1
  595. # Make a connection
  596. p = factory.buildProtocol(None)
  597. self.assertIsNotNone(p)
  598. self.assertEqual(1, factory.connectionCount)
  599. # Try to make a second connection, which will exceed the connection
  600. # limit. This should return None, because overflowProtocol is None.
  601. self.assertIsNone(factory.buildProtocol(None))
  602. self.assertEqual(1, factory.connectionCount)
  603. # Define an overflow protocol
  604. class OverflowProtocol(protocol.Protocol):
  605. def connectionMade(self):
  606. factory.overflowed = True
  607. factory.overflowProtocol = OverflowProtocol
  608. factory.overflowed = False
  609. # Try to make a second connection again, now that we have an overflow
  610. # protocol. Note that overflow connections count towards the connection
  611. # count.
  612. op = factory.buildProtocol(None)
  613. op.makeConnection(None) # to trigger connectionMade
  614. self.assertTrue(factory.overflowed)
  615. self.assertEqual(2, factory.connectionCount)
  616. # Close the connections.
  617. p.connectionLost(None)
  618. self.assertEqual(1, factory.connectionCount)
  619. op.connectionLost(None)
  620. self.assertEqual(0, factory.connectionCount)
  621. class WriteSequenceEchoProtocol(EchoProtocol):
  622. def dataReceived(self, bytes):
  623. if bytes.find(b'vector!') != -1:
  624. self.transport.writeSequence([bytes])
  625. else:
  626. EchoProtocol.dataReceived(self, bytes)
  627. class TestLoggingFactory(policies.TrafficLoggingFactory):
  628. openFile = None
  629. def open(self, name):
  630. assert self.openFile is None, "open() called too many times"
  631. self.openFile = NativeStringIO()
  632. return self.openFile
  633. class LoggingFactoryTests(unittest.TestCase):
  634. """
  635. Tests for L{policies.TrafficLoggingFactory}.
  636. """
  637. def test_thingsGetLogged(self):
  638. """
  639. Check the output produced by L{policies.TrafficLoggingFactory}.
  640. """
  641. wrappedFactory = Server()
  642. wrappedFactory.protocol = WriteSequenceEchoProtocol
  643. t = StringTransportWithDisconnection()
  644. f = TestLoggingFactory(wrappedFactory, 'test')
  645. p = f.buildProtocol(('1.2.3.4', 5678))
  646. t.protocol = p
  647. p.makeConnection(t)
  648. v = f.openFile.getvalue()
  649. self.assertIn('*', v)
  650. self.assertFalse(t.value())
  651. p.dataReceived(b'here are some bytes')
  652. v = f.openFile.getvalue()
  653. self.assertIn("C 1: %r" % (b'here are some bytes',), v)
  654. self.assertIn("S 1: %r" % (b'here are some bytes',), v)
  655. self.assertEqual(t.value(), b'here are some bytes')
  656. t.clear()
  657. p.dataReceived(b'prepare for vector! to the extreme')
  658. v = f.openFile.getvalue()
  659. self.assertIn("SV 1: %r" % ([b'prepare for vector! to the extreme'],), v)
  660. self.assertEqual(t.value(), b'prepare for vector! to the extreme')
  661. p.loseConnection()
  662. v = f.openFile.getvalue()
  663. self.assertIn('ConnectionDone', v)
  664. def test_counter(self):
  665. """
  666. Test counter management with the resetCounter method.
  667. """
  668. wrappedFactory = Server()
  669. f = TestLoggingFactory(wrappedFactory, 'test')
  670. self.assertEqual(f._counter, 0)
  671. f.buildProtocol(('1.2.3.4', 5678))
  672. self.assertEqual(f._counter, 1)
  673. # Reset log file
  674. f.openFile = None
  675. f.buildProtocol(('1.2.3.4', 5679))
  676. self.assertEqual(f._counter, 2)
  677. f.resetCounter()
  678. self.assertEqual(f._counter, 0)
  679. def test_loggingFactoryOpensLogfileAutomatically(self):
  680. """
  681. When the L{policies.TrafficLoggingFactory} builds a protocol, it
  682. automatically opens a unique log file for that protocol and attaches
  683. the logfile to the built protocol.
  684. """
  685. open_calls = []
  686. open_rvalues = []
  687. def mocked_open(*args, **kwargs):
  688. """
  689. Mock for the open call to prevent actually opening a log file.
  690. """
  691. open_calls.append((args, kwargs))
  692. io = NativeStringIO()
  693. io.name = args[0]
  694. open_rvalues.append(io)
  695. return io
  696. self.patch(builtins, 'open', mocked_open)
  697. wrappedFactory = protocol.ServerFactory()
  698. wrappedFactory.protocol = SimpleProtocol
  699. factory = policies.TrafficLoggingFactory(wrappedFactory, 'test')
  700. first_proto = factory.buildProtocol(address.IPv4Address('TCP',
  701. '127.0.0.1',
  702. 12345))
  703. second_proto = factory.buildProtocol(address.IPv4Address('TCP',
  704. '127.0.0.1',
  705. 12346))
  706. # We expect open to be called twice, with the files passed to the
  707. # protocols.
  708. first_call = (('test-1', 'w'), {})
  709. second_call = (('test-2', 'w'), {})
  710. self.assertEqual([first_call, second_call], open_calls)
  711. self.assertEqual(
  712. [first_proto.logfile, second_proto.logfile], open_rvalues
  713. )