test_tls.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. Tests for implementations of L{ITLSTransport}.
  5. """
  6. from __future__ import division, absolute_import
  7. __metaclass__ = type
  8. from zope.interface import implementer
  9. from twisted.python.compat import networkString
  10. from twisted.python.filepath import FilePath
  11. from twisted.internet.test.reactormixins import ReactorBuilder
  12. from twisted.internet.protocol import ServerFactory, ClientFactory, Protocol
  13. from twisted.internet.interfaces import (
  14. IReactorSSL, ITLSTransport, IStreamClientEndpoint)
  15. from twisted.internet.defer import Deferred, DeferredList
  16. from twisted.internet.endpoints import (
  17. SSL4ServerEndpoint, SSL4ClientEndpoint, TCP4ClientEndpoint)
  18. from twisted.internet.error import ConnectionClosed
  19. from twisted.internet.task import Cooperator
  20. from twisted.trial.unittest import SkipTest
  21. from twisted.python.runtime import platform
  22. from twisted.internet.test.test_core import ObjectModelIntegrationMixin
  23. from twisted.internet.test.test_tcp import (
  24. StreamTransportTestsMixin, AbortConnectionMixin)
  25. from twisted.internet.test.connectionmixins import (
  26. EndpointCreator, ConnectionTestsMixin, BrokenContextFactory)
  27. try:
  28. from OpenSSL.crypto import FILETYPE_PEM
  29. except ImportError:
  30. FILETYPE_PEM = None
  31. else:
  32. from twisted.internet.ssl import PrivateCertificate, KeyPair
  33. from twisted.internet.ssl import ClientContextFactory
  34. class TLSMixin:
  35. requiredInterfaces = [IReactorSSL]
  36. if platform.isWindows():
  37. msg = (
  38. "For some reason, these reactors don't deal with SSL "
  39. "disconnection correctly on Windows. See #3371.")
  40. skippedReactors = {
  41. "twisted.internet.glib2reactor.Glib2Reactor": msg,
  42. "twisted.internet.gtk2reactor.Gtk2Reactor": msg}
  43. class ContextGeneratingMixin(object):
  44. import twisted
  45. _pem = FilePath(
  46. networkString(twisted.__file__)).sibling(b"test").child(b"server.pem")
  47. del twisted
  48. def getServerContext(self):
  49. """
  50. Return a new SSL context suitable for use in a test server.
  51. """
  52. pem = self._pem.getContent()
  53. cert = PrivateCertificate.load(
  54. pem, KeyPair.load(pem, FILETYPE_PEM), FILETYPE_PEM)
  55. return cert.options()
  56. def getClientContext(self):
  57. return ClientContextFactory()
  58. @implementer(IStreamClientEndpoint)
  59. class StartTLSClientEndpoint(object):
  60. """
  61. An endpoint which wraps another one and adds a TLS layer immediately when
  62. connections are set up.
  63. @ivar wrapped: A L{IStreamClientEndpoint} provider which will be used to
  64. really set up connections.
  65. @ivar contextFactory: A L{ContextFactory} to use to do TLS.
  66. """
  67. def __init__(self, wrapped, contextFactory):
  68. self.wrapped = wrapped
  69. self.contextFactory = contextFactory
  70. def connect(self, factory):
  71. """
  72. Establish a connection using a protocol build by C{factory} and
  73. immediately start TLS on it. Return a L{Deferred} which fires with the
  74. protocol instance.
  75. """
  76. # This would be cleaner when we have ITransport.switchProtocol, which
  77. # will be added with ticket #3204:
  78. class WrapperFactory(ServerFactory):
  79. def buildProtocol(wrapperSelf, addr):
  80. protocol = factory.buildProtocol(addr)
  81. def connectionMade(orig=protocol.connectionMade):
  82. protocol.transport.startTLS(self.contextFactory)
  83. orig()
  84. protocol.connectionMade = connectionMade
  85. return protocol
  86. return self.wrapped.connect(WrapperFactory())
  87. class StartTLSClientCreator(EndpointCreator, ContextGeneratingMixin):
  88. """
  89. Create L{ITLSTransport.startTLS} endpoint for the client, and normal SSL
  90. for server just because it's easier.
  91. """
  92. def server(self, reactor):
  93. """
  94. Construct an SSL server endpoint. This should be constructing a TCP
  95. server endpoint which immediately calls C{startTLS} instead, but that
  96. is hard.
  97. """
  98. return SSL4ServerEndpoint(reactor, 0, self.getServerContext())
  99. def client(self, reactor, serverAddress):
  100. """
  101. Construct a TCP client endpoint wrapped to immediately start TLS.
  102. """
  103. return StartTLSClientEndpoint(
  104. TCP4ClientEndpoint(
  105. reactor, '127.0.0.1', serverAddress.port),
  106. ClientContextFactory())
  107. class BadContextTestsMixin(object):
  108. """
  109. Mixin for L{ReactorBuilder} subclasses which defines a helper for testing
  110. the handling of broken context factories.
  111. """
  112. def _testBadContext(self, useIt):
  113. """
  114. Assert that the exception raised by a broken context factory's
  115. C{getContext} method is raised by some reactor method. If it is not, an
  116. exception will be raised to fail the test.
  117. @param useIt: A two-argument callable which will be called with a
  118. reactor and a broken context factory and which is expected to raise
  119. the same exception as the broken context factory's C{getContext}
  120. method.
  121. """
  122. reactor = self.buildReactor()
  123. exc = self.assertRaises(
  124. ValueError, useIt, reactor, BrokenContextFactory())
  125. self.assertEqual(BrokenContextFactory.message, str(exc))
  126. class StartTLSClientTestsMixin(TLSMixin, ReactorBuilder, ConnectionTestsMixin):
  127. """
  128. Tests for TLS connections established using L{ITLSTransport.startTLS} (as
  129. opposed to L{IReactorSSL.connectSSL} or L{IReactorSSL.listenSSL}).
  130. """
  131. endpoints = StartTLSClientCreator()
  132. class SSLCreator(EndpointCreator, ContextGeneratingMixin):
  133. """
  134. Create SSL endpoints.
  135. """
  136. def server(self, reactor):
  137. """
  138. Create an SSL server endpoint on a TCP/IP-stack allocated port.
  139. """
  140. return SSL4ServerEndpoint(reactor, 0, self.getServerContext())
  141. def client(self, reactor, serverAddress):
  142. """
  143. Create an SSL client endpoint which will connect localhost on
  144. the port given by C{serverAddress}.
  145. @type serverAddress: L{IPv4Address}
  146. """
  147. return SSL4ClientEndpoint(
  148. reactor, '127.0.0.1', serverAddress.port,
  149. ClientContextFactory())
  150. class SSLClientTestsMixin(TLSMixin, ReactorBuilder, ContextGeneratingMixin,
  151. ConnectionTestsMixin, BadContextTestsMixin):
  152. """
  153. Mixin defining tests relating to L{ITLSTransport}.
  154. """
  155. endpoints = SSLCreator()
  156. def test_badContext(self):
  157. """
  158. If the context factory passed to L{IReactorSSL.connectSSL} raises an
  159. exception from its C{getContext} method, that exception is raised by
  160. L{IReactorSSL.connectSSL}.
  161. """
  162. def useIt(reactor, contextFactory):
  163. return reactor.connectSSL(
  164. "127.0.0.1", 1234, ClientFactory(), contextFactory)
  165. self._testBadContext(useIt)
  166. def test_disconnectAfterWriteAfterStartTLS(self):
  167. """
  168. L{ITCPTransport.loseConnection} ends a connection which was set up with
  169. L{ITLSTransport.startTLS} and which has recently been written to. This
  170. is intended to verify that a socket send error masked by the TLS
  171. implementation doesn't prevent the connection from being reported as
  172. closed.
  173. """
  174. class ShortProtocol(Protocol):
  175. def connectionMade(self):
  176. if not ITLSTransport.providedBy(self.transport):
  177. # Functionality isn't available to be tested.
  178. finished = self.factory.finished
  179. self.factory.finished = None
  180. finished.errback(SkipTest("No ITLSTransport support"))
  181. return
  182. # Switch the transport to TLS.
  183. self.transport.startTLS(self.factory.context)
  184. # Force TLS to really get negotiated. If nobody talks, nothing
  185. # will happen.
  186. self.transport.write(b"x")
  187. def dataReceived(self, data):
  188. # Stuff some bytes into the socket. This mostly has the effect
  189. # of causing the next write to fail with ENOTCONN or EPIPE.
  190. # With the pyOpenSSL implementation of ITLSTransport, the error
  191. # is swallowed outside of the control of Twisted.
  192. self.transport.write(b"y")
  193. # Now close the connection, which requires a TLS close alert to
  194. # be sent.
  195. self.transport.loseConnection()
  196. def connectionLost(self, reason):
  197. # This is the success case. The client and the server want to
  198. # get here.
  199. finished = self.factory.finished
  200. if finished is not None:
  201. self.factory.finished = None
  202. finished.callback(reason)
  203. reactor = self.buildReactor()
  204. serverFactory = ServerFactory()
  205. serverFactory.finished = Deferred()
  206. serverFactory.protocol = ShortProtocol
  207. serverFactory.context = self.getServerContext()
  208. clientFactory = ClientFactory()
  209. clientFactory.finished = Deferred()
  210. clientFactory.protocol = ShortProtocol
  211. clientFactory.context = self.getClientContext()
  212. clientFactory.context.method = serverFactory.context.method
  213. lostConnectionResults = []
  214. finished = DeferredList(
  215. [serverFactory.finished, clientFactory.finished],
  216. consumeErrors=True)
  217. def cbFinished(results):
  218. lostConnectionResults.extend([results[0][1], results[1][1]])
  219. finished.addCallback(cbFinished)
  220. port = reactor.listenTCP(0, serverFactory, interface='127.0.0.1')
  221. self.addCleanup(port.stopListening)
  222. connector = reactor.connectTCP(
  223. port.getHost().host, port.getHost().port, clientFactory)
  224. self.addCleanup(connector.disconnect)
  225. finished.addCallback(lambda ign: reactor.stop())
  226. self.runReactor(reactor)
  227. lostConnectionResults[0].trap(ConnectionClosed)
  228. lostConnectionResults[1].trap(ConnectionClosed)
  229. class TLSPortTestsBuilder(TLSMixin, ContextGeneratingMixin,
  230. ObjectModelIntegrationMixin, BadContextTestsMixin,
  231. StreamTransportTestsMixin, ReactorBuilder):
  232. """
  233. Tests for L{IReactorSSL.listenSSL}
  234. """
  235. def getListeningPort(self, reactor, factory):
  236. """
  237. Get a TLS port from a reactor.
  238. """
  239. return reactor.listenSSL(0, factory, self.getServerContext())
  240. def getExpectedStartListeningLogMessage(self, port, factory):
  241. """
  242. Get the message expected to be logged when a TLS port starts listening.
  243. """
  244. return "%s (TLS) starting on %d" % (factory, port.getHost().port)
  245. def getExpectedConnectionLostLogMsg(self, port):
  246. """
  247. Get the expected connection lost message for a TLS port.
  248. """
  249. return "(TLS Port %s Closed)" % (port.getHost().port,)
  250. def test_badContext(self):
  251. """
  252. If the context factory passed to L{IReactorSSL.listenSSL} raises an
  253. exception from its C{getContext} method, that exception is raised by
  254. L{IReactorSSL.listenSSL}.
  255. """
  256. def useIt(reactor, contextFactory):
  257. return reactor.listenSSL(0, ServerFactory(), contextFactory)
  258. self._testBadContext(useIt)
  259. globals().update(SSLClientTestsMixin.makeTestCaseClasses())
  260. globals().update(StartTLSClientTestsMixin.makeTestCaseClasses())
  261. globals().update(TLSPortTestsBuilder().makeTestCaseClasses())
  262. class AbortSSLConnectionTests(ReactorBuilder, AbortConnectionMixin, ContextGeneratingMixin):
  263. """
  264. C{abortConnection} tests using SSL.
  265. """
  266. requiredInterfaces = (IReactorSSL,)
  267. endpoints = SSLCreator()
  268. def buildReactor(self):
  269. reactor = ReactorBuilder.buildReactor(self)
  270. from twisted.internet import _producer_helpers
  271. # Patch twisted.protocols.tls to use this reactor, until we get
  272. # around to fixing #5206, or the TLS code uses an explicit reactor:
  273. cooperator = Cooperator(
  274. scheduler=lambda x: reactor.callLater(0.00001, x))
  275. self.patch(_producer_helpers, "cooperate", cooperator.cooperate)
  276. return reactor
  277. def setUp(self):
  278. if FILETYPE_PEM is None:
  279. raise SkipTest("OpenSSL not available.")
  280. globals().update(AbortSSLConnectionTests.makeTestCaseClasses())