fields.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. import bson
  2. from marshmallow import ValidationError, fields, missing
  3. from mongoengine import ValidationError as MongoValidationError, NotRegistered
  4. from mongoengine.base import get_document
  5. # Republish default fields...
  6. from marshmallow.fields import * # noqa
  7. # ...and add custom ones for mongoengine
  8. class ObjectId(fields.Field):
  9. def _deserialize(self, value, attr, data):
  10. try:
  11. return bson.ObjectId(value)
  12. except Exception:
  13. raise ValidationError('invalid ObjectId `%s`' % value)
  14. def _serialize(self, value, attr, obj):
  15. if value is None:
  16. return missing
  17. return str(value)
  18. class Point(fields.Field):
  19. def _deserialize(self, value, attr, data):
  20. try:
  21. return dict(
  22. type='Point',
  23. coordinates=[float(value['x']), float(value['y'])]
  24. )
  25. except Exception:
  26. raise ValidationError('invalid Point `%s`' % value)
  27. def _serialize(self, value, attr, obj):
  28. if value is None:
  29. return missing
  30. return dict(
  31. x=value['coordinates'][0],
  32. y=value['coordinates'][1]
  33. )
  34. class Reference(fields.Field):
  35. """
  36. Marshmallow custom field to map with :class Mongoengine.ReferenceField:
  37. """
  38. def __init__(self, document_type_obj, *args, **kwargs):
  39. self.document_type_obj = document_type_obj
  40. super(Reference, self).__init__(*args, **kwargs)
  41. @property
  42. def document_type(self):
  43. if isinstance(self.document_type_obj, str):
  44. self.document_type_obj = get_document(self.document_type_obj)
  45. return self.document_type_obj
  46. def _deserialize(self, value, attr, data):
  47. document_type = self.document_type
  48. try:
  49. return document_type.objects.get(pk=value)
  50. except (document_type.DoesNotExist, MongoValidationError, ValueError, TypeError):
  51. raise ValidationError('unknown document %s `%s`' %
  52. (document_type._class_name, value))
  53. return value
  54. def _serialize(self, value, attr, obj):
  55. # Only return the id of the document for serialization
  56. if value is None:
  57. return missing
  58. return str(value.id) if isinstance(value.id, bson.ObjectId) else value.id
  59. class GenericReference(fields.Field):
  60. """
  61. Marshmallow custom field to map with :class Mongoengine.GenericReferenceField:
  62. :param choices: List of Mongoengine document class (or class name) allowed
  63. .. note:: Without `choices` param, this field allow to reference to
  64. any document in the application which can be a security issue.
  65. """
  66. def __init__(self, *args, **kwargs):
  67. self.document_class_choices = []
  68. choices = kwargs.pop('choices', None)
  69. if choices:
  70. # Temporary fix for https://github.com/MongoEngine/mongoengine/pull/1060
  71. for choice in choices:
  72. if hasattr(choice, '_class_name'):
  73. self.document_class_choices.append(choice._class_name)
  74. else:
  75. self.document_class_choices.append(choice)
  76. super(GenericReference, self).__init__(*args, **kwargs)
  77. def _deserialize(self, value, attr, data):
  78. # To deserialize a generic reference, we need a _cls field in addition
  79. # with the id field
  80. if not isinstance(value, dict) or not value.get('id') or not value.get('_cls'):
  81. raise ValidationError("Need a dict with 'id' and '_cls' fields")
  82. doc_id = value['id']
  83. doc_cls_name = value['_cls']
  84. if self.document_class_choices and doc_cls_name not in self.document_class_choices:
  85. raise ValidationError("Invalid _cls field `%s`, must be one of %s" %
  86. (doc_cls_name, self.document_class_choices))
  87. try:
  88. doc_cls = get_document(doc_cls_name)
  89. except NotRegistered:
  90. raise ValidationError("Invalid _cls field `%s`" % doc_cls_name)
  91. try:
  92. doc = doc_cls.objects.get(pk=doc_id)
  93. except (doc_cls.DoesNotExist, MongoValidationError, ValueError, TypeError):
  94. raise ValidationError('unknown document %s `%s`' %
  95. (doc_cls_name, value))
  96. return doc
  97. def _serialize(self, value, attr, obj):
  98. # Only return the id of the document for serialization
  99. if value is None:
  100. return missing
  101. return value.id
  102. class GenericEmbeddedDocument(fields.Field):
  103. """
  104. Dynamic embedded document
  105. """
  106. def _deserialize(self, value, attr, data):
  107. # Cannot deserialize given we have no way knowing wich kind of
  108. # document is given...
  109. return missing
  110. def _serialize(self, value, attr, obj):
  111. # Create the schema at serialize time to be dynamic
  112. from marshmallow_mongoengine.schema import ModelSchema
  113. class NestedSchema(ModelSchema):
  114. class Meta:
  115. model = type(value)
  116. data, errors = NestedSchema().dump(value)
  117. if errors:
  118. raise ValidationError(errors)
  119. return data
  120. class Map(fields.Field):
  121. """
  122. Marshmallow custom field to map with :class Mongoengine.Map:
  123. """
  124. def __init__(self, mapped, **kwargs):
  125. self.mapped = mapped
  126. self.schema = getattr(mapped, "schema", None)
  127. super(Map, self).__init__(**kwargs)
  128. def _schema_process(self, action, value):
  129. func = getattr(self.schema, action)
  130. total = {}
  131. for k, v in value.items():
  132. data, errors = func(v)
  133. if errors:
  134. raise ValidationError(errors)
  135. total[k] = data
  136. return total
  137. def _serialize(self, value, attr, obj):
  138. if self.schema:
  139. return self._schema_process('dump', value)
  140. else:
  141. return value
  142. def _deserialize(self, value, attr, data):
  143. if self.schema:
  144. return self._schema_process('load', value)
  145. else:
  146. return value
  147. class Skip(fields.Field):
  148. """
  149. Marshmallow custom field that just ignore the current field
  150. """
  151. def _deserialize(self, value, attr, data):
  152. return missing
  153. def _serialize(self, value, attr, obj):
  154. return missing