123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471 |
- from collections import defaultdict
- from bson import ObjectId, SON
- from bson.dbref import DBRef
- import pymongo
- import six
- from mongoengine.base import UPDATE_OPERATORS
- from mongoengine.common import _import_class
- from mongoengine.connection import get_connection
- from mongoengine.errors import InvalidQueryError
- from mongoengine.python_support import IS_PYMONGO_3
- __all__ = ('query', 'update')
- COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
- 'all', 'size', 'exists', 'not', 'elemMatch', 'type')
- GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
- 'within_box', 'within_polygon', 'near', 'near_sphere',
- 'max_distance', 'min_distance', 'geo_within', 'geo_within_box',
- 'geo_within_polygon', 'geo_within_center',
- 'geo_within_sphere', 'geo_intersects')
- STRING_OPERATORS = ('contains', 'icontains', 'startswith',
- 'istartswith', 'endswith', 'iendswith',
- 'exact', 'iexact')
- CUSTOM_OPERATORS = ('match',)
- MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
- STRING_OPERATORS + CUSTOM_OPERATORS)
- # TODO make this less complex
- def query(_doc_cls=None, **kwargs):
- """Transform a query from Django-style format to Mongo format."""
- mongo_query = {}
- merge_query = defaultdict(list)
- for key, value in sorted(kwargs.items()):
- if key == '__raw__':
- mongo_query.update(value)
- continue
- parts = key.rsplit('__')
- indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
- parts = [part for part in parts if not part.isdigit()]
- # Check for an operator and transform to mongo-style if there is
- op = None
- if len(parts) > 1 and parts[-1] in MATCH_OPERATORS:
- op = parts.pop()
- # Allow to escape operator-like field name by __
- if len(parts) > 1 and parts[-1] == '':
- parts.pop()
- negate = False
- if len(parts) > 1 and parts[-1] == 'not':
- parts.pop()
- negate = True
- if _doc_cls:
- # Switch field names to proper names [set in Field(name='foo')]
- try:
- fields = _doc_cls._lookup_field(parts)
- except Exception as e:
- raise InvalidQueryError(e)
- parts = []
- CachedReferenceField = _import_class('CachedReferenceField')
- GenericReferenceField = _import_class('GenericReferenceField')
- cleaned_fields = []
- for field in fields:
- append_field = True
- if isinstance(field, six.string_types):
- parts.append(field)
- append_field = False
- # is last and CachedReferenceField
- elif isinstance(field, CachedReferenceField) and fields[-1] == field:
- parts.append('%s._id' % field.db_field)
- else:
- parts.append(field.db_field)
- if append_field:
- cleaned_fields.append(field)
- # Convert value to proper value
- field = cleaned_fields[-1]
- singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
- singular_ops += STRING_OPERATORS
- if op in singular_ops:
- if isinstance(field, six.string_types):
- if (op in STRING_OPERATORS and
- isinstance(value, six.string_types)):
- StringField = _import_class('StringField')
- value = StringField.prepare_query_value(op, value)
- else:
- value = field
- else:
- value = field.prepare_query_value(op, value)
- if isinstance(field, CachedReferenceField) and value:
- value = value['_id']
- elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
- # Raise an error if the in/nin/all/near param is not iterable.
- value = _prepare_query_for_iterable(field, op, value)
- # If we're querying a GenericReferenceField, we need to alter the
- # key depending on the value:
- # * If the value is a DBRef, the key should be "field_name._ref".
- # * If the value is an ObjectId, the key should be "field_name._ref.$id".
- if isinstance(field, GenericReferenceField):
- if isinstance(value, DBRef):
- parts[-1] += '._ref'
- elif isinstance(value, ObjectId):
- parts[-1] += '._ref.$id'
- # if op and op not in COMPARISON_OPERATORS:
- if op:
- if op in GEO_OPERATORS:
- value = _geo_operator(field, op, value)
- elif op in ('match', 'elemMatch'):
- ListField = _import_class('ListField')
- EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
- if (
- isinstance(value, dict) and
- isinstance(field, ListField) and
- isinstance(field.field, EmbeddedDocumentField)
- ):
- value = query(field.field.document_type, **value)
- else:
- value = field.prepare_query_value(op, value)
- value = {'$elemMatch': value}
- elif op in CUSTOM_OPERATORS:
- NotImplementedError('Custom method "%s" has not '
- 'been implemented' % op)
- elif op not in STRING_OPERATORS:
- value = {'$' + op: value}
- if negate:
- value = {'$not': value}
- for i, part in indices:
- parts.insert(i, part)
- key = '.'.join(parts)
- if op is None or key not in mongo_query:
- mongo_query[key] = value
- elif key in mongo_query:
- if isinstance(mongo_query[key], dict) and isinstance(value, dict):
- mongo_query[key].update(value)
- # $max/minDistance needs to come last - convert to SON
- value_dict = mongo_query[key]
- if ('$maxDistance' in value_dict or '$minDistance' in value_dict) and \
- ('$near' in value_dict or '$nearSphere' in value_dict):
- value_son = SON()
- for k, v in value_dict.iteritems():
- if k == '$maxDistance' or k == '$minDistance':
- continue
- value_son[k] = v
- # Required for MongoDB >= 2.6, may fail when combining
- # PyMongo 3+ and MongoDB < 2.6
- near_embedded = False
- for near_op in ('$near', '$nearSphere'):
- if isinstance(value_dict.get(near_op), dict) and (
- IS_PYMONGO_3 or get_connection().max_wire_version > 1):
- value_son[near_op] = SON(value_son[near_op])
- if '$maxDistance' in value_dict:
- value_son[near_op][
- '$maxDistance'] = value_dict['$maxDistance']
- if '$minDistance' in value_dict:
- value_son[near_op][
- '$minDistance'] = value_dict['$minDistance']
- near_embedded = True
- if not near_embedded:
- if '$maxDistance' in value_dict:
- value_son['$maxDistance'] = value_dict['$maxDistance']
- if '$minDistance' in value_dict:
- value_son['$minDistance'] = value_dict['$minDistance']
- mongo_query[key] = value_son
- else:
- # Store for manually merging later
- merge_query[key].append(value)
- # The queryset has been filter in such a way we must manually merge
- for k, v in merge_query.items():
- merge_query[k].append(mongo_query[k])
- del mongo_query[k]
- if isinstance(v, list):
- value = [{k: val} for val in v]
- if '$and' in mongo_query.keys():
- mongo_query['$and'].extend(value)
- else:
- mongo_query['$and'] = value
- return mongo_query
- def update(_doc_cls=None, **update):
- """Transform an update spec from Django-style format to Mongo
- format.
- """
- mongo_update = {}
- for key, value in update.items():
- if key == '__raw__':
- mongo_update.update(value)
- continue
- parts = key.split('__')
- # if there is no operator, default to 'set'
- if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
- parts.insert(0, 'set')
- # Check for an operator and transform to mongo-style if there is
- op = None
- if parts[0] in UPDATE_OPERATORS:
- op = parts.pop(0)
- # Convert Pythonic names to Mongo equivalents
- operator_map = {
- 'push_all': 'pushAll',
- 'pull_all': 'pullAll',
- 'dec': 'inc',
- 'add_to_set': 'addToSet',
- 'set_on_insert': 'setOnInsert'
- }
- if op == 'dec':
- # Support decrement by flipping a positive value's sign
- # and using 'inc'
- value = -value
- # If the operator doesn't found from operator map, the op value
- # will stay unchanged
- op = operator_map.get(op, op)
- match = None
- if parts[-1] in COMPARISON_OPERATORS:
- match = parts.pop()
- # Allow to escape operator-like field name by __
- if len(parts) > 1 and parts[-1] == '':
- parts.pop()
- if _doc_cls:
- # Switch field names to proper names [set in Field(name='foo')]
- try:
- fields = _doc_cls._lookup_field(parts)
- except Exception as e:
- raise InvalidQueryError(e)
- parts = []
- cleaned_fields = []
- appended_sub_field = False
- for field in fields:
- append_field = True
- if isinstance(field, six.string_types):
- # Convert the S operator to $
- if field == 'S':
- field = '$'
- parts.append(field)
- append_field = False
- else:
- parts.append(field.db_field)
- if append_field:
- appended_sub_field = False
- cleaned_fields.append(field)
- if hasattr(field, 'field'):
- cleaned_fields.append(field.field)
- appended_sub_field = True
- # Convert value to proper value
- if appended_sub_field:
- field = cleaned_fields[-2]
- else:
- field = cleaned_fields[-1]
- GeoJsonBaseField = _import_class('GeoJsonBaseField')
- if isinstance(field, GeoJsonBaseField):
- value = field.to_mongo(value)
- if op == 'pull':
- if field.required or value is not None:
- if match == 'in' and not isinstance(value, dict):
- value = _prepare_query_for_iterable(field, op, value)
- else:
- value = field.prepare_query_value(op, value)
- elif op == 'push' and isinstance(value, (list, tuple, set)):
- value = [field.prepare_query_value(op, v) for v in value]
- elif op in (None, 'set', 'push'):
- if field.required or value is not None:
- value = field.prepare_query_value(op, value)
- elif op in ('pushAll', 'pullAll'):
- value = [field.prepare_query_value(op, v) for v in value]
- elif op in ('addToSet', 'setOnInsert'):
- if isinstance(value, (list, tuple, set)):
- value = [field.prepare_query_value(op, v) for v in value]
- elif field.required or value is not None:
- value = field.prepare_query_value(op, value)
- elif op == 'unset':
- value = 1
- elif op == 'inc':
- value = field.prepare_query_value(op, value)
- if match:
- match = '$' + match
- value = {match: value}
- key = '.'.join(parts)
- if not op:
- raise InvalidQueryError('Updates must supply an operation '
- 'eg: set__FIELD=value')
- if 'pull' in op and '.' in key:
- # Dot operators don't work on pull operations
- # unless they point to a list field
- # Otherwise it uses nested dict syntax
- if op == 'pullAll':
- raise InvalidQueryError('pullAll operations only support '
- 'a single field depth')
- # Look for the last list field and use dot notation until there
- field_classes = [c.__class__ for c in cleaned_fields]
- field_classes.reverse()
- ListField = _import_class('ListField')
- EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
- if ListField in field_classes or EmbeddedDocumentListField in field_classes:
- # Join all fields via dot notation to the last ListField or EmbeddedDocumentListField
- # Then process as normal
- if ListField in field_classes:
- _check_field = ListField
- else:
- _check_field = EmbeddedDocumentListField
- last_listField = len(
- cleaned_fields) - field_classes.index(_check_field)
- key = '.'.join(parts[:last_listField])
- parts = parts[last_listField:]
- parts.insert(0, key)
- parts.reverse()
- for key in parts:
- value = {key: value}
- elif op == 'addToSet' and isinstance(value, list):
- value = {key: {'$each': value}}
- elif op in ('push', 'pushAll'):
- if parts[-1].isdigit():
- key = parts[0]
- position = int(parts[-1])
- # $position expects an iterable. If pushing a single value,
- # wrap it in a list.
- if not isinstance(value, (set, tuple, list)):
- value = [value]
- value = {key: {'$each': value, '$position': position}}
- else:
- if op == 'pushAll':
- op = 'push' # convert to non-deprecated keyword
- if not isinstance(value, (set, tuple, list)):
- value = [value]
- value = {key: {'$each': value}}
- else:
- value = {key: value}
- else:
- value = {key: value}
- key = '$' + op
- if key not in mongo_update:
- mongo_update[key] = value
- elif key in mongo_update and isinstance(mongo_update[key], dict):
- mongo_update[key].update(value)
- return mongo_update
- def _geo_operator(field, op, value):
- """Helper to return the query for a given geo query."""
- if op == 'max_distance':
- value = {'$maxDistance': value}
- elif op == 'min_distance':
- value = {'$minDistance': value}
- elif field._geo_index == pymongo.GEO2D:
- if op == 'within_distance':
- value = {'$within': {'$center': value}}
- elif op == 'within_spherical_distance':
- value = {'$within': {'$centerSphere': value}}
- elif op == 'within_polygon':
- value = {'$within': {'$polygon': value}}
- elif op == 'near':
- value = {'$near': value}
- elif op == 'near_sphere':
- value = {'$nearSphere': value}
- elif op == 'within_box':
- value = {'$within': {'$box': value}}
- else:
- raise NotImplementedError('Geo method "%s" has not been '
- 'implemented for a GeoPointField' % op)
- else:
- if op == 'geo_within':
- value = {'$geoWithin': _infer_geometry(value)}
- elif op == 'geo_within_box':
- value = {'$geoWithin': {'$box': value}}
- elif op == 'geo_within_polygon':
- value = {'$geoWithin': {'$polygon': value}}
- elif op == 'geo_within_center':
- value = {'$geoWithin': {'$center': value}}
- elif op == 'geo_within_sphere':
- value = {'$geoWithin': {'$centerSphere': value}}
- elif op == 'geo_intersects':
- value = {'$geoIntersects': _infer_geometry(value)}
- elif op == 'near':
- value = {'$near': _infer_geometry(value)}
- else:
- raise NotImplementedError(
- 'Geo method "%s" has not been implemented for a %s '
- % (op, field._name)
- )
- return value
- def _infer_geometry(value):
- """Helper method that tries to infer the $geometry shape for a
- given value.
- """
- if isinstance(value, dict):
- if '$geometry' in value:
- return value
- elif 'coordinates' in value and 'type' in value:
- return {'$geometry': value}
- raise InvalidQueryError('Invalid $geometry dictionary should have '
- 'type and coordinates keys')
- elif isinstance(value, (list, set)):
- # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
- try:
- value[0][0][0]
- return {'$geometry': {'type': 'Polygon', 'coordinates': value}}
- except (TypeError, IndexError):
- pass
- try:
- value[0][0]
- return {'$geometry': {'type': 'LineString', 'coordinates': value}}
- except (TypeError, IndexError):
- pass
- try:
- value[0]
- return {'$geometry': {'type': 'Point', 'coordinates': value}}
- except (TypeError, IndexError):
- pass
- raise InvalidQueryError('Invalid $geometry data. Can be either a '
- 'dictionary or (nested) lists of coordinate(s)')
- def _prepare_query_for_iterable(field, op, value):
- # We need a special check for BaseDocument, because - although it's iterable - using
- # it as such in the context of this method is most definitely a mistake.
- BaseDocument = _import_class('BaseDocument')
- if isinstance(value, BaseDocument):
- raise TypeError("When using the `in`, `nin`, `all`, or "
- "`near`-operators you can\'t use a "
- "`Document`, you must wrap your object "
- "in a list (object -> [object]).")
- if not hasattr(value, '__iter__'):
- raise TypeError("The `in`, `nin`, `all`, or "
- "`near`-operators must be applied to an "
- "iterable (e.g. a list).")
- return [field.prepare_query_value(op, v) for v in value]
|