diagrams.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. from transitions import Transition
  2. from transitions.extensions.markup import MarkupMachine
  3. from transitions.core import listify
  4. import warnings
  5. import logging
  6. from functools import partial
  7. import copy
  8. _LOGGER = logging.getLogger(__name__)
  9. _LOGGER.addHandler(logging.NullHandler())
  10. # make deprecation warnings of transition visible for module users
  11. warnings.filterwarnings(action='default', message=r".*transitions version.*")
  12. # this is a workaround for dill issues when partials and super is used in conjunction
  13. # without it, Python 3.0 - 3.3 will not support pickling
  14. # https://github.com/pytransitions/transitions/issues/236
  15. _super = super
  16. class TransitionGraphSupport(Transition):
  17. """ Transition used in conjunction with (Nested)Graphs to update graphs whenever a transition is
  18. conducted.
  19. """
  20. def _change_state(self, event_data):
  21. graph = event_data.machine.model_graphs[event_data.model]
  22. graph.reset_styling()
  23. graph.set_previous_transition(self.source, self.dest, event_data.event.name)
  24. _super(TransitionGraphSupport, self)._change_state(event_data) # pylint: disable=protected-access
  25. for state in _flatten(listify(getattr(event_data.model, event_data.machine.model_attribute))):
  26. graph.set_node_style(self.dest if hasattr(state, 'name') else state, 'active')
  27. class GraphMachine(MarkupMachine):
  28. """ Extends transitions.core.Machine with graph support.
  29. Is also used as a mixin for HierarchicalMachine.
  30. Attributes:
  31. _pickle_blacklist (list): Objects that should not/do not need to be pickled.
  32. transition_cls (cls): TransitionGraphSupport
  33. """
  34. _pickle_blacklist = ['model_graphs']
  35. transition_cls = TransitionGraphSupport
  36. machine_attributes = {
  37. 'directed': 'true',
  38. 'strict': 'false',
  39. 'rankdir': 'LR',
  40. }
  41. hierarchical_machine_attributes = {
  42. 'rankdir': 'TB',
  43. 'rank': 'source',
  44. 'nodesep': '1.5',
  45. 'compound': 'true'
  46. }
  47. style_attributes = {
  48. 'node': {
  49. '': {},
  50. 'default': {
  51. 'shape': 'rectangle',
  52. 'style': 'rounded, filled',
  53. 'fillcolor': 'white',
  54. 'color': 'black',
  55. 'peripheries': '1'
  56. },
  57. 'active': {
  58. 'color': 'red',
  59. 'fillcolor': 'darksalmon',
  60. 'peripheries': '2'
  61. },
  62. 'previous': {
  63. 'color': 'blue',
  64. 'fillcolor': 'azure2',
  65. 'peripheries': '1'
  66. }
  67. },
  68. 'edge': {
  69. '': {},
  70. 'default': {
  71. 'color': 'black'
  72. },
  73. 'previous': {
  74. 'color': 'blue'
  75. }
  76. },
  77. 'graph': {
  78. '': {},
  79. 'default': {
  80. 'color': 'black',
  81. 'fillcolor': 'white',
  82. 'style': 'solid'
  83. },
  84. 'parallel': {
  85. 'color': 'black',
  86. 'fillcolor': 'white',
  87. 'style': 'dotted'
  88. },
  89. 'previous': {
  90. 'color': 'blue',
  91. 'fillcolor': 'azure2',
  92. 'style': 'filled'
  93. },
  94. 'active': {
  95. 'color': 'red',
  96. 'fillcolor': 'darksalmon',
  97. 'style': 'filled'
  98. },
  99. }
  100. }
  101. # model_graphs cannot be pickled. Omit them.
  102. def __getstate__(self):
  103. # self.pkl_graphs = [(g.markup, g.custom_styles) for g in self.model_graphs]
  104. return {k: v for k, v in self.__dict__.items() if k not in self._pickle_blacklist}
  105. def __setstate__(self, state):
  106. self.__dict__.update(state)
  107. self.model_graphs = {} # reinitialize new model_graphs
  108. for model in self.models:
  109. try:
  110. _ = self._get_graph(model, title=self.title)
  111. except AttributeError as e:
  112. _LOGGER.warning("Graph for model could not be initialized after pickling: %s", e)
  113. def __init__(self, *args, **kwargs):
  114. # remove graph config from keywords
  115. self.title = kwargs.pop('title', 'State Machine')
  116. self.show_conditions = kwargs.pop('show_conditions', False)
  117. self.show_state_attributes = kwargs.pop('show_state_attributes', False)
  118. # in MarkupMachine this switch is called 'with_auto_transitions'
  119. # keep 'auto_transitions_markup' for backwards compatibility
  120. kwargs['auto_transitions_markup'] = kwargs.get('auto_transitions_markup', False) or \
  121. kwargs.pop('show_auto_transitions', False)
  122. self.model_graphs = {}
  123. # determine graph engine; if pygraphviz cannot be imported, fall back to graphviz
  124. use_pygraphviz = kwargs.pop('use_pygraphviz', True)
  125. if use_pygraphviz:
  126. try:
  127. import pygraphviz
  128. except ImportError:
  129. use_pygraphviz = False
  130. self.graph_cls = self._init_graphviz_engine(use_pygraphviz)
  131. _LOGGER.debug("Using graph engine %s", self.graph_cls)
  132. _super(GraphMachine, self).__init__(*args, **kwargs)
  133. # for backwards compatibility assign get_combined_graph to get_graph
  134. # if model is not the machine
  135. if not hasattr(self, 'get_graph'):
  136. setattr(self, 'get_graph', self.get_combined_graph)
  137. def _init_graphviz_engine(self, use_pygraphviz):
  138. if use_pygraphviz:
  139. try:
  140. # state class needs to have a separator and machine needs to be a context manager
  141. if hasattr(self.state_cls, 'separator') and hasattr(self, '__enter__'):
  142. from .diagrams_pygraphviz import NestedGraph as Graph
  143. self.machine_attributes.update(self.hierarchical_machine_attributes)
  144. else:
  145. from .diagrams_pygraphviz import Graph
  146. return Graph
  147. except ImportError:
  148. pass
  149. if hasattr(self.state_cls, 'separator') and hasattr(self, '__enter__'):
  150. from .diagrams_graphviz import NestedGraph as Graph
  151. self.machine_attributes.update(self.hierarchical_machine_attributes)
  152. else:
  153. from .diagrams_graphviz import Graph
  154. return Graph
  155. def _get_graph(self, model, title=None, force_new=False, show_roi=False):
  156. if force_new:
  157. grph = self.graph_cls(self, title=title if title is not None else self.title)
  158. self.model_graphs[model] = grph
  159. try:
  160. state = getattr(model, self.model_attribute)
  161. self.model_graphs[model].set_node_style(state.name if hasattr(state, 'name') else state, 'active')
  162. except AttributeError:
  163. _LOGGER.info("Could not set active state of diagram")
  164. try:
  165. m = self.model_graphs[model]
  166. except KeyError:
  167. _ = self._get_graph(model, title, force_new=True)
  168. m = self.model_graphs[model]
  169. m.roi_state = getattr(model, self.model_attribute) if show_roi else None
  170. return m.get_graph(title=title)
  171. def get_combined_graph(self, title=None, force_new=False, show_roi=False):
  172. """ This method is currently equivalent to 'get_graph' of the first machine's model.
  173. In future releases of transitions, this function will return a combined graph with active states
  174. of all models.
  175. Args:
  176. title (str): Title of the resulting graph.
  177. force_new (bool): If set to True, (re-)generate the model's graph.
  178. show_roi (bool): If set to True, only render states that are active and/or can be reached from
  179. the current state.
  180. Returns: AGraph of the first machine's model.
  181. """
  182. _LOGGER.info('Returning graph of the first model. In future releases, this '
  183. 'method will return a combined graph of all models.')
  184. return self._get_graph(self.models[0], title, force_new, show_roi)
  185. def add_model(self, model, initial=None):
  186. models = listify(model)
  187. super(GraphMachine, self).add_model(models, initial)
  188. for mod in models:
  189. mod = self if mod == 'self' else mod
  190. if hasattr(mod, 'get_graph'):
  191. raise AttributeError('Model already has a get_graph attribute. Graph retrieval cannot be bound.')
  192. setattr(mod, 'get_graph', partial(self._get_graph, mod))
  193. _ = mod.get_graph(title=self.title, force_new=True) # initialises graph
  194. def add_states(self, states, on_enter=None, on_exit=None,
  195. ignore_invalid_triggers=None, **kwargs):
  196. """ Calls the base method and regenerates all models's graphs. """
  197. _super(GraphMachine, self).add_states(states, on_enter=on_enter, on_exit=on_exit,
  198. ignore_invalid_triggers=ignore_invalid_triggers, **kwargs)
  199. for model in self.models:
  200. model.get_graph(force_new=True)
  201. def add_transition(self, trigger, source, dest, conditions=None,
  202. unless=None, before=None, after=None, prepare=None, **kwargs):
  203. """ Calls the base method and regenerates all models's graphs. """
  204. _super(GraphMachine, self).add_transition(trigger, source, dest, conditions=conditions, unless=unless,
  205. before=before, after=after, prepare=prepare, **kwargs)
  206. for model in self.models:
  207. model.get_graph(force_new=True)
  208. class BaseGraph(object):
  209. def __init__(self, machine, title=None):
  210. self.machine = machine
  211. self.fsm_graph = None
  212. self.roi_state = None
  213. self.generate(title)
  214. def _convert_state_attributes(self, state):
  215. label = state.get('label', state['name'])
  216. if self.machine.show_state_attributes:
  217. if 'tags' in state:
  218. label += ' [' + ', '.join(state['tags']) + ']'
  219. if 'on_enter' in state:
  220. label += r'\l- enter:\l + ' + r'\l + '.join(state['on_enter'])
  221. if 'on_exit' in state:
  222. label += r'\l- exit:\l + ' + r'\l + '.join(state['on_exit'])
  223. if 'timeout' in state:
  224. label += r'\l- timeout(' + state['timeout'] + 's) -> (' + ', '.join(state['on_timeout']) + ')'
  225. return label
  226. def _transition_label(self, tran):
  227. edge_label = tran.get('label', tran['trigger'])
  228. if 'dest' not in tran:
  229. edge_label += " [internal]"
  230. if self.machine.show_conditions and any(prop in tran for prop in ['conditions', 'unless']):
  231. x = '{edge_label} [{conditions}]'.format(
  232. edge_label=edge_label,
  233. conditions=' & '.join(tran.get('conditions', []) + ['!' + u for u in tran.get('unless', [])]),
  234. )
  235. return x
  236. return edge_label
  237. def _get_global_name(self, path):
  238. if path:
  239. state = path.pop(0)
  240. with self.machine(state):
  241. return self._get_global_name(path)
  242. else:
  243. return self.machine.get_global_name()
  244. def _get_elements(self):
  245. states = []
  246. transitions = []
  247. try:
  248. markup = self.machine.get_markup_config()
  249. q = [([], markup)]
  250. while q:
  251. prefix, scope = q.pop(0)
  252. for transition in scope.get('transitions', []):
  253. if prefix:
  254. t = copy.copy(transition)
  255. t['source'] = self.machine.state_cls.separator.join(prefix + [t['source']])
  256. t['dest'] = self.machine.state_cls.separator.join(prefix + [t['dest']])
  257. else:
  258. t = transition
  259. transitions.append(t)
  260. for state in scope.get('children', []) + scope.get('states', []):
  261. if not prefix:
  262. s = state
  263. states.append(s)
  264. ini = state.get('initial', [])
  265. if not isinstance(ini, list):
  266. ini = ini.name if hasattr(ini, 'name') else ini
  267. t = dict(trigger='',
  268. source=self.machine.state_cls.separator.join(prefix + [state['name']]) + '_anchor',
  269. dest=self.machine.state_cls.separator.join(prefix + [state['name'], ini]))
  270. transitions.append(t)
  271. if state.get('children', []):
  272. q.append((prefix + [state['name']], state))
  273. except KeyError as e:
  274. _LOGGER.error("Graph creation incomplete!")
  275. return states, transitions
  276. def _flatten(item):
  277. for elem in item:
  278. if isinstance(elem, (list, tuple, set)):
  279. for res in _flatten(elem):
  280. yield res
  281. else:
  282. yield elem