locking.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. """
  2. transitions.extensions.factory
  3. ------------------------------
  4. Adds locking to machine methods as well as model functions that trigger events.
  5. Additionally, the user can inject her/his own context manager into the machine if required.
  6. """
  7. from collections import defaultdict
  8. from functools import partial
  9. from threading import Lock
  10. import inspect
  11. import warnings
  12. import logging
  13. from transitions.core import Machine, Event, listify
  14. _LOGGER = logging.getLogger(__name__)
  15. _LOGGER.addHandler(logging.NullHandler())
  16. # this is a workaround for dill issues when partials and super is used in conjunction
  17. # without it, Python 3.0 - 3.3 will not support pickling
  18. # https://github.com/pytransitions/transitions/issues/236
  19. _super = super
  20. try:
  21. from contextlib import nested # Python 2
  22. from thread import get_ident
  23. # with nested statements now raise a DeprecationWarning. Should be replaced with ExitStack-like approaches.
  24. warnings.simplefilter('ignore', DeprecationWarning)
  25. except ImportError:
  26. from contextlib import ExitStack, contextmanager
  27. from threading import get_ident
  28. @contextmanager
  29. def nested(*contexts):
  30. """ Reimplementation of nested in Python 3. """
  31. with ExitStack() as stack:
  32. for ctx in contexts:
  33. stack.enter_context(ctx)
  34. yield contexts
  35. class PicklableLock(object):
  36. """ A wrapper for threading.Lock which discards its state during pickling and
  37. is reinitialized unlocked when unpickled.
  38. """
  39. def __init__(self):
  40. self.lock = Lock()
  41. def __getstate__(self):
  42. return ''
  43. def __setstate__(self, value):
  44. return self.__init__()
  45. def __enter__(self):
  46. self.lock.__enter__()
  47. def __exit__(self, exc_type, exc_val, exc_tb):
  48. self.lock.__exit__(exc_type, exc_val, exc_tb)
  49. class IdentManager:
  50. def __init__(self):
  51. self.current = 0
  52. def __enter__(self):
  53. self.current = get_ident()
  54. pass
  55. def __exit__(self, exc_type, exc_val, exc_tb):
  56. self.current = 0
  57. class LockedEvent(Event):
  58. """ An event type which uses the parent's machine context map when triggered. """
  59. def trigger(self, model, *args, **kwargs):
  60. """ Extends transitions.core.Event.trigger by using locks/machine contexts. """
  61. # pylint: disable=protected-access
  62. # noinspection PyProtectedMember
  63. # LockedMachine._locked should not be called somewhere else. That's why it should not be exposed
  64. # to Machine users.
  65. if self.machine._ident.current != get_ident():
  66. with nested(*self.machine.model_context_map[model]):
  67. return _super(LockedEvent, self).trigger(model, *args, **kwargs)
  68. else:
  69. return _super(LockedEvent, self).trigger(model, *args, **kwargs)
  70. class LockedMachine(Machine):
  71. """ Machine class which manages contexts. In it's default version the machine uses a `threading.Lock`
  72. context to lock access to its methods and event triggers bound to model objects.
  73. Attributes:
  74. machine_context (dict): A dict of context managers to be entered whenever a machine method is
  75. called or an event is triggered. Contexts are managed for each model individually.
  76. """
  77. event_cls = LockedEvent
  78. def __init__(self, *args, **kwargs):
  79. self._ident = IdentManager()
  80. try:
  81. self.machine_context = listify(kwargs.pop('machine_context'))
  82. except KeyError:
  83. self.machine_context = [PicklableLock()]
  84. self.machine_context.append(self._ident)
  85. self.model_context_map = defaultdict(list)
  86. _super(LockedMachine, self).__init__(*args, **kwargs)
  87. def add_model(self, model, initial=None, model_context=None):
  88. """ Extends `transitions.core.Machine.add_model` by `model_context` keyword.
  89. Args:
  90. model (list or object): A model (list) to be managed by the machine.
  91. initial (str, Enum or State): The initial state of the passed model[s].
  92. model_context (list or object): If passed, assign the context (list) to the machines
  93. model specific context map.
  94. """
  95. models = listify(model)
  96. model_context = listify(model_context) if model_context is not None else []
  97. output = _super(LockedMachine, self).add_model(models, initial)
  98. for mod in models:
  99. mod = self if mod == 'self' else mod
  100. self.model_context_map[mod].extend(self.machine_context)
  101. self.model_context_map[mod].extend(model_context)
  102. return output
  103. def remove_model(self, model):
  104. """ Extends `transitions.core.Machine.remove_model` by removing model specific context maps
  105. from the machine when the model itself is removed. """
  106. models = listify(model)
  107. for mod in models:
  108. del self.model_context_map[mod]
  109. return _super(LockedMachine, self).remove_model(models)
  110. def __getattribute__(self, item):
  111. get_attr = _super(LockedMachine, self).__getattribute__
  112. tmp = get_attr(item)
  113. if not item.startswith('_') and inspect.ismethod(tmp):
  114. return partial(get_attr('_locked_method'), tmp)
  115. return tmp
  116. def __getattr__(self, item):
  117. try:
  118. return _super(LockedMachine, self).__getattribute__(item)
  119. except AttributeError:
  120. return _super(LockedMachine, self).__getattr__(item)
  121. # Determine if the returned method is a partial and make sure the returned partial has
  122. # not been created by Machine.__getattr__.
  123. # https://github.com/tyarkoni/transitions/issues/214
  124. def _add_model_to_state(self, state, model):
  125. _super(LockedMachine, self)._add_model_to_state(state, model) # pylint: disable=protected-access
  126. for prefix in ['enter', 'exit']:
  127. callback = "on_{0}_".format(prefix) + state.name
  128. func = getattr(model, callback, None)
  129. if isinstance(func, partial) and func.func != state.add_callback:
  130. state.add_callback(prefix, callback)
  131. def _locked_method(self, func, *args, **kwargs):
  132. if self._ident.current != get_ident():
  133. with nested(*self.machine_context):
  134. return func(*args, **kwargs)
  135. else:
  136. return func(*args, **kwargs)