schema.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # -*- coding: utf-8 -*-
  2. import copy
  3. from mongoengine.base import BaseDocument
  4. import marshmallow as ma
  5. from marshmallow.compat import with_metaclass
  6. from marshmallow_mongoengine.convert import ModelConverter
  7. DEFAULT_SKIP_VALUES = (None, [], {})
  8. class SchemaOpts(ma.SchemaOpts):
  9. """Options class for `ModelSchema`.
  10. Adds the following options:
  11. - ``model``: The Mongoengine Document model to generate the `Schema`
  12. from (required).
  13. - ``model_fields_kwargs``: Dict of {field: kwargs} to provide as
  14. additionals argument during fields creation.
  15. - ``model_build_obj``: If true, :Schema load: returns a :model: objects
  16. instead of a dict (default: True).
  17. - ``model_converter``: `ModelConverter` class to use for converting the
  18. Mongoengine Document model to marshmallow fields.
  19. - ``model_dump_only_pk``: If the document autogenerate it primary_key
  20. (default behaviour in Mongoengine), ignore it from the incomming data
  21. (default: False)
  22. - ``model_skip_values``: Skip the field if it contains one of the given
  23. values (default: None, [] and {})
  24. """
  25. def __init__(self, meta, *args, **kwargs):
  26. super(SchemaOpts, self).__init__(meta, *args, **kwargs)
  27. self.model = getattr(meta, 'model', None)
  28. if self.model and not issubclass(self.model, BaseDocument):
  29. raise ValueError("`model` must be a subclass of mongoengine.base.BaseDocument")
  30. self.model_fields_kwargs = getattr(meta, 'model_fields_kwargs', {})
  31. self.model_dump_only_pk = getattr(meta, 'model_dump_only_pk', False)
  32. self.model_converter = getattr(meta, 'model_converter', ModelConverter)
  33. self.model_build_obj = getattr(meta, 'model_build_obj', True)
  34. self.model_skip_values = getattr(meta, 'model_skip_values', DEFAULT_SKIP_VALUES)
  35. class SchemaMeta(ma.schema.SchemaMeta):
  36. """Metaclass for `ModelSchema`."""
  37. # override SchemaMeta
  38. @classmethod
  39. def get_declared_fields(mcs, klass, *args, **kwargs):
  40. """Updates declared fields with fields converted from the
  41. Mongoengine model passed as the `model` class Meta option.
  42. """
  43. declared_fields = kwargs.get('dict_class', dict)()
  44. # Generate the fields provided through inheritance
  45. opts = klass.opts
  46. model = getattr(opts, 'model', None)
  47. if model:
  48. converter = opts.model_converter()
  49. declared_fields.update(converter.fields_for_model(
  50. model,
  51. fields=opts.fields
  52. ))
  53. # Generate the fields provided in the current class
  54. base_fields = super(SchemaMeta, mcs).get_declared_fields(
  55. klass, *args, **kwargs
  56. )
  57. declared_fields.update(base_fields)
  58. # Customize fields with provided kwargs
  59. for field_name, field_kwargs in klass.opts.model_fields_kwargs.items():
  60. field = declared_fields.get(field_name, None)
  61. if field:
  62. # Copy to prevent alteration of a possible parent class's field
  63. field = copy.copy(field)
  64. for key, value in field_kwargs.items():
  65. setattr(field, key, value)
  66. declared_fields[field_name] = field
  67. if opts.model_dump_only_pk and opts.model:
  68. # If primary key is automatically generated (nominal case), we
  69. # must make sure this field is read-only
  70. if opts.model._auto_id_field is True:
  71. field_name = opts.model._meta['id_field']
  72. id_field = declared_fields.get(field_name)
  73. if id_field:
  74. # Copy to prevent alteration of a possible parent class's field
  75. id_field = copy.copy(id_field)
  76. id_field.dump_only = True
  77. declared_fields[field_name] = id_field
  78. return declared_fields
  79. class ModelSchema(with_metaclass(SchemaMeta, ma.Schema)):
  80. """Base class for Mongoengine model-based Schemas.
  81. Example: ::
  82. from marshmallow_mongoengine import ModelSchema
  83. from mymodels import User
  84. class UserSchema(ModelSchema):
  85. class Meta:
  86. model = User
  87. """
  88. OPTIONS_CLASS = SchemaOpts
  89. @ma.post_dump
  90. def _remove_skip_values(self, data):
  91. to_skip = self.opts.model_skip_values
  92. return {
  93. key: value for key, value in data.items()
  94. if value not in to_skip
  95. }
  96. @ma.post_load
  97. def _make_object(self, data):
  98. if self.opts.model_build_obj and self.opts.model:
  99. return self.opts.model(**data)
  100. else:
  101. return data
  102. def update(self, obj, data):
  103. """Helper function to update an already existing document
  104. instead of creating a new one.
  105. :param obj: Mongoengine Document to update
  106. :param data: incomming payload to deserialize
  107. :return: an :class UnmarshallResult:
  108. Example: ::
  109. from marshmallow_mongoengine import ModelSchema
  110. from mymodels import User
  111. class UserSchema(ModelSchema):
  112. class Meta:
  113. model = User
  114. def update_obj(id, payload):
  115. user = User.objects(id=id).first()
  116. result = UserSchema().update(user, payload)
  117. result.data is user # True
  118. Note:
  119. Given the update is done on a existing object, the required param
  120. on the fields is ignored
  121. """
  122. # TODO: find a cleaner way to skip required validation on update
  123. required_fields = [k for k, f in self.fields.items() if f.required]
  124. for field in required_fields:
  125. self.fields[field].required = False
  126. loaded_data, errors = self._do_load(data, postprocess=False)
  127. for field in required_fields:
  128. self.fields[field].required = True
  129. if not errors:
  130. # Update the given obj fields
  131. for k, v in loaded_data.items():
  132. # Skip default values that have been automatically
  133. # added during unserialization
  134. if k in data:
  135. setattr(obj, k, v)
  136. return ma.UnmarshalResult(data=obj, errors=errors)