mongoengine.py 7.3 KB


  1. """ Support for Mongoengine ODM.
  2. .. note:: Support for Mongoengine_ is in early development.
  3. ::
  4. from mixer.backend.mongoengine import mixer
  5. class User(Document):
  6. created_at = DateTimeField(default=datetime.datetime.now)
  7. email = EmailField(required=True)
  8. first_name = StringField(max_length=50)
  9. last_name = StringField(max_length=50)
  10. class Post(Document):
  11. title = StringField(max_length=120, required=True)
  12. author = ReferenceField(User)
  13. tags = ListField(StringField(max_length=30))
  14. post = mixer.blend(Post, author__username='foo')
  15. """
  16. from __future__ import absolute_import
  17. import datetime
  18. import decimal
  19. from bson import ObjectId
  20. from mongoengine import (
  21. BooleanField,
  22. DateTimeField,
  23. DecimalField,
  24. Document,
  25. EmailField,
  26. EmbeddedDocumentField,
  27. FloatField,
  28. GenericReferenceField,
  29. GeoPointField,
  30. IntField,
  31. LineStringField,
  32. ListField,
  33. ObjectIdField,
  34. PointField,
  35. PolygonField,
  36. ReferenceField,
  37. StringField,
  38. URLField,
  39. UUIDField,
  40. )
  41. from .. import mix_types as t
  42. from ..main import (
  43. SKIP_VALUE, TypeMixer as BaseTypeMixer, GenFactory as BaseFactory,
  44. Mixer as BaseMixer, partial, faker
  45. )
  46. def get_objectid(**kwargs):
  47. """ Create a new ObjectId instance.
  48. :return ObjectId:
  49. """
  50. return ObjectId()
  51. def get_pointfield(**kwargs):
  52. """ Get a Point structure.
  53. :return dict:
  54. """
  55. return dict(type='Point', coordinates=faker.coordinates())
  56. def get_linestring(length=5, **kwargs):
  57. """ Get a LineString structure.
  58. :return dict:
  59. """
  60. return dict(type='LineString', coordinates=[faker.coordinates() for _ in range(length)])
  61. def get_polygon(length=5, **kwargs):
  62. """ Get a Poligon structure.
  63. :return dict:
  64. """
  65. lines = []
  66. for _ in range(length):
  67. line = get_linestring()['coordinates']
  68. if lines:
  69. line.insert(0, lines[-1][-1])
  70. lines.append(line)
  71. if lines:
  72. lines[0].insert(0, lines[-1][-1])
  73. return dict(type='Poligon', coordinates=lines)
  74. def get_generic_reference(_typemixer=None, **params):
  75. """ Choose a GenericRelation. """
  76. meta = type(_typemixer)
  77. scheme = faker.random_element([
  78. m for (_, m, _, _) in meta.mixers.keys()
  79. if issubclass(m, Document) and m is not _typemixer._TypeMixer__scheme # noqa
  80. ])
  81. return TypeMixer(scheme, mixer=_typemixer._TypeMixer__mixer,
  82. factory=_typemixer._TypeMixer__factory,
  83. fake=_typemixer._TypeMixer__fake).blend(**params)
  84. class GenFactory(BaseFactory):
  85. """ Map a mongoengine classes to simple types. """
  86. types = {
  87. BooleanField: bool,
  88. DateTimeField: datetime.datetime,
  89. DecimalField: decimal.Decimal,
  90. EmailField: t.EmailString,
  91. FloatField: float,
  92. IntField: int,
  93. StringField: str,
  94. URLField: t.URL,
  95. UUIDField: t.UUID,
  96. }
  97. generators = {
  98. GenericReferenceField: get_generic_reference,
  99. GeoPointField: faker.coordinates,
  100. LineStringField: get_linestring,
  101. ObjectIdField: get_objectid,
  102. PointField: get_pointfield,
  103. PolygonField: get_polygon,
  104. }
  105. class TypeMixer(BaseTypeMixer):
  106. """ TypeMixer for Mongoengine. """
  107. factory = GenFactory
  108. def make_fabric(self, me_field, field_name=None, fake=None, kwargs=None): # noqa
  109. """ Make a fabric for field.
  110. :param me_field: Mongoengine field's instance
  111. :param field_name: Field name
  112. :param fake: Force fake data
  113. :return function:
  114. """
  115. ftype = type(me_field)
  116. kwargs = {} if kwargs is None else kwargs
  117. if me_field.choices:
  118. if isinstance(me_field.choices[0], tuple):
  119. choices, _ = list(zip(*me_field.choices))
  120. else:
  121. choices = list(me_field.choices)
  122. return partial(faker.random_element, choices)
  123. if ftype is StringField:
  124. fab = super(TypeMixer, self).make_fabric(
  125. ftype, field_name=field_name, fake=fake, kwargs=kwargs)
  126. return lambda: fab()[:me_field.max_length]
  127. if ftype is ListField:
  128. fab = self.make_fabric(me_field.field, kwargs=kwargs)
  129. return lambda: [fab() for _ in range(3)]
  130. if isinstance(me_field, (EmbeddedDocumentField, ReferenceField)):
  131. ftype = me_field.document_type
  132. elif ftype is GenericReferenceField:
  133. kwargs.update({'_typemixer': self})
  134. elif ftype is DecimalField:
  135. kwargs['right_digits'] = me_field.precision
  136. return super(TypeMixer, self).make_fabric(
  137. ftype, field_name=field_name, fake=fake, kwargs=kwargs)
  138. @staticmethod
  139. def get_default(field):
  140. """ Get default value from field.
  141. :return value: A default value or NO_VALUE
  142. """
  143. if not field.scheme.default:
  144. return SKIP_VALUE
  145. if callable(field.scheme.default):
  146. return field.scheme.default()
  147. return field.scheme.default
  148. @staticmethod
  149. def is_unique(field):
  150. """ Return True is field's value should be a unique.
  151. :return bool:
  152. """
  153. return field.scheme.unique
  154. @staticmethod
  155. def is_required(field):
  156. """ Return True is field's value should be defined.
  157. :return bool:
  158. """
  159. if isinstance(field.scheme, ReferenceField):
  160. return True
  161. return field.scheme.required or isinstance(field.scheme, ObjectIdField)
  162. def gen_select(self, field_name, select):
  163. """ Select related document from mongo. """
  164. field = self.__fields.get(field_name)
  165. if not field:
  166. return super(TypeMixer, self).gen_select(field_name, select)
  167. return field.name, field.scheme.document_type.objects.filter(**select.params).first()
  168. def guard(self, *args, **kwargs):
  169. """ Ensure for an objects are exist in DB. """
  170. qs = self.__scheme.objects(*args, **kwargs)
  171. count = len(qs)
  172. if count == 1:
  173. return qs[0]
  174. return qs
  175. def reload(self, obj):
  176. """ Reload object from storage. """
  177. return self.__scheme.get(id=obj.id)
  178. def __load_fields(self):
  179. for fname, field in self.__scheme._fields.items():
  180. yield fname, t.Field(field, fname)
  181. class Mixer(BaseMixer):
  182. """ Mixer class for mongoengine.
  183. Default mixer (desnt save a generated instances to db)
  184. ::
  185. from mixer.backend.mongoengine import mixer
  186. user = mixer.blend(User)
  187. You can initialize the Mixer by manual:
  188. ::
  189. from mixer.backend.mongoengine import Mixer
  190. mixer = Mixer(commit=True)
  191. user = mixer.blend(User)
  192. """
  193. type_mixer_cls = TypeMixer
  194. def __init__(self, commit=True, **params):
  195. """ Initialize the Mongoengine Mixer.
  196. :param fake: (True) Generate fake data instead of random data.
  197. :param commit: (True) Save object to Mongo DB.
  198. """
  199. super(Mixer, self).__init__(**params)
  200. self.params['commit'] = commit
  201. def postprocess(self, target):
  202. """ Save instance to DB.
  203. :return instance:
  204. """
  205. if self.params.get('commit') and isinstance(target, Document):
  206. target.save()
  207. return target
  208. mixer = Mixer()
  209. # pylama:ignore=E1120