| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318 |
- from transitions import Transition
- from transitions.extensions.markup import MarkupMachine
- from transitions.core import listify
- import warnings
- import logging
- from functools import partial
- import copy
- _LOGGER = logging.getLogger(__name__)
- _LOGGER.addHandler(logging.NullHandler())
- # make deprecation warnings of transition visible for module users
- warnings.filterwarnings(action='default', message=r".*transitions version.*")
- # this is a workaround for dill issues when partials and super is used in conjunction
- # without it, Python 3.0 - 3.3 will not support pickling
- # https://github.com/pytransitions/transitions/issues/236
- _super = super
- class TransitionGraphSupport(Transition):
- """ Transition used in conjunction with (Nested)Graphs to update graphs whenever a transition is
- conducted.
- """
- def _change_state(self, event_data):
- graph = event_data.machine.model_graphs[event_data.model]
- graph.reset_styling()
- graph.set_previous_transition(self.source, self.dest, event_data.event.name)
- _super(TransitionGraphSupport, self)._change_state(event_data) # pylint: disable=protected-access
- for state in _flatten(listify(getattr(event_data.model, event_data.machine.model_attribute))):
- graph.set_node_style(self.dest if hasattr(state, 'name') else state, 'active')
- class GraphMachine(MarkupMachine):
- """ Extends transitions.core.Machine with graph support.
- Is also used as a mixin for HierarchicalMachine.
- Attributes:
- _pickle_blacklist (list): Objects that should not/do not need to be pickled.
- transition_cls (cls): TransitionGraphSupport
- """
- _pickle_blacklist = ['model_graphs']
- transition_cls = TransitionGraphSupport
- machine_attributes = {
- 'directed': 'true',
- 'strict': 'false',
- 'rankdir': 'LR',
- }
- hierarchical_machine_attributes = {
- 'rankdir': 'TB',
- 'rank': 'source',
- 'nodesep': '1.5',
- 'compound': 'true'
- }
- style_attributes = {
- 'node': {
- '': {},
- 'default': {
- 'shape': 'rectangle',
- 'style': 'rounded, filled',
- 'fillcolor': 'white',
- 'color': 'black',
- 'peripheries': '1'
- },
- 'active': {
- 'color': 'red',
- 'fillcolor': 'darksalmon',
- 'peripheries': '2'
- },
- 'previous': {
- 'color': 'blue',
- 'fillcolor': 'azure2',
- 'peripheries': '1'
- }
- },
- 'edge': {
- '': {},
- 'default': {
- 'color': 'black'
- },
- 'previous': {
- 'color': 'blue'
- }
- },
- 'graph': {
- '': {},
- 'default': {
- 'color': 'black',
- 'fillcolor': 'white',
- 'style': 'solid'
- },
- 'parallel': {
- 'color': 'black',
- 'fillcolor': 'white',
- 'style': 'dotted'
- },
- 'previous': {
- 'color': 'blue',
- 'fillcolor': 'azure2',
- 'style': 'filled'
- },
- 'active': {
- 'color': 'red',
- 'fillcolor': 'darksalmon',
- 'style': 'filled'
- },
- }
- }
- # model_graphs cannot be pickled. Omit them.
- def __getstate__(self):
- # self.pkl_graphs = [(g.markup, g.custom_styles) for g in self.model_graphs]
- return {k: v for k, v in self.__dict__.items() if k not in self._pickle_blacklist}
- def __setstate__(self, state):
- self.__dict__.update(state)
- self.model_graphs = {} # reinitialize new model_graphs
- for model in self.models:
- try:
- _ = self._get_graph(model, title=self.title)
- except AttributeError as e:
- _LOGGER.warning("Graph for model could not be initialized after pickling: %s", e)
- def __init__(self, *args, **kwargs):
- # remove graph config from keywords
- self.title = kwargs.pop('title', 'State Machine')
- self.show_conditions = kwargs.pop('show_conditions', False)
- self.show_state_attributes = kwargs.pop('show_state_attributes', False)
- # in MarkupMachine this switch is called 'with_auto_transitions'
- # keep 'auto_transitions_markup' for backwards compatibility
- kwargs['auto_transitions_markup'] = kwargs.get('auto_transitions_markup', False) or \
- kwargs.pop('show_auto_transitions', False)
- self.model_graphs = {}
- # determine graph engine; if pygraphviz cannot be imported, fall back to graphviz
- use_pygraphviz = kwargs.pop('use_pygraphviz', True)
- if use_pygraphviz:
- try:
- import pygraphviz
- except ImportError:
- use_pygraphviz = False
- self.graph_cls = self._init_graphviz_engine(use_pygraphviz)
- _LOGGER.debug("Using graph engine %s", self.graph_cls)
- _super(GraphMachine, self).__init__(*args, **kwargs)
- # for backwards compatibility assign get_combined_graph to get_graph
- # if model is not the machine
- if not hasattr(self, 'get_graph'):
- setattr(self, 'get_graph', self.get_combined_graph)
- def _init_graphviz_engine(self, use_pygraphviz):
- if use_pygraphviz:
- try:
- # state class needs to have a separator and machine needs to be a context manager
- if hasattr(self.state_cls, 'separator') and hasattr(self, '__enter__'):
- from .diagrams_pygraphviz import NestedGraph as Graph
- self.machine_attributes.update(self.hierarchical_machine_attributes)
- else:
- from .diagrams_pygraphviz import Graph
- return Graph
- except ImportError:
- pass
- if hasattr(self.state_cls, 'separator') and hasattr(self, '__enter__'):
- from .diagrams_graphviz import NestedGraph as Graph
- self.machine_attributes.update(self.hierarchical_machine_attributes)
- else:
- from .diagrams_graphviz import Graph
- return Graph
- def _get_graph(self, model, title=None, force_new=False, show_roi=False):
- if force_new:
- grph = self.graph_cls(self, title=title if title is not None else self.title)
- self.model_graphs[model] = grph
- try:
- state = getattr(model, self.model_attribute)
- self.model_graphs[model].set_node_style(state.name if hasattr(state, 'name') else state, 'active')
- except AttributeError:
- _LOGGER.info("Could not set active state of diagram")
- try:
- m = self.model_graphs[model]
- except KeyError:
- _ = self._get_graph(model, title, force_new=True)
- m = self.model_graphs[model]
- m.roi_state = getattr(model, self.model_attribute) if show_roi else None
- return m.get_graph(title=title)
- def get_combined_graph(self, title=None, force_new=False, show_roi=False):
- """ This method is currently equivalent to 'get_graph' of the first machine's model.
- In future releases of transitions, this function will return a combined graph with active states
- of all models.
- Args:
- title (str): Title of the resulting graph.
- force_new (bool): If set to True, (re-)generate the model's graph.
- show_roi (bool): If set to True, only render states that are active and/or can be reached from
- the current state.
- Returns: AGraph of the first machine's model.
- """
- _LOGGER.info('Returning graph of the first model. In future releases, this '
- 'method will return a combined graph of all models.')
- return self._get_graph(self.models[0], title, force_new, show_roi)
- def add_model(self, model, initial=None):
- models = listify(model)
- super(GraphMachine, self).add_model(models, initial)
- for mod in models:
- mod = self if mod == 'self' else mod
- if hasattr(mod, 'get_graph'):
- raise AttributeError('Model already has a get_graph attribute. Graph retrieval cannot be bound.')
- setattr(mod, 'get_graph', partial(self._get_graph, mod))
- _ = mod.get_graph(title=self.title, force_new=True) # initialises graph
- def add_states(self, states, on_enter=None, on_exit=None,
- ignore_invalid_triggers=None, **kwargs):
- """ Calls the base method and regenerates all models's graphs. """
- _super(GraphMachine, self).add_states(states, on_enter=on_enter, on_exit=on_exit,
- ignore_invalid_triggers=ignore_invalid_triggers, **kwargs)
- for model in self.models:
- model.get_graph(force_new=True)
- def add_transition(self, trigger, source, dest, conditions=None,
- unless=None, before=None, after=None, prepare=None, **kwargs):
- """ Calls the base method and regenerates all models's graphs. """
- _super(GraphMachine, self).add_transition(trigger, source, dest, conditions=conditions, unless=unless,
- before=before, after=after, prepare=prepare, **kwargs)
- for model in self.models:
- model.get_graph(force_new=True)
- class BaseGraph(object):
- def __init__(self, machine, title=None):
- self.machine = machine
- self.fsm_graph = None
- self.roi_state = None
- self.generate(title)
- def _convert_state_attributes(self, state):
- label = state.get('label', state['name'])
- if self.machine.show_state_attributes:
- if 'tags' in state:
- label += ' [' + ', '.join(state['tags']) + ']'
- if 'on_enter' in state:
- label += r'\l- enter:\l + ' + r'\l + '.join(state['on_enter'])
- if 'on_exit' in state:
- label += r'\l- exit:\l + ' + r'\l + '.join(state['on_exit'])
- if 'timeout' in state:
- label += r'\l- timeout(' + state['timeout'] + 's) -> (' + ', '.join(state['on_timeout']) + ')'
- return label
- def _transition_label(self, tran):
- edge_label = tran.get('label', tran['trigger'])
- if 'dest' not in tran:
- edge_label += " [internal]"
- if self.machine.show_conditions and any(prop in tran for prop in ['conditions', 'unless']):
- x = '{edge_label} [{conditions}]'.format(
- edge_label=edge_label,
- conditions=' & '.join(tran.get('conditions', []) + ['!' + u for u in tran.get('unless', [])]),
- )
- return x
- return edge_label
- def _get_global_name(self, path):
- if path:
- state = path.pop(0)
- with self.machine(state):
- return self._get_global_name(path)
- else:
- return self.machine.get_global_name()
- def _get_elements(self):
- states = []
- transitions = []
- try:
- markup = self.machine.get_markup_config()
- q = [([], markup)]
- while q:
- prefix, scope = q.pop(0)
- for transition in scope.get('transitions', []):
- if prefix:
- t = copy.copy(transition)
- t['source'] = self.machine.state_cls.separator.join(prefix + [t['source']])
- t['dest'] = self.machine.state_cls.separator.join(prefix + [t['dest']])
- else:
- t = transition
- transitions.append(t)
- for state in scope.get('children', []) + scope.get('states', []):
- if not prefix:
- s = state
- states.append(s)
- ini = state.get('initial', [])
- if not isinstance(ini, list):
- ini = ini.name if hasattr(ini, 'name') else ini
- t = dict(trigger='',
- source=self.machine.state_cls.separator.join(prefix + [state['name']]) + '_anchor',
- dest=self.machine.state_cls.separator.join(prefix + [state['name'], ini]))
- transitions.append(t)
- if state.get('children', []):
- q.append((prefix + [state['name']], state))
- except KeyError as e:
- _LOGGER.error("Graph creation incomplete!")
- return states, transitions
- def _flatten(item):
- for elem in item:
- if isinstance(elem, (list, tuple, set)):
- for res in _flatten(elem):
- yield res
- else:
- yield elem
|