_pclass.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. import six
  2. from pyrsistent._checked_types import (InvariantException, CheckedType, _restore_pickle, store_invariants)
  3. from pyrsistent._field_common import (
  4. set_fields, check_type, is_field_ignore_extra_complaint, PFIELD_NO_INITIAL, serialize, check_global_invariants
  5. )
  6. from pyrsistent._transformations import transform
  7. def _is_pclass(bases):
  8. return len(bases) == 1 and bases[0] == CheckedType
  9. class PClassMeta(type):
  10. def __new__(mcs, name, bases, dct):
  11. set_fields(dct, bases, name='_pclass_fields')
  12. store_invariants(dct, bases, '_pclass_invariants', '__invariant__')
  13. dct['__slots__'] = ('_pclass_frozen',) + tuple(key for key in dct['_pclass_fields'])
  14. # There must only be one __weakref__ entry in the inheritance hierarchy,
  15. # lets put it on the top level class.
  16. if _is_pclass(bases):
  17. dct['__slots__'] += ('__weakref__',)
  18. return super(PClassMeta, mcs).__new__(mcs, name, bases, dct)
  19. _MISSING_VALUE = object()
  20. def _check_and_set_attr(cls, field, name, value, result, invariant_errors):
  21. check_type(cls, field, name, value)
  22. is_ok, error_code = field.invariant(value)
  23. if not is_ok:
  24. invariant_errors.append(error_code)
  25. else:
  26. setattr(result, name, value)
  27. @six.add_metaclass(PClassMeta)
  28. class PClass(CheckedType):
  29. """
  30. A PClass is a python class with a fixed set of specified fields. PClasses are declared as python classes inheriting
  31. from PClass. It is defined the same way that PRecords are and behaves like a PRecord in all aspects except that it
  32. is not a PMap and hence not a collection but rather a plain Python object.
  33. More documentation and examples of PClass usage is available at https://github.com/tobgu/pyrsistent
  34. """
  35. def __new__(cls, **kwargs): # Support *args?
  36. result = super(PClass, cls).__new__(cls)
  37. factory_fields = kwargs.pop('_factory_fields', None)
  38. ignore_extra = kwargs.pop('ignore_extra', None)
  39. missing_fields = []
  40. invariant_errors = []
  41. for name, field in cls._pclass_fields.items():
  42. if name in kwargs:
  43. if factory_fields is None or name in factory_fields:
  44. if is_field_ignore_extra_complaint(PClass, field, ignore_extra):
  45. value = field.factory(kwargs[name], ignore_extra=ignore_extra)
  46. else:
  47. value = field.factory(kwargs[name])
  48. else:
  49. value = kwargs[name]
  50. _check_and_set_attr(cls, field, name, value, result, invariant_errors)
  51. del kwargs[name]
  52. elif field.initial is not PFIELD_NO_INITIAL:
  53. initial = field.initial() if callable(field.initial) else field.initial
  54. _check_and_set_attr(
  55. cls, field, name, initial, result, invariant_errors)
  56. elif field.mandatory:
  57. missing_fields.append('{0}.{1}'.format(cls.__name__, name))
  58. if invariant_errors or missing_fields:
  59. raise InvariantException(tuple(invariant_errors), tuple(missing_fields), 'Field invariant failed')
  60. if kwargs:
  61. raise AttributeError("'{0}' are not among the specified fields for {1}".format(
  62. ', '.join(kwargs), cls.__name__))
  63. check_global_invariants(result, cls._pclass_invariants)
  64. result._pclass_frozen = True
  65. return result
  66. def set(self, *args, **kwargs):
  67. """
  68. Set a field in the instance. Returns a new instance with the updated value. The original instance remains
  69. unmodified. Accepts key-value pairs or single string representing the field name and a value.
  70. >>> from pyrsistent import PClass, field
  71. >>> class AClass(PClass):
  72. ... x = field()
  73. ...
  74. >>> a = AClass(x=1)
  75. >>> a2 = a.set(x=2)
  76. >>> a3 = a.set('x', 3)
  77. >>> a
  78. AClass(x=1)
  79. >>> a2
  80. AClass(x=2)
  81. >>> a3
  82. AClass(x=3)
  83. """
  84. if args:
  85. kwargs[args[0]] = args[1]
  86. factory_fields = set(kwargs)
  87. for key in self._pclass_fields:
  88. if key not in kwargs:
  89. value = getattr(self, key, _MISSING_VALUE)
  90. if value is not _MISSING_VALUE:
  91. kwargs[key] = value
  92. return self.__class__(_factory_fields=factory_fields, **kwargs)
  93. @classmethod
  94. def create(cls, kwargs, _factory_fields=None, ignore_extra=False):
  95. """
  96. Factory method. Will create a new PClass of the current type and assign the values
  97. specified in kwargs.
  98. :param ignore_extra: A boolean which when set to True will ignore any keys which appear in kwargs that are not
  99. in the set of fields on the PClass.
  100. """
  101. if isinstance(kwargs, cls):
  102. return kwargs
  103. if ignore_extra:
  104. kwargs = {k: kwargs[k] for k in cls._pclass_fields if k in kwargs}
  105. return cls(_factory_fields=_factory_fields, ignore_extra=ignore_extra, **kwargs)
  106. def serialize(self, format=None):
  107. """
  108. Serialize the current PClass using custom serializer functions for fields where
  109. such have been supplied.
  110. """
  111. result = {}
  112. for name in self._pclass_fields:
  113. value = getattr(self, name, _MISSING_VALUE)
  114. if value is not _MISSING_VALUE:
  115. result[name] = serialize(self._pclass_fields[name].serializer, format, value)
  116. return result
  117. def transform(self, *transformations):
  118. """
  119. Apply transformations to the currency PClass. For more details on transformations see
  120. the documentation for PMap. Transformations on PClasses do not support key matching
  121. since the PClass is not a collection. Apart from that the transformations available
  122. for other persistent types work as expected.
  123. """
  124. return transform(self, transformations)
  125. def __eq__(self, other):
  126. if isinstance(other, self.__class__):
  127. for name in self._pclass_fields:
  128. if getattr(self, name, _MISSING_VALUE) != getattr(other, name, _MISSING_VALUE):
  129. return False
  130. return True
  131. return NotImplemented
  132. def __ne__(self, other):
  133. return not self == other
  134. def __hash__(self):
  135. # May want to optimize this by caching the hash somehow
  136. return hash(tuple((key, getattr(self, key, _MISSING_VALUE)) for key in self._pclass_fields))
  137. def __setattr__(self, key, value):
  138. if getattr(self, '_pclass_frozen', False):
  139. raise AttributeError("Can't set attribute, key={0}, value={1}".format(key, value))
  140. super(PClass, self).__setattr__(key, value)
  141. def __delattr__(self, key):
  142. raise AttributeError("Can't delete attribute, key={0}, use remove()".format(key))
  143. def _to_dict(self):
  144. result = {}
  145. for key in self._pclass_fields:
  146. value = getattr(self, key, _MISSING_VALUE)
  147. if value is not _MISSING_VALUE:
  148. result[key] = value
  149. return result
  150. def __repr__(self):
  151. return "{0}({1})".format(self.__class__.__name__,
  152. ', '.join('{0}={1}'.format(k, repr(v)) for k, v in self._to_dict().items()))
  153. def __reduce__(self):
  154. # Pickling support
  155. data = dict((key, getattr(self, key)) for key in self._pclass_fields if hasattr(self, key))
  156. return _restore_pickle, (self.__class__, data,)
  157. def evolver(self):
  158. """
  159. Returns an evolver for this object.
  160. """
  161. return _PClassEvolver(self, self._to_dict())
  162. def remove(self, name):
  163. """
  164. Remove attribute given by name from the current instance. Raises AttributeError if the
  165. attribute doesn't exist.
  166. """
  167. evolver = self.evolver()
  168. del evolver[name]
  169. return evolver.persistent()
  170. class _PClassEvolver(object):
  171. __slots__ = ('_pclass_evolver_original', '_pclass_evolver_data', '_pclass_evolver_data_is_dirty', '_factory_fields')
  172. def __init__(self, original, initial_dict):
  173. self._pclass_evolver_original = original
  174. self._pclass_evolver_data = initial_dict
  175. self._pclass_evolver_data_is_dirty = False
  176. self._factory_fields = set()
  177. def __getitem__(self, item):
  178. return self._pclass_evolver_data[item]
  179. def set(self, key, value):
  180. if self._pclass_evolver_data.get(key, _MISSING_VALUE) is not value:
  181. self._pclass_evolver_data[key] = value
  182. self._factory_fields.add(key)
  183. self._pclass_evolver_data_is_dirty = True
  184. return self
  185. def __setitem__(self, key, value):
  186. self.set(key, value)
  187. def remove(self, item):
  188. if item in self._pclass_evolver_data:
  189. del self._pclass_evolver_data[item]
  190. self._factory_fields.discard(item)
  191. self._pclass_evolver_data_is_dirty = True
  192. return self
  193. raise AttributeError(item)
  194. def __delitem__(self, item):
  195. self.remove(item)
  196. def persistent(self):
  197. if self._pclass_evolver_data_is_dirty:
  198. return self._pclass_evolver_original.__class__(_factory_fields=self._factory_fields,
  199. **self._pclass_evolver_data)
  200. return self._pclass_evolver_original
  201. def __setattr__(self, key, value):
  202. if key not in self.__slots__:
  203. self.set(key, value)
  204. else:
  205. super(_PClassEvolver, self).__setattr__(key, value)
  206. def __getattr__(self, item):
  207. return self[item]