123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from __future__ import unicode_literals
- __all__ = ['singledispatch']
- from functools import update_wrapper
- from weakref import WeakKeyDictionary
- from singledispatch_helpers import MappingProxyType, get_cache_token
- ################################################################################
- ### singledispatch() - single-dispatch generic function decorator
- ################################################################################
- def _c3_merge(sequences):
- """Merges MROs in *sequences* to a single MRO using the C3 algorithm.
- Adapted from http://www.python.org/download/releases/2.3/mro/.
- """
- result = []
- while True:
- sequences = [s for s in sequences if s] # purge empty sequences
- if not sequences:
- return result
- for s1 in sequences: # find merge candidates among seq heads
- candidate = s1[0]
- for s2 in sequences:
- if candidate in s2[1:]:
- candidate = None
- break # reject the current head, it appears later
- else:
- break
- if not candidate:
- raise RuntimeError("Inconsistent hierarchy")
- result.append(candidate)
- # remove the chosen candidate
- for seq in sequences:
- if seq[0] == candidate:
- del seq[0]
- def _c3_mro(cls, abcs=None):
- """Computes the method resolution order using extended C3 linearization.
- If no *abcs* are given, the algorithm works exactly like the built-in C3
- linearization used for method resolution.
- If given, *abcs* is a list of abstract base classes that should be inserted
- into the resulting MRO. Unrelated ABCs are ignored and don't end up in the
- result. The algorithm inserts ABCs where their functionality is introduced,
- i.e. issubclass(cls, abc) returns True for the class itself but returns
- False for all its direct base classes. Implicit ABCs for a given class
- (either registered or inferred from the presence of a special method like
- __len__) are inserted directly after the last ABC explicitly listed in the
- MRO of said class. If two implicit ABCs end up next to each other in the
- resulting MRO, their ordering depends on the order of types in *abcs*.
- """
- for i, base in enumerate(reversed(cls.__bases__)):
- if hasattr(base, '__abstractmethods__'):
- boundary = len(cls.__bases__) - i
- break # Bases up to the last explicit ABC are considered first.
- else:
- boundary = 0
- abcs = list(abcs) if abcs else []
- explicit_bases = list(cls.__bases__[:boundary])
- abstract_bases = []
- other_bases = list(cls.__bases__[boundary:])
- for base in abcs:
- if issubclass(cls, base) and not any(
- issubclass(b, base) for b in cls.__bases__
- ):
- # If *cls* is the class that introduces behaviour described by
- # an ABC *base*, insert said ABC to its MRO.
- abstract_bases.append(base)
- for base in abstract_bases:
- abcs.remove(base)
- explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases]
- abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases]
- other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases]
- return _c3_merge(
- [[cls]] +
- explicit_c3_mros + abstract_c3_mros + other_c3_mros +
- [explicit_bases] + [abstract_bases] + [other_bases]
- )
- def _compose_mro(cls, types):
- """Calculates the method resolution order for a given class *cls*.
- Includes relevant abstract base classes (with their respective bases) from
- the *types* iterable. Uses a modified C3 linearization algorithm.
- """
- bases = set(cls.__mro__)
- # Remove entries which are already present in the __mro__ or unrelated.
- def is_related(typ):
- return (typ not in bases and hasattr(typ, '__mro__')
- and issubclass(cls, typ))
- types = [n for n in types if is_related(n)]
- # Remove entries which are strict bases of other entries (they will end up
- # in the MRO anyway.
- def is_strict_base(typ):
- for other in types:
- if typ != other and typ in other.__mro__:
- return True
- return False
- types = [n for n in types if not is_strict_base(n)]
- # Subclasses of the ABCs in *types* which are also implemented by
- # *cls* can be used to stabilize ABC ordering.
- type_set = set(types)
- mro = []
- for typ in types:
- found = []
- for sub in typ.__subclasses__():
- if sub not in bases and issubclass(cls, sub):
- found.append([s for s in sub.__mro__ if s in type_set])
- if not found:
- mro.append(typ)
- continue
- # Favor subclasses with the biggest number of useful bases
- found.sort(key=len, reverse=True)
- for sub in found:
- for subcls in sub:
- if subcls not in mro:
- mro.append(subcls)
- return _c3_mro(cls, abcs=mro)
- def _find_impl(cls, registry):
- """Returns the best matching implementation from *registry* for type *cls*.
- Where there is no registered implementation for a specific type, its method
- resolution order is used to find a more generic implementation.
- Note: if *registry* does not contain an implementation for the base
- *object* type, this function may return None.
- """
- mro = _compose_mro(cls, registry.keys())
- match = None
- for t in mro:
- if match is not None:
- # If *match* is an implicit ABC but there is another unrelated,
- # equally matching implicit ABC, refuse the temptation to guess.
- if (t in registry and t not in cls.__mro__
- and match not in cls.__mro__
- and not issubclass(match, t)):
- raise RuntimeError("Ambiguous dispatch: {0} or {1}".format(
- match, t))
- break
- if t in registry:
- match = t
- return registry.get(match)
- def singledispatch(func):
- """Single-dispatch generic function decorator.
- Transforms a function into a generic function, which can have different
- behaviours depending upon the type of its first argument. The decorated
- function acts as the default implementation, and additional
- implementations can be registered using the register() attribute of the
- generic function.
- """
- registry = {}
- dispatch_cache = WeakKeyDictionary()
- def ns(): pass
- ns.cache_token = None
- def dispatch(cls):
- """generic_func.dispatch(cls) -> <function implementation>
- Runs the dispatch algorithm to return the best available implementation
- for the given *cls* registered on *generic_func*.
- """
- if ns.cache_token is not None:
- current_token = get_cache_token()
- if ns.cache_token != current_token:
- dispatch_cache.clear()
- ns.cache_token = current_token
- try:
- impl = dispatch_cache[cls]
- except KeyError:
- try:
- impl = registry[cls]
- except KeyError:
- impl = _find_impl(cls, registry)
- dispatch_cache[cls] = impl
- return impl
- def register(cls, func=None):
- """generic_func.register(cls, func) -> func
- Registers a new implementation for the given *cls* on a *generic_func*.
- """
- if func is None:
- return lambda f: register(cls, f)
- registry[cls] = func
- if ns.cache_token is None and hasattr(cls, '__abstractmethods__'):
- ns.cache_token = get_cache_token()
- dispatch_cache.clear()
- return func
- def wrapper(*args, **kw):
- return dispatch(args[0].__class__)(*args, **kw)
- registry[object] = func
- wrapper.register = register
- wrapper.dispatch = dispatch
- wrapper.registry = MappingProxyType(registry)
- wrapper._clear_cache = dispatch_cache.clear
- update_wrapper(wrapper, func)
- return wrapper
|