__init__.py 8.1 KB


  1. # -*- coding: utf-8 -*-
  2. #!/usr/bin/env python
  3. import random
  4. import decimal
  5. import pprint
  6. import inspect
  7. from typing import TYPE_CHECKING, Type
  8. import faker
  9. import mongoengine as mg
  10. from bson.objectid import ObjectId
  11. from apps.web.core.db import MonetaryField, VirtualCoinField, RatioField, PercentField, StrictDictField
  12. from apps.web.dealer.models import Dealer
  13. from apps.web.agent.models import Agent
  14. from apps.web.management.models import Manager
  15. from apps.web.device.models import Device, Group
  16. if TYPE_CHECKING:
  17. from apps.web.core.db import Searchable
  18. factory = faker.Factory()
  19. fake = factory.create()
  20. class MultiKeyDict(object):
  21. """
  22. 可以快速的设置多键为一个value的map
  23. """
  24. def __init__(self, mapping_or_iterable=None, **kwargs):
  25. """ Initializes dictionary from an optional positional argument and a possibly empty set of keyword arguments."""
  26. self.items_dict = {}
  27. if mapping_or_iterable is not None:
  28. if type(mapping_or_iterable) is dict:
  29. mapping_or_iterable = mapping_or_iterable.items()
  30. for kv in mapping_or_iterable:
  31. if len(kv) != 2:
  32. raise Exception('Iterable should contain tuples with exactly two values but specified: {0}.'.format(kv))
  33. self[kv[0]] = kv[1]
  34. for keys, value in kwargs.items():
  35. self[keys] = value
  36. def __getitem__(self, key):
  37. """ Return the value at index specified as key."""
  38. return self.items_dict.__getitem__(key)
  39. def __setitem__(self, keys, value):
  40. """ Set the value at index (or list of indexes) specified as keys.
  41. """
  42. if type(keys) in [tuple, list]:
  43. for key in keys:
  44. self.items_dict[key] = value
  45. else:
  46. self.items_dict[keys] = value
  47. def __delitem__(self, key):
  48. """ Called to implement deletion of self[key]."""
  49. self.items_dict.__delitem__(key)
  50. def __contains__(self, key):
  51. """ Returns True if this object contains an item referenced by the key."""
  52. return key in self.items_dict
  53. def has_key(self, key):
  54. """ Returns True if this object contains an item referenced by the key."""
  55. return key in self.items_dict
  56. def iteritems(self):
  57. """ Returns an iterator over the dictionary's (key, value) pairs.
  58. """
  59. for item in self.items_dict.items():
  60. yield item
  61. def iterkeys(self):
  62. """ Returns an iterator over the dictionary's keys.
  63. """
  64. for keys in self.items_dict.keys():
  65. yield keys
  66. def itervalues(self):
  67. """ Returns an iterator over the dictionary's values.
  68. """
  69. for value in self.items_dict.values():
  70. yield value
  71. def keys(self):
  72. return self.items_dict.keys()
  73. def values(self):
  74. """ Returns a copy of the dictionary's values.
  75. """
  76. return self.items_dict.values()
  77. def __len__(self):
  78. """ Returns number of objects in dictionary."""
  79. return len(self.items_dict)
  80. def get(self, key, default=None):
  81. """ Return the value at index specified as key."""
  82. return self.items_dict.get(key, default)
  83. def __str__(self):
  84. return '<MultiKeyDict items=%s>' % (self.items_dict,)
  85. class NoStrategyFound(Exception):
  86. pass
  87. class FieldMisMapped(Exception):
  88. pass
  89. referential_map = MultiKeyDict({
  90. ('dealerId', 'ownerId'): (Dealer, mg.fields.ObjectId),
  91. ('agentId',): (Agent, mg.fields.ObjectId),
  92. ('managerId',): (Manager, mg.fields.ObjectId),
  93. ('devNo', 'logicalCode'): (Device, mg.fields.StringField),
  94. ('groupId',): (Group, mg.fields.StringField)
  95. })
  96. def bigger_float(): return fake.random_number() + random.random()
  97. def check_fields():
  98. import inspect
  99. import apps.web.core.db as db_mod
  100. BaseField = mg.fields.BaseField
  101. all_customized_models = {obj for name, obj in inspect.getmembers(db_mod)
  102. if inspect.isclass(obj) and issubclass(obj, BaseField)}
  103. _field_map_keys = set(get_field_mapping().keys())
  104. if not all_customized_models.issubset(_field_map_keys):
  105. pprint.pprint('all_customized_models - _field_map_keys = %r' % (all_customized_models - _field_map_keys,))
  106. return False
  107. else:
  108. return True
  109. __default_get_field_mapping = None
  110. def get_field_mapping():
  111. # type: ()->dict
  112. global __default_get_field_mapping
  113. def _fake_StringField(field, **kwargs):
  114. if field.choices:
  115. return random.choice(field.choices)
  116. else:
  117. return fake.text().replace(' ', '-')[:15]
  118. if __default_get_field_mapping is None:
  119. __default_get_field_mapping = {
  120. mg.fields.IntField: lambda field, **kwargs: fake.random_int(min=field.min_value or 0, max=field.max_value or 999999999),
  121. mg.fields.StringField: _fake_StringField,
  122. mg.fields.LongField: lambda field, **kwargs: fake.random_number(10),
  123. mg.fields.BinaryField: lambda field, **kwargs: fake.binary(),
  124. mg.fields.BooleanField: lambda field, **kwargs: fake.boolean(),
  125. mg.fields.DateTimeField: lambda field, **kwargs: fake.date_time(),
  126. mg.fields.FloatField: lambda field, **kwargs: bigger_float,
  127. mg.fields.ListField: lambda field, **kwargs: [],
  128. mg.fields.DictField: lambda field, **kwargs: {},
  129. #: mongo force coordinate to be set as (longitude, latitude)
  130. #: https://stackoverflow.com/questions/41513112/cant-extract-geo-keys-longitude-latitude-is-out-of-bounds
  131. mg.fields.PointField: lambda field, **kwargs: list(reversed(map(float, fake.local_latlng("CN", coords_only=True)))),
  132. mg.fields.ObjectIdField: lambda field, **kwargs: ObjectId(),
  133. mg.fields.DecimalField: lambda field, **kwargs: decimal.Decimal(str(fake.random_number())),
  134. mg.fields.ReferenceField: lambda field, **kwargs: ObjectId(),
  135. StrictDictField: lambda field, **kwargs: {},
  136. MonetaryField: lambda field, **kwargs: bigger_float(),
  137. VirtualCoinField: lambda field, **kwargs: bigger_float(),
  138. RatioField: lambda field, **kwargs: random.random(),
  139. PercentField: lambda field, **kwargs: random.randint(0, 101)
  140. }
  141. return __default_get_field_mapping
  142. def get_next_type(type_):
  143. types = inspect.getmro(type_)
  144. if len(types) == 0: return types[0]
  145. else: return types[1]
  146. def match_field(field_name, field_type, field, **kwargs):
  147. """
  148. Recursively matching fields
  149. :param field_name:
  150. :param field_type:
  151. :param field:
  152. :param kwargs:
  153. :return:
  154. """
  155. assert issubclass(field_type, mg.fields.BaseField), 'field_type has to be a subclass of BaseField'
  156. top_field_type = field_type
  157. field_mapping = get_field_mapping()
  158. def match(field_type):
  159. """
  160. TODO 应该有更直接的办法, 通过关系代数简化逻辑
  161. :param field_type:
  162. :return:
  163. """
  164. if field_type == mg.fields.BaseField:
  165. raise NoStrategyFound('no strategy found for %r - %r' % (field_name, top_field_type))
  166. if field_type in field_mapping:
  167. # model, model_type = referential_map.get(field_name, (None, None))
  168. # if model:
  169. # field_value = field_mapping[model_type](field, **kwargs)
  170. # if not model.objects(field_name=field_value):
  171. # generate_model(model=model).save()
  172. # else:
  173. field_value = field_mapping[field_type](field, **kwargs)
  174. return field_value
  175. #: generate and return a EmbeddedDocument based on its own fields
  176. elif field_type == mg.fields.EmbeddedDocumentField:
  177. return field.document_type(**generate_dict(field.document_type))
  178. else:
  179. return match(get_next_type(field_type))
  180. return match(field_type)
  181. # noinspection PyUnresolvedReferences
  182. def generate_dict(model):
  183. # type:(Type[Searchable])->dict
  184. return {field_name: match_field(field_name = field_name, field_type = type(field), field = field)
  185. for field_name, field in model._fields.iteritems()}
  186. def generate_model(model):
  187. return model(**generate_dict(model))