collection.py 79 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944
  1. from __future__ import division
  2. import collections
  3. from collections import OrderedDict
  4. import copy
  5. from datetime import datetime
  6. import functools
  7. import itertools
  8. import json
  9. import math
  10. import threading
  11. import time
  12. import types
  13. import warnings
  14. try:
  15. from bson import json_util, SON, BSON
  16. except ImportError:
  17. json_utils = SON = BSON = None
  18. try:
  19. import execjs
  20. except ImportError:
  21. execjs = None
  22. try:
  23. from pymongo import ReturnDocument
  24. except ImportError:
  25. class ReturnDocument(object):
  26. BEFORE = False
  27. AFTER = True
  28. from sentinels import NOTHING
  29. from six import iteritems
  30. from six import iterkeys
  31. from six import itervalues
  32. from six import MAXSIZE
  33. from six import string_types
  34. from six import text_type
  35. from mongomock.command_cursor import CommandCursor
  36. from mongomock import DuplicateKeyError, BulkWriteError
  37. from mongomock.filtering import filter_applies
  38. from mongomock.filtering import iter_key_candidates
  39. from mongomock import helpers
  40. from mongomock import InvalidOperation
  41. from mongomock import ObjectId
  42. from mongomock import OperationFailure
  43. from mongomock.results import BulkWriteResult
  44. from mongomock.results import DeleteResult
  45. from mongomock.results import InsertManyResult
  46. from mongomock.results import InsertOneResult
  47. from mongomock.results import UpdateResult
  48. from mongomock.write_concern import WriteConcern
  49. from mongomock import WriteError
  50. lock = threading.RLock()
  51. def validate_is_mapping(option, value):
  52. if not isinstance(value, collections.Mapping):
  53. raise TypeError('%s must be an instance of dict, bson.son.SON, or '
  54. 'other type that inherits from '
  55. 'collections.Mapping' % (option,))
  56. def validate_is_mutable_mapping(option, value):
  57. if not isinstance(value, collections.MutableMapping):
  58. raise TypeError('%s must be an instance of dict, bson.son.SON, or '
  59. 'other type that inherits from '
  60. 'collections.MutableMapping' % (option,))
  61. def validate_ok_for_replace(replacement):
  62. validate_is_mapping('replacement', replacement)
  63. if replacement:
  64. first = next(iter(replacement))
  65. if first.startswith('$'):
  66. raise ValueError('replacement can not include $ operators')
  67. def validate_ok_for_update(update):
  68. validate_is_mapping('update', update)
  69. if not update:
  70. raise ValueError('update only works with $ operators')
  71. first = next(iter(update))
  72. if not first.startswith('$'):
  73. raise ValueError('update only works with $ operators')
  74. def validate_write_concern_params(**params):
  75. if params:
  76. WriteConcern(**params)
  77. def get_value_by_dot(doc, key):
  78. """Get dictionary value using dotted key"""
  79. result = doc
  80. for i in key.split('.'):
  81. result = result[i]
  82. return result
  83. def set_value_by_dot(doc, key, value):
  84. """Set dictionary value using dotted key"""
  85. result = doc
  86. keys = key.split('.')
  87. for i in keys[:-1]:
  88. if i not in result:
  89. result[i] = {}
  90. result = result[i]
  91. result[keys[-1]] = value
  92. return doc
  93. class BulkWriteOperation(object):
  94. def __init__(self, builder, selector, is_upsert=False):
  95. self.builder = builder
  96. self.selector = selector
  97. self.is_upsert = is_upsert
  98. def upsert(self):
  99. assert not self.is_upsert
  100. return BulkWriteOperation(self.builder, self.selector, is_upsert=True)
  101. def register_remove_op(self, multi):
  102. collection = self.builder.collection
  103. selector = self.selector
  104. def exec_remove():
  105. op_result = collection.remove(selector, multi=multi)
  106. if op_result.get("ok"):
  107. return {'nRemoved': op_result.get('n')}
  108. err = op_result.get("err")
  109. if err:
  110. return {"writeErrors": [err]}
  111. return {}
  112. self.builder.executors.append(exec_remove)
  113. def remove(self):
  114. assert not self.is_upsert
  115. self.register_remove_op(multi=True)
  116. def remove_one(self,):
  117. assert not self.is_upsert
  118. self.register_remove_op(multi=False)
  119. def register_update_op(self, document, multi, **extra_args):
  120. if not extra_args.get("remove"):
  121. validate_ok_for_update(document)
  122. collection = self.builder.collection
  123. selector = self.selector
  124. def exec_update():
  125. result = collection._update(spec=selector, document=document,
  126. multi=multi, upsert=self.is_upsert,
  127. **extra_args)
  128. ret_val = {}
  129. if result.get('upserted'):
  130. ret_val["upserted"] = result.get('upserted')
  131. ret_val["nUpserted"] = result.get('n')
  132. modified = result.get('nModified')
  133. if modified is not None:
  134. ret_val['nModified'] = modified
  135. ret_val['nMatched'] = modified
  136. if result.get('err'):
  137. ret_val['err'] = result.get('err')
  138. return ret_val
  139. self.builder.executors.append(exec_update)
  140. def update(self, document):
  141. self.register_update_op(document, multi=True)
  142. def update_one(self, document):
  143. self.register_update_op(document, multi=False)
  144. def replace_one(self, document):
  145. self.register_update_op(document, multi=False, remove=True)
  146. class BulkOperationBuilder(object):
  147. def __init__(self, collection, ordered=False):
  148. self.collection = collection
  149. self.ordered = ordered
  150. self.results = {}
  151. self.executors = []
  152. self.done = False
  153. self._insert_returns_nModified = True
  154. self._update_returns_nModified = True
  155. def find(self, selector):
  156. return BulkWriteOperation(self, selector)
  157. def insert(self, doc):
  158. def exec_insert():
  159. self.collection.insert(doc)
  160. return {'nInserted': 1}
  161. self.executors.append(exec_insert)
  162. def __aggregate_operation_result(self, total_result, key, value):
  163. agg_val = total_result.get(key)
  164. assert agg_val is not None, "Unknow operation result %s=%s" \
  165. " (unrecognized key)" % (key, value)
  166. if isinstance(agg_val, int):
  167. total_result[key] += value
  168. elif isinstance(agg_val, list):
  169. if key == "upserted":
  170. new_element = {"index": len(agg_val), "_id": value}
  171. agg_val.append(new_element)
  172. else:
  173. agg_val.append(value)
  174. else:
  175. assert False, "Fixme: missed aggreation rule for type: %s for" \
  176. " key {%s=%s}" % (type(agg_val), key, agg_val)
  177. def _set_nModified_policy(self, insert, update):
  178. self._insert_returns_nModified = insert
  179. self._update_returns_nModified = update
  180. def execute(self, write_concern=None):
  181. if not self.executors:
  182. raise InvalidOperation("Bulk operation empty!")
  183. if self.done:
  184. raise InvalidOperation("Bulk operation already executed!")
  185. self.done = True
  186. result = {'nModified': 0, 'nUpserted': 0, 'nMatched': 0,
  187. 'writeErrors': [], 'upserted': [], 'writeConcernErrors': [],
  188. 'nRemoved': 0, 'nInserted': 0}
  189. has_update = False
  190. has_insert = False
  191. broken_nModified_info = False
  192. for execute_func in self.executors:
  193. exec_name = execute_func.__name__
  194. op_result = execute_func()
  195. for (key, value) in op_result.items():
  196. self.__aggregate_operation_result(result, key, value)
  197. if exec_name == "exec_update":
  198. has_update = True
  199. if "nModified" not in op_result:
  200. broken_nModified_info = True
  201. has_insert |= exec_name == "exec_insert"
  202. if broken_nModified_info:
  203. result.pop('nModified')
  204. elif has_insert and self._insert_returns_nModified:
  205. pass
  206. elif has_update and self._update_returns_nModified:
  207. pass
  208. elif self._update_returns_nModified and self._insert_returns_nModified:
  209. pass
  210. else:
  211. result.pop('nModified')
  212. return result
  213. def add_insert(self, doc):
  214. self.insert(doc)
  215. def add_update(self, selector, doc, multi, upsert, collation=None):
  216. write_operation = BulkWriteOperation(self, selector, is_upsert=upsert)
  217. write_operation.register_update_op(doc, multi)
  218. def add_replace(self, selector, doc, upsert, collation=None):
  219. write_operation = BulkWriteOperation(self, selector, is_upsert=upsert)
  220. write_operation.replace_one(doc)
  221. def add_delete(self, selector, just_one, collation=None):
  222. write_operation = BulkWriteOperation(self, selector, is_upsert=False)
  223. write_operation.register_remove_op(not just_one)
  224. class Collection(object):
  225. def __init__(self, db, name):
  226. self.name = name
  227. self.full_name = "{0}.{1}".format(db.name, name)
  228. self.database = db
  229. self._documents = OrderedDict()
  230. self._force_created = False
  231. self._uniques = []
  232. def _is_created(self):
  233. return self._documents or self._uniques or self._force_created
  234. def __repr__(self):
  235. return "Collection({0}, '{1}')".format(self.database, self.name)
  236. def __getitem__(self, name):
  237. return self.database[self.name + '.' + name]
  238. def __getattr__(self, name):
  239. return self.__getitem__(name)
  240. def initialize_unordered_bulk_op(self):
  241. return BulkOperationBuilder(self, ordered=False)
  242. def initialize_ordered_bulk_op(self):
  243. return BulkOperationBuilder(self, ordered=True)
  244. def insert(self, data, manipulate=True, check_keys=True,
  245. continue_on_error=False, **kwargs):
  246. warnings.warn("insert is deprecated. Use insert_one or insert_many "
  247. "instead.", DeprecationWarning, stacklevel=2)
  248. validate_write_concern_params(**kwargs)
  249. return self._insert(data)
  250. def insert_one(self, document):
  251. validate_is_mutable_mapping('document', document)
  252. return InsertOneResult(self._insert(document), acknowledged=True)
  253. def insert_many(self, documents, ordered=True):
  254. if not isinstance(documents, collections.Iterable) or not documents:
  255. raise TypeError('documents must be a non-empty list')
  256. for document in documents:
  257. validate_is_mutable_mapping('document', document)
  258. try:
  259. return InsertManyResult(self._insert(documents), acknowledged=True)
  260. except DuplicateKeyError:
  261. raise BulkWriteError('batch op errors occurred')
  262. def _insert(self, data):
  263. if isinstance(data, list) or isinstance(data, types.GeneratorType):
  264. return [self._insert(item) for item in data]
  265. # Like pymongo, we should fill the _id in the inserted dict (odd behavior,
  266. # but we need to stick to it), so we must patch in-place the data dict
  267. for key in data.keys():
  268. data[key] = helpers.patch_datetime_awareness_in_document(data[key])
  269. if not all(isinstance(k, string_types) for k in data):
  270. raise ValueError("Document keys must be strings")
  271. if BSON:
  272. # bson validation
  273. BSON.encode(data, check_keys=True)
  274. if '_id' not in data:
  275. data['_id'] = ObjectId()
  276. object_id = data['_id']
  277. if isinstance(object_id, dict):
  278. object_id = helpers.hashdict(object_id)
  279. if object_id in self._documents:
  280. raise DuplicateKeyError("E11000 Duplicate Key Error", 11000)
  281. for unique, is_sparse in self._uniques:
  282. find_kwargs = {}
  283. for key, direction in unique:
  284. find_kwargs[key] = data.get(key, None)
  285. answer_count = len(list(self._iter_documents(find_kwargs)))
  286. if answer_count > 0 and not (is_sparse and find_kwargs[key] is None):
  287. raise DuplicateKeyError("E11000 Duplicate Key Error", 11000)
  288. with lock:
  289. self._documents[object_id] = self._internalize_dict(data)
  290. return data['_id']
  291. def _internalize_dict(self, d):
  292. return {k: copy.deepcopy(v) for k, v in iteritems(d)}
  293. def _has_key(self, doc, key):
  294. key_parts = key.split('.')
  295. sub_doc = doc
  296. for part in key_parts:
  297. if part not in sub_doc:
  298. return False
  299. sub_doc = sub_doc[part]
  300. return True
  301. def _remove_key(self, doc, key):
  302. key_parts = key.split('.')
  303. sub_doc = doc
  304. for part in key_parts[:-1]:
  305. sub_doc = sub_doc[part]
  306. del sub_doc[key_parts[-1]]
  307. def update_one(self, filter, update, upsert=False):
  308. validate_ok_for_update(update)
  309. return UpdateResult(self._update(filter, update, upsert=upsert),
  310. acknowledged=True)
  311. def update_many(self, filter, update, upsert=False):
  312. validate_ok_for_update(update)
  313. return UpdateResult(self._update(filter, update, upsert=upsert,
  314. multi=True),
  315. acknowledged=True)
  316. def replace_one(self, filter, replacement, upsert=False):
  317. validate_ok_for_replace(replacement)
  318. return UpdateResult(self._update(filter, replacement, upsert=upsert),
  319. acknowledged=True)
  320. def update(self, spec, document, upsert=False, manipulate=False,
  321. multi=False, check_keys=False, **kwargs):
  322. warnings.warn("update is deprecated. Use replace_one, update_one or "
  323. "update_many instead.", DeprecationWarning, stacklevel=2)
  324. return self._update(spec, document, upsert, manipulate, multi,
  325. check_keys, **kwargs)
  326. def _update(self, spec, document, upsert=False, manipulate=False,
  327. multi=False, check_keys=False, **kwargs):
  328. spec = helpers.patch_datetime_awareness_in_document(spec)
  329. document = helpers.patch_datetime_awareness_in_document(document)
  330. validate_is_mapping('spec', spec)
  331. validate_is_mapping('document', document)
  332. updated_existing = False
  333. upserted_id = None
  334. num_updated = 0
  335. for existing_document in itertools.chain(self._iter_documents(spec), [None]):
  336. # we need was_insert for the setOnInsert update operation
  337. was_insert = False
  338. # the sentinel document means we should do an upsert
  339. if existing_document is None:
  340. if not upsert or num_updated:
  341. continue
  342. # For upsert operation we have first to create a fake existing_document,
  343. # update it like a regular one, then finally insert it
  344. if spec.get('_id') is not None:
  345. _id = spec['_id']
  346. elif document.get('_id') is not None:
  347. _id = document['_id']
  348. else:
  349. _id = ObjectId()
  350. to_insert = dict(spec, _id=_id)
  351. to_insert = self._expand_dots(to_insert)
  352. existing_document = to_insert
  353. was_insert = True
  354. else:
  355. updated_existing = True
  356. num_updated += 1
  357. first = True
  358. subdocument = None
  359. for k, v in iteritems(document):
  360. if k in _updaters.keys():
  361. updater = _updaters[k]
  362. subdocument = self._update_document_fields_with_positional_awareness(
  363. existing_document, v, spec, updater, subdocument)
  364. elif k == '$setOnInsert':
  365. if not was_insert:
  366. continue
  367. subdocument = self._update_document_fields_with_positional_awareness(
  368. existing_document, v, spec, _set_updater, subdocument)
  369. elif k == '$currentDate':
  370. for value in itervalues(v):
  371. if value == {'$type': 'timestamp'}:
  372. raise NotImplementedError('timestamp is not supported so far')
  373. subdocument = self._update_document_fields_with_positional_awareness(
  374. existing_document, v, spec, _current_date_updater, subdocument)
  375. elif k == '$addToSet':
  376. for field, value in iteritems(v):
  377. nested_field_list = field.rsplit('.')
  378. if len(nested_field_list) == 1:
  379. if field not in existing_document:
  380. existing_document[field] = []
  381. # document should be a list append to it
  382. if isinstance(value, dict):
  383. if '$each' in value:
  384. # append the list to the field
  385. existing_document[field] += [
  386. obj for obj in list(value['$each'])
  387. if obj not in existing_document[field]]
  388. continue
  389. if value not in existing_document[field]:
  390. existing_document[field].append(value)
  391. continue
  392. # push to array in a nested attribute
  393. else:
  394. # create nested attributes if they do not exist
  395. subdocument = existing_document
  396. for field in nested_field_list[:-1]:
  397. if field not in subdocument:
  398. subdocument[field] = {}
  399. subdocument = subdocument[field]
  400. # we're pushing a list
  401. push_results = []
  402. if nested_field_list[-1] in subdocument:
  403. # if the list exists, then use that list
  404. push_results = subdocument[
  405. nested_field_list[-1]]
  406. if isinstance(value, dict) and '$each' in value:
  407. push_results += [
  408. obj for obj in list(value['$each'])
  409. if obj not in push_results]
  410. elif value not in push_results:
  411. push_results.append(value)
  412. subdocument[nested_field_list[-1]] = push_results
  413. elif k == '$pull':
  414. for field, value in iteritems(v):
  415. nested_field_list = field.rsplit('.')
  416. # nested fields includes a positional element
  417. # need to find that element
  418. if '$' in nested_field_list:
  419. if not subdocument:
  420. subdocument = self._get_subdocument(
  421. existing_document, spec, nested_field_list)
  422. # value should be a dictionary since we're pulling
  423. pull_results = []
  424. # and the last subdoc should be an array
  425. for obj in subdocument[nested_field_list[-1]]:
  426. if isinstance(obj, dict):
  427. for pull_key, pull_value in iteritems(value):
  428. if obj[pull_key] != pull_value:
  429. pull_results.append(obj)
  430. continue
  431. if obj != value:
  432. pull_results.append(obj)
  433. # cannot write to doc directly as it doesn't save to
  434. # existing_document
  435. subdocument[nested_field_list[-1]] = pull_results
  436. else:
  437. arr = existing_document
  438. for field in nested_field_list:
  439. if field not in arr:
  440. break
  441. arr = arr[field]
  442. if not isinstance(arr, list):
  443. continue
  444. arr_copy = copy.deepcopy(arr)
  445. if isinstance(value, dict):
  446. for obj in arr_copy:
  447. if filter_applies(value, obj):
  448. arr.remove(obj)
  449. else:
  450. for obj in arr_copy:
  451. if value == obj:
  452. arr.remove(obj)
  453. elif k == '$pullAll':
  454. for field, value in iteritems(v):
  455. nested_field_list = field.rsplit('.')
  456. if len(nested_field_list) == 1:
  457. if field in existing_document:
  458. arr = existing_document[field]
  459. existing_document[field] = [
  460. obj for obj in arr if obj not in value]
  461. continue
  462. else:
  463. subdocument = existing_document
  464. for nested_field in nested_field_list[:-1]:
  465. if nested_field not in subdocument:
  466. break
  467. subdocument = subdocument[nested_field]
  468. if nested_field_list[-1] in subdocument:
  469. arr = subdocument[nested_field_list[-1]]
  470. subdocument[nested_field_list[-1]] = [
  471. obj for obj in arr if obj not in value]
  472. elif k == '$push':
  473. for field, value in iteritems(v):
  474. nested_field_list = field.rsplit('.')
  475. if len(nested_field_list) == 1:
  476. if field not in existing_document:
  477. existing_document[field] = []
  478. # document should be a list
  479. # append to it
  480. if isinstance(value, dict):
  481. if '$each' in value:
  482. # append the list to the field
  483. existing_document[field] += list(value['$each'])
  484. continue
  485. existing_document[field].append(value)
  486. continue
  487. # nested fields includes a positional element
  488. # need to find that element
  489. elif '$' in nested_field_list:
  490. if not subdocument:
  491. subdocument = self._get_subdocument(
  492. existing_document, spec, nested_field_list)
  493. # we're pushing a list
  494. push_results = []
  495. if nested_field_list[-1] in subdocument:
  496. # if the list exists, then use that list
  497. push_results = subdocument[nested_field_list[-1]]
  498. if isinstance(value, dict):
  499. # check to see if we have the format
  500. # { '$each': [] }
  501. if '$each' in value:
  502. push_results += list(value['$each'])
  503. else:
  504. push_results.append(value)
  505. else:
  506. push_results.append(value)
  507. # cannot write to doc directly as it doesn't save to
  508. # existing_document
  509. subdocument[nested_field_list[-1]] = push_results
  510. # push to array in a nested attribute
  511. else:
  512. # create nested attributes if they do not exist
  513. subdocument = existing_document
  514. for field in nested_field_list[:-1]:
  515. if field not in subdocument:
  516. subdocument[field] = {}
  517. subdocument = subdocument[field]
  518. # we're pushing a list
  519. push_results = []
  520. if nested_field_list[-1] in subdocument:
  521. # if the list exists, then use that list
  522. push_results = subdocument[nested_field_list[-1]]
  523. if isinstance(value, dict) and '$each' in value:
  524. push_results += list(value['$each'])
  525. else:
  526. push_results.append(value)
  527. subdocument[nested_field_list[-1]] = push_results
  528. else:
  529. if first:
  530. # replace entire document
  531. for key in document.keys():
  532. if key.startswith('$'):
  533. # can't mix modifiers with non-modifiers in
  534. # update
  535. raise ValueError('field names cannot start with $ [{}]'.format(k))
  536. _id = spec.get('_id', existing_document.get('_id'))
  537. existing_document.clear()
  538. if _id:
  539. existing_document['_id'] = _id
  540. existing_document.update(self._internalize_dict(document))
  541. if existing_document['_id'] != _id:
  542. raise OperationFailure(
  543. "The _id field cannot be changed from {0} to {1}"
  544. .format(existing_document['_id'], _id))
  545. break
  546. else:
  547. # can't mix modifiers with non-modifiers in update
  548. raise ValueError(
  549. 'Invalid modifier specified: {}'.format(k))
  550. first = False
  551. # if empty document comes
  552. if len(document) == 0:
  553. _id = spec.get('_id', existing_document.get('_id'))
  554. existing_document.clear()
  555. if _id:
  556. existing_document['_id'] = _id
  557. if was_insert:
  558. upserted_id = self._insert(existing_document)
  559. if not multi:
  560. break
  561. return {
  562. text_type("connectionId"): self.database.client._id,
  563. text_type("err"): None,
  564. text_type("n"): num_updated,
  565. text_type("nModified"): num_updated if updated_existing else 0,
  566. text_type("ok"): 1,
  567. text_type("upserted"): upserted_id,
  568. text_type("updatedExisting"): updated_existing,
  569. }
  570. def _get_subdocument(self, existing_document, spec, nested_field_list):
  571. """This method retrieves the subdocument of the existing_document.nested_field_list.
  572. It uses the spec to filter through the items. It will continue to grab nested documents
  573. until it can go no further. It will then return the subdocument that was last saved.
  574. '$' is the positional operator, so we use the $elemMatch in the spec to find the right
  575. subdocument in the array.
  576. """
  577. # current document in view
  578. doc = existing_document
  579. # previous document in view
  580. subdocument = existing_document
  581. # current spec in view
  582. subspec = spec
  583. # walk down the dictionary
  584. for subfield in nested_field_list:
  585. if subfield == '$':
  586. # positional element should have the equivalent elemMatch in the
  587. # query
  588. subspec = subspec['$elemMatch']
  589. for item in doc:
  590. # iterate through
  591. if filter_applies(subspec, item):
  592. # found the matching item save the parent
  593. subdocument = doc
  594. # save the item
  595. doc = item
  596. break
  597. continue
  598. subdocument = doc
  599. doc = doc[subfield]
  600. if subfield not in subspec:
  601. break
  602. subspec = subspec[subfield]
  603. return subdocument
  604. def _expand_dots(self, doc):
  605. expanded = {}
  606. paths = {}
  607. for k, v in iteritems(doc):
  608. key_parts = k.split('.')
  609. sub_doc = v
  610. for i in reversed(range(1, len(key_parts))):
  611. key = key_parts[i]
  612. sub_doc = {key: sub_doc}
  613. key = key_parts[0]
  614. if key in expanded:
  615. raise WriteError("cannot infer query fields to set, "
  616. "both paths '%s' and '%s' are matched"
  617. % (k, paths[key]))
  618. paths[key] = k
  619. expanded[key] = sub_doc
  620. return expanded
  621. def _discard_operators(self, doc):
  622. # TODO(this looks a little too naive...)
  623. return {k: v for k, v in iteritems(doc) if not k.startswith("$")}
  624. def find(self, filter=None, projection=None, skip=0, limit=0,
  625. no_cursor_timeout=False, cursor_type=None, sort=None,
  626. allow_partial_results=False, oplog_replay=False, modifiers=None,
  627. batch_size=0, manipulate=True, collation=None):
  628. spec = filter
  629. if spec is None:
  630. spec = {}
  631. validate_is_mapping('filter', spec)
  632. return Cursor(self, spec, sort, projection, skip, limit, collation=collation)
  633. def _get_dataset(self, spec, sort, fields, as_class):
  634. dataset = (self._copy_only_fields(document, fields, as_class)
  635. for document in self._iter_documents(spec))
  636. if sort:
  637. for sortKey, sortDirection in reversed(sort):
  638. dataset = iter(sorted(
  639. dataset, key=lambda x: _resolve_sort_key(sortKey, x),
  640. reverse=sortDirection < 0))
  641. return dataset
  642. def _copy_field(self, obj, container):
  643. if isinstance(obj, list):
  644. new = []
  645. for item in obj:
  646. new.append(self._copy_field(item, container))
  647. return new
  648. if isinstance(obj, dict):
  649. new = container()
  650. for key, value in obj.items():
  651. new[key] = self._copy_field(value, container)
  652. return new
  653. else:
  654. return copy.copy(obj)
  655. def _extract_projection_operators(self, fields):
  656. """Removes and returns fields with projection operators."""
  657. result = {}
  658. allowed_projection_operators = {'$elemMatch'}
  659. for key, value in iteritems(fields):
  660. if isinstance(value, dict):
  661. for op in value:
  662. if op not in allowed_projection_operators:
  663. raise ValueError('Unsupported projection option: {}'.format(op))
  664. result[key] = value
  665. for key in result:
  666. del fields[key]
  667. return result
  668. def _apply_projection_operators(self, ops, doc, doc_copy):
  669. """Applies projection operators to copied document."""
  670. for field, op in iteritems(ops):
  671. if field not in doc_copy:
  672. if field in doc:
  673. # field was not copied yet (since we are in include mode)
  674. doc_copy[field] = doc[field]
  675. else:
  676. # field doesn't exist in original document, no work to do
  677. continue
  678. if '$elemMatch' in op:
  679. if isinstance(doc_copy[field], list):
  680. # find the first item that matches
  681. matched = False
  682. for item in doc_copy[field]:
  683. if filter_applies(op['$elemMatch'], item):
  684. matched = True
  685. doc_copy[field] = [item]
  686. break
  687. # nothing have matched
  688. if not matched:
  689. del doc_copy[field]
  690. else:
  691. # remove the field since there is nothing to iterate
  692. del doc_copy[field]
  693. def _copy_only_fields(self, doc, fields, container):
  694. """Copy only the specified fields."""
  695. if fields is None:
  696. return self._copy_field(doc, container)
  697. else:
  698. if not fields:
  699. fields = {"_id": 1}
  700. if not isinstance(fields, dict):
  701. fields = helpers._fields_list_to_dict(fields)
  702. # we can pass in something like {"_id":0, "field":1}, so pull the id
  703. # value out and hang on to it until later
  704. id_value = fields.pop('_id', 1)
  705. # filter out fields with projection operators, we will take care of them later
  706. projection_operators = self._extract_projection_operators(fields)
  707. # other than the _id field, all fields must be either includes or
  708. # excludes, this can evaluate to 0
  709. if len(set(list(fields.values()))) > 1:
  710. raise ValueError(
  711. 'You cannot currently mix including and excluding fields.')
  712. # if we have novalues passed in, make a doc_copy based on the
  713. # id_value
  714. if len(list(fields.values())) == 0:
  715. if id_value == 1:
  716. doc_copy = container()
  717. else:
  718. doc_copy = self._copy_field(doc, container)
  719. # if 1 was passed in as the field values, include those fields
  720. elif list(fields.values())[0] == 1:
  721. doc_copy = container()
  722. for key in fields:
  723. key_parts = key.split('.')
  724. subdocument = doc
  725. subdocument_copy = doc_copy
  726. last_copy = subdocument_copy
  727. full_key_path_found = True
  728. for key_part in key_parts[:-1]:
  729. if key_part not in subdocument:
  730. full_key_path_found = False
  731. break
  732. subdocument = subdocument[key_part]
  733. last_copy = subdocument_copy
  734. subdocument_copy = subdocument_copy.setdefault(key_part, {})
  735. if full_key_path_found:
  736. last_key = key_parts[-1]
  737. if isinstance(subdocument, dict) and last_key in subdocument:
  738. subdocument_copy[last_key] = subdocument[last_key]
  739. elif isinstance(subdocument, (list, tuple)):
  740. subdocument = [{last_key: x[last_key]}
  741. for x in subdocument if last_key in x]
  742. if subdocument:
  743. last_copy[key_parts[-2]] = subdocument
  744. # otherwise, exclude the fields passed in
  745. else:
  746. doc_copy = self._copy_field(doc, container)
  747. for key in fields:
  748. key_parts = key.split('.')
  749. subdocument_copy = doc_copy
  750. full_key_path_found = True
  751. for key_part in key_parts[:-1]:
  752. if key_part not in subdocument_copy:
  753. full_key_path_found = False
  754. break
  755. subdocument_copy = subdocument_copy[key_part]
  756. if not full_key_path_found or key_parts[-1] not in subdocument_copy:
  757. continue
  758. del subdocument_copy[key_parts[-1]]
  759. # set the _id value if we requested it, otherwise remove it
  760. if id_value == 0:
  761. doc_copy.pop('_id', None)
  762. else:
  763. if '_id' in doc:
  764. doc_copy['_id'] = doc['_id']
  765. fields['_id'] = id_value # put _id back in fields
  766. # time to apply the projection operators and put back their fields
  767. self._apply_projection_operators(projection_operators, doc, doc_copy)
  768. for field, op in iteritems(projection_operators):
  769. fields[field] = op
  770. return doc_copy
  771. def _update_document_fields(self, doc, fields, updater):
  772. """Implements the $set behavior on an existing document"""
  773. for k, v in iteritems(fields):
  774. self._update_document_single_field(doc, k, v, updater)
  775. def _update_document_fields_positional(self, doc, fields, spec, updater,
  776. subdocument=None):
  777. """Implements the $set behavior on an existing document"""
  778. for k, v in iteritems(fields):
  779. if '$' in k:
  780. field_name_parts = k.split('.')
  781. if not subdocument:
  782. current_doc = doc
  783. subspec = spec
  784. for part in field_name_parts[:-1]:
  785. if part == '$':
  786. subspec = subspec.get('$elemMatch', subspec)
  787. for item in current_doc:
  788. if filter_applies(subspec, item):
  789. current_doc = item
  790. break
  791. continue
  792. new_spec = {}
  793. for el in subspec:
  794. if el.startswith(part):
  795. if len(el.split(".")) > 1:
  796. new_spec[".".join(
  797. el.split(".")[1:])] = subspec[el]
  798. else:
  799. new_spec = subspec[el]
  800. subspec = new_spec
  801. current_doc = current_doc[part]
  802. subdocument = current_doc
  803. if (field_name_parts[-1] == '$' and
  804. isinstance(subdocument, list)):
  805. for i, doc in enumerate(subdocument):
  806. if filter_applies(subspec, doc):
  807. subdocument[i] = v
  808. break
  809. continue
  810. updater(subdocument, field_name_parts[-1], v)
  811. continue
  812. # otherwise, we handle it the standard way
  813. self._update_document_single_field(doc, k, v, updater)
  814. return subdocument
  815. def _update_document_fields_with_positional_awareness(self, existing_document, v, spec,
  816. updater, subdocument):
  817. positional = any('$' in key for key in iterkeys(v))
  818. if positional:
  819. return self._update_document_fields_positional(
  820. existing_document, v, spec, updater, subdocument)
  821. self._update_document_fields(existing_document, v, updater)
  822. return subdocument
  823. def _update_document_single_field(self, doc, field_name, field_value, updater):
  824. field_name_parts = field_name.split(".")
  825. for part in field_name_parts[:-1]:
  826. if isinstance(doc, list):
  827. try:
  828. if part == '$':
  829. doc = doc[0]
  830. else:
  831. doc = doc[int(part)]
  832. continue
  833. except ValueError:
  834. pass
  835. elif isinstance(doc, dict):
  836. if updater is _unset_updater and part not in doc:
  837. # If the parent doesn't exists, so does it child.
  838. return
  839. doc = doc.setdefault(part, {})
  840. else:
  841. return
  842. field_name = field_name_parts[-1]
  843. if isinstance(doc, list):
  844. try:
  845. doc[int(field_name)] = field_value
  846. except IndexError:
  847. pass
  848. else:
  849. updater(doc, field_name, field_value)
  850. def _iter_documents(self, filter=None):
  851. return (document for document in list(itervalues(self._documents))
  852. if filter_applies(filter, document))
  853. def find_one(self, filter=None, *args, **kwargs):
  854. # Allow calling find_one with a non-dict argument that gets used as
  855. # the id for the query.
  856. if filter is None:
  857. filter = {}
  858. if not isinstance(filter, collections.Mapping):
  859. filter = {'_id': filter}
  860. try:
  861. return next(self.find(filter, *args, **kwargs))
  862. except StopIteration:
  863. return None
  864. def find_one_and_delete(self, filter, projection=None, sort=None, **kwargs):
  865. kwargs['remove'] = True
  866. validate_is_mapping('filter', filter)
  867. return self._find_and_modify(filter, projection, sort=sort, **kwargs)
  868. def find_one_and_replace(self, filter, replacement,
  869. projection=None, sort=None, upsert=False,
  870. return_document=ReturnDocument.BEFORE, **kwargs):
  871. validate_is_mapping('filter', filter)
  872. validate_ok_for_replace(replacement)
  873. return self._find_and_modify(filter, projection, replacement, upsert,
  874. sort, return_document, **kwargs)
  875. def find_one_and_update(self, filter, update,
  876. projection=None, sort=None, upsert=False,
  877. return_document=ReturnDocument.BEFORE, **kwargs):
  878. validate_is_mapping('filter', filter)
  879. validate_ok_for_update(update)
  880. return self._find_and_modify(filter, projection, update, upsert,
  881. sort, return_document, **kwargs)
  882. def find_and_modify(self, query={}, update=None, upsert=False, sort=None,
  883. full_response=False, manipulate=False, fields=None, **kwargs):
  884. warnings.warn("find_and_modify is deprecated, use find_one_and_delete"
  885. ", find_one_and_replace, or find_one_and_update instead",
  886. DeprecationWarning, stacklevel=2)
  887. if 'projection' in kwargs:
  888. raise TypeError("find_and_modify() got an unexpected keyword argument 'projection'")
  889. return self._find_and_modify(query, update=update, upsert=upsert,
  890. sort=sort, projection=fields, **kwargs)
  891. def _find_and_modify(self, query, projection=None, update=None,
  892. upsert=False, sort=None,
  893. return_document=ReturnDocument.BEFORE, **kwargs):
  894. remove = kwargs.get("remove", False)
  895. if kwargs.get("new", False) and remove:
  896. # message from mongodb
  897. raise OperationFailure("remove and returnNew can't co-exist")
  898. if not (remove or update):
  899. raise ValueError("Must either update or remove")
  900. if remove and update:
  901. raise ValueError("Can't do both update and remove")
  902. old = self.find_one(query, projection=projection, sort=sort)
  903. if not old and not upsert:
  904. return
  905. if old and '_id' in old:
  906. query = {'_id': old['_id']}
  907. if remove:
  908. self.delete_one(query)
  909. else:
  910. self._update(query, update, upsert)
  911. if return_document is ReturnDocument.AFTER or kwargs.get('new'):
  912. return self.find_one(query, projection)
  913. return old
  914. def save(self, to_save, manipulate=True, check_keys=True, **kwargs):
  915. warnings.warn("save is deprecated. Use insert_one or replace_one "
  916. "instead", DeprecationWarning, stacklevel=2)
  917. validate_is_mutable_mapping("to_save", to_save)
  918. validate_write_concern_params(**kwargs)
  919. if "_id" not in to_save:
  920. return self.insert(to_save)
  921. else:
  922. self._update({"_id": to_save["_id"]}, to_save, True,
  923. manipulate, check_keys=True, **kwargs)
  924. return to_save.get("_id", None)
  925. def delete_one(self, filter):
  926. validate_is_mapping('filter', filter)
  927. return DeleteResult(self._delete(filter), True)
  928. def delete_many(self, filter):
  929. validate_is_mapping('filter', filter)
  930. return DeleteResult(self._delete(filter, multi=True), True)
  931. def _delete(self, filter, multi=False):
  932. filter = helpers.patch_datetime_awareness_in_document(filter)
  933. if filter is None:
  934. filter = {}
  935. if not isinstance(filter, collections.Mapping):
  936. filter = {'_id': filter}
  937. to_delete = list(self.find(filter))
  938. deleted_count = 0
  939. for doc in to_delete:
  940. doc_id = doc['_id']
  941. if isinstance(doc_id, dict):
  942. doc_id = helpers.hashdict(doc_id)
  943. del self._documents[doc_id]
  944. deleted_count += 1
  945. if not multi:
  946. break
  947. return {
  948. "connectionId": self.database.client._id,
  949. "n": deleted_count,
  950. "ok": 1.0,
  951. "err": None,
  952. }
  953. def remove(self, spec_or_id=None, multi=True, **kwargs):
  954. warnings.warn("remove is deprecated. Use delete_one or delete_many "
  955. "instead.", DeprecationWarning, stacklevel=2)
  956. validate_write_concern_params(**kwargs)
  957. return self._delete(spec_or_id, multi=multi)
  958. def count(self, filter=None, **kwargs):
  959. if filter is None:
  960. return len(self._documents)
  961. else:
  962. return len(list(self._iter_documents(filter)))
  963. def drop(self):
  964. self.database.drop_collection(self.name)
  965. def ensure_index(self, key_or_list, cache_for=300, **kwargs):
  966. self.create_index(key_or_list, cache_for, **kwargs)
  967. def create_index(self, key_or_list, cache_for=300, **kwargs):
  968. if kwargs.pop('unique', False):
  969. self._uniques.append((helpers.index_list(key_or_list), kwargs.pop('sparse', False)))
  970. def drop_index(self, index_or_name):
  971. pass
  972. def drop_indexes(self):
  973. self._uniques = []
  974. def reindex(self):
  975. pass
  976. def list_indexes(self):
  977. return {}
  978. def index_information(self):
  979. return {}
  980. def map_reduce(self, map_func, reduce_func, out, full_response=False,
  981. query=None, limit=0):
  982. if execjs is None:
  983. raise NotImplementedError(
  984. "PyExecJS is required in order to run Map-Reduce. "
  985. "Use 'pip install pyexecjs pymongo' to support Map-Reduce mock."
  986. )
  987. if limit == 0:
  988. limit = None
  989. start_time = time.clock()
  990. out_collection = None
  991. reduced_rows = None
  992. full_dict = {
  993. 'counts': {
  994. 'input': 0,
  995. 'reduce': 0,
  996. 'emit': 0,
  997. 'output': 0},
  998. 'timeMillis': 0,
  999. 'ok': 1.0,
  1000. 'result': None}
  1001. map_ctx = execjs.compile("""
  1002. function doMap(fnc, docList) {
  1003. var mappedDict = {};
  1004. function emit(key, val) {
  1005. if (key['$oid']) {
  1006. mapped_key = '$oid' + key['$oid'];
  1007. }
  1008. else {
  1009. mapped_key = key;
  1010. }
  1011. if(!mappedDict[mapped_key]) {
  1012. mappedDict[mapped_key] = [];
  1013. }
  1014. mappedDict[mapped_key].push(val);
  1015. }
  1016. mapper = eval('('+fnc+')');
  1017. var mappedList = new Array();
  1018. for(var i=0; i<docList.length; i++) {
  1019. var thisDoc = eval('('+docList[i]+')');
  1020. var mappedVal = (mapper).call(thisDoc);
  1021. }
  1022. return mappedDict;
  1023. }
  1024. """)
  1025. reduce_ctx = execjs.compile("""
  1026. function doReduce(fnc, docList) {
  1027. var reducedList = new Array();
  1028. reducer = eval('('+fnc+')');
  1029. for(var key in docList) {
  1030. var reducedVal = {'_id': key,
  1031. 'value': reducer(key, docList[key])};
  1032. reducedList.push(reducedVal);
  1033. }
  1034. return reducedList;
  1035. }
  1036. """)
  1037. doc_list = [json.dumps(doc, default=json_util.default)
  1038. for doc in self.find(query)]
  1039. mapped_rows = map_ctx.call('doMap', map_func, doc_list)
  1040. reduced_rows = reduce_ctx.call('doReduce', reduce_func, mapped_rows)[:limit]
  1041. for reduced_row in reduced_rows:
  1042. if reduced_row['_id'].startswith('$oid'):
  1043. reduced_row['_id'] = ObjectId(reduced_row['_id'][4:])
  1044. reduced_rows = sorted(reduced_rows, key=lambda x: x['_id'])
  1045. if full_response:
  1046. full_dict['counts']['input'] = len(doc_list)
  1047. for key in mapped_rows.keys():
  1048. emit_count = len(mapped_rows[key])
  1049. full_dict['counts']['emit'] += emit_count
  1050. if emit_count > 1:
  1051. full_dict['counts']['reduce'] += 1
  1052. full_dict['counts']['output'] = len(reduced_rows)
  1053. if isinstance(out, (str, bytes)):
  1054. out_collection = getattr(self.database, out)
  1055. out_collection.drop()
  1056. out_collection.insert(reduced_rows)
  1057. ret_val = out_collection
  1058. full_dict['result'] = out
  1059. elif isinstance(out, SON) and out.get('replace') and out.get('db'):
  1060. # Must be of the format SON([('replace','results'),('db','outdb')])
  1061. out_db = getattr(self.database._client, out['db'])
  1062. out_collection = getattr(out_db, out['replace'])
  1063. out_collection.insert(reduced_rows)
  1064. ret_val = out_collection
  1065. full_dict['result'] = {'db': out['db'], 'collection': out['replace']}
  1066. elif isinstance(out, dict) and out.get('inline'):
  1067. ret_val = reduced_rows
  1068. full_dict['result'] = reduced_rows
  1069. else:
  1070. raise TypeError("'out' must be an instance of string, dict or bson.SON")
  1071. full_dict['timeMillis'] = int(round((time.clock() - start_time) * 1000))
  1072. if full_response:
  1073. ret_val = full_dict
  1074. return ret_val
  1075. def inline_map_reduce(self, map_func, reduce_func, full_response=False,
  1076. query=None, limit=0):
  1077. return self.map_reduce(
  1078. map_func, reduce_func, {'inline': 1}, full_response, query, limit)
  1079. def distinct(self, key, filter=None):
  1080. return self.find(filter).distinct(key)
  1081. def group(self, key, condition, initial, reduce, finalize=None):
  1082. if execjs is None:
  1083. raise NotImplementedError(
  1084. "PyExecJS is required in order to use group. "
  1085. "Use 'pip install pyexecjs pymongo' to support group mock."
  1086. )
  1087. reduce_ctx = execjs.compile("""
  1088. function doReduce(fnc, docList) {
  1089. reducer = eval('('+fnc+')');
  1090. for(var i=0, l=docList.length; i<l; i++) {
  1091. try {
  1092. reducedVal = reducer(docList[i-1], docList[i]);
  1093. }
  1094. catch (err) {
  1095. continue;
  1096. }
  1097. }
  1098. return docList[docList.length - 1];
  1099. }
  1100. """)
  1101. ret_array = []
  1102. doc_list_copy = []
  1103. ret_array_copy = []
  1104. reduced_val = {}
  1105. doc_list = [doc for doc in self.find(condition)]
  1106. for doc in doc_list:
  1107. doc_copy = copy.deepcopy(doc)
  1108. for k in doc:
  1109. if isinstance(doc[k], ObjectId):
  1110. doc_copy[k] = str(doc[k])
  1111. if k not in key and k not in reduce:
  1112. del doc_copy[k]
  1113. for initial_key in initial:
  1114. if initial_key in doc.keys():
  1115. pass
  1116. else:
  1117. doc_copy[initial_key] = initial[initial_key]
  1118. doc_list_copy.append(doc_copy)
  1119. doc_list = doc_list_copy
  1120. for k in key:
  1121. doc_list = sorted(doc_list, key=lambda x: _resolve_key(k, x))
  1122. for k in key:
  1123. if not isinstance(k, helpers.basestring):
  1124. raise TypeError(
  1125. "Keys must be a list of key names, "
  1126. "each an instance of %s" % helpers.basestring.__name__)
  1127. for k2, group in itertools.groupby(doc_list, lambda item: item[k]):
  1128. group_list = ([x for x in group])
  1129. reduced_val = reduce_ctx.call('doReduce', reduce, group_list)
  1130. ret_array.append(reduced_val)
  1131. for doc in ret_array:
  1132. doc_copy = copy.deepcopy(doc)
  1133. for k in doc:
  1134. if k not in key and k not in initial.keys():
  1135. del doc_copy[k]
  1136. ret_array_copy.append(doc_copy)
  1137. ret_array = ret_array_copy
  1138. return ret_array
  1139. def aggregate(self, pipeline, **kwargs):
  1140. pipeline_operators = [
  1141. '$project',
  1142. '$match',
  1143. '$redact',
  1144. '$limit',
  1145. '$skip',
  1146. '$unwind',
  1147. '$group',
  1148. '$sample'
  1149. '$sort',
  1150. '$geoNear',
  1151. '$lookup',
  1152. '$out',
  1153. '$indexStats']
  1154. group_operators = [
  1155. '$addToSet',
  1156. '$first',
  1157. '$last',
  1158. '$max',
  1159. '$min',
  1160. '$avg',
  1161. '$push',
  1162. '$sum',
  1163. '$stdDevPop',
  1164. '$stdDevSamp']
  1165. project_operators = [
  1166. '$max',
  1167. '$min',
  1168. '$avg',
  1169. '$sum',
  1170. '$stdDevPop',
  1171. '$stdDevSamp',
  1172. '$arrayElemAt'
  1173. ]
  1174. boolean_operators = ['$and', '$or', '$not'] # noqa
  1175. set_operators = [ # noqa
  1176. '$setEquals',
  1177. '$setIntersection',
  1178. '$setDifference',
  1179. '$setUnion',
  1180. '$setIsSubset',
  1181. '$anyElementTrue',
  1182. '$allElementsTrue']
  1183. comparison_operators = [ # noqa
  1184. '$cmp',
  1185. '$eq',
  1186. '$gt',
  1187. '$gte',
  1188. '$lt',
  1189. '$lte',
  1190. '$ne']
  1191. arithmetic_operators = [ # noqa
  1192. '$abs',
  1193. '$add',
  1194. '$ceil',
  1195. '$divide',
  1196. '$exp',
  1197. '$floor',
  1198. '$ln',
  1199. '$log',
  1200. '$log10',
  1201. '$mod',
  1202. '$multiply',
  1203. '$pow',
  1204. '$sqrt',
  1205. '$subtract',
  1206. '$trunc']
  1207. string_operators = [ # noqa
  1208. '$concat',
  1209. '$strcasecmp',
  1210. '$substr',
  1211. '$toLower',
  1212. '$toUpper']
  1213. text_search_operators = ['$meta'] # noqa
  1214. array_operators = [ # noqa
  1215. '$arrayElemAt',
  1216. '$concatArrays',
  1217. '$filter',
  1218. '$isArray',
  1219. '$size',
  1220. '$slice']
  1221. projection_operators = ['$map', '$let', '$literal'] # noqa
  1222. date_operators = [ # noqa
  1223. '$dayOfYear',
  1224. '$dayOfMonth',
  1225. '$dayOfWeek',
  1226. '$year',
  1227. '$month',
  1228. '$week',
  1229. '$hour',
  1230. '$minute',
  1231. '$second',
  1232. '$millisecond',
  1233. '$dateToString']
  1234. conditional_operators = ['$cond', '$ifNull'] # noqa
  1235. def _handle_arithmetic_operator(operator, values, doc_dict):
  1236. if operator == '$abs':
  1237. return abs(_parse_expression(values, doc_dict))
  1238. elif operator == '$ceil':
  1239. return math.ceil(_parse_expression(values, doc_dict))
  1240. elif operator == '$divide':
  1241. assert len(values) == 2, 'divide must have only 2 items'
  1242. return _parse_expression(values[0], doc_dict) / _parse_expression(values[1],
  1243. doc_dict)
  1244. elif operator == '$exp':
  1245. return math.exp(_parse_expression(values, doc_dict))
  1246. elif operator == '$floor':
  1247. return math.floor(_parse_expression(values, doc_dict))
  1248. elif operator == '$ln':
  1249. return math.log(_parse_expression(values, doc_dict))
  1250. elif operator == '$log':
  1251. assert len(values) == 2, 'log must have only 2 items'
  1252. return math.log(_parse_expression(values[0], doc_dict),
  1253. _parse_expression(values[1], doc_dict))
  1254. elif operator == '$log10':
  1255. return math.log10(_parse_expression(values, doc_dict))
  1256. elif operator == '$mod':
  1257. assert len(values) == 2, 'mod must have only 2 items'
  1258. return math.fmod(_parse_expression(values[0], doc_dict),
  1259. _parse_expression(values[1], doc_dict))
  1260. elif operator == '$pow':
  1261. assert len(values) == 2, 'pow must have only 2 items'
  1262. return math.pow(_parse_expression(values[0], doc_dict),
  1263. _parse_expression(values[1], doc_dict))
  1264. elif operator == '$sqrt':
  1265. return math.sqrt(_parse_expression(values, doc_dict))
  1266. elif operator == '$subtract':
  1267. assert len(values) == 2, 'subtract must have only 2 items'
  1268. return _parse_expression(values[0], doc_dict) - _parse_expression(values[1],
  1269. doc_dict)
  1270. else:
  1271. raise NotImplementedError("Although '%s' is a valid aritmetic operator for the "
  1272. "aggregation pipeline, it is currently not implemented "
  1273. " in Mongomock." % operator)
  1274. def _handle_comparison_operator(operator, values, doc_dict):
  1275. assert len(values) == 2, 'Comparison requires two expressions'
  1276. if operator == '$eq':
  1277. return _parse_expression(values[0], doc_dict) == \
  1278. _parse_expression(values[1], doc_dict)
  1279. elif operator == '$gt':
  1280. return _parse_expression(values[0], doc_dict) > \
  1281. _parse_expression(values[1], doc_dict)
  1282. elif operator == '$gte':
  1283. return _parse_expression(values[0], doc_dict) >= \
  1284. _parse_expression(values[1], doc_dict)
  1285. elif operator == '$lt':
  1286. return _parse_expression(values[0], doc_dict) < \
  1287. _parse_expression(values[1], doc_dict)
  1288. elif operator == '$lte':
  1289. return _parse_expression(values[0], doc_dict) <= \
  1290. _parse_expression(values[1], doc_dict)
  1291. elif operator == '$ne':
  1292. return _parse_expression(values[0], doc_dict) != \
  1293. _parse_expression(values[1], doc_dict)
  1294. else:
  1295. raise NotImplementedError(
  1296. "Although '%s' is a valid comparison operator for the "
  1297. "aggregation pipeline, it is currently not implemented "
  1298. " in Mongomock." % operator)
  1299. def _handle_date_operator(operator, values, doc_dict):
  1300. out_value = _parse_expression(values, doc_dict)
  1301. if operator == '$dayOfYear':
  1302. return out_value.timetuple().tm_yday
  1303. elif operator == '$dayOfMonth':
  1304. return out_value.day
  1305. elif operator == '$dayOfWeek':
  1306. return out_value.isoweekday()
  1307. elif operator == '$year':
  1308. return out_value.year
  1309. elif operator == '$month':
  1310. return out_value.month
  1311. elif operator == '$week':
  1312. return out_value.isocalendar()[1]
  1313. elif operator == '$hour':
  1314. return out_value.hour
  1315. elif operator == '$minute':
  1316. return out_value.minute
  1317. elif operator == '$second':
  1318. return out_value.second
  1319. elif operator == '$millisecond':
  1320. return int(out_value.microsecond / 1000)
  1321. else:
  1322. raise NotImplementedError(
  1323. "Although '%s' is a valid date operator for the "
  1324. "aggregation pipeline, it is currently not implemented "
  1325. " in Mongomock." % operator)
  1326. def _handle_array_operator(operator, values, doc_dict):
  1327. out_value = _parse_expression(values, doc_dict)
  1328. if operator == '$size':
  1329. return len(out_value)
  1330. else:
  1331. raise NotImplementedError(
  1332. "Although '%s' is a valid date operator for the "
  1333. "aggregation pipeline, it is currently not implemented "
  1334. " in Mongomock." % operator)
  1335. def _handle_conditional_operator(operator, values, doc_dict):
  1336. if operator == '$ifNull':
  1337. field, fallback = values
  1338. try:
  1339. out_value = _parse_expression(field, doc_dict)
  1340. except KeyError:
  1341. return fallback
  1342. return out_value if out_value is not None else fallback
  1343. else:
  1344. raise NotImplementedError(
  1345. "Although '%s' is a valid date operator for the "
  1346. "aggregation pipeline, it is currently not implemented "
  1347. " in Mongomock." % operator)
  1348. def _handle_project_operator(operator, values, doc_dict):
  1349. if operator == '$min':
  1350. if len(values) > 2:
  1351. raise NotImplementedError("Although %d is a valid amount of elements in "
  1352. "aggregation pipeline, it is currently not "
  1353. " implemented in Mongomock" % len(values))
  1354. return min(_parse_expression(values[0], doc_dict),
  1355. _parse_expression(values[1], doc_dict))
  1356. elif operator == '$arrayElemAt':
  1357. key, index = values
  1358. array = _parse_basic_expression(key, doc_dict)
  1359. v = array[index]
  1360. return v
  1361. else:
  1362. raise NotImplementedError("Although '%s' is a valid project operator for the "
  1363. "aggregation pipeline, it is currently not implemented "
  1364. "in Mongomock." % operator)
  1365. def _parse_basic_expression(expression, doc_dict):
  1366. if isinstance(expression, str) and expression.startswith('$'):
  1367. get_value = helpers.embedded_item_getter(expression.replace('$', ''))
  1368. return get_value(doc_dict)
  1369. else:
  1370. return expression
  1371. def _parse_expression(expression, doc_dict):
  1372. if not isinstance(expression, dict):
  1373. return _parse_basic_expression(expression, doc_dict)
  1374. value_dict = {}
  1375. for k, v in iteritems(expression):
  1376. if k in arithmetic_operators:
  1377. return _handle_arithmetic_operator(k, v, doc_dict)
  1378. elif k in project_operators:
  1379. return _handle_project_operator(k, v, doc_dict)
  1380. elif k in comparison_operators:
  1381. return _handle_comparison_operator(k, v, doc_dict)
  1382. elif k in date_operators:
  1383. return _handle_date_operator(k, v, doc_dict)
  1384. elif k in array_operators:
  1385. return _handle_array_operator(k, v, doc_dict)
  1386. elif k in conditional_operators:
  1387. return _handle_conditional_operator(k, v, doc_dict)
  1388. else:
  1389. value_dict[k] = _parse_expression(v, doc_dict)
  1390. return value_dict
  1391. def _extend_collection(out_collection, field, expression):
  1392. field_exists = False
  1393. for doc in out_collection:
  1394. if field in doc:
  1395. field_exists = True
  1396. break
  1397. if not field_exists:
  1398. for doc in out_collection:
  1399. if isinstance(expression, str) and expression.startswith('$'):
  1400. try:
  1401. doc[field] = get_value_by_dot(doc, expression.lstrip('$'))
  1402. except KeyError:
  1403. pass
  1404. else:
  1405. # verify expression has operator as first
  1406. doc[field] = _parse_expression(expression.copy(), doc)
  1407. return out_collection
  1408. out_collection = [doc for doc in self.find()]
  1409. for stage in pipeline:
  1410. for k, v in iteritems(stage):
  1411. if k == '$match':
  1412. out_collection = [doc for doc in out_collection
  1413. if filter_applies(v, doc)]
  1414. elif k == '$group':
  1415. grouped_collection = []
  1416. _id = stage['$group']['_id']
  1417. if _id:
  1418. key_getter = functools.partial(_parse_expression, _id)
  1419. out_collection = sorted(out_collection, key=key_getter)
  1420. grouped = itertools.groupby(out_collection, key_getter)
  1421. else:
  1422. grouped = [(None, out_collection)]
  1423. for doc_id, group in grouped:
  1424. group_list = ([x for x in group])
  1425. doc_dict = {'_id': doc_id}
  1426. for field, value in iteritems(v):
  1427. if field == '_id':
  1428. continue
  1429. for operator, key in iteritems(value):
  1430. if operator in (
  1431. "$sum",
  1432. "$avg",
  1433. "$min",
  1434. "$max",
  1435. "$first",
  1436. "$last",
  1437. "$addToSet",
  1438. '$push'
  1439. ):
  1440. key_getter = functools.partial(_parse_expression, key)
  1441. values = [key_getter(doc) for doc in group_list]
  1442. if operator == "$sum":
  1443. val_it = (val or 0 for val in values)
  1444. doc_dict[field] = sum(val_it)
  1445. elif operator == "$avg":
  1446. values = [val or 0 for val in values]
  1447. doc_dict[field] = sum(values) / max(len(values), 1)
  1448. elif operator == "$min":
  1449. val_it = (val or MAXSIZE for val in values)
  1450. doc_dict[field] = min(val_it)
  1451. elif operator == "$max":
  1452. val_it = (val or -MAXSIZE for val in values)
  1453. doc_dict[field] = max(val_it)
  1454. elif operator == "$first":
  1455. doc_dict[field] = values[0]
  1456. elif operator == "$last":
  1457. doc_dict[field] = values[-1]
  1458. elif operator == "$addToSet":
  1459. val_it = (val or None for val in values)
  1460. doc_dict[field] = set(val_it)
  1461. elif operator == '$push':
  1462. if field not in doc_dict:
  1463. doc_dict[field] = []
  1464. doc_dict[field].extend(values)
  1465. else:
  1466. if operator in group_operators:
  1467. raise NotImplementedError(
  1468. "Although %s is a valid group operator for the "
  1469. "aggregation pipeline, it is currently not implemented "
  1470. "in Mongomock." % operator)
  1471. else:
  1472. raise NotImplementedError(
  1473. "%s is not a valid group operator for the aggregation "
  1474. "pipeline. See http://docs.mongodb.org/manual/meta/"
  1475. "aggregation-quick-reference/ for a complete list of "
  1476. "valid operators." % operator)
  1477. grouped_collection.append(doc_dict)
  1478. out_collection = grouped_collection
  1479. elif k == '$sort':
  1480. sort_array = []
  1481. for x, y in v.items():
  1482. sort_array.append({x: y})
  1483. for sort_pair in reversed(sort_array):
  1484. for sortKey, sortDirection in sort_pair.items():
  1485. out_collection = sorted(
  1486. out_collection,
  1487. key=lambda x: _resolve_sort_key(sortKey, x),
  1488. reverse=sortDirection < 0)
  1489. elif k == '$skip':
  1490. out_collection = out_collection[v:]
  1491. elif k == '$limit':
  1492. out_collection = out_collection[:v]
  1493. elif k == '$unwind':
  1494. if not isinstance(v, helpers.basestring) or v[0] != '$':
  1495. raise ValueError(
  1496. "$unwind failed: exception: field path references must be prefixed "
  1497. "with a '$' '%s'" % v)
  1498. unwound_collection = []
  1499. for doc in out_collection:
  1500. array_value = get_value_by_dot(doc, v[1:])
  1501. if array_value in (None, []):
  1502. continue
  1503. elif not isinstance(array_value, list):
  1504. raise TypeError(
  1505. '$unwind must specify an array field, field: '
  1506. '"%s", value found: %s' % (v, array_value))
  1507. for field_item in array_value:
  1508. unwound_collection.append(copy.deepcopy(doc))
  1509. unwound_collection[-1] = set_value_by_dot(
  1510. unwound_collection[-1], v[1:], field_item)
  1511. out_collection = unwound_collection
  1512. elif k == '$project':
  1513. filter_list = ['_id']
  1514. for field, value in iteritems(v):
  1515. if field == '_id' and not value:
  1516. filter_list.remove('_id')
  1517. elif value:
  1518. filter_list.append(field)
  1519. out_collection = _extend_collection(out_collection, field, value)
  1520. out_collection = [{k: v for (k, v) in x.items() if k in filter_list}
  1521. for x in out_collection]
  1522. elif k == '$out':
  1523. # TODO(MetrodataTeam): should leave the origin collection unchanged
  1524. collection = self.database.get_collection(v)
  1525. if collection.count() > 0:
  1526. collection.drop()
  1527. collection.insert_many(out_collection)
  1528. else:
  1529. if k in pipeline_operators:
  1530. raise NotImplementedError(
  1531. "Although '%s' is a valid operator for the aggregation pipeline, it is "
  1532. "currently not implemented in Mongomock." % k)
  1533. else:
  1534. raise NotImplementedError(
  1535. "%s is not a valid operator for the aggregation pipeline. "
  1536. "See http://docs.mongodb.org/manual/meta/aggregation-quick-reference/ "
  1537. "for a complete list of valid operators." % k)
  1538. return CommandCursor(out_collection)
  1539. def with_options(
  1540. self, codec_options=None, read_preference=None, write_concern=None, read_concern=None):
  1541. return self
  1542. def rename(self, new_name, **kwargs):
  1543. self.database.rename_collection(self.name, new_name, **kwargs)
  1544. def bulk_write(self, operations):
  1545. bulk = BulkOperationBuilder(self)
  1546. for operation in operations:
  1547. operation._add_to_bulk(bulk)
  1548. return BulkWriteResult(bulk.execute(), True)
  1549. def _resolve_key(key, doc):
  1550. return next(iter(iter_key_candidates(key, doc)), NOTHING)
  1551. def _resolve_sort_key(key, doc):
  1552. value = _resolve_key(key, doc)
  1553. # see http://docs.mongodb.org/manual/reference/method/cursor.sort/#ascending-descending-sort
  1554. if value is NOTHING:
  1555. return 0, value
  1556. return 1, value
  1557. class Cursor(object):
  1558. def __init__(self, collection, spec=None, sort=None, projection=None, skip=0, limit=0,
  1559. collation=None):
  1560. super(Cursor, self).__init__()
  1561. self.collection = collection
  1562. spec = helpers.patch_datetime_awareness_in_document(spec)
  1563. self._spec = spec
  1564. self._sort = sort
  1565. self._projection = projection
  1566. self._skip = skip
  1567. self._factory_last_generated_results = None
  1568. self._results = None
  1569. self._factory = functools.partial(collection._get_dataset,
  1570. spec, sort, projection, dict)
  1571. # pymongo limit defaults to 0, returning everything
  1572. self._limit = limit if limit != 0 else None
  1573. self._collation = collation
  1574. self.rewind()
  1575. def _compute_results(self, with_limit_and_skip=False):
  1576. # Recompute the result only if the query has changed
  1577. if not self._results or self._factory_last_generated_results != self._factory:
  1578. if self.collection.database.client._tz_aware:
  1579. results = [helpers.make_datetime_timezone_aware_in_document(x)
  1580. for x in self._factory()]
  1581. else:
  1582. results = list(self._factory())
  1583. self._factory_last_generated_results = self._factory
  1584. self._results = results
  1585. if with_limit_and_skip:
  1586. results = self._results[self._skip:]
  1587. if self._limit:
  1588. results = results[:self._limit]
  1589. else:
  1590. results = self._results
  1591. return results
  1592. def __iter__(self):
  1593. return self
  1594. def clone(self):
  1595. cursor = Cursor(self.collection,
  1596. self._spec, self._sort, self._projection, self._skip, self._limit)
  1597. cursor._factory = self._factory
  1598. return cursor
  1599. def __next__(self):
  1600. try:
  1601. doc = self._compute_results(with_limit_and_skip=True)[self._emitted]
  1602. self._emitted += 1
  1603. return doc
  1604. except IndexError:
  1605. raise StopIteration()
  1606. next = __next__
  1607. def rewind(self):
  1608. self._emitted = 0
  1609. def sort(self, key_or_list, direction=None):
  1610. if direction is None:
  1611. direction = 1
  1612. def _make_sort_factory_layer(upper_factory, sortKey, sortDirection):
  1613. def layer():
  1614. return sorted(upper_factory(), key=lambda x: _resolve_sort_key(sortKey, x),
  1615. reverse=sortDirection < 0)
  1616. return layer
  1617. if isinstance(key_or_list, (tuple, list)):
  1618. for sortKey, sortDirection in reversed(key_or_list):
  1619. self._factory = _make_sort_factory_layer(self._factory, sortKey, sortDirection)
  1620. else:
  1621. self._factory = _make_sort_factory_layer(self._factory, key_or_list, direction)
  1622. return self
  1623. def count(self, with_limit_and_skip=False):
  1624. results = self._compute_results(with_limit_and_skip)
  1625. return len(results)
  1626. def skip(self, count):
  1627. self._skip = count
  1628. return self
  1629. def limit(self, count):
  1630. self._limit = count if count != 0 else None
  1631. return self
  1632. def batch_size(self, count):
  1633. return self
  1634. def close(self):
  1635. pass
  1636. def distinct(self, key):
  1637. if not isinstance(key, helpers.basestring):
  1638. raise TypeError('cursor.distinct key must be a string')
  1639. unique = set()
  1640. unique_dict_vals = []
  1641. for x in self._compute_results():
  1642. value = _resolve_key(key, x)
  1643. if value == NOTHING:
  1644. continue
  1645. if isinstance(value, dict):
  1646. if any(dict_val == value for dict_val in unique_dict_vals):
  1647. continue
  1648. unique_dict_vals.append(value)
  1649. else:
  1650. unique.update(
  1651. value if isinstance(
  1652. value, (tuple, list)) else [value])
  1653. return list(unique) + unique_dict_vals
  1654. def __getitem__(self, index):
  1655. if isinstance(index, slice):
  1656. if index.step is not None:
  1657. raise IndexError("Cursor instances do not support slice steps")
  1658. skip = 0
  1659. if index.start is not None:
  1660. if index.start < 0:
  1661. raise IndexError("Cursor instances do not support"
  1662. "negative indices")
  1663. skip = index.start
  1664. if index.stop is not None:
  1665. limit = index.stop - skip
  1666. if limit < 0:
  1667. raise IndexError("stop index must be greater than start"
  1668. "index for slice %r" % index)
  1669. if limit == 0:
  1670. self.__empty = True
  1671. else:
  1672. limit = 0
  1673. self._skip = skip
  1674. self._limit = limit
  1675. return self
  1676. elif not isinstance(index, int):
  1677. raise TypeError("index '%s' cannot be applied to Cursor instances" % index)
  1678. elif index < 0:
  1679. raise IndexError('Cursor instances do not support negativeindices')
  1680. else:
  1681. return self._compute_results(with_limit_and_skip=True)[index]
  1682. def __enter__(self):
  1683. return self
  1684. def __exit__(self, exc_type, exc_val, exc_tb):
  1685. self.close()
  1686. def _set_updater(doc, field_name, value):
  1687. if isinstance(value, (tuple, list)):
  1688. value = copy.deepcopy(value)
  1689. if isinstance(doc, dict):
  1690. doc[field_name] = value
  1691. def _unset_updater(doc, field_name, value):
  1692. if isinstance(doc, dict):
  1693. doc.pop(field_name, None)
  1694. def _inc_updater(doc, field_name, value):
  1695. if isinstance(doc, dict):
  1696. doc[field_name] = doc.get(field_name, 0) + value
  1697. def _max_updater(doc, field_name, value):
  1698. if isinstance(doc, dict):
  1699. doc[field_name] = max(doc.get(field_name, value), value)
  1700. def _min_updater(doc, field_name, value):
  1701. if isinstance(doc, dict):
  1702. doc[field_name] = min(doc.get(field_name, value), value)
  1703. def _sum_updater(doc, field_name, current, result):
  1704. if isinstance(doc, dict):
  1705. result = current + doc.get[field_name, 0]
  1706. return result
  1707. def _current_date_updater(doc, field_name, value):
  1708. if isinstance(doc, dict):
  1709. doc[field_name] = datetime.utcnow()
  1710. _updaters = {
  1711. '$set': _set_updater,
  1712. '$unset': _unset_updater,
  1713. '$inc': _inc_updater,
  1714. '$max': _max_updater,
  1715. '$min': _min_updater,
  1716. }