test_newtls.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Copyright (c) Twisted Matrix Laboratories.
  2. # See LICENSE for details.
  3. """
  4. Tests for L{twisted.internet._newtls}.
  5. """
  6. from __future__ import division, absolute_import
  7. from twisted.trial import unittest
  8. from twisted.internet import interfaces
  9. from twisted.internet.test.reactormixins import ReactorBuilder
  10. from twisted.internet.test.connectionmixins import (
  11. ConnectableProtocol, runProtocolsWithReactor)
  12. from twisted.internet.test.test_tls import SSLCreator, TLSMixin
  13. from twisted.internet.test.test_tls import StartTLSClientCreator
  14. from twisted.internet.test.test_tls import ContextGeneratingMixin
  15. from twisted.internet.test.test_tcp import TCPCreator
  16. try:
  17. from twisted.protocols import tls
  18. from twisted.internet import _newtls
  19. except ImportError:
  20. _newtls = None
  21. from zope.interface import implementer
  22. class BypassTLSTests(unittest.TestCase):
  23. """
  24. Tests for the L{_newtls._BypassTLS} class.
  25. """
  26. if not _newtls:
  27. skip = "Couldn't import _newtls, perhaps pyOpenSSL is old or missing"
  28. def test_loseConnectionPassThrough(self):
  29. """
  30. C{_BypassTLS.loseConnection} calls C{loseConnection} on the base
  31. class, while preserving any default argument in the base class'
  32. C{loseConnection} implementation.
  33. """
  34. default = object()
  35. result = []
  36. class FakeTransport(object):
  37. def loseConnection(self, _connDone=default):
  38. result.append(_connDone)
  39. bypass = _newtls._BypassTLS(FakeTransport, FakeTransport())
  40. # The default from FakeTransport is used:
  41. bypass.loseConnection()
  42. self.assertEqual(result, [default])
  43. # And we can pass our own:
  44. notDefault = object()
  45. bypass.loseConnection(notDefault)
  46. self.assertEqual(result, [default, notDefault])
  47. class FakeProducer(object):
  48. """
  49. A producer that does nothing.
  50. """
  51. def pauseProducing(self):
  52. pass
  53. def resumeProducing(self):
  54. pass
  55. def stopProducing(self):
  56. pass
  57. @implementer(interfaces.IHandshakeListener)
  58. class ProducerProtocol(ConnectableProtocol):
  59. """
  60. Register a producer, unregister it, and verify the producer hooks up to
  61. innards of C{TLSMemoryBIOProtocol}.
  62. """
  63. def __init__(self, producer, result):
  64. self.producer = producer
  65. self.result = result
  66. def handshakeCompleted(self):
  67. if not isinstance(self.transport.protocol,
  68. tls.TLSMemoryBIOProtocol):
  69. # Either the test or the code have a bug...
  70. raise RuntimeError("TLSMemoryBIOProtocol not hooked up.")
  71. self.transport.registerProducer(self.producer, True)
  72. # The producer was registered with the TLSMemoryBIOProtocol:
  73. self.result.append(self.transport.protocol._producer._producer)
  74. self.transport.unregisterProducer()
  75. # The producer was unregistered from the TLSMemoryBIOProtocol:
  76. self.result.append(self.transport.protocol._producer)
  77. self.transport.loseConnection()
  78. class ProducerTestsMixin(ReactorBuilder, TLSMixin, ContextGeneratingMixin):
  79. """
  80. Test the new TLS code integrates C{TLSMemoryBIOProtocol} correctly.
  81. """
  82. if not _newtls:
  83. skip = "Could not import twisted.internet._newtls"
  84. def test_producerSSLFromStart(self):
  85. """
  86. C{registerProducer} and C{unregisterProducer} on TLS transports
  87. created as SSL from the get go are passed to the
  88. C{TLSMemoryBIOProtocol}, not the underlying transport directly.
  89. """
  90. result = []
  91. producer = FakeProducer()
  92. runProtocolsWithReactor(self, ConnectableProtocol(),
  93. ProducerProtocol(producer, result),
  94. SSLCreator())
  95. self.assertEqual(result, [producer, None])
  96. def test_producerAfterStartTLS(self):
  97. """
  98. C{registerProducer} and C{unregisterProducer} on TLS transports
  99. created by C{startTLS} are passed to the C{TLSMemoryBIOProtocol}, not
  100. the underlying transport directly.
  101. """
  102. result = []
  103. producer = FakeProducer()
  104. runProtocolsWithReactor(self, ConnectableProtocol(),
  105. ProducerProtocol(producer, result),
  106. StartTLSClientCreator())
  107. self.assertEqual(result, [producer, None])
  108. def startTLSAfterRegisterProducer(self, streaming):
  109. """
  110. When a producer is registered, and then startTLS is called,
  111. the producer is re-registered with the C{TLSMemoryBIOProtocol}.
  112. """
  113. clientContext = self.getClientContext()
  114. serverContext = self.getServerContext()
  115. result = []
  116. producer = FakeProducer()
  117. class RegisterTLSProtocol(ConnectableProtocol):
  118. def connectionMade(self):
  119. self.transport.registerProducer(producer, streaming)
  120. self.transport.startTLS(serverContext)
  121. # Store TLSMemoryBIOProtocol and underlying transport producer
  122. # status:
  123. if streaming:
  124. # _ProducerMembrane -> producer:
  125. result.append(self.transport.protocol._producer._producer)
  126. result.append(self.transport.producer._producer)
  127. else:
  128. # _ProducerMembrane -> _PullToPush -> producer:
  129. result.append(
  130. self.transport.protocol._producer._producer._producer)
  131. result.append(self.transport.producer._producer._producer)
  132. self.transport.unregisterProducer()
  133. self.transport.loseConnection()
  134. class StartTLSProtocol(ConnectableProtocol):
  135. def connectionMade(self):
  136. self.transport.startTLS(clientContext)
  137. runProtocolsWithReactor(self, RegisterTLSProtocol(),
  138. StartTLSProtocol(), TCPCreator())
  139. self.assertEqual(result, [producer, producer])
  140. def test_startTLSAfterRegisterProducerStreaming(self):
  141. """
  142. When a streaming producer is registered, and then startTLS is called,
  143. the producer is re-registered with the C{TLSMemoryBIOProtocol}.
  144. """
  145. self.startTLSAfterRegisterProducer(True)
  146. def test_startTLSAfterRegisterProducerNonStreaming(self):
  147. """
  148. When a non-streaming producer is registered, and then startTLS is
  149. called, the producer is re-registered with the
  150. C{TLSMemoryBIOProtocol}.
  151. """
  152. self.startTLSAfterRegisterProducer(False)
  153. globals().update(ProducerTestsMixin.makeTestCaseClasses())