connection.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. #
  2. # A higher level module for using sockets (or Windows named pipes)
  3. #
  4. # multiprocessing/connection.py
  5. #
  6. # Copyright (c) 2006-2008, R Oudkerk
  7. # Licensed to PSF under a Contributor Agreement.
  8. #
  9. from __future__ import absolute_import
  10. import os
  11. import random
  12. import sys
  13. import socket
  14. import string
  15. import errno
  16. import time
  17. import tempfile
  18. import itertools
  19. from .. import AuthenticationError
  20. from .. import reduction
  21. from .._ext import _billiard, win32
  22. from ..compat import get_errno, setblocking, bytes as cbytes
  23. from ..five import monotonic
  24. from ..forking import duplicate, close
  25. from ..reduction import ForkingPickler
  26. from ..util import get_temp_dir, Finalize, sub_debug, debug
  27. try:
  28. WindowsError = WindowsError # noqa
  29. except NameError:
  30. WindowsError = None # noqa
  31. __all__ = ['Client', 'Listener', 'Pipe']
  32. # global set later
  33. xmlrpclib = None
  34. Connection = getattr(_billiard, 'Connection', None)
  35. PipeConnection = getattr(_billiard, 'PipeConnection', None)
  36. #
  37. #
  38. #
  39. BUFSIZE = 8192
  40. # A very generous timeout when it comes to local connections...
  41. CONNECTION_TIMEOUT = 20.
  42. _mmap_counter = itertools.count()
  43. default_family = 'AF_INET'
  44. families = ['AF_INET']
  45. if hasattr(socket, 'AF_UNIX'):
  46. default_family = 'AF_UNIX'
  47. families += ['AF_UNIX']
  48. if sys.platform == 'win32':
  49. default_family = 'AF_PIPE'
  50. families += ['AF_PIPE']
  51. def _init_timeout(timeout=CONNECTION_TIMEOUT):
  52. return monotonic() + timeout
  53. def _check_timeout(t):
  54. return monotonic() > t
  55. #
  56. #
  57. #
  58. def arbitrary_address(family):
  59. '''
  60. Return an arbitrary free address for the given family
  61. '''
  62. if family == 'AF_INET':
  63. return ('localhost', 0)
  64. elif family == 'AF_UNIX':
  65. return tempfile.mktemp(prefix='listener-', dir=get_temp_dir())
  66. elif family == 'AF_PIPE':
  67. randomchars = ''.join(
  68. random.choice(string.ascii_lowercase + string.digits)
  69. for i in range(6)
  70. )
  71. return r'\\.\pipe\pyc-%d-%d-%s' % (
  72. os.getpid(), next(_mmap_counter), randomchars
  73. )
  74. else:
  75. raise ValueError('unrecognized family')
  76. def address_type(address):
  77. '''
  78. Return the types of the address
  79. This can be 'AF_INET', 'AF_UNIX', or 'AF_PIPE'
  80. '''
  81. if type(address) == tuple:
  82. return 'AF_INET'
  83. elif type(address) is str and address.startswith('\\\\'):
  84. return 'AF_PIPE'
  85. elif type(address) is str:
  86. return 'AF_UNIX'
  87. else:
  88. raise ValueError('address type of %r unrecognized' % address)
  89. #
  90. # Public functions
  91. #
  92. class Listener(object):
  93. '''
  94. Returns a listener object.
  95. This is a wrapper for a bound socket which is 'listening' for
  96. connections, or for a Windows named pipe.
  97. '''
  98. def __init__(self, address=None, family=None, backlog=1, authkey=None):
  99. family = (family or
  100. (address and address_type(address)) or
  101. default_family)
  102. address = address or arbitrary_address(family)
  103. if family == 'AF_PIPE':
  104. self._listener = PipeListener(address, backlog)
  105. else:
  106. self._listener = SocketListener(address, family, backlog)
  107. if authkey is not None and not isinstance(authkey, bytes):
  108. raise TypeError('authkey should be a byte string')
  109. self._authkey = authkey
  110. def accept(self):
  111. '''
  112. Accept a connection on the bound socket or named pipe of `self`.
  113. Returns a `Connection` object.
  114. '''
  115. if self._listener is None:
  116. raise IOError('listener is closed')
  117. c = self._listener.accept()
  118. if self._authkey:
  119. deliver_challenge(c, self._authkey)
  120. answer_challenge(c, self._authkey)
  121. return c
  122. def close(self):
  123. '''
  124. Close the bound socket or named pipe of `self`.
  125. '''
  126. if self._listener is not None:
  127. self._listener.close()
  128. self._listener = None
  129. address = property(lambda self: self._listener._address)
  130. last_accepted = property(lambda self: self._listener._last_accepted)
  131. def __enter__(self):
  132. return self
  133. def __exit__(self, *exc_args):
  134. self.close()
  135. def Client(address, family=None, authkey=None):
  136. '''
  137. Returns a connection to the address of a `Listener`
  138. '''
  139. family = family or address_type(address)
  140. if family == 'AF_PIPE':
  141. c = PipeClient(address)
  142. else:
  143. c = SocketClient(address)
  144. if authkey is not None and not isinstance(authkey, bytes):
  145. raise TypeError('authkey should be a byte string')
  146. if authkey is not None:
  147. answer_challenge(c, authkey)
  148. deliver_challenge(c, authkey)
  149. return c
  150. if sys.platform != 'win32':
  151. def Pipe(duplex=True, rnonblock=False, wnonblock=False):
  152. '''
  153. Returns pair of connection objects at either end of a pipe
  154. '''
  155. if duplex:
  156. s1, s2 = socket.socketpair()
  157. s1.setblocking(not rnonblock)
  158. s2.setblocking(not wnonblock)
  159. c1 = Connection(os.dup(s1.fileno()))
  160. c2 = Connection(os.dup(s2.fileno()))
  161. s1.close()
  162. s2.close()
  163. else:
  164. fd1, fd2 = os.pipe()
  165. if rnonblock:
  166. setblocking(fd1, 0)
  167. if wnonblock:
  168. setblocking(fd2, 0)
  169. c1 = Connection(fd1, writable=False)
  170. c2 = Connection(fd2, readable=False)
  171. return c1, c2
  172. else:
  173. def Pipe(duplex=True, rnonblock=False, wnonblock=False): # noqa
  174. '''
  175. Returns pair of connection objects at either end of a pipe
  176. '''
  177. address = arbitrary_address('AF_PIPE')
  178. if duplex:
  179. openmode = win32.PIPE_ACCESS_DUPLEX
  180. access = win32.GENERIC_READ | win32.GENERIC_WRITE
  181. obsize, ibsize = BUFSIZE, BUFSIZE
  182. else:
  183. openmode = win32.PIPE_ACCESS_INBOUND
  184. access = win32.GENERIC_WRITE
  185. obsize, ibsize = 0, BUFSIZE
  186. h1 = win32.CreateNamedPipe(
  187. address, openmode,
  188. win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
  189. win32.PIPE_WAIT,
  190. 1, obsize, ibsize, win32.NMPWAIT_WAIT_FOREVER, win32.NULL
  191. )
  192. h2 = win32.CreateFile(
  193. address, access, 0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL
  194. )
  195. win32.SetNamedPipeHandleState(
  196. h2, win32.PIPE_READMODE_MESSAGE, None, None
  197. )
  198. try:
  199. win32.ConnectNamedPipe(h1, win32.NULL)
  200. except WindowsError as exc:
  201. if exc.args[0] != win32.ERROR_PIPE_CONNECTED:
  202. raise
  203. c1 = PipeConnection(h1, writable=duplex)
  204. c2 = PipeConnection(h2, readable=duplex)
  205. return c1, c2
  206. #
  207. # Definitions for connections based on sockets
  208. #
  209. class SocketListener(object):
  210. '''
  211. Representation of a socket which is bound to an address and listening
  212. '''
  213. def __init__(self, address, family, backlog=1):
  214. self._socket = socket.socket(getattr(socket, family))
  215. try:
  216. # SO_REUSEADDR has different semantics on Windows (Issue #2550).
  217. if os.name == 'posix':
  218. self._socket.setsockopt(socket.SOL_SOCKET,
  219. socket.SO_REUSEADDR, 1)
  220. self._socket.bind(address)
  221. self._socket.listen(backlog)
  222. self._address = self._socket.getsockname()
  223. except OSError:
  224. self._socket.close()
  225. raise
  226. self._family = family
  227. self._last_accepted = None
  228. if family == 'AF_UNIX':
  229. self._unlink = Finalize(
  230. self, os.unlink, args=(address,), exitpriority=0
  231. )
  232. else:
  233. self._unlink = None
  234. def accept(self):
  235. s, self._last_accepted = self._socket.accept()
  236. fd = duplicate(s.fileno())
  237. conn = Connection(fd)
  238. s.close()
  239. return conn
  240. def close(self):
  241. self._socket.close()
  242. if self._unlink is not None:
  243. self._unlink()
  244. def SocketClient(address):
  245. '''
  246. Return a connection object connected to the socket given by `address`
  247. '''
  248. family = address_type(address)
  249. s = socket.socket(getattr(socket, family))
  250. t = _init_timeout()
  251. while 1:
  252. try:
  253. s.connect(address)
  254. except socket.error as exc:
  255. if get_errno(exc) != errno.ECONNREFUSED or _check_timeout(t):
  256. debug('failed to connect to address %s', address)
  257. raise
  258. time.sleep(0.01)
  259. else:
  260. break
  261. else:
  262. raise
  263. fd = duplicate(s.fileno())
  264. conn = Connection(fd)
  265. s.close()
  266. return conn
  267. #
  268. # Definitions for connections based on named pipes
  269. #
  270. if sys.platform == 'win32':
  271. class PipeListener(object):
  272. '''
  273. Representation of a named pipe
  274. '''
  275. def __init__(self, address, backlog=None):
  276. self._address = address
  277. handle = win32.CreateNamedPipe(
  278. address, win32.PIPE_ACCESS_DUPLEX,
  279. win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
  280. win32.PIPE_WAIT,
  281. win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
  282. win32.NMPWAIT_WAIT_FOREVER, win32.NULL
  283. )
  284. self._handle_queue = [handle]
  285. self._last_accepted = None
  286. sub_debug('listener created with address=%r', self._address)
  287. self.close = Finalize(
  288. self, PipeListener._finalize_pipe_listener,
  289. args=(self._handle_queue, self._address), exitpriority=0
  290. )
  291. def accept(self):
  292. newhandle = win32.CreateNamedPipe(
  293. self._address, win32.PIPE_ACCESS_DUPLEX,
  294. win32.PIPE_TYPE_MESSAGE | win32.PIPE_READMODE_MESSAGE |
  295. win32.PIPE_WAIT,
  296. win32.PIPE_UNLIMITED_INSTANCES, BUFSIZE, BUFSIZE,
  297. win32.NMPWAIT_WAIT_FOREVER, win32.NULL
  298. )
  299. self._handle_queue.append(newhandle)
  300. handle = self._handle_queue.pop(0)
  301. try:
  302. win32.ConnectNamedPipe(handle, win32.NULL)
  303. except WindowsError as exc:
  304. if exc.args[0] != win32.ERROR_PIPE_CONNECTED:
  305. raise
  306. return PipeConnection(handle)
  307. @staticmethod
  308. def _finalize_pipe_listener(queue, address):
  309. sub_debug('closing listener with address=%r', address)
  310. for handle in queue:
  311. close(handle)
  312. def PipeClient(address):
  313. '''
  314. Return a connection object connected to the pipe given by `address`
  315. '''
  316. t = _init_timeout()
  317. while 1:
  318. try:
  319. win32.WaitNamedPipe(address, 1000)
  320. h = win32.CreateFile(
  321. address, win32.GENERIC_READ | win32.GENERIC_WRITE,
  322. 0, win32.NULL, win32.OPEN_EXISTING, 0, win32.NULL,
  323. )
  324. except WindowsError as exc:
  325. if exc.args[0] not in (
  326. win32.ERROR_SEM_TIMEOUT,
  327. win32.ERROR_PIPE_BUSY) or _check_timeout(t):
  328. raise
  329. else:
  330. break
  331. else:
  332. raise
  333. win32.SetNamedPipeHandleState(
  334. h, win32.PIPE_READMODE_MESSAGE, None, None
  335. )
  336. return PipeConnection(h)
  337. #
  338. # Authentication stuff
  339. #
  340. MESSAGE_LENGTH = 20
  341. CHALLENGE = cbytes('#CHALLENGE#', 'ascii')
  342. WELCOME = cbytes('#WELCOME#', 'ascii')
  343. FAILURE = cbytes('#FAILURE#', 'ascii')
  344. def deliver_challenge(connection, authkey):
  345. import hmac
  346. assert isinstance(authkey, bytes)
  347. message = os.urandom(MESSAGE_LENGTH)
  348. connection.send_bytes(CHALLENGE + message)
  349. digest = hmac.new(authkey, message).digest()
  350. response = connection.recv_bytes(256) # reject large message
  351. if response == digest:
  352. connection.send_bytes(WELCOME)
  353. else:
  354. connection.send_bytes(FAILURE)
  355. raise AuthenticationError('digest received was wrong')
  356. def answer_challenge(connection, authkey):
  357. import hmac
  358. assert isinstance(authkey, bytes)
  359. message = connection.recv_bytes(256) # reject large message
  360. assert message[:len(CHALLENGE)] == CHALLENGE, 'message = %r' % message
  361. message = message[len(CHALLENGE):]
  362. digest = hmac.new(authkey, message).digest()
  363. connection.send_bytes(digest)
  364. response = connection.recv_bytes(256) # reject large message
  365. if response != WELCOME:
  366. raise AuthenticationError('digest sent was rejected')
  367. #
  368. # Support for using xmlrpclib for serialization
  369. #
  370. class ConnectionWrapper(object):
  371. def __init__(self, conn, dumps, loads):
  372. self._conn = conn
  373. self._dumps = dumps
  374. self._loads = loads
  375. for attr in ('fileno', 'close', 'poll', 'recv_bytes', 'send_bytes'):
  376. obj = getattr(conn, attr)
  377. setattr(self, attr, obj)
  378. def send(self, obj):
  379. s = self._dumps(obj)
  380. self._conn.send_bytes(s)
  381. def recv(self):
  382. s = self._conn.recv_bytes()
  383. return self._loads(s)
  384. def _xml_dumps(obj):
  385. return xmlrpclib.dumps((obj,), None, None, None, 1).encode('utf8')
  386. def _xml_loads(s):
  387. (obj,), method = xmlrpclib.loads(s.decode('utf8'))
  388. return obj
  389. class XmlListener(Listener):
  390. def accept(self):
  391. global xmlrpclib
  392. import xmlrpclib # noqa
  393. obj = Listener.accept(self)
  394. return ConnectionWrapper(obj, _xml_dumps, _xml_loads)
  395. def XmlClient(*args, **kwds):
  396. global xmlrpclib
  397. import xmlrpclib # noqa
  398. return ConnectionWrapper(Client(*args, **kwds), _xml_dumps, _xml_loads)
  399. if sys.platform == 'win32':
  400. ForkingPickler.register(socket.socket, reduction.reduce_socket)
  401. ForkingPickler.register(Connection, reduction.reduce_connection)
  402. ForkingPickler.register(PipeConnection, reduction.reduce_pipe_connection)
  403. else:
  404. ForkingPickler.register(socket.socket, reduction.reduce_socket)
  405. ForkingPickler.register(Connection, reduction.reduce_connection)