context_managers.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. from contextlib import contextmanager
  2. from pymongo.write_concern import WriteConcern
  3. from mongoengine.common import _import_class
  4. from mongoengine.connection import DEFAULT_CONNECTION_NAME, get_db
  5. __all__ = ('switch_db', 'switch_collection', 'no_dereference',
  6. 'no_sub_classes', 'query_counter', 'set_write_concern')
  7. class switch_db(object):
  8. """switch_db alias context manager.
  9. Example ::
  10. # Register connections
  11. register_connection('default', 'mongoenginetest')
  12. register_connection('testdb-1', 'mongoenginetest2')
  13. class Group(Document):
  14. name = StringField()
  15. Group(name='test').save() # Saves in the default db
  16. with switch_db(Group, 'testdb-1') as Group:
  17. Group(name='hello testdb!').save() # Saves in testdb-1
  18. """
  19. def __init__(self, cls, db_alias):
  20. """Construct the switch_db context manager
  21. :param cls: the class to change the registered db
  22. :param db_alias: the name of the specific database to use
  23. """
  24. self.cls = cls
  25. self.collection = cls._get_collection()
  26. self.db_alias = db_alias
  27. self.ori_db_alias = cls._meta.get('db_alias', DEFAULT_CONNECTION_NAME)
  28. def __enter__(self):
  29. """Change the db_alias and clear the cached collection."""
  30. self.cls._meta['db_alias'] = self.db_alias
  31. self.cls._collection = None
  32. return self.cls
  33. def __exit__(self, t, value, traceback):
  34. """Reset the db_alias and collection."""
  35. self.cls._meta['db_alias'] = self.ori_db_alias
  36. self.cls._collection = self.collection
  37. class switch_collection(object):
  38. """switch_collection alias context manager.
  39. Example ::
  40. class Group(Document):
  41. name = StringField()
  42. Group(name='test').save() # Saves in the default db
  43. with switch_collection(Group, 'group1') as Group:
  44. Group(name='hello testdb!').save() # Saves in group1 collection
  45. """
  46. def __init__(self, cls, collection_name):
  47. """Construct the switch_collection context manager.
  48. :param cls: the class to change the registered db
  49. :param collection_name: the name of the collection to use
  50. """
  51. self.cls = cls
  52. self.ori_collection = cls._get_collection()
  53. self.ori_get_collection_name = cls._get_collection_name
  54. self.collection_name = collection_name
  55. def __enter__(self):
  56. """Change the _get_collection_name and clear the cached collection."""
  57. @classmethod
  58. def _get_collection_name(cls):
  59. return self.collection_name
  60. self.cls._get_collection_name = _get_collection_name
  61. self.cls._collection = None
  62. return self.cls
  63. def __exit__(self, t, value, traceback):
  64. """Reset the collection."""
  65. self.cls._collection = self.ori_collection
  66. self.cls._get_collection_name = self.ori_get_collection_name
  67. class no_dereference(object):
  68. """no_dereference context manager.
  69. Turns off all dereferencing in Documents for the duration of the context
  70. manager::
  71. with no_dereference(Group) as Group:
  72. Group.objects.find()
  73. """
  74. def __init__(self, cls):
  75. """Construct the no_dereference context manager.
  76. :param cls: the class to turn dereferencing off on
  77. """
  78. self.cls = cls
  79. ReferenceField = _import_class('ReferenceField')
  80. GenericReferenceField = _import_class('GenericReferenceField')
  81. ComplexBaseField = _import_class('ComplexBaseField')
  82. self.deref_fields = [k for k, v in self.cls._fields.iteritems()
  83. if isinstance(v, (ReferenceField,
  84. GenericReferenceField,
  85. ComplexBaseField))]
  86. def __enter__(self):
  87. """Change the objects default and _auto_dereference values."""
  88. for field in self.deref_fields:
  89. self.cls._fields[field]._auto_dereference = False
  90. return self.cls
  91. def __exit__(self, t, value, traceback):
  92. """Reset the default and _auto_dereference values."""
  93. for field in self.deref_fields:
  94. self.cls._fields[field]._auto_dereference = True
  95. return self.cls
  96. class no_sub_classes(object):
  97. """no_sub_classes context manager.
  98. Only returns instances of this class and no sub (inherited) classes::
  99. with no_sub_classes(Group) as Group:
  100. Group.objects.find()
  101. """
  102. def __init__(self, cls):
  103. """Construct the no_sub_classes context manager.
  104. :param cls: the class to turn querying sub classes on
  105. """
  106. self.cls = cls
  107. self.cls_initial_subclasses = None
  108. def __enter__(self):
  109. """Change the objects default and _auto_dereference values."""
  110. self.cls_initial_subclasses = self.cls._subclasses
  111. self.cls._subclasses = (self.cls._class_name,)
  112. return self.cls
  113. def __exit__(self, t, value, traceback):
  114. """Reset the default and _auto_dereference values."""
  115. self.cls._subclasses = self.cls_initial_subclasses
  116. class query_counter(object):
  117. """Query_counter context manager to get the number of queries.
  118. This works by updating the `profiling_level` of the database so that all queries get logged,
  119. resetting the db.system.profile collection at the beginnig of the context and counting the new entries.
  120. This was designed for debugging purpose. In fact it is a global counter so queries issued by other threads/processes
  121. can interfere with it
  122. Be aware that:
  123. - Iterating over large amount of documents (>101) makes pymongo issue `getmore` queries to fetch the next batch of
  124. documents (https://docs.mongodb.com/manual/tutorial/iterate-a-cursor/#cursor-batches)
  125. - Some queries are ignored by default by the counter (killcursors, db.system.indexes)
  126. """
  127. def __init__(self):
  128. """Construct the query_counter
  129. """
  130. self.db = get_db()
  131. self.initial_profiling_level = None
  132. self._ctx_query_counter = 0 # number of queries issued by the context
  133. self._ignored_query = {
  134. 'ns':
  135. {'$ne': '%s.system.indexes' % self.db.name},
  136. 'op': # MONGODB < 3.2
  137. {'$ne': 'killcursors'},
  138. 'command.killCursors': # MONGODB >= 3.2
  139. {'$exists': False}
  140. }
  141. def _turn_on_profiling(self):
  142. self.initial_profiling_level = self.db.profiling_level()
  143. self.db.set_profiling_level(0)
  144. self.db.system.profile.drop()
  145. self.db.set_profiling_level(2)
  146. def _resets_profiling(self):
  147. self.db.set_profiling_level(self.initial_profiling_level)
  148. def __enter__(self):
  149. self._turn_on_profiling()
  150. return self
  151. def __exit__(self, t, value, traceback):
  152. self._resets_profiling()
  153. def __eq__(self, value):
  154. counter = self._get_count()
  155. return value == counter
  156. def __ne__(self, value):
  157. return not self.__eq__(value)
  158. def __lt__(self, value):
  159. return self._get_count() < value
  160. def __le__(self, value):
  161. return self._get_count() <= value
  162. def __gt__(self, value):
  163. return self._get_count() > value
  164. def __ge__(self, value):
  165. return self._get_count() >= value
  166. def __int__(self):
  167. return self._get_count()
  168. def __repr__(self):
  169. """repr query_counter as the number of queries."""
  170. return u"%s" % self._get_count()
  171. def _get_count(self):
  172. """Get the number of queries by counting the current number of entries in db.system.profile
  173. and substracting the queries issued by this context. In fact everytime this is called, 1 query is
  174. issued so we need to balance that
  175. """
  176. count = self.db.system.profile.find(self._ignored_query).count() - self._ctx_query_counter
  177. self._ctx_query_counter += 1 # Account for the query we just issued to gather the information
  178. return count
  179. @contextmanager
  180. def set_write_concern(collection, write_concerns):
  181. combined_concerns = dict(collection.write_concern.document.items())
  182. combined_concerns.update(write_concerns)
  183. yield collection.with_options(write_concern=WriteConcern(**combined_concerns))