db.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727
  1. # -*- coding: utf-8 -*-
  2. # !/usr/bin/env python
  3. import datetime
  4. import decimal
  5. import json
  6. import logging
  7. import re
  8. import time
  9. from collections import Callable
  10. from copy import deepcopy
  11. from decimal import Decimal
  12. from functools import reduce
  13. import inspect
  14. import bson
  15. import six
  16. from bson import DBRef, ObjectId
  17. from bson.decimal128 import Decimal128
  18. from bson.regex import Regex
  19. from django.http.request import QueryDict
  20. from mongoengine import Q, Document, DynamicDocument, QuerySet, DecimalField, DictField, SequenceField, BooleanField
  21. from mongoengine.base import BaseField
  22. from mongoengine.common import _import_class
  23. from typing import Union, TYPE_CHECKING
  24. from apilib.monetary import RMB, VirtualCoin, Ratio, Percent, Money, Permillage, AccuracyRMB
  25. from apilib.quantity import Quantity
  26. from apilib.utils_string import md5
  27. from apps.web.constant import Const
  28. from apps.web.core.exceptions import ParameterError, ImproperlyConfigured
  29. if TYPE_CHECKING:
  30. from pymongo.collection import Collection
  31. logger = logging.getLogger(__name__)
  32. def copy_document_classes(clazz, new_name, db_alias):
  33. exclude = ['_fields', '_db_field_map', '_reverse_db_field_map', '_fields_ordered', '_is_document',
  34. 'MultipleObjectsReturned', '_superclasses', '_subclasses', '_types', '_class_name',
  35. '_meta', '__doc__', '__module__', '_collection', '_is_base_cls', '_auto_id_field', 'id',
  36. 'DoesNotExist', 'objects', '_cached_reference_fields']
  37. new_cls_key = '{}_{}'.format(db_alias, clazz.__name__)
  38. if new_cls_key not in clazz.new_cls_list.keys():
  39. dicts = {}
  40. for name, field in clazz.__dict__.iteritems():
  41. if name in exclude:
  42. continue
  43. else:
  44. if isinstance(field, BaseField):
  45. if field.unique:
  46. field.unique = False
  47. dicts[name] = field
  48. assert hasattr(clazz, '_origin_meta'), u'必须提供_origin_meta'
  49. dicts['meta'] = deepcopy(clazz._origin_meta)
  50. dicts['meta'].pop('indexes', None)
  51. dicts['meta']['index_background'] = True
  52. dicts['meta']['auto_create_index'] = False
  53. dicts['meta']['shard_key'] = clazz._shard_key
  54. dicts['__module__'] = clazz.__module__
  55. dicts['meta']['db_alias'] = db_alias
  56. new_cls = type(new_name, clazz.__bases__, dicts)
  57. for name, field in clazz.__dict__.iteritems():
  58. if name in exclude:
  59. continue
  60. else:
  61. field = getattr(clazz, name)
  62. if isinstance(field, BaseField):
  63. continue
  64. else:
  65. if inspect.ismethod(field):
  66. if not field.im_self:
  67. setattr(new_cls, name, field.im_func)
  68. else:
  69. setattr(new_cls, name, classmethod(field.im_func))
  70. elif inspect.isfunction(field):
  71. setattr(new_cls, name, staticmethod(field))
  72. else:
  73. setattr(new_cls, name, field)
  74. clazz.new_cls_list[new_cls_key] = new_cls
  75. else:
  76. pass
  77. return clazz.new_cls_list[new_cls_key]
  78. # Customized fields
  79. class MonetaryField(DecimalField):
  80. def __init__(self, min_value = None, max_value = None, force_string = False,
  81. precision = 4, rounding = decimal.ROUND_HALF_UP, **kwargs):
  82. #: 默认精度四位,展示2-3位
  83. super(MonetaryField, self).__init__(min_value = min_value, max_value = max_value,
  84. force_string = force_string, precision = precision, rounding = rounding,
  85. **kwargs)
  86. def to_mongo(self, value):
  87. value = self.to_python(value) # type: Union[RMB, VirtualCoin]
  88. return Decimal128(value.amount)
  89. def to_python(self, value, precision=3):
  90. if value is None:
  91. return value
  92. try:
  93. value = decimal.Decimal('%s' % value)
  94. except (TypeError, ValueError, decimal.InvalidOperation):
  95. return RMB(value)
  96. return RMB(
  97. value.quantize(decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding = self.rounding))
  98. def validate(self, value):
  99. value = self.to_python(value)
  100. if self.min_value is not None and value < self.min_value:
  101. self.error('Monetary value is too small')
  102. if self.max_value is not None and value > self.max_value:
  103. self.error('Monetary value is too large')
  104. class AccuracyMoneyField(MonetaryField):
  105. def to_mongo(self, value):
  106. value = self.to_python(value) # type: Union[AccuracyRMB]
  107. return Decimal128(value.amount)
  108. def to_python(self, value, precision=5):
  109. if value is None:
  110. return value
  111. try:
  112. value = decimal.Decimal('%s' % value)
  113. except (TypeError, ValueError, decimal.InvalidOperation):
  114. return AccuracyRMB(value)
  115. return AccuracyRMB(value.quantize(decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding=self.rounding))
  116. class VirtualCoinField(MonetaryField):
  117. def to_python(self, value, precision = 2):
  118. if value is None:
  119. return value
  120. # Convert to string for python 2.6 before casting to Decimal
  121. try:
  122. value = decimal.Decimal('%s' % value)
  123. except (TypeError, ValueError, decimal.InvalidOperation):
  124. return VirtualCoin(value)
  125. return VirtualCoin(value.quantize(
  126. decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding = self.rounding))
  127. class RatioField(MonetaryField):
  128. def to_python(self, value, precision = 3):
  129. if value is None:
  130. return value
  131. # Convert to string for python 2.6 before casting to Decimal
  132. try:
  133. value = decimal.Decimal('%s' % value)
  134. except (TypeError, ValueError, decimal.InvalidOperation):
  135. return Ratio(value)
  136. return Ratio(
  137. value.quantize(decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding = self.rounding))
  138. class PercentField(MonetaryField):
  139. """
  140. 百分之
  141. """
  142. def to_python(self, value, precision = 2):
  143. if value is None:
  144. return value
  145. # Convert to string for python 2.6 before casting to Decimal
  146. try:
  147. value = decimal.Decimal('%s' % value)
  148. except (TypeError, ValueError, decimal.InvalidOperation):
  149. return Percent(value)
  150. return Percent(
  151. value.quantize(decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding = self.rounding))
  152. class PermillageField(MonetaryField):
  153. """
  154. """
  155. def to_python(self, value, precision = 2):
  156. if value is None:
  157. return value
  158. # Convert to string for python 2.6 before casting to Decimal
  159. try:
  160. value = decimal.Decimal('%s' % value)
  161. except (TypeError, ValueError, decimal.InvalidOperation):
  162. return Permillage(value)
  163. return Permillage(
  164. value.quantize(decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding = self.rounding))
  165. class StrictDictField(DictField):
  166. def to_mongo(self, value, use_db_field = True, fields = None):
  167. # type:(dict, bool, list)->dict
  168. """
  169. 重载ComplexField的to_mongo方法,主要支持转换Mongo不支持的类型
  170. :param value:
  171. :param use_db_field:
  172. :param fields:
  173. :return:
  174. """
  175. #: 常用系统内部Python类型和Mongo类型的映射表
  176. converter_map = {
  177. Decimal: lambda _: Decimal128(_),
  178. Percent: lambda _: _.mongo_amount,
  179. Money: lambda _: _.mongo_amount,
  180. Ratio: lambda _: _.mongo_amount,
  181. Quantity: lambda _: _.mongo_amount,
  182. }
  183. Document = _import_class('Document')
  184. EmbeddedDocument = _import_class('EmbeddedDocument')
  185. GenericReferenceField = _import_class('GenericReferenceField')
  186. if isinstance(value, six.string_types):
  187. return value
  188. if hasattr(value, 'to_mongo'):
  189. if isinstance(value, Document):
  190. return GenericReferenceField().to_mongo(value)
  191. cls = value.__class__
  192. val = value.to_mongo(use_db_field, fields)
  193. # If it's a document that is not inherited add _cls
  194. if isinstance(value, EmbeddedDocument):
  195. val['_cls'] = cls.__name__
  196. return val
  197. if not hasattr(value, 'items'):
  198. try:
  199. value = {k: v for k, v in enumerate(value)}
  200. except TypeError: # Not iterable return the value
  201. return value
  202. if self.field:
  203. value_dict = {
  204. key: self.field._to_mongo_safe_call(item, use_db_field, fields)
  205. for key, item in value.iteritems()
  206. }
  207. else:
  208. value_dict = {}
  209. for k, v in value.iteritems():
  210. if isinstance(v, Document):
  211. # We need the id from the saved object to create the DBRef
  212. if v.pk is None:
  213. self.error('You can only reference documents once they'
  214. ' have been saved to the database')
  215. # If its a document that is not inheritable it won't have
  216. # any _cls data so make it a generic reference allows
  217. # us to dereference
  218. meta = getattr(v, '_meta', {})
  219. allow_inheritance = meta.get('allow_inheritance')
  220. if not allow_inheritance and not self.field:
  221. value_dict[k] = GenericReferenceField().to_mongo(v)
  222. else:
  223. collection = v._get_collection_name()
  224. value_dict[k] = DBRef(collection, v.pk)
  225. elif hasattr(v, 'to_mongo'):
  226. cls = v.__class__
  227. val = v.to_mongo(use_db_field, fields)
  228. # If it's a document that is not inherited add _cls
  229. if isinstance(v, (Document, EmbeddedDocument)):
  230. val['_cls'] = cls.__name__
  231. value_dict[k] = val
  232. else:
  233. for type_, func in converter_map.items():
  234. if isinstance(v, type_):
  235. value_dict[k] = converter_map[type_](v)
  236. break
  237. else:
  238. value_dict[k] = self.to_mongo(v, use_db_field, fields)
  239. return value_dict
  240. class CustomizedSequenceField(SequenceField):
  241. def prepare_query_value(self, op, value):
  242. """
  243. 父类的返回的是 decorator_fn(value) (* fn : x = x + 10000 *), 但是实质我们存储的是已经decorated的value。这里直接返回存储的值
  244. :param op:
  245. :param value:
  246. :return:
  247. """
  248. return value
  249. def inverse_value_decorator(self, fn, value):
  250. # type: (Callable[int, int], int)->int
  251. offset = fn(value) - value
  252. return value - offset
  253. class BooleanIntField(BooleanField):
  254. """
  255. 使用Int存储boolean 方便前台传参
  256. """
  257. __interval = (0, 1)
  258. def to_python(self, value):
  259. try:
  260. value = int(bool(value))
  261. except ValueError:
  262. pass
  263. return value
  264. def to_mongo(self, value):
  265. if value in self.__interval:
  266. return value
  267. return int(value)
  268. def validate(self, value):
  269. if value not in [0, 1] and not isinstance(value, bool):
  270. self.error('BooleanIntField only accepts 0/1 or bool value')
  271. def prepare_condition(condition):
  272. field = [condition['field'], condition['operator']]
  273. field = (s for s in field if s)
  274. field = '__'.join(field)
  275. return {field: condition['value']}
  276. def prepare_conditions(row):
  277. return (Q(**prepare_condition(condition)) for condition in row)
  278. def join_conditions(row):
  279. return reduce(lambda a, b: a | b, prepare_conditions(row))
  280. def join_rows(rows):
  281. return reduce(lambda a, b: a & b, rows)
  282. def dynamic_query(query_input): return join_rows(join_conditions(row) for row in query_input)
  283. def search_query(fields, search_text):
  284. """
  285. 组装搜索查询
  286. :param fields: 供搜索的字段
  287. :param search_text: 供搜索的键值
  288. :return:
  289. """
  290. pattern = Regex.from_native(re.compile('.*' + search_text + '.*', re.IGNORECASE))
  291. filter_dict = {"$or": [{field: {"$regex": pattern}} for field in fields if field != 'id']}
  292. if 'id' in fields:
  293. try:
  294. filter_dict['$or'].insert(0, {'_id': ObjectId(str(search_text))})
  295. except Exception as e:
  296. logger.error("[search_query] get an error = {}, fields = {}, search_text = {}".format(e, fields, search_text))
  297. pass
  298. return Q(__raw__ = filter_dict)
  299. class CustomQuerySet(QuerySet):
  300. def paginate(self, pageIndex, pageSize):
  301. # type: (int, int)->CustomQuerySet
  302. front = (pageIndex - 1) if pageIndex >= 1 else 0
  303. return self.skip(front * pageSize).limit(pageSize)
  304. def search(self, search_key, fields = None):
  305. # type:(str, list)->CustomQuerySet
  306. """
  307. 提供搜索接口,有优化空间,如数据量加大,考虑使用成熟的搜索引擎后端如`ElasticSearch`建立索引
  308. 代码层面
  309. :param search_key:
  310. :param fields: 可选特定字段搜索
  311. :return:
  312. """
  313. if not search_key:
  314. return self
  315. if fields is None:
  316. fields = self._document.search_fields
  317. try:
  318. assert len(fields) >= 1
  319. except (AssertionError, TypeError):
  320. raise ImproperlyConfigured(
  321. 'in order to search, search_fields has to be defined(length > 1) or supplied, model = %r' % self._document
  322. )
  323. else:
  324. if not isinstance(fields, (tuple, list)):
  325. raise TypeError('parameter fields has to be a list')
  326. return self(search_query(fields = fields, search_text = search_key))
  327. def rev_by_time(self):
  328. """
  329. :return:
  330. """
  331. if self._document.time_field is not None:
  332. key = self._document.time_field
  333. elif Const.DEFAULT_TIME_FIELD_NAME in dir(self._document):
  334. key = Const.DEFAULT_TIME_FIELD_NAME
  335. else:
  336. raise ValueError('cannot call rev_by_time without time_field set on document class')
  337. return self().order_by('-{key}'.format(key = key))
  338. def head(self, default = None):
  339. """
  340. :return:
  341. """
  342. queryset = self.clone()
  343. try:
  344. result = queryset[0]
  345. except IndexError:
  346. if callable(default):
  347. result = default()
  348. else:
  349. result = default
  350. return result
  351. def sum_and_count(self, field):
  352. """Sum over the values of the specified field.
  353. # TODO zjl 使用 count_document 重写
  354. :param field: the field to sum over; use dot notation to refer to
  355. embedded document fields
  356. """
  357. db_field = self._fields_to_dbfields([field]).pop()
  358. pipeline = [
  359. {'$match': self._query},
  360. {'$group': {'_id': 'sum', 'count': {'$sum': 1}, 'total': {'$sum': '$' + db_field}}}
  361. ]
  362. # if we're performing a sum over a list field, we sum up all the
  363. # elements in the list, hence we need to $unwind the arrays first
  364. ListField = _import_class('ListField')
  365. field_parts = field.split('.')
  366. field_instances = self._document._lookup_field(field_parts)
  367. if isinstance(field_instances[-1], ListField):
  368. pipeline.insert(1, {'$unwind': '$' + field})
  369. result = self._document._get_collection().aggregate(pipeline)
  370. from mongoengine.python_support import IS_PYMONGO_3
  371. if IS_PYMONGO_3:
  372. result = tuple(result)
  373. else:
  374. result = result.get('result')
  375. if result:
  376. return int(result[0]['count']), result[0]['total']
  377. return 0, 0
  378. def _count_documents(self, with_limit_and_skip = False):
  379. if self._limit == 0 and with_limit_and_skip is False or self._none:
  380. return 0
  381. if with_limit_and_skip:
  382. kwargs = {}
  383. if self._limit:
  384. kwargs["limit"] = self._limit
  385. if self._skip:
  386. kwargs["skip"] = self._skip
  387. return self._collection.count_documents(filter = self._query, **kwargs)
  388. else:
  389. return self._collection.count_documents(filter = self._query)
  390. def count(self, with_limit_and_skip = False):
  391. if with_limit_and_skip is False:
  392. return self._count_documents(with_limit_and_skip)
  393. else:
  394. if self._len is None:
  395. self._len = self._count_documents(with_limit_and_skip)
  396. return self._len
  397. class UtilMixin(object):
  398. def to_js_timestamp(self, datetime_):
  399. #: js 时间戳以毫秒计
  400. return time.mktime(datetime_.timetuple()) * 1000
  401. def to_datetime_str(self, datetime_):
  402. if not datetime_:
  403. return ''
  404. if isinstance(datetime_, datetime.datetime):
  405. return datetime_.strftime(Const.DATETIME_FMT)
  406. else:
  407. return str(datetime_)
  408. def to_date_str(self, datetime_):
  409. return datetime_.strftime(Const.DATE_FMT)
  410. class BaseDocument(DynamicDocument, UtilMixin):
  411. meta = {
  412. 'abstract': True,
  413. 'queryset_class': CustomQuerySet,
  414. 'index_background': True,
  415. 'auto_create_index': False
  416. }
  417. time_field = None
  418. new_cls_list = {}
  419. def __repr__(self):
  420. return self.__str__()
  421. def __str__(self):
  422. return '{}<id={}>'.format(self.__class__.__name__, str(self.id))
  423. @classmethod
  424. def get_collection(cls):
  425. # type: ()->Collection
  426. return cls._get_collection()
  427. clt = get_collection
  428. @classmethod
  429. def agg(cls, pipeline, **kwargs):
  430. return cls.get_collection().aggregate(pipeline, **kwargs)
  431. @classmethod
  432. def last1(cls):
  433. return cls.objects().rev_by_time().first()
  434. @classmethod
  435. def first1(cls):
  436. return cls.objects().first()
  437. def to_dict(self):
  438. result = {}
  439. for field in self.__class__._fields_ordered:
  440. if field == 'id':
  441. value = str(getattr(self, field))
  442. else:
  443. value = getattr(self, field)
  444. if isinstance(value, datetime.datetime):
  445. value = value.strftime('%Y-%m-%d %H:%M:%S')
  446. result[field] = value
  447. return result
  448. class Searchable(BaseDocument):
  449. search_fields = ()
  450. meta = {
  451. 'abstract': True,
  452. }
  453. @classmethod
  454. def search(cls, search_key, fields = None):
  455. """
  456. 搜索功能
  457. :param search_key:
  458. :param fields: 可选特定字段搜索
  459. :return:
  460. """
  461. return cls.objects.search(search_key, fields = fields)
  462. class RoleBaseDocument(Searchable):
  463. meta = {
  464. 'abstract': True,
  465. }
  466. @property
  467. def __role__(self):
  468. return self.__class__.__name__.lower()
  469. @property
  470. def role(self):
  471. role = getattr(self, '__role__', None)
  472. if role is None:
  473. raise ImproperlyConfigured('no __role__ attr provided')
  474. return role
  475. @property
  476. def logName(self):
  477. return '{}<id={}>'.format(self.__class__.__name__, str(self.id))
  478. @property
  479. def request_limit_key(self):
  480. return '{}_{}'.format(str(self.role), str(self.id))
  481. def paginate(queryset, pageIndex, pageSize):
  482. """TODO 添加此方法至自定义的QuerySet"""
  483. if isinstance(queryset, list):
  484. return queryset[(pageIndex - 1) * pageSize: pageIndex * pageSize]
  485. return queryset.skip((pageIndex - 1) * pageSize).limit(pageSize)
  486. class Query(object):
  487. def __init__(self, raw, attrs, pageIndex = 1, pageSize = 10, channel = 'unknown'):
  488. self.raw = raw
  489. self.attrs = attrs
  490. self.pageIndex = pageIndex
  491. self.pageSize = pageSize
  492. self.channel = channel
  493. def to_string(self):
  494. return '&'.join('%s=%s' % (k, v) for k, v in sorted(self.raw.iteritems()))
  495. @property
  496. def cache_key(self):
  497. return '%s:query:%s' % (self.channel, self.to_string(),)
  498. @property
  499. def hashed(self):
  500. return md5(self.to_string())
  501. @classmethod
  502. def from_query_dict(cls, queryDict, ignoreKeys = None):
  503. ignoreKeys = [] if ignoreKeys is None else ignoreKeys
  504. if isinstance(queryDict, QueryDict):
  505. unprocessedPayload = {k: v[0] for k, v in dict(queryDict).items() if v[0] and k not in ignoreKeys}
  506. processedPayload = unprocessedPayload.copy()
  507. pageIndex = int(processedPayload.pop('pageIndex', 1))
  508. pageSize = int(processedPayload.pop('pageSize', 10))
  509. todayDate = datetime.datetime.now().strftime(Const.DATE_FMT)
  510. startTime = datetime.datetime.strptime(str(processedPayload.pop('startTime', todayDate)) + ' 00:00:00',
  511. Const.DATETIME_FMT)
  512. endTime = datetime.datetime.strptime(str(processedPayload.pop('endTime', todayDate)) + ' 23:59:59',
  513. Const.DATETIME_FMT)
  514. processedPayload['dateTimeAdded__lte'] = endTime
  515. processedPayload['dateTimeAdded__gte'] = startTime
  516. return cls(raw = unprocessedPayload, attrs = processedPayload, pageIndex = pageIndex, pageSize = pageSize)
  517. def prepare_query(queryDict, ignoreKeys = None): return Query.from_query_dict(queryDict, ignoreKeys)
  518. def calc_doc_size(doc, unit = 'bytes'):
  519. # type: (Union[dict, Document], str)->int
  520. _unit_convert = {
  521. 'bytes': lambda _: _,
  522. 'kilobytes': lambda _: _ / 1024.,
  523. 'megabytes': lambda _: _ / 1024. / 1024.
  524. }
  525. _calc = lambda _: _unit_convert[unit](len(bson.BSON.encode(_)))
  526. if isinstance(doc, dict):
  527. return _calc(doc)
  528. elif isinstance(doc, Document):
  529. return _calc(doc.to_mongo())
  530. else:
  531. raise ParameterError('only mongoengine Document or raw dicts are supported')
  532. if __name__ == '__main__':
  533. #: test
  534. #: the snippet adapted from https://blog.sneawo.com/blog/2017/03/26/how-to-build-a-dynamic-query-with-mongoengine/
  535. query_input = [
  536. [
  537. {
  538. "field": "some_field",
  539. "operator": "gt",
  540. "value": 30
  541. },
  542. {
  543. "field": "some_field",
  544. "operator": "lt",
  545. "value": 40
  546. },
  547. {
  548. "field": "some_field",
  549. "operator": "",
  550. "value": 35
  551. }
  552. ],
  553. [
  554. {
  555. "field": "another_field",
  556. "operator": "istartswith",
  557. "value": "test"
  558. }
  559. ]
  560. ]
  561. print json.dumps(dynamic_query(query_input).to_query(None))