123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730 |
- # Copyright (c) Twisted Matrix Laboratories.
- # See LICENSE for details.
- """
- Tests for twisted SSL support.
- """
- from __future__ import division, absolute_import
- from twisted.python.filepath import FilePath
- from twisted.trial import unittest
- from twisted.internet import protocol, reactor, interfaces, defer
- from twisted.internet.error import ConnectionDone
- from twisted.protocols import basic
- from twisted.python.reflect import requireModule
- from twisted.python.runtime import platform
- from twisted.test.test_tcp import ProperlyCloseFilesMixin
- from twisted.test.proto_helpers import waitUntilAllDisconnected
- import os, errno
- try:
- from OpenSSL import SSL, crypto
- from twisted.internet import ssl
- from twisted.test.ssl_helpers import ClientTLSContext, certPath
- except ImportError:
- def _noSSL():
- # ugh, make pyflakes happy.
- global SSL
- global ssl
- SSL = ssl = None
- _noSSL()
- try:
- from twisted.protocols import tls as newTLS
- except ImportError:
- # Assuming SSL exists, we're using old version in reactor (i.e. non-protocol)
- newTLS = None
- from zope.interface import implementer
- class UnintelligentProtocol(basic.LineReceiver):
- """
- @ivar deferred: a deferred that will fire at connection lost.
- @type deferred: L{defer.Deferred}
- @cvar pretext: text sent before TLS is set up.
- @type pretext: C{bytes}
- @cvar posttext: text sent after TLS is set up.
- @type posttext: C{bytes}
- """
- pretext = [
- b"first line",
- b"last thing before tls starts",
- b"STARTTLS"]
- posttext = [
- b"first thing after tls started",
- b"last thing ever"]
- def __init__(self):
- self.deferred = defer.Deferred()
- def connectionMade(self):
- for l in self.pretext:
- self.sendLine(l)
- def lineReceived(self, line):
- if line == b"READY":
- self.transport.startTLS(ClientTLSContext(), self.factory.client)
- for l in self.posttext:
- self.sendLine(l)
- self.transport.loseConnection()
- def connectionLost(self, reason):
- self.deferred.callback(None)
- class LineCollector(basic.LineReceiver):
- """
- @ivar deferred: a deferred that will fire at connection lost.
- @type deferred: L{defer.Deferred}
- @ivar doTLS: whether the protocol is initiate TLS or not.
- @type doTLS: C{bool}
- @ivar fillBuffer: if set to True, it will send lots of data once
- C{STARTTLS} is received.
- @type fillBuffer: C{bool}
- """
- def __init__(self, doTLS, fillBuffer=False):
- self.doTLS = doTLS
- self.fillBuffer = fillBuffer
- self.deferred = defer.Deferred()
- def connectionMade(self):
- self.factory.rawdata = b''
- self.factory.lines = []
- def lineReceived(self, line):
- self.factory.lines.append(line)
- if line == b'STARTTLS':
- if self.fillBuffer:
- for x in range(500):
- self.sendLine(b'X' * 1000)
- self.sendLine(b'READY')
- if self.doTLS:
- ctx = ServerTLSContext(
- privateKeyFileName=certPath,
- certificateFileName=certPath,
- )
- self.transport.startTLS(ctx, self.factory.server)
- else:
- self.setRawMode()
- def rawDataReceived(self, data):
- self.factory.rawdata += data
- self.transport.loseConnection()
- def connectionLost(self, reason):
- self.deferred.callback(None)
- class SingleLineServerProtocol(protocol.Protocol):
- """
- A protocol that sends a single line of data at C{connectionMade}.
- """
- def connectionMade(self):
- self.transport.write(b"+OK <some crap>\r\n")
- self.transport.getPeerCertificate()
- class RecordingClientProtocol(protocol.Protocol):
- """
- @ivar deferred: a deferred that will fire with first received content.
- @type deferred: L{defer.Deferred}
- """
- def __init__(self):
- self.deferred = defer.Deferred()
- def connectionMade(self):
- self.transport.getPeerCertificate()
- def dataReceived(self, data):
- self.deferred.callback(data)
- @implementer(interfaces.IHandshakeListener)
- class ImmediatelyDisconnectingProtocol(protocol.Protocol):
- """
- A protocol that disconnect immediately on connection. It fires the
- C{connectionDisconnected} deferred of its factory on connetion lost.
- """
- def handshakeCompleted(self):
- self.transport.loseConnection()
- def connectionLost(self, reason):
- self.factory.connectionDisconnected.callback(None)
- def generateCertificateObjects(organization, organizationalUnit):
- """
- Create a certificate for given C{organization} and C{organizationalUnit}.
- @return: a tuple of (key, request, certificate) objects.
- """
- pkey = crypto.PKey()
- pkey.generate_key(crypto.TYPE_RSA, 1024)
- req = crypto.X509Req()
- subject = req.get_subject()
- subject.O = organization
- subject.OU = organizationalUnit
- req.set_pubkey(pkey)
- req.sign(pkey, "md5")
- # Here comes the actual certificate
- cert = crypto.X509()
- cert.set_serial_number(1)
- cert.gmtime_adj_notBefore(0)
- cert.gmtime_adj_notAfter(60) # Testing certificates need not be long lived
- cert.set_issuer(req.get_subject())
- cert.set_subject(req.get_subject())
- cert.set_pubkey(req.get_pubkey())
- cert.sign(pkey, "md5")
- return pkey, req, cert
- def generateCertificateFiles(basename, organization, organizationalUnit):
- """
- Create certificate files key, req and cert prefixed by C{basename} for
- given C{organization} and C{organizationalUnit}.
- """
- pkey, req, cert = generateCertificateObjects(organization, organizationalUnit)
- for ext, obj, dumpFunc in [
- ('key', pkey, crypto.dump_privatekey),
- ('req', req, crypto.dump_certificate_request),
- ('cert', cert, crypto.dump_certificate)]:
- fName = os.extsep.join((basename, ext)).encode("utf-8")
- FilePath(fName).setContent(dumpFunc(crypto.FILETYPE_PEM, obj))
- class ContextGeneratingMixin:
- """
- Offer methods to create L{ssl.DefaultOpenSSLContextFactory} for both client
- and server.
- @ivar clientBase: prefix of client certificate files.
- @type clientBase: C{str}
- @ivar serverBase: prefix of server certificate files.
- @type serverBase: C{str}
- @ivar clientCtxFactory: a generated context factory to be used in
- L{IReactorSSL.connectSSL}.
- @type clientCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
- @ivar serverCtxFactory: a generated context factory to be used in
- L{IReactorSSL.listenSSL}.
- @type serverCtxFactory: L{ssl.DefaultOpenSSLContextFactory}
- """
- def makeContextFactory(self, org, orgUnit, *args, **kwArgs):
- base = self.mktemp()
- generateCertificateFiles(base, org, orgUnit)
- serverCtxFactory = ssl.DefaultOpenSSLContextFactory(
- os.extsep.join((base, 'key')),
- os.extsep.join((base, 'cert')),
- *args, **kwArgs)
- return base, serverCtxFactory
- def setupServerAndClient(self, clientArgs, clientKwArgs, serverArgs,
- serverKwArgs):
- self.clientBase, self.clientCtxFactory = self.makeContextFactory(
- *clientArgs, **clientKwArgs)
- self.serverBase, self.serverCtxFactory = self.makeContextFactory(
- *serverArgs, **serverKwArgs)
- if SSL is not None:
- class ServerTLSContext(ssl.DefaultOpenSSLContextFactory):
- """
- A context factory with a default method set to
- L{OpenSSL.SSL.TLSv1_METHOD}.
- """
- isClient = False
- def __init__(self, *args, **kw):
- kw['sslmethod'] = SSL.TLSv1_METHOD
- ssl.DefaultOpenSSLContextFactory.__init__(self, *args, **kw)
- class StolenTCPTests(ProperlyCloseFilesMixin, unittest.TestCase):
- """
- For SSL transports, test many of the same things which are tested for
- TCP transports.
- """
- def createServer(self, address, portNumber, factory):
- """
- Create an SSL server with a certificate using L{IReactorSSL.listenSSL}.
- """
- cert = ssl.PrivateCertificate.loadPEM(FilePath(certPath).getContent())
- contextFactory = cert.options()
- return reactor.listenSSL(
- portNumber, factory, contextFactory, interface=address)
- def connectClient(self, address, portNumber, clientCreator):
- """
- Create an SSL client using L{IReactorSSL.connectSSL}.
- """
- contextFactory = ssl.CertificateOptions()
- return clientCreator.connectSSL(address, portNumber, contextFactory)
- def getHandleExceptionType(self):
- """
- Return L{OpenSSL.SSL.Error} as the expected error type which will be
- raised by a write to the L{OpenSSL.SSL.Connection} object after it has
- been closed.
- """
- return SSL.Error
- def getHandleErrorCode(self):
- """
- Return the argument L{OpenSSL.SSL.Error} will be constructed with for
- this case. This is basically just a random OpenSSL implementation
- detail. It would be better if this test worked in a way which did not
- require this.
- """
- # Windows 2000 SP 4 and Windows XP SP 2 give back WSAENOTSOCK for
- # SSL.Connection.write for some reason. The twisted.protocols.tls
- # implementation of IReactorSSL doesn't suffer from this imprecation,
- # though, since it is isolated from the Windows I/O layer (I suppose?).
- # If test_properlyCloseFiles waited for the SSL handshake to complete
- # and performed an orderly shutdown, then this would probably be a
- # little less weird: writing to a shutdown SSL connection has a more
- # well-defined failure mode (or at least it should).
- # So figure out if twisted.protocols.tls is in use. If it can be
- # imported, it should be.
- if requireModule('twisted.protocols.tls') is None:
- # It isn't available, so we expect WSAENOTSOCK if we're on Windows.
- if platform.getType() == 'win32':
- return errno.WSAENOTSOCK
- # Otherwise, we expect an error about how we tried to write to a
- # shutdown connection. This is terribly implementation-specific.
- return [('SSL routines', 'SSL_write', 'protocol is shutdown')]
- class TLSTests(unittest.TestCase):
- """
- Tests for startTLS support.
- @ivar fillBuffer: forwarded to L{LineCollector.fillBuffer}
- @type fillBuffer: C{bool}
- """
- fillBuffer = False
- clientProto = None
- serverProto = None
- def tearDown(self):
- if self.clientProto.transport is not None:
- self.clientProto.transport.loseConnection()
- if self.serverProto.transport is not None:
- self.serverProto.transport.loseConnection()
- def _runTest(self, clientProto, serverProto, clientIsServer=False):
- """
- Helper method to run TLS tests.
- @param clientProto: protocol instance attached to the client
- connection.
- @param serverProto: protocol instance attached to the server
- connection.
- @param clientIsServer: flag indicated if client should initiate
- startTLS instead of server.
- @return: a L{defer.Deferred} that will fire when both connections are
- lost.
- """
- self.clientProto = clientProto
- cf = self.clientFactory = protocol.ClientFactory()
- cf.protocol = lambda: clientProto
- if clientIsServer:
- cf.server = False
- else:
- cf.client = True
- self.serverProto = serverProto
- sf = self.serverFactory = protocol.ServerFactory()
- sf.protocol = lambda: serverProto
- if clientIsServer:
- sf.client = False
- else:
- sf.server = True
- port = reactor.listenTCP(0, sf, interface="127.0.0.1")
- self.addCleanup(port.stopListening)
- reactor.connectTCP('127.0.0.1', port.getHost().port, cf)
- return defer.gatherResults([clientProto.deferred, serverProto.deferred])
- def test_TLS(self):
- """
- Test for server and client startTLS: client should received data both
- before and after the startTLS.
- """
- def check(ignore):
- self.assertEqual(
- self.serverFactory.lines,
- UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
- )
- d = self._runTest(UnintelligentProtocol(),
- LineCollector(True, self.fillBuffer))
- return d.addCallback(check)
- def test_unTLS(self):
- """
- Test for server startTLS not followed by a startTLS in client: the data
- received after server startTLS should be received as raw.
- """
- def check(ignored):
- self.assertEqual(
- self.serverFactory.lines,
- UnintelligentProtocol.pretext
- )
- self.assertTrue(self.serverFactory.rawdata,
- "No encrypted bytes received")
- d = self._runTest(UnintelligentProtocol(),
- LineCollector(False, self.fillBuffer))
- return d.addCallback(check)
- def test_backwardsTLS(self):
- """
- Test startTLS first initiated by client.
- """
- def check(ignored):
- self.assertEqual(
- self.clientFactory.lines,
- UnintelligentProtocol.pretext + UnintelligentProtocol.posttext
- )
- d = self._runTest(LineCollector(True, self.fillBuffer),
- UnintelligentProtocol(), True)
- return d.addCallback(check)
- class SpammyTLSTests(TLSTests):
- """
- Test TLS features with bytes sitting in the out buffer.
- """
- fillBuffer = True
- class BufferingTests(unittest.TestCase):
- serverProto = None
- clientProto = None
- def tearDown(self):
- if self.serverProto.transport is not None:
- self.serverProto.transport.loseConnection()
- if self.clientProto.transport is not None:
- self.clientProto.transport.loseConnection()
- return waitUntilAllDisconnected(
- reactor, [self.serverProto, self.clientProto])
- def test_openSSLBuffering(self):
- serverProto = self.serverProto = SingleLineServerProtocol()
- clientProto = self.clientProto = RecordingClientProtocol()
- server = protocol.ServerFactory()
- client = self.client = protocol.ClientFactory()
- server.protocol = lambda: serverProto
- client.protocol = lambda: clientProto
- sCTX = ssl.DefaultOpenSSLContextFactory(certPath, certPath)
- cCTX = ssl.ClientContextFactory()
- port = reactor.listenSSL(0, server, sCTX, interface='127.0.0.1')
- self.addCleanup(port.stopListening)
- clientConnector = reactor.connectSSL('127.0.0.1', port.getHost().port,
- client, cCTX)
- self.addCleanup(clientConnector.disconnect)
- return clientProto.deferred.addCallback(
- self.assertEqual, b"+OK <some crap>\r\n")
- class ConnectionLostTests(unittest.TestCase, ContextGeneratingMixin):
- """
- SSL connection closing tests.
- """
- def testImmediateDisconnect(self):
- org = "twisted.test.test_ssl"
- self.setupServerAndClient(
- (org, org + ", client"), {},
- (org, org + ", server"), {})
- # Set up a server, connect to it with a client, which should work since our verifiers
- # allow anything, then disconnect.
- serverProtocolFactory = protocol.ServerFactory()
- serverProtocolFactory.protocol = protocol.Protocol
- self.serverPort = serverPort = reactor.listenSSL(0,
- serverProtocolFactory, self.serverCtxFactory)
- clientProtocolFactory = protocol.ClientFactory()
- clientProtocolFactory.protocol = ImmediatelyDisconnectingProtocol
- clientProtocolFactory.connectionDisconnected = defer.Deferred()
- reactor.connectSSL('127.0.0.1',
- serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
- return clientProtocolFactory.connectionDisconnected.addCallback(
- lambda ignoredResult: self.serverPort.stopListening())
- def test_bothSidesLoseConnection(self):
- """
- Both sides of SSL connection close connection; the connections should
- close cleanly, and only after the underlying TCP connection has
- disconnected.
- """
- @implementer(interfaces.IHandshakeListener)
- class CloseAfterHandshake(protocol.Protocol):
- gotData = False
- def __init__(self):
- self.done = defer.Deferred()
- def handshakeCompleted(self):
- self.transport.loseConnection()
- def connectionLost(self, reason):
- self.done.errback(reason)
- del self.done
- org = "twisted.test.test_ssl"
- self.setupServerAndClient(
- (org, org + ", client"), {},
- (org, org + ", server"), {})
- serverProtocol = CloseAfterHandshake()
- serverProtocolFactory = protocol.ServerFactory()
- serverProtocolFactory.protocol = lambda: serverProtocol
- serverPort = reactor.listenSSL(0,
- serverProtocolFactory, self.serverCtxFactory)
- self.addCleanup(serverPort.stopListening)
- clientProtocol = CloseAfterHandshake()
- clientProtocolFactory = protocol.ClientFactory()
- clientProtocolFactory.protocol = lambda: clientProtocol
- reactor.connectSSL('127.0.0.1',
- serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
- def checkResult(failure):
- failure.trap(ConnectionDone)
- return defer.gatherResults(
- [clientProtocol.done.addErrback(checkResult),
- serverProtocol.done.addErrback(checkResult)])
- if newTLS is None:
- test_bothSidesLoseConnection.skip = "Old SSL code doesn't always close cleanly."
- def testFailedVerify(self):
- org = "twisted.test.test_ssl"
- self.setupServerAndClient(
- (org, org + ", client"), {},
- (org, org + ", server"), {})
- def verify(*a):
- return False
- self.clientCtxFactory.getContext().set_verify(SSL.VERIFY_PEER, verify)
- serverConnLost = defer.Deferred()
- serverProtocol = protocol.Protocol()
- serverProtocol.connectionLost = serverConnLost.callback
- serverProtocolFactory = protocol.ServerFactory()
- serverProtocolFactory.protocol = lambda: serverProtocol
- self.serverPort = serverPort = reactor.listenSSL(0,
- serverProtocolFactory, self.serverCtxFactory)
- clientConnLost = defer.Deferred()
- clientProtocol = protocol.Protocol()
- clientProtocol.connectionLost = clientConnLost.callback
- clientProtocolFactory = protocol.ClientFactory()
- clientProtocolFactory.protocol = lambda: clientProtocol
- reactor.connectSSL('127.0.0.1',
- serverPort.getHost().port, clientProtocolFactory, self.clientCtxFactory)
- dl = defer.DeferredList([serverConnLost, clientConnLost], consumeErrors=True)
- return dl.addCallback(self._cbLostConns)
- def _cbLostConns(self, results):
- (sSuccess, sResult), (cSuccess, cResult) = results
- self.assertFalse(sSuccess)
- self.assertFalse(cSuccess)
- acceptableErrors = [SSL.Error]
- # Rather than getting a verification failure on Windows, we are getting
- # a connection failure. Without something like sslverify proxying
- # in-between we can't fix up the platform's errors, so let's just
- # specifically say it is only OK in this one case to keep the tests
- # passing. Normally we'd like to be as strict as possible here, so
- # we're not going to allow this to report errors incorrectly on any
- # other platforms.
- if platform.isWindows():
- from twisted.internet.error import ConnectionLost
- acceptableErrors.append(ConnectionLost)
- sResult.trap(*acceptableErrors)
- cResult.trap(*acceptableErrors)
- return self.serverPort.stopListening()
- class FakeContext:
- """
- L{OpenSSL.SSL.Context} double which can more easily be inspected.
- """
- def __init__(self, method):
- self._method = method
- self._options = 0
- def set_options(self, options):
- self._options |= options
- def use_certificate_file(self, fileName):
- pass
- def use_privatekey_file(self, fileName):
- pass
- class DefaultOpenSSLContextFactoryTests(unittest.TestCase):
- """
- Tests for L{ssl.DefaultOpenSSLContextFactory}.
- """
- def setUp(self):
- # pyOpenSSL Context objects aren't introspectable enough. Pass in
- # an alternate context factory so we can inspect what is done to it.
- self.contextFactory = ssl.DefaultOpenSSLContextFactory(
- certPath, certPath, _contextFactory=FakeContext)
- self.context = self.contextFactory.getContext()
- def test_method(self):
- """
- L{ssl.DefaultOpenSSLContextFactory.getContext} returns an SSL context
- which can use SSLv3 or TLSv1 but not SSLv2.
- """
- # SSLv23_METHOD allows SSLv2, SSLv3, or TLSv1
- self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
- # And OP_NO_SSLv2 disables the SSLv2 support.
- self.assertEqual(self.context._options & SSL.OP_NO_SSLv2,
- SSL.OP_NO_SSLv2)
- # Make sure SSLv3 and TLSv1 aren't disabled though.
- self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
- self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
- def test_missingCertificateFile(self):
- """
- Instantiating L{ssl.DefaultOpenSSLContextFactory} with a certificate
- filename which does not identify an existing file results in the
- initializer raising L{OpenSSL.SSL.Error}.
- """
- self.assertRaises(
- SSL.Error,
- ssl.DefaultOpenSSLContextFactory, certPath, self.mktemp())
- def test_missingPrivateKeyFile(self):
- """
- Instantiating L{ssl.DefaultOpenSSLContextFactory} with a private key
- filename which does not identify an existing file results in the
- initializer raising L{OpenSSL.SSL.Error}.
- """
- self.assertRaises(
- SSL.Error,
- ssl.DefaultOpenSSLContextFactory, self.mktemp(), certPath)
- class ClientContextFactoryTests(unittest.TestCase):
- """
- Tests for L{ssl.ClientContextFactory}.
- """
- def setUp(self):
- self.contextFactory = ssl.ClientContextFactory()
- self.contextFactory._contextFactory = FakeContext
- self.context = self.contextFactory.getContext()
- def test_method(self):
- """
- L{ssl.ClientContextFactory.getContext} returns a context which can use
- SSLv3 or TLSv1 but not SSLv2.
- """
- self.assertEqual(self.context._method, SSL.SSLv23_METHOD)
- self.assertEqual(self.context._options & SSL.OP_NO_SSLv2,
- SSL.OP_NO_SSLv2)
- self.assertFalse(self.context._options & SSL.OP_NO_SSLv3)
- self.assertFalse(self.context._options & SSL.OP_NO_TLSv1)
- if interfaces.IReactorSSL(reactor, None) is None:
- for tCase in [StolenTCPTests, TLSTests, SpammyTLSTests,
- BufferingTests, ConnectionLostTests,
- DefaultOpenSSLContextFactoryTests,
- ClientContextFactoryTests]:
- tCase.skip = "Reactor does not support SSL, cannot run SSL tests"
|