marshmallow.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. """ Support for Marshmallow.
  2. ::
  3. from mixer.backend.marshmallow import mixer
  4. """
  5. from __future__ import absolute_import
  6. import datetime as dt
  7. import decimal
  8. from marshmallow import fields, validate, missing
  9. from .. import mix_types as t
  10. from ..main import (
  11. TypeMixer as BaseTypeMixer, Mixer as BaseMixer, GenFactory as BaseFactory,
  12. LOGGER, faker, partial, SKIP_VALUE)
  13. def get_nested(_scheme=None, _typemixer=None, _many=False, **kwargs):
  14. """Create nested objects."""
  15. obj = TypeMixer(
  16. _scheme,
  17. mixer=_typemixer._TypeMixer__mixer,
  18. factory=_typemixer._TypeMixer__factory,
  19. fake=_typemixer._TypeMixer__fake,
  20. ).blend(**kwargs)
  21. if _many:
  22. return [obj]
  23. return obj
  24. class GenFactory(BaseFactory):
  25. """Support for Marshmallow fields."""
  26. types = {
  27. (fields.Str, fields.String): str,
  28. fields.UUID: t.UUID,
  29. (fields.Number, fields.Integer, fields.Int): t.BigInteger,
  30. fields.Decimal: decimal.Decimal,
  31. (fields.Bool, fields.Boolean): bool,
  32. fields.Float: float,
  33. (fields.DateTime, fields.LocalDateTime): dt.datetime,
  34. fields.Time: dt.time,
  35. fields.Date: dt.date,
  36. (fields.URL, fields.Url): t.URL,
  37. fields.Email: t.EmailString,
  38. # fields.FormattedString
  39. # fields.TimeDelta
  40. # fields.Dict
  41. # fields.Method
  42. # fields.Function
  43. # fields.Constant
  44. }
  45. generators = {
  46. fields.DateTime: lambda: faker.date_time().isoformat(),
  47. fields.Nested: get_nested,
  48. }
  49. class TypeMixer(BaseTypeMixer):
  50. """ TypeMixer for Marshmallow. """
  51. factory = GenFactory
  52. def __load_fields(self):
  53. for name, field in self.__scheme._declared_fields.items():
  54. yield name, t.Field(field, name)
  55. def is_required(self, field):
  56. """ Return True is field's value should be defined.
  57. :return bool:
  58. """
  59. return field.scheme.required or (
  60. self.__mixer.params['required'] and not field.scheme.dump_only)
  61. @staticmethod
  62. def get_default(field):
  63. """ Get default value from field.
  64. :return value:
  65. """
  66. return field.scheme.default is missing and SKIP_VALUE or field.scheme.default # noqa
  67. def populate_target(self, values):
  68. """ Populate target. """
  69. data, errors = self.__scheme().load(dict(values))
  70. if errors:
  71. LOGGER.error("Mixer-marshmallow: %r", errors)
  72. return data
  73. def make_fabric(self, field, field_name=None, fake=False, kwargs=None): # noqa
  74. kwargs = {} if kwargs is None else kwargs
  75. if isinstance(field, fields.Nested):
  76. kwargs.update({'_typemixer': self, '_scheme': type(field.schema), '_many': field.many})
  77. if isinstance(field, fields.List):
  78. fab = self.make_fabric(
  79. field.container, field_name=field_name, fake=fake, kwargs=kwargs)
  80. return lambda: [fab() for _ in range(faker.small_positive_integer(4))]
  81. for validator in field.validators:
  82. if isinstance(validator, validate.OneOf):
  83. return partial(faker.random_element, validator.choices)
  84. return super(TypeMixer, self).make_fabric(
  85. type(field), field_name=field_name, fake=fake, kwargs=kwargs)
  86. class Mixer(BaseMixer):
  87. """ Integration with Marshmallow. """
  88. type_mixer_cls = TypeMixer
  89. def __init__(self, *args, **kwargs):
  90. super(Mixer, self).__init__(*args, **kwargs)
  91. # All fields is required by default
  92. self.params.setdefault('required', True)
  93. mixer = Mixer()