asyncio.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. import itertools
  2. import logging
  3. import asyncio
  4. import contextvars
  5. from functools import partial, reduce
  6. import copy
  7. from ..core import State, Condition, Transition, EventData, listify
  8. from ..core import Event, MachineError, Machine
  9. from .nesting import HierarchicalMachine, NestedState, NestedEvent, NestedTransition, _resolve_order
  10. _LOGGER = logging.getLogger(__name__)
  11. _LOGGER.addHandler(logging.NullHandler())
  12. is_subtask = contextvars.ContextVar('is_subtask', default=False)
  13. class AsyncState(State):
  14. """A persistent representation of a state managed by a ``Machine``. Callback execution is done asynchronously.
  15. Attributes:
  16. name (str): State name which is also assigned to the model(s).
  17. on_enter (list): Callbacks awaited when a state is entered.
  18. on_exit (list): Callbacks awaited when a state is entered.
  19. ignore_invalid_triggers (bool): Indicates if unhandled/invalid triggers should raise an exception.
  20. """
  21. async def enter(self, event_data):
  22. """ Triggered when a state is entered. """
  23. _LOGGER.debug("%sEntering state %s. Processing callbacks...", event_data.machine.name, self.name)
  24. await event_data.machine.callbacks(self.on_enter, event_data)
  25. _LOGGER.info("%sEntered state %s", event_data.machine.name, self.name)
  26. async def exit(self, event_data):
  27. """ Triggered when a state is exited. """
  28. _LOGGER.debug("%sExiting state %s. Processing callbacks...", event_data.machine.name, self.name)
  29. await event_data.machine.callbacks(self.on_exit, event_data)
  30. _LOGGER.info("%sExited state %s", event_data.machine.name, self.name)
  31. class NestedAsyncState(NestedState, AsyncState):
  32. async def scoped_enter(self, event_data, scope=[]):
  33. self._scope = scope
  34. await self.enter(event_data)
  35. self._scope = []
  36. async def scoped_exit(self, event_data, scope=[]):
  37. self._scope = scope
  38. await self.exit(event_data)
  39. self._scope = []
  40. class AsyncCondition(Condition):
  41. """ A helper class to await condition checks in the intended way.
  42. Attributes:
  43. func (callable): The function to call for the condition check
  44. target (bool): Indicates the target state--i.e., when True,
  45. the condition-checking callback should return True to pass,
  46. and when False, the callback should return False to pass.
  47. """
  48. async def check(self, event_data):
  49. """ Check whether the condition passes.
  50. Args:
  51. event_data (EventData): An EventData instance to pass to the
  52. condition (if event sending is enabled) or to extract arguments
  53. from (if event sending is disabled). Also contains the data
  54. model attached to the current machine which is used to invoke
  55. the condition.
  56. """
  57. predicate = event_data.machine.resolve_callable(self.func, event_data)
  58. if asyncio.iscoroutinefunction(predicate):
  59. if event_data.machine.send_event:
  60. return await predicate(event_data) == self.target
  61. else:
  62. return await predicate(*event_data.args, **event_data.kwargs) == self.target
  63. else:
  64. return super(AsyncCondition, self).check(event_data)
  65. class AsyncTransition(Transition):
  66. """ Representation of an asynchronous transition managed by a ``AsyncMachine`` instance.
  67. Attributes:
  68. source (str): Source state of the transition.
  69. dest (str): Destination state of the transition.
  70. prepare (list): Callbacks executed before conditions checks.
  71. conditions (list): Callbacks evaluated to determine if
  72. the transition should be executed.
  73. before (list): Callbacks executed before the transition is executed
  74. but only if condition checks have been successful.
  75. after (list): Callbacks executed after the transition is executed
  76. but only if condition checks have been successful.
  77. """
  78. condition_cls = AsyncCondition
  79. async def _eval_conditions(self, event_data):
  80. res = await asyncio.gather(*[cond.check(event_data) for cond in self.conditions])
  81. if not all(res):
  82. _LOGGER.debug("%sTransition condition failed: Transition halted.", event_data.machine.name)
  83. return False
  84. return True
  85. async def execute(self, event_data):
  86. """ Executes the transition.
  87. Args:
  88. event_data: An instance of class EventData.
  89. Returns: boolean indicating whether or not the transition was
  90. successfully executed (True if successful, False if not).
  91. """
  92. _LOGGER.debug("%sInitiating transition from state %s to state %s...",
  93. event_data.machine.name, self.source, self.dest)
  94. await event_data.machine.callbacks(self.prepare, event_data)
  95. _LOGGER.debug("%sExecuted callbacks before conditions.", event_data.machine.name)
  96. if not await self._eval_conditions(event_data):
  97. return False
  98. # cancel running tasks since the transition will happen
  99. machine = event_data.machine
  100. model = event_data.model
  101. if model in machine.async_tasks and not machine.async_tasks[model].done():
  102. parent = machine.async_tasks[model]
  103. check = is_subtask.get()
  104. if parent != check:
  105. machine.async_tasks[model].cancel()
  106. else:
  107. current = asyncio.current_task()
  108. is_subtask.set(current)
  109. machine.async_tasks[model] = current
  110. await event_data.machine.callbacks(itertools.chain(event_data.machine.before_state_change, self.before), event_data)
  111. _LOGGER.debug("%sExecuted callback before transition.", event_data.machine.name)
  112. if self.dest: # if self.dest is None this is an internal transition with no actual state change
  113. await self._change_state(event_data)
  114. await event_data.machine.callbacks(itertools.chain(self.after, event_data.machine.after_state_change), event_data)
  115. _LOGGER.debug("%sExecuted callback after transition.", event_data.machine.name)
  116. return True
  117. async def _change_state(self, event_data):
  118. if hasattr(event_data.machine, "model_graphs"):
  119. graph = event_data.machine.model_graphs[event_data.model]
  120. graph.reset_styling()
  121. graph.set_previous_transition(self.source, self.dest)
  122. await event_data.machine.get_state(self.source).exit(event_data)
  123. event_data.machine.set_state(self.dest, event_data.model)
  124. event_data.update(event_data.model.state)
  125. await event_data.machine.get_state(self.dest).enter(event_data)
  126. class NestedAsyncTransition(AsyncTransition, NestedTransition):
  127. async def _change_state(self, event_data):
  128. if hasattr(event_data.machine, "model_graphs"):
  129. graph = event_data.machine.model_graphs[event_data.model]
  130. graph.reset_styling()
  131. graph.set_previous_transition(self.source, self.dest)
  132. state_tree, exit_partials, enter_partials = self._resolve_transition(event_data)
  133. for func in exit_partials:
  134. await func()
  135. self._update_model(event_data, state_tree)
  136. for func in enter_partials:
  137. await func()
  138. class AsyncEvent(Event):
  139. """ A collection of transitions assigned to the same trigger """
  140. async def trigger(self, model, *args, **kwargs):
  141. """ Serially execute all transitions that match the current state,
  142. halting as soon as one successfully completes. Note that `AsyncEvent` triggers must be awaited.
  143. Args:
  144. args and kwargs: Optional positional or named arguments that will
  145. be passed onto the EventData object, enabling arbitrary state
  146. information to be passed on to downstream triggered functions.
  147. Returns: boolean indicating whether or not a transition was
  148. successfully executed (True if successful, False if not).
  149. """
  150. func = partial(self._trigger, model, *args, **kwargs)
  151. t = asyncio.create_task(self.machine._process(func))
  152. try:
  153. return await t
  154. except asyncio.CancelledError:
  155. return False
  156. async def _trigger(self, model, *args, **kwargs):
  157. state = self.machine.get_state(model.state)
  158. if state.name not in self.transitions:
  159. msg = "%sCan't trigger event %s from state %s!" % (self.machine.name, self.name,
  160. state.name)
  161. ignore = state.ignore_invalid_triggers if state.ignore_invalid_triggers is not None \
  162. else self.machine.ignore_invalid_triggers
  163. if ignore:
  164. _LOGGER.warning(msg)
  165. return False
  166. else:
  167. raise MachineError(msg)
  168. event_data = EventData(state, self, self.machine, model, args=args, kwargs=kwargs)
  169. return await self._process(event_data)
  170. async def _process(self, event_data):
  171. await self.machine.callbacks(self.machine.prepare_event, event_data)
  172. _LOGGER.debug("%sExecuted machine preparation callbacks before conditions.", self.machine.name)
  173. try:
  174. for trans in self.transitions[event_data.state.name]:
  175. event_data.transition = trans
  176. if await trans.execute(event_data):
  177. event_data.result = True
  178. break
  179. except Exception as err:
  180. event_data.error = err
  181. raise
  182. finally:
  183. await self.machine.callbacks(self.machine.finalize_event, event_data)
  184. _LOGGER.debug("%sExecuted machine finalize callbacks", self.machine.name)
  185. return event_data.result
  186. class NestedAsyncEvent(NestedEvent):
  187. async def trigger(self, _model, _machine, *args, **kwargs):
  188. """ Serially execute all transitions that match the current state,
  189. halting as soon as one successfully completes. NOTE: This should only
  190. be called by HierarchicalMachine instances.
  191. Args:
  192. _model (object): model object to
  193. machine (HierarchicalMachine): Since NestedEvents can be used in multiple machine instances, this one
  194. will be used to determine the current state separator.
  195. args and kwargs: Optional positional or named arguments that will
  196. be passed onto the EventData object, enabling arbitrary state
  197. information to be passed on to downstream triggered functions.
  198. Returns: boolean indicating whether or not a transition was
  199. successfully executed (True if successful, False if not).
  200. """
  201. func = partial(self._trigger, _model, _machine, *args, **kwargs)
  202. t = asyncio.create_task(_machine._process(func))
  203. try:
  204. return await t
  205. except asyncio.CancelledError:
  206. return False
  207. async def _trigger(self, _model, _machine, *args, **kwargs):
  208. state_tree = _machine._build_state_tree(getattr(_model, _machine.model_attribute), _machine.state_cls.separator)
  209. state_tree = reduce(dict.get, _machine.get_global_name(join=False), state_tree)
  210. ordered_states = _resolve_order(state_tree)
  211. done = []
  212. res = None
  213. for state_path in ordered_states:
  214. state_name = _machine.state_cls.separator.join(state_path)
  215. if state_name not in done and state_name in self.transitions:
  216. state = _machine.get_state(state_name)
  217. event_data = EventData(state, self, _machine, _model, args=args, kwargs=kwargs)
  218. event_data.source_name = state_name
  219. event_data.source_path = copy.copy(state_path)
  220. res = await self._process(event_data)
  221. if res:
  222. elems = state_path
  223. while elems:
  224. done.append(_machine.state_cls.separator.join(elems))
  225. elems.pop()
  226. return res
  227. async def _process(self, event_data):
  228. machine = event_data.machine
  229. await machine.callbacks(event_data.machine.prepare_event, event_data)
  230. _LOGGER.debug("%sExecuted machine preparation callbacks before conditions.", machine.name)
  231. try:
  232. for trans in self.transitions[event_data.source_name]:
  233. event_data.transition = trans
  234. if await trans.execute(event_data):
  235. event_data.result = True
  236. break
  237. except Exception as err:
  238. event_data.error = err
  239. raise
  240. finally:
  241. await machine.callbacks(machine.finalize_event, event_data)
  242. _LOGGER.debug("%sExecuted machine finalize callbacks", machine.name)
  243. return event_data.result
  244. class AsyncMachine(Machine):
  245. """ Machine manages states, transitions and models. In case it is initialized without a specific model
  246. (or specifically no model), it will also act as a model itself. Machine takes also care of decorating
  247. models with conveniences functions related to added transitions and states during runtime.
  248. Attributes:
  249. states (OrderedDict): Collection of all registered states.
  250. events (dict): Collection of transitions ordered by trigger/event.
  251. models (list): List of models attached to the machine.
  252. initial (str): Name of the initial state for new models.
  253. prepare_event (list): Callbacks executed when an event is triggered.
  254. before_state_change (list): Callbacks executed after condition checks but before transition is conducted.
  255. Callbacks will be executed BEFORE the custom callbacks assigned to the transition.
  256. after_state_change (list): Callbacks executed after the transition has been conducted.
  257. Callbacks will be executed AFTER the custom callbacks assigned to the transition.
  258. finalize_event (list): Callbacks will be executed after all transitions callbacks have been executed.
  259. Callbacks mentioned here will also be called if a transition or condition check raised an error.
  260. queued (bool): Whether transitions in callbacks should be executed immediately (False) or sequentially.
  261. send_event (bool): When True, any arguments passed to trigger methods will be wrapped in an EventData
  262. object, allowing indirect and encapsulated access to data. When False, all positional and keyword
  263. arguments will be passed directly to all callback methods.
  264. auto_transitions (bool): When True (default), every state will automatically have an associated
  265. to_{state}() convenience trigger in the base model.
  266. ignore_invalid_triggers (bool): When True, any calls to trigger methods that are not valid for the
  267. present state (e.g., calling an a_to_b() trigger when the current state is c) will be silently
  268. ignored rather than raising an invalid transition exception.
  269. name (str): Name of the ``Machine`` instance mainly used for easier log message distinction.
  270. """
  271. state_cls = AsyncState
  272. transition_cls = AsyncTransition
  273. event_cls = AsyncEvent
  274. async_tasks = {}
  275. async def dispatch(self, trigger, *args, **kwargs): # ToDo: not tested
  276. """ Trigger an event on all models assigned to the machine.
  277. Args:
  278. trigger (str): Event name
  279. *args (list): List of arguments passed to the event trigger
  280. **kwargs (dict): Dictionary of keyword arguments passed to the event trigger
  281. Returns:
  282. bool The truth value of all triggers combined with AND
  283. """
  284. results = await asyncio.gather(*[getattr(model, trigger)(*args, **kwargs) for model in self.models])
  285. return all(results)
  286. async def callbacks(self, funcs, event_data):
  287. """ Triggers a list of callbacks """
  288. await asyncio.gather(*[event_data.machine.callback(func, event_data) for func in funcs])
  289. async def callback(self, func, event_data):
  290. """ Trigger a callback function with passed event_data parameters. In case func is a string,
  291. the callable will be resolved from the passed model in event_data. This function is not intended to
  292. be called directly but through state and transition callback definitions.
  293. Args:
  294. func (string, callable): The callback function.
  295. 1. First, if the func is callable, just call it
  296. 2. Second, we try to import string assuming it is a path to a func
  297. 3. Fallback to a model attribute
  298. event_data (EventData): An EventData instance to pass to the
  299. callback (if event sending is enabled) or to extract arguments
  300. from (if event sending is disabled).
  301. """
  302. func = self.resolve_callable(func, event_data)
  303. if self.send_event:
  304. if asyncio.iscoroutinefunction(func) or asyncio.iscoroutinefunction(getattr(func, 'func', None)):
  305. await func(event_data)
  306. else:
  307. func(event_data)
  308. else:
  309. if asyncio.iscoroutinefunction(func) or asyncio.iscoroutinefunction(getattr(func, 'func', None)):
  310. await func(*event_data.args, **event_data.kwargs)
  311. else:
  312. func(*event_data.args, **event_data.kwargs)
  313. async def _process(self, trigger):
  314. # default processing
  315. if not self.has_queue:
  316. if not self._transition_queue:
  317. # if trigger raises an Error, it has to be handled by the Machine.process caller
  318. return await trigger()
  319. else:
  320. raise MachineError("Attempt to process events synchronously while transition queue is not empty!")
  321. self._transition_queue.append(trigger)
  322. # another entry in the queue implies a running transition; skip immediate execution
  323. if len(self._transition_queue) > 1:
  324. return True
  325. # execute as long as transition queue is not empty ToDo: not tested!
  326. while self._transition_queue:
  327. try:
  328. await self._transition_queue[0]()
  329. self._transition_queue.popleft()
  330. except Exception:
  331. # if a transition raises an exception, clear queue and delegate exception handling
  332. self._transition_queue.clear()
  333. raise
  334. return True
  335. class HierarchicalAsyncMachine(HierarchicalMachine, AsyncMachine):
  336. state_cls = NestedAsyncState
  337. transition_cls = NestedAsyncTransition
  338. event_cls = NestedAsyncEvent
  339. async def trigger_event(self, _model, _trigger, *args, **kwargs):
  340. """ Processes events recursively and forwards arguments if suitable events are found.
  341. This function is usually bound to models with model and trigger arguments already
  342. resolved as a partial. Execution will halt when a nested transition has been executed
  343. successfully.
  344. Args:
  345. _model (object): targeted model
  346. _trigger (str): event name
  347. *args: positional parameters passed to the event and its callbacks
  348. **kwargs: keyword arguments passed to the event and its callbacks
  349. Returns:
  350. bool: whether a transition has been executed successfully
  351. Raises:
  352. MachineError: When no suitable transition could be found and ignore_invalid_trigger
  353. is not True. Note that a transition which is not executed due to conditions
  354. is still considered valid.
  355. """
  356. with self():
  357. res = await self._trigger_event(_model, _trigger, None, *args, **kwargs)
  358. return self._check_event_result(res, _model, _trigger)
  359. async def _trigger_event(self, _model, _trigger, _state_tree, *args, **kwargs):
  360. if _state_tree is None:
  361. _state_tree = self._build_state_tree(listify(getattr(_model, self.model_attribute)), self.state_cls.separator)
  362. res = {}
  363. for key, value in _state_tree.items():
  364. if value:
  365. with self(key):
  366. res[key] = await self._trigger_event(_model, _trigger, value, *args, **kwargs)
  367. if not res.get(key, None) and _trigger in self.events:
  368. res[key] = await self.events[_trigger].trigger(_model, self, *args, **kwargs)
  369. return None if not res or all([v is None for v in res.values()]) else any(res.values())