visitor.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import copy
  2. from mongoengine.errors import InvalidQueryError
  3. from mongoengine.queryset import transform
  4. __all__ = ('Q', 'QNode')
  5. class QNodeVisitor(object):
  6. """Base visitor class for visiting Q-object nodes in a query tree.
  7. """
  8. def visit_combination(self, combination):
  9. """Called by QCombination objects.
  10. """
  11. return combination
  12. def visit_query(self, query):
  13. """Called by (New)Q objects.
  14. """
  15. return query
  16. class DuplicateQueryConditionsError(InvalidQueryError):
  17. pass
  18. class SimplificationVisitor(QNodeVisitor):
  19. """Simplifies query trees by combining unnecessary 'and' connection nodes
  20. into a single Q-object.
  21. """
  22. def visit_combination(self, combination):
  23. if combination.operation == combination.AND:
  24. # The simplification only applies to 'simple' queries
  25. if all(isinstance(node, Q) for node in combination.children):
  26. queries = [n.query for n in combination.children]
  27. try:
  28. return Q(**self._query_conjunction(queries))
  29. except DuplicateQueryConditionsError:
  30. # Cannot be simplified
  31. pass
  32. return combination
  33. def _query_conjunction(self, queries):
  34. """Merges query dicts - effectively &ing them together.
  35. """
  36. query_ops = set()
  37. combined_query = {}
  38. for query in queries:
  39. ops = set(query.keys())
  40. # Make sure that the same operation isn't applied more than once
  41. # to a single field
  42. intersection = ops.intersection(query_ops)
  43. if intersection:
  44. raise DuplicateQueryConditionsError()
  45. query_ops.update(ops)
  46. combined_query.update(copy.deepcopy(query))
  47. return combined_query
  48. class QueryCompilerVisitor(QNodeVisitor):
  49. """Compiles the nodes in a query tree to a PyMongo-compatible query
  50. dictionary.
  51. """
  52. def __init__(self, document):
  53. self.document = document
  54. def visit_combination(self, combination):
  55. operator = '$and'
  56. if combination.operation == combination.OR:
  57. operator = '$or'
  58. return {operator: combination.children}
  59. def visit_query(self, query):
  60. return transform.query(self.document, **query.query)
  61. class QNode(object):
  62. """Base class for nodes in query trees."""
  63. AND = 0
  64. OR = 1
  65. def to_query(self, document):
  66. query = self.accept(SimplificationVisitor())
  67. query = query.accept(QueryCompilerVisitor(document))
  68. return query
  69. def accept(self, visitor):
  70. raise NotImplementedError
  71. def _combine(self, other, operation):
  72. """Combine this node with another node into a QCombination
  73. object.
  74. """
  75. if getattr(other, 'empty', True):
  76. return self
  77. if self.empty:
  78. return other
  79. return QCombination(operation, [self, other])
  80. @property
  81. def empty(self):
  82. return False
  83. def __or__(self, other):
  84. return self._combine(other, self.OR)
  85. def __and__(self, other):
  86. return self._combine(other, self.AND)
  87. class QCombination(QNode):
  88. """Represents the combination of several conditions by a given
  89. logical operator.
  90. """
  91. def __init__(self, operation, children):
  92. self.operation = operation
  93. self.children = []
  94. for node in children:
  95. # If the child is a combination of the same type, we can merge its
  96. # children directly into this combinations children
  97. if isinstance(node, QCombination) and node.operation == operation:
  98. self.children += node.children
  99. else:
  100. self.children.append(node)
  101. def __repr__(self):
  102. op = ' & ' if self.operation is self.AND else ' | '
  103. return '(%s)' % op.join([repr(node) for node in self.children])
  104. def accept(self, visitor):
  105. for i in range(len(self.children)):
  106. if isinstance(self.children[i], QNode):
  107. self.children[i] = self.children[i].accept(visitor)
  108. return visitor.visit_combination(self)
  109. @property
  110. def empty(self):
  111. return not bool(self.children)
  112. class Q(QNode):
  113. """A simple query object, used in a query tree to build up more complex
  114. query structures.
  115. """
  116. def __init__(self, **query):
  117. self.query = query
  118. def __repr__(self):
  119. return 'Q(**%s)' % repr(self.query)
  120. def accept(self, visitor):
  121. return visitor.visit_query(self)
  122. @property
  123. def empty(self):
  124. return not bool(self.query)