diagrams_graphviz.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """
  2. transitions.extensions.diagrams
  3. -------------------------------
  4. Graphviz support for (nested) machines. This also includes partial views
  5. of currently valid transitions.
  6. """
  7. import logging
  8. from functools import partial
  9. from collections import defaultdict
  10. from os.path import splitext
  11. import copy
  12. from .diagrams import BaseGraph
  13. from ..core import listify
  14. try:
  15. import graphviz as pgv
  16. except ImportError: # pragma: no cover
  17. pgv = None
  18. _LOGGER = logging.getLogger(__name__)
  19. _LOGGER.addHandler(logging.NullHandler())
  20. # this is a workaround for dill issues when partials and super is used in conjunction
  21. # without it, Python 3.0 - 3.3 will not support pickling
  22. # https://github.com/pytransitions/transitions/issues/236
  23. _super = super
  24. class Graph(BaseGraph):
  25. """ Graph creation for transitions.core.Machine.
  26. Attributes:
  27. machine (object): Reference to the related machine.
  28. """
  29. def __init__(self, machine, title=None):
  30. self.reset_styling()
  31. _super(Graph, self).__init__(machine, title)
  32. def set_previous_transition(self, src, dst, key=None):
  33. self.custom_styles['edge'][src][dst] = 'previous'
  34. self.set_node_style(src, 'previous')
  35. def set_node_style(self, state, style):
  36. self.custom_styles['node'][state] = style
  37. def reset_styling(self):
  38. self.custom_styles = {'edge': defaultdict(lambda: defaultdict(str)),
  39. 'node': defaultdict(str)}
  40. def _add_nodes(self, states, container):
  41. for state in states:
  42. style = self.custom_styles['node'][state['name']]
  43. container.node(state['name'], label=self._convert_state_attributes(state),
  44. **self.machine.style_attributes['node'][style])
  45. def _add_edges(self, transitions, container):
  46. edge_labels = defaultdict(lambda: defaultdict(list))
  47. for transition in transitions:
  48. try:
  49. dst = transition['dest']
  50. except KeyError:
  51. dst = transition['source']
  52. edge_labels[transition['source']][dst].append(self._transition_label(transition))
  53. for src, dests in edge_labels.items():
  54. for dst, labels in dests.items():
  55. style = self.custom_styles['edge'][src][dst]
  56. container.edge(src, dst, label=' | '.join(labels), **self.machine.style_attributes['edge'][style])
  57. def generate(self, title=None, roi_state=None):
  58. """ Generate a DOT graph with graphviz
  59. Args:
  60. roi_state (str): Optional, show only custom states and edges from roi_state
  61. """
  62. if not pgv: # pragma: no cover
  63. raise Exception('AGraph diagram requires graphviz')
  64. title = self.machine.title if not title else title
  65. fsm_graph = pgv.Digraph(name=title, node_attr=self.machine.style_attributes['node']['default'],
  66. edge_attr=self.machine.style_attributes['edge']['default'],
  67. graph_attr=self.machine.style_attributes['graph']['default'])
  68. fsm_graph.graph_attr.update(**self.machine.machine_attributes)
  69. fsm_graph.graph_attr['label'] = title
  70. # For each state, draw a circle
  71. states, transitions = self._get_elements()
  72. if roi_state:
  73. transitions = [t for t in transitions
  74. if t['source'] == roi_state or self.custom_styles['edge'][t['source']][t['dest']]]
  75. state_names = [t for trans in transitions
  76. for t in [trans['source'], trans.get('dest', trans['source'])]]
  77. state_names += [k for k, style in self.custom_styles['node'].items() if style]
  78. states = _filter_states(states, state_names, self.machine.state_cls)
  79. self._add_nodes(states, fsm_graph)
  80. self._add_edges(transitions, fsm_graph)
  81. setattr(fsm_graph, 'draw', partial(self.draw, fsm_graph))
  82. return fsm_graph
  83. def get_graph(self, title=None):
  84. return self.generate(title, roi_state=self.roi_state)
  85. @staticmethod
  86. def draw(graph, filename, format=None, prog='dot', args=''):
  87. """ Generates and saves an image of the state machine using graphviz.
  88. Args:
  89. filename (str): path and name of image output
  90. format (str): Optional format of the output file
  91. Returns:
  92. """
  93. graph.engine = prog
  94. try:
  95. filename, ext = splitext(filename)
  96. format = format if format is not None else ext[1:]
  97. graph.render(filename, format=format if format else 'png', cleanup=True)
  98. except TypeError:
  99. if format is None:
  100. raise ValueError("Parameter 'format' must not be None when filename is no valid file path.")
  101. filename.write(graph.pipe(format))
  102. class NestedGraph(Graph):
  103. """ Graph creation support for transitions.extensions.nested.HierarchicalGraphMachine. """
  104. def __init__(self, *args, **kwargs):
  105. self._cluster_states = []
  106. _super(NestedGraph, self).__init__(*args, **kwargs)
  107. def set_previous_transition(self, src, dst, key=None):
  108. src_name = self._get_global_name(src.split(self.machine.state_cls.separator))
  109. dst_name = self._get_global_name(dst.split(self.machine.state_cls.separator))
  110. _super(NestedGraph, self).set_previous_transition(src_name, dst_name, key)
  111. def _add_nodes(self, states, container, prefix='', default_style='default'):
  112. for state in states:
  113. name = prefix + state['name']
  114. label = self._convert_state_attributes(state)
  115. if state.get('children', []):
  116. cluster_name = "cluster_" + name
  117. with container.subgraph(name=cluster_name,
  118. graph_attr=self.machine.style_attributes['graph']['default']) as sub:
  119. style = self.custom_styles['node'][name] or default_style
  120. sub.graph_attr.update(label=label, rank='source', **self.machine.style_attributes['graph'][style])
  121. self._cluster_states.append(name)
  122. is_parallel = isinstance(state.get('initial', ''), list)
  123. width = '0.0' if is_parallel else '0.1'
  124. with sub.subgraph(name=cluster_name + '_root',
  125. graph_attr={'label': '', 'color': 'None', 'rank': 'min'}) as root:
  126. root.node(name + "_anchor", shape='point', fillcolor='black', width=width)
  127. self._add_nodes(state['children'], sub, default_style='parallel' if is_parallel else 'default',
  128. prefix=prefix + state['name'] + self.machine.state_cls.separator)
  129. else:
  130. style = self.custom_styles['node'][name] or default_style
  131. container.node(name, label=label, **self.machine.style_attributes['node'][style])
  132. def _add_edges(self, transitions, container, prefix=''):
  133. edges_attr = defaultdict(lambda: defaultdict(dict))
  134. for transition in transitions:
  135. # enable customizable labels
  136. label_pos = 'label'
  137. src = prefix + transition['source']
  138. try:
  139. dst = prefix + transition['dest']
  140. except KeyError:
  141. dst = src
  142. if edges_attr[src][dst]:
  143. attr = edges_attr[src][dst]
  144. attr[attr['label_pos']] = ' | '.join([edges_attr[src][dst][attr['label_pos']],
  145. self._transition_label(transition)])
  146. continue
  147. else:
  148. attr = {}
  149. if src in self._cluster_states:
  150. attr['ltail'] = 'cluster_' + src
  151. src_name = src + "_anchor"
  152. label_pos = 'headlabel'
  153. else:
  154. src_name = src
  155. if dst in self._cluster_states:
  156. if not src.startswith(dst):
  157. attr['lhead'] = "cluster_" + dst
  158. label_pos = 'taillabel' if label_pos.startswith('l') else 'label'
  159. dst_name = dst + '_anchor'
  160. else:
  161. dst_name = dst
  162. # remove ltail when dst (ltail always starts with 'cluster_') is a child of src
  163. if 'ltail' in attr and dst_name.startswith(attr['ltail'][8:]):
  164. del attr['ltail']
  165. # # remove ltail when dst is a child of src
  166. # if 'ltail' in edge_attr:
  167. # if _get_subgraph(container, edge_attr['ltail']).has_node(dst_name):
  168. # del edge_attr['ltail']
  169. attr[label_pos] = self._transition_label(transition)
  170. attr['label_pos'] = label_pos
  171. attr['source'] = src_name
  172. attr['dest'] = dst_name
  173. edges_attr[src][dst] = attr
  174. for src, dests in edges_attr.items():
  175. for dst, attr in dests.items():
  176. del attr['label_pos']
  177. style = self.custom_styles['edge'][src][dst]
  178. attr.update(**self.machine.style_attributes['edge'][style])
  179. container.edge(attr.pop('source'), attr.pop('dest'), **attr)
  180. def _filter_states(states, state_names, state_cls, prefix=None):
  181. prefix = prefix or []
  182. result = []
  183. for state in states:
  184. pref = prefix + [state['name']]
  185. if 'children' in state:
  186. state['children'] = _filter_states(state['children'], state_names, state_cls, prefix=pref)
  187. result.append(state)
  188. elif getattr(state_cls, 'separator', '_').join(pref) in state_names:
  189. result.append(state)
  190. return result