123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944 |
- from __future__ import division
- import collections
- from collections import OrderedDict
- import copy
- from datetime import datetime
- import functools
- import itertools
- import json
- import math
- import threading
- import time
- import types
- import warnings
- try:
- from bson import json_util, SON, BSON
- except ImportError:
- json_utils = SON = BSON = None
- try:
- import execjs
- except ImportError:
- execjs = None
- try:
- from pymongo import ReturnDocument
- except ImportError:
- class ReturnDocument(object):
- BEFORE = False
- AFTER = True
- from sentinels import NOTHING
- from six import iteritems
- from six import iterkeys
- from six import itervalues
- from six import MAXSIZE
- from six import string_types
- from six import text_type
- from mongomock.command_cursor import CommandCursor
- from mongomock import DuplicateKeyError, BulkWriteError
- from mongomock.filtering import filter_applies
- from mongomock.filtering import iter_key_candidates
- from mongomock import helpers
- from mongomock import InvalidOperation
- from mongomock import ObjectId
- from mongomock import OperationFailure
- from mongomock.results import BulkWriteResult
- from mongomock.results import DeleteResult
- from mongomock.results import InsertManyResult
- from mongomock.results import InsertOneResult
- from mongomock.results import UpdateResult
- from mongomock.write_concern import WriteConcern
- from mongomock import WriteError
- lock = threading.RLock()
- def validate_is_mapping(option, value):
- if not isinstance(value, collections.Mapping):
- raise TypeError('%s must be an instance of dict, bson.son.SON, or '
- 'other type that inherits from '
- 'collections.Mapping' % (option,))
- def validate_is_mutable_mapping(option, value):
- if not isinstance(value, collections.MutableMapping):
- raise TypeError('%s must be an instance of dict, bson.son.SON, or '
- 'other type that inherits from '
- 'collections.MutableMapping' % (option,))
- def validate_ok_for_replace(replacement):
- validate_is_mapping('replacement', replacement)
- if replacement:
- first = next(iter(replacement))
- if first.startswith('$'):
- raise ValueError('replacement can not include $ operators')
- def validate_ok_for_update(update):
- validate_is_mapping('update', update)
- if not update:
- raise ValueError('update only works with $ operators')
- first = next(iter(update))
- if not first.startswith('$'):
- raise ValueError('update only works with $ operators')
- def validate_write_concern_params(**params):
- if params:
- WriteConcern(**params)
- def get_value_by_dot(doc, key):
- """Get dictionary value using dotted key"""
- result = doc
- for i in key.split('.'):
- result = result[i]
- return result
- def set_value_by_dot(doc, key, value):
- """Set dictionary value using dotted key"""
- result = doc
- keys = key.split('.')
- for i in keys[:-1]:
- if i not in result:
- result[i] = {}
- result = result[i]
- result[keys[-1]] = value
- return doc
- class BulkWriteOperation(object):
- def __init__(self, builder, selector, is_upsert=False):
- self.builder = builder
- self.selector = selector
- self.is_upsert = is_upsert
- def upsert(self):
- assert not self.is_upsert
- return BulkWriteOperation(self.builder, self.selector, is_upsert=True)
- def register_remove_op(self, multi):
- collection = self.builder.collection
- selector = self.selector
- def exec_remove():
- op_result = collection.remove(selector, multi=multi)
- if op_result.get("ok"):
- return {'nRemoved': op_result.get('n')}
- err = op_result.get("err")
- if err:
- return {"writeErrors": [err]}
- return {}
- self.builder.executors.append(exec_remove)
- def remove(self):
- assert not self.is_upsert
- self.register_remove_op(multi=True)
- def remove_one(self,):
- assert not self.is_upsert
- self.register_remove_op(multi=False)
- def register_update_op(self, document, multi, **extra_args):
- if not extra_args.get("remove"):
- validate_ok_for_update(document)
- collection = self.builder.collection
- selector = self.selector
- def exec_update():
- result = collection._update(spec=selector, document=document,
- multi=multi, upsert=self.is_upsert,
- **extra_args)
- ret_val = {}
- if result.get('upserted'):
- ret_val["upserted"] = result.get('upserted')
- ret_val["nUpserted"] = result.get('n')
- modified = result.get('nModified')
- if modified is not None:
- ret_val['nModified'] = modified
- ret_val['nMatched'] = modified
- if result.get('err'):
- ret_val['err'] = result.get('err')
- return ret_val
- self.builder.executors.append(exec_update)
- def update(self, document):
- self.register_update_op(document, multi=True)
- def update_one(self, document):
- self.register_update_op(document, multi=False)
- def replace_one(self, document):
- self.register_update_op(document, multi=False, remove=True)
- class BulkOperationBuilder(object):
- def __init__(self, collection, ordered=False):
- self.collection = collection
- self.ordered = ordered
- self.results = {}
- self.executors = []
- self.done = False
- self._insert_returns_nModified = True
- self._update_returns_nModified = True
- def find(self, selector):
- return BulkWriteOperation(self, selector)
- def insert(self, doc):
- def exec_insert():
- self.collection.insert(doc)
- return {'nInserted': 1}
- self.executors.append(exec_insert)
- def __aggregate_operation_result(self, total_result, key, value):
- agg_val = total_result.get(key)
- assert agg_val is not None, "Unknow operation result %s=%s" \
- " (unrecognized key)" % (key, value)
- if isinstance(agg_val, int):
- total_result[key] += value
- elif isinstance(agg_val, list):
- if key == "upserted":
- new_element = {"index": len(agg_val), "_id": value}
- agg_val.append(new_element)
- else:
- agg_val.append(value)
- else:
- assert False, "Fixme: missed aggreation rule for type: %s for" \
- " key {%s=%s}" % (type(agg_val), key, agg_val)
- def _set_nModified_policy(self, insert, update):
- self._insert_returns_nModified = insert
- self._update_returns_nModified = update
- def execute(self, write_concern=None):
- if not self.executors:
- raise InvalidOperation("Bulk operation empty!")
- if self.done:
- raise InvalidOperation("Bulk operation already executed!")
- self.done = True
- result = {'nModified': 0, 'nUpserted': 0, 'nMatched': 0,
- 'writeErrors': [], 'upserted': [], 'writeConcernErrors': [],
- 'nRemoved': 0, 'nInserted': 0}
- has_update = False
- has_insert = False
- broken_nModified_info = False
- for execute_func in self.executors:
- exec_name = execute_func.__name__
- op_result = execute_func()
- for (key, value) in op_result.items():
- self.__aggregate_operation_result(result, key, value)
- if exec_name == "exec_update":
- has_update = True
- if "nModified" not in op_result:
- broken_nModified_info = True
- has_insert |= exec_name == "exec_insert"
- if broken_nModified_info:
- result.pop('nModified')
- elif has_insert and self._insert_returns_nModified:
- pass
- elif has_update and self._update_returns_nModified:
- pass
- elif self._update_returns_nModified and self._insert_returns_nModified:
- pass
- else:
- result.pop('nModified')
- return result
- def add_insert(self, doc):
- self.insert(doc)
- def add_update(self, selector, doc, multi, upsert, collation=None):
- write_operation = BulkWriteOperation(self, selector, is_upsert=upsert)
- write_operation.register_update_op(doc, multi)
- def add_replace(self, selector, doc, upsert, collation=None):
- write_operation = BulkWriteOperation(self, selector, is_upsert=upsert)
- write_operation.replace_one(doc)
- def add_delete(self, selector, just_one, collation=None):
- write_operation = BulkWriteOperation(self, selector, is_upsert=False)
- write_operation.register_remove_op(not just_one)
- class Collection(object):
- def __init__(self, db, name):
- self.name = name
- self.full_name = "{0}.{1}".format(db.name, name)
- self.database = db
- self._documents = OrderedDict()
- self._force_created = False
- self._uniques = []
- def _is_created(self):
- return self._documents or self._uniques or self._force_created
- def __repr__(self):
- return "Collection({0}, '{1}')".format(self.database, self.name)
- def __getitem__(self, name):
- return self.database[self.name + '.' + name]
- def __getattr__(self, name):
- return self.__getitem__(name)
- def initialize_unordered_bulk_op(self):
- return BulkOperationBuilder(self, ordered=False)
- def initialize_ordered_bulk_op(self):
- return BulkOperationBuilder(self, ordered=True)
- def insert(self, data, manipulate=True, check_keys=True,
- continue_on_error=False, **kwargs):
- warnings.warn("insert is deprecated. Use insert_one or insert_many "
- "instead.", DeprecationWarning, stacklevel=2)
- validate_write_concern_params(**kwargs)
- return self._insert(data)
- def insert_one(self, document):
- validate_is_mutable_mapping('document', document)
- return InsertOneResult(self._insert(document), acknowledged=True)
- def insert_many(self, documents, ordered=True):
- if not isinstance(documents, collections.Iterable) or not documents:
- raise TypeError('documents must be a non-empty list')
- for document in documents:
- validate_is_mutable_mapping('document', document)
- try:
- return InsertManyResult(self._insert(documents), acknowledged=True)
- except DuplicateKeyError:
- raise BulkWriteError('batch op errors occurred')
- def _insert(self, data):
- if isinstance(data, list) or isinstance(data, types.GeneratorType):
- return [self._insert(item) for item in data]
- # Like pymongo, we should fill the _id in the inserted dict (odd behavior,
- # but we need to stick to it), so we must patch in-place the data dict
- for key in data.keys():
- data[key] = helpers.patch_datetime_awareness_in_document(data[key])
- if not all(isinstance(k, string_types) for k in data):
- raise ValueError("Document keys must be strings")
- if BSON:
- # bson validation
- BSON.encode(data, check_keys=True)
- if '_id' not in data:
- data['_id'] = ObjectId()
- object_id = data['_id']
- if isinstance(object_id, dict):
- object_id = helpers.hashdict(object_id)
- if object_id in self._documents:
- raise DuplicateKeyError("E11000 Duplicate Key Error", 11000)
- for unique, is_sparse in self._uniques:
- find_kwargs = {}
- for key, direction in unique:
- find_kwargs[key] = data.get(key, None)
- answer_count = len(list(self._iter_documents(find_kwargs)))
- if answer_count > 0 and not (is_sparse and find_kwargs[key] is None):
- raise DuplicateKeyError("E11000 Duplicate Key Error", 11000)
- with lock:
- self._documents[object_id] = self._internalize_dict(data)
- return data['_id']
- def _internalize_dict(self, d):
- return {k: copy.deepcopy(v) for k, v in iteritems(d)}
- def _has_key(self, doc, key):
- key_parts = key.split('.')
- sub_doc = doc
- for part in key_parts:
- if part not in sub_doc:
- return False
- sub_doc = sub_doc[part]
- return True
- def _remove_key(self, doc, key):
- key_parts = key.split('.')
- sub_doc = doc
- for part in key_parts[:-1]:
- sub_doc = sub_doc[part]
- del sub_doc[key_parts[-1]]
- def update_one(self, filter, update, upsert=False):
- validate_ok_for_update(update)
- return UpdateResult(self._update(filter, update, upsert=upsert),
- acknowledged=True)
- def update_many(self, filter, update, upsert=False):
- validate_ok_for_update(update)
- return UpdateResult(self._update(filter, update, upsert=upsert,
- multi=True),
- acknowledged=True)
- def replace_one(self, filter, replacement, upsert=False):
- validate_ok_for_replace(replacement)
- return UpdateResult(self._update(filter, replacement, upsert=upsert),
- acknowledged=True)
- def update(self, spec, document, upsert=False, manipulate=False,
- multi=False, check_keys=False, **kwargs):
- warnings.warn("update is deprecated. Use replace_one, update_one or "
- "update_many instead.", DeprecationWarning, stacklevel=2)
- return self._update(spec, document, upsert, manipulate, multi,
- check_keys, **kwargs)
- def _update(self, spec, document, upsert=False, manipulate=False,
- multi=False, check_keys=False, **kwargs):
- spec = helpers.patch_datetime_awareness_in_document(spec)
- document = helpers.patch_datetime_awareness_in_document(document)
- validate_is_mapping('spec', spec)
- validate_is_mapping('document', document)
- updated_existing = False
- upserted_id = None
- num_updated = 0
- for existing_document in itertools.chain(self._iter_documents(spec), [None]):
- # we need was_insert for the setOnInsert update operation
- was_insert = False
- # the sentinel document means we should do an upsert
- if existing_document is None:
- if not upsert or num_updated:
- continue
- # For upsert operation we have first to create a fake existing_document,
- # update it like a regular one, then finally insert it
- if spec.get('_id') is not None:
- _id = spec['_id']
- elif document.get('_id') is not None:
- _id = document['_id']
- else:
- _id = ObjectId()
- to_insert = dict(spec, _id=_id)
- to_insert = self._expand_dots(to_insert)
- existing_document = to_insert
- was_insert = True
- else:
- updated_existing = True
- num_updated += 1
- first = True
- subdocument = None
- for k, v in iteritems(document):
- if k in _updaters.keys():
- updater = _updaters[k]
- subdocument = self._update_document_fields_with_positional_awareness(
- existing_document, v, spec, updater, subdocument)
- elif k == '$setOnInsert':
- if not was_insert:
- continue
- subdocument = self._update_document_fields_with_positional_awareness(
- existing_document, v, spec, _set_updater, subdocument)
- elif k == '$currentDate':
- for value in itervalues(v):
- if value == {'$type': 'timestamp'}:
- raise NotImplementedError('timestamp is not supported so far')
- subdocument = self._update_document_fields_with_positional_awareness(
- existing_document, v, spec, _current_date_updater, subdocument)
- elif k == '$addToSet':
- for field, value in iteritems(v):
- nested_field_list = field.rsplit('.')
- if len(nested_field_list) == 1:
- if field not in existing_document:
- existing_document[field] = []
- # document should be a list append to it
- if isinstance(value, dict):
- if '$each' in value:
- # append the list to the field
- existing_document[field] += [
- obj for obj in list(value['$each'])
- if obj not in existing_document[field]]
- continue
- if value not in existing_document[field]:
- existing_document[field].append(value)
- continue
- # push to array in a nested attribute
- else:
- # create nested attributes if they do not exist
- subdocument = existing_document
- for field in nested_field_list[:-1]:
- if field not in subdocument:
- subdocument[field] = {}
- subdocument = subdocument[field]
- # we're pushing a list
- push_results = []
- if nested_field_list[-1] in subdocument:
- # if the list exists, then use that list
- push_results = subdocument[
- nested_field_list[-1]]
- if isinstance(value, dict) and '$each' in value:
- push_results += [
- obj for obj in list(value['$each'])
- if obj not in push_results]
- elif value not in push_results:
- push_results.append(value)
- subdocument[nested_field_list[-1]] = push_results
- elif k == '$pull':
- for field, value in iteritems(v):
- nested_field_list = field.rsplit('.')
- # nested fields includes a positional element
- # need to find that element
- if '$' in nested_field_list:
- if not subdocument:
- subdocument = self._get_subdocument(
- existing_document, spec, nested_field_list)
- # value should be a dictionary since we're pulling
- pull_results = []
- # and the last subdoc should be an array
- for obj in subdocument[nested_field_list[-1]]:
- if isinstance(obj, dict):
- for pull_key, pull_value in iteritems(value):
- if obj[pull_key] != pull_value:
- pull_results.append(obj)
- continue
- if obj != value:
- pull_results.append(obj)
- # cannot write to doc directly as it doesn't save to
- # existing_document
- subdocument[nested_field_list[-1]] = pull_results
- else:
- arr = existing_document
- for field in nested_field_list:
- if field not in arr:
- break
- arr = arr[field]
- if not isinstance(arr, list):
- continue
- arr_copy = copy.deepcopy(arr)
- if isinstance(value, dict):
- for obj in arr_copy:
- if filter_applies(value, obj):
- arr.remove(obj)
- else:
- for obj in arr_copy:
- if value == obj:
- arr.remove(obj)
- elif k == '$pullAll':
- for field, value in iteritems(v):
- nested_field_list = field.rsplit('.')
- if len(nested_field_list) == 1:
- if field in existing_document:
- arr = existing_document[field]
- existing_document[field] = [
- obj for obj in arr if obj not in value]
- continue
- else:
- subdocument = existing_document
- for nested_field in nested_field_list[:-1]:
- if nested_field not in subdocument:
- break
- subdocument = subdocument[nested_field]
- if nested_field_list[-1] in subdocument:
- arr = subdocument[nested_field_list[-1]]
- subdocument[nested_field_list[-1]] = [
- obj for obj in arr if obj not in value]
- elif k == '$push':
- for field, value in iteritems(v):
- nested_field_list = field.rsplit('.')
- if len(nested_field_list) == 1:
- if field not in existing_document:
- existing_document[field] = []
- # document should be a list
- # append to it
- if isinstance(value, dict):
- if '$each' in value:
- # append the list to the field
- existing_document[field] += list(value['$each'])
- continue
- existing_document[field].append(value)
- continue
- # nested fields includes a positional element
- # need to find that element
- elif '$' in nested_field_list:
- if not subdocument:
- subdocument = self._get_subdocument(
- existing_document, spec, nested_field_list)
- # we're pushing a list
- push_results = []
- if nested_field_list[-1] in subdocument:
- # if the list exists, then use that list
- push_results = subdocument[nested_field_list[-1]]
- if isinstance(value, dict):
- # check to see if we have the format
- # { '$each': [] }
- if '$each' in value:
- push_results += list(value['$each'])
- else:
- push_results.append(value)
- else:
- push_results.append(value)
- # cannot write to doc directly as it doesn't save to
- # existing_document
- subdocument[nested_field_list[-1]] = push_results
- # push to array in a nested attribute
- else:
- # create nested attributes if they do not exist
- subdocument = existing_document
- for field in nested_field_list[:-1]:
- if field not in subdocument:
- subdocument[field] = {}
- subdocument = subdocument[field]
- # we're pushing a list
- push_results = []
- if nested_field_list[-1] in subdocument:
- # if the list exists, then use that list
- push_results = subdocument[nested_field_list[-1]]
- if isinstance(value, dict) and '$each' in value:
- push_results += list(value['$each'])
- else:
- push_results.append(value)
- subdocument[nested_field_list[-1]] = push_results
- else:
- if first:
- # replace entire document
- for key in document.keys():
- if key.startswith('$'):
- # can't mix modifiers with non-modifiers in
- # update
- raise ValueError('field names cannot start with $ [{}]'.format(k))
- _id = spec.get('_id', existing_document.get('_id'))
- existing_document.clear()
- if _id:
- existing_document['_id'] = _id
- existing_document.update(self._internalize_dict(document))
- if existing_document['_id'] != _id:
- raise OperationFailure(
- "The _id field cannot be changed from {0} to {1}"
- .format(existing_document['_id'], _id))
- break
- else:
- # can't mix modifiers with non-modifiers in update
- raise ValueError(
- 'Invalid modifier specified: {}'.format(k))
- first = False
- # if empty document comes
- if len(document) == 0:
- _id = spec.get('_id', existing_document.get('_id'))
- existing_document.clear()
- if _id:
- existing_document['_id'] = _id
- if was_insert:
- upserted_id = self._insert(existing_document)
- if not multi:
- break
- return {
- text_type("connectionId"): self.database.client._id,
- text_type("err"): None,
- text_type("n"): num_updated,
- text_type("nModified"): num_updated if updated_existing else 0,
- text_type("ok"): 1,
- text_type("upserted"): upserted_id,
- text_type("updatedExisting"): updated_existing,
- }
- def _get_subdocument(self, existing_document, spec, nested_field_list):
- """This method retrieves the subdocument of the existing_document.nested_field_list.
- It uses the spec to filter through the items. It will continue to grab nested documents
- until it can go no further. It will then return the subdocument that was last saved.
- '$' is the positional operator, so we use the $elemMatch in the spec to find the right
- subdocument in the array.
- """
- # current document in view
- doc = existing_document
- # previous document in view
- subdocument = existing_document
- # current spec in view
- subspec = spec
- # walk down the dictionary
- for subfield in nested_field_list:
- if subfield == '$':
- # positional element should have the equivalent elemMatch in the
- # query
- subspec = subspec['$elemMatch']
- for item in doc:
- # iterate through
- if filter_applies(subspec, item):
- # found the matching item save the parent
- subdocument = doc
- # save the item
- doc = item
- break
- continue
- subdocument = doc
- doc = doc[subfield]
- if subfield not in subspec:
- break
- subspec = subspec[subfield]
- return subdocument
- def _expand_dots(self, doc):
- expanded = {}
- paths = {}
- for k, v in iteritems(doc):
- key_parts = k.split('.')
- sub_doc = v
- for i in reversed(range(1, len(key_parts))):
- key = key_parts[i]
- sub_doc = {key: sub_doc}
- key = key_parts[0]
- if key in expanded:
- raise WriteError("cannot infer query fields to set, "
- "both paths '%s' and '%s' are matched"
- % (k, paths[key]))
- paths[key] = k
- expanded[key] = sub_doc
- return expanded
- def _discard_operators(self, doc):
- # TODO(this looks a little too naive...)
- return {k: v for k, v in iteritems(doc) if not k.startswith("$")}
- def find(self, filter=None, projection=None, skip=0, limit=0,
- no_cursor_timeout=False, cursor_type=None, sort=None,
- allow_partial_results=False, oplog_replay=False, modifiers=None,
- batch_size=0, manipulate=True, collation=None):
- spec = filter
- if spec is None:
- spec = {}
- validate_is_mapping('filter', spec)
- return Cursor(self, spec, sort, projection, skip, limit, collation=collation)
- def _get_dataset(self, spec, sort, fields, as_class):
- dataset = (self._copy_only_fields(document, fields, as_class)
- for document in self._iter_documents(spec))
- if sort:
- for sortKey, sortDirection in reversed(sort):
- dataset = iter(sorted(
- dataset, key=lambda x: _resolve_sort_key(sortKey, x),
- reverse=sortDirection < 0))
- return dataset
- def _copy_field(self, obj, container):
- if isinstance(obj, list):
- new = []
- for item in obj:
- new.append(self._copy_field(item, container))
- return new
- if isinstance(obj, dict):
- new = container()
- for key, value in obj.items():
- new[key] = self._copy_field(value, container)
- return new
- else:
- return copy.copy(obj)
- def _extract_projection_operators(self, fields):
- """Removes and returns fields with projection operators."""
- result = {}
- allowed_projection_operators = {'$elemMatch'}
- for key, value in iteritems(fields):
- if isinstance(value, dict):
- for op in value:
- if op not in allowed_projection_operators:
- raise ValueError('Unsupported projection option: {}'.format(op))
- result[key] = value
- for key in result:
- del fields[key]
- return result
- def _apply_projection_operators(self, ops, doc, doc_copy):
- """Applies projection operators to copied document."""
- for field, op in iteritems(ops):
- if field not in doc_copy:
- if field in doc:
- # field was not copied yet (since we are in include mode)
- doc_copy[field] = doc[field]
- else:
- # field doesn't exist in original document, no work to do
- continue
- if '$elemMatch' in op:
- if isinstance(doc_copy[field], list):
- # find the first item that matches
- matched = False
- for item in doc_copy[field]:
- if filter_applies(op['$elemMatch'], item):
- matched = True
- doc_copy[field] = [item]
- break
- # nothing have matched
- if not matched:
- del doc_copy[field]
- else:
- # remove the field since there is nothing to iterate
- del doc_copy[field]
- def _copy_only_fields(self, doc, fields, container):
- """Copy only the specified fields."""
- if fields is None:
- return self._copy_field(doc, container)
- else:
- if not fields:
- fields = {"_id": 1}
- if not isinstance(fields, dict):
- fields = helpers._fields_list_to_dict(fields)
- # we can pass in something like {"_id":0, "field":1}, so pull the id
- # value out and hang on to it until later
- id_value = fields.pop('_id', 1)
- # filter out fields with projection operators, we will take care of them later
- projection_operators = self._extract_projection_operators(fields)
- # other than the _id field, all fields must be either includes or
- # excludes, this can evaluate to 0
- if len(set(list(fields.values()))) > 1:
- raise ValueError(
- 'You cannot currently mix including and excluding fields.')
- # if we have novalues passed in, make a doc_copy based on the
- # id_value
- if len(list(fields.values())) == 0:
- if id_value == 1:
- doc_copy = container()
- else:
- doc_copy = self._copy_field(doc, container)
- # if 1 was passed in as the field values, include those fields
- elif list(fields.values())[0] == 1:
- doc_copy = container()
- for key in fields:
- key_parts = key.split('.')
- subdocument = doc
- subdocument_copy = doc_copy
- last_copy = subdocument_copy
- full_key_path_found = True
- for key_part in key_parts[:-1]:
- if key_part not in subdocument:
- full_key_path_found = False
- break
- subdocument = subdocument[key_part]
- last_copy = subdocument_copy
- subdocument_copy = subdocument_copy.setdefault(key_part, {})
- if full_key_path_found:
- last_key = key_parts[-1]
- if isinstance(subdocument, dict) and last_key in subdocument:
- subdocument_copy[last_key] = subdocument[last_key]
- elif isinstance(subdocument, (list, tuple)):
- subdocument = [{last_key: x[last_key]}
- for x in subdocument if last_key in x]
- if subdocument:
- last_copy[key_parts[-2]] = subdocument
- # otherwise, exclude the fields passed in
- else:
- doc_copy = self._copy_field(doc, container)
- for key in fields:
- key_parts = key.split('.')
- subdocument_copy = doc_copy
- full_key_path_found = True
- for key_part in key_parts[:-1]:
- if key_part not in subdocument_copy:
- full_key_path_found = False
- break
- subdocument_copy = subdocument_copy[key_part]
- if not full_key_path_found or key_parts[-1] not in subdocument_copy:
- continue
- del subdocument_copy[key_parts[-1]]
- # set the _id value if we requested it, otherwise remove it
- if id_value == 0:
- doc_copy.pop('_id', None)
- else:
- if '_id' in doc:
- doc_copy['_id'] = doc['_id']
- fields['_id'] = id_value # put _id back in fields
- # time to apply the projection operators and put back their fields
- self._apply_projection_operators(projection_operators, doc, doc_copy)
- for field, op in iteritems(projection_operators):
- fields[field] = op
- return doc_copy
- def _update_document_fields(self, doc, fields, updater):
- """Implements the $set behavior on an existing document"""
- for k, v in iteritems(fields):
- self._update_document_single_field(doc, k, v, updater)
- def _update_document_fields_positional(self, doc, fields, spec, updater,
- subdocument=None):
- """Implements the $set behavior on an existing document"""
- for k, v in iteritems(fields):
- if '$' in k:
- field_name_parts = k.split('.')
- if not subdocument:
- current_doc = doc
- subspec = spec
- for part in field_name_parts[:-1]:
- if part == '$':
- subspec = subspec.get('$elemMatch', subspec)
- for item in current_doc:
- if filter_applies(subspec, item):
- current_doc = item
- break
- continue
- new_spec = {}
- for el in subspec:
- if el.startswith(part):
- if len(el.split(".")) > 1:
- new_spec[".".join(
- el.split(".")[1:])] = subspec[el]
- else:
- new_spec = subspec[el]
- subspec = new_spec
- current_doc = current_doc[part]
- subdocument = current_doc
- if (field_name_parts[-1] == '$' and
- isinstance(subdocument, list)):
- for i, doc in enumerate(subdocument):
- if filter_applies(subspec, doc):
- subdocument[i] = v
- break
- continue
- updater(subdocument, field_name_parts[-1], v)
- continue
- # otherwise, we handle it the standard way
- self._update_document_single_field(doc, k, v, updater)
- return subdocument
- def _update_document_fields_with_positional_awareness(self, existing_document, v, spec,
- updater, subdocument):
- positional = any('$' in key for key in iterkeys(v))
- if positional:
- return self._update_document_fields_positional(
- existing_document, v, spec, updater, subdocument)
- self._update_document_fields(existing_document, v, updater)
- return subdocument
- def _update_document_single_field(self, doc, field_name, field_value, updater):
- field_name_parts = field_name.split(".")
- for part in field_name_parts[:-1]:
- if isinstance(doc, list):
- try:
- if part == '$':
- doc = doc[0]
- else:
- doc = doc[int(part)]
- continue
- except ValueError:
- pass
- elif isinstance(doc, dict):
- if updater is _unset_updater and part not in doc:
- # If the parent doesn't exists, so does it child.
- return
- doc = doc.setdefault(part, {})
- else:
- return
- field_name = field_name_parts[-1]
- if isinstance(doc, list):
- try:
- doc[int(field_name)] = field_value
- except IndexError:
- pass
- else:
- updater(doc, field_name, field_value)
- def _iter_documents(self, filter=None):
- return (document for document in list(itervalues(self._documents))
- if filter_applies(filter, document))
- def find_one(self, filter=None, *args, **kwargs):
- # Allow calling find_one with a non-dict argument that gets used as
- # the id for the query.
- if filter is None:
- filter = {}
- if not isinstance(filter, collections.Mapping):
- filter = {'_id': filter}
- try:
- return next(self.find(filter, *args, **kwargs))
- except StopIteration:
- return None
- def find_one_and_delete(self, filter, projection=None, sort=None, **kwargs):
- kwargs['remove'] = True
- validate_is_mapping('filter', filter)
- return self._find_and_modify(filter, projection, sort=sort, **kwargs)
- def find_one_and_replace(self, filter, replacement,
- projection=None, sort=None, upsert=False,
- return_document=ReturnDocument.BEFORE, **kwargs):
- validate_is_mapping('filter', filter)
- validate_ok_for_replace(replacement)
- return self._find_and_modify(filter, projection, replacement, upsert,
- sort, return_document, **kwargs)
- def find_one_and_update(self, filter, update,
- projection=None, sort=None, upsert=False,
- return_document=ReturnDocument.BEFORE, **kwargs):
- validate_is_mapping('filter', filter)
- validate_ok_for_update(update)
- return self._find_and_modify(filter, projection, update, upsert,
- sort, return_document, **kwargs)
- def find_and_modify(self, query={}, update=None, upsert=False, sort=None,
- full_response=False, manipulate=False, fields=None, **kwargs):
- warnings.warn("find_and_modify is deprecated, use find_one_and_delete"
- ", find_one_and_replace, or find_one_and_update instead",
- DeprecationWarning, stacklevel=2)
- if 'projection' in kwargs:
- raise TypeError("find_and_modify() got an unexpected keyword argument 'projection'")
- return self._find_and_modify(query, update=update, upsert=upsert,
- sort=sort, projection=fields, **kwargs)
- def _find_and_modify(self, query, projection=None, update=None,
- upsert=False, sort=None,
- return_document=ReturnDocument.BEFORE, **kwargs):
- remove = kwargs.get("remove", False)
- if kwargs.get("new", False) and remove:
- # message from mongodb
- raise OperationFailure("remove and returnNew can't co-exist")
- if not (remove or update):
- raise ValueError("Must either update or remove")
- if remove and update:
- raise ValueError("Can't do both update and remove")
- old = self.find_one(query, projection=projection, sort=sort)
- if not old and not upsert:
- return
- if old and '_id' in old:
- query = {'_id': old['_id']}
- if remove:
- self.delete_one(query)
- else:
- self._update(query, update, upsert)
- if return_document is ReturnDocument.AFTER or kwargs.get('new'):
- return self.find_one(query, projection)
- return old
- def save(self, to_save, manipulate=True, check_keys=True, **kwargs):
- warnings.warn("save is deprecated. Use insert_one or replace_one "
- "instead", DeprecationWarning, stacklevel=2)
- validate_is_mutable_mapping("to_save", to_save)
- validate_write_concern_params(**kwargs)
- if "_id" not in to_save:
- return self.insert(to_save)
- else:
- self._update({"_id": to_save["_id"]}, to_save, True,
- manipulate, check_keys=True, **kwargs)
- return to_save.get("_id", None)
- def delete_one(self, filter):
- validate_is_mapping('filter', filter)
- return DeleteResult(self._delete(filter), True)
- def delete_many(self, filter):
- validate_is_mapping('filter', filter)
- return DeleteResult(self._delete(filter, multi=True), True)
- def _delete(self, filter, multi=False):
- filter = helpers.patch_datetime_awareness_in_document(filter)
- if filter is None:
- filter = {}
- if not isinstance(filter, collections.Mapping):
- filter = {'_id': filter}
- to_delete = list(self.find(filter))
- deleted_count = 0
- for doc in to_delete:
- doc_id = doc['_id']
- if isinstance(doc_id, dict):
- doc_id = helpers.hashdict(doc_id)
- del self._documents[doc_id]
- deleted_count += 1
- if not multi:
- break
- return {
- "connectionId": self.database.client._id,
- "n": deleted_count,
- "ok": 1.0,
- "err": None,
- }
- def remove(self, spec_or_id=None, multi=True, **kwargs):
- warnings.warn("remove is deprecated. Use delete_one or delete_many "
- "instead.", DeprecationWarning, stacklevel=2)
- validate_write_concern_params(**kwargs)
- return self._delete(spec_or_id, multi=multi)
- def count(self, filter=None, **kwargs):
- if filter is None:
- return len(self._documents)
- else:
- return len(list(self._iter_documents(filter)))
- def drop(self):
- self.database.drop_collection(self.name)
- def ensure_index(self, key_or_list, cache_for=300, **kwargs):
- self.create_index(key_or_list, cache_for, **kwargs)
- def create_index(self, key_or_list, cache_for=300, **kwargs):
- if kwargs.pop('unique', False):
- self._uniques.append((helpers.index_list(key_or_list), kwargs.pop('sparse', False)))
- def drop_index(self, index_or_name):
- pass
- def drop_indexes(self):
- self._uniques = []
- def reindex(self):
- pass
- def list_indexes(self):
- return {}
- def index_information(self):
- return {}
- def map_reduce(self, map_func, reduce_func, out, full_response=False,
- query=None, limit=0):
- if execjs is None:
- raise NotImplementedError(
- "PyExecJS is required in order to run Map-Reduce. "
- "Use 'pip install pyexecjs pymongo' to support Map-Reduce mock."
- )
- if limit == 0:
- limit = None
- start_time = time.clock()
- out_collection = None
- reduced_rows = None
- full_dict = {
- 'counts': {
- 'input': 0,
- 'reduce': 0,
- 'emit': 0,
- 'output': 0},
- 'timeMillis': 0,
- 'ok': 1.0,
- 'result': None}
- map_ctx = execjs.compile("""
- function doMap(fnc, docList) {
- var mappedDict = {};
- function emit(key, val) {
- if (key['$oid']) {
- mapped_key = '$oid' + key['$oid'];
- }
- else {
- mapped_key = key;
- }
- if(!mappedDict[mapped_key]) {
- mappedDict[mapped_key] = [];
- }
- mappedDict[mapped_key].push(val);
- }
- mapper = eval('('+fnc+')');
- var mappedList = new Array();
- for(var i=0; i<docList.length; i++) {
- var thisDoc = eval('('+docList[i]+')');
- var mappedVal = (mapper).call(thisDoc);
- }
- return mappedDict;
- }
- """)
- reduce_ctx = execjs.compile("""
- function doReduce(fnc, docList) {
- var reducedList = new Array();
- reducer = eval('('+fnc+')');
- for(var key in docList) {
- var reducedVal = {'_id': key,
- 'value': reducer(key, docList[key])};
- reducedList.push(reducedVal);
- }
- return reducedList;
- }
- """)
- doc_list = [json.dumps(doc, default=json_util.default)
- for doc in self.find(query)]
- mapped_rows = map_ctx.call('doMap', map_func, doc_list)
- reduced_rows = reduce_ctx.call('doReduce', reduce_func, mapped_rows)[:limit]
- for reduced_row in reduced_rows:
- if reduced_row['_id'].startswith('$oid'):
- reduced_row['_id'] = ObjectId(reduced_row['_id'][4:])
- reduced_rows = sorted(reduced_rows, key=lambda x: x['_id'])
- if full_response:
- full_dict['counts']['input'] = len(doc_list)
- for key in mapped_rows.keys():
- emit_count = len(mapped_rows[key])
- full_dict['counts']['emit'] += emit_count
- if emit_count > 1:
- full_dict['counts']['reduce'] += 1
- full_dict['counts']['output'] = len(reduced_rows)
- if isinstance(out, (str, bytes)):
- out_collection = getattr(self.database, out)
- out_collection.drop()
- out_collection.insert(reduced_rows)
- ret_val = out_collection
- full_dict['result'] = out
- elif isinstance(out, SON) and out.get('replace') and out.get('db'):
- # Must be of the format SON([('replace','results'),('db','outdb')])
- out_db = getattr(self.database._client, out['db'])
- out_collection = getattr(out_db, out['replace'])
- out_collection.insert(reduced_rows)
- ret_val = out_collection
- full_dict['result'] = {'db': out['db'], 'collection': out['replace']}
- elif isinstance(out, dict) and out.get('inline'):
- ret_val = reduced_rows
- full_dict['result'] = reduced_rows
- else:
- raise TypeError("'out' must be an instance of string, dict or bson.SON")
- full_dict['timeMillis'] = int(round((time.clock() - start_time) * 1000))
- if full_response:
- ret_val = full_dict
- return ret_val
- def inline_map_reduce(self, map_func, reduce_func, full_response=False,
- query=None, limit=0):
- return self.map_reduce(
- map_func, reduce_func, {'inline': 1}, full_response, query, limit)
- def distinct(self, key, filter=None):
- return self.find(filter).distinct(key)
- def group(self, key, condition, initial, reduce, finalize=None):
- if execjs is None:
- raise NotImplementedError(
- "PyExecJS is required in order to use group. "
- "Use 'pip install pyexecjs pymongo' to support group mock."
- )
- reduce_ctx = execjs.compile("""
- function doReduce(fnc, docList) {
- reducer = eval('('+fnc+')');
- for(var i=0, l=docList.length; i<l; i++) {
- try {
- reducedVal = reducer(docList[i-1], docList[i]);
- }
- catch (err) {
- continue;
- }
- }
- return docList[docList.length - 1];
- }
- """)
- ret_array = []
- doc_list_copy = []
- ret_array_copy = []
- reduced_val = {}
- doc_list = [doc for doc in self.find(condition)]
- for doc in doc_list:
- doc_copy = copy.deepcopy(doc)
- for k in doc:
- if isinstance(doc[k], ObjectId):
- doc_copy[k] = str(doc[k])
- if k not in key and k not in reduce:
- del doc_copy[k]
- for initial_key in initial:
- if initial_key in doc.keys():
- pass
- else:
- doc_copy[initial_key] = initial[initial_key]
- doc_list_copy.append(doc_copy)
- doc_list = doc_list_copy
- for k in key:
- doc_list = sorted(doc_list, key=lambda x: _resolve_key(k, x))
- for k in key:
- if not isinstance(k, helpers.basestring):
- raise TypeError(
- "Keys must be a list of key names, "
- "each an instance of %s" % helpers.basestring.__name__)
- for k2, group in itertools.groupby(doc_list, lambda item: item[k]):
- group_list = ([x for x in group])
- reduced_val = reduce_ctx.call('doReduce', reduce, group_list)
- ret_array.append(reduced_val)
- for doc in ret_array:
- doc_copy = copy.deepcopy(doc)
- for k in doc:
- if k not in key and k not in initial.keys():
- del doc_copy[k]
- ret_array_copy.append(doc_copy)
- ret_array = ret_array_copy
- return ret_array
- def aggregate(self, pipeline, **kwargs):
- pipeline_operators = [
- '$project',
- '$match',
- '$redact',
- '$limit',
- '$skip',
- '$unwind',
- '$group',
- '$sample'
- '$sort',
- '$geoNear',
- '$lookup',
- '$out',
- '$indexStats']
- group_operators = [
- '$addToSet',
- '$first',
- '$last',
- '$max',
- '$min',
- '$avg',
- '$push',
- '$sum',
- '$stdDevPop',
- '$stdDevSamp']
- project_operators = [
- '$max',
- '$min',
- '$avg',
- '$sum',
- '$stdDevPop',
- '$stdDevSamp',
- '$arrayElemAt'
- ]
- boolean_operators = ['$and', '$or', '$not'] # noqa
- set_operators = [ # noqa
- '$setEquals',
- '$setIntersection',
- '$setDifference',
- '$setUnion',
- '$setIsSubset',
- '$anyElementTrue',
- '$allElementsTrue']
- comparison_operators = [ # noqa
- '$cmp',
- '$eq',
- '$gt',
- '$gte',
- '$lt',
- '$lte',
- '$ne']
- arithmetic_operators = [ # noqa
- '$abs',
- '$add',
- '$ceil',
- '$divide',
- '$exp',
- '$floor',
- '$ln',
- '$log',
- '$log10',
- '$mod',
- '$multiply',
- '$pow',
- '$sqrt',
- '$subtract',
- '$trunc']
- string_operators = [ # noqa
- '$concat',
- '$strcasecmp',
- '$substr',
- '$toLower',
- '$toUpper']
- text_search_operators = ['$meta'] # noqa
- array_operators = [ # noqa
- '$arrayElemAt',
- '$concatArrays',
- '$filter',
- '$isArray',
- '$size',
- '$slice']
- projection_operators = ['$map', '$let', '$literal'] # noqa
- date_operators = [ # noqa
- '$dayOfYear',
- '$dayOfMonth',
- '$dayOfWeek',
- '$year',
- '$month',
- '$week',
- '$hour',
- '$minute',
- '$second',
- '$millisecond',
- '$dateToString']
- conditional_operators = ['$cond', '$ifNull'] # noqa
- def _handle_arithmetic_operator(operator, values, doc_dict):
- if operator == '$abs':
- return abs(_parse_expression(values, doc_dict))
- elif operator == '$ceil':
- return math.ceil(_parse_expression(values, doc_dict))
- elif operator == '$divide':
- assert len(values) == 2, 'divide must have only 2 items'
- return _parse_expression(values[0], doc_dict) / _parse_expression(values[1],
- doc_dict)
- elif operator == '$exp':
- return math.exp(_parse_expression(values, doc_dict))
- elif operator == '$floor':
- return math.floor(_parse_expression(values, doc_dict))
- elif operator == '$ln':
- return math.log(_parse_expression(values, doc_dict))
- elif operator == '$log':
- assert len(values) == 2, 'log must have only 2 items'
- return math.log(_parse_expression(values[0], doc_dict),
- _parse_expression(values[1], doc_dict))
- elif operator == '$log10':
- return math.log10(_parse_expression(values, doc_dict))
- elif operator == '$mod':
- assert len(values) == 2, 'mod must have only 2 items'
- return math.fmod(_parse_expression(values[0], doc_dict),
- _parse_expression(values[1], doc_dict))
- elif operator == '$pow':
- assert len(values) == 2, 'pow must have only 2 items'
- return math.pow(_parse_expression(values[0], doc_dict),
- _parse_expression(values[1], doc_dict))
- elif operator == '$sqrt':
- return math.sqrt(_parse_expression(values, doc_dict))
- elif operator == '$subtract':
- assert len(values) == 2, 'subtract must have only 2 items'
- return _parse_expression(values[0], doc_dict) - _parse_expression(values[1],
- doc_dict)
- else:
- raise NotImplementedError("Although '%s' is a valid aritmetic operator for the "
- "aggregation pipeline, it is currently not implemented "
- " in Mongomock." % operator)
- def _handle_comparison_operator(operator, values, doc_dict):
- assert len(values) == 2, 'Comparison requires two expressions'
- if operator == '$eq':
- return _parse_expression(values[0], doc_dict) == \
- _parse_expression(values[1], doc_dict)
- elif operator == '$gt':
- return _parse_expression(values[0], doc_dict) > \
- _parse_expression(values[1], doc_dict)
- elif operator == '$gte':
- return _parse_expression(values[0], doc_dict) >= \
- _parse_expression(values[1], doc_dict)
- elif operator == '$lt':
- return _parse_expression(values[0], doc_dict) < \
- _parse_expression(values[1], doc_dict)
- elif operator == '$lte':
- return _parse_expression(values[0], doc_dict) <= \
- _parse_expression(values[1], doc_dict)
- elif operator == '$ne':
- return _parse_expression(values[0], doc_dict) != \
- _parse_expression(values[1], doc_dict)
- else:
- raise NotImplementedError(
- "Although '%s' is a valid comparison operator for the "
- "aggregation pipeline, it is currently not implemented "
- " in Mongomock." % operator)
- def _handle_date_operator(operator, values, doc_dict):
- out_value = _parse_expression(values, doc_dict)
- if operator == '$dayOfYear':
- return out_value.timetuple().tm_yday
- elif operator == '$dayOfMonth':
- return out_value.day
- elif operator == '$dayOfWeek':
- return out_value.isoweekday()
- elif operator == '$year':
- return out_value.year
- elif operator == '$month':
- return out_value.month
- elif operator == '$week':
- return out_value.isocalendar()[1]
- elif operator == '$hour':
- return out_value.hour
- elif operator == '$minute':
- return out_value.minute
- elif operator == '$second':
- return out_value.second
- elif operator == '$millisecond':
- return int(out_value.microsecond / 1000)
- else:
- raise NotImplementedError(
- "Although '%s' is a valid date operator for the "
- "aggregation pipeline, it is currently not implemented "
- " in Mongomock." % operator)
- def _handle_array_operator(operator, values, doc_dict):
- out_value = _parse_expression(values, doc_dict)
- if operator == '$size':
- return len(out_value)
- else:
- raise NotImplementedError(
- "Although '%s' is a valid date operator for the "
- "aggregation pipeline, it is currently not implemented "
- " in Mongomock." % operator)
- def _handle_conditional_operator(operator, values, doc_dict):
- if operator == '$ifNull':
- field, fallback = values
- try:
- out_value = _parse_expression(field, doc_dict)
- except KeyError:
- return fallback
- return out_value if out_value is not None else fallback
- else:
- raise NotImplementedError(
- "Although '%s' is a valid date operator for the "
- "aggregation pipeline, it is currently not implemented "
- " in Mongomock." % operator)
- def _handle_project_operator(operator, values, doc_dict):
- if operator == '$min':
- if len(values) > 2:
- raise NotImplementedError("Although %d is a valid amount of elements in "
- "aggregation pipeline, it is currently not "
- " implemented in Mongomock" % len(values))
- return min(_parse_expression(values[0], doc_dict),
- _parse_expression(values[1], doc_dict))
- elif operator == '$arrayElemAt':
- key, index = values
- array = _parse_basic_expression(key, doc_dict)
- v = array[index]
- return v
- else:
- raise NotImplementedError("Although '%s' is a valid project operator for the "
- "aggregation pipeline, it is currently not implemented "
- "in Mongomock." % operator)
- def _parse_basic_expression(expression, doc_dict):
- if isinstance(expression, str) and expression.startswith('$'):
- get_value = helpers.embedded_item_getter(expression.replace('$', ''))
- return get_value(doc_dict)
- else:
- return expression
- def _parse_expression(expression, doc_dict):
- if not isinstance(expression, dict):
- return _parse_basic_expression(expression, doc_dict)
- value_dict = {}
- for k, v in iteritems(expression):
- if k in arithmetic_operators:
- return _handle_arithmetic_operator(k, v, doc_dict)
- elif k in project_operators:
- return _handle_project_operator(k, v, doc_dict)
- elif k in comparison_operators:
- return _handle_comparison_operator(k, v, doc_dict)
- elif k in date_operators:
- return _handle_date_operator(k, v, doc_dict)
- elif k in array_operators:
- return _handle_array_operator(k, v, doc_dict)
- elif k in conditional_operators:
- return _handle_conditional_operator(k, v, doc_dict)
- else:
- value_dict[k] = _parse_expression(v, doc_dict)
- return value_dict
- def _extend_collection(out_collection, field, expression):
- field_exists = False
- for doc in out_collection:
- if field in doc:
- field_exists = True
- break
- if not field_exists:
- for doc in out_collection:
- if isinstance(expression, str) and expression.startswith('$'):
- try:
- doc[field] = get_value_by_dot(doc, expression.lstrip('$'))
- except KeyError:
- pass
- else:
- # verify expression has operator as first
- doc[field] = _parse_expression(expression.copy(), doc)
- return out_collection
- out_collection = [doc for doc in self.find()]
- for stage in pipeline:
- for k, v in iteritems(stage):
- if k == '$match':
- out_collection = [doc for doc in out_collection
- if filter_applies(v, doc)]
- elif k == '$group':
- grouped_collection = []
- _id = stage['$group']['_id']
- if _id:
- key_getter = functools.partial(_parse_expression, _id)
- out_collection = sorted(out_collection, key=key_getter)
- grouped = itertools.groupby(out_collection, key_getter)
- else:
- grouped = [(None, out_collection)]
- for doc_id, group in grouped:
- group_list = ([x for x in group])
- doc_dict = {'_id': doc_id}
- for field, value in iteritems(v):
- if field == '_id':
- continue
- for operator, key in iteritems(value):
- if operator in (
- "$sum",
- "$avg",
- "$min",
- "$max",
- "$first",
- "$last",
- "$addToSet",
- '$push'
- ):
- key_getter = functools.partial(_parse_expression, key)
- values = [key_getter(doc) for doc in group_list]
- if operator == "$sum":
- val_it = (val or 0 for val in values)
- doc_dict[field] = sum(val_it)
- elif operator == "$avg":
- values = [val or 0 for val in values]
- doc_dict[field] = sum(values) / max(len(values), 1)
- elif operator == "$min":
- val_it = (val or MAXSIZE for val in values)
- doc_dict[field] = min(val_it)
- elif operator == "$max":
- val_it = (val or -MAXSIZE for val in values)
- doc_dict[field] = max(val_it)
- elif operator == "$first":
- doc_dict[field] = values[0]
- elif operator == "$last":
- doc_dict[field] = values[-1]
- elif operator == "$addToSet":
- val_it = (val or None for val in values)
- doc_dict[field] = set(val_it)
- elif operator == '$push':
- if field not in doc_dict:
- doc_dict[field] = []
- doc_dict[field].extend(values)
- else:
- if operator in group_operators:
- raise NotImplementedError(
- "Although %s is a valid group operator for the "
- "aggregation pipeline, it is currently not implemented "
- "in Mongomock." % operator)
- else:
- raise NotImplementedError(
- "%s is not a valid group operator for the aggregation "
- "pipeline. See http://docs.mongodb.org/manual/meta/"
- "aggregation-quick-reference/ for a complete list of "
- "valid operators." % operator)
- grouped_collection.append(doc_dict)
- out_collection = grouped_collection
- elif k == '$sort':
- sort_array = []
- for x, y in v.items():
- sort_array.append({x: y})
- for sort_pair in reversed(sort_array):
- for sortKey, sortDirection in sort_pair.items():
- out_collection = sorted(
- out_collection,
- key=lambda x: _resolve_sort_key(sortKey, x),
- reverse=sortDirection < 0)
- elif k == '$skip':
- out_collection = out_collection[v:]
- elif k == '$limit':
- out_collection = out_collection[:v]
- elif k == '$unwind':
- if not isinstance(v, helpers.basestring) or v[0] != '$':
- raise ValueError(
- "$unwind failed: exception: field path references must be prefixed "
- "with a '$' '%s'" % v)
- unwound_collection = []
- for doc in out_collection:
- array_value = get_value_by_dot(doc, v[1:])
- if array_value in (None, []):
- continue
- elif not isinstance(array_value, list):
- raise TypeError(
- '$unwind must specify an array field, field: '
- '"%s", value found: %s' % (v, array_value))
- for field_item in array_value:
- unwound_collection.append(copy.deepcopy(doc))
- unwound_collection[-1] = set_value_by_dot(
- unwound_collection[-1], v[1:], field_item)
- out_collection = unwound_collection
- elif k == '$project':
- filter_list = ['_id']
- for field, value in iteritems(v):
- if field == '_id' and not value:
- filter_list.remove('_id')
- elif value:
- filter_list.append(field)
- out_collection = _extend_collection(out_collection, field, value)
- out_collection = [{k: v for (k, v) in x.items() if k in filter_list}
- for x in out_collection]
- elif k == '$out':
- # TODO(MetrodataTeam): should leave the origin collection unchanged
- collection = self.database.get_collection(v)
- if collection.count() > 0:
- collection.drop()
- collection.insert_many(out_collection)
- else:
- if k in pipeline_operators:
- raise NotImplementedError(
- "Although '%s' is a valid operator for the aggregation pipeline, it is "
- "currently not implemented in Mongomock." % k)
- else:
- raise NotImplementedError(
- "%s is not a valid operator for the aggregation pipeline. "
- "See http://docs.mongodb.org/manual/meta/aggregation-quick-reference/ "
- "for a complete list of valid operators." % k)
- return CommandCursor(out_collection)
- def with_options(
- self, codec_options=None, read_preference=None, write_concern=None, read_concern=None):
- return self
- def rename(self, new_name, **kwargs):
- self.database.rename_collection(self.name, new_name, **kwargs)
- def bulk_write(self, operations):
- bulk = BulkOperationBuilder(self)
- for operation in operations:
- operation._add_to_bulk(bulk)
- return BulkWriteResult(bulk.execute(), True)
- def _resolve_key(key, doc):
- return next(iter(iter_key_candidates(key, doc)), NOTHING)
- def _resolve_sort_key(key, doc):
- value = _resolve_key(key, doc)
- # see http://docs.mongodb.org/manual/reference/method/cursor.sort/#ascending-descending-sort
- if value is NOTHING:
- return 0, value
- return 1, value
- class Cursor(object):
- def __init__(self, collection, spec=None, sort=None, projection=None, skip=0, limit=0,
- collation=None):
- super(Cursor, self).__init__()
- self.collection = collection
- spec = helpers.patch_datetime_awareness_in_document(spec)
- self._spec = spec
- self._sort = sort
- self._projection = projection
- self._skip = skip
- self._factory_last_generated_results = None
- self._results = None
- self._factory = functools.partial(collection._get_dataset,
- spec, sort, projection, dict)
- # pymongo limit defaults to 0, returning everything
- self._limit = limit if limit != 0 else None
- self._collation = collation
- self.rewind()
- def _compute_results(self, with_limit_and_skip=False):
- # Recompute the result only if the query has changed
- if not self._results or self._factory_last_generated_results != self._factory:
- if self.collection.database.client._tz_aware:
- results = [helpers.make_datetime_timezone_aware_in_document(x)
- for x in self._factory()]
- else:
- results = list(self._factory())
- self._factory_last_generated_results = self._factory
- self._results = results
- if with_limit_and_skip:
- results = self._results[self._skip:]
- if self._limit:
- results = results[:self._limit]
- else:
- results = self._results
- return results
- def __iter__(self):
- return self
- def clone(self):
- cursor = Cursor(self.collection,
- self._spec, self._sort, self._projection, self._skip, self._limit)
- cursor._factory = self._factory
- return cursor
- def __next__(self):
- try:
- doc = self._compute_results(with_limit_and_skip=True)[self._emitted]
- self._emitted += 1
- return doc
- except IndexError:
- raise StopIteration()
- next = __next__
- def rewind(self):
- self._emitted = 0
- def sort(self, key_or_list, direction=None):
- if direction is None:
- direction = 1
- def _make_sort_factory_layer(upper_factory, sortKey, sortDirection):
- def layer():
- return sorted(upper_factory(), key=lambda x: _resolve_sort_key(sortKey, x),
- reverse=sortDirection < 0)
- return layer
- if isinstance(key_or_list, (tuple, list)):
- for sortKey, sortDirection in reversed(key_or_list):
- self._factory = _make_sort_factory_layer(self._factory, sortKey, sortDirection)
- else:
- self._factory = _make_sort_factory_layer(self._factory, key_or_list, direction)
- return self
- def count(self, with_limit_and_skip=False):
- results = self._compute_results(with_limit_and_skip)
- return len(results)
- def skip(self, count):
- self._skip = count
- return self
- def limit(self, count):
- self._limit = count if count != 0 else None
- return self
- def batch_size(self, count):
- return self
- def close(self):
- pass
- def distinct(self, key):
- if not isinstance(key, helpers.basestring):
- raise TypeError('cursor.distinct key must be a string')
- unique = set()
- unique_dict_vals = []
- for x in self._compute_results():
- value = _resolve_key(key, x)
- if value == NOTHING:
- continue
- if isinstance(value, dict):
- if any(dict_val == value for dict_val in unique_dict_vals):
- continue
- unique_dict_vals.append(value)
- else:
- unique.update(
- value if isinstance(
- value, (tuple, list)) else [value])
- return list(unique) + unique_dict_vals
- def __getitem__(self, index):
- if isinstance(index, slice):
- if index.step is not None:
- raise IndexError("Cursor instances do not support slice steps")
- skip = 0
- if index.start is not None:
- if index.start < 0:
- raise IndexError("Cursor instances do not support"
- "negative indices")
- skip = index.start
- if index.stop is not None:
- limit = index.stop - skip
- if limit < 0:
- raise IndexError("stop index must be greater than start"
- "index for slice %r" % index)
- if limit == 0:
- self.__empty = True
- else:
- limit = 0
- self._skip = skip
- self._limit = limit
- return self
- elif not isinstance(index, int):
- raise TypeError("index '%s' cannot be applied to Cursor instances" % index)
- elif index < 0:
- raise IndexError('Cursor instances do not support negativeindices')
- else:
- return self._compute_results(with_limit_and_skip=True)[index]
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
- def _set_updater(doc, field_name, value):
- if isinstance(value, (tuple, list)):
- value = copy.deepcopy(value)
- if isinstance(doc, dict):
- doc[field_name] = value
- def _unset_updater(doc, field_name, value):
- if isinstance(doc, dict):
- doc.pop(field_name, None)
- def _inc_updater(doc, field_name, value):
- if isinstance(doc, dict):
- doc[field_name] = doc.get(field_name, 0) + value
- def _max_updater(doc, field_name, value):
- if isinstance(doc, dict):
- doc[field_name] = max(doc.get(field_name, value), value)
- def _min_updater(doc, field_name, value):
- if isinstance(doc, dict):
- doc[field_name] = min(doc.get(field_name, value), value)
- def _sum_updater(doc, field_name, current, result):
- if isinstance(doc, dict):
- result = current + doc.get[field_name, 0]
- return result
- def _current_date_updater(doc, field_name, value):
- if isinstance(doc, dict):
- doc[field_name] = datetime.utcnow()
- _updaters = {
- '$set': _set_updater,
- '$unset': _unset_updater,
- '$inc': _inc_updater,
- '$max': _max_updater,
- '$min': _min_updater,
- }
|