transform.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. from collections import defaultdict
  2. from bson import ObjectId, SON
  3. from bson.dbref import DBRef
  4. import pymongo
  5. import six
  6. from mongoengine.base import UPDATE_OPERATORS
  7. from mongoengine.common import _import_class
  8. from mongoengine.connection import get_connection
  9. from mongoengine.errors import InvalidQueryError
  10. from mongoengine.python_support import IS_PYMONGO_3
  11. __all__ = ('query', 'update')
  12. COMPARISON_OPERATORS = ('ne', 'gt', 'gte', 'lt', 'lte', 'in', 'nin', 'mod',
  13. 'all', 'size', 'exists', 'not', 'elemMatch', 'type')
  14. GEO_OPERATORS = ('within_distance', 'within_spherical_distance',
  15. 'within_box', 'within_polygon', 'near', 'near_sphere',
  16. 'max_distance', 'min_distance', 'geo_within', 'geo_within_box',
  17. 'geo_within_polygon', 'geo_within_center',
  18. 'geo_within_sphere', 'geo_intersects')
  19. STRING_OPERATORS = ('contains', 'icontains', 'startswith',
  20. 'istartswith', 'endswith', 'iendswith',
  21. 'exact', 'iexact')
  22. CUSTOM_OPERATORS = ('match',)
  23. MATCH_OPERATORS = (COMPARISON_OPERATORS + GEO_OPERATORS +
  24. STRING_OPERATORS + CUSTOM_OPERATORS)
  25. # TODO make this less complex
  26. def query(_doc_cls=None, **kwargs):
  27. """Transform a query from Django-style format to Mongo format."""
  28. mongo_query = {}
  29. merge_query = defaultdict(list)
  30. for key, value in sorted(kwargs.items()):
  31. if key == '__raw__':
  32. mongo_query.update(value)
  33. continue
  34. parts = key.rsplit('__')
  35. indices = [(i, p) for i, p in enumerate(parts) if p.isdigit()]
  36. parts = [part for part in parts if not part.isdigit()]
  37. # Check for an operator and transform to mongo-style if there is
  38. op = None
  39. if len(parts) > 1 and parts[-1] in MATCH_OPERATORS:
  40. op = parts.pop()
  41. # Allow to escape operator-like field name by __
  42. if len(parts) > 1 and parts[-1] == '':
  43. parts.pop()
  44. negate = False
  45. if len(parts) > 1 and parts[-1] == 'not':
  46. parts.pop()
  47. negate = True
  48. if _doc_cls:
  49. # Switch field names to proper names [set in Field(name='foo')]
  50. try:
  51. fields = _doc_cls._lookup_field(parts)
  52. except Exception as e:
  53. raise InvalidQueryError(e)
  54. parts = []
  55. CachedReferenceField = _import_class('CachedReferenceField')
  56. GenericReferenceField = _import_class('GenericReferenceField')
  57. cleaned_fields = []
  58. for field in fields:
  59. append_field = True
  60. if isinstance(field, six.string_types):
  61. parts.append(field)
  62. append_field = False
  63. # is last and CachedReferenceField
  64. elif isinstance(field, CachedReferenceField) and fields[-1] == field:
  65. parts.append('%s._id' % field.db_field)
  66. else:
  67. parts.append(field.db_field)
  68. if append_field:
  69. cleaned_fields.append(field)
  70. # Convert value to proper value
  71. field = cleaned_fields[-1]
  72. singular_ops = [None, 'ne', 'gt', 'gte', 'lt', 'lte', 'not']
  73. singular_ops += STRING_OPERATORS
  74. if op in singular_ops:
  75. if isinstance(field, six.string_types):
  76. if (op in STRING_OPERATORS and
  77. isinstance(value, six.string_types)):
  78. StringField = _import_class('StringField')
  79. value = StringField.prepare_query_value(op, value)
  80. else:
  81. value = field
  82. else:
  83. value = field.prepare_query_value(op, value)
  84. if isinstance(field, CachedReferenceField) and value:
  85. value = value['_id']
  86. elif op in ('in', 'nin', 'all', 'near') and not isinstance(value, dict):
  87. # Raise an error if the in/nin/all/near param is not iterable.
  88. value = _prepare_query_for_iterable(field, op, value)
  89. # If we're querying a GenericReferenceField, we need to alter the
  90. # key depending on the value:
  91. # * If the value is a DBRef, the key should be "field_name._ref".
  92. # * If the value is an ObjectId, the key should be "field_name._ref.$id".
  93. if isinstance(field, GenericReferenceField):
  94. if isinstance(value, DBRef):
  95. parts[-1] += '._ref'
  96. elif isinstance(value, ObjectId):
  97. parts[-1] += '._ref.$id'
  98. # if op and op not in COMPARISON_OPERATORS:
  99. if op:
  100. if op in GEO_OPERATORS:
  101. value = _geo_operator(field, op, value)
  102. elif op in ('match', 'elemMatch'):
  103. ListField = _import_class('ListField')
  104. EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
  105. if (
  106. isinstance(value, dict) and
  107. isinstance(field, ListField) and
  108. isinstance(field.field, EmbeddedDocumentField)
  109. ):
  110. value = query(field.field.document_type, **value)
  111. else:
  112. value = field.prepare_query_value(op, value)
  113. value = {'$elemMatch': value}
  114. elif op in CUSTOM_OPERATORS:
  115. NotImplementedError('Custom method "%s" has not '
  116. 'been implemented' % op)
  117. elif op not in STRING_OPERATORS:
  118. value = {'$' + op: value}
  119. if negate:
  120. value = {'$not': value}
  121. for i, part in indices:
  122. parts.insert(i, part)
  123. key = '.'.join(parts)
  124. if op is None or key not in mongo_query:
  125. mongo_query[key] = value
  126. elif key in mongo_query:
  127. if isinstance(mongo_query[key], dict) and isinstance(value, dict):
  128. mongo_query[key].update(value)
  129. # $max/minDistance needs to come last - convert to SON
  130. value_dict = mongo_query[key]
  131. if ('$maxDistance' in value_dict or '$minDistance' in value_dict) and \
  132. ('$near' in value_dict or '$nearSphere' in value_dict):
  133. value_son = SON()
  134. for k, v in value_dict.iteritems():
  135. if k == '$maxDistance' or k == '$minDistance':
  136. continue
  137. value_son[k] = v
  138. # Required for MongoDB >= 2.6, may fail when combining
  139. # PyMongo 3+ and MongoDB < 2.6
  140. near_embedded = False
  141. for near_op in ('$near', '$nearSphere'):
  142. if isinstance(value_dict.get(near_op), dict) and (
  143. IS_PYMONGO_3 or get_connection().max_wire_version > 1):
  144. value_son[near_op] = SON(value_son[near_op])
  145. if '$maxDistance' in value_dict:
  146. value_son[near_op][
  147. '$maxDistance'] = value_dict['$maxDistance']
  148. if '$minDistance' in value_dict:
  149. value_son[near_op][
  150. '$minDistance'] = value_dict['$minDistance']
  151. near_embedded = True
  152. if not near_embedded:
  153. if '$maxDistance' in value_dict:
  154. value_son['$maxDistance'] = value_dict['$maxDistance']
  155. if '$minDistance' in value_dict:
  156. value_son['$minDistance'] = value_dict['$minDistance']
  157. mongo_query[key] = value_son
  158. else:
  159. # Store for manually merging later
  160. merge_query[key].append(value)
  161. # The queryset has been filter in such a way we must manually merge
  162. for k, v in merge_query.items():
  163. merge_query[k].append(mongo_query[k])
  164. del mongo_query[k]
  165. if isinstance(v, list):
  166. value = [{k: val} for val in v]
  167. if '$and' in mongo_query.keys():
  168. mongo_query['$and'].extend(value)
  169. else:
  170. mongo_query['$and'] = value
  171. return mongo_query
  172. def update(_doc_cls=None, **update):
  173. """Transform an update spec from Django-style format to Mongo
  174. format.
  175. """
  176. mongo_update = {}
  177. for key, value in update.items():
  178. if key == '__raw__':
  179. mongo_update.update(value)
  180. continue
  181. parts = key.split('__')
  182. # if there is no operator, default to 'set'
  183. if len(parts) < 3 and parts[0] not in UPDATE_OPERATORS:
  184. parts.insert(0, 'set')
  185. # Check for an operator and transform to mongo-style if there is
  186. op = None
  187. if parts[0] in UPDATE_OPERATORS:
  188. op = parts.pop(0)
  189. # Convert Pythonic names to Mongo equivalents
  190. operator_map = {
  191. 'push_all': 'pushAll',
  192. 'pull_all': 'pullAll',
  193. 'dec': 'inc',
  194. 'add_to_set': 'addToSet',
  195. 'set_on_insert': 'setOnInsert'
  196. }
  197. if op == 'dec':
  198. # Support decrement by flipping a positive value's sign
  199. # and using 'inc'
  200. value = -value
  201. # If the operator doesn't found from operator map, the op value
  202. # will stay unchanged
  203. op = operator_map.get(op, op)
  204. match = None
  205. if parts[-1] in COMPARISON_OPERATORS:
  206. match = parts.pop()
  207. # Allow to escape operator-like field name by __
  208. if len(parts) > 1 and parts[-1] == '':
  209. parts.pop()
  210. if _doc_cls:
  211. # Switch field names to proper names [set in Field(name='foo')]
  212. try:
  213. fields = _doc_cls._lookup_field(parts)
  214. except Exception as e:
  215. raise InvalidQueryError(e)
  216. parts = []
  217. cleaned_fields = []
  218. appended_sub_field = False
  219. for field in fields:
  220. append_field = True
  221. if isinstance(field, six.string_types):
  222. # Convert the S operator to $
  223. if field == 'S':
  224. field = '$'
  225. parts.append(field)
  226. append_field = False
  227. else:
  228. parts.append(field.db_field)
  229. if append_field:
  230. appended_sub_field = False
  231. cleaned_fields.append(field)
  232. if hasattr(field, 'field'):
  233. cleaned_fields.append(field.field)
  234. appended_sub_field = True
  235. # Convert value to proper value
  236. if appended_sub_field:
  237. field = cleaned_fields[-2]
  238. else:
  239. field = cleaned_fields[-1]
  240. GeoJsonBaseField = _import_class('GeoJsonBaseField')
  241. if isinstance(field, GeoJsonBaseField):
  242. value = field.to_mongo(value)
  243. if op == 'pull':
  244. if field.required or value is not None:
  245. if match == 'in' and not isinstance(value, dict):
  246. value = _prepare_query_for_iterable(field, op, value)
  247. else:
  248. value = field.prepare_query_value(op, value)
  249. elif op == 'push' and isinstance(value, (list, tuple, set)):
  250. value = [field.prepare_query_value(op, v) for v in value]
  251. elif op in (None, 'set', 'push'):
  252. if field.required or value is not None:
  253. value = field.prepare_query_value(op, value)
  254. elif op in ('pushAll', 'pullAll'):
  255. value = [field.prepare_query_value(op, v) for v in value]
  256. elif op in ('addToSet', 'setOnInsert'):
  257. if isinstance(value, (list, tuple, set)):
  258. value = [field.prepare_query_value(op, v) for v in value]
  259. elif field.required or value is not None:
  260. value = field.prepare_query_value(op, value)
  261. elif op == 'unset':
  262. value = 1
  263. elif op == 'inc':
  264. value = field.prepare_query_value(op, value)
  265. if match:
  266. match = '$' + match
  267. value = {match: value}
  268. key = '.'.join(parts)
  269. if not op:
  270. raise InvalidQueryError('Updates must supply an operation '
  271. 'eg: set__FIELD=value')
  272. if 'pull' in op and '.' in key:
  273. # Dot operators don't work on pull operations
  274. # unless they point to a list field
  275. # Otherwise it uses nested dict syntax
  276. if op == 'pullAll':
  277. raise InvalidQueryError('pullAll operations only support '
  278. 'a single field depth')
  279. # Look for the last list field and use dot notation until there
  280. field_classes = [c.__class__ for c in cleaned_fields]
  281. field_classes.reverse()
  282. ListField = _import_class('ListField')
  283. EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
  284. if ListField in field_classes or EmbeddedDocumentListField in field_classes:
  285. # Join all fields via dot notation to the last ListField or EmbeddedDocumentListField
  286. # Then process as normal
  287. if ListField in field_classes:
  288. _check_field = ListField
  289. else:
  290. _check_field = EmbeddedDocumentListField
  291. last_listField = len(
  292. cleaned_fields) - field_classes.index(_check_field)
  293. key = '.'.join(parts[:last_listField])
  294. parts = parts[last_listField:]
  295. parts.insert(0, key)
  296. parts.reverse()
  297. for key in parts:
  298. value = {key: value}
  299. elif op == 'addToSet' and isinstance(value, list):
  300. value = {key: {'$each': value}}
  301. elif op in ('push', 'pushAll'):
  302. if parts[-1].isdigit():
  303. key = parts[0]
  304. position = int(parts[-1])
  305. # $position expects an iterable. If pushing a single value,
  306. # wrap it in a list.
  307. if not isinstance(value, (set, tuple, list)):
  308. value = [value]
  309. value = {key: {'$each': value, '$position': position}}
  310. else:
  311. if op == 'pushAll':
  312. op = 'push' # convert to non-deprecated keyword
  313. if not isinstance(value, (set, tuple, list)):
  314. value = [value]
  315. value = {key: {'$each': value}}
  316. else:
  317. value = {key: value}
  318. else:
  319. value = {key: value}
  320. key = '$' + op
  321. if key not in mongo_update:
  322. mongo_update[key] = value
  323. elif key in mongo_update and isinstance(mongo_update[key], dict):
  324. mongo_update[key].update(value)
  325. return mongo_update
  326. def _geo_operator(field, op, value):
  327. """Helper to return the query for a given geo query."""
  328. if op == 'max_distance':
  329. value = {'$maxDistance': value}
  330. elif op == 'min_distance':
  331. value = {'$minDistance': value}
  332. elif field._geo_index == pymongo.GEO2D:
  333. if op == 'within_distance':
  334. value = {'$within': {'$center': value}}
  335. elif op == 'within_spherical_distance':
  336. value = {'$within': {'$centerSphere': value}}
  337. elif op == 'within_polygon':
  338. value = {'$within': {'$polygon': value}}
  339. elif op == 'near':
  340. value = {'$near': value}
  341. elif op == 'near_sphere':
  342. value = {'$nearSphere': value}
  343. elif op == 'within_box':
  344. value = {'$within': {'$box': value}}
  345. else:
  346. raise NotImplementedError('Geo method "%s" has not been '
  347. 'implemented for a GeoPointField' % op)
  348. else:
  349. if op == 'geo_within':
  350. value = {'$geoWithin': _infer_geometry(value)}
  351. elif op == 'geo_within_box':
  352. value = {'$geoWithin': {'$box': value}}
  353. elif op == 'geo_within_polygon':
  354. value = {'$geoWithin': {'$polygon': value}}
  355. elif op == 'geo_within_center':
  356. value = {'$geoWithin': {'$center': value}}
  357. elif op == 'geo_within_sphere':
  358. value = {'$geoWithin': {'$centerSphere': value}}
  359. elif op == 'geo_intersects':
  360. value = {'$geoIntersects': _infer_geometry(value)}
  361. elif op == 'near':
  362. value = {'$near': _infer_geometry(value)}
  363. else:
  364. raise NotImplementedError(
  365. 'Geo method "%s" has not been implemented for a %s '
  366. % (op, field._name)
  367. )
  368. return value
  369. def _infer_geometry(value):
  370. """Helper method that tries to infer the $geometry shape for a
  371. given value.
  372. """
  373. if isinstance(value, dict):
  374. if '$geometry' in value:
  375. return value
  376. elif 'coordinates' in value and 'type' in value:
  377. return {'$geometry': value}
  378. raise InvalidQueryError('Invalid $geometry dictionary should have '
  379. 'type and coordinates keys')
  380. elif isinstance(value, (list, set)):
  381. # TODO: shouldn't we test value[0][0][0][0] to see if it is MultiPolygon?
  382. try:
  383. value[0][0][0]
  384. return {'$geometry': {'type': 'Polygon', 'coordinates': value}}
  385. except (TypeError, IndexError):
  386. pass
  387. try:
  388. value[0][0]
  389. return {'$geometry': {'type': 'LineString', 'coordinates': value}}
  390. except (TypeError, IndexError):
  391. pass
  392. try:
  393. value[0]
  394. return {'$geometry': {'type': 'Point', 'coordinates': value}}
  395. except (TypeError, IndexError):
  396. pass
  397. raise InvalidQueryError('Invalid $geometry data. Can be either a '
  398. 'dictionary or (nested) lists of coordinate(s)')
  399. def _prepare_query_for_iterable(field, op, value):
  400. # We need a special check for BaseDocument, because - although it's iterable - using
  401. # it as such in the context of this method is most definitely a mistake.
  402. BaseDocument = _import_class('BaseDocument')
  403. if isinstance(value, BaseDocument):
  404. raise TypeError("When using the `in`, `nin`, `all`, or "
  405. "`near`-operators you can\'t use a "
  406. "`Document`, you must wrap your object "
  407. "in a list (object -> [object]).")
  408. if not hasattr(value, '__iter__'):
  409. raise TypeError("The `in`, `nin`, `all`, or "
  410. "`near`-operators must be applied to an "
  411. "iterable (e.g. a list).")
  412. return [field.prepare_query_value(op, v) for v in value]