srvconnect.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # -*- test-case-name: twisted.names.test.test_srvconnect -*-
  2. # Copyright (c) Twisted Matrix Laboratories.
  3. # See LICENSE for details.
  4. from __future__ import absolute_import, division
  5. import random
  6. from zope.interface import implementer
  7. from twisted.internet import error, interfaces
  8. from twisted.names import client, dns
  9. from twisted.names.error import DNSNameError
  10. from twisted.python.compat import nativeString, unicode
  11. class _SRVConnector_ClientFactoryWrapper:
  12. def __init__(self, connector, wrappedFactory):
  13. self.__connector = connector
  14. self.__wrappedFactory = wrappedFactory
  15. def startedConnecting(self, connector):
  16. self.__wrappedFactory.startedConnecting(self.__connector)
  17. def clientConnectionFailed(self, connector, reason):
  18. self.__connector.connectionFailed(reason)
  19. def clientConnectionLost(self, connector, reason):
  20. self.__connector.connectionLost(reason)
  21. def __getattr__(self, key):
  22. return getattr(self.__wrappedFactory, key)
  23. @implementer(interfaces.IConnector)
  24. class SRVConnector:
  25. """
  26. A connector that looks up DNS SRV records.
  27. RFC 2782 details how SRV records should be interpreted and selected
  28. for subsequent connection attempts. The algorithm for using the records'
  29. priority and weight is implemented in L{pickServer}.
  30. @ivar servers: List of candidate server records for future connection
  31. attempts.
  32. @type servers: L{list} of L{dns.Record_SRV}
  33. @ivar orderedServers: List of server records that have already been tried
  34. in this round of connection attempts.
  35. @type orderedServers: L{list} of L{dns.Record_SRV}
  36. """
  37. stopAfterDNS=0
  38. def __init__(self, reactor, service, domain, factory,
  39. protocol='tcp', connectFuncName='connectTCP',
  40. connectFuncArgs=(),
  41. connectFuncKwArgs={},
  42. defaultPort=None,
  43. ):
  44. """
  45. @param domain: The domain to connect to. If passed as a unicode
  46. string, it will be encoded using C{idna} encoding.
  47. @type domain: L{bytes} or L{unicode}
  48. @param defaultPort: Optional default port number to be used when SRV
  49. lookup fails and the service name is unknown. This should be the
  50. port number associated with the service name as defined by the IANA
  51. registry.
  52. @type defaultPort: L{int}
  53. """
  54. self.reactor = reactor
  55. self.service = service
  56. if isinstance(domain, unicode):
  57. domain = domain.encode('idna')
  58. self.domain = nativeString(domain)
  59. self.factory = factory
  60. self.protocol = protocol
  61. self.connectFuncName = connectFuncName
  62. self.connectFuncArgs = connectFuncArgs
  63. self.connectFuncKwArgs = connectFuncKwArgs
  64. self._defaultPort = defaultPort
  65. self.connector = None
  66. self.servers = None
  67. self.orderedServers = None # list of servers already used in this round
  68. def connect(self):
  69. """Start connection to remote server."""
  70. self.factory.doStart()
  71. self.factory.startedConnecting(self)
  72. if not self.servers:
  73. if self.domain is None:
  74. self.connectionFailed(error.DNSLookupError("Domain is not defined."))
  75. return
  76. d = client.lookupService('_%s._%s.%s' %
  77. (nativeString(self.service),
  78. nativeString(self.protocol),
  79. self.domain))
  80. d.addCallbacks(self._cbGotServers, self._ebGotServers)
  81. d.addCallback(lambda x, self=self: self._reallyConnect())
  82. if self._defaultPort:
  83. d.addErrback(self._ebServiceUnknown)
  84. d.addErrback(self.connectionFailed)
  85. elif self.connector is None:
  86. self._reallyConnect()
  87. else:
  88. self.connector.connect()
  89. def _ebGotServers(self, failure):
  90. failure.trap(DNSNameError)
  91. # Some DNS servers reply with NXDOMAIN when in fact there are
  92. # just no SRV records for that domain. Act as if we just got an
  93. # empty response and use fallback.
  94. self.servers = []
  95. self.orderedServers = []
  96. def _cbGotServers(self, result):
  97. answers, auth, add = result
  98. if len(answers) == 1 and answers[0].type == dns.SRV \
  99. and answers[0].payload \
  100. and answers[0].payload.target == dns.Name(b'.'):
  101. # decidedly not available
  102. raise error.DNSLookupError("Service %s not available for domain %s."
  103. % (repr(self.service), repr(self.domain)))
  104. self.servers = []
  105. self.orderedServers = []
  106. for a in answers:
  107. if a.type != dns.SRV or not a.payload:
  108. continue
  109. self.orderedServers.append(a.payload)
  110. def _ebServiceUnknown(self, failure):
  111. """
  112. Connect to the default port when the service name is unknown.
  113. If no SRV records were found, the service name will be passed as the
  114. port. If resolving the name fails with
  115. L{error.ServiceNameUnknownError}, a final attempt is done using the
  116. default port.
  117. """
  118. failure.trap(error.ServiceNameUnknownError)
  119. self.servers = [dns.Record_SRV(0, 0, self._defaultPort, self.domain)]
  120. self.orderedServers = []
  121. self.connect()
  122. def pickServer(self):
  123. """
  124. Pick the next server.
  125. This selects the next server from the list of SRV records according
  126. to their priority and weight values, as set out by the default
  127. algorithm specified in RFC 2782.
  128. At the beginning of a round, L{servers} is populated with
  129. L{orderedServers}, and the latter is made empty. L{servers}
  130. is the list of candidates, and L{orderedServers} is the list of servers
  131. that have already been tried.
  132. First, all records are ordered by priority and weight in ascending
  133. order. Then for each priority level, a running sum is calculated
  134. over the sorted list of records for that priority. Then a random value
  135. between 0 and the final sum is compared to each record in order. The
  136. first record that is greater than or equal to that random value is
  137. chosen and removed from the list of candidates for this round.
  138. @return: A tuple of target hostname and port from the chosen DNS SRV
  139. record.
  140. @rtype: L{tuple} of native L{str} and L{int}
  141. """
  142. assert self.servers is not None
  143. assert self.orderedServers is not None
  144. if not self.servers and not self.orderedServers:
  145. # no SRV record, fall back..
  146. return self.domain, self.service
  147. if not self.servers and self.orderedServers:
  148. # start new round
  149. self.servers = self.orderedServers
  150. self.orderedServers = []
  151. assert self.servers
  152. self.servers.sort(key=lambda record: (record.priority, record.weight))
  153. minPriority = self.servers[0].priority
  154. index = 0
  155. weightSum = 0
  156. weightIndex = []
  157. for x in self.servers:
  158. if x.priority == minPriority:
  159. weightSum += x.weight
  160. weightIndex.append((index, weightSum))
  161. index += 1
  162. rand = random.randint(0, weightSum)
  163. for index, weight in weightIndex:
  164. if weight >= rand:
  165. chosen = self.servers[index]
  166. del self.servers[index]
  167. self.orderedServers.append(chosen)
  168. return str(chosen.target), chosen.port
  169. raise RuntimeError(
  170. 'Impossible %s pickServer result.' % (self.__class__.__name__,))
  171. def _reallyConnect(self):
  172. if self.stopAfterDNS:
  173. self.stopAfterDNS=0
  174. return
  175. self.host, self.port = self.pickServer()
  176. assert self.host is not None, 'Must have a host to connect to.'
  177. assert self.port is not None, 'Must have a port to connect to.'
  178. connectFunc = getattr(self.reactor, self.connectFuncName)
  179. self.connector=connectFunc(
  180. self.host, self.port,
  181. _SRVConnector_ClientFactoryWrapper(self, self.factory),
  182. *self.connectFuncArgs, **self.connectFuncKwArgs)
  183. def stopConnecting(self):
  184. """Stop attempting to connect."""
  185. if self.connector:
  186. self.connector.stopConnecting()
  187. else:
  188. self.stopAfterDNS=1
  189. def disconnect(self):
  190. """Disconnect whatever our are state is."""
  191. if self.connector is not None:
  192. self.connector.disconnect()
  193. else:
  194. self.stopConnecting()
  195. def getDestination(self):
  196. assert self.connector
  197. return self.connector.getDestination()
  198. def connectionFailed(self, reason):
  199. self.factory.clientConnectionFailed(self, reason)
  200. self.factory.doStop()
  201. def connectionLost(self, reason):
  202. self.factory.clientConnectionLost(self, reason)
  203. self.factory.doStop()