policies.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743
  1. # -*- test-case-name: twisted.test.test_policies -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. """
  5. Resource limiting policies.
  6. @seealso: See also L{twisted.protocols.htb} for rate limiting.
  7. """
  8. from __future__ import division, absolute_import
  9. # system imports
  10. import sys
  11. from zope.interface import directlyProvides, providedBy
  12. # twisted imports
  13. from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
  14. from twisted.internet import error
  15. from twisted.internet.interfaces import ILoggingContext
  16. from twisted.python import log
  17. def _wrappedLogPrefix(wrapper, wrapped):
  18. """
  19. Compute a log prefix for a wrapper and the object it wraps.
  20. @rtype: C{str}
  21. """
  22. if ILoggingContext.providedBy(wrapped):
  23. logPrefix = wrapped.logPrefix()
  24. else:
  25. logPrefix = wrapped.__class__.__name__
  26. return "%s (%s)" % (logPrefix, wrapper.__class__.__name__)
  27. class ProtocolWrapper(Protocol):
  28. """
  29. Wraps protocol instances and acts as their transport as well.
  30. @ivar wrappedProtocol: An L{IProtocol<twisted.internet.interfaces.IProtocol>}
  31. provider to which L{IProtocol<twisted.internet.interfaces.IProtocol>}
  32. method calls onto this L{ProtocolWrapper} will be proxied.
  33. @ivar factory: The L{WrappingFactory} which created this
  34. L{ProtocolWrapper}.
  35. """
  36. disconnecting = 0
  37. def __init__(self, factory, wrappedProtocol):
  38. self.wrappedProtocol = wrappedProtocol
  39. self.factory = factory
  40. def logPrefix(self):
  41. """
  42. Use a customized log prefix mentioning both the wrapped protocol and
  43. the current one.
  44. """
  45. return _wrappedLogPrefix(self, self.wrappedProtocol)
  46. def makeConnection(self, transport):
  47. """
  48. When a connection is made, register this wrapper with its factory,
  49. save the real transport, and connect the wrapped protocol to this
  50. L{ProtocolWrapper} to intercept any transport calls it makes.
  51. """
  52. directlyProvides(self, providedBy(transport))
  53. Protocol.makeConnection(self, transport)
  54. self.factory.registerProtocol(self)
  55. self.wrappedProtocol.makeConnection(self)
  56. # Transport relaying
  57. def write(self, data):
  58. self.transport.write(data)
  59. def writeSequence(self, data):
  60. self.transport.writeSequence(data)
  61. def loseConnection(self):
  62. self.disconnecting = 1
  63. self.transport.loseConnection()
  64. def getPeer(self):
  65. return self.transport.getPeer()
  66. def getHost(self):
  67. return self.transport.getHost()
  68. def registerProducer(self, producer, streaming):
  69. self.transport.registerProducer(producer, streaming)
  70. def unregisterProducer(self):
  71. self.transport.unregisterProducer()
  72. def stopConsuming(self):
  73. self.transport.stopConsuming()
  74. def __getattr__(self, name):
  75. return getattr(self.transport, name)
  76. # Protocol relaying
  77. def dataReceived(self, data):
  78. self.wrappedProtocol.dataReceived(data)
  79. def connectionLost(self, reason):
  80. self.factory.unregisterProtocol(self)
  81. self.wrappedProtocol.connectionLost(reason)
  82. class WrappingFactory(ClientFactory):
  83. """
  84. Wraps a factory and its protocols, and keeps track of them.
  85. """
  86. protocol = ProtocolWrapper
  87. def __init__(self, wrappedFactory):
  88. self.wrappedFactory = wrappedFactory
  89. self.protocols = {}
  90. def logPrefix(self):
  91. """
  92. Generate a log prefix mentioning both the wrapped factory and this one.
  93. """
  94. return _wrappedLogPrefix(self, self.wrappedFactory)
  95. def doStart(self):
  96. self.wrappedFactory.doStart()
  97. ClientFactory.doStart(self)
  98. def doStop(self):
  99. self.wrappedFactory.doStop()
  100. ClientFactory.doStop(self)
  101. def startedConnecting(self, connector):
  102. self.wrappedFactory.startedConnecting(connector)
  103. def clientConnectionFailed(self, connector, reason):
  104. self.wrappedFactory.clientConnectionFailed(connector, reason)
  105. def clientConnectionLost(self, connector, reason):
  106. self.wrappedFactory.clientConnectionLost(connector, reason)
  107. def buildProtocol(self, addr):
  108. return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
  109. def registerProtocol(self, p):
  110. """
  111. Called by protocol to register itself.
  112. """
  113. self.protocols[p] = 1
  114. def unregisterProtocol(self, p):
  115. """
  116. Called by protocols when they go away.
  117. """
  118. del self.protocols[p]
  119. class ThrottlingProtocol(ProtocolWrapper):
  120. """
  121. Protocol for L{ThrottlingFactory}.
  122. """
  123. # wrap API for tracking bandwidth
  124. def write(self, data):
  125. self.factory.registerWritten(len(data))
  126. ProtocolWrapper.write(self, data)
  127. def writeSequence(self, seq):
  128. self.factory.registerWritten(sum(map(len, seq)))
  129. ProtocolWrapper.writeSequence(self, seq)
  130. def dataReceived(self, data):
  131. self.factory.registerRead(len(data))
  132. ProtocolWrapper.dataReceived(self, data)
  133. def registerProducer(self, producer, streaming):
  134. self.producer = producer
  135. ProtocolWrapper.registerProducer(self, producer, streaming)
  136. def unregisterProducer(self):
  137. del self.producer
  138. ProtocolWrapper.unregisterProducer(self)
  139. def throttleReads(self):
  140. self.transport.pauseProducing()
  141. def unthrottleReads(self):
  142. self.transport.resumeProducing()
  143. def throttleWrites(self):
  144. if hasattr(self, "producer"):
  145. self.producer.pauseProducing()
  146. def unthrottleWrites(self):
  147. if hasattr(self, "producer"):
  148. self.producer.resumeProducing()
  149. class ThrottlingFactory(WrappingFactory):
  150. """
  151. Throttles bandwidth and number of connections.
  152. Write bandwidth will only be throttled if there is a producer
  153. registered.
  154. """
  155. protocol = ThrottlingProtocol
  156. def __init__(self, wrappedFactory, maxConnectionCount=sys.maxsize,
  157. readLimit=None, writeLimit=None):
  158. WrappingFactory.__init__(self, wrappedFactory)
  159. self.connectionCount = 0
  160. self.maxConnectionCount = maxConnectionCount
  161. self.readLimit = readLimit # max bytes we should read per second
  162. self.writeLimit = writeLimit # max bytes we should write per second
  163. self.readThisSecond = 0
  164. self.writtenThisSecond = 0
  165. self.unthrottleReadsID = None
  166. self.checkReadBandwidthID = None
  167. self.unthrottleWritesID = None
  168. self.checkWriteBandwidthID = None
  169. def callLater(self, period, func):
  170. """
  171. Wrapper around
  172. L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
  173. for test purpose.
  174. """
  175. from twisted.internet import reactor
  176. return reactor.callLater(period, func)
  177. def registerWritten(self, length):
  178. """
  179. Called by protocol to tell us more bytes were written.
  180. """
  181. self.writtenThisSecond += length
  182. def registerRead(self, length):
  183. """
  184. Called by protocol to tell us more bytes were read.
  185. """
  186. self.readThisSecond += length
  187. def checkReadBandwidth(self):
  188. """
  189. Checks if we've passed bandwidth limits.
  190. """
  191. if self.readThisSecond > self.readLimit:
  192. self.throttleReads()
  193. throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
  194. self.unthrottleReadsID = self.callLater(throttleTime,
  195. self.unthrottleReads)
  196. self.readThisSecond = 0
  197. self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
  198. def checkWriteBandwidth(self):
  199. if self.writtenThisSecond > self.writeLimit:
  200. self.throttleWrites()
  201. throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
  202. self.unthrottleWritesID = self.callLater(throttleTime,
  203. self.unthrottleWrites)
  204. # reset for next round
  205. self.writtenThisSecond = 0
  206. self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
  207. def throttleReads(self):
  208. """
  209. Throttle reads on all protocols.
  210. """
  211. log.msg("Throttling reads on %s" % self)
  212. for p in self.protocols.keys():
  213. p.throttleReads()
  214. def unthrottleReads(self):
  215. """
  216. Stop throttling reads on all protocols.
  217. """
  218. self.unthrottleReadsID = None
  219. log.msg("Stopped throttling reads on %s" % self)
  220. for p in self.protocols.keys():
  221. p.unthrottleReads()
  222. def throttleWrites(self):
  223. """
  224. Throttle writes on all protocols.
  225. """
  226. log.msg("Throttling writes on %s" % self)
  227. for p in self.protocols.keys():
  228. p.throttleWrites()
  229. def unthrottleWrites(self):
  230. """
  231. Stop throttling writes on all protocols.
  232. """
  233. self.unthrottleWritesID = None
  234. log.msg("Stopped throttling writes on %s" % self)
  235. for p in self.protocols.keys():
  236. p.unthrottleWrites()
  237. def buildProtocol(self, addr):
  238. if self.connectionCount == 0:
  239. if self.readLimit is not None:
  240. self.checkReadBandwidth()
  241. if self.writeLimit is not None:
  242. self.checkWriteBandwidth()
  243. if self.connectionCount < self.maxConnectionCount:
  244. self.connectionCount += 1
  245. return WrappingFactory.buildProtocol(self, addr)
  246. else:
  247. log.msg("Max connection count reached!")
  248. return None
  249. def unregisterProtocol(self, p):
  250. WrappingFactory.unregisterProtocol(self, p)
  251. self.connectionCount -= 1
  252. if self.connectionCount == 0:
  253. if self.unthrottleReadsID is not None:
  254. self.unthrottleReadsID.cancel()
  255. if self.checkReadBandwidthID is not None:
  256. self.checkReadBandwidthID.cancel()
  257. if self.unthrottleWritesID is not None:
  258. self.unthrottleWritesID.cancel()
  259. if self.checkWriteBandwidthID is not None:
  260. self.checkWriteBandwidthID.cancel()
  261. class SpewingProtocol(ProtocolWrapper):
  262. def dataReceived(self, data):
  263. log.msg("Received: %r" % data)
  264. ProtocolWrapper.dataReceived(self,data)
  265. def write(self, data):
  266. log.msg("Sending: %r" % data)
  267. ProtocolWrapper.write(self,data)
  268. class SpewingFactory(WrappingFactory):
  269. protocol = SpewingProtocol
  270. class LimitConnectionsByPeer(WrappingFactory):
  271. maxConnectionsPerPeer = 5
  272. def startFactory(self):
  273. self.peerConnections = {}
  274. def buildProtocol(self, addr):
  275. peerHost = addr[0]
  276. connectionCount = self.peerConnections.get(peerHost, 0)
  277. if connectionCount >= self.maxConnectionsPerPeer:
  278. return None
  279. self.peerConnections[peerHost] = connectionCount + 1
  280. return WrappingFactory.buildProtocol(self, addr)
  281. def unregisterProtocol(self, p):
  282. peerHost = p.getPeer()[1]
  283. self.peerConnections[peerHost] -= 1
  284. if self.peerConnections[peerHost] == 0:
  285. del self.peerConnections[peerHost]
  286. class LimitTotalConnectionsFactory(ServerFactory):
  287. """
  288. Factory that limits the number of simultaneous connections.
  289. @type connectionCount: C{int}
  290. @ivar connectionCount: number of current connections.
  291. @type connectionLimit: C{int} or L{None}
  292. @cvar connectionLimit: maximum number of connections.
  293. @type overflowProtocol: L{Protocol} or L{None}
  294. @cvar overflowProtocol: Protocol to use for new connections when
  295. connectionLimit is exceeded. If L{None} (the default value), excess
  296. connections will be closed immediately.
  297. """
  298. connectionCount = 0
  299. connectionLimit = None
  300. overflowProtocol = None
  301. def buildProtocol(self, addr):
  302. if (self.connectionLimit is None or
  303. self.connectionCount < self.connectionLimit):
  304. # Build the normal protocol
  305. wrappedProtocol = self.protocol()
  306. elif self.overflowProtocol is None:
  307. # Just drop the connection
  308. return None
  309. else:
  310. # Too many connections, so build the overflow protocol
  311. wrappedProtocol = self.overflowProtocol()
  312. wrappedProtocol.factory = self
  313. protocol = ProtocolWrapper(self, wrappedProtocol)
  314. self.connectionCount += 1
  315. return protocol
  316. def registerProtocol(self, p):
  317. pass
  318. def unregisterProtocol(self, p):
  319. self.connectionCount -= 1
  320. class TimeoutProtocol(ProtocolWrapper):
  321. """
  322. Protocol that automatically disconnects when the connection is idle.
  323. """
  324. def __init__(self, factory, wrappedProtocol, timeoutPeriod):
  325. """
  326. Constructor.
  327. @param factory: An L{protocol.Factory}.
  328. @param wrappedProtocol: A L{Protocol} to wrapp.
  329. @param timeoutPeriod: Number of seconds to wait for activity before
  330. timing out.
  331. """
  332. ProtocolWrapper.__init__(self, factory, wrappedProtocol)
  333. self.timeoutCall = None
  334. self.setTimeout(timeoutPeriod)
  335. def setTimeout(self, timeoutPeriod=None):
  336. """
  337. Set a timeout.
  338. This will cancel any existing timeouts.
  339. @param timeoutPeriod: If not L{None}, change the timeout period.
  340. Otherwise, use the existing value.
  341. """
  342. self.cancelTimeout()
  343. if timeoutPeriod is not None:
  344. self.timeoutPeriod = timeoutPeriod
  345. self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
  346. def cancelTimeout(self):
  347. """
  348. Cancel the timeout.
  349. If the timeout was already cancelled, this does nothing.
  350. """
  351. if self.timeoutCall:
  352. try:
  353. self.timeoutCall.cancel()
  354. except error.AlreadyCalled:
  355. pass
  356. self.timeoutCall = None
  357. def resetTimeout(self):
  358. """
  359. Reset the timeout, usually because some activity just happened.
  360. """
  361. if self.timeoutCall:
  362. self.timeoutCall.reset(self.timeoutPeriod)
  363. def write(self, data):
  364. self.resetTimeout()
  365. ProtocolWrapper.write(self, data)
  366. def writeSequence(self, seq):
  367. self.resetTimeout()
  368. ProtocolWrapper.writeSequence(self, seq)
  369. def dataReceived(self, data):
  370. self.resetTimeout()
  371. ProtocolWrapper.dataReceived(self, data)
  372. def connectionLost(self, reason):
  373. self.cancelTimeout()
  374. ProtocolWrapper.connectionLost(self, reason)
  375. def timeoutFunc(self):
  376. """
  377. This method is called when the timeout is triggered.
  378. By default it calls I{loseConnection}. Override this if you want
  379. something else to happen.
  380. """
  381. self.loseConnection()
  382. class TimeoutFactory(WrappingFactory):
  383. """
  384. Factory for TimeoutWrapper.
  385. """
  386. protocol = TimeoutProtocol
  387. def __init__(self, wrappedFactory, timeoutPeriod=30*60):
  388. self.timeoutPeriod = timeoutPeriod
  389. WrappingFactory.__init__(self, wrappedFactory)
  390. def buildProtocol(self, addr):
  391. return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
  392. timeoutPeriod=self.timeoutPeriod)
  393. def callLater(self, period, func):
  394. """
  395. Wrapper around
  396. L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
  397. for test purpose.
  398. """
  399. from twisted.internet import reactor
  400. return reactor.callLater(period, func)
  401. class TrafficLoggingProtocol(ProtocolWrapper):
  402. def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
  403. number=0):
  404. """
  405. @param factory: factory which created this protocol.
  406. @type factory: L{protocol.Factory}.
  407. @param wrappedProtocol: the underlying protocol.
  408. @type wrappedProtocol: C{protocol.Protocol}.
  409. @param logfile: file opened for writing used to write log messages.
  410. @type logfile: C{file}
  411. @param lengthLimit: maximum size of the datareceived logged.
  412. @type lengthLimit: C{int}
  413. @param number: identifier of the connection.
  414. @type number: C{int}.
  415. """
  416. ProtocolWrapper.__init__(self, factory, wrappedProtocol)
  417. self.logfile = logfile
  418. self.lengthLimit = lengthLimit
  419. self._number = number
  420. def _log(self, line):
  421. self.logfile.write(line + '\n')
  422. self.logfile.flush()
  423. def _mungeData(self, data):
  424. if self.lengthLimit and len(data) > self.lengthLimit:
  425. data = data[:self.lengthLimit - 12] + '<... elided>'
  426. return data
  427. # IProtocol
  428. def connectionMade(self):
  429. self._log('*')
  430. return ProtocolWrapper.connectionMade(self)
  431. def dataReceived(self, data):
  432. self._log('C %d: %r' % (self._number, self._mungeData(data)))
  433. return ProtocolWrapper.dataReceived(self, data)
  434. def connectionLost(self, reason):
  435. self._log('C %d: %r' % (self._number, reason))
  436. return ProtocolWrapper.connectionLost(self, reason)
  437. # ITransport
  438. def write(self, data):
  439. self._log('S %d: %r' % (self._number, self._mungeData(data)))
  440. return ProtocolWrapper.write(self, data)
  441. def writeSequence(self, iovec):
  442. self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
  443. return ProtocolWrapper.writeSequence(self, iovec)
  444. def loseConnection(self):
  445. self._log('S %d: *' % (self._number,))
  446. return ProtocolWrapper.loseConnection(self)
  447. class TrafficLoggingFactory(WrappingFactory):
  448. protocol = TrafficLoggingProtocol
  449. _counter = 0
  450. def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
  451. self.logfilePrefix = logfilePrefix
  452. self.lengthLimit = lengthLimit
  453. WrappingFactory.__init__(self, wrappedFactory)
  454. def open(self, name):
  455. return open(name, 'w')
  456. def buildProtocol(self, addr):
  457. self._counter += 1
  458. logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
  459. return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
  460. logfile, self.lengthLimit, self._counter)
  461. def resetCounter(self):
  462. """
  463. Reset the value of the counter used to identify connections.
  464. """
  465. self._counter = 0
  466. class TimeoutMixin:
  467. """
  468. Mixin for protocols which wish to timeout connections.
  469. Protocols that mix this in have a single timeout, set using L{setTimeout}.
  470. When the timeout is hit, L{timeoutConnection} is called, which, by
  471. default, closes the connection.
  472. @cvar timeOut: The number of seconds after which to timeout the connection.
  473. """
  474. timeOut = None
  475. __timeoutCall = None
  476. def callLater(self, period, func):
  477. """
  478. Wrapper around
  479. L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
  480. for test purpose.
  481. """
  482. from twisted.internet import reactor
  483. return reactor.callLater(period, func)
  484. def resetTimeout(self):
  485. """
  486. Reset the timeout count down.
  487. If the connection has already timed out, then do nothing. If the
  488. timeout has been cancelled (probably using C{setTimeout(None)}), also
  489. do nothing.
  490. It's often a good idea to call this when the protocol has received
  491. some meaningful input from the other end of the connection. "I've got
  492. some data, they're still there, reset the timeout".
  493. """
  494. if self.__timeoutCall is not None and self.timeOut is not None:
  495. self.__timeoutCall.reset(self.timeOut)
  496. def setTimeout(self, period):
  497. """
  498. Change the timeout period
  499. @type period: C{int} or L{None}
  500. @param period: The period, in seconds, to change the timeout to, or
  501. L{None} to disable the timeout.
  502. """
  503. prev = self.timeOut
  504. self.timeOut = period
  505. if self.__timeoutCall is not None:
  506. if period is None:
  507. self.__timeoutCall.cancel()
  508. self.__timeoutCall = None
  509. else:
  510. self.__timeoutCall.reset(period)
  511. elif period is not None:
  512. self.__timeoutCall = self.callLater(period, self.__timedOut)
  513. return prev
  514. def __timedOut(self):
  515. self.__timeoutCall = None
  516. self.timeoutConnection()
  517. def timeoutConnection(self):
  518. """
  519. Called when the connection times out.
  520. Override to define behavior other than dropping the connection.
  521. """
  522. self.transport.loseConnection()