dereference.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from bson import DBRef, SON
  2. import six
  3. from mongoengine.base import (BaseDict, BaseList, EmbeddedDocumentList,
  4. TopLevelDocumentMetaclass, get_document)
  5. from mongoengine.base.datastructures import LazyReference
  6. from mongoengine.connection import get_db
  7. from mongoengine.document import Document, EmbeddedDocument
  8. from mongoengine.fields import DictField, ListField, MapField, ReferenceField
  9. from mongoengine.queryset import QuerySet
  10. class DeReference(object):
  11. def __call__(self, items, max_depth=1, instance=None, name=None):
  12. """
  13. Cheaply dereferences the items to a set depth.
  14. Also handles the conversion of complex data types.
  15. :param items: The iterable (dict, list, queryset) to be dereferenced.
  16. :param max_depth: The maximum depth to recurse to
  17. :param instance: The owning instance used for tracking changes by
  18. :class:`~mongoengine.base.ComplexBaseField`
  19. :param name: The name of the field, used for tracking changes by
  20. :class:`~mongoengine.base.ComplexBaseField`
  21. :param get: A boolean determining if being called by __get__
  22. """
  23. if items is None or isinstance(items, six.string_types):
  24. return items
  25. # cheapest way to convert a queryset to a list
  26. # list(queryset) uses a count() query to determine length
  27. if isinstance(items, QuerySet):
  28. items = [i for i in items]
  29. self.max_depth = max_depth
  30. doc_type = None
  31. if instance and isinstance(instance, (Document, EmbeddedDocument,
  32. TopLevelDocumentMetaclass)):
  33. doc_type = instance._fields.get(name)
  34. while hasattr(doc_type, 'field'):
  35. doc_type = doc_type.field
  36. if isinstance(doc_type, ReferenceField):
  37. field = doc_type
  38. doc_type = doc_type.document_type
  39. is_list = not hasattr(items, 'items')
  40. if is_list and all([i.__class__ == doc_type for i in items]):
  41. return items
  42. elif not is_list and all(
  43. [i.__class__ == doc_type for i in items.values()]):
  44. return items
  45. elif not field.dbref:
  46. if not hasattr(items, 'items'):
  47. def _get_items(items):
  48. new_items = []
  49. for v in items:
  50. if isinstance(v, list):
  51. new_items.append(_get_items(v))
  52. elif not isinstance(v, (DBRef, Document)):
  53. new_items.append(field.to_python(v))
  54. else:
  55. new_items.append(v)
  56. return new_items
  57. items = _get_items(items)
  58. else:
  59. items = {
  60. k: (v if isinstance(v, (DBRef, Document))
  61. else field.to_python(v))
  62. for k, v in items.iteritems()
  63. }
  64. self.reference_map = self._find_references(items)
  65. self.object_map = self._fetch_objects(doc_type=doc_type)
  66. return self._attach_objects(items, 0, instance, name)
  67. def _find_references(self, items, depth=0):
  68. """
  69. Recursively finds all db references to be dereferenced
  70. :param items: The iterable (dict, list, queryset)
  71. :param depth: The current depth of recursion
  72. """
  73. reference_map = {}
  74. if not items or depth >= self.max_depth:
  75. return reference_map
  76. # Determine the iterator to use
  77. if isinstance(items, dict):
  78. iterator = items.values()
  79. else:
  80. iterator = items
  81. # Recursively find dbreferences
  82. depth += 1
  83. for item in iterator:
  84. if isinstance(item, (Document, EmbeddedDocument)):
  85. for field_name, field in item._fields.iteritems():
  86. v = item._data.get(field_name, None)
  87. if isinstance(v, LazyReference):
  88. # LazyReference inherits DBRef but should not be dereferenced here !
  89. continue
  90. elif isinstance(v, DBRef):
  91. reference_map.setdefault(field.document_type, set()).add(v.id)
  92. elif isinstance(v, (dict, SON)) and '_ref' in v:
  93. reference_map.setdefault(get_document(v['_cls']), set()).add(v['_ref'].id)
  94. elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
  95. field_cls = getattr(getattr(field, 'field', None), 'document_type', None)
  96. references = self._find_references(v, depth)
  97. for key, refs in references.iteritems():
  98. if isinstance(field_cls, (Document, TopLevelDocumentMetaclass)):
  99. key = field_cls
  100. reference_map.setdefault(key, set()).update(refs)
  101. elif isinstance(item, LazyReference):
  102. # LazyReference inherits DBRef but should not be dereferenced here !
  103. continue
  104. elif isinstance(item, DBRef):
  105. reference_map.setdefault(item.collection, set()).add(item.id)
  106. elif isinstance(item, (dict, SON)) and '_ref' in item:
  107. reference_map.setdefault(get_document(item['_cls']), set()).add(item['_ref'].id)
  108. elif isinstance(item, (dict, list, tuple)) and depth - 1 <= self.max_depth:
  109. references = self._find_references(item, depth - 1)
  110. for key, refs in references.iteritems():
  111. reference_map.setdefault(key, set()).update(refs)
  112. return reference_map
  113. def _fetch_objects(self, doc_type=None):
  114. """Fetch all references and convert to their document objects
  115. """
  116. object_map = {}
  117. for collection, dbrefs in self.reference_map.iteritems():
  118. # we use getattr instead of hasattr because hasattr swallows any exception under python2
  119. # so it could hide nasty things without raising exceptions (cfr bug #1688))
  120. ref_document_cls_exists = (getattr(collection, 'objects', None) is not None)
  121. if ref_document_cls_exists:
  122. col_name = collection._get_collection_name()
  123. refs = [dbref for dbref in dbrefs
  124. if (col_name, dbref) not in object_map]
  125. references = collection.objects.in_bulk(refs)
  126. for key, doc in references.iteritems():
  127. object_map[(col_name, key)] = doc
  128. else: # Generic reference: use the refs data to convert to document
  129. if isinstance(doc_type, (ListField, DictField, MapField)):
  130. continue
  131. refs = [dbref for dbref in dbrefs
  132. if (collection, dbref) not in object_map]
  133. if doc_type:
  134. references = doc_type._get_db()[collection].find({'_id': {'$in': refs}})
  135. for ref in references:
  136. doc = doc_type._from_son(ref)
  137. object_map[(collection, doc.id)] = doc
  138. else:
  139. references = get_db()[collection].find({'_id': {'$in': refs}})
  140. for ref in references:
  141. if '_cls' in ref:
  142. doc = get_document(ref['_cls'])._from_son(ref)
  143. elif doc_type is None:
  144. doc = get_document(
  145. ''.join(x.capitalize()
  146. for x in collection.split('_')))._from_son(ref)
  147. else:
  148. doc = doc_type._from_son(ref)
  149. object_map[(collection, doc.id)] = doc
  150. return object_map
  151. def _attach_objects(self, items, depth=0, instance=None, name=None):
  152. """
  153. Recursively finds all db references to be dereferenced
  154. :param items: The iterable (dict, list, queryset)
  155. :param depth: The current depth of recursion
  156. :param instance: The owning instance used for tracking changes by
  157. :class:`~mongoengine.base.ComplexBaseField`
  158. :param name: The name of the field, used for tracking changes by
  159. :class:`~mongoengine.base.ComplexBaseField`
  160. """
  161. if not items:
  162. if isinstance(items, (BaseDict, BaseList)):
  163. return items
  164. if instance:
  165. if isinstance(items, dict):
  166. return BaseDict(items, instance, name)
  167. else:
  168. return BaseList(items, instance, name)
  169. if isinstance(items, (dict, SON)):
  170. if '_ref' in items:
  171. return self.object_map.get(
  172. (items['_ref'].collection, items['_ref'].id), items)
  173. elif '_cls' in items:
  174. doc = get_document(items['_cls'])._from_son(items)
  175. _cls = doc._data.pop('_cls', None)
  176. del items['_cls']
  177. doc._data = self._attach_objects(doc._data, depth, doc, None)
  178. if _cls is not None:
  179. doc._data['_cls'] = _cls
  180. return doc
  181. if not hasattr(items, 'items'):
  182. is_list = True
  183. list_type = BaseList
  184. if isinstance(items, EmbeddedDocumentList):
  185. list_type = EmbeddedDocumentList
  186. as_tuple = isinstance(items, tuple)
  187. iterator = enumerate(items)
  188. data = []
  189. else:
  190. is_list = False
  191. iterator = items.iteritems()
  192. data = {}
  193. depth += 1
  194. for k, v in iterator:
  195. if is_list:
  196. data.append(v)
  197. else:
  198. data[k] = v
  199. if k in self.object_map and not is_list:
  200. data[k] = self.object_map[k]
  201. elif isinstance(v, (Document, EmbeddedDocument)):
  202. for field_name in v._fields:
  203. v = data[k]._data.get(field_name, None)
  204. if isinstance(v, DBRef):
  205. data[k]._data[field_name] = self.object_map.get(
  206. (v.collection, v.id), v)
  207. elif isinstance(v, (dict, SON)) and '_ref' in v:
  208. data[k]._data[field_name] = self.object_map.get(
  209. (v['_ref'].collection, v['_ref'].id), v)
  210. elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
  211. item_name = six.text_type('{0}.{1}.{2}').format(name, k, field_name)
  212. data[k]._data[field_name] = self._attach_objects(v, depth, instance=instance, name=item_name)
  213. elif isinstance(v, (dict, list, tuple)) and depth <= self.max_depth:
  214. item_name = '%s.%s' % (name, k) if name else name
  215. data[k] = self._attach_objects(v, depth - 1, instance=instance, name=item_name)
  216. elif isinstance(v, DBRef) and hasattr(v, 'id'):
  217. data[k] = self.object_map.get((v.collection, v.id), v)
  218. if instance and name:
  219. if is_list:
  220. return tuple(data) if as_tuple else list_type(data, instance, name)
  221. return BaseDict(data, instance, name)
  222. depth += 1
  223. return data