_precord.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import six
  2. from pyrsistent._checked_types import CheckedType, _restore_pickle, InvariantException, store_invariants
  3. from pyrsistent._field_common import (
  4. set_fields, check_type, is_field_ignore_extra_complaint, PFIELD_NO_INITIAL, serialize, check_global_invariants
  5. )
  6. from pyrsistent._pmap import PMap, pmap
  7. class _PRecordMeta(type):
  8. def __new__(mcs, name, bases, dct):
  9. set_fields(dct, bases, name='_precord_fields')
  10. store_invariants(dct, bases, '_precord_invariants', '__invariant__')
  11. dct['_precord_mandatory_fields'] = \
  12. set(name for name, field in dct['_precord_fields'].items() if field.mandatory)
  13. dct['_precord_initial_values'] = \
  14. dict((k, field.initial) for k, field in dct['_precord_fields'].items() if field.initial is not PFIELD_NO_INITIAL)
  15. dct['__slots__'] = ()
  16. return super(_PRecordMeta, mcs).__new__(mcs, name, bases, dct)
  17. @six.add_metaclass(_PRecordMeta)
  18. class PRecord(PMap, CheckedType):
  19. """
  20. A PRecord is a PMap with a fixed set of specified fields. Records are declared as python classes inheriting
  21. from PRecord. Because it is a PMap it has full support for all Mapping methods such as iteration and element
  22. access using subscript notation.
  23. More documentation and examples of PRecord usage is available at https://github.com/tobgu/pyrsistent
  24. """
  25. def __new__(cls, **kwargs):
  26. # Hack total! If these two special attributes exist that means we can create
  27. # ourselves. Otherwise we need to go through the Evolver to create the structures
  28. # for us.
  29. if '_precord_size' in kwargs and '_precord_buckets' in kwargs:
  30. return super(PRecord, cls).__new__(cls, kwargs['_precord_size'], kwargs['_precord_buckets'])
  31. factory_fields = kwargs.pop('_factory_fields', None)
  32. ignore_extra = kwargs.pop('_ignore_extra', False)
  33. initial_values = kwargs
  34. if cls._precord_initial_values:
  35. initial_values = dict((k, v() if callable(v) else v)
  36. for k, v in cls._precord_initial_values.items())
  37. initial_values.update(kwargs)
  38. e = _PRecordEvolver(cls, pmap(), _factory_fields=factory_fields, _ignore_extra=ignore_extra)
  39. for k, v in initial_values.items():
  40. e[k] = v
  41. return e.persistent()
  42. def set(self, *args, **kwargs):
  43. """
  44. Set a field in the record. This set function differs slightly from that in the PMap
  45. class. First of all it accepts key-value pairs. Second it accepts multiple key-value
  46. pairs to perform one, atomic, update of multiple fields.
  47. """
  48. # The PRecord set() can accept kwargs since all fields that have been declared are
  49. # valid python identifiers. Also allow multiple fields to be set in one operation.
  50. if args:
  51. return super(PRecord, self).set(args[0], args[1])
  52. return self.update(kwargs)
  53. def evolver(self):
  54. """
  55. Returns an evolver of this object.
  56. """
  57. return _PRecordEvolver(self.__class__, self)
  58. def __repr__(self):
  59. return "{0}({1})".format(self.__class__.__name__,
  60. ', '.join('{0}={1}'.format(k, repr(v)) for k, v in self.items()))
  61. @classmethod
  62. def create(cls, kwargs, _factory_fields=None, ignore_extra=False):
  63. """
  64. Factory method. Will create a new PRecord of the current type and assign the values
  65. specified in kwargs.
  66. :param ignore_extra: A boolean which when set to True will ignore any keys which appear in kwargs that are not
  67. in the set of fields on the PRecord.
  68. """
  69. if isinstance(kwargs, cls):
  70. return kwargs
  71. if ignore_extra:
  72. kwargs = {k: kwargs[k] for k in cls._precord_fields if k in kwargs}
  73. return cls(_factory_fields=_factory_fields, _ignore_extra=ignore_extra, **kwargs)
  74. def __reduce__(self):
  75. # Pickling support
  76. return _restore_pickle, (self.__class__, dict(self),)
  77. def serialize(self, format=None):
  78. """
  79. Serialize the current PRecord using custom serializer functions for fields where
  80. such have been supplied.
  81. """
  82. return dict((k, serialize(self._precord_fields[k].serializer, format, v)) for k, v in self.items())
  83. class _PRecordEvolver(PMap._Evolver):
  84. __slots__ = ('_destination_cls', '_invariant_error_codes', '_missing_fields', '_factory_fields', '_ignore_extra')
  85. def __init__(self, cls, original_pmap, _factory_fields=None, _ignore_extra=False):
  86. super(_PRecordEvolver, self).__init__(original_pmap)
  87. self._destination_cls = cls
  88. self._invariant_error_codes = []
  89. self._missing_fields = []
  90. self._factory_fields = _factory_fields
  91. self._ignore_extra = _ignore_extra
  92. def __setitem__(self, key, original_value):
  93. self.set(key, original_value)
  94. def set(self, key, original_value):
  95. field = self._destination_cls._precord_fields.get(key)
  96. if field:
  97. if self._factory_fields is None or field in self._factory_fields:
  98. try:
  99. if is_field_ignore_extra_complaint(PRecord, field, self._ignore_extra):
  100. value = field.factory(original_value, ignore_extra=self._ignore_extra)
  101. else:
  102. value = field.factory(original_value)
  103. except InvariantException as e:
  104. self._invariant_error_codes += e.invariant_errors
  105. self._missing_fields += e.missing_fields
  106. return self
  107. else:
  108. value = original_value
  109. check_type(self._destination_cls, field, key, value)
  110. is_ok, error_code = field.invariant(value)
  111. if not is_ok:
  112. self._invariant_error_codes.append(error_code)
  113. return super(_PRecordEvolver, self).set(key, value)
  114. else:
  115. raise AttributeError("'{0}' is not among the specified fields for {1}".format(key, self._destination_cls.__name__))
  116. def persistent(self):
  117. cls = self._destination_cls
  118. is_dirty = self.is_dirty()
  119. pm = super(_PRecordEvolver, self).persistent()
  120. if is_dirty or not isinstance(pm, cls):
  121. result = cls(_precord_buckets=pm._buckets, _precord_size=pm._size)
  122. else:
  123. result = pm
  124. if cls._precord_mandatory_fields:
  125. self._missing_fields += tuple('{0}.{1}'.format(cls.__name__, f) for f
  126. in (cls._precord_mandatory_fields - set(result.keys())))
  127. if self._invariant_error_codes or self._missing_fields:
  128. raise InvariantException(tuple(self._invariant_error_codes), tuple(self._missing_fields),
  129. 'Field invariant failed')
  130. check_global_invariants(result, cls._precord_invariants)
  131. return result