diagrams_pygraphviz.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. """
  2. transitions.extensions.diagrams
  3. -------------------------------
  4. Graphviz support for (nested) machines. This also includes partial views
  5. of currently valid transitions.
  6. """
  7. import logging
  8. import copy
  9. from .nesting import NestedState
  10. from .diagrams import BaseGraph
  11. try:
  12. import pygraphviz as pgv
  13. except ImportError: # pragma: no cover
  14. pgv = None
  15. _LOGGER = logging.getLogger(__name__)
  16. _LOGGER.addHandler(logging.NullHandler())
  17. # this is a workaround for dill issues when partials and super is used in conjunction
  18. # without it, Python 3.0 - 3.3 will not support pickling
  19. # https://github.com/pytransitions/transitions/issues/236
  20. _super = super
  21. class Graph(BaseGraph):
  22. """ Graph creation for transitions.core.Machine.
  23. Attributes:
  24. machine (object): Reference to the related machine.
  25. """
  26. def _add_nodes(self, states, container):
  27. for state in states:
  28. shape = self.machine.style_attributes['node']['default']['shape']
  29. container.add_node(state['name'], label=self._convert_state_attributes(state), shape=shape)
  30. def _add_edges(self, transitions, container):
  31. for transition in transitions:
  32. src = transition['source']
  33. edge_attr = {'label': self._transition_label(transition)}
  34. try:
  35. dst = transition['dest']
  36. except KeyError:
  37. dst = src
  38. if container.has_edge(src, dst):
  39. edge = container.get_edge(src, dst)
  40. edge.attr['label'] = edge.attr['label'] + ' | ' + edge_attr['label']
  41. else:
  42. container.add_edge(src, dst, **edge_attr)
  43. def generate(self, title=None):
  44. """ Generate a DOT graph with pygraphviz, returns an AGraph object """
  45. if not pgv: # pragma: no cover
  46. raise Exception('AGraph diagram requires pygraphviz')
  47. title = '' if not title else title
  48. self.fsm_graph = pgv.AGraph(label=title, **self.machine.machine_attributes)
  49. self.fsm_graph.node_attr.update(self.machine.style_attributes['node']['default'])
  50. self.fsm_graph.edge_attr.update(self.machine.style_attributes['edge']['default'])
  51. states, transitions = self._get_elements()
  52. self._add_nodes(states, self.fsm_graph)
  53. self._add_edges(transitions, self.fsm_graph)
  54. setattr(self.fsm_graph, 'style_attributes', self.machine.style_attributes)
  55. return self.fsm_graph
  56. def get_graph(self, title=None):
  57. if title:
  58. self.fsm_graph.graph_attr['label'] = title
  59. if self.roi_state:
  60. filtered = self.fsm_graph.copy()
  61. kept_nodes = set()
  62. active_state = self.roi_state if filtered.has_node(self.roi_state) else self.roi_state + '_anchor'
  63. kept_nodes.add(active_state)
  64. # remove all edges that have no connection to the currently active state
  65. for edge in filtered.edges():
  66. if active_state not in edge:
  67. filtered.delete_edge(edge)
  68. # find the ingoing edge by color; remove the rest
  69. for edge in filtered.in_edges(active_state):
  70. if edge.attr['color'] == self.fsm_graph.style_attributes['edge']['previous']['color']:
  71. kept_nodes.add(edge[0])
  72. else:
  73. filtered.delete_edge(edge)
  74. # remove outgoing edges from children
  75. for edge in filtered.out_edges_iter(active_state):
  76. kept_nodes.add(edge[1])
  77. for node in filtered.nodes():
  78. if node not in kept_nodes:
  79. filtered.delete_node(node)
  80. return filtered
  81. else:
  82. return self.fsm_graph
  83. def set_node_style(self, state, style):
  84. node = self.fsm_graph.get_node(state)
  85. style_attr = self.fsm_graph.style_attributes.get('node', {}).get(style)
  86. node.attr.update(style_attr)
  87. def set_previous_transition(self, src, dst, key=None):
  88. try:
  89. edge = self.fsm_graph.get_edge(src, dst)
  90. except KeyError:
  91. self.fsm_graph.add_edge(src, dst)
  92. edge = self.fsm_graph.get_edge(src, dst)
  93. style_attr = self.fsm_graph.style_attributes.get('edge', {}).get('previous')
  94. edge.attr.update(style_attr)
  95. self.set_node_style(src, 'previous')
  96. self.set_node_style(dst, 'active')
  97. def reset_styling(self):
  98. for edge in self.fsm_graph.edges_iter():
  99. style_attr = self.fsm_graph.style_attributes.get('edge', {}).get('default')
  100. edge.attr.update(style_attr)
  101. for node in self.fsm_graph.nodes_iter():
  102. if 'point' not in node.attr['shape']:
  103. style_attr = self.fsm_graph.style_attributes.get('node', {}).get('default')
  104. node.attr.update(style_attr)
  105. for sub_graph in self.fsm_graph.subgraphs_iter():
  106. style_attr = self.fsm_graph.style_attributes.get('graph', {}).get('default')
  107. sub_graph.graph_attr.update(style_attr)
  108. class NestedGraph(Graph):
  109. """ Graph creation support for transitions.extensions.nested.HierarchicalGraphMachine. """
  110. def __init__(self, *args, **kwargs):
  111. self.seen_transitions = []
  112. _super(NestedGraph, self).__init__(*args, **kwargs)
  113. # self.style_attributes['edge']['default']['minlen'] = 2
  114. def _add_nodes(self, states, container, prefix='', default_style='default'):
  115. for state in states:
  116. name = prefix + state['name']
  117. label = self._convert_state_attributes(state)
  118. if 'children' in state:
  119. cluster_name = "cluster_" + name
  120. is_parallel = isinstance(state.get('initial', ''), list)
  121. sub = container.add_subgraph(name=cluster_name, label=label, rank='source',
  122. **self.machine.style_attributes['graph'][default_style])
  123. root_container = sub.add_subgraph(name=cluster_name + '_root', label='', color=None, rank='min')
  124. width = '0' if is_parallel else '0.1'
  125. root_container.add_node(name + "_anchor", shape='point', fillcolor='black', width=width)
  126. self._add_nodes(state['children'], sub, prefix=prefix + state['name'] + NestedState.separator,
  127. default_style='parallel' if is_parallel else 'default')
  128. else:
  129. container.add_node(name, label=label, shape=self.machine.style_attributes['node']['default']['shape'])
  130. def _add_edges(self, transitions, container):
  131. for transition in transitions:
  132. # enable customizable labels
  133. label_pos = 'label'
  134. src = transition['source']
  135. try:
  136. dst = transition['dest']
  137. except KeyError:
  138. dst = src
  139. edge_attr = {}
  140. if _get_subgraph(container, 'cluster_' + src) is not None:
  141. edge_attr['ltail'] = 'cluster_' + src
  142. src_name = src + "_anchor"
  143. label_pos = 'headlabel'
  144. else:
  145. src_name = src
  146. dst_graph = _get_subgraph(container, 'cluster_' + dst)
  147. if dst_graph is not None:
  148. if not src.startswith(dst):
  149. edge_attr['lhead'] = "cluster_" + dst
  150. label_pos = 'taillabel' if label_pos.startswith('l') else 'label'
  151. dst_name = dst + '_anchor'
  152. else:
  153. dst_name = dst
  154. # remove ltail when dst is a child of src
  155. if 'ltail' in edge_attr:
  156. if _get_subgraph(container, edge_attr['ltail']).has_node(dst_name):
  157. del edge_attr['ltail']
  158. edge_attr[label_pos] = self._transition_label(transition)
  159. if container.has_edge(src_name, dst_name):
  160. edge = container.get_edge(src_name, dst_name)
  161. edge.attr[label_pos] += ' | ' + edge_attr[label_pos]
  162. else:
  163. container.add_edge(src_name, dst_name, **edge_attr)
  164. def set_node_style(self, state, style):
  165. try:
  166. node = self.fsm_graph.get_node(state)
  167. style_attr = self.fsm_graph.style_attributes.get('node', {}).get(style)
  168. node.attr.update(style_attr)
  169. except KeyError:
  170. subgraph = _get_subgraph(self.fsm_graph, 'cluster_' + state)
  171. style_attr = self.fsm_graph.style_attributes.get('graph', {}).get(style)
  172. subgraph.graph_attr.update(style_attr)
  173. def set_previous_transition(self, src, dst, key=None):
  174. src = self._get_global_name(src.split(self.machine.state_cls.separator))
  175. dst = self._get_global_name(dst.split(self.machine.state_cls.separator))
  176. edge_attr = self.fsm_graph.style_attributes.get('edge', {}).get('previous').copy()
  177. try:
  178. edge = self.fsm_graph.get_edge(src, dst)
  179. except KeyError:
  180. _src = src
  181. _dst = dst
  182. if _get_subgraph(self.fsm_graph, 'cluster_' + src):
  183. edge_attr['ltail'] = 'cluster_' + src
  184. _src += '_anchor'
  185. if _get_subgraph(self.fsm_graph, 'cluster_' + dst):
  186. edge_attr['lhead'] = "cluster_" + dst
  187. _dst += '_anchor'
  188. try:
  189. edge = self.fsm_graph.get_edge(_src, _dst)
  190. except KeyError:
  191. self.fsm_graph.add_edge(_src, _dst)
  192. edge = self.fsm_graph.get_edge(_src, _dst)
  193. edge.attr.update(edge_attr)
  194. self.set_node_style(src, 'previous')
  195. def _get_subgraph(graph, name):
  196. """ Searches for subgraphs in a graph.
  197. Args:
  198. g (AGraph): Container to be searched.
  199. name (str): Name of the cluster.
  200. Returns: AGraph if a cluster called 'name' exists else None
  201. """
  202. sub_graph = graph.get_subgraph(name)
  203. if sub_graph:
  204. return sub_graph
  205. for sub in graph.subgraphs_iter():
  206. sub_graph = _get_subgraph(sub, name)
  207. if sub_graph:
  208. return sub_graph
  209. return None