# -*- coding: utf-8 -*- # !/usr/bin/env python import datetime import decimal import json import logging import re import time from collections import Callable from copy import deepcopy from decimal import Decimal from functools import reduce import inspect import bson import six from bson import DBRef, ObjectId from bson.decimal128 import Decimal128 from bson.regex import Regex from django.http.request import QueryDict from mongoengine import Q, Document, DynamicDocument, QuerySet, DecimalField, DictField, SequenceField, BooleanField from mongoengine.base import BaseField from mongoengine.common import _import_class from typing import Union, TYPE_CHECKING from apilib.monetary import RMB, VirtualCoin, Ratio, Percent, Money, Permillage, AccuracyRMB from apilib.quantity import Quantity from apilib.utils_string import md5 from apps.web.constant import Const from apps.web.core.exceptions import ParameterError, ImproperlyConfigured if TYPE_CHECKING: from pymongo.collection import Collection logger = logging.getLogger(__name__) def copy_document_classes(clazz, new_name, db_alias): exclude = ['_fields', '_db_field_map', '_reverse_db_field_map', '_fields_ordered', '_is_document', 'MultipleObjectsReturned', '_superclasses', '_subclasses', '_types', '_class_name', '_meta', '__doc__', '__module__', '_collection', '_is_base_cls', '_auto_id_field', 'id', 'DoesNotExist', 'objects', '_cached_reference_fields'] new_cls_key = '{}_{}'.format(db_alias, clazz.__name__) if new_cls_key not in clazz.new_cls_list.keys(): dicts = {} for name, field in clazz.__dict__.iteritems(): if name in exclude: continue else: if isinstance(field, BaseField): if field.unique: field.unique = False dicts[name] = field assert hasattr(clazz, '_origin_meta'), u'必须提供_origin_meta' dicts['meta'] = deepcopy(clazz._origin_meta) dicts['meta'].pop('indexes', None) dicts['meta']['index_background'] = True dicts['meta']['auto_create_index'] = False dicts['meta']['shard_key'] = clazz._shard_key dicts['__module__'] = clazz.__module__ dicts['meta']['db_alias'] = db_alias new_cls = type(new_name, clazz.__bases__, dicts) for name, field in clazz.__dict__.iteritems(): if name in exclude: continue else: field = getattr(clazz, name) if isinstance(field, BaseField): continue else: if inspect.ismethod(field): if not field.im_self: setattr(new_cls, name, field.im_func) else: setattr(new_cls, name, classmethod(field.im_func)) elif inspect.isfunction(field): setattr(new_cls, name, staticmethod(field)) else: setattr(new_cls, name, field) clazz.new_cls_list[new_cls_key] = new_cls else: pass return clazz.new_cls_list[new_cls_key] # Customized fields class MonetaryField(DecimalField): def __init__(self, min_value = None, max_value = None, force_string = False, precision = 4, rounding = decimal.ROUND_HALF_UP, **kwargs): #: 默认精度四位,展示2-3位 super(MonetaryField, self).__init__(min_value = min_value, max_value = max_value, force_string = force_string, precision = precision, rounding = rounding, **kwargs) def to_mongo(self, value): value = self.to_python(value) # type: Union[RMB, VirtualCoin] return Decimal128(value.amount) def to_python(self, value, precision=3): if value is None: return value try: value = decimal.Decimal('%s' % value) except (TypeError, ValueError, decimal.InvalidOperation): return RMB(value) return RMB( value.quantize(decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding = self.rounding)) def validate(self, value): value = self.to_python(value) if self.min_value is not None and value < self.min_value: self.error('Monetary value is too small') if self.max_value is not None and value > self.max_value: self.error('Monetary value is too large') class AccuracyMoneyField(MonetaryField): def to_mongo(self, value): value = self.to_python(value) # type: Union[AccuracyRMB] return Decimal128(value.amount) def to_python(self, value, precision=5): if value is None: return value try: value = decimal.Decimal('%s' % value) except (TypeError, ValueError, decimal.InvalidOperation): return AccuracyRMB(value) return AccuracyRMB(value.quantize(decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding=self.rounding)) class VirtualCoinField(MonetaryField): def to_python(self, value, precision = 2): if value is None: return value # Convert to string for python 2.6 before casting to Decimal try: value = decimal.Decimal('%s' % value) except (TypeError, ValueError, decimal.InvalidOperation): return VirtualCoin(value) return VirtualCoin(value.quantize( decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding = self.rounding)) class RatioField(MonetaryField): def to_python(self, value, precision = 3): if value is None: return value # Convert to string for python 2.6 before casting to Decimal try: value = decimal.Decimal('%s' % value) except (TypeError, ValueError, decimal.InvalidOperation): return Ratio(value) return Ratio( value.quantize(decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding = self.rounding)) class PercentField(MonetaryField): """ 百分之 """ def to_python(self, value, precision = 2): if value is None: return value # Convert to string for python 2.6 before casting to Decimal try: value = decimal.Decimal('%s' % value) except (TypeError, ValueError, decimal.InvalidOperation): return Percent(value) return Percent( value.quantize(decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding = self.rounding)) class PermillageField(MonetaryField): """ """ def to_python(self, value, precision = 2): if value is None: return value # Convert to string for python 2.6 before casting to Decimal try: value = decimal.Decimal('%s' % value) except (TypeError, ValueError, decimal.InvalidOperation): return Permillage(value) return Permillage( value.quantize(decimal.Decimal('.%s' % ('0' * (precision or self.precision))), rounding = self.rounding)) class StrictDictField(DictField): def to_mongo(self, value, use_db_field = True, fields = None): # type:(dict, bool, list)->dict """ 重载ComplexField的to_mongo方法,主要支持转换Mongo不支持的类型 :param value: :param use_db_field: :param fields: :return: """ #: 常用系统内部Python类型和Mongo类型的映射表 converter_map = { Decimal: lambda _: Decimal128(_), Percent: lambda _: _.mongo_amount, Money: lambda _: _.mongo_amount, Ratio: lambda _: _.mongo_amount, Quantity: lambda _: _.mongo_amount, } Document = _import_class('Document') EmbeddedDocument = _import_class('EmbeddedDocument') GenericReferenceField = _import_class('GenericReferenceField') if isinstance(value, six.string_types): return value if hasattr(value, 'to_mongo'): if isinstance(value, Document): return GenericReferenceField().to_mongo(value) cls = value.__class__ val = value.to_mongo(use_db_field, fields) # If it's a document that is not inherited add _cls if isinstance(value, EmbeddedDocument): val['_cls'] = cls.__name__ return val if not hasattr(value, 'items'): try: value = {k: v for k, v in enumerate(value)} except TypeError: # Not iterable return the value return value if self.field: value_dict = { key: self.field._to_mongo_safe_call(item, use_db_field, fields) for key, item in value.iteritems() } else: value_dict = {} for k, v in value.iteritems(): if isinstance(v, Document): # We need the id from the saved object to create the DBRef if v.pk is None: self.error('You can only reference documents once they' ' have been saved to the database') # If its a document that is not inheritable it won't have # any _cls data so make it a generic reference allows # us to dereference meta = getattr(v, '_meta', {}) allow_inheritance = meta.get('allow_inheritance') if not allow_inheritance and not self.field: value_dict[k] = GenericReferenceField().to_mongo(v) else: collection = v._get_collection_name() value_dict[k] = DBRef(collection, v.pk) elif hasattr(v, 'to_mongo'): cls = v.__class__ val = v.to_mongo(use_db_field, fields) # If it's a document that is not inherited add _cls if isinstance(v, (Document, EmbeddedDocument)): val['_cls'] = cls.__name__ value_dict[k] = val else: for type_, func in converter_map.items(): if isinstance(v, type_): value_dict[k] = converter_map[type_](v) break else: value_dict[k] = self.to_mongo(v, use_db_field, fields) return value_dict class CustomizedSequenceField(SequenceField): def prepare_query_value(self, op, value): """ 父类的返回的是 decorator_fn(value) (* fn : x = x + 10000 *), 但是实质我们存储的是已经decorated的value。这里直接返回存储的值 :param op: :param value: :return: """ return value def inverse_value_decorator(self, fn, value): # type: (Callable[int, int], int)->int offset = fn(value) - value return value - offset class BooleanIntField(BooleanField): """ 使用Int存储boolean 方便前台传参 """ __interval = (0, 1) def to_python(self, value): try: value = int(bool(value)) except ValueError: pass return value def to_mongo(self, value): if value in self.__interval: return value return int(value) def validate(self, value): if value not in [0, 1] and not isinstance(value, bool): self.error('BooleanIntField only accepts 0/1 or bool value') def prepare_condition(condition): field = [condition['field'], condition['operator']] field = (s for s in field if s) field = '__'.join(field) return {field: condition['value']} def prepare_conditions(row): return (Q(**prepare_condition(condition)) for condition in row) def join_conditions(row): return reduce(lambda a, b: a | b, prepare_conditions(row)) def join_rows(rows): return reduce(lambda a, b: a & b, rows) def dynamic_query(query_input): return join_rows(join_conditions(row) for row in query_input) def search_query(fields, search_text): """ 组装搜索查询 :param fields: 供搜索的字段 :param search_text: 供搜索的键值 :return: """ pattern = Regex.from_native(re.compile('.*' + search_text + '.*', re.IGNORECASE)) filter_dict = {"$or": [{field: {"$regex": pattern}} for field in fields if field != 'id']} if 'id' in fields: try: filter_dict['$or'].insert(0, {'_id': ObjectId(str(search_text))}) except Exception as e: logger.error("[search_query] get an error = {}, fields = {}, search_text = {}".format(e, fields, search_text)) pass return Q(__raw__ = filter_dict) class CustomQuerySet(QuerySet): def paginate(self, pageIndex, pageSize): # type: (int, int)->CustomQuerySet front = (pageIndex - 1) if pageIndex >= 1 else 0 return self.skip(front * pageSize).limit(pageSize) def search(self, search_key, fields = None): # type:(str, list)->CustomQuerySet """ 提供搜索接口,有优化空间,如数据量加大,考虑使用成熟的搜索引擎后端如`ElasticSearch`建立索引 代码层面 :param search_key: :param fields: 可选特定字段搜索 :return: """ if not search_key: return self if fields is None: fields = self._document.search_fields try: assert len(fields) >= 1 except (AssertionError, TypeError): raise ImproperlyConfigured( 'in order to search, search_fields has to be defined(length > 1) or supplied, model = %r' % self._document ) else: if not isinstance(fields, (tuple, list)): raise TypeError('parameter fields has to be a list') return self(search_query(fields = fields, search_text = search_key)) def rev_by_time(self): """ :return: """ if self._document.time_field is not None: key = self._document.time_field elif Const.DEFAULT_TIME_FIELD_NAME in dir(self._document): key = Const.DEFAULT_TIME_FIELD_NAME else: raise ValueError('cannot call rev_by_time without time_field set on document class') return self().order_by('-{key}'.format(key = key)) def head(self, default = None): """ :return: """ queryset = self.clone() try: result = queryset[0] except IndexError: if callable(default): result = default() else: result = default return result def sum_and_count(self, field): """Sum over the values of the specified field. # TODO zjl 使用 count_document 重写 :param field: the field to sum over; use dot notation to refer to embedded document fields """ db_field = self._fields_to_dbfields([field]).pop() pipeline = [ {'$match': self._query}, {'$group': {'_id': 'sum', 'count': {'$sum': 1}, 'total': {'$sum': '$' + db_field}}} ] # if we're performing a sum over a list field, we sum up all the # elements in the list, hence we need to $unwind the arrays first ListField = _import_class('ListField') field_parts = field.split('.') field_instances = self._document._lookup_field(field_parts) if isinstance(field_instances[-1], ListField): pipeline.insert(1, {'$unwind': '$' + field}) result = self._document._get_collection().aggregate(pipeline) from mongoengine.python_support import IS_PYMONGO_3 if IS_PYMONGO_3: result = tuple(result) else: result = result.get('result') if result: return int(result[0]['count']), result[0]['total'] return 0, 0 def _count_documents(self, with_limit_and_skip = False): if self._limit == 0 and with_limit_and_skip is False or self._none: return 0 if with_limit_and_skip: kwargs = {} if self._limit: kwargs["limit"] = self._limit if self._skip: kwargs["skip"] = self._skip return self._collection.count_documents(filter = self._query, **kwargs) else: return self._collection.count_documents(filter = self._query) def count(self, with_limit_and_skip = False): if with_limit_and_skip is False: return self._count_documents(with_limit_and_skip) else: if self._len is None: self._len = self._count_documents(with_limit_and_skip) return self._len class UtilMixin(object): def to_js_timestamp(self, datetime_): #: js 时间戳以毫秒计 return time.mktime(datetime_.timetuple()) * 1000 def to_datetime_str(self, datetime_): if not datetime_: return '' if isinstance(datetime_, datetime.datetime): return datetime_.strftime(Const.DATETIME_FMT) else: return str(datetime_) def to_date_str(self, datetime_): return datetime_.strftime(Const.DATE_FMT) class BaseDocument(DynamicDocument, UtilMixin): meta = { 'abstract': True, 'queryset_class': CustomQuerySet, 'index_background': True, 'auto_create_index': False } time_field = None new_cls_list = {} def __repr__(self): return self.__str__() def __str__(self): return '{}'.format(self.__class__.__name__, str(self.id)) @classmethod def get_collection(cls): # type: ()->Collection return cls._get_collection() clt = get_collection @classmethod def agg(cls, pipeline, **kwargs): return cls.get_collection().aggregate(pipeline, **kwargs) @classmethod def last1(cls): return cls.objects().rev_by_time().first() @classmethod def first1(cls): return cls.objects().first() def to_dict(self): result = {} for field in self.__class__._fields_ordered: if field == 'id': value = str(getattr(self, field)) else: value = getattr(self, field) if isinstance(value, datetime.datetime): value = value.strftime('%Y-%m-%d %H:%M:%S') result[field] = value return result class Searchable(BaseDocument): search_fields = () meta = { 'abstract': True, } @classmethod def search(cls, search_key, fields = None): """ 搜索功能 :param search_key: :param fields: 可选特定字段搜索 :return: """ return cls.objects.search(search_key, fields = fields) class RoleBaseDocument(Searchable): meta = { 'abstract': True, } @property def __role__(self): return self.__class__.__name__.lower() @property def role(self): role = getattr(self, '__role__', None) if role is None: raise ImproperlyConfigured('no __role__ attr provided') return role @property def logName(self): return '{}'.format(self.__class__.__name__, str(self.id)) @property def request_limit_key(self): return '{}_{}'.format(str(self.role), str(self.id)) def paginate(queryset, pageIndex, pageSize): """TODO 添加此方法至自定义的QuerySet""" if isinstance(queryset, list): return queryset[(pageIndex - 1) * pageSize: pageIndex * pageSize] return queryset.skip((pageIndex - 1) * pageSize).limit(pageSize) class Query(object): def __init__(self, raw, attrs, pageIndex = 1, pageSize = 10, channel = 'unknown'): self.raw = raw self.attrs = attrs self.pageIndex = pageIndex self.pageSize = pageSize self.channel = channel def to_string(self): return '&'.join('%s=%s' % (k, v) for k, v in sorted(self.raw.iteritems())) @property def cache_key(self): return '%s:query:%s' % (self.channel, self.to_string(),) @property def hashed(self): return md5(self.to_string()) @classmethod def from_query_dict(cls, queryDict, ignoreKeys = None): ignoreKeys = [] if ignoreKeys is None else ignoreKeys if isinstance(queryDict, QueryDict): unprocessedPayload = {k: v[0] for k, v in dict(queryDict).items() if v[0] and k not in ignoreKeys} processedPayload = unprocessedPayload.copy() pageIndex = int(processedPayload.pop('pageIndex', 1)) pageSize = int(processedPayload.pop('pageSize', 10)) todayDate = datetime.datetime.now().strftime(Const.DATE_FMT) startTime = datetime.datetime.strptime(str(processedPayload.pop('startTime', todayDate)) + ' 00:00:00', Const.DATETIME_FMT) endTime = datetime.datetime.strptime(str(processedPayload.pop('endTime', todayDate)) + ' 23:59:59', Const.DATETIME_FMT) processedPayload['dateTimeAdded__lte'] = endTime processedPayload['dateTimeAdded__gte'] = startTime return cls(raw = unprocessedPayload, attrs = processedPayload, pageIndex = pageIndex, pageSize = pageSize) def prepare_query(queryDict, ignoreKeys = None): return Query.from_query_dict(queryDict, ignoreKeys) def calc_doc_size(doc, unit = 'bytes'): # type: (Union[dict, Document], str)->int _unit_convert = { 'bytes': lambda _: _, 'kilobytes': lambda _: _ / 1024., 'megabytes': lambda _: _ / 1024. / 1024. } _calc = lambda _: _unit_convert[unit](len(bson.BSON.encode(_))) if isinstance(doc, dict): return _calc(doc) elif isinstance(doc, Document): return _calc(doc.to_mongo()) else: raise ParameterError('only mongoengine Document or raw dicts are supported') if __name__ == '__main__': #: test #: the snippet adapted from https://blog.sneawo.com/blog/2017/03/26/how-to-build-a-dynamic-query-with-mongoengine/ query_input = [ [ { "field": "some_field", "operator": "gt", "value": 30 }, { "field": "some_field", "operator": "lt", "value": 40 }, { "field": "some_field", "operator": "", "value": 35 } ], [ { "field": "another_field", "operator": "istartswith", "value": "test" } ] ] print json.dumps(dynamic_query(query_input).to_query(None))