models.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # coding=utf-8
  2. #
  3. # This file is part of Hypothesis, which may be found at
  4. # https://github.com/HypothesisWorks/hypothesis-python
  5. #
  6. # Most of this work is copyright (C) 2013-2018 David R. MacIver
  7. # (david@drmaciver.com), but it contains contributions by others. See
  8. # CONTRIBUTING.rst for a full list of people who may hold copyright, and
  9. # consult the git log if you need to determine who owns an individual
  10. # contribution.
  11. #
  12. # This Source Code Form is subject to the terms of the Mozilla Public License,
  13. # v. 2.0. If a copy of the MPL was not distributed with this file, You can
  14. # obtain one at http://mozilla.org/MPL/2.0/.
  15. #
  16. # END HEADER
  17. from __future__ import division, print_function, absolute_import
  18. import string
  19. from decimal import Decimal
  20. from datetime import timedelta
  21. import django.db.models as dm
  22. from django.db import IntegrityError
  23. from django.conf import settings as django_settings
  24. from django.core.exceptions import ValidationError
  25. import hypothesis.strategies as st
  26. from hypothesis import reject
  27. from hypothesis.errors import InvalidArgument
  28. from hypothesis.extra.pytz import timezones
  29. from hypothesis.provisional import emails, ip4_addr_strings, \
  30. ip6_addr_strings
  31. from hypothesis.utils.conventions import UniqueIdentifier
  32. def get_tz_strat():
  33. if getattr(django_settings, 'USE_TZ', False):
  34. return timezones()
  35. return st.none()
  36. __default_field_mappings = None
  37. def field_mappings():
  38. global __default_field_mappings
  39. if __default_field_mappings is None:
  40. # Sized fields are handled in _get_strategy_for_field()
  41. # URL fields are not yet handled
  42. __default_field_mappings = {
  43. dm.SmallIntegerField: st.integers(-32768, 32767),
  44. dm.IntegerField: st.integers(-2147483648, 2147483647),
  45. dm.BigIntegerField:
  46. st.integers(-9223372036854775808, 9223372036854775807),
  47. dm.PositiveIntegerField: st.integers(0, 2147483647),
  48. dm.PositiveSmallIntegerField: st.integers(0, 32767),
  49. dm.BinaryField: st.binary(),
  50. dm.BooleanField: st.booleans(),
  51. dm.DateField: st.dates(),
  52. dm.DateTimeField: st.datetimes(timezones=get_tz_strat()),
  53. dm.DurationField: st.timedeltas(),
  54. dm.EmailField: emails(),
  55. dm.FloatField: st.floats(),
  56. dm.NullBooleanField: st.one_of(st.none(), st.booleans()),
  57. dm.TimeField: st.times(timezones=get_tz_strat()),
  58. dm.UUIDField: st.uuids(),
  59. }
  60. # SQLite does not support timezone-aware times, or timedeltas that
  61. # don't fit in six bytes of microseconds, so we override those
  62. db = getattr(django_settings, 'DATABASES', {}).get('default', {})
  63. if db.get('ENGINE', '').endswith('.sqlite3'): # pragma: no branch
  64. sqlite_delta = timedelta(microseconds=2 ** 47 - 1)
  65. __default_field_mappings.update({
  66. dm.TimeField: st.times(),
  67. dm.DurationField: st.timedeltas(-sqlite_delta, sqlite_delta),
  68. })
  69. return __default_field_mappings
  70. def add_default_field_mapping(field_type, strategy):
  71. field_mappings()[field_type] = strategy
  72. default_value = UniqueIdentifier(u'default_value')
  73. def validator_to_filter(f):
  74. """Converts the field run_validators method to something suitable for use
  75. in filter."""
  76. def validate(value):
  77. try:
  78. f.run_validators(value)
  79. return True
  80. except ValidationError:
  81. return False
  82. return validate
  83. def _get_strategy_for_field(f):
  84. if f.choices:
  85. choices = []
  86. for value, name_or_optgroup in f.choices:
  87. if isinstance(name_or_optgroup, (list, tuple)):
  88. choices.extend(key for key, _ in name_or_optgroup)
  89. else:
  90. choices.append(value)
  91. if isinstance(f, (dm.CharField, dm.TextField)) and f.blank:
  92. choices.insert(0, u'')
  93. strategy = st.sampled_from(choices)
  94. elif type(f) == dm.SlugField:
  95. strategy = st.text(alphabet=string.ascii_letters + string.digits,
  96. min_size=(None if f.blank else 1),
  97. max_size=f.max_length)
  98. elif type(f) == dm.GenericIPAddressField:
  99. lookup = {'both': ip4_addr_strings() | ip6_addr_strings(),
  100. 'ipv4': ip4_addr_strings(), 'ipv6': ip6_addr_strings()}
  101. strategy = lookup[f.protocol.lower()]
  102. elif type(f) in (dm.TextField, dm.CharField):
  103. strategy = st.text(min_size=(None if f.blank else 1),
  104. max_size=f.max_length)
  105. elif type(f) == dm.DecimalField:
  106. bound = Decimal(10 ** f.max_digits - 1) / (10 ** f.decimal_places)
  107. strategy = st.decimals(min_value=-bound, max_value=bound,
  108. places=f.decimal_places)
  109. else:
  110. strategy = field_mappings().get(type(f), st.nothing())
  111. if f.validators:
  112. strategy = strategy.filter(validator_to_filter(f))
  113. if f.null:
  114. strategy = st.one_of(st.none(), strategy)
  115. return strategy
  116. def models(model, **extra):
  117. """Return a strategy for instances of a model."""
  118. result = {k: v for k, v in extra.items() if v is not default_value}
  119. missed = []
  120. for f in model._meta.concrete_fields:
  121. if not (f.name in extra or isinstance(f, dm.AutoField)):
  122. result[f.name] = _get_strategy_for_field(f)
  123. if result[f.name].is_empty:
  124. missed.append(f.name)
  125. if missed:
  126. raise InvalidArgument(
  127. u'Missing arguments for mandatory field%s %s for model %s'
  128. % (u's' if missed else u'', u', '.join(missed), model.__name__))
  129. return _models_impl(st.builds(model.objects.get_or_create, **result))
  130. @st.composite
  131. def _models_impl(draw, strat):
  132. """Handle the nasty part of drawing a value for models()"""
  133. try:
  134. return draw(strat)[0]
  135. except IntegrityError:
  136. reject()