aggregates.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. """
  2. Classes to represent the default SQL aggregate functions
  3. """
  4. import copy
  5. from django.db.models.fields import IntegerField, FloatField
  6. from django.db.models.lookups import RegisterLookupMixin
  7. from django.utils.functional import cached_property
  8. __all__ = ['Aggregate', 'Avg', 'Count', 'Max', 'Min', 'StdDev', 'Sum', 'Variance']
  9. class Aggregate(RegisterLookupMixin):
  10. """
  11. Default SQL Aggregate.
  12. """
  13. is_ordinal = False
  14. is_computed = False
  15. sql_template = '%(function)s(%(field)s)'
  16. def __init__(self, col, source=None, is_summary=False, **extra):
  17. """Instantiate an SQL aggregate
  18. * col is a column reference describing the subject field
  19. of the aggregate. It can be an alias, or a tuple describing
  20. a table and column name.
  21. * source is the underlying field or aggregate definition for
  22. the column reference. If the aggregate is not an ordinal or
  23. computed type, this reference is used to determine the coerced
  24. output type of the aggregate.
  25. * extra is a dictionary of additional data to provide for the
  26. aggregate definition
  27. Also utilizes the class variables:
  28. * sql_function, the name of the SQL function that implements the
  29. aggregate.
  30. * sql_template, a template string that is used to render the
  31. aggregate into SQL.
  32. * is_ordinal, a boolean indicating if the output of this aggregate
  33. is an integer (e.g., a count)
  34. * is_computed, a boolean indicating if this output of this aggregate
  35. is a computed float (e.g., an average), regardless of the input
  36. type.
  37. """
  38. self.col = col
  39. self.source = source
  40. self.is_summary = is_summary
  41. self.extra = extra
  42. # Follow the chain of aggregate sources back until you find an
  43. # actual field, or an aggregate that forces a particular output
  44. # type. This type of this field will be used to coerce values
  45. # retrieved from the database.
  46. tmp = self
  47. while tmp and isinstance(tmp, Aggregate):
  48. if getattr(tmp, 'is_ordinal', False):
  49. tmp = self._ordinal_aggregate_field
  50. elif getattr(tmp, 'is_computed', False):
  51. tmp = self._computed_aggregate_field
  52. else:
  53. tmp = tmp.source
  54. self.field = tmp
  55. # Two fake fields used to identify aggregate types in data-conversion operations.
  56. @cached_property
  57. def _ordinal_aggregate_field(self):
  58. return IntegerField()
  59. @cached_property
  60. def _computed_aggregate_field(self):
  61. return FloatField()
  62. def relabeled_clone(self, change_map):
  63. clone = copy.copy(self)
  64. if isinstance(self.col, (list, tuple)):
  65. clone.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
  66. return clone
  67. def as_sql(self, qn, connection):
  68. "Return the aggregate, rendered as SQL with parameters."
  69. params = []
  70. if hasattr(self.col, 'as_sql'):
  71. field_name, params = self.col.as_sql(qn, connection)
  72. elif isinstance(self.col, (list, tuple)):
  73. field_name = '.'.join(qn(c) for c in self.col)
  74. else:
  75. field_name = qn(self.col)
  76. substitutions = {
  77. 'function': self.sql_function,
  78. 'field': field_name
  79. }
  80. substitutions.update(self.extra)
  81. return self.sql_template % substitutions, params
  82. def get_group_by_cols(self):
  83. return []
  84. @property
  85. def output_field(self):
  86. return self.field
  87. class Avg(Aggregate):
  88. is_computed = True
  89. sql_function = 'AVG'
  90. class Count(Aggregate):
  91. is_ordinal = True
  92. sql_function = 'COUNT'
  93. sql_template = '%(function)s(%(distinct)s%(field)s)'
  94. def __init__(self, col, distinct=False, **extra):
  95. super(Count, self).__init__(col, distinct='DISTINCT ' if distinct else '', **extra)
  96. class Max(Aggregate):
  97. sql_function = 'MAX'
  98. class Min(Aggregate):
  99. sql_function = 'MIN'
  100. class StdDev(Aggregate):
  101. is_computed = True
  102. def __init__(self, col, sample=False, **extra):
  103. super(StdDev, self).__init__(col, **extra)
  104. self.sql_function = 'STDDEV_SAMP' if sample else 'STDDEV_POP'
  105. class Sum(Aggregate):
  106. sql_function = 'SUM'
  107. class Variance(Aggregate):
  108. is_computed = True
  109. def __init__(self, col, sample=False, **extra):
  110. super(Variance, self).__init__(col, **extra)
  111. self.sql_function = 'VAR_SAMP' if sample else 'VAR_POP'