aggregation.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # Copyright 2019-present MongoDB, Inc.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"); you
  4. # may not use this file except in compliance with the License. You
  5. # may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
  12. # implied. See the License for the specific language governing
  13. # permissions and limitations under the License.
  14. """Perform aggregation operations on a collection or database."""
  15. from bson.son import SON
  16. from pymongo import common
  17. from pymongo.collation import validate_collation_or_none
  18. from pymongo.errors import ConfigurationError
  19. from pymongo.read_preferences import ReadPreference
  20. class _AggregationCommand(object):
  21. """The internal abstract base class for aggregation cursors.
  22. Should not be called directly by application developers. Use
  23. :meth:`pymongo.collection.Collection.aggregate`, or
  24. :meth:`pymongo.database.Database.aggregate` instead.
  25. """
  26. def __init__(self, target, cursor_class, pipeline, options,
  27. explicit_session, user_fields=None, result_processor=None):
  28. if "explain" in options:
  29. raise ConfigurationError("The explain option is not supported. "
  30. "Use Database.command instead.")
  31. self._target = target
  32. common.validate_list('pipeline', pipeline)
  33. self._pipeline = pipeline
  34. self._performs_write = False
  35. if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]):
  36. self._performs_write = True
  37. common.validate_is_mapping('options', options)
  38. self._options = options
  39. # This is the batchSize that will be used for setting the initial
  40. # batchSize for the cursor, as well as the subsequent getMores.
  41. self._batch_size = common.validate_non_negative_integer_or_none(
  42. "batchSize", self._options.pop("batchSize", None))
  43. # If the cursor option is already specified, avoid overriding it.
  44. self._options.setdefault("cursor", {})
  45. # If the pipeline performs a write, we ignore the initial batchSize
  46. # since the server doesn't return results in this case.
  47. if self._batch_size is not None and not self._performs_write:
  48. self._options["cursor"]["batchSize"] = self._batch_size
  49. self._cursor_class = cursor_class
  50. self._explicit_session = explicit_session
  51. self._user_fields = user_fields
  52. self._result_processor = result_processor
  53. self._collation = validate_collation_or_none(
  54. options.pop('collation', None))
  55. self._max_await_time_ms = options.pop('maxAwaitTimeMS', None)
  56. @property
  57. def _aggregation_target(self):
  58. """The argument to pass to the aggregate command."""
  59. raise NotImplementedError
  60. @property
  61. def _cursor_namespace(self):
  62. """The namespace in which the aggregate command is run."""
  63. raise NotImplementedError
  64. @property
  65. def _cursor_collection(self, cursor_doc):
  66. """The Collection used for the aggregate command cursor."""
  67. raise NotImplementedError
  68. @property
  69. def _database(self):
  70. """The database against which the aggregation command is run."""
  71. raise NotImplementedError
  72. @staticmethod
  73. def _check_compat(sock_info):
  74. """Check whether the server version in-use supports aggregation."""
  75. pass
  76. def _process_result(
  77. self, result, session, server, sock_info, secondary_ok):
  78. if self._result_processor:
  79. self._result_processor(
  80. result, session, server, sock_info, secondary_ok)
  81. def get_read_preference(self, session):
  82. if self._performs_write:
  83. return ReadPreference.PRIMARY
  84. return self._target._read_preference_for(session)
  85. def get_cursor(self, session, server, sock_info, secondary_ok):
  86. # Ensure command compatibility.
  87. self._check_compat(sock_info)
  88. # Serialize command.
  89. cmd = SON([("aggregate", self._aggregation_target),
  90. ("pipeline", self._pipeline)])
  91. cmd.update(self._options)
  92. # Apply this target's read concern if:
  93. # readConcern has not been specified as a kwarg and either
  94. # - server version is >= 4.2 or
  95. # - server version is >= 3.2 and pipeline doesn't use $out
  96. if (('readConcern' not in cmd) and
  97. ((sock_info.max_wire_version >= 4 and
  98. not self._performs_write) or
  99. (sock_info.max_wire_version >= 8))):
  100. read_concern = self._target.read_concern
  101. else:
  102. read_concern = None
  103. # Apply this target's write concern if:
  104. # writeConcern has not been specified as a kwarg and pipeline doesn't
  105. # perform a write operation
  106. if 'writeConcern' not in cmd and self._performs_write:
  107. write_concern = self._target._write_concern_for(session)
  108. else:
  109. write_concern = None
  110. # Run command.
  111. result = sock_info.command(
  112. self._database.name,
  113. cmd,
  114. secondary_ok,
  115. self.get_read_preference(session),
  116. self._target.codec_options,
  117. parse_write_concern_error=True,
  118. read_concern=read_concern,
  119. write_concern=write_concern,
  120. collation=self._collation,
  121. session=session,
  122. client=self._database.client,
  123. user_fields=self._user_fields)
  124. self._process_result(result, session, server, sock_info, secondary_ok)
  125. # Extract cursor from result or mock/fake one if necessary.
  126. if 'cursor' in result:
  127. cursor = result['cursor']
  128. else:
  129. # Pre-MongoDB 2.6 or unacknowledged write. Fake a cursor.
  130. cursor = {
  131. "id": 0,
  132. "firstBatch": result.get("result", []),
  133. "ns": self._cursor_namespace,
  134. }
  135. # Create and return cursor instance.
  136. cmd_cursor = self._cursor_class(
  137. self._cursor_collection(cursor), cursor, sock_info.address,
  138. batch_size=self._batch_size or 0,
  139. max_await_time_ms=self._max_await_time_ms,
  140. session=session, explicit_session=self._explicit_session)
  141. cmd_cursor._maybe_pin_connection(sock_info)
  142. return cmd_cursor
  143. class _CollectionAggregationCommand(_AggregationCommand):
  144. def __init__(self, *args, **kwargs):
  145. # Pop additional option and initialize parent class.
  146. use_cursor = kwargs.pop("use_cursor", True)
  147. super(_CollectionAggregationCommand, self).__init__(*args, **kwargs)
  148. # Remove the cursor document if the user has set use_cursor to False.
  149. self._use_cursor = use_cursor
  150. if not self._use_cursor:
  151. self._options.pop("cursor", None)
  152. @property
  153. def _aggregation_target(self):
  154. return self._target.name
  155. @property
  156. def _cursor_namespace(self):
  157. return self._target.full_name
  158. def _cursor_collection(self, cursor):
  159. """The Collection used for the aggregate command cursor."""
  160. return self._target
  161. @property
  162. def _database(self):
  163. return self._target.database
  164. class _CollectionRawAggregationCommand(_CollectionAggregationCommand):
  165. def __init__(self, *args, **kwargs):
  166. super(_CollectionRawAggregationCommand, self).__init__(*args, **kwargs)
  167. # For raw-batches, we set the initial batchSize for the cursor to 0.
  168. if self._use_cursor and not self._performs_write:
  169. self._options["cursor"]["batchSize"] = 0
  170. class _DatabaseAggregationCommand(_AggregationCommand):
  171. @property
  172. def _aggregation_target(self):
  173. return 1
  174. @property
  175. def _cursor_namespace(self):
  176. return "%s.$cmd.aggregate" % (self._target.name,)
  177. @property
  178. def _database(self):
  179. return self._target
  180. def _cursor_collection(self, cursor):
  181. """The Collection used for the aggregate command cursor."""
  182. # Collection level aggregate may not always return the "ns" field
  183. # according to our MockupDB tests. Let's handle that case for db level
  184. # aggregate too by defaulting to the <db>.$cmd.aggregate namespace.
  185. _, collname = cursor.get("ns", self._cursor_namespace).split(".", 1)
  186. return self._database[collname]
  187. @staticmethod
  188. def _check_compat(sock_info):
  189. # Older server version don't raise a descriptive error, so we raise
  190. # one instead.
  191. if not sock_info.max_wire_version >= 6:
  192. err_msg = "Database.aggregate() is only supported on MongoDB 3.6+."
  193. raise ConfigurationError(err_msg)