srv_resolver.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2019-present MongoDB, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you
  4. # may not use this file except in compliance with the License. You
  5. # may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
  12. # implied. See the License for the specific language governing
  13. # permissions and limitations under the License.
  14. """Support for resolving hosts and options from mongodb+srv:// URIs."""
  15. try:
  16. from dns import resolver
  17. _HAVE_DNSPYTHON = True
  18. except ImportError:
  19. _HAVE_DNSPYTHON = False
  20. from bson.py3compat import PY3
  21. from pymongo.common import CONNECT_TIMEOUT
  22. from pymongo.errors import ConfigurationError
  23. from pymongo._ipaddress import is_ip_address
  24. if PY3:
  25. # dnspython can return bytes or str from various parts
  26. # of its API depending on version. We always want str.
  27. def maybe_decode(text):
  28. if isinstance(text, bytes):
  29. return text.decode()
  30. return text
  31. else:
  32. def maybe_decode(text):
  33. return text
  34. # PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet.
  35. def _resolve(*args, **kwargs):
  36. if hasattr(resolver, 'resolve'):
  37. # dnspython >= 2
  38. return resolver.resolve(*args, **kwargs)
  39. # dnspython 1.X
  40. return resolver.query(*args, **kwargs)
  41. _INVALID_HOST_MSG = (
  42. "Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. "
  43. "Did you mean to use 'mongodb://'?")
  44. class _SrvResolver(object):
  45. def __init__(self, fqdn, connect_timeout=None):
  46. self.__fqdn = fqdn
  47. self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT
  48. # Validate the fully qualified domain name.
  49. if is_ip_address(fqdn):
  50. raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",))
  51. try:
  52. self.__plist = self.__fqdn.split(".")[1:]
  53. except Exception:
  54. raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
  55. self.__slen = len(self.__plist)
  56. if self.__slen < 2:
  57. raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,))
  58. def get_options(self):
  59. try:
  60. results = _resolve(self.__fqdn, 'TXT',
  61. lifetime=self.__connect_timeout)
  62. except (resolver.NoAnswer, resolver.NXDOMAIN):
  63. # No TXT records
  64. return None
  65. except Exception as exc:
  66. raise ConfigurationError(str(exc))
  67. if len(results) > 1:
  68. raise ConfigurationError('Only one TXT record is supported')
  69. return (
  70. b'&'.join([b''.join(res.strings) for res in results])).decode(
  71. 'utf-8')
  72. def _resolve_uri(self, encapsulate_errors):
  73. try:
  74. results = _resolve('_mongodb._tcp.' + self.__fqdn, 'SRV',
  75. lifetime=self.__connect_timeout)
  76. except Exception as exc:
  77. if not encapsulate_errors:
  78. # Raise the original error.
  79. raise
  80. # Else, raise all errors as ConfigurationError.
  81. raise ConfigurationError(str(exc))
  82. return results
  83. def _get_srv_response_and_hosts(self, encapsulate_errors):
  84. results = self._resolve_uri(encapsulate_errors)
  85. # Construct address tuples
  86. nodes = [
  87. (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port)
  88. for res in results]
  89. # Validate hosts
  90. for node in nodes:
  91. try:
  92. nlist = node[0].split(".")[1:][-self.__slen:]
  93. except Exception:
  94. raise ConfigurationError("Invalid SRV host: %s" % (node[0],))
  95. if self.__plist != nlist:
  96. raise ConfigurationError("Invalid SRV host: %s" % (node[0],))
  97. return results, nodes
  98. def get_hosts(self):
  99. _, nodes = self._get_srv_response_and_hosts(True)
  100. return nodes
  101. def get_hosts_and_min_ttl(self):
  102. results, nodes = self._get_srv_response_and_hosts(False)
  103. return nodes, results.rrset.ttl