_future.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545
  1. """Future-returning APIs for coroutines."""
  2. # Copyright (c) PyZMQ Developers.
  3. # Distributed under the terms of the Modified BSD License.
  4. from collections import namedtuple, deque
  5. from itertools import chain
  6. from zmq import EVENTS, POLLOUT, POLLIN
  7. import zmq as _zmq
  8. _FutureEvent = namedtuple('_FutureEvent', ('future', 'kind', 'kwargs', 'msg'))
  9. # These are incomplete classes and need a Mixin for compatibility with an eventloop
  10. # defining the followig attributes:
  11. #
  12. # _Future
  13. # _READ
  14. # _WRITE
  15. # _default_loop()
  16. class _AsyncPoller(_zmq.Poller):
  17. """Poller that returns a Future on poll, instead of blocking."""
  18. def poll(self, timeout=-1):
  19. """Return a Future for a poll event"""
  20. future = self._Future()
  21. if timeout == 0:
  22. try:
  23. result = super(_AsyncPoller, self).poll(0)
  24. except Exception as e:
  25. future.set_exception(e)
  26. else:
  27. future.set_result(result)
  28. return future
  29. loop = self._default_loop()
  30. # register Future to be called as soon as any event is available on any socket
  31. watcher = self._Future()
  32. # watch raw sockets:
  33. raw_sockets = []
  34. def wake_raw(*args):
  35. if not watcher.done():
  36. watcher.set_result(None)
  37. watcher.add_done_callback(lambda f: self._unwatch_raw_sockets(loop, *raw_sockets))
  38. for socket, mask in self.sockets:
  39. if isinstance(socket, _zmq.Socket):
  40. if not isinstance(socket, self._socket_class):
  41. # it's a blocking zmq.Socket, wrap it in async
  42. socket = self._socket_class.from_socket(socket)
  43. if mask & _zmq.POLLIN:
  44. socket._add_recv_event('poll', future=watcher)
  45. if mask & _zmq.POLLOUT:
  46. socket._add_send_event('poll', future=watcher)
  47. else:
  48. raw_sockets.append(socket)
  49. evt = 0
  50. if mask & _zmq.POLLIN:
  51. evt |= self._READ
  52. if mask & _zmq.POLLOUT:
  53. evt |= self._WRITE
  54. self._watch_raw_socket(loop, socket, evt, wake_raw)
  55. def on_poll_ready(f):
  56. if future.done():
  57. return
  58. if watcher.cancelled():
  59. try:
  60. future.cancel()
  61. except RuntimeError:
  62. # RuntimeError may be called during teardown
  63. pass
  64. return
  65. if watcher.exception():
  66. future.set_exception(watcher.exception())
  67. else:
  68. try:
  69. result = super(_AsyncPoller, self).poll(0)
  70. except Exception as e:
  71. future.set_exception(e)
  72. else:
  73. future.set_result(result)
  74. watcher.add_done_callback(on_poll_ready)
  75. if timeout is not None and timeout > 0:
  76. # schedule cancel to fire on poll timeout, if any
  77. def trigger_timeout():
  78. if not watcher.done():
  79. watcher.set_result(None)
  80. timeout_handle = loop.call_later(
  81. 1e-3 * timeout,
  82. trigger_timeout
  83. )
  84. def cancel_timeout(f):
  85. if hasattr(timeout_handle, 'cancel'):
  86. timeout_handle.cancel()
  87. else:
  88. loop.remove_timeout(timeout_handle)
  89. future.add_done_callback(cancel_timeout)
  90. def cancel_watcher(f):
  91. if not watcher.done():
  92. watcher.cancel()
  93. future.add_done_callback(cancel_watcher)
  94. return future
  95. class _AsyncSocket(_zmq.Socket):
  96. # Warning : these class variables are only here to allow to call super().__setattr__.
  97. # They be overridden at instance initialization and not shared in the whole class
  98. _recv_futures = None
  99. _send_futures = None
  100. _state = 0
  101. _shadow_sock = None
  102. _poller_class = _AsyncPoller
  103. io_loop = None
  104. _fd = None
  105. def __init__(self, context=None, socket_type=-1, io_loop=None, **kwargs):
  106. if isinstance(context, _zmq.Socket):
  107. context, from_socket = (None, context)
  108. else:
  109. from_socket = kwargs.pop('_from_socket', None)
  110. if from_socket is not None:
  111. super(_AsyncSocket, self).__init__(shadow=from_socket.underlying)
  112. self._shadow_sock = from_socket
  113. else:
  114. super(_AsyncSocket, self).__init__(context, socket_type, **kwargs)
  115. self._shadow_sock = _zmq.Socket.shadow(self.underlying)
  116. self.io_loop = io_loop or self._default_loop()
  117. self._recv_futures = deque()
  118. self._send_futures = deque()
  119. self._state = 0
  120. self._fd = self._shadow_sock.FD
  121. self._init_io_state()
  122. @classmethod
  123. def from_socket(cls, socket, io_loop=None):
  124. """Create an async socket from an existing Socket"""
  125. return cls(_from_socket=socket, io_loop=io_loop)
  126. def close(self, linger=None):
  127. if not self.closed:
  128. for event in list(chain(self._recv_futures, self._send_futures)):
  129. if not event.future.done():
  130. try:
  131. event.future.cancel()
  132. except RuntimeError:
  133. # RuntimeError may be called during teardown
  134. pass
  135. self._clear_io_state()
  136. super(_AsyncSocket, self).close(linger=linger)
  137. close.__doc__ = _zmq.Socket.close.__doc__
  138. def get(self, key):
  139. result = super(_AsyncSocket, self).get(key)
  140. if key == EVENTS:
  141. self._schedule_remaining_events(result)
  142. return result
  143. get.__doc__ = _zmq.Socket.get.__doc__
  144. def recv_multipart(self, flags=0, copy=True, track=False):
  145. """Receive a complete multipart zmq message.
  146. Returns a Future whose result will be a multipart message.
  147. """
  148. return self._add_recv_event('recv_multipart',
  149. dict(flags=flags, copy=copy, track=track)
  150. )
  151. def recv(self, flags=0, copy=True, track=False):
  152. """Receive a single zmq frame.
  153. Returns a Future, whose result will be the received frame.
  154. Recommend using recv_multipart instead.
  155. """
  156. return self._add_recv_event('recv',
  157. dict(flags=flags, copy=copy, track=track)
  158. )
  159. def send_multipart(self, msg, flags=0, copy=True, track=False, **kwargs):
  160. """Send a complete multipart zmq message.
  161. Returns a Future that resolves when sending is complete.
  162. """
  163. kwargs['flags'] = flags
  164. kwargs['copy'] = copy
  165. kwargs['track'] = track
  166. return self._add_send_event('send_multipart', msg=msg, kwargs=kwargs)
  167. def send(self, msg, flags=0, copy=True, track=False, **kwargs):
  168. """Send a single zmq frame.
  169. Returns a Future that resolves when sending is complete.
  170. Recommend using send_multipart instead.
  171. """
  172. kwargs['flags'] = flags
  173. kwargs['copy'] = copy
  174. kwargs['track'] = track
  175. kwargs.update(dict(flags=flags, copy=copy, track=track))
  176. return self._add_send_event('send', msg=msg, kwargs=kwargs)
  177. def _deserialize(self, recvd, load):
  178. """Deserialize with Futures"""
  179. f = self._Future()
  180. def _chain(_):
  181. """Chain result through serialization to recvd"""
  182. if f.done():
  183. return
  184. if recvd.exception():
  185. f.set_exception(recvd.exception())
  186. else:
  187. buf = recvd.result()
  188. try:
  189. loaded = load(buf)
  190. except Exception as e:
  191. f.set_exception(e)
  192. else:
  193. f.set_result(loaded)
  194. recvd.add_done_callback(_chain)
  195. def _chain_cancel(_):
  196. """Chain cancellation from f to recvd"""
  197. if recvd.done():
  198. return
  199. if f.cancelled():
  200. recvd.cancel()
  201. f.add_done_callback(_chain_cancel)
  202. return f
  203. def poll(self, timeout=None, flags=_zmq.POLLIN):
  204. """poll the socket for events
  205. returns a Future for the poll results.
  206. """
  207. if self.closed:
  208. raise _zmq.ZMQError(_zmq.ENOTSUP)
  209. p = self._poller_class()
  210. p.register(self, flags)
  211. f = p.poll(timeout)
  212. future = self._Future()
  213. def unwrap_result(f):
  214. if future.done():
  215. return
  216. if f.cancelled():
  217. try:
  218. future.cancel()
  219. except RuntimeError:
  220. # RuntimeError may be called during teardown
  221. pass
  222. return
  223. if f.exception():
  224. future.set_exception(f.exception())
  225. else:
  226. evts = dict(f.result())
  227. future.set_result(evts.get(self, 0))
  228. if f.done():
  229. # hook up result if
  230. unwrap_result(f)
  231. else:
  232. f.add_done_callback(unwrap_result)
  233. return future
  234. def _add_timeout(self, future, timeout):
  235. """Add a timeout for a send or recv Future"""
  236. def future_timeout():
  237. if future.done():
  238. # future already resolved, do nothing
  239. return
  240. # raise EAGAIN
  241. future.set_exception(_zmq.Again())
  242. self._call_later(timeout, future_timeout)
  243. def _call_later(self, delay, callback):
  244. """Schedule a function to be called later
  245. Override for different IOLoop implementations
  246. Tornado and asyncio happen to both have ioloop.call_later
  247. with the same signature.
  248. """
  249. self.io_loop.call_later(delay, callback)
  250. @staticmethod
  251. def _remove_finished_future(future, event_list):
  252. """Make sure that futures are removed from the event list when they resolve
  253. Avoids delaying cleanup until the next send/recv event,
  254. which may never come.
  255. """
  256. for f_idx, (f, kind, kwargs, _) in enumerate(event_list):
  257. if f is future:
  258. break
  259. else:
  260. return
  261. # "future" instance is shared between sockets, but each socket has its own event list.
  262. event_list.remove(event_list[f_idx])
  263. def _add_recv_event(self, kind, kwargs=None, future=None):
  264. """Add a recv event, returning the corresponding Future"""
  265. f = future or self._Future()
  266. if kind.startswith('recv') and kwargs.get('flags', 0) & _zmq.DONTWAIT:
  267. # short-circuit non-blocking calls
  268. recv = getattr(self._shadow_sock, kind)
  269. try:
  270. r = recv(**kwargs)
  271. except Exception as e:
  272. f.set_exception(e)
  273. else:
  274. f.set_result(r)
  275. return f
  276. # we add it to the list of futures before we add the timeout as the
  277. # timeout will remove the future from recv_futures to avoid leaks
  278. self._recv_futures.append(
  279. _FutureEvent(f, kind, kwargs, msg=None)
  280. )
  281. # Don't let the Future sit in _recv_events after it's done
  282. f.add_done_callback(lambda f: self._remove_finished_future(f, self._recv_futures))
  283. if hasattr(_zmq, 'RCVTIMEO'):
  284. timeout_ms = self._shadow_sock.rcvtimeo
  285. if timeout_ms >= 0:
  286. self._add_timeout(f, timeout_ms * 1e-3)
  287. if self._shadow_sock.get(EVENTS) & POLLIN:
  288. # recv immediately, if we can
  289. self._handle_recv()
  290. if self._recv_futures:
  291. self._add_io_state(POLLIN)
  292. return f
  293. def _add_send_event(self, kind, msg=None, kwargs=None, future=None):
  294. """Add a send event, returning the corresponding Future"""
  295. f = future or self._Future()
  296. # attempt send with DONTWAIT if no futures are waiting
  297. # short-circuit for sends that will resolve immediately
  298. # only call if no send Futures are waiting
  299. if (
  300. kind in ('send', 'send_multipart')
  301. and not self._send_futures
  302. ):
  303. flags = kwargs.get('flags', 0)
  304. nowait_kwargs = kwargs.copy()
  305. nowait_kwargs['flags'] = flags | _zmq.DONTWAIT
  306. # short-circuit non-blocking calls
  307. send = getattr(self._shadow_sock, kind)
  308. # track if the send resolved or not
  309. # (EAGAIN if DONTWAIT is not set should proceed with)
  310. finish_early = True
  311. try:
  312. r = send(msg, **nowait_kwargs)
  313. except _zmq.Again as e:
  314. if flags & _zmq.DONTWAIT:
  315. f.set_exception(e)
  316. else:
  317. # EAGAIN raised and DONTWAIT not requested,
  318. # proceed with async send
  319. finish_early = False
  320. except Exception as e:
  321. f.set_exception(e)
  322. else:
  323. f.set_result(r)
  324. if finish_early:
  325. # short-circuit resolved, return finished Future
  326. # schedule wake for recv if there are any receivers waiting
  327. if self._recv_futures:
  328. self._schedule_remaining_events()
  329. return f
  330. # we add it to the list of futures before we add the timeout as the
  331. # timeout will remove the future from recv_futures to avoid leaks
  332. self._send_futures.append(
  333. _FutureEvent(f, kind, kwargs=kwargs, msg=msg)
  334. )
  335. # Don't let the Future sit in _send_futures after it's done
  336. f.add_done_callback(lambda f: self._remove_finished_future(f, self._send_futures))
  337. if hasattr(_zmq, 'SNDTIMEO'):
  338. timeout_ms = self._shadow_sock.get(_zmq.SNDTIMEO)
  339. if timeout_ms >= 0:
  340. self._add_timeout(f, timeout_ms * 1e-3)
  341. self._add_io_state(POLLOUT)
  342. return f
  343. def _handle_recv(self):
  344. """Handle recv events"""
  345. if not self._shadow_sock.get(EVENTS) & POLLIN:
  346. # event triggered, but state may have been changed between trigger and callback
  347. return
  348. f = None
  349. while self._recv_futures:
  350. f, kind, kwargs, _ = self._recv_futures.popleft()
  351. # skip any cancelled futures
  352. if f.done():
  353. f = None
  354. else:
  355. break
  356. if not self._recv_futures:
  357. self._drop_io_state(POLLIN)
  358. if f is None:
  359. return
  360. if kind == 'poll':
  361. # on poll event, just signal ready, nothing else.
  362. f.set_result(None)
  363. return
  364. elif kind == 'recv_multipart':
  365. recv = self._shadow_sock.recv_multipart
  366. elif kind == 'recv':
  367. recv = self._shadow_sock.recv
  368. else:
  369. raise ValueError("Unhandled recv event type: %r" % kind)
  370. kwargs['flags'] |= _zmq.DONTWAIT
  371. try:
  372. result = recv(**kwargs)
  373. except Exception as e:
  374. f.set_exception(e)
  375. else:
  376. f.set_result(result)
  377. def _handle_send(self):
  378. if not self._shadow_sock.get(EVENTS) & POLLOUT:
  379. # event triggered, but state may have been changed between trigger and callback
  380. return
  381. f = None
  382. while self._send_futures:
  383. f, kind, kwargs, msg = self._send_futures.popleft()
  384. # skip any cancelled futures
  385. if f.done():
  386. f = None
  387. else:
  388. break
  389. if not self._send_futures:
  390. self._drop_io_state(POLLOUT)
  391. if f is None:
  392. return
  393. if kind == 'poll':
  394. # on poll event, just signal ready, nothing else.
  395. f.set_result(None)
  396. return
  397. elif kind == 'send_multipart':
  398. send = self._shadow_sock.send_multipart
  399. elif kind == 'send':
  400. send = self._shadow_sock.send
  401. else:
  402. raise ValueError("Unhandled send event type: %r" % kind)
  403. kwargs['flags'] |= _zmq.DONTWAIT
  404. try:
  405. result = send(msg, **kwargs)
  406. except Exception as e:
  407. f.set_exception(e)
  408. else:
  409. f.set_result(result)
  410. # event masking from ZMQStream
  411. def _handle_events(self, fd=0, events=0):
  412. """Dispatch IO events to _handle_recv, etc."""
  413. zmq_events = self._shadow_sock.get(EVENTS)
  414. if zmq_events & _zmq.POLLIN:
  415. self._handle_recv()
  416. if zmq_events & _zmq.POLLOUT:
  417. self._handle_send()
  418. self._schedule_remaining_events()
  419. def _schedule_remaining_events(self, events=None):
  420. """Schedule a call to handle_events next loop iteration
  421. If there are still events to handle.
  422. """
  423. # edge-triggered handling
  424. # allow passing events in, in case this is triggered by retrieving events,
  425. # so we don't have to retrieve it twice.
  426. if self._state == 0:
  427. # not watching for anything, nothing to schedule
  428. return
  429. if events is None:
  430. events = self._shadow_sock.get(EVENTS)
  431. if events & self._state:
  432. self._call_later(0, self._handle_events)
  433. def _add_io_state(self, state):
  434. """Add io_state to poller."""
  435. if self._state != state:
  436. state = self._state = self._state | state
  437. self._update_handler(self._state)
  438. def _drop_io_state(self, state):
  439. """Stop poller from watching an io_state."""
  440. if self._state & state:
  441. self._state = self._state & (~state)
  442. self._update_handler(self._state)
  443. def _update_handler(self, state):
  444. """Update IOLoop handler with state.
  445. zmq FD is always read-only.
  446. """
  447. self._schedule_remaining_events()
  448. def _init_io_state(self):
  449. """initialize the ioloop event handler"""
  450. self.io_loop.add_handler(self._shadow_sock, self._handle_events, self._READ)
  451. self._call_later(0, self._handle_events)
  452. def _clear_io_state(self):
  453. """unregister the ioloop event handler
  454. called once during close
  455. """
  456. fd = self._shadow_sock
  457. if self._shadow_sock.closed:
  458. fd = self._fd
  459. self.io_loop.remove_handler(fd)