peewee.py 5.0 KB


  1. """ Support for Peewee ODM.
  2. ::
  3. from mixer.backend.peewee import mixer
  4. """
  5. from __future__ import absolute_import
  6. from peewee import * # noqa
  7. try:
  8. from peewee import AutoField
  9. except ImportError:
  10. from peewee import PrimaryKeyField as AutoField
  11. import datetime
  12. import decimal
  13. from .. import mix_types as t
  14. from ..main import (
  15. TypeMixer as BaseTypeMixer, Mixer as BaseMixer, SKIP_VALUE,
  16. GenFactory as BaseFactory, partial, faker)
  17. def get_relation(_scheme=None, _typemixer=None, **params):
  18. """ Function description. """
  19. scheme = _scheme.rel_model
  20. return TypeMixer(
  21. scheme,
  22. mixer=_typemixer._TypeMixer__mixer,
  23. factory=_typemixer._TypeMixer__factory,
  24. fake=_typemixer._TypeMixer__fake,
  25. ).blend(**params)
  26. def get_blob(**kwargs):
  27. """ Generate value for BlobField. """
  28. raise NotImplementedError
  29. class GenFactory(BaseFactory):
  30. """ Map a peewee classes to simple types. """
  31. types = {
  32. AutoField: t.PositiveInteger,
  33. IntegerField: int,
  34. BigIntegerField: t.BigInteger,
  35. (FloatField, DoubleField): float,
  36. DecimalField: decimal.Decimal,
  37. CharField: str,
  38. TextField: t.Text,
  39. DateTimeField: datetime.datetime,
  40. DateField: datetime.date,
  41. TimeField: datetime.time,
  42. BooleanField: bool,
  43. # BlobField: None,
  44. }
  45. generators = {
  46. BlobField: get_blob,
  47. ForeignKeyField: get_relation,
  48. }
  49. class TypeMixer(BaseTypeMixer):
  50. """ TypeMixer for Peewee ORM. """
  51. factory = GenFactory
  52. def __load_fields(self):
  53. for field in self.__scheme._meta.sorted_fields:
  54. yield field.name, t.Field(field, field.name)
  55. def populate_target(self, values):
  56. """ Populate target. """
  57. return self.__scheme(**dict(values))
  58. def gen_field(self, field):
  59. """ Function description. """
  60. if isinstance(field.scheme, AutoField)\
  61. and self.__mixer and self.__mixer.params.get('commit'):
  62. return field.name, SKIP_VALUE
  63. return super(TypeMixer, self).gen_field(field)
  64. def gen_select(self, field_name, select):
  65. """ Select exists value from database.
  66. :param field_name: Name of field for generation.
  67. :return : None or (name, value) for later use
  68. """
  69. field = self.__fields[field_name]
  70. if not isinstance(field.scheme, ForeignKeyField):
  71. return field_name, SKIP_VALUE
  72. model = field.scheme.rel_model
  73. value = model.select().order_by(fn.Random()).get()
  74. return self.get_value(field_name, value)
  75. def is_required(self, field):
  76. """ Return True is field's value should be defined.
  77. :return bool:
  78. """
  79. return not field.scheme.null
  80. def is_unique(self, field):
  81. """ Return True is field's value should be a unique.
  82. :return bool:
  83. """
  84. return field.scheme.unique
  85. @staticmethod
  86. def get_default(field):
  87. """ Get default value from field.
  88. :return value:
  89. """
  90. return field.scheme.default is None and SKIP_VALUE or field.scheme.default # noqa
  91. def make_fabric(self, field, field_name=None, fake=False, kwargs=None): # noqa
  92. """ Make values fabric for column.
  93. :param column: SqlAlchemy column
  94. :param field_name: Field name
  95. :param fake: Force fake data
  96. :return function:
  97. """
  98. kwargs = {} if kwargs is None else kwargs
  99. if field.choices:
  100. try:
  101. choices, _ = list(zip(*field.choices))
  102. return partial(faker.random_element, choices)
  103. except ValueError:
  104. pass
  105. if isinstance(field, ForeignKeyField):
  106. kwargs.update({'_typemixer': self, '_scheme': field})
  107. return super(TypeMixer, self).make_fabric(
  108. type(field), field_name=field_name, fake=fake, kwargs=kwargs)
  109. def guard(self, *args, **kwargs):
  110. """ Look objects in database.
  111. :returns: A finded object or False
  112. """
  113. qs = self.__scheme.select().where(*args, **kwargs)
  114. count = qs.count()
  115. if count == 1:
  116. return qs.get()
  117. if count:
  118. return list(qs)
  119. return False
  120. def reload(self, obj):
  121. """ Reload object from database. """
  122. if not obj.get_id():
  123. raise ValueError("Cannot load the object: %s" % obj)
  124. return type(obj).select().where(obj._meta.primary_key == obj.get_id()).get()
  125. class Mixer(BaseMixer):
  126. """ Integration with Peewee ORM. """
  127. type_mixer_cls = TypeMixer
  128. def __init__(self, **params):
  129. """Initialize the Mixer instance."""
  130. params.setdefault('commit', True)
  131. super(Mixer, self).__init__(**params)
  132. def postprocess(self, target):
  133. """ Save objects in db.
  134. :return value: A generated value
  135. """
  136. if self.params.get('commit'):
  137. target.save()
  138. return target
  139. # Default Peewee mixer
  140. mixer = Mixer()
  141. # pylama:ignore=E1120