hypothesis_mongoengine.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # -*- coding: utf-8 -*-
  2. # !/usr/bin/env python
  3. """该代码从 hypothesis.extra.django.models 演变而来
  4. 目的为了直接方便的生成一系列需要的数据模型
  5. 当测试需要数据的时候可直接操作,生成的数据保持与模型的对应
  6. models(Device).example()
  7. """
  8. from __future__ import division, print_function, absolute_import
  9. import string
  10. from decimal import Decimal
  11. from typing import Union
  12. import django.db.models as dm
  13. from django.db import IntegrityError
  14. from django.conf import settings as django_settings
  15. from django.core.exceptions import ValidationError
  16. import mongoengine as mg
  17. import mongoengine.fields as mg_fields
  18. import hypothesis.strategies as st
  19. from hypothesis.errors import InvalidArgument
  20. from hypothesis.extra.pytz import timezones
  21. from hypothesis.utils.conventions import UniqueIdentifier
  22. from hypothesis.searchstrategy.strategies import SearchStrategy
  23. class ModelNotSupported(Exception):
  24. pass
  25. def referenced_models(model, seen=None):
  26. if seen is None:
  27. seen = set()
  28. for f in model._meta.concrete_fields:
  29. if isinstance(f, dm.ForeignKey):
  30. t = f.rel.to
  31. if t not in seen:
  32. seen.add(t)
  33. referenced_models(t, seen)
  34. return seen
  35. def get_datetime_strat():
  36. if getattr(django_settings, 'USE_TZ', False):
  37. return st.datetimes(timezones=timezones())
  38. return st.datetimes()
  39. __default_field_mappings = None
  40. def field_mappings():
  41. global __default_field_mappings
  42. if __default_field_mappings is None:
  43. __default_field_mappings = {
  44. mg.fields.IntField: st.integers(-2147483648, 2147483647),
  45. mg.fields.LongField:
  46. st.integers(-9223372036854775808, 9223372036854775807),
  47. mg.fields.BinaryField: st.binary(),
  48. mg.fields.BooleanField: st.booleans(),
  49. mg.fields.DateTimeField: get_datetime_strat(),
  50. mg.fields.FloatField: st.floats(),
  51. mg.fields.ListField: st.lists(),
  52. mg.fields.PointField: st.tuples(st.floats(), st.floats()),
  53. }
  54. return __default_field_mappings
  55. def add_default_field_mapping(field_type, strategy):
  56. field_mappings()[field_type] = strategy
  57. default_value = UniqueIdentifier(u'default_value')
  58. class UnmappedFieldError(Exception):
  59. pass
  60. def validator_to_filter(f):
  61. """Converts the field run_validators method to something suitable for use
  62. in filter."""
  63. def validate(value):
  64. try:
  65. f.run_validators(value)
  66. return True
  67. except ValidationError:
  68. return False
  69. return validate
  70. safe_letters = string.ascii_letters + string.digits + '_-'
  71. domains = st.builds(
  72. lambda x, y: '.'.join(x + [y]),
  73. st.lists(st.text(safe_letters, min_size=1), min_size=1), st.sampled_from([
  74. 'com', 'net', 'org', 'biz', 'info',
  75. ])
  76. )
  77. email_domains = st.one_of(
  78. domains,
  79. st.sampled_from(['gmail.com', 'yahoo.com', 'hotmail.com', 'qq.com'])
  80. )
  81. base_emails = st.text(safe_letters, min_size=1)
  82. emails_with_plus = st.builds(
  83. lambda x, y: '%s+%s' % (x, y), base_emails, base_emails
  84. )
  85. emails = st.builds(
  86. lambda x, y: '%s@%s' % (x, y),
  87. st.one_of(base_emails, emails_with_plus), email_domains
  88. )
  89. def _get_strategy_for_field(f):
  90. # type: () -> Union[SearchStrategy, None, UniqueIdentifier]
  91. #: TODO to replace with mongoengine fields
  92. if f.choices:
  93. choices = [value for (value, name) in f.choices]
  94. if isinstance(f, (mg_fields.StringField, mg_fields.URLField)):
  95. choices.append(u'')
  96. strategy = st.sampled_from(choices)
  97. elif isinstance(f, mg_fields.EmailField):
  98. return emails
  99. elif type(f) in (mg_fields.StringField, ):
  100. strategy = st.text(min_size=f.min_length,
  101. max_size=f.max_length)
  102. elif type(f) == mg.DecimalField:
  103. m = 10 ** f.max_value - 1
  104. div = 10 ** f.precision
  105. q = Decimal('1.' + ('0' * f.decimal_places))
  106. strategy = (
  107. st.integers(min_value=-m, max_value=m)
  108. .map(lambda n: (Decimal(n) / div).quantize(q)))
  109. else:
  110. try:
  111. strategy = field_mappings()[type(f)]
  112. except KeyError:
  113. if f.null:
  114. return None
  115. else:
  116. raise UnmappedFieldError(f)
  117. #if f.validators:
  118. # strategy = strategy.filter(validator_to_filter(f))
  119. if f.null:
  120. strategy = st.one_of(st.none(), strategy)
  121. return strategy
  122. def models(model, **extra):
  123. result = {}
  124. mandatory = set()
  125. for f in model._meta.concrete_fields:
  126. try:
  127. strategy = _get_strategy_for_field(f)
  128. except UnmappedFieldError:
  129. mandatory.add(f.name)
  130. continue
  131. if strategy is not None:
  132. result[f.name] = strategy
  133. missed = {x for x in mandatory if x not in extra}
  134. if missed:
  135. raise InvalidArgument((
  136. u'Missing arguments for mandatory field%s %s for model %s' % (
  137. u's' if len(missed) > 1 else u'',
  138. u', '.join(missed),
  139. model.__name__,
  140. )))
  141. result.update(extra)
  142. # Remove default_values so we don't try to generate anything for those.
  143. result = {k: v for k, v in result.items() if v is not default_value}
  144. return ModelStrategy(model, result)
  145. class ModelStrategy(SearchStrategy):
  146. def __init__(self, model, mappings):
  147. super(ModelStrategy, self).__init__()
  148. self.model = model
  149. self.arg_strategy = st.fixed_dictionaries(mappings)
  150. def __repr__(self):
  151. return u'ModelStrategy(%s)' % (self.model.__name__,)
  152. def do_draw(self, data):
  153. try:
  154. result, _ = self.model.objects.get_or_create(
  155. **self.arg_strategy.do_draw(data)
  156. )
  157. return result
  158. except IntegrityError:
  159. data.mark_invalid()