123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160 |
- # -*- coding: utf-8 -*-
- import copy
- from mongoengine.base import BaseDocument
- import marshmallow as ma
- from marshmallow.compat import with_metaclass
- from marshmallow_mongoengine.convert import ModelConverter
- DEFAULT_SKIP_VALUES = (None, [], {})
- class SchemaOpts(ma.SchemaOpts):
- """Options class for `ModelSchema`.
- Adds the following options:
- - ``model``: The Mongoengine Document model to generate the `Schema`
- from (required).
- - ``model_fields_kwargs``: Dict of {field: kwargs} to provide as
- additionals argument during fields creation.
- - ``model_build_obj``: If true, :Schema load: returns a :model: objects
- instead of a dict (default: True).
- - ``model_converter``: `ModelConverter` class to use for converting the
- Mongoengine Document model to marshmallow fields.
- - ``model_dump_only_pk``: If the document autogenerate it primary_key
- (default behaviour in Mongoengine), ignore it from the incomming data
- (default: False)
- - ``model_skip_values``: Skip the field if it contains one of the given
- values (default: None, [] and {})
- """
- def __init__(self, meta, *args, **kwargs):
- super(SchemaOpts, self).__init__(meta, *args, **kwargs)
- self.model = getattr(meta, 'model', None)
- if self.model and not issubclass(self.model, BaseDocument):
- raise ValueError("`model` must be a subclass of mongoengine.base.BaseDocument")
- self.model_fields_kwargs = getattr(meta, 'model_fields_kwargs', {})
- self.model_dump_only_pk = getattr(meta, 'model_dump_only_pk', False)
- self.model_converter = getattr(meta, 'model_converter', ModelConverter)
- self.model_build_obj = getattr(meta, 'model_build_obj', True)
- self.model_skip_values = getattr(meta, 'model_skip_values', DEFAULT_SKIP_VALUES)
- class SchemaMeta(ma.schema.SchemaMeta):
- """Metaclass for `ModelSchema`."""
- # override SchemaMeta
- @classmethod
- def get_declared_fields(mcs, klass, *args, **kwargs):
- """Updates declared fields with fields converted from the
- Mongoengine model passed as the `model` class Meta option.
- """
- declared_fields = kwargs.get('dict_class', dict)()
- # Generate the fields provided through inheritance
- opts = klass.opts
- model = getattr(opts, 'model', None)
- if model:
- converter = opts.model_converter()
- declared_fields.update(converter.fields_for_model(
- model,
- fields=opts.fields
- ))
- # Generate the fields provided in the current class
- base_fields = super(SchemaMeta, mcs).get_declared_fields(
- klass, *args, **kwargs
- )
- declared_fields.update(base_fields)
- # Customize fields with provided kwargs
- for field_name, field_kwargs in klass.opts.model_fields_kwargs.items():
- field = declared_fields.get(field_name, None)
- if field:
- # Copy to prevent alteration of a possible parent class's field
- field = copy.copy(field)
- for key, value in field_kwargs.items():
- setattr(field, key, value)
- declared_fields[field_name] = field
- if opts.model_dump_only_pk and opts.model:
- # If primary key is automatically generated (nominal case), we
- # must make sure this field is read-only
- if opts.model._auto_id_field is True:
- field_name = opts.model._meta['id_field']
- id_field = declared_fields.get(field_name)
- if id_field:
- # Copy to prevent alteration of a possible parent class's field
- id_field = copy.copy(id_field)
- id_field.dump_only = True
- declared_fields[field_name] = id_field
- return declared_fields
- class ModelSchema(with_metaclass(SchemaMeta, ma.Schema)):
- """Base class for Mongoengine model-based Schemas.
- Example: ::
- from marshmallow_mongoengine import ModelSchema
- from mymodels import User
- class UserSchema(ModelSchema):
- class Meta:
- model = User
- """
- OPTIONS_CLASS = SchemaOpts
- @ma.post_dump
- def _remove_skip_values(self, data):
- to_skip = self.opts.model_skip_values
- return {
- key: value for key, value in data.items()
- if value not in to_skip
- }
- @ma.post_load
- def _make_object(self, data):
- if self.opts.model_build_obj and self.opts.model:
- return self.opts.model(**data)
- else:
- return data
- def update(self, obj, data):
- """Helper function to update an already existing document
- instead of creating a new one.
- :param obj: Mongoengine Document to update
- :param data: incomming payload to deserialize
- :return: an :class UnmarshallResult:
- Example: ::
- from marshmallow_mongoengine import ModelSchema
- from mymodels import User
- class UserSchema(ModelSchema):
- class Meta:
- model = User
- def update_obj(id, payload):
- user = User.objects(id=id).first()
- result = UserSchema().update(user, payload)
- result.data is user # True
- Note:
- Given the update is done on a existing object, the required param
- on the fields is ignored
- """
- # TODO: find a cleaner way to skip required validation on update
- required_fields = [k for k, f in self.fields.items() if f.required]
- for field in required_fields:
- self.fields[field].required = False
- loaded_data, errors = self._do_load(data, postprocess=False)
- for field in required_fields:
- self.fields[field].required = True
- if not errors:
- # Update the given obj fields
- for k, v in loaded_data.items():
- # Skip default values that have been automatically
- # added during unserialization
- if k in data:
- setattr(obj, k, v)
- return ma.UnmarshalResult(data=obj, errors=errors)
|