# -*- coding: utf-8 -*- # !/usr/bin/env python """该代码从 hypothesis.extra.django.models 演变而来 目的为了直接方便的生成一系列需要的数据模型 当测试需要数据的时候可直接操作,生成的数据保持与模型的对应 例 models(Device).example() """ from __future__ import division, print_function, absolute_import import string from decimal import Decimal from typing import Union import django.db.models as dm from django.db import IntegrityError from django.conf import settings as django_settings from django.core.exceptions import ValidationError import mongoengine as mg import mongoengine.fields as mg_fields import hypothesis.strategies as st from hypothesis.errors import InvalidArgument from hypothesis.extra.pytz import timezones from hypothesis.utils.conventions import UniqueIdentifier from hypothesis.searchstrategy.strategies import SearchStrategy class ModelNotSupported(Exception): pass def referenced_models(model, seen=None): if seen is None: seen = set() for f in model._meta.concrete_fields: if isinstance(f, dm.ForeignKey): t = f.rel.to if t not in seen: seen.add(t) referenced_models(t, seen) return seen def get_datetime_strat(): if getattr(django_settings, 'USE_TZ', False): return st.datetimes(timezones=timezones()) return st.datetimes() __default_field_mappings = None def field_mappings(): global __default_field_mappings if __default_field_mappings is None: __default_field_mappings = { mg.fields.IntField: st.integers(-2147483648, 2147483647), mg.fields.LongField: st.integers(-9223372036854775808, 9223372036854775807), mg.fields.BinaryField: st.binary(), mg.fields.BooleanField: st.booleans(), mg.fields.DateTimeField: get_datetime_strat(), mg.fields.FloatField: st.floats(), mg.fields.ListField: st.lists(), mg.fields.PointField: st.tuples(st.floats(), st.floats()), } return __default_field_mappings def add_default_field_mapping(field_type, strategy): field_mappings()[field_type] = strategy default_value = UniqueIdentifier(u'default_value') class UnmappedFieldError(Exception): pass def validator_to_filter(f): """Converts the field run_validators method to something suitable for use in filter.""" def validate(value): try: f.run_validators(value) return True except ValidationError: return False return validate safe_letters = string.ascii_letters + string.digits + '_-' domains = st.builds( lambda x, y: '.'.join(x + [y]), st.lists(st.text(safe_letters, min_size=1), min_size=1), st.sampled_from([ 'com', 'net', 'org', 'biz', 'info', ]) ) email_domains = st.one_of( domains, st.sampled_from(['gmail.com', 'yahoo.com', 'hotmail.com', 'qq.com']) ) base_emails = st.text(safe_letters, min_size=1) emails_with_plus = st.builds( lambda x, y: '%s+%s' % (x, y), base_emails, base_emails ) emails = st.builds( lambda x, y: '%s@%s' % (x, y), st.one_of(base_emails, emails_with_plus), email_domains ) def _get_strategy_for_field(f): # type: () -> Union[SearchStrategy, None, UniqueIdentifier] #: TODO to replace with mongoengine fields if f.choices: choices = [value for (value, name) in f.choices] if isinstance(f, (mg_fields.StringField, mg_fields.URLField)): choices.append(u'') strategy = st.sampled_from(choices) elif isinstance(f, mg_fields.EmailField): return emails elif type(f) in (mg_fields.StringField, ): strategy = st.text(min_size=f.min_length, max_size=f.max_length) elif type(f) == mg.DecimalField: m = 10 ** f.max_value - 1 div = 10 ** f.precision q = Decimal('1.' + ('0' * f.decimal_places)) strategy = ( st.integers(min_value=-m, max_value=m) .map(lambda n: (Decimal(n) / div).quantize(q))) else: try: strategy = field_mappings()[type(f)] except KeyError: if f.null: return None else: raise UnmappedFieldError(f) #if f.validators: # strategy = strategy.filter(validator_to_filter(f)) if f.null: strategy = st.one_of(st.none(), strategy) return strategy def models(model, **extra): result = {} mandatory = set() for f in model._meta.concrete_fields: try: strategy = _get_strategy_for_field(f) except UnmappedFieldError: mandatory.add(f.name) continue if strategy is not None: result[f.name] = strategy missed = {x for x in mandatory if x not in extra} if missed: raise InvalidArgument(( u'Missing arguments for mandatory field%s %s for model %s' % ( u's' if len(missed) > 1 else u'', u', '.join(missed), model.__name__, ))) result.update(extra) # Remove default_values so we don't try to generate anything for those. result = {k: v for k, v in result.items() if v is not default_value} return ModelStrategy(model, result) class ModelStrategy(SearchStrategy): def __init__(self, model, mappings): super(ModelStrategy, self).__init__() self.model = model self.arg_strategy = st.fixed_dictionaries(mappings) def __repr__(self): return u'ModelStrategy(%s)' % (self.model.__name__,) def do_draw(self, data): try: result, _ = self.model.objects.get_or_create( **self.arg_strategy.do_draw(data) ) return result except IntegrityError: data.mark_invalid()