test_tls.py 71 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. Tests for L{twisted.protocols.tls}.
  5. """
  6. from __future__ import division, absolute_import
  7. from zope.interface.verify import verifyObject
  8. from zope.interface import Interface, directlyProvides, implementer
  9. from twisted.python.compat import iterbytes
  10. try:
  11. from twisted.protocols.tls import TLSMemoryBIOProtocol, TLSMemoryBIOFactory
  12. from twisted.protocols.tls import _PullToPush, _ProducerMembrane
  13. from OpenSSL.crypto import X509Type
  14. from OpenSSL.SSL import (TLSv1_METHOD, TLSv1_1_METHOD, TLSv1_2_METHOD,
  15. Error, Context, ConnectionType,
  16. WantReadError)
  17. except ImportError:
  18. # Skip the whole test module if it can't be imported.
  19. skip = "pyOpenSSL 0.10 or newer required for twisted.protocol.tls"
  20. TLSv1_METHOD = TLSv1_1_METHOD = TLSv1_2_METHOD = None
  21. else:
  22. from twisted.internet.ssl import PrivateCertificate, optionsForClientTLS
  23. from twisted.test.ssl_helpers import (ClientTLSContext, ServerTLSContext,
  24. certPath)
  25. from twisted.test.test_sslverify import certificatesForAuthorityAndServer
  26. from twisted.test.iosim import connectedServerAndClient
  27. from twisted.python.filepath import FilePath
  28. from twisted.python.failure import Failure
  29. from twisted.python import log
  30. from twisted.internet.interfaces import (
  31. ISystemHandle, ISSLTransport,
  32. IPushProducer, IProtocolNegotiationFactory, IHandshakeListener,
  33. IOpenSSLServerConnectionCreator, IOpenSSLClientConnectionCreator
  34. )
  35. from twisted.internet.error import ConnectionDone, ConnectionLost
  36. from twisted.internet.defer import Deferred, gatherResults
  37. from twisted.internet.protocol import Protocol, ClientFactory, ServerFactory
  38. from twisted.internet.task import TaskStopped
  39. from twisted.protocols.loopback import loopbackAsync, collapsingPumpPolicy
  40. from twisted.trial.unittest import TestCase, SynchronousTestCase
  41. from twisted.test.test_tcp import ConnectionLostNotifyingProtocol
  42. from twisted.test.proto_helpers import StringTransport, NonStreamingProducer
  43. class HandshakeCallbackContextFactory:
  44. """
  45. L{HandshakeCallbackContextFactory} is a factory for SSL contexts which
  46. allows applications to get notification when the SSL handshake completes.
  47. @ivar _finished: A L{Deferred} which will be called back when the handshake
  48. is done.
  49. """
  50. # pyOpenSSL needs to expose this.
  51. # https://bugs.launchpad.net/pyopenssl/+bug/372832
  52. SSL_CB_HANDSHAKE_DONE = 0x20
  53. def __init__(self, method=TLSv1_METHOD):
  54. self._finished = Deferred()
  55. self._method = method
  56. def factoryAndDeferred(cls):
  57. """
  58. Create a new L{HandshakeCallbackContextFactory} and return a two-tuple
  59. of it and a L{Deferred} which will fire when a connection created with
  60. it completes a TLS handshake.
  61. """
  62. contextFactory = cls()
  63. return contextFactory, contextFactory._finished
  64. factoryAndDeferred = classmethod(factoryAndDeferred)
  65. def _info(self, connection, where, ret):
  66. """
  67. This is the "info callback" on the context. It will be called
  68. periodically by pyOpenSSL with information about the state of a
  69. connection. When it indicates the handshake is complete, it will fire
  70. C{self._finished}.
  71. """
  72. if where & self.SSL_CB_HANDSHAKE_DONE:
  73. self._finished.callback(None)
  74. def getContext(self):
  75. """
  76. Create and return an SSL context configured to use L{self._info} as the
  77. info callback.
  78. """
  79. context = Context(self._method)
  80. context.set_info_callback(self._info)
  81. return context
  82. class AccumulatingProtocol(Protocol):
  83. """
  84. A protocol which collects the bytes it receives and closes its connection
  85. after receiving a certain minimum of data.
  86. @ivar howMany: The number of bytes of data to wait for before closing the
  87. connection.
  88. @ivar received: A L{list} of L{bytes} of the bytes received so far.
  89. """
  90. def __init__(self, howMany):
  91. self.howMany = howMany
  92. def connectionMade(self):
  93. self.received = []
  94. def dataReceived(self, data):
  95. self.received.append(data)
  96. if sum(map(len, self.received)) >= self.howMany:
  97. self.transport.loseConnection()
  98. def connectionLost(self, reason):
  99. if not reason.check(ConnectionDone):
  100. log.err(reason)
  101. def buildTLSProtocol(server=False, transport=None, fakeConnection=None):
  102. """
  103. Create a protocol hooked up to a TLS transport hooked up to a
  104. StringTransport.
  105. """
  106. # We want to accumulate bytes without disconnecting, so set high limit:
  107. clientProtocol = AccumulatingProtocol(999999999999)
  108. clientFactory = ClientFactory()
  109. clientFactory.protocol = lambda: clientProtocol
  110. if fakeConnection:
  111. @implementer(IOpenSSLServerConnectionCreator,
  112. IOpenSSLClientConnectionCreator)
  113. class HardCodedConnection(object):
  114. def clientConnectionForTLS(self, tlsProtocol):
  115. return fakeConnection
  116. serverConnectionForTLS = clientConnectionForTLS
  117. contextFactory = HardCodedConnection()
  118. else:
  119. if server:
  120. contextFactory = ServerTLSContext()
  121. else:
  122. contextFactory = ClientTLSContext()
  123. wrapperFactory = TLSMemoryBIOFactory(
  124. contextFactory, not server, clientFactory)
  125. sslProtocol = wrapperFactory.buildProtocol(None)
  126. if transport is None:
  127. transport = StringTransport()
  128. sslProtocol.makeConnection(transport)
  129. return clientProtocol, sslProtocol
  130. class TLSMemoryBIOFactoryTests(TestCase):
  131. """
  132. Ensure TLSMemoryBIOFactory logging acts correctly.
  133. """
  134. def test_quiet(self):
  135. """
  136. L{TLSMemoryBIOFactory.doStart} and L{TLSMemoryBIOFactory.doStop} do
  137. not log any messages.
  138. """
  139. contextFactory = ServerTLSContext()
  140. logs = []
  141. logger = logs.append
  142. log.addObserver(logger)
  143. self.addCleanup(log.removeObserver, logger)
  144. wrappedFactory = ServerFactory()
  145. # Disable logging on the wrapped factory:
  146. wrappedFactory.doStart = lambda: None
  147. wrappedFactory.doStop = lambda: None
  148. factory = TLSMemoryBIOFactory(contextFactory, False, wrappedFactory)
  149. factory.doStart()
  150. factory.doStop()
  151. self.assertEqual(logs, [])
  152. def test_logPrefix(self):
  153. """
  154. L{TLSMemoryBIOFactory.logPrefix} amends the wrapped factory's log prefix
  155. with a short string (C{"TLS"}) indicating the wrapping, rather than its
  156. full class name.
  157. """
  158. contextFactory = ServerTLSContext()
  159. factory = TLSMemoryBIOFactory(contextFactory, False, ServerFactory())
  160. self.assertEqual("ServerFactory (TLS)", factory.logPrefix())
  161. def test_logPrefixFallback(self):
  162. """
  163. If the wrapped factory does not provide L{ILoggingContext},
  164. L{TLSMemoryBIOFactory.logPrefix} uses the wrapped factory's class name.
  165. """
  166. class NoFactory(object):
  167. pass
  168. contextFactory = ServerTLSContext()
  169. factory = TLSMemoryBIOFactory(contextFactory, False, NoFactory())
  170. self.assertEqual("NoFactory (TLS)", factory.logPrefix())
  171. def handshakingClientAndServer(clientGreetingData=None,
  172. clientAbortAfterHandshake=False):
  173. """
  174. Construct a client and server L{TLSMemoryBIOProtocol} connected by an IO
  175. pump.
  176. @param greetingData: The data which should be written in L{connectionMade}.
  177. @type greetingData: L{bytes}
  178. @return: 3-tuple of client, server, L{twisted.test.iosim.IOPump}
  179. """
  180. authCert, serverCert = certificatesForAuthorityAndServer()
  181. @implementer(IHandshakeListener)
  182. class Client(AccumulatingProtocol, object):
  183. handshook = False
  184. peerAfterHandshake = None
  185. def connectionMade(self):
  186. super(Client, self).connectionMade()
  187. if clientGreetingData is not None:
  188. self.transport.write(clientGreetingData)
  189. def handshakeCompleted(self):
  190. self.handshook = True
  191. self.peerAfterHandshake = self.transport.getPeerCertificate()
  192. if clientAbortAfterHandshake:
  193. self.transport.abortConnection()
  194. def connectionLost(self, reason):
  195. pass
  196. @implementer(IHandshakeListener)
  197. class Server(AccumulatingProtocol, object):
  198. handshaked = False
  199. def handshakeCompleted(self):
  200. self.handshaked = True
  201. def connectionLost(self, reason):
  202. pass
  203. clientF = TLSMemoryBIOFactory(
  204. optionsForClientTLS(u"example.com", trustRoot=authCert),
  205. isClient=True,
  206. wrappedFactory=ClientFactory.forProtocol(lambda: Client(999999))
  207. )
  208. serverF = TLSMemoryBIOFactory(
  209. serverCert.options(), isClient=False,
  210. wrappedFactory=ServerFactory.forProtocol(lambda: Server(999999))
  211. )
  212. client, server, pump = connectedServerAndClient(
  213. lambda: serverF.buildProtocol(None),
  214. lambda: clientF.buildProtocol(None),
  215. greet=False,
  216. )
  217. return client, server, pump
  218. class DeterministicTLSMemoryBIOTests(SynchronousTestCase):
  219. """
  220. Test for the implementation of L{ISSLTransport} which runs over another
  221. transport.
  222. @note: Prefer to add test cases to this suite, in this style, using
  223. L{connectedServerAndClient}, rather than returning L{Deferred}s.
  224. """
  225. def test_handshakeNotification(self):
  226. """
  227. The completion of the TLS handshake calls C{handshakeCompleted} on
  228. L{Protocol} objects that provide L{IHandshakeListener}. At the time
  229. C{handshakeCompleted} is invoked, the transport's peer certificate will
  230. have been initialized.
  231. """
  232. client, server, pump = handshakingClientAndServer()
  233. self.assertEqual(client.wrappedProtocol.handshook, False)
  234. self.assertEqual(server.wrappedProtocol.handshaked, False)
  235. pump.flush()
  236. self.assertEqual(client.wrappedProtocol.handshook, True)
  237. self.assertEqual(server.wrappedProtocol.handshaked, True)
  238. self.assertIsNot(client.wrappedProtocol.peerAfterHandshake, None)
  239. def test_handshakeStopWriting(self):
  240. """
  241. If some data is written to the transport in C{connectionMade}, but
  242. C{handshakeDone} doesn't like something it sees about the handshake, it
  243. can use C{abortConnection} to ensure that the application never
  244. receives that data.
  245. """
  246. client, server, pump = handshakingClientAndServer(b"untrustworthy",
  247. True)
  248. pump.flush()
  249. self.assertEqual(server.wrappedProtocol.received, [])
  250. class TLSMemoryBIOTests(TestCase):
  251. """
  252. Tests for the implementation of L{ISSLTransport} which runs over another
  253. L{ITransport}.
  254. """
  255. def test_interfaces(self):
  256. """
  257. L{TLSMemoryBIOProtocol} instances provide L{ISSLTransport} and
  258. L{ISystemHandle}.
  259. """
  260. proto = TLSMemoryBIOProtocol(None, None)
  261. self.assertTrue(ISSLTransport.providedBy(proto))
  262. self.assertTrue(ISystemHandle.providedBy(proto))
  263. def test_wrappedProtocolInterfaces(self):
  264. """
  265. L{TLSMemoryBIOProtocol} instances provide the interfaces provided by
  266. the transport they wrap.
  267. """
  268. class ITransport(Interface):
  269. pass
  270. class MyTransport(object):
  271. def write(self, data):
  272. pass
  273. clientFactory = ClientFactory()
  274. contextFactory = ClientTLSContext()
  275. wrapperFactory = TLSMemoryBIOFactory(
  276. contextFactory, True, clientFactory)
  277. transport = MyTransport()
  278. directlyProvides(transport, ITransport)
  279. tlsProtocol = TLSMemoryBIOProtocol(wrapperFactory, Protocol())
  280. tlsProtocol.makeConnection(transport)
  281. self.assertTrue(ITransport.providedBy(tlsProtocol))
  282. def test_getHandle(self):
  283. """
  284. L{TLSMemoryBIOProtocol.getHandle} returns the L{OpenSSL.SSL.Connection}
  285. instance it uses to actually implement TLS.
  286. This may seem odd. In fact, it is. The L{OpenSSL.SSL.Connection} is
  287. not actually the "system handle" here, nor even an object the reactor
  288. knows about directly. However, L{twisted.internet.ssl.Certificate}'s
  289. C{peerFromTransport} and C{hostFromTransport} methods depend on being
  290. able to get an L{OpenSSL.SSL.Connection} object in order to work
  291. properly. Implementing L{ISystemHandle.getHandle} like this is the
  292. easiest way for those APIs to be made to work. If they are changed,
  293. then it may make sense to get rid of this implementation of
  294. L{ISystemHandle} and return the underlying socket instead.
  295. """
  296. factory = ClientFactory()
  297. contextFactory = ClientTLSContext()
  298. wrapperFactory = TLSMemoryBIOFactory(contextFactory, True, factory)
  299. proto = TLSMemoryBIOProtocol(wrapperFactory, Protocol())
  300. transport = StringTransport()
  301. proto.makeConnection(transport)
  302. self.assertIsInstance(proto.getHandle(), ConnectionType)
  303. def test_makeConnection(self):
  304. """
  305. When L{TLSMemoryBIOProtocol} is connected to a transport, it connects
  306. the protocol it wraps to a transport.
  307. """
  308. clientProtocol = Protocol()
  309. clientFactory = ClientFactory()
  310. clientFactory.protocol = lambda: clientProtocol
  311. contextFactory = ClientTLSContext()
  312. wrapperFactory = TLSMemoryBIOFactory(
  313. contextFactory, True, clientFactory)
  314. sslProtocol = wrapperFactory.buildProtocol(None)
  315. transport = StringTransport()
  316. sslProtocol.makeConnection(transport)
  317. self.assertIsNotNone(clientProtocol.transport)
  318. self.assertIsNot(clientProtocol.transport, transport)
  319. self.assertIs(clientProtocol.transport, sslProtocol)
  320. def handshakeProtocols(self):
  321. """
  322. Start handshake between TLS client and server.
  323. """
  324. clientFactory = ClientFactory()
  325. clientFactory.protocol = Protocol
  326. clientContextFactory, handshakeDeferred = (
  327. HandshakeCallbackContextFactory.factoryAndDeferred())
  328. wrapperFactory = TLSMemoryBIOFactory(
  329. clientContextFactory, True, clientFactory)
  330. sslClientProtocol = wrapperFactory.buildProtocol(None)
  331. serverFactory = ServerFactory()
  332. serverFactory.protocol = Protocol
  333. serverContextFactory = ServerTLSContext()
  334. wrapperFactory = TLSMemoryBIOFactory(
  335. serverContextFactory, False, serverFactory)
  336. sslServerProtocol = wrapperFactory.buildProtocol(None)
  337. connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
  338. return (sslClientProtocol, sslServerProtocol, handshakeDeferred,
  339. connectionDeferred)
  340. def test_handshake(self):
  341. """
  342. The TLS handshake is performed when L{TLSMemoryBIOProtocol} is
  343. connected to a transport.
  344. """
  345. tlsClient, tlsServer, handshakeDeferred, _ = self.handshakeProtocols()
  346. # Only wait for the handshake to complete. Anything after that isn't
  347. # important here.
  348. return handshakeDeferred
  349. def test_handshakeFailure(self):
  350. """
  351. L{TLSMemoryBIOProtocol} reports errors in the handshake process to the
  352. application-level protocol object using its C{connectionLost} method
  353. and disconnects the underlying transport.
  354. """
  355. clientConnectionLost = Deferred()
  356. clientFactory = ClientFactory()
  357. clientFactory.protocol = (
  358. lambda: ConnectionLostNotifyingProtocol(
  359. clientConnectionLost))
  360. clientContextFactory = HandshakeCallbackContextFactory()
  361. wrapperFactory = TLSMemoryBIOFactory(
  362. clientContextFactory, True, clientFactory)
  363. sslClientProtocol = wrapperFactory.buildProtocol(None)
  364. serverConnectionLost = Deferred()
  365. serverFactory = ServerFactory()
  366. serverFactory.protocol = (
  367. lambda: ConnectionLostNotifyingProtocol(
  368. serverConnectionLost))
  369. # This context factory rejects any clients which do not present a
  370. # certificate.
  371. certificateData = FilePath(certPath).getContent()
  372. certificate = PrivateCertificate.loadPEM(certificateData)
  373. serverContextFactory = certificate.options(certificate)
  374. wrapperFactory = TLSMemoryBIOFactory(
  375. serverContextFactory, False, serverFactory)
  376. sslServerProtocol = wrapperFactory.buildProtocol(None)
  377. connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
  378. def cbConnectionLost(protocol):
  379. # The connection should close on its own in response to the error
  380. # induced by the client not supplying the required certificate.
  381. # After that, check to make sure the protocol's connectionLost was
  382. # called with the right thing.
  383. protocol.lostConnectionReason.trap(Error)
  384. clientConnectionLost.addCallback(cbConnectionLost)
  385. serverConnectionLost.addCallback(cbConnectionLost)
  386. # Additionally, the underlying transport should have been told to
  387. # go away.
  388. return gatherResults([
  389. clientConnectionLost, serverConnectionLost,
  390. connectionDeferred])
  391. def test_getPeerCertificate(self):
  392. """
  393. L{TLSMemoryBIOProtocol.getPeerCertificate} returns the
  394. L{OpenSSL.crypto.X509Type} instance representing the peer's
  395. certificate.
  396. """
  397. # Set up a client and server so there's a certificate to grab.
  398. clientFactory = ClientFactory()
  399. clientFactory.protocol = Protocol
  400. clientContextFactory, handshakeDeferred = (
  401. HandshakeCallbackContextFactory.factoryAndDeferred())
  402. wrapperFactory = TLSMemoryBIOFactory(
  403. clientContextFactory, True, clientFactory)
  404. sslClientProtocol = wrapperFactory.buildProtocol(None)
  405. serverFactory = ServerFactory()
  406. serverFactory.protocol = Protocol
  407. serverContextFactory = ServerTLSContext()
  408. wrapperFactory = TLSMemoryBIOFactory(
  409. serverContextFactory, False, serverFactory)
  410. sslServerProtocol = wrapperFactory.buildProtocol(None)
  411. loopbackAsync(sslServerProtocol, sslClientProtocol)
  412. # Wait for the handshake
  413. def cbHandshook(ignored):
  414. # Grab the server's certificate and check it out
  415. cert = sslClientProtocol.getPeerCertificate()
  416. self.assertIsInstance(cert, X509Type)
  417. self.assertEqual(
  418. cert.digest('sha1'),
  419. # openssl x509 -noout -sha1 -fingerprint -in server.pem
  420. b'45:DD:FD:E2:BD:BF:8B:D0:00:B7:D2:7A:BB:20:F5:34:05:4B:15:80')
  421. handshakeDeferred.addCallback(cbHandshook)
  422. return handshakeDeferred
  423. def test_writeAfterHandshake(self):
  424. """
  425. Bytes written to L{TLSMemoryBIOProtocol} before the handshake is
  426. complete are received by the protocol on the other side of the
  427. connection once the handshake succeeds.
  428. """
  429. data = b"some bytes"
  430. clientProtocol = Protocol()
  431. clientFactory = ClientFactory()
  432. clientFactory.protocol = lambda: clientProtocol
  433. clientContextFactory, handshakeDeferred = (
  434. HandshakeCallbackContextFactory.factoryAndDeferred())
  435. wrapperFactory = TLSMemoryBIOFactory(
  436. clientContextFactory, True, clientFactory)
  437. sslClientProtocol = wrapperFactory.buildProtocol(None)
  438. serverProtocol = AccumulatingProtocol(len(data))
  439. serverFactory = ServerFactory()
  440. serverFactory.protocol = lambda: serverProtocol
  441. serverContextFactory = ServerTLSContext()
  442. wrapperFactory = TLSMemoryBIOFactory(
  443. serverContextFactory, False, serverFactory)
  444. sslServerProtocol = wrapperFactory.buildProtocol(None)
  445. connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
  446. # Wait for the handshake to finish before writing anything.
  447. def cbHandshook(ignored):
  448. clientProtocol.transport.write(data)
  449. # The server will drop the connection once it gets the bytes.
  450. return connectionDeferred
  451. handshakeDeferred.addCallback(cbHandshook)
  452. # Once the connection is lost, make sure the server received the
  453. # expected bytes.
  454. def cbDisconnected(ignored):
  455. self.assertEqual(b"".join(serverProtocol.received), data)
  456. handshakeDeferred.addCallback(cbDisconnected)
  457. return handshakeDeferred
  458. def writeBeforeHandshakeTest(self, sendingProtocol, data):
  459. """
  460. Run test where client sends data before handshake, given the sending
  461. protocol and expected bytes.
  462. """
  463. clientFactory = ClientFactory()
  464. clientFactory.protocol = sendingProtocol
  465. clientContextFactory, handshakeDeferred = (
  466. HandshakeCallbackContextFactory.factoryAndDeferred())
  467. wrapperFactory = TLSMemoryBIOFactory(
  468. clientContextFactory, True, clientFactory)
  469. sslClientProtocol = wrapperFactory.buildProtocol(None)
  470. serverProtocol = AccumulatingProtocol(len(data))
  471. serverFactory = ServerFactory()
  472. serverFactory.protocol = lambda: serverProtocol
  473. serverContextFactory = ServerTLSContext()
  474. wrapperFactory = TLSMemoryBIOFactory(
  475. serverContextFactory, False, serverFactory)
  476. sslServerProtocol = wrapperFactory.buildProtocol(None)
  477. connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
  478. # Wait for the connection to end, then make sure the server received
  479. # the bytes sent by the client.
  480. def cbConnectionDone(ignored):
  481. self.assertEqual(b"".join(serverProtocol.received), data)
  482. connectionDeferred.addCallback(cbConnectionDone)
  483. return connectionDeferred
  484. def test_writeBeforeHandshake(self):
  485. """
  486. Bytes written to L{TLSMemoryBIOProtocol} before the handshake is
  487. complete are received by the protocol on the other side of the
  488. connection once the handshake succeeds.
  489. """
  490. data = b"some bytes"
  491. class SimpleSendingProtocol(Protocol):
  492. def connectionMade(self):
  493. self.transport.write(data)
  494. return self.writeBeforeHandshakeTest(SimpleSendingProtocol, data)
  495. def test_writeSequence(self):
  496. """
  497. Bytes written to L{TLSMemoryBIOProtocol} with C{writeSequence} are
  498. received by the protocol on the other side of the connection.
  499. """
  500. data = b"some bytes"
  501. class SimpleSendingProtocol(Protocol):
  502. def connectionMade(self):
  503. self.transport.writeSequence(list(iterbytes(data)))
  504. return self.writeBeforeHandshakeTest(SimpleSendingProtocol, data)
  505. def test_writeAfterLoseConnection(self):
  506. """
  507. Bytes written to L{TLSMemoryBIOProtocol} after C{loseConnection} is
  508. called are not transmitted (unless there is a registered producer,
  509. which will be tested elsewhere).
  510. """
  511. data = b"some bytes"
  512. class SimpleSendingProtocol(Protocol):
  513. def connectionMade(self):
  514. self.transport.write(data)
  515. self.transport.loseConnection()
  516. self.transport.write(b"hello")
  517. self.transport.writeSequence([b"world"])
  518. return self.writeBeforeHandshakeTest(SimpleSendingProtocol, data)
  519. def test_writeUnicodeRaisesTypeError(self):
  520. """
  521. Writing C{unicode} to L{TLSMemoryBIOProtocol} throws a C{TypeError}.
  522. """
  523. notBytes = u"hello"
  524. result = []
  525. class SimpleSendingProtocol(Protocol):
  526. def connectionMade(self):
  527. try:
  528. self.transport.write(notBytes)
  529. except TypeError:
  530. result.append(True)
  531. self.transport.write(b"bytes")
  532. self.transport.loseConnection()
  533. d = self.writeBeforeHandshakeTest(SimpleSendingProtocol, b"bytes")
  534. return d.addCallback(lambda ign: self.assertEqual(result, [True]))
  535. def test_multipleWrites(self):
  536. """
  537. If multiple separate TLS messages are received in a single chunk from
  538. the underlying transport, all of the application bytes from each
  539. message are delivered to the application-level protocol.
  540. """
  541. data = [b'a', b'b', b'c', b'd', b'e', b'f', b'g', b'h', b'i']
  542. class SimpleSendingProtocol(Protocol):
  543. def connectionMade(self):
  544. for b in data:
  545. self.transport.write(b)
  546. clientFactory = ClientFactory()
  547. clientFactory.protocol = SimpleSendingProtocol
  548. clientContextFactory = HandshakeCallbackContextFactory()
  549. wrapperFactory = TLSMemoryBIOFactory(
  550. clientContextFactory, True, clientFactory)
  551. sslClientProtocol = wrapperFactory.buildProtocol(None)
  552. serverProtocol = AccumulatingProtocol(sum(map(len, data)))
  553. serverFactory = ServerFactory()
  554. serverFactory.protocol = lambda: serverProtocol
  555. serverContextFactory = ServerTLSContext()
  556. wrapperFactory = TLSMemoryBIOFactory(
  557. serverContextFactory, False, serverFactory)
  558. sslServerProtocol = wrapperFactory.buildProtocol(None)
  559. connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol, collapsingPumpPolicy)
  560. # Wait for the connection to end, then make sure the server received
  561. # the bytes sent by the client.
  562. def cbConnectionDone(ignored):
  563. self.assertEqual(b"".join(serverProtocol.received), b''.join(data))
  564. connectionDeferred.addCallback(cbConnectionDone)
  565. return connectionDeferred
  566. def hugeWrite(self, method=TLSv1_METHOD):
  567. """
  568. If a very long string is passed to L{TLSMemoryBIOProtocol.write}, any
  569. trailing part of it which cannot be send immediately is buffered and
  570. sent later.
  571. """
  572. data = b"some bytes"
  573. factor = 2 ** 20
  574. class SimpleSendingProtocol(Protocol):
  575. def connectionMade(self):
  576. self.transport.write(data * factor)
  577. clientFactory = ClientFactory()
  578. clientFactory.protocol = SimpleSendingProtocol
  579. clientContextFactory = HandshakeCallbackContextFactory(method=method)
  580. wrapperFactory = TLSMemoryBIOFactory(
  581. clientContextFactory, True, clientFactory)
  582. sslClientProtocol = wrapperFactory.buildProtocol(None)
  583. serverProtocol = AccumulatingProtocol(len(data) * factor)
  584. serverFactory = ServerFactory()
  585. serverFactory.protocol = lambda: serverProtocol
  586. serverContextFactory = ServerTLSContext(method=method)
  587. wrapperFactory = TLSMemoryBIOFactory(
  588. serverContextFactory, False, serverFactory)
  589. sslServerProtocol = wrapperFactory.buildProtocol(None)
  590. connectionDeferred = loopbackAsync(sslServerProtocol, sslClientProtocol)
  591. # Wait for the connection to end, then make sure the server received
  592. # the bytes sent by the client.
  593. def cbConnectionDone(ignored):
  594. self.assertEqual(b"".join(serverProtocol.received), data * factor)
  595. connectionDeferred.addCallback(cbConnectionDone)
  596. return connectionDeferred
  597. def test_hugeWrite_TLSv1(self):
  598. return self.hugeWrite()
  599. def test_hugeWrite_TLSv1_1(self):
  600. return self.hugeWrite(method=TLSv1_1_METHOD)
  601. def test_hugeWrite_TLSv1_2(self):
  602. return self.hugeWrite(method=TLSv1_2_METHOD)
  603. def test_disorderlyShutdown(self):
  604. """
  605. If a L{TLSMemoryBIOProtocol} loses its connection unexpectedly, this is
  606. reported to the application.
  607. """
  608. clientConnectionLost = Deferred()
  609. clientFactory = ClientFactory()
  610. clientFactory.protocol = (
  611. lambda: ConnectionLostNotifyingProtocol(
  612. clientConnectionLost))
  613. clientContextFactory = HandshakeCallbackContextFactory()
  614. wrapperFactory = TLSMemoryBIOFactory(
  615. clientContextFactory, True, clientFactory)
  616. sslClientProtocol = wrapperFactory.buildProtocol(None)
  617. # Client speaks first, so the server can be dumb.
  618. serverProtocol = Protocol()
  619. loopbackAsync(serverProtocol, sslClientProtocol)
  620. # Now destroy the connection.
  621. serverProtocol.transport.loseConnection()
  622. # And when the connection completely dies, check the reason.
  623. def cbDisconnected(clientProtocol):
  624. clientProtocol.lostConnectionReason.trap(Error, ConnectionLost)
  625. clientConnectionLost.addCallback(cbDisconnected)
  626. return clientConnectionLost
  627. def test_loseConnectionAfterHandshake(self):
  628. """
  629. L{TLSMemoryBIOProtocol.loseConnection} sends a TLS close alert and
  630. shuts down the underlying connection cleanly on both sides, after
  631. transmitting all buffered data.
  632. """
  633. class NotifyingProtocol(ConnectionLostNotifyingProtocol):
  634. def __init__(self, onConnectionLost):
  635. ConnectionLostNotifyingProtocol.__init__(self,
  636. onConnectionLost)
  637. self.data = []
  638. def dataReceived(self, data):
  639. self.data.append(data)
  640. clientConnectionLost = Deferred()
  641. clientFactory = ClientFactory()
  642. clientProtocol = NotifyingProtocol(clientConnectionLost)
  643. clientFactory.protocol = lambda: clientProtocol
  644. clientContextFactory, handshakeDeferred = (
  645. HandshakeCallbackContextFactory.factoryAndDeferred())
  646. wrapperFactory = TLSMemoryBIOFactory(
  647. clientContextFactory, True, clientFactory)
  648. sslClientProtocol = wrapperFactory.buildProtocol(None)
  649. serverConnectionLost = Deferred()
  650. serverProtocol = NotifyingProtocol(serverConnectionLost)
  651. serverFactory = ServerFactory()
  652. serverFactory.protocol = lambda: serverProtocol
  653. serverContextFactory = ServerTLSContext()
  654. wrapperFactory = TLSMemoryBIOFactory(
  655. serverContextFactory, False, serverFactory)
  656. sslServerProtocol = wrapperFactory.buildProtocol(None)
  657. loopbackAsync(sslServerProtocol, sslClientProtocol)
  658. chunkOfBytes = b"123456890" * 100000
  659. # Wait for the handshake before dropping the connection.
  660. def cbHandshake(ignored):
  661. # Write more than a single bio_read, to ensure client will still
  662. # have some data it needs to write when it receives the TLS close
  663. # alert, and that simply doing a single bio_read won't be
  664. # sufficient. Thus we will verify that any amount of buffered data
  665. # will be written out before the connection is closed, rather than
  666. # just small amounts that can be returned in a single bio_read:
  667. clientProtocol.transport.write(chunkOfBytes)
  668. serverProtocol.transport.write(b'x')
  669. serverProtocol.transport.loseConnection()
  670. # Now wait for the client and server to notice.
  671. return gatherResults([clientConnectionLost, serverConnectionLost])
  672. handshakeDeferred.addCallback(cbHandshake)
  673. # Wait for the connection to end, then make sure the client and server
  674. # weren't notified of a handshake failure that would cause the test to
  675. # fail.
  676. def cbConnectionDone(result):
  677. (clientProtocol, serverProtocol) = result
  678. clientProtocol.lostConnectionReason.trap(ConnectionDone)
  679. serverProtocol.lostConnectionReason.trap(ConnectionDone)
  680. # The server should have received all bytes sent by the client:
  681. self.assertEqual(b"".join(serverProtocol.data), chunkOfBytes)
  682. # The server should have closed its underlying transport, in
  683. # addition to whatever it did to shut down the TLS layer.
  684. self.assertTrue(serverProtocol.transport.q.disconnect)
  685. # The client should also have closed its underlying transport once
  686. # it saw the server shut down the TLS layer, so as to avoid relying
  687. # on the server to close the underlying connection.
  688. self.assertTrue(clientProtocol.transport.q.disconnect)
  689. handshakeDeferred.addCallback(cbConnectionDone)
  690. return handshakeDeferred
  691. def test_connectionLostOnlyAfterUnderlyingCloses(self):
  692. """
  693. The user protocol's connectionLost is only called when transport
  694. underlying TLS is disconnected.
  695. """
  696. class LostProtocol(Protocol):
  697. disconnected = None
  698. def connectionLost(self, reason):
  699. self.disconnected = reason
  700. wrapperFactory = TLSMemoryBIOFactory(ClientTLSContext(),
  701. True, ClientFactory())
  702. protocol = LostProtocol()
  703. tlsProtocol = TLSMemoryBIOProtocol(wrapperFactory, protocol)
  704. transport = StringTransport()
  705. tlsProtocol.makeConnection(transport)
  706. # Pretend TLS shutdown finished cleanly; the underlying transport
  707. # should be told to close, but the user protocol should not yet be
  708. # notified:
  709. tlsProtocol._tlsShutdownFinished(None)
  710. self.assertTrue(transport.disconnecting)
  711. self.assertIsNone(protocol.disconnected)
  712. # Now close the underlying connection; the user protocol should be
  713. # notified with the given reason (since TLS closed cleanly):
  714. tlsProtocol.connectionLost(Failure(ConnectionLost("ono")))
  715. self.assertTrue(protocol.disconnected.check(ConnectionLost))
  716. self.assertEqual(protocol.disconnected.value.args, ("ono",))
  717. def test_loseConnectionTwice(self):
  718. """
  719. If TLSMemoryBIOProtocol.loseConnection is called multiple times, all
  720. but the first call have no effect.
  721. """
  722. tlsClient, tlsServer, handshakeDeferred, disconnectDeferred = (
  723. self.handshakeProtocols())
  724. self.successResultOf(handshakeDeferred)
  725. # Make sure loseConnection calls _shutdownTLS the first time (mostly
  726. # to make sure we've overriding it correctly):
  727. calls = []
  728. def _shutdownTLS(shutdown=tlsClient._shutdownTLS):
  729. calls.append(1)
  730. return shutdown()
  731. tlsClient._shutdownTLS = _shutdownTLS
  732. tlsClient.write(b'x')
  733. tlsClient.loseConnection()
  734. self.assertTrue(tlsClient.disconnecting)
  735. self.assertEqual(calls, [1])
  736. # Make sure _shutdownTLS isn't called a second time:
  737. tlsClient.loseConnection()
  738. self.assertEqual(calls, [1])
  739. # We do successfully disconnect at some point:
  740. return disconnectDeferred
  741. def test_unexpectedEOF(self):
  742. """
  743. Unexpected disconnects get converted to ConnectionLost errors.
  744. """
  745. tlsClient, tlsServer, handshakeDeferred, disconnectDeferred = (
  746. self.handshakeProtocols())
  747. serverProtocol = tlsServer.wrappedProtocol
  748. data = []
  749. reason = []
  750. serverProtocol.dataReceived = data.append
  751. serverProtocol.connectionLost = reason.append
  752. # Write data, then disconnect *underlying* transport, resulting in an
  753. # unexpected TLS disconnect:
  754. def handshakeDone(ign):
  755. tlsClient.write(b"hello")
  756. tlsClient.transport.loseConnection()
  757. handshakeDeferred.addCallback(handshakeDone)
  758. # Receiver should be disconnected, with ConnectionLost notification
  759. # (masking the Unexpected EOF SSL error):
  760. def disconnected(ign):
  761. self.assertTrue(reason[0].check(ConnectionLost), reason[0])
  762. disconnectDeferred.addCallback(disconnected)
  763. return disconnectDeferred
  764. def test_errorWriting(self):
  765. """
  766. Errors while writing cause the protocols to be disconnected.
  767. """
  768. tlsClient, tlsServer, handshakeDeferred, disconnectDeferred = (
  769. self.handshakeProtocols())
  770. reason = []
  771. tlsClient.wrappedProtocol.connectionLost = reason.append
  772. # Pretend TLS connection is unhappy sending:
  773. class Wrapper(object):
  774. def __init__(self, wrapped):
  775. self._wrapped = wrapped
  776. def __getattr__(self, attr):
  777. return getattr(self._wrapped, attr)
  778. def send(self, *args):
  779. raise Error("ONO!")
  780. tlsClient._tlsConnection = Wrapper(tlsClient._tlsConnection)
  781. # Write some data:
  782. def handshakeDone(ign):
  783. tlsClient.write(b"hello")
  784. handshakeDeferred.addCallback(handshakeDone)
  785. # Failed writer should be disconnected with SSL error:
  786. def disconnected(ign):
  787. self.assertTrue(reason[0].check(Error), reason[0])
  788. disconnectDeferred.addCallback(disconnected)
  789. return disconnectDeferred
  790. class TLSProducerTests(TestCase):
  791. """
  792. The TLS transport must support the IConsumer interface.
  793. """
  794. def drain(self, transport, allowEmpty=False):
  795. """
  796. Drain the bytes currently pending write from a L{StringTransport}, then
  797. clear it, since those bytes have been consumed.
  798. @param transport: The L{StringTransport} to get the bytes from.
  799. @type transport: L{StringTransport}
  800. @param allowEmpty: Allow the test to pass even if the transport has no
  801. outgoing bytes in it.
  802. @type allowEmpty: L{bool}
  803. @return: the outgoing bytes from the given transport
  804. @rtype: L{bytes}
  805. """
  806. value = transport.value()
  807. transport.clear()
  808. self.assertEqual(bool(allowEmpty or value), True)
  809. return value
  810. def setupStreamingProducer(self, transport=None, fakeConnection=None,
  811. server=False):
  812. class HistoryStringTransport(StringTransport):
  813. def __init__(self):
  814. StringTransport.__init__(self)
  815. self.producerHistory = []
  816. def pauseProducing(self):
  817. self.producerHistory.append("pause")
  818. StringTransport.pauseProducing(self)
  819. def resumeProducing(self):
  820. self.producerHistory.append("resume")
  821. StringTransport.resumeProducing(self)
  822. def stopProducing(self):
  823. self.producerHistory.append("stop")
  824. StringTransport.stopProducing(self)
  825. applicationProtocol, tlsProtocol = buildTLSProtocol(
  826. transport=transport, fakeConnection=fakeConnection,
  827. server=server)
  828. producer = HistoryStringTransport()
  829. applicationProtocol.transport.registerProducer(producer, True)
  830. self.assertTrue(tlsProtocol.transport.streaming)
  831. return applicationProtocol, tlsProtocol, producer
  832. def flushTwoTLSProtocols(self, tlsProtocol, serverTLSProtocol):
  833. """
  834. Transfer bytes back and forth between two TLS protocols.
  835. """
  836. # We want to make sure all bytes are passed back and forth; JP
  837. # estimated that 3 rounds should be enough:
  838. for i in range(3):
  839. clientData = self.drain(tlsProtocol.transport, True)
  840. if clientData:
  841. serverTLSProtocol.dataReceived(clientData)
  842. serverData = self.drain(serverTLSProtocol.transport, True)
  843. if serverData:
  844. tlsProtocol.dataReceived(serverData)
  845. if not serverData and not clientData:
  846. break
  847. self.assertEqual(tlsProtocol.transport.value(), b"")
  848. self.assertEqual(serverTLSProtocol.transport.value(), b"")
  849. def test_producerDuringRenegotiation(self):
  850. """
  851. If we write some data to a TLS connection that is blocked waiting for a
  852. renegotiation with its peer, it will pause and resume its registered
  853. producer exactly once.
  854. """
  855. c, ct, cp = self.setupStreamingProducer()
  856. s, st, sp = self.setupStreamingProducer(server=True)
  857. self.flushTwoTLSProtocols(ct, st)
  858. # no public API for this yet because it's (mostly) unnecessary, but we
  859. # have to be prepared for a peer to do it to us
  860. tlsc = ct._tlsConnection
  861. tlsc.renegotiate()
  862. self.assertRaises(WantReadError, tlsc.do_handshake)
  863. ct._flushSendBIO()
  864. st.dataReceived(self.drain(ct.transport))
  865. payload = b'payload'
  866. s.transport.write(payload)
  867. s.transport.loseConnection()
  868. # give the client the server the client's response...
  869. ct.dataReceived(self.drain(st.transport))
  870. messageThatUnblocksTheServer = self.drain(ct.transport)
  871. # split it into just enough chunks that it would provoke the producer
  872. # with an incorrect implementation...
  873. for fragment in (messageThatUnblocksTheServer[0:1],
  874. messageThatUnblocksTheServer[1:2],
  875. messageThatUnblocksTheServer[2:]):
  876. st.dataReceived(fragment)
  877. self.assertEqual(st.transport.disconnecting, False)
  878. s.transport.unregisterProducer()
  879. self.flushTwoTLSProtocols(ct, st)
  880. self.assertEqual(st.transport.disconnecting, True)
  881. self.assertEqual(b''.join(c.received), payload)
  882. self.assertEqual(sp.producerHistory, ['pause', 'resume'])
  883. def test_streamingProducerPausedInNormalMode(self):
  884. """
  885. When the TLS transport is not blocked on reads, it correctly calls
  886. pauseProducing on the registered producer.
  887. """
  888. _, tlsProtocol, producer = self.setupStreamingProducer()
  889. # The TLS protocol's transport pretends to be full, pausing its
  890. # producer:
  891. tlsProtocol.transport.producer.pauseProducing()
  892. self.assertEqual(producer.producerState, 'paused')
  893. self.assertEqual(producer.producerHistory, ['pause'])
  894. self.assertTrue(tlsProtocol._producer._producerPaused)
  895. def test_streamingProducerResumedInNormalMode(self):
  896. """
  897. When the TLS transport is not blocked on reads, it correctly calls
  898. resumeProducing on the registered producer.
  899. """
  900. _, tlsProtocol, producer = self.setupStreamingProducer()
  901. tlsProtocol.transport.producer.pauseProducing()
  902. self.assertEqual(producer.producerHistory, ['pause'])
  903. # The TLS protocol's transport pretends to have written everything
  904. # out, so it resumes its producer:
  905. tlsProtocol.transport.producer.resumeProducing()
  906. self.assertEqual(producer.producerState, 'producing')
  907. self.assertEqual(producer.producerHistory, ['pause', 'resume'])
  908. self.assertFalse(tlsProtocol._producer._producerPaused)
  909. def test_streamingProducerPausedInWriteBlockedOnReadMode(self):
  910. """
  911. When the TLS transport is blocked on reads, it correctly calls
  912. pauseProducing on the registered producer.
  913. """
  914. clientProtocol, tlsProtocol, producer = self.setupStreamingProducer()
  915. # Write to TLS transport. Because we do this before the initial TLS
  916. # handshake is finished, writing bytes triggers a WantReadError,
  917. # indicating that until bytes are read for the handshake, more bytes
  918. # cannot be written. Thus writing bytes before the handshake should
  919. # cause the producer to be paused:
  920. clientProtocol.transport.write(b"hello")
  921. self.assertEqual(producer.producerState, 'paused')
  922. self.assertEqual(producer.producerHistory, ['pause'])
  923. self.assertTrue(tlsProtocol._producer._producerPaused)
  924. def test_streamingProducerResumedInWriteBlockedOnReadMode(self):
  925. """
  926. When the TLS transport is blocked on reads, it correctly calls
  927. resumeProducing on the registered producer.
  928. """
  929. clientProtocol, tlsProtocol, producer = self.setupStreamingProducer()
  930. # Write to TLS transport, triggering WantReadError; this should cause
  931. # the producer to be paused. We use a large chunk of data to make sure
  932. # large writes don't trigger multiple pauses:
  933. clientProtocol.transport.write(b"hello world" * 320000)
  934. self.assertEqual(producer.producerHistory, ['pause'])
  935. # Now deliver bytes that will fix the WantRead condition; this should
  936. # unpause the producer:
  937. serverProtocol, serverTLSProtocol = buildTLSProtocol(server=True)
  938. self.flushTwoTLSProtocols(tlsProtocol, serverTLSProtocol)
  939. self.assertEqual(producer.producerHistory, ['pause', 'resume'])
  940. self.assertFalse(tlsProtocol._producer._producerPaused)
  941. # Make sure we haven't disconnected for some reason:
  942. self.assertFalse(tlsProtocol.transport.disconnecting)
  943. self.assertEqual(producer.producerState, 'producing')
  944. def test_streamingProducerTwice(self):
  945. """
  946. Registering a streaming producer twice throws an exception.
  947. """
  948. clientProtocol, tlsProtocol, producer = self.setupStreamingProducer()
  949. originalProducer = tlsProtocol._producer
  950. producer2 = object()
  951. self.assertRaises(RuntimeError,
  952. clientProtocol.transport.registerProducer, producer2, True)
  953. self.assertIs(tlsProtocol._producer, originalProducer)
  954. def test_streamingProducerUnregister(self):
  955. """
  956. Unregistering a streaming producer removes it, reverting to initial state.
  957. """
  958. clientProtocol, tlsProtocol, producer = self.setupStreamingProducer()
  959. clientProtocol.transport.unregisterProducer()
  960. self.assertIsNone(tlsProtocol._producer)
  961. self.assertIsNone(tlsProtocol.transport.producer)
  962. def loseConnectionWithProducer(self, writeBlockedOnRead):
  963. """
  964. Common code for tests involving writes by producer after
  965. loseConnection is called.
  966. """
  967. clientProtocol, tlsProtocol, producer = self.setupStreamingProducer()
  968. serverProtocol, serverTLSProtocol = buildTLSProtocol(server=True)
  969. if not writeBlockedOnRead:
  970. # Do the initial handshake before write:
  971. self.flushTwoTLSProtocols(tlsProtocol, serverTLSProtocol)
  972. else:
  973. # In this case the write below will trigger write-blocked-on-read
  974. # condition...
  975. pass
  976. # Now write, then lose connection:
  977. clientProtocol.transport.write(b"x ")
  978. clientProtocol.transport.loseConnection()
  979. self.flushTwoTLSProtocols(tlsProtocol, serverTLSProtocol)
  980. # Underlying transport should not have loseConnection called yet, nor
  981. # should producer be stopped:
  982. self.assertFalse(tlsProtocol.transport.disconnecting)
  983. self.assertFalse("stop" in producer.producerHistory)
  984. # Writes from client to server should continue to go through, since we
  985. # haven't unregistered producer yet:
  986. clientProtocol.transport.write(b"hello")
  987. clientProtocol.transport.writeSequence([b" ", b"world"])
  988. # Unregister producer; this should trigger TLS shutdown:
  989. clientProtocol.transport.unregisterProducer()
  990. self.assertNotEqual(tlsProtocol.transport.value(), b"")
  991. self.assertFalse(tlsProtocol.transport.disconnecting)
  992. # Additional writes should not go through:
  993. clientProtocol.transport.write(b"won't")
  994. clientProtocol.transport.writeSequence([b"won't!"])
  995. # Finish TLS close handshake:
  996. self.flushTwoTLSProtocols(tlsProtocol, serverTLSProtocol)
  997. self.assertTrue(tlsProtocol.transport.disconnecting)
  998. # Bytes made it through, as long as they were written before producer
  999. # was unregistered:
  1000. self.assertEqual(b"".join(serverProtocol.received), b"x hello world")
  1001. def test_streamingProducerLoseConnectionWithProducer(self):
  1002. """
  1003. loseConnection() waits for the producer to unregister itself, then
  1004. does a clean TLS close alert, then closes the underlying connection.
  1005. """
  1006. return self.loseConnectionWithProducer(False)
  1007. def test_streamingProducerLoseConnectionWithProducerWBOR(self):
  1008. """
  1009. Even when writes are blocked on reading, loseConnection() waits for
  1010. the producer to unregister itself, then does a clean TLS close alert,
  1011. then closes the underlying connection.
  1012. """
  1013. return self.loseConnectionWithProducer(True)
  1014. def test_streamingProducerBothTransportsDecideToPause(self):
  1015. """
  1016. pauseProducing() events can come from both the TLS transport layer and
  1017. the underlying transport. In this case, both decide to pause,
  1018. underlying first.
  1019. """
  1020. class PausingStringTransport(StringTransport):
  1021. _didPause = False
  1022. def write(self, data):
  1023. if not self._didPause and self.producer is not None:
  1024. self._didPause = True
  1025. self.producer.pauseProducing()
  1026. StringTransport.write(self, data)
  1027. class TLSConnection(object):
  1028. def __init__(self):
  1029. self.l = []
  1030. def send(self, data):
  1031. # on first write, don't send all bytes:
  1032. if not self.l:
  1033. data = data[:-1]
  1034. # pause on second write:
  1035. if len(self.l) == 1:
  1036. self.l.append("paused")
  1037. raise WantReadError()
  1038. # otherwise just take in data:
  1039. self.l.append(data)
  1040. return len(data)
  1041. def set_connect_state(self):
  1042. pass
  1043. def do_handshake(self):
  1044. pass
  1045. def bio_write(self, data):
  1046. pass
  1047. def bio_read(self, size):
  1048. return b'X'
  1049. def recv(self, size):
  1050. raise WantReadError()
  1051. transport = PausingStringTransport()
  1052. clientProtocol, tlsProtocol, producer = self.setupStreamingProducer(
  1053. transport, fakeConnection=TLSConnection())
  1054. self.assertEqual(producer.producerState, 'producing')
  1055. # Shove in fake TLSConnection that will raise WantReadError the second
  1056. # time send() is called. This will allow us to have bytes written to
  1057. # to the PausingStringTransport, so it will pause the producer. Then,
  1058. # WantReadError will be thrown, triggering the TLS transport's
  1059. # producer code path.
  1060. clientProtocol.transport.write(b"hello")
  1061. self.assertEqual(producer.producerState, 'paused')
  1062. self.assertEqual(producer.producerHistory, ['pause'])
  1063. # Now, underlying transport resumes, and then we deliver some data to
  1064. # TLS transport so that it will resume:
  1065. tlsProtocol.transport.producer.resumeProducing()
  1066. self.assertEqual(producer.producerState, 'producing')
  1067. self.assertEqual(producer.producerHistory, ['pause', 'resume'])
  1068. tlsProtocol.dataReceived(b"hello")
  1069. self.assertEqual(producer.producerState, 'producing')
  1070. self.assertEqual(producer.producerHistory, ['pause', 'resume'])
  1071. def test_streamingProducerStopProducing(self):
  1072. """
  1073. If the underlying transport tells its producer to stopProducing(),
  1074. this is passed on to the high-level producer.
  1075. """
  1076. _, tlsProtocol, producer = self.setupStreamingProducer()
  1077. tlsProtocol.transport.producer.stopProducing()
  1078. self.assertEqual(producer.producerState, 'stopped')
  1079. def test_nonStreamingProducer(self):
  1080. """
  1081. Non-streaming producers get wrapped as streaming producers.
  1082. """
  1083. clientProtocol, tlsProtocol = buildTLSProtocol()
  1084. producer = NonStreamingProducer(clientProtocol.transport)
  1085. # Register non-streaming producer:
  1086. clientProtocol.transport.registerProducer(producer, False)
  1087. streamingProducer = tlsProtocol.transport.producer._producer
  1088. # Verify it was wrapped into streaming producer:
  1089. self.assertIsInstance(streamingProducer, _PullToPush)
  1090. self.assertEqual(streamingProducer._producer, producer)
  1091. self.assertEqual(streamingProducer._consumer, clientProtocol.transport)
  1092. self.assertTrue(tlsProtocol.transport.streaming)
  1093. # Verify the streaming producer was started, and ran until the end:
  1094. def done(ignore):
  1095. # Our own producer is done:
  1096. self.assertIsNone(producer.consumer)
  1097. # The producer has been unregistered:
  1098. self.assertIsNone(tlsProtocol.transport.producer)
  1099. # The streaming producer wrapper knows it's done:
  1100. self.assertTrue(streamingProducer._finished)
  1101. producer.result.addCallback(done)
  1102. serverProtocol, serverTLSProtocol = buildTLSProtocol(server=True)
  1103. self.flushTwoTLSProtocols(tlsProtocol, serverTLSProtocol)
  1104. return producer.result
  1105. def test_interface(self):
  1106. """
  1107. L{_ProducerMembrane} implements L{IPushProducer}.
  1108. """
  1109. producer = StringTransport()
  1110. membrane = _ProducerMembrane(producer)
  1111. self.assertTrue(verifyObject(IPushProducer, membrane))
  1112. def registerProducerAfterConnectionLost(self, streaming):
  1113. """
  1114. If a producer is registered after the transport has disconnected, the
  1115. producer is not used, and its stopProducing method is called.
  1116. """
  1117. clientProtocol, tlsProtocol = buildTLSProtocol()
  1118. clientProtocol.connectionLost = lambda reason: reason.trap(
  1119. Error, ConnectionLost)
  1120. class Producer(object):
  1121. stopped = False
  1122. def resumeProducing(self):
  1123. return 1/0 # this should never be called
  1124. def stopProducing(self):
  1125. self.stopped = True
  1126. # Disconnect the transport:
  1127. tlsProtocol.connectionLost(Failure(ConnectionDone()))
  1128. # Register the producer; startProducing should not be called, but
  1129. # stopProducing will:
  1130. producer = Producer()
  1131. tlsProtocol.registerProducer(producer, False)
  1132. self.assertIsNone(tlsProtocol.transport.producer)
  1133. self.assertTrue(producer.stopped)
  1134. def test_streamingProducerAfterConnectionLost(self):
  1135. """
  1136. If a streaming producer is registered after the transport has
  1137. disconnected, the producer is not used, and its stopProducing method
  1138. is called.
  1139. """
  1140. self.registerProducerAfterConnectionLost(True)
  1141. def test_nonStreamingProducerAfterConnectionLost(self):
  1142. """
  1143. If a non-streaming producer is registered after the transport has
  1144. disconnected, the producer is not used, and its stopProducing method
  1145. is called.
  1146. """
  1147. self.registerProducerAfterConnectionLost(False)
  1148. class NonStreamingProducerTests(TestCase):
  1149. """
  1150. Non-streaming producers can be adapted into being streaming producers.
  1151. """
  1152. def streamUntilEnd(self, consumer):
  1153. """
  1154. Verify the consumer writes out all its data, but is not called after
  1155. that.
  1156. """
  1157. nsProducer = NonStreamingProducer(consumer)
  1158. streamingProducer = _PullToPush(nsProducer, consumer)
  1159. consumer.registerProducer(streamingProducer, True)
  1160. # The producer will call unregisterProducer(), and we need to hook
  1161. # that up so the streaming wrapper is notified; the
  1162. # TLSMemoryBIOProtocol will have to do this itself, which is tested
  1163. # elsewhere:
  1164. def unregister(orig=consumer.unregisterProducer):
  1165. orig()
  1166. streamingProducer.stopStreaming()
  1167. consumer.unregisterProducer = unregister
  1168. done = nsProducer.result
  1169. def doneStreaming(_):
  1170. # All data was streamed, and the producer unregistered itself:
  1171. self.assertEqual(consumer.value(), b"0123456789")
  1172. self.assertIsNone(consumer.producer)
  1173. # And the streaming wrapper stopped:
  1174. self.assertTrue(streamingProducer._finished)
  1175. done.addCallback(doneStreaming)
  1176. # Now, start streaming:
  1177. streamingProducer.startStreaming()
  1178. return done
  1179. def test_writeUntilDone(self):
  1180. """
  1181. When converted to a streaming producer, the non-streaming producer
  1182. writes out all its data, but is not called after that.
  1183. """
  1184. consumer = StringTransport()
  1185. return self.streamUntilEnd(consumer)
  1186. def test_pause(self):
  1187. """
  1188. When the streaming producer is paused, the underlying producer stops
  1189. getting resumeProducing calls.
  1190. """
  1191. class PausingStringTransport(StringTransport):
  1192. writes = 0
  1193. def __init__(self):
  1194. StringTransport.__init__(self)
  1195. self.paused = Deferred()
  1196. def write(self, data):
  1197. self.writes += 1
  1198. StringTransport.write(self, data)
  1199. if self.writes == 3:
  1200. self.producer.pauseProducing()
  1201. d = self.paused
  1202. del self.paused
  1203. d.callback(None)
  1204. consumer = PausingStringTransport()
  1205. nsProducer = NonStreamingProducer(consumer)
  1206. streamingProducer = _PullToPush(nsProducer, consumer)
  1207. consumer.registerProducer(streamingProducer, True)
  1208. # Make sure the consumer does not continue:
  1209. def shouldNotBeCalled(ignore):
  1210. self.fail("BUG: The producer should not finish!")
  1211. nsProducer.result.addCallback(shouldNotBeCalled)
  1212. done = consumer.paused
  1213. def paused(ignore):
  1214. # The CooperatorTask driving the producer was paused:
  1215. self.assertEqual(streamingProducer._coopTask._pauseCount, 1)
  1216. done.addCallback(paused)
  1217. # Now, start streaming:
  1218. streamingProducer.startStreaming()
  1219. return done
  1220. def test_resume(self):
  1221. """
  1222. When the streaming producer is paused and then resumed, the underlying
  1223. producer starts getting resumeProducing calls again after the resume.
  1224. The test will never finish (or rather, time out) if the resume
  1225. producing call is not working.
  1226. """
  1227. class PausingStringTransport(StringTransport):
  1228. writes = 0
  1229. def write(self, data):
  1230. self.writes += 1
  1231. StringTransport.write(self, data)
  1232. if self.writes == 3:
  1233. self.producer.pauseProducing()
  1234. self.producer.resumeProducing()
  1235. consumer = PausingStringTransport()
  1236. return self.streamUntilEnd(consumer)
  1237. def test_stopProducing(self):
  1238. """
  1239. When the streaming producer is stopped by the consumer, the underlying
  1240. producer is stopped, and streaming is stopped.
  1241. """
  1242. class StoppingStringTransport(StringTransport):
  1243. writes = 0
  1244. def write(self, data):
  1245. self.writes += 1
  1246. StringTransport.write(self, data)
  1247. if self.writes == 3:
  1248. self.producer.stopProducing()
  1249. consumer = StoppingStringTransport()
  1250. nsProducer = NonStreamingProducer(consumer)
  1251. streamingProducer = _PullToPush(nsProducer, consumer)
  1252. consumer.registerProducer(streamingProducer, True)
  1253. done = nsProducer.result
  1254. def doneStreaming(_):
  1255. # Not all data was streamed, and the producer was stopped:
  1256. self.assertEqual(consumer.value(), b"012")
  1257. self.assertTrue(nsProducer.stopped)
  1258. # And the streaming wrapper stopped:
  1259. self.assertTrue(streamingProducer._finished)
  1260. done.addCallback(doneStreaming)
  1261. # Now, start streaming:
  1262. streamingProducer.startStreaming()
  1263. return done
  1264. def resumeProducingRaises(self, consumer, expectedExceptions):
  1265. """
  1266. Common implementation for tests where the underlying producer throws
  1267. an exception when its resumeProducing is called.
  1268. """
  1269. class ThrowingProducer(NonStreamingProducer):
  1270. def resumeProducing(self):
  1271. if self.counter == 2:
  1272. return 1/0
  1273. else:
  1274. NonStreamingProducer.resumeProducing(self)
  1275. nsProducer = ThrowingProducer(consumer)
  1276. streamingProducer = _PullToPush(nsProducer, consumer)
  1277. consumer.registerProducer(streamingProducer, True)
  1278. # Register log observer:
  1279. loggedMsgs = []
  1280. log.addObserver(loggedMsgs.append)
  1281. self.addCleanup(log.removeObserver, loggedMsgs.append)
  1282. # Make consumer unregister do what TLSMemoryBIOProtocol would do:
  1283. def unregister(orig=consumer.unregisterProducer):
  1284. orig()
  1285. streamingProducer.stopStreaming()
  1286. consumer.unregisterProducer = unregister
  1287. # Start streaming:
  1288. streamingProducer.startStreaming()
  1289. done = streamingProducer._coopTask.whenDone()
  1290. done.addErrback(lambda reason: reason.trap(TaskStopped))
  1291. def stopped(ign):
  1292. self.assertEqual(consumer.value(), b"01")
  1293. # Any errors from resumeProducing were logged:
  1294. errors = self.flushLoggedErrors()
  1295. self.assertEqual(len(errors), len(expectedExceptions))
  1296. for f, (expected, msg), logMsg in zip(
  1297. errors, expectedExceptions, loggedMsgs):
  1298. self.assertTrue(f.check(expected))
  1299. self.assertIn(msg, logMsg['why'])
  1300. # And the streaming wrapper stopped:
  1301. self.assertTrue(streamingProducer._finished)
  1302. done.addCallback(stopped)
  1303. return done
  1304. def test_resumeProducingRaises(self):
  1305. """
  1306. If the underlying producer raises an exception when resumeProducing is
  1307. called, the streaming wrapper should log the error, unregister from
  1308. the consumer and stop streaming.
  1309. """
  1310. consumer = StringTransport()
  1311. done = self.resumeProducingRaises(
  1312. consumer,
  1313. [(ZeroDivisionError, "failed, producing will be stopped")])
  1314. def cleanShutdown(ignore):
  1315. # Producer was unregistered from consumer:
  1316. self.assertIsNone(consumer.producer)
  1317. done.addCallback(cleanShutdown)
  1318. return done
  1319. def test_resumeProducingRaiseAndUnregisterProducerRaises(self):
  1320. """
  1321. If the underlying producer raises an exception when resumeProducing is
  1322. called, the streaming wrapper should log the error, unregister from
  1323. the consumer and stop streaming even if the unregisterProducer call
  1324. also raise.
  1325. """
  1326. consumer = StringTransport()
  1327. def raiser():
  1328. raise RuntimeError()
  1329. consumer.unregisterProducer = raiser
  1330. return self.resumeProducingRaises(
  1331. consumer,
  1332. [(ZeroDivisionError, "failed, producing will be stopped"),
  1333. (RuntimeError, "failed to unregister producer")])
  1334. def test_stopStreamingTwice(self):
  1335. """
  1336. stopStreaming() can be called more than once without blowing
  1337. up. This is useful for error-handling paths.
  1338. """
  1339. consumer = StringTransport()
  1340. nsProducer = NonStreamingProducer(consumer)
  1341. streamingProducer = _PullToPush(nsProducer, consumer)
  1342. streamingProducer.startStreaming()
  1343. streamingProducer.stopStreaming()
  1344. streamingProducer.stopStreaming()
  1345. self.assertTrue(streamingProducer._finished)
  1346. def test_interface(self):
  1347. """
  1348. L{_PullToPush} implements L{IPushProducer}.
  1349. """
  1350. consumer = StringTransport()
  1351. nsProducer = NonStreamingProducer(consumer)
  1352. streamingProducer = _PullToPush(nsProducer, consumer)
  1353. self.assertTrue(verifyObject(IPushProducer, streamingProducer))
  1354. @implementer(IProtocolNegotiationFactory)
  1355. class ClientNegotiationFactory(ClientFactory):
  1356. """
  1357. A L{ClientFactory} that has a set of acceptable protocols for NPN/ALPN
  1358. negotiation.
  1359. """
  1360. def __init__(self, acceptableProtocols):
  1361. """
  1362. Create a L{ClientNegotiationFactory}.
  1363. @param acceptableProtocols: The protocols the client will accept
  1364. speaking after the TLS handshake is complete.
  1365. @type acceptableProtocols: L{list} of L{bytes}
  1366. """
  1367. self._acceptableProtocols = acceptableProtocols
  1368. def acceptableProtocols(self):
  1369. """
  1370. Returns a list of protocols that can be spoken by the connection
  1371. factory in the form of ALPN tokens, as laid out in the IANA registry
  1372. for ALPN tokens.
  1373. @return: a list of ALPN tokens in order of preference.
  1374. @rtype: L{list} of L{bytes}
  1375. """
  1376. return self._acceptableProtocols
  1377. @implementer(IProtocolNegotiationFactory)
  1378. class ServerNegotiationFactory(ServerFactory):
  1379. """
  1380. A L{ServerFactory} that has a set of acceptable protocols for NPN/ALPN
  1381. negotiation.
  1382. """
  1383. def __init__(self, acceptableProtocols):
  1384. """
  1385. Create a L{ServerNegotiationFactory}.
  1386. @param acceptableProtocols: The protocols the server will accept
  1387. speaking after the TLS handshake is complete.
  1388. @type acceptableProtocols: L{list} of L{bytes}
  1389. """
  1390. self._acceptableProtocols = acceptableProtocols
  1391. def acceptableProtocols(self):
  1392. """
  1393. Returns a list of protocols that can be spoken by the connection
  1394. factory in the form of ALPN tokens, as laid out in the IANA registry
  1395. for ALPN tokens.
  1396. @return: a list of ALPN tokens in order of preference.
  1397. @rtype: L{list} of L{bytes}
  1398. """
  1399. return self._acceptableProtocols
  1400. class IProtocolNegotiationFactoryTests(TestCase):
  1401. """
  1402. Tests for L{IProtocolNegotiationFactory} inside L{TLSMemoryBIOFactory}.
  1403. These tests expressly don't include the case where both server and client
  1404. advertise protocols but don't have any overlap. This is because the
  1405. behaviour here is platform-dependent and changes from version to version.
  1406. Prior to version 1.1.0 of OpenSSL, failing the ALPN negotiation does not
  1407. fail the handshake. At least in 1.0.2h, failing NPN *does* fail the
  1408. handshake, at least with the callback implemented by PyOpenSSL.
  1409. This is sufficiently painful to test that we simply don't. It's not
  1410. necessary to validate that our offering logic works anyway: all we need to
  1411. see is that it works in the successful case and that it degrades properly.
  1412. """
  1413. def handshakeProtocols(self, clientProtocols, serverProtocols):
  1414. """
  1415. Start handshake between TLS client and server.
  1416. @param clientProtocols: The protocols the client will accept speaking
  1417. after the TLS handshake is complete.
  1418. @type clientProtocols: L{list} of L{bytes}
  1419. @param serverProtocols: The protocols the server will accept speaking
  1420. after the TLS handshake is complete.
  1421. @type serverProtocols: L{list} of L{bytes}
  1422. @return: A L{tuple} of four different items: the client L{Protocol},
  1423. the server L{Protocol}, a L{Deferred} that fires when the client
  1424. first receives bytes (and so the TLS connection is complete), and a
  1425. L{Deferred} that fires when the server first receives bytes.
  1426. @rtype: A L{tuple} of (L{Protocol}, L{Protocol}, L{Deferred},
  1427. L{Deferred})
  1428. """
  1429. data = b'some bytes'
  1430. class NotifyingSender(Protocol):
  1431. def __init__(self, notifier):
  1432. self.notifier = notifier
  1433. def connectionMade(self):
  1434. self.transport.writeSequence(list(iterbytes(data)))
  1435. def dataReceived(self, data):
  1436. if self.notifier is not None:
  1437. self.notifier.callback(self)
  1438. self.notifier = None
  1439. clientDataReceived = Deferred()
  1440. clientFactory = ClientNegotiationFactory(clientProtocols)
  1441. clientFactory.protocol = lambda: NotifyingSender(
  1442. clientDataReceived
  1443. )
  1444. clientContextFactory, _ = (
  1445. HandshakeCallbackContextFactory.factoryAndDeferred())
  1446. wrapperFactory = TLSMemoryBIOFactory(
  1447. clientContextFactory, True, clientFactory)
  1448. sslClientProtocol = wrapperFactory.buildProtocol(None)
  1449. serverDataReceived = Deferred()
  1450. serverFactory = ServerNegotiationFactory(serverProtocols)
  1451. serverFactory.protocol = lambda: NotifyingSender(
  1452. serverDataReceived
  1453. )
  1454. serverContextFactory = ServerTLSContext()
  1455. wrapperFactory = TLSMemoryBIOFactory(
  1456. serverContextFactory, False, serverFactory)
  1457. sslServerProtocol = wrapperFactory.buildProtocol(None)
  1458. loopbackAsync(
  1459. sslServerProtocol, sslClientProtocol
  1460. )
  1461. return (sslClientProtocol, sslServerProtocol, clientDataReceived,
  1462. serverDataReceived)
  1463. def test_negotiationWithNoProtocols(self):
  1464. """
  1465. When factories support L{IProtocolNegotiationFactory} but don't
  1466. advertise support for any protocols, no protocols are negotiated.
  1467. """
  1468. client, server, clientDataReceived, serverDataReceived = (
  1469. self.handshakeProtocols([], [])
  1470. )
  1471. def checkNegotiatedProtocol(ignored):
  1472. self.assertEqual(client.negotiatedProtocol, None)
  1473. self.assertEqual(server.negotiatedProtocol, None)
  1474. clientDataReceived.addCallback(lambda ignored: serverDataReceived)
  1475. serverDataReceived.addCallback(checkNegotiatedProtocol)
  1476. return clientDataReceived
  1477. def test_negotiationWithProtocolOverlap(self):
  1478. """
  1479. When factories support L{IProtocolNegotiationFactory} and support
  1480. overlapping protocols, the first protocol is negotiated.
  1481. """
  1482. client, server, clientDataReceived, serverDataReceived = (
  1483. self.handshakeProtocols([b'h2', b'http/1.1'], [b'h2', b'http/1.1'])
  1484. )
  1485. def checkNegotiatedProtocol(ignored):
  1486. self.assertEqual(client.negotiatedProtocol, b'h2')
  1487. self.assertEqual(server.negotiatedProtocol, b'h2')
  1488. clientDataReceived.addCallback(lambda ignored: serverDataReceived)
  1489. serverDataReceived.addCallback(checkNegotiatedProtocol)
  1490. return clientDataReceived
  1491. def test_negotiationClientOnly(self):
  1492. """
  1493. When factories support L{IProtocolNegotiationFactory} and only the
  1494. client advertises, nothing is negotiated.
  1495. """
  1496. client, server, clientDataReceived, serverDataReceived = (
  1497. self.handshakeProtocols([b'h2', b'http/1.1'], [])
  1498. )
  1499. def checkNegotiatedProtocol(ignored):
  1500. self.assertEqual(client.negotiatedProtocol, None)
  1501. self.assertEqual(server.negotiatedProtocol, None)
  1502. clientDataReceived.addCallback(lambda ignored: serverDataReceived)
  1503. serverDataReceived.addCallback(checkNegotiatedProtocol)
  1504. return clientDataReceived
  1505. def test_negotiationServerOnly(self):
  1506. """
  1507. When factories support L{IProtocolNegotiationFactory} and only the
  1508. server advertises, nothing is negotiated.
  1509. """
  1510. client, server, clientDataReceived, serverDataReceived = (
  1511. self.handshakeProtocols([], [b'h2', b'http/1.1'])
  1512. )
  1513. def checkNegotiatedProtocol(ignored):
  1514. self.assertEqual(client.negotiatedProtocol, None)
  1515. self.assertEqual(server.negotiatedProtocol, None)
  1516. clientDataReceived.addCallback(lambda ignored: serverDataReceived)
  1517. serverDataReceived.addCallback(checkNegotiatedProtocol)
  1518. return clientDataReceived