| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096 |
- import copy
- import numbers
- from functools import partial
- from bson import ObjectId, json_util
- from bson.dbref import DBRef
- from bson.son import SON
- import pymongo
- import six
- from mongoengine import signals
- from mongoengine.base.common import get_document
- from mongoengine.base.datastructures import (BaseDict, BaseList,
- EmbeddedDocumentList,
- LazyReference,
- StrictDict)
- from mongoengine.base.fields import ComplexBaseField
- from mongoengine.common import _import_class
- from mongoengine.errors import (FieldDoesNotExist, InvalidDocumentError,
- LookUpError, OperationError, ValidationError)
- from mongoengine.python_support import Hashable
- __all__ = ('BaseDocument', 'NON_FIELD_ERRORS')
- NON_FIELD_ERRORS = '__all__'
- class BaseDocument(object):
- __slots__ = ('_changed_fields', '_initialised', '_created', '_data',
- '_dynamic_fields', '_auto_id_field', '_db_field_map',
- '__weakref__')
- _dynamic = False
- _dynamic_lock = True
- STRICT = False
- def __init__(self, *args, **values):
- """
- Initialise a document or embedded document
- :param __auto_convert: Try and will cast python objects to Object types
- :param values: A dictionary of values for the document
- """
- self._initialised = False
- self._created = True
- if args:
- # Combine positional arguments with named arguments.
- # We only want named arguments.
- field = iter(self._fields_ordered)
- # If its an automatic id field then skip to the first defined field
- if getattr(self, '_auto_id_field', False):
- next(field)
- for value in args:
- name = next(field)
- if name in values:
- raise TypeError(
- 'Multiple values for keyword argument "%s"' % name)
- values[name] = value
- __auto_convert = values.pop('__auto_convert', True)
- # 399: set default values only to fields loaded from DB
- __only_fields = set(values.pop('__only_fields', values))
- _created = values.pop('_created', True)
- signals.pre_init.send(self.__class__, document=self, values=values)
- # Check if there are undefined fields supplied to the constructor,
- # if so raise an Exception.
- if not self._dynamic and (self._meta.get('strict', True) or _created):
- _undefined_fields = set(values.keys()) - set(
- self._fields.keys() + ['id', 'pk', '_cls', '_text_score'])
- if _undefined_fields:
- msg = (
- 'The fields "{0}" do not exist on the document "{1}"'
- ).format(_undefined_fields, self._class_name)
- raise FieldDoesNotExist(msg)
- if self.STRICT and not self._dynamic:
- self._data = StrictDict.create(allowed_keys=self._fields_ordered)()
- else:
- self._data = {}
- self._dynamic_fields = SON()
- # Assign default values to instance
- for key, field in self._fields.iteritems():
- if self._db_field_map.get(key, key) in __only_fields:
- continue
- value = getattr(self, key, None)
- setattr(self, key, value)
- # Set passed values after initialisation
- if self._dynamic:
- dynamic_data = {}
- for key, value in values.iteritems():
- if key in self._fields or key == '_id':
- setattr(self, key, value)
- else:
- dynamic_data[key] = value
- else:
- FileField = _import_class('FileField')
- for key, value in values.iteritems():
- key = self._reverse_db_field_map.get(key, key)
- if key in self._fields or key in ('id', 'pk', '_cls'):
- if __auto_convert and value is not None:
- field = self._fields.get(key)
- if field and not isinstance(field, FileField):
- value = field.to_python(value)
- setattr(self, key, value)
- else:
- self._data[key] = value
- # Set any get_<field>_display methods
- self.__set_field_display()
- if self._dynamic:
- self._dynamic_lock = False
- for key, value in dynamic_data.iteritems():
- setattr(self, key, value)
- # Flag initialised
- self._initialised = True
- self._created = _created
- signals.post_init.send(self.__class__, document=self)
- def __delattr__(self, *args, **kwargs):
- """Handle deletions of fields"""
- field_name = args[0]
- if field_name in self._fields:
- default = self._fields[field_name].default
- if callable(default):
- default = default()
- setattr(self, field_name, default)
- else:
- super(BaseDocument, self).__delattr__(*args, **kwargs)
- def __setattr__(self, name, value):
- # Handle dynamic data only if an initialised dynamic document
- if self._dynamic and not self._dynamic_lock:
- if not hasattr(self, name) and not name.startswith('_'):
- DynamicField = _import_class('DynamicField')
- field = DynamicField(db_field=name, null=True)
- field.name = name
- self._dynamic_fields[name] = field
- self._fields_ordered += (name,)
- if not name.startswith('_'):
- value = self.__expand_dynamic_values(name, value)
- # Handle marking data as changed
- if name in self._dynamic_fields:
- self._data[name] = value
- if hasattr(self, '_changed_fields'):
- self._mark_as_changed(name)
- try:
- self__created = self._created
- except AttributeError:
- self__created = True
- if (
- self._is_document and
- not self__created and
- name in self._meta.get('shard_key', tuple()) and
- self._data.get(name) != value
- ):
- msg = 'Shard Keys are immutable. Tried to update %s' % name
- raise OperationError(msg)
- try:
- self__initialised = self._initialised
- except AttributeError:
- self__initialised = False
- # Check if the user has created a new instance of a class
- if (self._is_document and self__initialised and
- self__created and name == self._meta.get('id_field')):
- super(BaseDocument, self).__setattr__('_created', False)
- super(BaseDocument, self).__setattr__(name, value)
- def __getstate__(self):
- data = {}
- for k in ('_changed_fields', '_initialised', '_created',
- '_dynamic_fields', '_fields_ordered'):
- if hasattr(self, k):
- data[k] = getattr(self, k)
- data['_data'] = self.to_mongo()
- return data
- def __setstate__(self, data):
- if isinstance(data['_data'], SON):
- data['_data'] = self.__class__._from_son(data['_data'])._data
- for k in ('_changed_fields', '_initialised', '_created', '_data',
- '_dynamic_fields'):
- if k in data:
- setattr(self, k, data[k])
- if '_fields_ordered' in data:
- if self._dynamic:
- setattr(self, '_fields_ordered', data['_fields_ordered'])
- else:
- _super_fields_ordered = type(self)._fields_ordered
- setattr(self, '_fields_ordered', _super_fields_ordered)
- dynamic_fields = data.get('_dynamic_fields') or SON()
- for k in dynamic_fields.keys():
- setattr(self, k, data['_data'].get(k))
- def __iter__(self):
- return iter(self._fields_ordered)
- def __getitem__(self, name):
- """Dictionary-style field access, return a field's value if present.
- """
- try:
- if name in self._fields_ordered:
- return getattr(self, name)
- except AttributeError:
- pass
- raise KeyError(name)
- def __setitem__(self, name, value):
- """Dictionary-style field access, set a field's value.
- """
- # Ensure that the field exists before settings its value
- if not self._dynamic and name not in self._fields:
- raise KeyError(name)
- return setattr(self, name, value)
- def __contains__(self, name):
- try:
- val = getattr(self, name)
- return val is not None
- except AttributeError:
- return False
- def __len__(self):
- return len(self._data)
- def __repr__(self):
- try:
- u = self.__str__()
- except (UnicodeEncodeError, UnicodeDecodeError):
- u = '[Bad Unicode data]'
- repr_type = str if u is None else type(u)
- return repr_type('<%s: %s>' % (self.__class__.__name__, u))
- def __str__(self):
- # TODO this could be simpler?
- if hasattr(self, '__unicode__'):
- if six.PY3:
- return self.__unicode__()
- else:
- return six.text_type(self).encode('utf-8')
- return six.text_type('%s object' % self.__class__.__name__)
- def __eq__(self, other):
- if isinstance(other, self.__class__) and hasattr(other, 'id') and other.id is not None:
- return self.id == other.id
- if isinstance(other, DBRef):
- return self._get_collection_name() == other.collection and self.id == other.id
- if self.id is None:
- return self is other
- return False
- def __ne__(self, other):
- return not self.__eq__(other)
- def clean(self):
- """
- Hook for doing document level data cleaning before validation is run.
- Any ValidationError raised by this method will not be associated with
- a particular field; it will have a special-case association with the
- field defined by NON_FIELD_ERRORS.
- """
- pass
- def get_text_score(self):
- """
- Get text score from text query
- """
- if '_text_score' not in self._data:
- raise InvalidDocumentError('This document is not originally built from a text query')
- return self._data['_text_score']
- def to_mongo(self, use_db_field=True, fields=None):
- """
- Return as SON data ready for use with MongoDB.
- """
- if not fields:
- fields = []
- data = SON()
- data['_id'] = None
- data['_cls'] = self._class_name
- # only root fields ['test1.a', 'test2'] => ['test1', 'test2']
- root_fields = {f.split('.')[0] for f in fields}
- for field_name in self:
- if root_fields and field_name not in root_fields:
- continue
- value = self._data.get(field_name, None)
- field = self._fields.get(field_name)
- if field is None and self._dynamic:
- field = self._dynamic_fields.get(field_name)
- if value is not None:
- f_inputs = field.to_mongo.__code__.co_varnames
- ex_vars = {}
- if fields and 'fields' in f_inputs:
- key = '%s.' % field_name
- embedded_fields = [
- i.replace(key, '') for i in fields
- if i.startswith(key)]
- ex_vars['fields'] = embedded_fields
- if 'use_db_field' in f_inputs:
- ex_vars['use_db_field'] = use_db_field
- value = field.to_mongo(value, **ex_vars)
- # Handle self generating fields
- if value is None and field._auto_gen:
- value = field.generate()
- self._data[field_name] = value
- if (value is not None) or (field.null):
- if use_db_field:
- data[field.db_field] = value
- else:
- data[field.name] = value
- # Only add _cls if allow_inheritance is True
- if not self._meta.get('allow_inheritance'):
- data.pop('_cls')
- return data
- def validate(self, clean=True):
- """Ensure that all fields' values are valid and that required fields
- are present.
- """
- # Ensure that each field is matched to a valid value
- errors = {}
- if clean:
- try:
- self.clean()
- except ValidationError as error:
- errors[NON_FIELD_ERRORS] = error
- # Get a list of tuples of field names and their current values
- fields = [(self._fields.get(name, self._dynamic_fields.get(name)),
- self._data.get(name)) for name in self._fields_ordered]
- EmbeddedDocumentField = _import_class('EmbeddedDocumentField')
- GenericEmbeddedDocumentField = _import_class(
- 'GenericEmbeddedDocumentField')
- for field, value in fields:
- if value is not None:
- try:
- if isinstance(field, (EmbeddedDocumentField,
- GenericEmbeddedDocumentField)):
- field._validate(value, clean=clean)
- else:
- field._validate(value)
- except ValidationError as error:
- errors[field.name] = error.errors or error
- except (ValueError, AttributeError, AssertionError) as error:
- errors[field.name] = error
- elif field.required and not getattr(field, '_auto_gen', False):
- errors[field.name] = ValidationError('Field is required',
- field_name=field.name)
- if errors:
- pk = 'None'
- if hasattr(self, 'pk'):
- pk = self.pk
- elif self._instance and hasattr(self._instance, 'pk'):
- pk = self._instance.pk
- message = 'ValidationError (%s:%s) ' % (self._class_name, pk)
- raise ValidationError(message, errors=errors)
- def to_json(self, *args, **kwargs):
- """Convert this document to JSON.
- :param use_db_field: Serialize field names as they appear in
- MongoDB (as opposed to attribute names on this document).
- Defaults to True.
- """
- use_db_field = kwargs.pop('use_db_field', True)
- return json_util.dumps(self.to_mongo(use_db_field), *args, **kwargs)
- @classmethod
- def from_json(cls, json_data, created=False):
- """Converts json data to a Document instance
- :param json_data: The json data to load into the Document
- :param created: If True, the document will be considered as a brand new document
- If False and an id is provided, it will consider that the data being
- loaded corresponds to what's already in the database (This has an impact of subsequent call to .save())
- If False and no id is provided, it will consider the data as a new document
- (default ``False``)
- """
- return cls._from_son(json_util.loads(json_data), created=created)
- def __expand_dynamic_values(self, name, value):
- """Expand any dynamic values to their correct types / values."""
- if not isinstance(value, (dict, list, tuple)):
- return value
- # If the value is a dict with '_cls' in it, turn it into a document
- is_dict = isinstance(value, dict)
- if is_dict and '_cls' in value:
- cls = get_document(value['_cls'])
- return cls(**value)
- if is_dict:
- value = {
- k: self.__expand_dynamic_values(k, v)
- for k, v in value.items()
- }
- else:
- value = [self.__expand_dynamic_values(name, v) for v in value]
- # Convert lists / values so we can watch for any changes on them
- EmbeddedDocumentListField = _import_class('EmbeddedDocumentListField')
- if (isinstance(value, (list, tuple)) and
- not isinstance(value, BaseList)):
- if issubclass(type(self), EmbeddedDocumentListField):
- value = EmbeddedDocumentList(value, self, name)
- else:
- value = BaseList(value, self, name)
- elif isinstance(value, dict) and not isinstance(value, BaseDict):
- value = BaseDict(value, self, name)
- return value
- def _mark_as_changed(self, key):
- """Mark a key as explicitly changed by the user."""
- if not key:
- return
- if not hasattr(self, '_changed_fields'):
- return
- if '.' in key:
- key, rest = key.split('.', 1)
- key = self._db_field_map.get(key, key)
- key = '%s.%s' % (key, rest)
- else:
- key = self._db_field_map.get(key, key)
- if key not in self._changed_fields:
- levels, idx = key.split('.'), 1
- while idx <= len(levels):
- if '.'.join(levels[:idx]) in self._changed_fields:
- break
- idx += 1
- else:
- self._changed_fields.append(key)
- # remove lower level changed fields
- level = '.'.join(levels[:idx]) + '.'
- remove = self._changed_fields.remove
- for field in self._changed_fields[:]:
- if field.startswith(level):
- remove(field)
- def _clear_changed_fields(self):
- """Using _get_changed_fields iterate and remove any fields that
- are marked as changed.
- """
- for changed in self._get_changed_fields():
- parts = changed.split('.')
- data = self
- for part in parts:
- if isinstance(data, list):
- try:
- data = data[int(part)]
- except IndexError:
- data = None
- elif isinstance(data, dict):
- data = data.get(part, None)
- else:
- data = getattr(data, part, None)
- if not isinstance(data, LazyReference) and hasattr(data, '_changed_fields'):
- if getattr(data, '_is_document', False):
- continue
- data._changed_fields = []
- self._changed_fields = []
- def _nestable_types_changed_fields(self, changed_fields, base_key, data):
- """Inspect nested data for changed fields
- :param changed_fields: Previously collected changed fields
- :param base_key: The base key that must be used to prepend changes to this data
- :param data: data to inspect for changes
- """
- # Loop list / dict fields as they contain documents
- # Determine the iterator to use
- if not hasattr(data, 'items'):
- iterator = enumerate(data)
- else:
- iterator = data.iteritems()
- for index_or_key, value in iterator:
- item_key = '%s%s.' % (base_key, index_or_key)
- # don't check anything lower if this key is already marked
- # as changed.
- if item_key[:-1] in changed_fields:
- continue
- if hasattr(value, '_get_changed_fields'):
- changed = value._get_changed_fields()
- changed_fields += ['%s%s' % (item_key, k) for k in changed if k]
- elif isinstance(value, (list, tuple, dict)):
- self._nestable_types_changed_fields(
- changed_fields, item_key, value)
- def _get_changed_fields(self):
- """Return a list of all fields that have explicitly been changed.
- """
- EmbeddedDocument = _import_class('EmbeddedDocument')
- ReferenceField = _import_class('ReferenceField')
- GenericReferenceField = _import_class('GenericReferenceField')
- SortedListField = _import_class('SortedListField')
- changed_fields = []
- changed_fields += getattr(self, '_changed_fields', [])
- for field_name in self._fields_ordered:
- db_field_name = self._db_field_map.get(field_name, field_name)
- key = '%s.' % db_field_name
- data = self._data.get(field_name, None)
- field = self._fields.get(field_name)
- if db_field_name in changed_fields:
- # Whole field already marked as changed, no need to go further
- continue
- if isinstance(field, ReferenceField): # Don't follow referenced documents
- continue
- if isinstance(data, EmbeddedDocument):
- # Find all embedded fields that have been changed
- changed = data._get_changed_fields()
- changed_fields += ['%s%s' % (key, k) for k in changed if k]
- elif isinstance(data, (list, tuple, dict)):
- if (hasattr(field, 'field') and
- isinstance(field.field, (ReferenceField, GenericReferenceField))):
- continue
- elif isinstance(field, SortedListField) and field._ordering:
- # if ordering is affected whole list is changed
- if any(field._ordering in d._changed_fields for d in data):
- changed_fields.append(db_field_name)
- continue
- self._nestable_types_changed_fields(
- changed_fields, key, data)
- return changed_fields
- def _delta(self):
- """Returns the delta (set, unset) of the changes for a document.
- Gets any values that have been explicitly changed.
- """
- # Handles cases where not loaded from_son but has _id
- doc = self.to_mongo()
- set_fields = self._get_changed_fields()
- unset_data = {}
- parts = []
- if hasattr(self, '_changed_fields'):
- set_data = {}
- # Fetch each set item from its path
- for path in set_fields:
- parts = path.split('.')
- d = doc
- new_path = []
- for p in parts:
- if isinstance(d, (ObjectId, DBRef)):
- break
- elif isinstance(d, list) and p.lstrip('-').isdigit():
- if p[0] == '-':
- p = str(len(d) + int(p))
- try:
- d = d[int(p)]
- except IndexError:
- d = None
- elif hasattr(d, 'get'):
- d = d.get(p)
- new_path.append(p)
- path = '.'.join(new_path)
- set_data[path] = d
- else:
- set_data = doc
- if '_id' in set_data:
- del set_data['_id']
- # Determine if any changed items were actually unset.
- for path, value in set_data.items():
- if value or isinstance(value, (numbers.Number, bool)):
- continue
- # If we've set a value that ain't the default value don't unset it.
- default = None
- if (self._dynamic and len(parts) and parts[0] in
- self._dynamic_fields):
- del set_data[path]
- unset_data[path] = 1
- continue
- elif path in self._fields:
- default = self._fields[path].default
- else: # Perform a full lookup for lists / embedded lookups
- d = self
- parts = path.split('.')
- db_field_name = parts.pop()
- for p in parts:
- if isinstance(d, list) and p.lstrip('-').isdigit():
- if p[0] == '-':
- p = str(len(d) + int(p))
- d = d[int(p)]
- elif (hasattr(d, '__getattribute__') and
- not isinstance(d, dict)):
- real_path = d._reverse_db_field_map.get(p, p)
- d = getattr(d, real_path)
- else:
- d = d.get(p)
- if hasattr(d, '_fields'):
- field_name = d._reverse_db_field_map.get(db_field_name,
- db_field_name)
- if field_name in d._fields:
- default = d._fields.get(field_name).default
- else:
- default = None
- if default is not None:
- if callable(default):
- default = default()
- if default != value:
- continue
- del set_data[path]
- unset_data[path] = 1
- return set_data, unset_data
- @classmethod
- def _get_collection_name(cls):
- """Return the collection name for this class. None for abstract
- class.
- """
- return cls._meta.get('collection', None)
- @classmethod
- def _from_son(cls, son, _auto_dereference=True, only_fields=None, created=False):
- """Create an instance of a Document (subclass) from a PyMongo
- SON.
- """
- if not only_fields:
- only_fields = []
- if son and not isinstance(son, dict):
- raise ValueError("The source SON object needs to be of type 'dict'")
- # Get the class name from the document, falling back to the given
- # class if unavailable
- class_name = son.get('_cls', cls._class_name)
- # Convert SON to a data dict, making sure each key is a string and
- # corresponds to the right db field.
- data = {}
- for key, value in son.iteritems():
- key = str(key)
- key = cls._db_field_map.get(key, key)
- data[key] = value
- # Return correct subclass for document type
- if class_name != cls._class_name:
- cls = get_document(class_name)
- changed_fields = []
- errors_dict = {}
- fields = cls._fields
- if not _auto_dereference:
- fields = copy.deepcopy(fields)
- for field_name, field in fields.iteritems():
- field._auto_dereference = _auto_dereference
- if field.db_field in data:
- value = data[field.db_field]
- try:
- data[field_name] = (value if value is None
- else field.to_python(value))
- if field_name != field.db_field:
- del data[field.db_field]
- except (AttributeError, ValueError) as e:
- errors_dict[field_name] = e
- if errors_dict:
- errors = '\n'.join(['%s - %s' % (k, v)
- for k, v in errors_dict.items()])
- msg = ('Invalid data to create a `%s` instance.\n%s'
- % (cls._class_name, errors))
- raise InvalidDocumentError(msg)
- # In STRICT documents, remove any keys that aren't in cls._fields
- if cls.STRICT:
- data = {k: v for k, v in data.iteritems() if k in cls._fields}
- try:
- obj = cls(__auto_convert = False, _created = created, __only_fields = only_fields, **data)
- obj._changed_fields = changed_fields
- if not _auto_dereference:
- obj._fields = fields
- except Exception as e:
- raise e
- return obj
- @classmethod
- def _build_index_specs(cls, meta_indexes):
- """Generate and merge the full index specs."""
- geo_indices = cls._geo_indices()
- unique_indices = cls._unique_with_indexes()
- index_specs = [cls._build_index_spec(spec) for spec in meta_indexes]
- def merge_index_specs(index_specs, indices):
- """Helper method for merging index specs."""
- if not indices:
- return index_specs
- # Create a map of index fields to index spec. We're converting
- # the fields from a list to a tuple so that it's hashable.
- spec_fields = {
- tuple(index['fields']): index for index in index_specs
- }
- # For each new index, if there's an existing index with the same
- # fields list, update the existing spec with all data from the
- # new spec.
- for new_index in indices:
- candidate = spec_fields.get(tuple(new_index['fields']))
- if candidate is None:
- index_specs.append(new_index)
- else:
- candidate.update(new_index)
- return index_specs
- # Merge geo indexes and unique_with indexes into the meta index specs.
- index_specs = merge_index_specs(index_specs, geo_indices)
- index_specs = merge_index_specs(index_specs, unique_indices)
- return index_specs
- @classmethod
- def _build_index_spec(cls, spec):
- """Build a PyMongo index spec from a MongoEngine index spec."""
- if isinstance(spec, six.string_types):
- spec = {'fields': [spec]}
- elif isinstance(spec, (list, tuple)):
- spec = {'fields': list(spec)}
- elif isinstance(spec, dict):
- spec = dict(spec)
- index_list = []
- direction = None
- # Check to see if we need to include _cls
- allow_inheritance = cls._meta.get('allow_inheritance')
- include_cls = (
- allow_inheritance and
- not spec.get('sparse', False) and
- spec.get('cls', True) and
- '_cls' not in spec['fields']
- )
- # 733: don't include cls if index_cls is False unless there is an explicit cls with the index
- include_cls = include_cls and (spec.get('cls', False) or cls._meta.get('index_cls', True))
- if 'cls' in spec:
- spec.pop('cls')
- for key in spec['fields']:
- # If inherited spec continue
- if isinstance(key, (list, tuple)):
- continue
- # ASCENDING from +
- # DESCENDING from -
- # TEXT from $
- # HASHED from #
- # GEOSPHERE from (
- # GEOHAYSTACK from )
- # GEO2D from *
- direction = pymongo.ASCENDING
- if key.startswith('-'):
- direction = pymongo.DESCENDING
- elif key.startswith('$'):
- direction = pymongo.TEXT
- elif key.startswith('#'):
- direction = pymongo.HASHED
- elif key.startswith('('):
- direction = pymongo.GEOSPHERE
- elif key.startswith(')'):
- direction = pymongo.GEOHAYSTACK
- elif key.startswith('*'):
- direction = pymongo.GEO2D
- if key.startswith(('+', '-', '*', '$', '#', '(', ')')):
- key = key[1:]
- # Use real field name, do it manually because we need field
- # objects for the next part (list field checking)
- parts = key.split('.')
- if parts in (['pk'], ['id'], ['_id']):
- key = '_id'
- else:
- fields = cls._lookup_field(parts)
- parts = []
- for field in fields:
- try:
- if field != '_id':
- field = field.db_field
- except AttributeError:
- pass
- parts.append(field)
- key = '.'.join(parts)
- index_list.append((key, direction))
- # Don't add cls to a geo index
- if include_cls and direction not in (
- pymongo.GEO2D, pymongo.GEOHAYSTACK, pymongo.GEOSPHERE):
- index_list.insert(0, ('_cls', 1))
- if index_list:
- spec['fields'] = index_list
- return spec
- @classmethod
- def _unique_with_indexes(cls, namespace=''):
- """Find unique indexes in the document schema and return them."""
- unique_indexes = []
- for field_name, field in cls._fields.items():
- sparse = field.sparse
- # Generate a list of indexes needed by uniqueness constraints
- if field.unique:
- unique_fields = [field.db_field]
- # Add any unique_with fields to the back of the index spec
- if field.unique_with:
- if isinstance(field.unique_with, six.string_types):
- field.unique_with = [field.unique_with]
- # Convert unique_with field names to real field names
- unique_with = []
- for other_name in field.unique_with:
- parts = other_name.split('.')
- # Lookup real name
- parts = cls._lookup_field(parts)
- name_parts = [part.db_field for part in parts]
- unique_with.append('.'.join(name_parts))
- # Unique field should be required
- parts[-1].required = True
- sparse = (not sparse and
- parts[-1].name not in cls.__dict__)
- unique_fields += unique_with
- # Add the new index to the list
- fields = [
- ('%s%s' % (namespace, f), pymongo.ASCENDING)
- for f in unique_fields
- ]
- index = {'fields': fields, 'unique': True, 'sparse': sparse}
- unique_indexes.append(index)
- if field.__class__.__name__ == 'ListField':
- field = field.field
- # Grab any embedded document field unique indexes
- if (field.__class__.__name__ == 'EmbeddedDocumentField' and
- field.document_type != cls):
- field_namespace = '%s.' % field_name
- doc_cls = field.document_type
- unique_indexes += doc_cls._unique_with_indexes(field_namespace)
- return unique_indexes
- @classmethod
- def _geo_indices(cls, inspected=None, parent_field=None):
- inspected = inspected or []
- geo_indices = []
- inspected.append(cls)
- geo_field_type_names = ('EmbeddedDocumentField', 'GeoPointField',
- 'PointField', 'LineStringField',
- 'PolygonField')
- geo_field_types = tuple([_import_class(field)
- for field in geo_field_type_names])
- for field in cls._fields.values():
- if not isinstance(field, geo_field_types):
- continue
- if hasattr(field, 'document_type'):
- field_cls = field.document_type
- if field_cls in inspected:
- continue
- if hasattr(field_cls, '_geo_indices'):
- geo_indices += field_cls._geo_indices(
- inspected, parent_field=field.db_field)
- elif field._geo_index:
- field_name = field.db_field
- if parent_field:
- field_name = '%s.%s' % (parent_field, field_name)
- geo_indices.append({
- 'fields': [(field_name, field._geo_index)]
- })
- return geo_indices
- @classmethod
- def _lookup_field(cls, parts):
- """Given the path to a given field, return a list containing
- the Field object associated with that field and all of its parent
- Field objects.
- Args:
- parts (str, list, or tuple) - path to the field. Should be a
- string for simple fields existing on this document or a list
- of strings for a field that exists deeper in embedded documents.
- Returns:
- A list of Field instances for fields that were found or
- strings for sub-fields that weren't.
- Example:
- >>> user._lookup_field('name')
- [<mongoengine.fields.StringField at 0x1119bff50>]
- >>> user._lookup_field('roles')
- [<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>]
- >>> user._lookup_field(['roles', 'role'])
- [<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>,
- <mongoengine.fields.StringField at 0x1119ec050>]
- >>> user._lookup_field('doesnt_exist')
- raises LookUpError
- >>> user._lookup_field(['roles', 'doesnt_exist'])
- [<mongoengine.fields.EmbeddedDocumentListField at 0x1119ec250>,
- 'doesnt_exist']
- """
- # TODO this method is WAY too complicated. Simplify it.
- # TODO don't think returning a string for embedded non-existent fields is desired
- ListField = _import_class('ListField')
- DynamicField = _import_class('DynamicField')
- if not isinstance(parts, (list, tuple)):
- parts = [parts]
- fields = []
- field = None
- for field_name in parts:
- # Handle ListField indexing:
- if field_name.isdigit() and isinstance(field, ListField):
- fields.append(field_name)
- continue
- # Look up first field from the document
- if field is None:
- if field_name == 'pk':
- # Deal with "primary key" alias
- field_name = cls._meta['id_field']
- if field_name in cls._fields:
- field = cls._fields[field_name]
- elif cls._dynamic:
- field = DynamicField(db_field=field_name)
- elif cls._meta.get('allow_inheritance') or cls._meta.get('abstract', False):
- # 744: in case the field is defined in a subclass
- for subcls in cls.__subclasses__():
- try:
- field = subcls._lookup_field([field_name])[0]
- except LookUpError:
- continue
- if field is not None:
- break
- else:
- raise LookUpError('Cannot resolve field "%s"' % field_name)
- else:
- raise LookUpError('Cannot resolve field "%s"' % field_name)
- else:
- ReferenceField = _import_class('ReferenceField')
- GenericReferenceField = _import_class('GenericReferenceField')
- # If previous field was a reference, throw an error (we
- # cannot look up fields that are on references).
- if isinstance(field, (ReferenceField, GenericReferenceField)):
- raise LookUpError('Cannot perform join in mongoDB: %s' %
- '__'.join(parts))
- # If the parent field has a "field" attribute which has a
- # lookup_member method, call it to find the field
- # corresponding to this iteration.
- if hasattr(getattr(field, 'field', None), 'lookup_member'):
- new_field = field.field.lookup_member(field_name)
- # If the parent field is a DynamicField or if it's part of
- # a DynamicDocument, mark current field as a DynamicField
- # with db_name equal to the field name.
- elif cls._dynamic and (isinstance(field, DynamicField) or
- getattr(getattr(field, 'document_type', None), '_dynamic', None)):
- new_field = DynamicField(db_field=field_name)
- # Else, try to use the parent field's lookup_member method
- # to find the subfield.
- elif hasattr(field, 'lookup_member'):
- new_field = field.lookup_member(field_name)
- # Raise a LookUpError if all the other conditions failed.
- else:
- raise LookUpError(
- 'Cannot resolve subfield or operator {} '
- 'on the field {}'.format(field_name, field.name)
- )
- # If current field still wasn't found and the parent field
- # is a ComplexBaseField, add the name current field name and
- # move on.
- if not new_field and isinstance(field, ComplexBaseField):
- fields.append(field_name)
- continue
- elif not new_field:
- raise LookUpError('Cannot resolve field "%s"' % field_name)
- field = new_field # update field to the new field type
- fields.append(field)
- return fields
- @classmethod
- def _translate_field_name(cls, field, sep='.'):
- """Translate a field attribute name to a database field name.
- """
- parts = field.split(sep)
- parts = [f.db_field for f in cls._lookup_field(parts)]
- return '.'.join(parts)
- def __set_field_display(self):
- """For each field that specifies choices, create a
- get_<field>_display method.
- """
- fields_with_choices = [(n, f) for n, f in self._fields.items()
- if f.choices]
- for attr_name, field in fields_with_choices:
- setattr(self,
- 'get_%s_display' % attr_name,
- partial(self.__get_field_display, field=field))
- def __get_field_display(self, field):
- """Return the display value for a choice field"""
- value = getattr(self, field.name)
- if field.choices and isinstance(field.choices[0], (list, tuple)):
- if value is None:
- return None
- sep = getattr(field, 'display_sep', ' ')
- values = value if field.__class__.__name__ in ('ListField', 'SortedListField') else [value]
- return sep.join([
- six.text_type(dict(field.choices).get(val, val))
- for val in values or []])
- return value
|