nesting.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831
  1. # -*- coding: utf-8 -*-
  2. from collections import OrderedDict, defaultdict
  3. import copy
  4. from functools import partial, reduce
  5. import inspect
  6. import logging
  7. from six import string_types
  8. from ..core import State, Machine, Transition, Event, listify, MachineError, Enum, EnumMeta, EventData
  9. _LOGGER = logging.getLogger(__name__)
  10. _LOGGER.addHandler(logging.NullHandler())
  11. # this is a workaround for dill issues when partials and super is used in conjunction
  12. # without it, Python 3.0 - 3.3 will not support pickling
  13. # https://github.com/pytransitions/transitions/issues/236
  14. _super = super
  15. # converts a hierarchical tree into a list of current states
  16. def _build_state_list(state_tree, separator, prefix=[]):
  17. res = []
  18. for key, value in state_tree.items():
  19. if value:
  20. res.append(_build_state_list(value, separator, prefix=prefix + [key]))
  21. else:
  22. res.append(separator.join(prefix + [key]))
  23. return res if len(res) > 1 else res[0]
  24. # custom breadth-first tree exploration
  25. # makes sure that ALL children are evaluated before parents in parallel states
  26. def _resolve_order(state_tree):
  27. s = state_tree
  28. q = []
  29. res = []
  30. p = []
  31. while True:
  32. for k in reversed(list(s.keys())):
  33. pk = p + [k]
  34. res.append(pk)
  35. if s[k]:
  36. q.append((pk, s[k]))
  37. if not q:
  38. break
  39. p, s = q.pop(0)
  40. return reversed(res)
  41. class FunctionWrapper(object):
  42. """ A wrapper to enable transitions' convenience function to_<state> for nested states.
  43. This allows to call model.to_A.s1.C() in case a custom separator has been chosen."""
  44. def __init__(self, func, path):
  45. """
  46. Args:
  47. func: Function to be called at the end of the path.
  48. path: If path is an empty string, assign function
  49. """
  50. if path:
  51. self.add(func, path)
  52. self._func = None
  53. else:
  54. self._func = func
  55. def add(self, func, path):
  56. """ Assigns a `FunctionWrapper` as an attribute named like the next segment of the substates
  57. path.
  58. Args:
  59. func (callable): Function to be called at the end of the path.
  60. path (string): Remaining segment of the substate path.
  61. """
  62. if path:
  63. name = path[0]
  64. if name[0].isdigit():
  65. name = 's' + name
  66. if hasattr(self, name):
  67. getattr(self, name).add(func, path[1:])
  68. else:
  69. setattr(self, name, FunctionWrapper(func, path[1:]))
  70. else:
  71. self._func = func
  72. def __call__(self, *args, **kwargs):
  73. return self._func(*args, **kwargs)
  74. class NestedEvent(Event):
  75. """ An event type to work with nested states.
  76. This subclass is NOT compatible with simple Machine instances.
  77. """
  78. def trigger(self, _model, _machine, *args, **kwargs):
  79. """ Serially execute all transitions that match the current state,
  80. halting as soon as one successfully completes. NOTE: This should only
  81. be called by HierarchicalMachine instances.
  82. Args:
  83. _model (object): model object to
  84. _machine (HierarchicalMachine): Since NestedEvents can be used in multiple machine instances, this one
  85. will be used to determine the current state separator.
  86. args and kwargs: Optional positional or named arguments that will
  87. be passed onto the EventData object, enabling arbitrary state
  88. information to be passed on to downstream triggered functions.
  89. Returns: boolean indicating whether or not a transition was
  90. successfully executed (True if successful, False if not).
  91. """
  92. func = partial(self._trigger, _model, _machine, *args, **kwargs)
  93. # pylint: disable=protected-access
  94. # noinspection PyProtectedMember
  95. # Machine._process should not be called somewhere else. That's why it should not be exposed
  96. # to Machine users.
  97. return _machine._process(func)
  98. def _trigger(self, _model, _machine, *args, **kwargs):
  99. state_tree = _machine._build_state_tree(getattr(_model, _machine.model_attribute), _machine.state_cls.separator)
  100. state_tree = reduce(dict.get, _machine.get_global_name(join=False), state_tree)
  101. ordered_states = _resolve_order(state_tree)
  102. done = []
  103. res = None
  104. for state_path in ordered_states:
  105. state_name = _machine.state_cls.separator.join(state_path)
  106. if state_name not in done and state_name in self.transitions:
  107. state = _machine.get_state(state_name)
  108. event_data = EventData(state, self, _machine, _model, args=args, kwargs=kwargs)
  109. event_data.source_name = state_name
  110. event_data.source_path = copy.copy(state_path)
  111. res = self._process(event_data)
  112. if res:
  113. elems = state_path
  114. while elems:
  115. done.append(_machine.state_cls.separator.join(elems))
  116. elems.pop()
  117. return res
  118. def _process(self, event_data):
  119. machine = event_data.machine
  120. machine.callbacks(event_data.machine.prepare_event, event_data)
  121. _LOGGER.debug("%sExecuted machine preparation callbacks before conditions.", machine.name)
  122. try:
  123. for trans in self.transitions[event_data.source_name]:
  124. event_data.transition = trans
  125. if trans.execute(event_data):
  126. event_data.result = True
  127. break
  128. except Exception as err:
  129. event_data.error = err
  130. raise
  131. finally:
  132. machine.callbacks(machine.finalize_event, event_data)
  133. _LOGGER.debug("%sExecuted machine finalize callbacks", machine.name)
  134. return event_data.result
  135. class NestedState(State):
  136. """ A state which allows substates.
  137. Attributes:
  138. states (OrderedDict): A list of substates of the current state.
  139. events (dict): A list of events defined for the nested state.
  140. initial (str): Name of a child which should be entered when the state is entered.
  141. exit_stack (defaultdict): A list of currently active substates
  142. """
  143. separator = '_'
  144. u""" Separator between the names of parent and child states. In case '_' is required for
  145. naming state, this value can be set to other values such as '.' or even unicode characters
  146. such as '↦' (limited to Python 3 though).
  147. """
  148. def __init__(self, name, on_enter=None, on_exit=None, ignore_invalid_triggers=None, initial=None):
  149. _super(NestedState, self).__init__(name=name, on_enter=on_enter, on_exit=on_exit,
  150. ignore_invalid_triggers=ignore_invalid_triggers)
  151. self.initial = initial
  152. self.events = {}
  153. self.states = OrderedDict()
  154. self._scope = []
  155. def add_substate(self, state):
  156. """ Adds a state as a substate.
  157. Args:
  158. state (NestedState): State to add to the current state.
  159. """
  160. self.add_substates(state)
  161. def add_substates(self, states):
  162. """ Adds a list of states to the current state.
  163. Args:
  164. states (list): List of states to add to the current state.
  165. """
  166. for state in listify(states):
  167. self.states[state.name] = state
  168. def scoped_enter(self, event_data, scope=[]):
  169. self._scope = scope
  170. self.enter(event_data)
  171. self._scope = []
  172. def scoped_exit(self, event_data, scope=[]):
  173. self._scope = scope
  174. self.exit(event_data)
  175. self._scope = []
  176. @property
  177. def name(self):
  178. return self.separator.join(self._scope + [_super(NestedState, self).name])
  179. class NestedTransition(Transition):
  180. """ A transition which handles entering and leaving nested states.
  181. Attributes:
  182. source (str): Source state of the transition.
  183. dest (str): Destination state of the transition.
  184. prepare (list): Callbacks executed before conditions checks.
  185. conditions (list): Callbacks evaluated to determine if
  186. the transition should be executed.
  187. before (list): Callbacks executed before the transition is executed
  188. but only if condition checks have been successful.
  189. after (list): Callbacks executed after the transition is executed
  190. but only if condition checks have been successful.
  191. """
  192. def _resolve_transition(self, event_data):
  193. machine = event_data.machine
  194. dst_name_path = machine.get_local_name(self.dest, join=False)
  195. _ = machine.get_state(dst_name_path)
  196. model_states = listify(getattr(event_data.model, machine.model_attribute))
  197. state_tree = machine._build_state_tree(model_states, machine.state_cls.separator)
  198. scope = machine.get_global_name(join=False)
  199. src_name_path = event_data.source_path
  200. if src_name_path == dst_name_path:
  201. root = src_name_path[:-1] # exit and enter the same state
  202. else:
  203. root = []
  204. while dst_name_path and src_name_path and src_name_path[0] == dst_name_path[0]:
  205. root.append(src_name_path.pop(0))
  206. dst_name_path.pop(0)
  207. scoped_tree = reduce(dict.get, scope + root, state_tree)
  208. exit_partials = []
  209. if src_name_path:
  210. for state_name in _resolve_order(scoped_tree):
  211. cb = partial(machine.get_state(root + state_name).scoped_exit,
  212. event_data,
  213. scope + root + state_name[:-1])
  214. exit_partials.append(cb)
  215. if dst_name_path:
  216. new_states, enter_partials = self._enter_nested(root, dst_name_path, scope + root, event_data)
  217. else:
  218. new_states, enter_partials = {}, []
  219. for key in scoped_tree:
  220. del scoped_tree[key]
  221. for new_key, value in new_states.items():
  222. scoped_tree[new_key] = value
  223. break
  224. return state_tree, exit_partials, enter_partials
  225. def _change_state(self, event_data):
  226. state_tree, exit_partials, enter_partials = self._resolve_transition(event_data)
  227. for func in exit_partials:
  228. func()
  229. self._update_model(event_data, state_tree)
  230. for func in enter_partials:
  231. func()
  232. def _enter_nested(self, root, dest, prefix_path, event_data):
  233. if root:
  234. state_name = root.pop(0)
  235. with event_data.machine(state_name):
  236. return self._enter_nested(root, dest, prefix_path, event_data)
  237. elif dest:
  238. new_states = OrderedDict()
  239. state_name = dest.pop(0)
  240. with event_data.machine(state_name):
  241. new_states[state_name], new_enter = self._enter_nested([], dest, prefix_path + [state_name], event_data)
  242. enter_partials = [partial(event_data.machine.scoped.scoped_enter, event_data, prefix_path)] + new_enter
  243. return new_states, enter_partials
  244. elif event_data.machine.scoped.initial:
  245. new_states = OrderedDict()
  246. enter_partials = []
  247. q = []
  248. prefix = prefix_path
  249. scoped_tree = new_states
  250. initial_states = [event_data.machine.scoped.states[i] for i in listify(event_data.machine.scoped.initial)]
  251. while True:
  252. event_data.scope = prefix
  253. for state in initial_states:
  254. enter_partials.append(partial(state.scoped_enter, event_data, prefix))
  255. scoped_tree[state.name] = OrderedDict()
  256. if state.initial:
  257. q.append((scoped_tree[state.name], prefix + [state.name],
  258. [state.states[i] for i in listify(state.initial)]))
  259. if not q:
  260. break
  261. scoped_tree, prefix, initial_states = q.pop(0)
  262. return new_states, enter_partials
  263. else:
  264. return {}, []
  265. @staticmethod
  266. def _update_model(event_data, tree):
  267. model_states = _build_state_list(tree, event_data.machine.state_cls.separator)
  268. with event_data.machine():
  269. event_data.machine.set_state(model_states, event_data.model)
  270. states = event_data.machine.get_states(listify(model_states))
  271. event_data.state = states[0] if len(states) == 1 else states
  272. # Prevent deep copying of callback lists since these include either references to callable or
  273. # strings. Deep copying a method reference would lead to the creation of an entire new (model) object
  274. # (see https://github.com/pytransitions/transitions/issues/248)
  275. def __deepcopy__(self, memo):
  276. cls = self.__class__
  277. result = cls.__new__(cls)
  278. memo[id(self)] = result
  279. for key, value in self.__dict__.items():
  280. if key in cls.dynamic_methods:
  281. setattr(result, key, copy.copy(value))
  282. else:
  283. setattr(result, key, copy.deepcopy(value, memo))
  284. return result
  285. class HierarchicalMachine(Machine):
  286. """ Extends transitions.core.Machine by capabilities to handle nested states.
  287. A hierarchical machine REQUIRES NestedStates (or any subclass of it) to operate.
  288. """
  289. state_cls = NestedState
  290. transition_cls = NestedTransition
  291. event_cls = NestedEvent
  292. def __init__(self, *args, **kwargs):
  293. self._stack = []
  294. self.scoped = self
  295. _super(HierarchicalMachine, self).__init__(*args, **kwargs)
  296. def __call__(self, to_scope=None):
  297. if isinstance(to_scope, string_types):
  298. state_name = to_scope.split(self.state_cls.separator)[0]
  299. state = self.states[state_name]
  300. to_scope = (state, state.states, state.events)
  301. elif isinstance(to_scope, Enum):
  302. state = self.states[to_scope.name]
  303. to_scope = (state, state.states, state.events)
  304. elif to_scope is None:
  305. if self._stack:
  306. to_scope = self._stack[0]
  307. else:
  308. to_scope = (self, self.states, self.events)
  309. self._next_scope = to_scope
  310. return self
  311. def __enter__(self):
  312. self._stack.append((self.scoped, self.states, self.events))
  313. self.scoped, self.states, self.events = self._next_scope
  314. self._next_scope = None
  315. def __exit__(self, exc_type, exc_val, exc_tb):
  316. self.scoped, self.states, self.events = self._stack.pop()
  317. def add_model(self, model, initial=None):
  318. """ Extends transitions.core.Machine.add_model by applying a custom 'to' function to
  319. the added model.
  320. """
  321. models = [mod if mod != 'self' else self for mod in listify(model)]
  322. _super(HierarchicalMachine, self).add_model(models, initial=initial)
  323. initial_name = getattr(models[0], self.model_attribute)
  324. if hasattr(initial_name, 'name'):
  325. initial_name = initial_name.name
  326. initial_states = self._resolve_initial(models, initial_name.split(self.state_cls.separator))
  327. for mod in models:
  328. self.set_state(initial_states, mod)
  329. if hasattr(mod, 'to'):
  330. _LOGGER.warning("%sModel already has a 'to'-method. It will NOT "
  331. "be overwritten by NestedMachine", self.name)
  332. else:
  333. to_func = partial(self.to_state, mod)
  334. setattr(mod, 'to', to_func)
  335. def add_ordered_transitions(self, states=None, trigger='next_state',
  336. loop=True, loop_includes_initial=True,
  337. conditions=None, unless=None, before=None,
  338. after=None, prepare=None, **kwargs):
  339. if states is None:
  340. states = self.get_nested_state_names()
  341. _super(HierarchicalMachine, self).add_ordered_transitions(states=states, trigger=trigger, loop=loop,
  342. loop_includes_initial=loop_includes_initial,
  343. conditions=conditions,
  344. unless=unless, before=before, after=after,
  345. prepare=prepare, **kwargs)
  346. def add_states(self, states, on_enter=None, on_exit=None, ignore_invalid_triggers=None, **kwargs):
  347. remap = kwargs.pop('remap', None)
  348. for state in listify(states):
  349. if isinstance(state, Enum) and isinstance(state.value, EnumMeta):
  350. state = {'name': state.name, 'children': state.value}
  351. if isinstance(state, string_types):
  352. if remap is not None and state in remap:
  353. return
  354. domains = state.split(self.state_cls.separator, 1)
  355. if len(domains) > 1:
  356. try:
  357. self.get_state(domains[0])
  358. except ValueError:
  359. self.add_state(domains[0], on_enter=on_enter, on_exit=on_exit, ignore_invalid_triggers=ignore_invalid_triggers, **kwargs)
  360. with self(domains[0]):
  361. self.add_states(domains[1], on_enter=on_enter, on_exit=on_exit, ignore_invalid_triggers=ignore_invalid_triggers, **kwargs)
  362. else:
  363. if state in self.states:
  364. raise ValueError("State {0} cannot be added since it already exists.".format(state))
  365. new_state = self._create_state(state)
  366. self.states[new_state.name] = new_state
  367. self._init_state(new_state)
  368. elif isinstance(state, Enum):
  369. if remap is not None and state.name in remap:
  370. return
  371. new_state = self._create_state(state)
  372. if state.name in self.states:
  373. raise ValueError("State {0} cannot be added since it already exists.".format(state.name))
  374. self.states[new_state.name] = new_state
  375. self._init_state(new_state)
  376. elif isinstance(state, dict):
  377. if remap is not None and state['name'] in remap:
  378. return
  379. state = state.copy() # prevent messing with the initially passed dict
  380. remap = state.pop('remap', None)
  381. state_children = state.pop('children', [])
  382. state_parallel = state.pop('parallel', [])
  383. transitions = state.pop('transitions', [])
  384. new_state = self._create_state(**state)
  385. self.states[new_state.name] = new_state
  386. self._init_state(new_state)
  387. remapped_transitions = []
  388. with self(new_state.name):
  389. if state_parallel:
  390. self.add_states(state_parallel, remap=remap, **kwargs)
  391. new_state.initial = [s if isinstance(s, string_types) else s['name'] for s in state_parallel]
  392. else:
  393. self.add_states(state_children, remap=remap, **kwargs)
  394. if remap is not None:
  395. drop_event = []
  396. for evt in self.events.values():
  397. self.events[evt.name] = copy.copy(evt)
  398. for trigger, event in self.events.items():
  399. drop_source = []
  400. event.transitions = copy.deepcopy(event.transitions)
  401. for source_name, trans_source in event.transitions.items():
  402. if source_name in remap:
  403. drop_source.append(source_name)
  404. continue
  405. drop_trans = []
  406. for trans in trans_source:
  407. if trans.dest in remap:
  408. conditions, unless = [], []
  409. for cond in trans.conditions:
  410. # split a list in two lists based on the accessors (cond.target) truth value
  411. (unless, conditions)[cond.target].append(cond.func)
  412. remapped_transitions.append({
  413. 'trigger': trigger,
  414. 'source': new_state.name + self.state_cls.separator + trans.source,
  415. 'dest': remap[trans.dest],
  416. 'conditions': conditions,
  417. 'unless': unless,
  418. 'prepare': trans.prepare,
  419. 'before': trans.before,
  420. 'after': trans.after})
  421. drop_trans.append(trans)
  422. for t in drop_trans:
  423. trans_source.remove(t)
  424. if not trans_source:
  425. drop_source.append(source_name)
  426. for s in drop_source:
  427. del event.transitions[s]
  428. if not event.transitions:
  429. drop_event.append(trigger)
  430. for e in drop_event:
  431. del self.events[e]
  432. if transitions:
  433. self.add_transitions(transitions)
  434. self.add_transitions(remapped_transitions)
  435. elif isinstance(state, NestedState):
  436. if state.name in self.states:
  437. raise ValueError("State {0} cannot be added since it already exists.".format(state.name))
  438. self.states[state.name] = state
  439. self._init_state(state)
  440. elif isinstance(state, Machine):
  441. new_states = [s for s in state.states.values() if remap is None or s not in remap]
  442. self.add_states(new_states)
  443. for ev in state.events.values():
  444. self.events[ev.name] = ev
  445. if self.scoped.initial is None:
  446. self.scoped.initial = state.initial
  447. else:
  448. raise ValueError("Cannot add state of type {0}.".format(type(state).__name__))
  449. def add_transition(self, trigger, source, dest, conditions=None,
  450. unless=None, before=None, after=None, prepare=None, **kwargs):
  451. if source != self.wildcard_all:
  452. source = [self.state_cls.separator.join(self._get_enum_path(s)) if isinstance(s, Enum) else s
  453. for s in listify(source)]
  454. if dest != self.wildcard_same:
  455. dest = self.state_cls.separator.join(self._get_enum_path(dest)) if isinstance(dest, Enum) else dest
  456. _super(HierarchicalMachine, self).add_transition(trigger, source, dest, conditions,
  457. unless, before, after, prepare, **kwargs)
  458. def get_global_name(self, state=None, join=True):
  459. local_stack = [s[0] for s in self._stack] + [self.scoped]
  460. local_stack_start = len(local_stack) - local_stack[::-1].index(self)
  461. domains = [s.name for s in local_stack[local_stack_start:]]
  462. if state:
  463. state_name = state.name if hasattr(state, 'name') else state
  464. if state_name in self.states:
  465. domains.append(state_name)
  466. else:
  467. raise ValueError("State '{0}' not found in local states.".format(state))
  468. return self.state_cls.separator.join(domains) if join else domains
  469. def get_local_name(self, state_name, join=True):
  470. state_name = state_name.split(self.state_cls.separator)
  471. local_stack = [s[0] for s in self._stack] + [self.scoped]
  472. local_stack_start = len(local_stack) - local_stack[::-1].index(self)
  473. domains = [s.name for s in local_stack[local_stack_start:]]
  474. if domains and state_name and state_name[0] != domains[0]:
  475. return self.state_cls.separator.join(state_name) if join else state_name
  476. return self.state_cls.separator.join(state_name) if join else state_name
  477. def get_nested_state_names(self):
  478. ordered_states = []
  479. for state in self.states.values():
  480. ordered_states.append(self.get_global_name(state))
  481. with self(state.name):
  482. ordered_states.extend(self.get_nested_state_names())
  483. return ordered_states
  484. def get_nested_triggers(self, dest=None):
  485. if dest:
  486. triggers = _super(HierarchicalMachine, self).get_triggers(dest)
  487. else:
  488. triggers = list(self.events.keys())
  489. for state in self.states.values():
  490. with self(state.name):
  491. triggers.extend(self.get_nested_triggers())
  492. return triggers
  493. def get_state(self, state, hint=None):
  494. """ Return the State instance with the passed name. """
  495. if isinstance(state, Enum):
  496. state = self._get_enum_path(state)
  497. elif isinstance(state, string_types):
  498. state = state.split(self.state_cls.separator)
  499. if not hint:
  500. state = copy.copy(state)
  501. hint = copy.copy(state)
  502. if len(state) > 1:
  503. child = state.pop(0)
  504. try:
  505. with self(child):
  506. return self.get_state(state, hint)
  507. except (KeyError, ValueError):
  508. try:
  509. with self():
  510. state = self
  511. for elem in hint:
  512. state = state.states[elem]
  513. return state
  514. except KeyError:
  515. raise ValueError("State '%s' is not a registered state." % self.state_cls.separator.join(hint))
  516. elif state[0] not in self.states:
  517. raise ValueError("State '%s' is not a registered state." % state)
  518. return self.states[state[0]]
  519. def get_states(self, states):
  520. res = []
  521. for state in states:
  522. if isinstance(state, list):
  523. res.append(self.get_states(state))
  524. else:
  525. res.append(self.get_state(state))
  526. return res
  527. def get_triggers(self, *args):
  528. """ Extends transitions.core.Machine.get_triggers to also include parent state triggers. """
  529. # add parents to state set
  530. triggers = []
  531. with self():
  532. for state_name in args:
  533. state_path = state_name.split(self.state_cls.separator)
  534. root = state_path[0]
  535. while state_path:
  536. triggers.extend(_super(HierarchicalMachine, self).get_triggers(self.state_cls.separator.join(state_path)))
  537. with self(root):
  538. triggers.extend(self.get_nested_triggers(self.state_cls.separator.join(state_path)))
  539. state_path.pop()
  540. return triggers
  541. def is_state(self, state_name, model, allow_substates=False):
  542. current_name = getattr(model, self.model_attribute)
  543. if allow_substates:
  544. return current_name.startswith(state_name.name if hasattr(state_name, 'name') else state_name)
  545. return current_name == state_name
  546. def on_enter(self, state_name, callback):
  547. """ Helper function to add callbacks to states in case a custom state separator is used.
  548. Args:
  549. state_name (str): Name of the state
  550. callback (str or callable): Function to be called. Strings will be resolved to model functions.
  551. """
  552. self.get_state(state_name).add_callback('enter', callback)
  553. def on_exit(self, state_name, callback):
  554. """ Helper function to add callbacks to states in case a custom state separator is used.
  555. Args:
  556. state_name (str): Name of the state
  557. callback (str or callable): Function to be called. Strings will be resolved to model functions.
  558. """
  559. self.get_state(state_name).add_callback('exit', callback)
  560. def set_state(self, states, model=None):
  561. """ Set the current state.
  562. Args:
  563. states (list of str or Enum or State): value of state(s) to be set
  564. model (optional[object]): targeted model; if not set, all models will be set to 'state'
  565. """
  566. values = [self._set_state(value) for value in listify(states)]
  567. models = self.models if model is None else listify(model)
  568. for mod in models:
  569. setattr(mod, self.model_attribute, values if len(values) > 1 else values[0])
  570. def to_state(self, model, state_name, *args, **kwargs):
  571. """ Helper function to add go to states in case a custom state separator is used.
  572. Args:
  573. model (class): The model that should be used.
  574. state_name (str): Name of the destination state.
  575. """
  576. current_state = getattr(model, self.model_attribute)
  577. if isinstance(current_state, list):
  578. raise MachineError("Cannot use 'to_state' from parallel state")
  579. event = EventData(self.get_state(current_state), Event('to', self), self,
  580. model, args=args, kwargs=kwargs)
  581. event.source_name = current_state
  582. event.source_path = current_state.split(self.state_cls.separator)
  583. self._create_transition(current_state, state_name).execute(event)
  584. def trigger_event(self, _model, _trigger, *args, **kwargs):
  585. """ Processes events recursively and forwards arguments if suitable events are found.
  586. This function is usually bound to models with model and trigger arguments already
  587. resolved as a partial. Execution will halt when a nested transition has been executed
  588. successfully.
  589. Args:
  590. _model (object): targeted model
  591. _trigger (str): event name
  592. *args: positional parameters passed to the event and its callbacks
  593. **kwargs: keyword arguments passed to the event and its callbacks
  594. Returns:
  595. bool: whether a transition has been executed successfully
  596. Raises:
  597. MachineError: When no suitable transition could be found and ignore_invalid_trigger
  598. is not True. Note that a transition which is not executed due to conditions
  599. is still considered valid.
  600. """
  601. with self():
  602. res = self._trigger_event(_model, _trigger, None, *args, **kwargs)
  603. return self._check_event_result(res, _model, _trigger)
  604. def _add_model_to_state(self, state, model):
  605. name = self.get_global_name(state)
  606. if self.state_cls.separator == '_' or self.state_cls.separator not in name:
  607. value = state.value if isinstance(state.value, Enum) else name
  608. self._checked_assignment(model, 'is_%s' % name, partial(self.is_state, value, model))
  609. # Add dynamic method callbacks (enter/exit) if there are existing bound methods in the model
  610. # except if they are already mentioned in 'on_enter/exit' of the defined state
  611. for callback in self.state_cls.dynamic_methods:
  612. method = "{0}_{1}".format(callback, name)
  613. if hasattr(model, method) and inspect.ismethod(getattr(model, method)) and \
  614. method not in getattr(state, callback):
  615. state.add_callback(callback[3:], method)
  616. with self(state.name):
  617. for event in self.events.values():
  618. if not hasattr(model, event.name):
  619. self._add_trigger_to_model(event.name, model)
  620. for state in self.states.values():
  621. self._add_model_to_state(state, model)
  622. def _add_trigger_to_model(self, trigger, model):
  623. trig_func = partial(self.trigger_event, model, trigger)
  624. # FunctionWrappers are only necessary if a custom separator is used
  625. if trigger.startswith('to_') and self.state_cls.separator != '_':
  626. path = trigger[3:].split(self.state_cls.separator)
  627. if hasattr(model, 'to_' + path[0]):
  628. # add path to existing function wrapper
  629. getattr(model, 'to_' + path[0]).add(trig_func, path[1:])
  630. else:
  631. # create a new function wrapper
  632. self._checked_assignment(model, 'to_' + path[0], FunctionWrapper(trig_func, path[1:]))
  633. else:
  634. self._checked_assignment(model, trigger, trig_func)
  635. # converts a list of current states into a hierarchical state tree
  636. def _build_state_tree(self, model_states, separator, tree=None):
  637. tree = tree if tree is not None else OrderedDict()
  638. if isinstance(model_states, list):
  639. for state in model_states:
  640. _ = self._build_state_tree(state, separator, tree)
  641. else:
  642. tmp = tree
  643. if isinstance(model_states, (Enum, EnumMeta)):
  644. with self():
  645. path = self._get_enum_path(model_states)
  646. else:
  647. path = model_states.split(separator)
  648. for elem in path:
  649. tmp = tmp.setdefault(elem.name if hasattr(elem, 'name') else elem, OrderedDict())
  650. return tree
  651. def _get_enum_path(self, enum_state, prefix=[]):
  652. if enum_state.name in self.states and self.states[enum_state.name].value == enum_state:
  653. return prefix + [enum_state.name]
  654. for name in self.states:
  655. with self(name):
  656. res = self._get_enum_path(enum_state, prefix=prefix + [name])
  657. if res:
  658. return res
  659. return []
  660. def _check_event_result(self, res, model, trigger):
  661. if res is None:
  662. state_name = getattr(model, self.model_attribute)
  663. msg = "%sCan't trigger event %s from state %s!" % (self.name, trigger, state_name)
  664. state = self.get_state(state_name)
  665. ignore = state.ignore_invalid_triggers if state.ignore_invalid_triggers is not None \
  666. else self.ignore_invalid_triggers
  667. if ignore:
  668. _LOGGER.warning(msg)
  669. res = False
  670. else:
  671. raise MachineError(msg)
  672. return res
  673. def _get_trigger(self, model, trigger_name, *args, **kwargs):
  674. """Convenience function added to the model to trigger events by name.
  675. Args:
  676. model (object): Model with assigned event trigger.
  677. trigger_name (str): Name of the trigger to be called.
  678. *args: Variable length argument list which is passed to the triggered event.
  679. **kwargs: Arbitrary keyword arguments which is passed to the triggered event.
  680. Returns:
  681. bool: True if a transitions has been conducted or the trigger event has been queued.
  682. """
  683. try:
  684. return self.trigger_event(model, trigger_name, *args, **kwargs)
  685. except MachineError:
  686. raise AttributeError("Do not know event named '%s'." % trigger_name)
  687. def _has_state(self, state, raise_error=False):
  688. """ This function
  689. Args:
  690. state (NestedState): state to be tested
  691. raise_error (bool): whether ValueError should be raised when the state
  692. is not registered
  693. Returns:
  694. bool: Whether state is registered in the machine
  695. Raises:
  696. ValueError: When raise_error is True and state is not registered
  697. """
  698. found = _super(HierarchicalMachine, self)._has_state(state)
  699. if not found:
  700. for a_state in self.states:
  701. with self(a_state):
  702. if self._has_state(state):
  703. return True
  704. if not found and raise_error:
  705. msg = 'State %s has not been added to the machine' % (state.name if hasattr(state, 'name') else state)
  706. raise ValueError(msg)
  707. return found
  708. def _init_state(self, state):
  709. for model in self.models:
  710. self._add_model_to_state(state, model)
  711. if self.auto_transitions:
  712. state_name = self.get_global_name(state.name)
  713. parent = state_name.split(self.state_cls.separator, 1)
  714. with self():
  715. for a_state in self.get_nested_state_names():
  716. if a_state == parent[0]:
  717. self.add_transition('to_%s' % state_name, self.wildcard_all, state_name)
  718. elif len(parent) == 1:
  719. self.add_transition('to_%s' % a_state, state_name, a_state)
  720. with self(state.name):
  721. for substate in self.states.values():
  722. self._init_state(substate)
  723. def _resolve_initial(self, models, state_name_path, prefix=[]):
  724. if state_name_path:
  725. state_name = state_name_path.pop(0)
  726. with self(state_name):
  727. return self._resolve_initial(models, state_name_path, prefix=prefix + [state_name])
  728. if self.scoped.initial:
  729. entered_states = []
  730. for initial_state_name in listify(self.scoped.initial):
  731. with self(initial_state_name):
  732. entered_states.append(self._resolve_initial(models, [], prefix=prefix + [self.scoped.name]))
  733. return entered_states if len(entered_states) > 1 else entered_states[0]
  734. return self.state_cls.separator.join(prefix)
  735. def _set_state(self, state_name):
  736. if isinstance(state_name, list):
  737. return [self._set_state(value) for value in state_name]
  738. else:
  739. a_state = self.get_state(state_name)
  740. return a_state.value if isinstance(a_state.value, Enum) else state_name
  741. def _trigger_event(self, _model, _trigger, _state_tree, *args, **kwargs):
  742. if _state_tree is None:
  743. _state_tree = self._build_state_tree(listify(getattr(_model, self.model_attribute)),
  744. self.state_cls.separator)
  745. res = {}
  746. for key, value in _state_tree.items():
  747. if value:
  748. with self(key):
  749. res[key] = self._trigger_event(_model, _trigger, value, *args, **kwargs)
  750. if not res.get(key, None) and _trigger in self.events:
  751. res[key] = self.events[_trigger].trigger(_model, self, *args, **kwargs)
  752. return None if not res or all([v is None for v in res.values()]) else any(res.values())