123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- # -*- coding: utf-8 -*-
- # !/usr/bin/env python
- """该代码从 hypothesis.extra.django.models 演变而来
- 目的为了直接方便的生成一系列需要的数据模型
- 当测试需要数据的时候可直接操作,生成的数据保持与模型的对应
- 例
- models(Device).example()
- """
- from __future__ import division, print_function, absolute_import
- import string
- from decimal import Decimal
- from typing import Union
- import django.db.models as dm
- from django.db import IntegrityError
- from django.conf import settings as django_settings
- from django.core.exceptions import ValidationError
- import mongoengine as mg
- import mongoengine.fields as mg_fields
- import hypothesis.strategies as st
- from hypothesis.errors import InvalidArgument
- from hypothesis.extra.pytz import timezones
- from hypothesis.utils.conventions import UniqueIdentifier
- from hypothesis.searchstrategy.strategies import SearchStrategy
- class ModelNotSupported(Exception):
- pass
- def referenced_models(model, seen=None):
- if seen is None:
- seen = set()
- for f in model._meta.concrete_fields:
- if isinstance(f, dm.ForeignKey):
- t = f.rel.to
- if t not in seen:
- seen.add(t)
- referenced_models(t, seen)
- return seen
- def get_datetime_strat():
- if getattr(django_settings, 'USE_TZ', False):
- return st.datetimes(timezones=timezones())
- return st.datetimes()
- __default_field_mappings = None
- def field_mappings():
- global __default_field_mappings
- if __default_field_mappings is None:
- __default_field_mappings = {
- mg.fields.IntField: st.integers(-2147483648, 2147483647),
- mg.fields.LongField:
- st.integers(-9223372036854775808, 9223372036854775807),
- mg.fields.BinaryField: st.binary(),
- mg.fields.BooleanField: st.booleans(),
- mg.fields.DateTimeField: get_datetime_strat(),
- mg.fields.FloatField: st.floats(),
- mg.fields.ListField: st.lists(),
- mg.fields.PointField: st.tuples(st.floats(), st.floats()),
- }
- return __default_field_mappings
- def add_default_field_mapping(field_type, strategy):
- field_mappings()[field_type] = strategy
- default_value = UniqueIdentifier(u'default_value')
- class UnmappedFieldError(Exception):
- pass
- def validator_to_filter(f):
- """Converts the field run_validators method to something suitable for use
- in filter."""
- def validate(value):
- try:
- f.run_validators(value)
- return True
- except ValidationError:
- return False
- return validate
- safe_letters = string.ascii_letters + string.digits + '_-'
- domains = st.builds(
- lambda x, y: '.'.join(x + [y]),
- st.lists(st.text(safe_letters, min_size=1), min_size=1), st.sampled_from([
- 'com', 'net', 'org', 'biz', 'info',
- ])
- )
- email_domains = st.one_of(
- domains,
- st.sampled_from(['gmail.com', 'yahoo.com', 'hotmail.com', 'qq.com'])
- )
- base_emails = st.text(safe_letters, min_size=1)
- emails_with_plus = st.builds(
- lambda x, y: '%s+%s' % (x, y), base_emails, base_emails
- )
- emails = st.builds(
- lambda x, y: '%s@%s' % (x, y),
- st.one_of(base_emails, emails_with_plus), email_domains
- )
- def _get_strategy_for_field(f):
- # type: () -> Union[SearchStrategy, None, UniqueIdentifier]
- #: TODO to replace with mongoengine fields
- if f.choices:
- choices = [value for (value, name) in f.choices]
- if isinstance(f, (mg_fields.StringField, mg_fields.URLField)):
- choices.append(u'')
- strategy = st.sampled_from(choices)
- elif isinstance(f, mg_fields.EmailField):
- return emails
- elif type(f) in (mg_fields.StringField, ):
- strategy = st.text(min_size=f.min_length,
- max_size=f.max_length)
- elif type(f) == mg.DecimalField:
- m = 10 ** f.max_value - 1
- div = 10 ** f.precision
- q = Decimal('1.' + ('0' * f.decimal_places))
- strategy = (
- st.integers(min_value=-m, max_value=m)
- .map(lambda n: (Decimal(n) / div).quantize(q)))
- else:
- try:
- strategy = field_mappings()[type(f)]
- except KeyError:
- if f.null:
- return None
- else:
- raise UnmappedFieldError(f)
- #if f.validators:
- # strategy = strategy.filter(validator_to_filter(f))
- if f.null:
- strategy = st.one_of(st.none(), strategy)
- return strategy
- def models(model, **extra):
- result = {}
- mandatory = set()
- for f in model._meta.concrete_fields:
- try:
- strategy = _get_strategy_for_field(f)
- except UnmappedFieldError:
- mandatory.add(f.name)
- continue
- if strategy is not None:
- result[f.name] = strategy
- missed = {x for x in mandatory if x not in extra}
- if missed:
- raise InvalidArgument((
- u'Missing arguments for mandatory field%s %s for model %s' % (
- u's' if len(missed) > 1 else u'',
- u', '.join(missed),
- model.__name__,
- )))
- result.update(extra)
- # Remove default_values so we don't try to generate anything for those.
- result = {k: v for k, v in result.items() if v is not default_value}
- return ModelStrategy(model, result)
- class ModelStrategy(SearchStrategy):
- def __init__(self, model, mappings):
- super(ModelStrategy, self).__init__()
- self.model = model
- self.arg_strategy = st.fixed_dictionaries(mappings)
- def __repr__(self):
- return u'ModelStrategy(%s)' % (self.model.__name__,)
- def do_draw(self, data):
- try:
- result, _ = self.model.objects.get_or_create(
- **self.arg_strategy.do_draw(data)
- )
- return result
- except IntegrityError:
- data.mark_invalid()
|