nesting_legacy.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. # -*- coding: utf-8 -*-
  2. """
  3. transitions.extensions.nesting
  4. ------------------------------
  5. Adds the capability to work with nested states also known as hierarchical state machines.
  6. """
  7. from copy import copy, deepcopy
  8. from functools import partial
  9. import logging
  10. from six import string_types
  11. from ..core import Machine, Transition, State, Event, listify, MachineError, EventData, Enum
  12. from .nesting import FunctionWrapper
  13. _LOGGER = logging.getLogger(__name__)
  14. _LOGGER.addHandler(logging.NullHandler())
  15. # This is a workaround for dill issues when partials and super is used in conjunction
  16. # without it, Python 3.0 - 3.3 will not support pickling
  17. # https://github.com/pytransitions/transitions/issues/236
  18. _super = super
  19. class NestedState(State):
  20. """ A state which allows substates.
  21. Attributes:
  22. parent (NestedState): The parent of the current state.
  23. children (list): A list of child states of the current state.
  24. """
  25. separator = '_'
  26. u""" Separator between the names of parent and child states. In case '_' is required for
  27. naming state, this value can be set to other values such as '.' or even unicode characters
  28. such as '↦' (limited to Python 3 though).
  29. """
  30. def __init__(self, name, on_enter=None, on_exit=None, ignore_invalid_triggers=None, parent=None, initial=None):
  31. if parent is not None and isinstance(name, Enum):
  32. raise AttributeError("NestedState does not support nested enumerations.")
  33. self._initial = initial
  34. self._parent = None
  35. self.parent = parent
  36. _super(NestedState, self).__init__(name=name, on_enter=on_enter, on_exit=on_exit,
  37. ignore_invalid_triggers=ignore_invalid_triggers)
  38. self.children = []
  39. @property
  40. def parent(self):
  41. """ The parent state of this state. """
  42. return self._parent
  43. @parent.setter
  44. def parent(self, value):
  45. if value is not None:
  46. self._parent = value
  47. self._parent.children.append(self)
  48. @property
  49. def initial(self):
  50. """ When this state is entered it will automatically enter
  51. the child with this name if not None. """
  52. return self.name + self.separator + self._initial if self._initial else self._initial
  53. @initial.setter
  54. def initial(self, value):
  55. self._initial = value
  56. @property
  57. def level(self):
  58. """ Tracks how deeply nested this state is. This property is calculated from
  59. the state's parent (+1) or 0 when there is no parent. """
  60. return self.parent.level + 1 if self.parent is not None else 0
  61. @property
  62. def name(self):
  63. """ The computed name of this state. """
  64. if self.parent:
  65. return self.parent.name + self.separator + _super(NestedState, self).name
  66. return _super(NestedState, self).name
  67. @name.setter
  68. def name(self, value):
  69. self._name = value
  70. @property
  71. def value(self):
  72. return self.name if isinstance(self._name, string_types) else _super(NestedState, self).value
  73. def is_substate_of(self, state_name):
  74. """Check whether this state is a substate of a state named `state_name`
  75. Args:
  76. state_name (str): Name of the parent state to be checked
  77. Returns: bool True when `state_name` is a parent of this state
  78. """
  79. temp_state = self
  80. while not temp_state.value == state_name and temp_state.level > 0:
  81. temp_state = temp_state.parent
  82. return temp_state.value == state_name
  83. def exit_nested(self, event_data, target_state):
  84. """ Tracks child states to exit when the states is exited itself. This should not
  85. be triggered by the user but will be handled by the hierarchical machine.
  86. Args:
  87. event_data (EventData): Event related data.
  88. target_state (NestedState): The state to be entered.
  89. Returns: int level of the currently investigated (sub)state.
  90. """
  91. if self == target_state:
  92. self.exit(event_data)
  93. return self.level
  94. elif self.level > target_state.level:
  95. self.exit(event_data)
  96. return self.parent.exit_nested(event_data, target_state)
  97. elif self.level <= target_state.level:
  98. tmp_state = target_state
  99. while self.level != tmp_state.level:
  100. tmp_state = tmp_state.parent
  101. tmp_self = self
  102. while tmp_self.level > 0 and tmp_state.parent.name != tmp_self.parent.name:
  103. tmp_self.exit(event_data)
  104. tmp_self = tmp_self.parent
  105. tmp_state = tmp_state.parent
  106. if tmp_self == tmp_state:
  107. return tmp_self.level + 1
  108. tmp_self.exit(event_data)
  109. return tmp_self.level
  110. def enter_nested(self, event_data, level=None):
  111. """ Tracks parent states to be entered when the states is entered itself. This should not
  112. be triggered by the user but will be handled by the hierarchical machine.
  113. Args:
  114. event_data (EventData): Event related data.
  115. level (int): The level of the currently entered parent.
  116. """
  117. if level is not None and level <= self.level:
  118. if level != self.level:
  119. self.parent.enter_nested(event_data, level)
  120. self.enter(event_data)
  121. # Prevent deep copying of callback lists since these include either references to callables or
  122. # strings. Deep copying a method reference would lead to the creation of an entire new (model) object
  123. # (see https://github.com/pytransitions/transitions/issues/248)
  124. def __deepcopy__(self, memo):
  125. cls = self.__class__
  126. result = cls.__new__(cls)
  127. memo[id(self)] = result
  128. for key, value in self.__dict__.items():
  129. if key in cls.dynamic_methods:
  130. setattr(result, key, copy(value))
  131. else:
  132. setattr(result, key, deepcopy(value, memo))
  133. return result
  134. class NestedTransition(Transition):
  135. """ A transition which handles entering and leaving nested states.
  136. Attributes:
  137. dest (NestedState): The resolved transition destination in respect
  138. to initial states of nested states.
  139. """
  140. def execute(self, event_data):
  141. """ Extends transitions.core.transitions to handle nested states. """
  142. if self.dest is None:
  143. return _super(NestedTransition, self).execute(event_data)
  144. dest_state = event_data.machine.get_state(self.dest)
  145. while dest_state.initial:
  146. dest_state = event_data.machine.get_state(dest_state.initial)
  147. self.dest = dest_state.name
  148. return _super(NestedTransition, self).execute(event_data)
  149. # The actual state change method 'execute' in Transition was restructured to allow overriding
  150. def _change_state(self, event_data):
  151. machine = event_data.machine
  152. model = event_data.model
  153. dest_state = machine.get_state(self.dest)
  154. source_state = machine.get_model_state(model)
  155. lvl = source_state.exit_nested(event_data, dest_state)
  156. event_data.machine.set_state(self.dest, model)
  157. event_data.update(dest_state)
  158. dest_state.enter_nested(event_data, lvl)
  159. # Prevent deep copying of callback lists since these include either references to callable or
  160. # strings. Deep copying a method reference would lead to the creation of an entire new (model) object
  161. # (see https://github.com/pytransitions/transitions/issues/248)
  162. def __deepcopy__(self, memo):
  163. cls = self.__class__
  164. result = cls.__new__(cls)
  165. memo[id(self)] = result
  166. for key, value in self.__dict__.items():
  167. if key in cls.dynamic_methods:
  168. setattr(result, key, copy(value))
  169. else:
  170. setattr(result, key, deepcopy(value, memo))
  171. return result
  172. class NestedEvent(Event):
  173. """ An event type to work with nested states. """
  174. def _trigger(self, model, *args, **kwargs):
  175. state = self.machine.get_model_state(model)
  176. while state.parent and state.name not in self.transitions:
  177. state = state.parent
  178. if state.name not in self.transitions:
  179. msg = "%sCan't trigger event %s from state %s!" % (self.machine.name, self.name,
  180. self.machine.get_model_state(model))
  181. if self.machine.get_model_state(model).ignore_invalid_triggers:
  182. _LOGGER.warning(msg)
  183. else:
  184. raise MachineError(msg)
  185. event_data = EventData(state, self, self.machine,
  186. model, args=args, kwargs=kwargs)
  187. return self._process(event_data)
  188. class HierarchicalMachine(Machine):
  189. """ Extends transitions.core.Machine by capabilities to handle nested states.
  190. A hierarchical machine REQUIRES NestedStates (or any subclass of it) to operate.
  191. """
  192. state_cls = NestedState
  193. transition_cls = NestedTransition
  194. event_cls = NestedEvent
  195. def __init__(self, *args, **kwargs):
  196. self._buffered_transitions = []
  197. _super(HierarchicalMachine, self).__init__(*args, **kwargs)
  198. @Machine.initial.setter
  199. def initial(self, value):
  200. if isinstance(value, NestedState):
  201. if value.name not in self.states:
  202. self.add_state(value)
  203. else:
  204. assert self._has_state(value)
  205. state = value
  206. else:
  207. state_name = value.name if isinstance(value, Enum) else value
  208. if state_name not in self.states:
  209. self.add_state(state_name)
  210. state = self.get_state(state_name)
  211. if state.initial:
  212. self.initial = state.initial
  213. else:
  214. self._initial = state.name
  215. def add_model(self, model, initial=None):
  216. """ Extends transitions.core.Machine.add_model by applying a custom 'to' function to
  217. the added model.
  218. """
  219. _super(HierarchicalMachine, self).add_model(model, initial=initial)
  220. models = listify(model)
  221. for mod in models:
  222. mod = self if mod == 'self' else mod
  223. # TODO: Remove 'mod != self' in 0.7.0
  224. if hasattr(mod, 'to') and mod != self:
  225. _LOGGER.warning("%sModel already has a 'to'-method. It will NOT "
  226. "be overwritten by NestedMachine", self.name)
  227. else:
  228. to_func = partial(self.to_state, mod)
  229. setattr(mod, 'to', to_func)
  230. def is_state(self, state_name, model, allow_substates=False):
  231. """ Extends transitions.core.Machine.is_state with an additional parameter (allow_substates)
  232. to
  233. Args:
  234. state_name (str): Name of the checked state.
  235. model (class): The model to be investigated.
  236. allow_substates (bool): Whether substates should be allowed or not.
  237. Returns: bool Whether the passed model is in queried state (or a substate of it) or not.
  238. """
  239. if not allow_substates:
  240. return getattr(model, self.model_attribute) == state_name
  241. return self.get_model_state(model).is_substate_of(state_name)
  242. def _traverse(self, states, on_enter=None, on_exit=None,
  243. ignore_invalid_triggers=None, parent=None, remap=None):
  244. """ Parses passed value to build a nested state structure recursively.
  245. Args:
  246. states (list, str, dict, or State): a list, a State instance, the
  247. name of a new state, or a dict with keywords to pass on to the
  248. State initializer. If a list, each element can be of any of the
  249. latter three types.
  250. on_enter (str or list): callbacks to trigger when the state is
  251. entered. Only valid if first argument is string.
  252. on_exit (str or list): callbacks to trigger when the state is
  253. exited. Only valid if first argument is string.
  254. ignore_invalid_triggers: when True, any calls to trigger methods
  255. that are not valid for the present state (e.g., calling an
  256. a_to_b() trigger when the current state is c) will be silently
  257. ignored rather than raising an invalid transition exception.
  258. Note that this argument takes precedence over the same
  259. argument defined at the Machine level, and is in turn
  260. overridden by any ignore_invalid_triggers explicitly
  261. passed in an individual state's initialization arguments.
  262. parent (NestedState or str): parent state for nested states.
  263. remap (dict): reassigns transitions named `key from nested machines to parent state `value`.
  264. Returns: list of new `NestedState` objects
  265. """
  266. states = listify(states)
  267. new_states = []
  268. ignore = ignore_invalid_triggers
  269. remap = {} if remap is None else remap
  270. parent = self.get_state(parent) if isinstance(parent, (string_types, Enum)) else parent
  271. if ignore is None:
  272. ignore = self.ignore_invalid_triggers
  273. for state in states:
  274. tmp_states = []
  275. # other state representations are handled almost like in the base class but a parent parameter is added
  276. if isinstance(state, (string_types, Enum)):
  277. if state in remap:
  278. continue
  279. tmp_states.append(self._create_state(state, on_enter=on_enter, on_exit=on_exit, parent=parent,
  280. ignore_invalid_triggers=ignore))
  281. elif isinstance(state, dict):
  282. if state['name'] in remap:
  283. continue
  284. # shallow copy the dictionary to alter/add some parameters
  285. state = copy(state)
  286. if 'ignore_invalid_triggers' not in state:
  287. state['ignore_invalid_triggers'] = ignore
  288. if 'parent' not in state:
  289. state['parent'] = parent
  290. try:
  291. state_children = state.pop('children') # throws KeyError when no children set
  292. state_remap = state.pop('remap', None)
  293. state_parent = self._create_state(**state)
  294. nested = self._traverse(state_children, parent=state_parent, remap=state_remap)
  295. tmp_states.append(state_parent)
  296. tmp_states.extend(nested)
  297. except KeyError:
  298. tmp_states.insert(0, self._create_state(**state))
  299. elif isinstance(state, HierarchicalMachine):
  300. # set initial state of parent if it is None
  301. if parent.initial is None:
  302. parent.initial = state.initial
  303. # (deep) copy only states not mentioned in remap
  304. copied_states = [s for s in deepcopy(state.states).values() if s.name not in remap]
  305. # inner_states are the root states of the passed machine
  306. # which have be attached to the parent
  307. inner_states = [s for s in copied_states if s.level == 0]
  308. for inner in inner_states:
  309. inner.parent = parent
  310. tmp_states.extend(copied_states)
  311. for trigger, event in state.events.items():
  312. if trigger.startswith('to_'):
  313. path = trigger[3:].split(self.state_cls.separator)
  314. # do not copy auto_transitions since they would not be valid anymore;
  315. # trigger and destination do not exist in the new environment
  316. if path[0] in remap:
  317. continue
  318. ppath = parent.name.split(self.state_cls.separator)
  319. path = ['to_' + ppath[0]] + ppath[1:] + path
  320. trigger = '.'.join(path)
  321. # (deep) copy transitions and
  322. # adjust all transition start and end points to new state names
  323. for transitions in deepcopy(event.transitions).values():
  324. for transition in transitions:
  325. src = transition.source
  326. # transitions from remapped states will be filtered to prevent
  327. # unexpected behaviour in the parent machine
  328. if src in remap:
  329. continue
  330. dst = parent.name + self.state_cls.separator + transition.dest\
  331. if transition.dest not in remap else remap[transition.dest]
  332. conditions, unless = [], []
  333. for cond in transition.conditions:
  334. # split a list in two lists based on the accessors (cond.target) truth value
  335. (unless, conditions)[cond.target].append(cond.func)
  336. self._buffered_transitions.append({'trigger': trigger,
  337. 'source': parent.name + self.state_cls.separator + src,
  338. 'dest': dst,
  339. 'conditions': conditions,
  340. 'unless': unless,
  341. 'prepare': transition.prepare,
  342. 'before': transition.before,
  343. 'after': transition.after})
  344. elif isinstance(state, NestedState):
  345. tmp_states.append(state)
  346. if state.children:
  347. tmp_states.extend(self._traverse(state.children, on_enter=on_enter, on_exit=on_exit,
  348. ignore_invalid_triggers=ignore_invalid_triggers,
  349. parent=state, remap=remap))
  350. else:
  351. raise ValueError("%s is not an instance or subclass of NestedState "
  352. "required by HierarchicalMachine." % state)
  353. new_states.extend(tmp_states)
  354. duplicate_check = []
  355. for new in new_states:
  356. if new.name in duplicate_check:
  357. # collect state names for the following error message
  358. state_names = [s.name for s in new_states]
  359. raise ValueError("State %s cannot be added since it is already in state list %s."
  360. % (new.name, state_names))
  361. else:
  362. duplicate_check.append(new.name)
  363. return new_states
  364. def add_states(self, states, on_enter=None, on_exit=None,
  365. ignore_invalid_triggers=None, **kwargs):
  366. """ Extends transitions.core.Machine.add_states by calling traverse to parse possible
  367. substates first."""
  368. # preprocess states to flatten the configuration and resolve nesting
  369. new_states = self._traverse(states, on_enter=on_enter, on_exit=on_exit,
  370. ignore_invalid_triggers=ignore_invalid_triggers, **kwargs)
  371. _super(HierarchicalMachine, self).add_states(new_states, on_enter=on_enter, on_exit=on_exit,
  372. ignore_invalid_triggers=ignore_invalid_triggers,
  373. **kwargs)
  374. while self._buffered_transitions:
  375. args = self._buffered_transitions.pop(0)
  376. self.add_transition(**args)
  377. def get_nested_state_names(self):
  378. """ Returns all states of the state machine. """
  379. return self.states
  380. def get_triggers(self, *args):
  381. """ Extends transitions.core.Machine.get_triggers to also include parent state triggers. """
  382. # add parents to state set
  383. states = []
  384. for state_name in args:
  385. state = self.get_state(state_name)
  386. while state.parent:
  387. states.append(state.parent.name)
  388. state = state.parent
  389. states.extend(args)
  390. return _super(HierarchicalMachine, self).get_triggers(*states)
  391. def _add_trigger_to_model(self, trigger, model):
  392. # FunctionWrappers are only necessary if a custom separator is used
  393. if trigger.startswith('to_') and self.state_cls.separator != '_':
  394. path = trigger[3:].split(self.state_cls.separator)
  395. trig_func = partial(self.events[trigger].trigger, model)
  396. if hasattr(model, 'to_' + path[0]):
  397. # add path to existing function wrapper
  398. getattr(model, 'to_' + path[0]).add(trig_func, path[1:])
  399. else:
  400. # create a new function wrapper
  401. setattr(model, 'to_' + path[0], FunctionWrapper(trig_func, path[1:]))
  402. else:
  403. _super(HierarchicalMachine, self)._add_trigger_to_model(trigger, model) # pylint: disable=protected-access
  404. def on_enter(self, state_name, callback):
  405. """ Helper function to add callbacks to states in case a custom state separator is used.
  406. Args:
  407. state_name (str): Name of the state
  408. callback (str or callable): Function to be called. Strings will be resolved to model functions.
  409. """
  410. self.get_state(state_name).add_callback('enter', callback)
  411. def on_exit(self, state_name, callback):
  412. """ Helper function to add callbacks to states in case a custom state separator is used.
  413. Args:
  414. state_name (str): Name of the state
  415. callback (str or callable): Function to be called. Strings will be resolved to model functions.
  416. """
  417. self.get_state(state_name).add_callback('exit', callback)
  418. def to_state(self, model, state_name, *args, **kwargs):
  419. """ Helper function to add go to states in case a custom state separator is used.
  420. Args:
  421. model (class): The model that should be used.
  422. state_name (str): Name of the destination state.
  423. """
  424. event = EventData(self.get_model_state(model), Event('to', self), self,
  425. model, args=args, kwargs=kwargs)
  426. self._create_transition(getattr(model, self.model_attribute), state_name).execute(event)