iosim.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553
  1. # -*- test-case-name: twisted.test.test_amp,twisted.test.test_iosim -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. Utilities and helpers for simulating a network
  6. """
  7. from __future__ import absolute_import, division, print_function
  8. import itertools
  9. try:
  10. from OpenSSL.SSL import Error as NativeOpenSSLError
  11. except ImportError:
  12. pass
  13. from zope.interface import implementer, directlyProvides
  14. from twisted.internet.endpoints import TCP4ClientEndpoint, TCP4ServerEndpoint
  15. from twisted.internet.protocol import Factory, Protocol
  16. from twisted.internet.error import ConnectionRefusedError
  17. from twisted.python.failure import Failure
  18. from twisted.internet import error
  19. from twisted.internet import interfaces
  20. from .proto_helpers import MemoryReactorClock
  21. class TLSNegotiation:
  22. def __init__(self, obj, connectState):
  23. self.obj = obj
  24. self.connectState = connectState
  25. self.sent = False
  26. self.readyToSend = connectState
  27. def __repr__(self):
  28. return 'TLSNegotiation(%r)' % (self.obj,)
  29. def pretendToVerify(self, other, tpt):
  30. # Set the transport problems list here? disconnections?
  31. # hmmmmm... need some negative path tests.
  32. if not self.obj.iosimVerify(other.obj):
  33. tpt.disconnectReason = NativeOpenSSLError()
  34. tpt.loseConnection()
  35. @implementer(interfaces.IAddress)
  36. class FakeAddress(object):
  37. """
  38. The default address type for the host and peer of L{FakeTransport}
  39. connections.
  40. """
  41. @implementer(interfaces.ITransport,
  42. interfaces.ITLSTransport)
  43. class FakeTransport:
  44. """
  45. A wrapper around a file-like object to make it behave as a Transport.
  46. This doesn't actually stream the file to the attached protocol,
  47. and is thus useful mainly as a utility for debugging protocols.
  48. """
  49. _nextserial = staticmethod(lambda counter=itertools.count(): next(counter))
  50. closed = 0
  51. disconnecting = 0
  52. disconnected = 0
  53. disconnectReason = error.ConnectionDone("Connection done")
  54. producer = None
  55. streamingProducer = 0
  56. tls = None
  57. def __init__(self, protocol, isServer, hostAddress=None, peerAddress=None):
  58. """
  59. @param protocol: This transport will deliver bytes to this protocol.
  60. @type protocol: L{IProtocol} provider
  61. @param isServer: C{True} if this is the accepting side of the
  62. connection, C{False} if it is the connecting side.
  63. @type isServer: L{bool}
  64. @param hostAddress: The value to return from C{getHost}. L{None}
  65. results in a new L{FakeAddress} being created to use as the value.
  66. @type hostAddress: L{IAddress} provider or L{None}
  67. @param peerAddress: The value to return from C{getPeer}. L{None}
  68. results in a new L{FakeAddress} being created to use as the value.
  69. @type peerAddress: L{IAddress} provider or L{None}
  70. """
  71. self.protocol = protocol
  72. self.isServer = isServer
  73. self.stream = []
  74. self.serial = self._nextserial()
  75. if hostAddress is None:
  76. hostAddress = FakeAddress()
  77. self.hostAddress = hostAddress
  78. if peerAddress is None:
  79. peerAddress = FakeAddress()
  80. self.peerAddress = peerAddress
  81. def __repr__(self):
  82. return 'FakeTransport<%s,%s,%s>' % (
  83. self.isServer and 'S' or 'C', self.serial,
  84. self.protocol.__class__.__name__)
  85. def write(self, data):
  86. if self.tls is not None:
  87. self.tlsbuf.append(data)
  88. else:
  89. self.stream.append(data)
  90. def _checkProducer(self):
  91. # Cheating; this is called at "idle" times to allow producers to be
  92. # found and dealt with
  93. if self.producer and not self.streamingProducer:
  94. self.producer.resumeProducing()
  95. def registerProducer(self, producer, streaming):
  96. """
  97. From abstract.FileDescriptor
  98. """
  99. self.producer = producer
  100. self.streamingProducer = streaming
  101. if not streaming:
  102. producer.resumeProducing()
  103. def unregisterProducer(self):
  104. self.producer = None
  105. def stopConsuming(self):
  106. self.unregisterProducer()
  107. self.loseConnection()
  108. def writeSequence(self, iovec):
  109. self.write(b"".join(iovec))
  110. def loseConnection(self):
  111. self.disconnecting = True
  112. def abortConnection(self):
  113. """
  114. For the time being, this is the same as loseConnection; no buffered
  115. data will be lost.
  116. """
  117. self.disconnecting = True
  118. def reportDisconnect(self):
  119. if self.tls is not None:
  120. # We were in the middle of negotiating! Must have been a TLS
  121. # problem.
  122. err = NativeOpenSSLError()
  123. else:
  124. err = self.disconnectReason
  125. self.protocol.connectionLost(Failure(err))
  126. def logPrefix(self):
  127. """
  128. Identify this transport/event source to the logging system.
  129. """
  130. return "iosim"
  131. def getPeer(self):
  132. return self.peerAddress
  133. def getHost(self):
  134. return self.hostAddress
  135. def resumeProducing(self):
  136. # Never sends data anyways
  137. pass
  138. def pauseProducing(self):
  139. # Never sends data anyways
  140. pass
  141. def stopProducing(self):
  142. self.loseConnection()
  143. def startTLS(self, contextFactory, beNormal=True):
  144. # Nothing's using this feature yet, but startTLS has an undocumented
  145. # second argument which defaults to true; if set to False, servers will
  146. # behave like clients and clients will behave like servers.
  147. connectState = self.isServer ^ beNormal
  148. self.tls = TLSNegotiation(contextFactory, connectState)
  149. self.tlsbuf = []
  150. def getOutBuffer(self):
  151. """
  152. Get the pending writes from this transport, clearing them from the
  153. pending buffer.
  154. @return: the bytes written with C{transport.write}
  155. @rtype: L{bytes}
  156. """
  157. S = self.stream
  158. if S:
  159. self.stream = []
  160. return b''.join(S)
  161. elif self.tls is not None:
  162. if self.tls.readyToSend:
  163. # Only _send_ the TLS negotiation "packet" if I'm ready to.
  164. self.tls.sent = True
  165. return self.tls
  166. else:
  167. return None
  168. else:
  169. return None
  170. def bufferReceived(self, buf):
  171. if isinstance(buf, TLSNegotiation):
  172. assert self.tls is not None # By the time you're receiving a
  173. # negotiation, you have to have called
  174. # startTLS already.
  175. if self.tls.sent:
  176. self.tls.pretendToVerify(buf, self)
  177. self.tls = None # We're done with the handshake if we've gotten
  178. # this far... although maybe it failed...?
  179. # TLS started! Unbuffer...
  180. b, self.tlsbuf = self.tlsbuf, None
  181. self.writeSequence(b)
  182. directlyProvides(self, interfaces.ISSLTransport)
  183. else:
  184. # We haven't sent our own TLS negotiation: time to do that!
  185. self.tls.readyToSend = True
  186. else:
  187. self.protocol.dataReceived(buf)
  188. def makeFakeClient(clientProtocol):
  189. """
  190. Create and return a new in-memory transport hooked up to the given protocol.
  191. @param clientProtocol: The client protocol to use.
  192. @type clientProtocol: L{IProtocol} provider
  193. @return: The transport.
  194. @rtype: L{FakeTransport}
  195. """
  196. return FakeTransport(clientProtocol, isServer=False)
  197. def makeFakeServer(serverProtocol):
  198. """
  199. Create and return a new in-memory transport hooked up to the given protocol.
  200. @param serverProtocol: The server protocol to use.
  201. @type serverProtocol: L{IProtocol} provider
  202. @return: The transport.
  203. @rtype: L{FakeTransport}
  204. """
  205. return FakeTransport(serverProtocol, isServer=True)
  206. class IOPump:
  207. """
  208. Utility to pump data between clients and servers for protocol testing.
  209. Perhaps this is a utility worthy of being in protocol.py?
  210. """
  211. def __init__(self, client, server, clientIO, serverIO, debug):
  212. self.client = client
  213. self.server = server
  214. self.clientIO = clientIO
  215. self.serverIO = serverIO
  216. self.debug = debug
  217. def flush(self, debug=False):
  218. """
  219. Pump until there is no more input or output.
  220. Returns whether any data was moved.
  221. """
  222. result = False
  223. for x in range(1000):
  224. if self.pump(debug):
  225. result = True
  226. else:
  227. break
  228. else:
  229. assert 0, "Too long"
  230. return result
  231. def pump(self, debug=False):
  232. """
  233. Move data back and forth.
  234. Returns whether any data was moved.
  235. """
  236. if self.debug or debug:
  237. print('-- GLUG --')
  238. sData = self.serverIO.getOutBuffer()
  239. cData = self.clientIO.getOutBuffer()
  240. self.clientIO._checkProducer()
  241. self.serverIO._checkProducer()
  242. if self.debug or debug:
  243. print('.')
  244. # XXX slightly buggy in the face of incremental output
  245. if cData:
  246. print('C: ' + repr(cData))
  247. if sData:
  248. print('S: ' + repr(sData))
  249. if cData:
  250. self.serverIO.bufferReceived(cData)
  251. if sData:
  252. self.clientIO.bufferReceived(sData)
  253. if cData or sData:
  254. return True
  255. if (self.serverIO.disconnecting and
  256. not self.serverIO.disconnected):
  257. if self.debug or debug:
  258. print('* C')
  259. self.serverIO.disconnected = True
  260. self.clientIO.disconnecting = True
  261. self.clientIO.reportDisconnect()
  262. return True
  263. if self.clientIO.disconnecting and not self.clientIO.disconnected:
  264. if self.debug or debug:
  265. print('* S')
  266. self.clientIO.disconnected = True
  267. self.serverIO.disconnecting = True
  268. self.serverIO.reportDisconnect()
  269. return True
  270. return False
  271. def connect(serverProtocol, serverTransport, clientProtocol, clientTransport,
  272. debug=False, greet=True):
  273. """
  274. Create a new L{IOPump} connecting two protocols.
  275. @param serverProtocol: The protocol to use on the accepting side of the
  276. connection.
  277. @type serverProtocol: L{IProtocol} provider
  278. @param serverTransport: The transport to associate with C{serverProtocol}.
  279. @type serverTransport: L{FakeTransport}
  280. @param clientProtocol: The protocol to use on the initiating side of the
  281. connection.
  282. @type clientProtocol: L{IProtocol} provider
  283. @param clientTransport: The transport to associate with C{clientProtocol}.
  284. @type clientTransport: L{FakeTransport}
  285. @param debug: A flag indicating whether to log information about what the
  286. L{IOPump} is doing.
  287. @type debug: L{bool}
  288. @param greet: Should the L{IOPump} be L{flushed <IOPump.flush>} once before
  289. returning to put the protocols into their post-handshake or
  290. post-server-greeting state?
  291. @type greet: L{bool}
  292. @return: An L{IOPump} which connects C{serverProtocol} and
  293. C{clientProtocol} and delivers bytes between them when it is pumped.
  294. @rtype: L{IOPump}
  295. """
  296. serverProtocol.makeConnection(serverTransport)
  297. clientProtocol.makeConnection(clientTransport)
  298. pump = IOPump(
  299. clientProtocol, serverProtocol, clientTransport, serverTransport, debug
  300. )
  301. if greet:
  302. # Kick off server greeting, etc
  303. pump.flush()
  304. return pump
  305. def connectedServerAndClient(ServerClass, ClientClass,
  306. clientTransportFactory=makeFakeClient,
  307. serverTransportFactory=makeFakeServer,
  308. debug=False, greet=True):
  309. """
  310. Connect a given server and client class to each other.
  311. @param ServerClass: a callable that produces the server-side protocol.
  312. @type ServerClass: 0-argument callable returning L{IProtocol} provider.
  313. @param ClientClass: like C{ServerClass} but for the other side of the
  314. connection.
  315. @type ClientClass: 0-argument callable returning L{IProtocol} provider.
  316. @param clientTransportFactory: a callable that produces the transport which
  317. will be attached to the protocol returned from C{ClientClass}.
  318. @type clientTransportFactory: callable taking (L{IProtocol}) and returning
  319. L{FakeTransport}
  320. @param serverTransportFactory: a callable that produces the transport which
  321. will be attached to the protocol returned from C{ServerClass}.
  322. @type serverTransportFactory: callable taking (L{IProtocol}) and returning
  323. L{FakeTransport}
  324. @param debug: Should this dump an escaped version of all traffic on this
  325. connection to stdout for inspection?
  326. @type debug: L{bool}
  327. @param greet: Should the L{IOPump} be L{flushed <IOPump.flush>} once before
  328. returning to put the protocols into their post-handshake or
  329. post-server-greeting state?
  330. @type greet: L{bool}
  331. @return: the client protocol, the server protocol, and an L{IOPump} which,
  332. when its C{pump} and C{flush} methods are called, will move data
  333. between the created client and server protocol instances.
  334. @rtype: 3-L{tuple} of L{IProtocol}, L{IProtocol}, L{IOPump}
  335. """
  336. c = ClientClass()
  337. s = ServerClass()
  338. cio = clientTransportFactory(c)
  339. sio = serverTransportFactory(s)
  340. return c, s, connect(s, sio, c, cio, debug, greet)
  341. def _factoriesShouldConnect(clientInfo, serverInfo):
  342. """
  343. Should the client and server described by the arguments be connected to
  344. each other, i.e. do their port numbers match?
  345. @param clientInfo: the args for connectTCP
  346. @type clientInfo: L{tuple}
  347. @param serverInfo: the args for listenTCP
  348. @type serverInfo: L{tuple}
  349. @return: If they do match, return factories for the client and server that
  350. should connect; otherwise return L{None}, indicating they shouldn't be
  351. connected.
  352. @rtype: L{None} or 2-L{tuple} of (L{ClientFactory},
  353. L{IProtocolFactory})
  354. """
  355. (clientHost, clientPort, clientFactory, clientTimeout,
  356. clientBindAddress) = clientInfo
  357. (serverPort, serverFactory, serverBacklog,
  358. serverInterface) = serverInfo
  359. if serverPort == clientPort:
  360. return clientFactory, serverFactory
  361. else:
  362. return None
  363. class ConnectionCompleter(object):
  364. """
  365. A L{ConnectionCompleter} can cause synthetic TCP connections established by
  366. L{MemoryReactor.connectTCP} and L{MemoryReactor.listenTCP} to succeed or
  367. fail.
  368. """
  369. def __init__(self, memoryReactor):
  370. """
  371. Create a L{ConnectionCompleter} from a L{MemoryReactor}.
  372. @param memoryReactor: The reactor to attach to.
  373. @type memoryReactor: L{MemoryReactor}
  374. """
  375. self._reactor = memoryReactor
  376. def succeedOnce(self, debug=False):
  377. """
  378. Complete a single TCP connection established on this
  379. L{ConnectionCompleter}'s L{MemoryReactor}.
  380. @param debug: A flag; whether to dump output from the established
  381. connection to stdout.
  382. @type debug: L{bool}
  383. @return: a pump for the connection, or L{None} if no connection could
  384. be established.
  385. @rtype: L{IOPump} or L{None}
  386. """
  387. memoryReactor = self._reactor
  388. for clientIdx, clientInfo in enumerate(memoryReactor.tcpClients):
  389. for serverInfo in memoryReactor.tcpServers:
  390. factories = _factoriesShouldConnect(clientInfo, serverInfo)
  391. if factories:
  392. memoryReactor.tcpClients.remove(clientInfo)
  393. memoryReactor.connectors.pop(clientIdx)
  394. clientFactory, serverFactory = factories
  395. clientProtocol = clientFactory.buildProtocol(None)
  396. serverProtocol = serverFactory.buildProtocol(None)
  397. serverTransport = makeFakeServer(serverProtocol)
  398. clientTransport = makeFakeClient(clientProtocol)
  399. return connect(serverProtocol, serverTransport,
  400. clientProtocol, clientTransport,
  401. debug)
  402. def failOnce(self, reason=Failure(ConnectionRefusedError())):
  403. """
  404. Fail a single TCP connection established on this
  405. L{ConnectionCompleter}'s L{MemoryReactor}.
  406. @param reason: the reason to provide that the connection failed.
  407. @type reason: L{Failure}
  408. """
  409. self._reactor.tcpClients.pop(0)[2].clientConnectionFailed(
  410. self._reactor.connectors.pop(0), reason
  411. )
  412. def connectableEndpoint(debug=False):
  413. """
  414. Create an endpoint that can be fired on demand.
  415. @param debug: A flag; whether to dump output from the established
  416. connection to stdout.
  417. @type debug: L{bool}
  418. @return: A client endpoint, and an object that will cause one of the
  419. L{Deferred}s returned by that client endpoint.
  420. @rtype: 2-L{tuple} of (L{IStreamClientEndpoint}, L{ConnectionCompleter})
  421. """
  422. reactor = MemoryReactorClock()
  423. clientEndpoint = TCP4ClientEndpoint(reactor, "0.0.0.0", 4321)
  424. serverEndpoint = TCP4ServerEndpoint(reactor, 4321)
  425. serverEndpoint.listen(Factory.forProtocol(Protocol))
  426. return clientEndpoint, ConnectionCompleter(reactor)