|
- # -*- coding: utf-8 -*-
- #!/usr/bin/env python
- import random
- import decimal
- import pprint
- import inspect
- from typing import TYPE_CHECKING, Type
- import faker
- import mongoengine as mg
- from bson.objectid import ObjectId
- from apps.web.core.db import MonetaryField, VirtualCoinField, RatioField, PercentField, StrictDictField
- from apps.web.dealer.models import Dealer
- from apps.web.agent.models import Agent
- from apps.web.management.models import Manager
- from apps.web.device.models import Device, Group
- if TYPE_CHECKING:
- from apps.web.core.db import Searchable
- factory = faker.Factory()
- fake = factory.create()
- class MultiKeyDict(object):
- """
- 可以快速的设置多键为一个value的map
- """
- def __init__(self, mapping_or_iterable=None, **kwargs):
- """ Initializes dictionary from an optional positional argument and a possibly empty set of keyword arguments."""
- self.items_dict = {}
- if mapping_or_iterable is not None:
- if type(mapping_or_iterable) is dict:
- mapping_or_iterable = mapping_or_iterable.items()
- for kv in mapping_or_iterable:
- if len(kv) != 2:
- raise Exception('Iterable should contain tuples with exactly two values but specified: {0}.'.format(kv))
- self[kv[0]] = kv[1]
- for keys, value in kwargs.items():
- self[keys] = value
- def __getitem__(self, key):
- """ Return the value at index specified as key."""
- return self.items_dict.__getitem__(key)
- def __setitem__(self, keys, value):
- """ Set the value at index (or list of indexes) specified as keys.
- """
- if type(keys) in [tuple, list]:
- for key in keys:
- self.items_dict[key] = value
- else:
- self.items_dict[keys] = value
- def __delitem__(self, key):
- """ Called to implement deletion of self[key]."""
- self.items_dict.__delitem__(key)
- def __contains__(self, key):
- """ Returns True if this object contains an item referenced by the key."""
- return key in self.items_dict
- def has_key(self, key):
- """ Returns True if this object contains an item referenced by the key."""
- return key in self.items_dict
- def iteritems(self):
- """ Returns an iterator over the dictionary's (key, value) pairs.
- """
- for item in self.items_dict.items():
- yield item
- def iterkeys(self):
- """ Returns an iterator over the dictionary's keys.
- """
- for keys in self.items_dict.keys():
- yield keys
- def itervalues(self):
- """ Returns an iterator over the dictionary's values.
- """
- for value in self.items_dict.values():
- yield value
- def keys(self):
- return self.items_dict.keys()
- def values(self):
- """ Returns a copy of the dictionary's values.
- """
- return self.items_dict.values()
- def __len__(self):
- """ Returns number of objects in dictionary."""
- return len(self.items_dict)
- def get(self, key, default=None):
- """ Return the value at index specified as key."""
- return self.items_dict.get(key, default)
- def __str__(self):
- return '<MultiKeyDict items=%s>' % (self.items_dict,)
- class NoStrategyFound(Exception):
- pass
- class FieldMisMapped(Exception):
- pass
- referential_map = MultiKeyDict({
- ('dealerId', 'ownerId'): (Dealer, mg.fields.ObjectId),
- ('agentId',): (Agent, mg.fields.ObjectId),
- ('managerId',): (Manager, mg.fields.ObjectId),
- ('devNo', 'logicalCode'): (Device, mg.fields.StringField),
- ('groupId',): (Group, mg.fields.StringField)
- })
- def bigger_float(): return fake.random_number() + random.random()
- def check_fields():
- import inspect
- import apps.web.core.db as db_mod
- BaseField = mg.fields.BaseField
- all_customized_models = {obj for name, obj in inspect.getmembers(db_mod)
- if inspect.isclass(obj) and issubclass(obj, BaseField)}
- _field_map_keys = set(get_field_mapping().keys())
- if not all_customized_models.issubset(_field_map_keys):
- pprint.pprint('all_customized_models - _field_map_keys = %r' % (all_customized_models - _field_map_keys,))
- return False
- else:
- return True
- __default_get_field_mapping = None
- def get_field_mapping():
- # type: ()->dict
- global __default_get_field_mapping
- def _fake_StringField(field, **kwargs):
- if field.choices:
- return random.choice(field.choices)
- else:
- return fake.text().replace(' ', '-')[:15]
- if __default_get_field_mapping is None:
- __default_get_field_mapping = {
- mg.fields.IntField: lambda field, **kwargs: fake.random_int(min=field.min_value or 0, max=field.max_value or 999999999),
- mg.fields.StringField: _fake_StringField,
- mg.fields.LongField: lambda field, **kwargs: fake.random_number(10),
- mg.fields.BinaryField: lambda field, **kwargs: fake.binary(),
- mg.fields.BooleanField: lambda field, **kwargs: fake.boolean(),
- mg.fields.DateTimeField: lambda field, **kwargs: fake.date_time(),
- mg.fields.FloatField: lambda field, **kwargs: bigger_float,
- mg.fields.ListField: lambda field, **kwargs: [],
- mg.fields.DictField: lambda field, **kwargs: {},
- #: mongo force coordinate to be set as (longitude, latitude)
- #: https://stackoverflow.com/questions/41513112/cant-extract-geo-keys-longitude-latitude-is-out-of-bounds
- mg.fields.PointField: lambda field, **kwargs: list(reversed(map(float, fake.local_latlng("CN", coords_only=True)))),
- mg.fields.ObjectIdField: lambda field, **kwargs: ObjectId(),
- mg.fields.DecimalField: lambda field, **kwargs: decimal.Decimal(str(fake.random_number())),
- mg.fields.ReferenceField: lambda field, **kwargs: ObjectId(),
- StrictDictField: lambda field, **kwargs: {},
- MonetaryField: lambda field, **kwargs: bigger_float(),
- VirtualCoinField: lambda field, **kwargs: bigger_float(),
- RatioField: lambda field, **kwargs: random.random(),
- PercentField: lambda field, **kwargs: random.randint(0, 101)
- }
- return __default_get_field_mapping
- def get_next_type(type_):
- types = inspect.getmro(type_)
- if len(types) == 0: return types[0]
- else: return types[1]
- def match_field(field_name, field_type, field, **kwargs):
- """
- Recursively matching fields
- :param field_name:
- :param field_type:
- :param field:
- :param kwargs:
- :return:
- """
- assert issubclass(field_type, mg.fields.BaseField), 'field_type has to be a subclass of BaseField'
- top_field_type = field_type
- field_mapping = get_field_mapping()
- def match(field_type):
- """
- TODO 应该有更直接的办法, 通过关系代数简化逻辑
- :param field_type:
- :return:
- """
- if field_type == mg.fields.BaseField:
- raise NoStrategyFound('no strategy found for %r - %r' % (field_name, top_field_type))
- if field_type in field_mapping:
- # model, model_type = referential_map.get(field_name, (None, None))
- # if model:
- # field_value = field_mapping[model_type](field, **kwargs)
- # if not model.objects(field_name=field_value):
- # generate_model(model=model).save()
- # else:
- field_value = field_mapping[field_type](field, **kwargs)
- return field_value
- #: generate and return a EmbeddedDocument based on its own fields
- elif field_type == mg.fields.EmbeddedDocumentField:
- return field.document_type(**generate_dict(field.document_type))
- else:
- return match(get_next_type(field_type))
- return match(field_type)
- # noinspection PyUnresolvedReferences
- def generate_dict(model):
- # type:(Type[Searchable])->dict
- return {field_name: match_field(field_name = field_name, field_type = type(field), field = field)
- for field_name, field in model._fields.iteritems()}
- def generate_model(model):
- return model(**generate_dict(model))
|