django.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433
  1. """ Django support. """
  2. from __future__ import absolute_import
  3. import datetime as dt
  4. import decimal
  5. from os import path
  6. from types import GeneratorType
  7. from django.apps import apps
  8. from django.conf import settings
  9. from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation # noqa
  10. from django.contrib.contenttypes.models import ContentType
  11. from django.core.files.base import ContentFile
  12. from django.core.validators import validate_ipv4_address, validate_ipv6_address
  13. from django.db import models
  14. from .. import mix_types as t, _compat as _
  15. from ..main import (
  16. SKIP_VALUE, TypeMixerMeta as BaseTypeMixerMeta, TypeMixer as BaseTypeMixer,
  17. GenFactory as BaseFactory, Mixer as BaseMixer, partial, faker)
  18. get_contentfile = ContentFile
  19. MOCK_FILE = path.abspath(path.join(
  20. path.dirname(path.dirname(__file__)), 'resources', 'file.txt'
  21. ))
  22. MOCK_IMAGE = path.abspath(path.join(
  23. path.dirname(path.dirname(__file__)), 'resources', 'image.gif'
  24. ))
  25. class UTCZone(dt.tzinfo):
  26. """ Implement UTC timezone. """
  27. utcoffset = dst = lambda s, d: dt.timedelta(0)
  28. tzname = lambda s, d: "UTC"
  29. UTC = UTCZone()
  30. def get_file(filepath=MOCK_FILE, **kwargs):
  31. """ Generate a content file.
  32. :return ContentFile:
  33. """
  34. with open(filepath, 'rb') as f:
  35. name = path.basename(filepath)
  36. return get_contentfile(f.read(), name)
  37. def get_image(filepath=MOCK_IMAGE):
  38. """ Generate a content image.
  39. :return ContentFile:
  40. """
  41. return get_file(filepath)
  42. def get_relation(_scheme=None, _typemixer=None, **params):
  43. """ Function description. """
  44. scheme = _scheme.related_model
  45. if scheme is ContentType:
  46. choices = [m for m in apps.get_models() if m is not ContentType]
  47. return ContentType.objects.get_for_model(faker.random_element(choices))
  48. return TypeMixer(scheme, mixer=_typemixer._TypeMixer__mixer,
  49. factory=_typemixer._TypeMixer__factory,
  50. fake=_typemixer._TypeMixer__fake,).blend(**params)
  51. def get_datetime(**params):
  52. """ Support Django TZ support. """
  53. return faker.date_time(tzinfo=UTC if settings.USE_TZ else None)
  54. class GenFactory(BaseFactory):
  55. """ Map a django classes to simple types. """
  56. types = {
  57. (models.AutoField, models.PositiveIntegerField): t.PositiveInteger,
  58. models.BigIntegerField: t.BigInteger,
  59. models.BooleanField: bool,
  60. (models.CharField, models.SlugField): str,
  61. models.DateField: dt.date,
  62. models.DecimalField: decimal.Decimal,
  63. models.EmailField: t.EmailString,
  64. models.FloatField: float,
  65. models.GenericIPAddressField: t.IPString,
  66. models.IPAddressField: t.IP4String,
  67. models.IntegerField: int,
  68. models.PositiveSmallIntegerField: t.PositiveSmallInteger,
  69. models.SmallIntegerField: t.SmallInteger,
  70. models.TextField: t.Text,
  71. models.TimeField: dt.time,
  72. models.URLField: t.URL,
  73. }
  74. generators = {
  75. models.BinaryField: faker.pybytes,
  76. models.DateTimeField: get_datetime,
  77. models.FileField: get_file,
  78. models.FilePathField: lambda: MOCK_FILE,
  79. models.ForeignKey: get_relation,
  80. models.ImageField: get_image,
  81. models.ManyToManyField: get_relation,
  82. models.OneToOneField: get_relation,
  83. }
  84. class TypeMixerMeta(BaseTypeMixerMeta):
  85. """ Load django models from strings. """
  86. def __new__(mcs, name, bases, params):
  87. """ Associate Scheme with Django models.
  88. Cache Django models.
  89. :return mixer.backend.django.TypeMixer: A generated class.
  90. """
  91. params['models_cache'] = dict()
  92. cls = super(TypeMixerMeta, mcs).__new__(mcs, name, bases, params)
  93. return cls
  94. def __load_cls(cls, cls_type):
  95. if isinstance(cls_type, _.string_types):
  96. if '.' in cls_type:
  97. app_label, model_name = cls_type.split(".")
  98. return apps.get_model(app_label, model_name)
  99. else:
  100. try:
  101. if cls_type not in cls.models_cache:
  102. cls.__update_cache()
  103. return cls.models_cache[cls_type]
  104. except KeyError:
  105. raise ValueError('Model "%s" not found.' % cls_type)
  106. return cls_type
  107. def __update_cache(cls):
  108. """ Update apps cache for Django < 1.7. """
  109. for app in apps.all_models:
  110. for name, model in apps.all_models[app].items():
  111. cls.models_cache[name] = model
  112. class TypeMixer(_.with_metaclass(TypeMixerMeta, BaseTypeMixer)):
  113. """ TypeMixer for Django. """
  114. __metaclass__ = TypeMixerMeta
  115. factory = GenFactory
  116. def postprocess(self, target, postprocess_values):
  117. """ Fill postprocess_values. """
  118. for name, deffered in postprocess_values:
  119. if not isinstance(deffered.scheme, GenericForeignKey):
  120. continue
  121. name, value = self._get_value(name, deffered.value)
  122. setattr(target, name, value)
  123. if self.__mixer:
  124. target = self.__mixer.postprocess(target)
  125. for name, deffered in postprocess_values:
  126. if isinstance(deffered.scheme, GenericForeignKey) or not target.pk:
  127. continue
  128. name, value = self._get_value(name, deffered.value)
  129. # # If the ManyToMany relation has an intermediary model,
  130. # # the add and remove methods do not exist.
  131. through = deffered.scheme.remote_field.through
  132. if not through._meta.auto_created and self.__mixer: # noqa
  133. self.__mixer.blend(
  134. through, **{
  135. deffered.scheme.m2m_field_name(): target,
  136. deffered.scheme.m2m_reverse_field_name(): value})
  137. continue
  138. if not isinstance(value, (list, tuple)):
  139. value = [value]
  140. getattr(target, name).set(value)
  141. return target
  142. def get_value(self, name, value):
  143. """ Set value to generated instance.
  144. :return : None or (name, value) for later use
  145. """
  146. field = self.__fields.get(name)
  147. if field:
  148. if (field.scheme in self.__scheme._meta.local_many_to_many or
  149. isinstance(field.scheme, GenericForeignKey)):
  150. return name, t._Deffered(value, field.scheme)
  151. return self._get_value(name, value, field)
  152. return super(TypeMixer, self).get_value(name, value)
  153. def _get_value(self, name, value, field=None):
  154. if isinstance(value, GeneratorType):
  155. return self._get_value(name, next(value), field)
  156. if not isinstance(value, t.Mix) and value is not SKIP_VALUE:
  157. if callable(value):
  158. return self._get_value(name, value(), field)
  159. if field and not isinstance(field.scheme, models.ForeignKey):
  160. value = field.scheme.to_python(value)
  161. return name, value
  162. def gen_select(self, field_name, select):
  163. """ Select exists value from database.
  164. :param field_name: Name of field for generation.
  165. :return : None or (name, value) for later use
  166. """
  167. if field_name not in self.__fields:
  168. return field_name, None
  169. try:
  170. field = self.__fields[field_name]
  171. return field.name, field.scheme.remote_field.model.objects.filter(**select.params).\
  172. order_by('?')[0]
  173. except Exception:
  174. raise Exception("Cannot find a value for the field: '{0}'".format(field_name))
  175. def gen_field(self, field):
  176. """ Generate value by field.
  177. :param relation: Instance of :class:`Field`
  178. :return : None or (name, value) for later use
  179. """
  180. if isinstance(field.scheme, GenericForeignKey):
  181. return field.name, SKIP_VALUE
  182. if field.params and not field.scheme:
  183. raise ValueError('Invalid relation %s' % field.name)
  184. return super(TypeMixer, self).gen_field(field)
  185. def make_fabric(self, field, fname=None, fake=False, kwargs=None): # noqa
  186. """ Make a fabric for field.
  187. :param field: A mixer field
  188. :param fname: Field name
  189. :param fake: Force fake data
  190. :return function:
  191. """
  192. kwargs = {} if kwargs is None else kwargs
  193. fcls = type(field)
  194. stype = self.__factory.cls_to_simple(fcls)
  195. if fcls is models.CommaSeparatedIntegerField:
  196. return partial(
  197. faker.random_sample, range(0, field.max_length), length=field.max_length)
  198. if field and field.choices:
  199. try:
  200. choices, _ = list(zip(*field.choices))
  201. return partial(faker.random_element, choices)
  202. except ValueError:
  203. pass
  204. if stype in (str, t.Text):
  205. fab = super(TypeMixer, self).make_fabric(
  206. fcls, field_name=fname, fake=fake, kwargs=kwargs)
  207. return lambda: fab()[:field.max_length]
  208. if stype is decimal.Decimal:
  209. kwargs['left_digits'] = field.max_digits - field.decimal_places
  210. kwargs['right_digits'] = field.decimal_places
  211. elif stype is t.IPString:
  212. # Hack for support Django 1.4/1.5
  213. protocol = getattr(field, 'protocol', None)
  214. if not protocol:
  215. validator = field.default_validators[0]
  216. protocol = 'both'
  217. if validator is validate_ipv4_address:
  218. protocol = 'ipv4'
  219. elif validator is validate_ipv6_address:
  220. protocol = 'ipv6'
  221. # protocol matching is case insensitive
  222. # default address is either IPv4 or IPv6
  223. kwargs['protocol'] = protocol.lower()
  224. elif isinstance(field, models.fields.related.RelatedField):
  225. kwargs.update({'_typemixer': self, '_scheme': field})
  226. return super(TypeMixer, self).make_fabric(
  227. fcls, field_name=fname, fake=fake, kwargs=kwargs)
  228. @staticmethod
  229. def is_unique(field):
  230. """ Return True is field's value should be a unique.
  231. :return bool:
  232. """
  233. return field.scheme.unique
  234. @staticmethod
  235. def is_required(field):
  236. """ Return True is field's value should be defined.
  237. :return bool:
  238. """
  239. if field.params:
  240. return True
  241. if field.scheme.has_default() or field.scheme.null and field.scheme.blank:
  242. return False
  243. if field.scheme.auto_created:
  244. return False
  245. if isinstance(field.scheme, models.ManyToManyField):
  246. return False
  247. if isinstance(field.scheme, GenericRelation):
  248. return False
  249. return True
  250. def guard(self, *args, **kwargs):
  251. """ Look objects in database.
  252. :returns: A finded object or False
  253. """
  254. qs = self.__scheme.objects.filter(*args, **kwargs)
  255. count = qs.count()
  256. if count == 1:
  257. return qs.get()
  258. if count:
  259. return list(qs)
  260. return False
  261. def reload(self, obj):
  262. """ Reload object from database. """
  263. if not obj.pk:
  264. raise ValueError("Cannot load the object: %s" % obj)
  265. return self.__scheme._default_manager.get(pk=obj.pk)
  266. def __load_fields(self):
  267. private_fields = getattr(self.__scheme._meta, 'private_fields', [])
  268. for field in private_fields:
  269. yield field.name, t.Field(field, field.name)
  270. for field in self.__scheme._meta.fields:
  271. if isinstance(field, models.AutoField)\
  272. and self.__mixer and self.__mixer.params.get('commit'):
  273. continue
  274. yield field.name, t.Field(field, field.name)
  275. for field in self.__scheme._meta.local_many_to_many:
  276. yield field.name, t.Field(field, field.name)
  277. class Mixer(BaseMixer):
  278. """ Integration with Django. """
  279. type_mixer_cls = TypeMixer
  280. def __init__(self, commit=True, **params):
  281. """Initialize Mixer instance.
  282. :param commit: (True) Save object to database.
  283. """
  284. super(Mixer, self).__init__(**params)
  285. self.params['commit'] = commit
  286. def postprocess(self, target):
  287. """ Save objects in db.
  288. :return value: A generated value
  289. """
  290. if self.params.get('commit'):
  291. target.save()
  292. return target
  293. # Default mixer
  294. mixer = Mixer()
  295. # pylama:ignore=E1120