socket_checker.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # Copyright 2020-present MongoDB, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Select / poll helper"""
  15. import errno
  16. import select
  17. import sys
  18. # PYTHON-2320: Jython does not fully support poll on SSL sockets,
  19. # https://bugs.jython.org/issue2900
  20. _HAVE_POLL = hasattr(select, "poll") and not sys.platform.startswith('java')
  21. _SelectError = getattr(select, "error", OSError)
  22. def _errno_from_exception(exc):
  23. if hasattr(exc, 'errno'):
  24. return exc.errno
  25. if exc.args:
  26. return exc.args[0]
  27. return None
  28. class SocketChecker(object):
  29. def __init__(self):
  30. if _HAVE_POLL:
  31. self._poller = select.poll()
  32. else:
  33. self._poller = None
  34. def select(self, sock, read=False, write=False, timeout=0):
  35. """Select for reads or writes with a timeout in seconds (or None).
  36. Returns True if the socket is readable/writable, False on timeout.
  37. """
  38. while True:
  39. try:
  40. if self._poller:
  41. mask = select.POLLERR | select.POLLHUP
  42. if read:
  43. mask = mask | select.POLLIN | select.POLLPRI
  44. if write:
  45. mask = mask | select.POLLOUT
  46. self._poller.register(sock, mask)
  47. try:
  48. # poll() timeout is in milliseconds. select()
  49. # timeout is in seconds.
  50. timeout_ = None if timeout is None else timeout * 1000
  51. res = self._poller.poll(timeout_)
  52. # poll returns a possibly-empty list containing
  53. # (fd, event) 2-tuples for the descriptors that have
  54. # events or errors to report. Return True if the list
  55. # is not empty.
  56. return bool(res)
  57. finally:
  58. self._poller.unregister(sock)
  59. else:
  60. rlist = [sock] if read else []
  61. wlist = [sock] if write else []
  62. res = select.select(rlist, wlist, [sock], timeout)
  63. # select returns a 3-tuple of lists of objects that are
  64. # ready: subsets of the first three arguments. Return
  65. # True if any of the lists are not empty.
  66. return any(res)
  67. except (_SelectError, IOError) as exc:
  68. if _errno_from_exception(exc) in (errno.EINTR, errno.EAGAIN):
  69. continue
  70. raise
  71. def socket_closed(self, sock):
  72. """Return True if we know socket has been closed, False otherwise.
  73. """
  74. try:
  75. return self.select(sock, read=True)
  76. except (RuntimeError, KeyError):
  77. # RuntimeError is raised during a concurrent poll. KeyError
  78. # is raised by unregister if the socket is not in the poller.
  79. # These errors should not be possible since we protect the
  80. # poller with a mutex.
  81. raise
  82. except ValueError:
  83. # ValueError is raised by register/unregister/select if the
  84. # socket file descriptor is negative or outside the range for
  85. # select (> 1023).
  86. return True
  87. except Exception:
  88. # Any other exceptions should be attributed to a closed
  89. # or invalid socket.
  90. return True