123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264 |
- import six
- from pyrsistent._checked_types import (InvariantException, CheckedType, _restore_pickle, store_invariants)
- from pyrsistent._field_common import (
- set_fields, check_type, is_field_ignore_extra_complaint, PFIELD_NO_INITIAL, serialize, check_global_invariants
- )
- from pyrsistent._transformations import transform
- def _is_pclass(bases):
- return len(bases) == 1 and bases[0] == CheckedType
- class PClassMeta(type):
- def __new__(mcs, name, bases, dct):
- set_fields(dct, bases, name='_pclass_fields')
- store_invariants(dct, bases, '_pclass_invariants', '__invariant__')
- dct['__slots__'] = ('_pclass_frozen',) + tuple(key for key in dct['_pclass_fields'])
- # There must only be one __weakref__ entry in the inheritance hierarchy,
- # lets put it on the top level class.
- if _is_pclass(bases):
- dct['__slots__'] += ('__weakref__',)
- return super(PClassMeta, mcs).__new__(mcs, name, bases, dct)
- _MISSING_VALUE = object()
- def _check_and_set_attr(cls, field, name, value, result, invariant_errors):
- check_type(cls, field, name, value)
- is_ok, error_code = field.invariant(value)
- if not is_ok:
- invariant_errors.append(error_code)
- else:
- setattr(result, name, value)
- @six.add_metaclass(PClassMeta)
- class PClass(CheckedType):
- """
- A PClass is a python class with a fixed set of specified fields. PClasses are declared as python classes inheriting
- from PClass. It is defined the same way that PRecords are and behaves like a PRecord in all aspects except that it
- is not a PMap and hence not a collection but rather a plain Python object.
- More documentation and examples of PClass usage is available at https://github.com/tobgu/pyrsistent
- """
- def __new__(cls, **kwargs): # Support *args?
- result = super(PClass, cls).__new__(cls)
- factory_fields = kwargs.pop('_factory_fields', None)
- ignore_extra = kwargs.pop('ignore_extra', None)
- missing_fields = []
- invariant_errors = []
- for name, field in cls._pclass_fields.items():
- if name in kwargs:
- if factory_fields is None or name in factory_fields:
- if is_field_ignore_extra_complaint(PClass, field, ignore_extra):
- value = field.factory(kwargs[name], ignore_extra=ignore_extra)
- else:
- value = field.factory(kwargs[name])
- else:
- value = kwargs[name]
- _check_and_set_attr(cls, field, name, value, result, invariant_errors)
- del kwargs[name]
- elif field.initial is not PFIELD_NO_INITIAL:
- initial = field.initial() if callable(field.initial) else field.initial
- _check_and_set_attr(
- cls, field, name, initial, result, invariant_errors)
- elif field.mandatory:
- missing_fields.append('{0}.{1}'.format(cls.__name__, name))
- if invariant_errors or missing_fields:
- raise InvariantException(tuple(invariant_errors), tuple(missing_fields), 'Field invariant failed')
- if kwargs:
- raise AttributeError("'{0}' are not among the specified fields for {1}".format(
- ', '.join(kwargs), cls.__name__))
- check_global_invariants(result, cls._pclass_invariants)
- result._pclass_frozen = True
- return result
- def set(self, *args, **kwargs):
- """
- Set a field in the instance. Returns a new instance with the updated value. The original instance remains
- unmodified. Accepts key-value pairs or single string representing the field name and a value.
- >>> from pyrsistent import PClass, field
- >>> class AClass(PClass):
- ... x = field()
- ...
- >>> a = AClass(x=1)
- >>> a2 = a.set(x=2)
- >>> a3 = a.set('x', 3)
- >>> a
- AClass(x=1)
- >>> a2
- AClass(x=2)
- >>> a3
- AClass(x=3)
- """
- if args:
- kwargs[args[0]] = args[1]
- factory_fields = set(kwargs)
- for key in self._pclass_fields:
- if key not in kwargs:
- value = getattr(self, key, _MISSING_VALUE)
- if value is not _MISSING_VALUE:
- kwargs[key] = value
- return self.__class__(_factory_fields=factory_fields, **kwargs)
- @classmethod
- def create(cls, kwargs, _factory_fields=None, ignore_extra=False):
- """
- Factory method. Will create a new PClass of the current type and assign the values
- specified in kwargs.
- :param ignore_extra: A boolean which when set to True will ignore any keys which appear in kwargs that are not
- in the set of fields on the PClass.
- """
- if isinstance(kwargs, cls):
- return kwargs
- if ignore_extra:
- kwargs = {k: kwargs[k] for k in cls._pclass_fields if k in kwargs}
- return cls(_factory_fields=_factory_fields, ignore_extra=ignore_extra, **kwargs)
- def serialize(self, format=None):
- """
- Serialize the current PClass using custom serializer functions for fields where
- such have been supplied.
- """
- result = {}
- for name in self._pclass_fields:
- value = getattr(self, name, _MISSING_VALUE)
- if value is not _MISSING_VALUE:
- result[name] = serialize(self._pclass_fields[name].serializer, format, value)
- return result
- def transform(self, *transformations):
- """
- Apply transformations to the currency PClass. For more details on transformations see
- the documentation for PMap. Transformations on PClasses do not support key matching
- since the PClass is not a collection. Apart from that the transformations available
- for other persistent types work as expected.
- """
- return transform(self, transformations)
- def __eq__(self, other):
- if isinstance(other, self.__class__):
- for name in self._pclass_fields:
- if getattr(self, name, _MISSING_VALUE) != getattr(other, name, _MISSING_VALUE):
- return False
- return True
- return NotImplemented
- def __ne__(self, other):
- return not self == other
- def __hash__(self):
- # May want to optimize this by caching the hash somehow
- return hash(tuple((key, getattr(self, key, _MISSING_VALUE)) for key in self._pclass_fields))
- def __setattr__(self, key, value):
- if getattr(self, '_pclass_frozen', False):
- raise AttributeError("Can't set attribute, key={0}, value={1}".format(key, value))
- super(PClass, self).__setattr__(key, value)
- def __delattr__(self, key):
- raise AttributeError("Can't delete attribute, key={0}, use remove()".format(key))
- def _to_dict(self):
- result = {}
- for key in self._pclass_fields:
- value = getattr(self, key, _MISSING_VALUE)
- if value is not _MISSING_VALUE:
- result[key] = value
- return result
- def __repr__(self):
- return "{0}({1})".format(self.__class__.__name__,
- ', '.join('{0}={1}'.format(k, repr(v)) for k, v in self._to_dict().items()))
- def __reduce__(self):
- # Pickling support
- data = dict((key, getattr(self, key)) for key in self._pclass_fields if hasattr(self, key))
- return _restore_pickle, (self.__class__, data,)
- def evolver(self):
- """
- Returns an evolver for this object.
- """
- return _PClassEvolver(self, self._to_dict())
- def remove(self, name):
- """
- Remove attribute given by name from the current instance. Raises AttributeError if the
- attribute doesn't exist.
- """
- evolver = self.evolver()
- del evolver[name]
- return evolver.persistent()
- class _PClassEvolver(object):
- __slots__ = ('_pclass_evolver_original', '_pclass_evolver_data', '_pclass_evolver_data_is_dirty', '_factory_fields')
- def __init__(self, original, initial_dict):
- self._pclass_evolver_original = original
- self._pclass_evolver_data = initial_dict
- self._pclass_evolver_data_is_dirty = False
- self._factory_fields = set()
- def __getitem__(self, item):
- return self._pclass_evolver_data[item]
- def set(self, key, value):
- if self._pclass_evolver_data.get(key, _MISSING_VALUE) is not value:
- self._pclass_evolver_data[key] = value
- self._factory_fields.add(key)
- self._pclass_evolver_data_is_dirty = True
- return self
- def __setitem__(self, key, value):
- self.set(key, value)
- def remove(self, item):
- if item in self._pclass_evolver_data:
- del self._pclass_evolver_data[item]
- self._factory_fields.discard(item)
- self._pclass_evolver_data_is_dirty = True
- return self
- raise AttributeError(item)
- def __delitem__(self, item):
- self.remove(item)
- def persistent(self):
- if self._pclass_evolver_data_is_dirty:
- return self._pclass_evolver_original.__class__(_factory_fields=self._factory_fields,
- **self._pclass_evolver_data)
- return self._pclass_evolver_original
- def __setattr__(self, key, value):
- if key not in self.__slots__:
- self.set(key, value)
- else:
- super(_PClassEvolver, self).__setattr__(key, value)
- def __getattr__(self, item):
- return self[item]
|