compiler.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. from __future__ import absolute_import, unicode_literals
  2. import re
  3. from django.db.models.sql import compiler
  4. # query_class returns the base class to use for Django queries.
  5. # The custom 'SqlServerQuery' class derives from django.db.models.sql.query.Query
  6. # which is passed in as "QueryClass" by Django itself.
  7. #
  8. # SqlServerQuery overrides:
  9. # ...insert queries to add "SET IDENTITY_INSERT" if needed.
  10. # ...select queries to emulate LIMIT/OFFSET for sliced queries.
  11. # Pattern to scan a column data type string and split the data type from any
  12. # constraints or other included parts of a column definition. Based upon
  13. # <column_definition> from http://msdn.microsoft.com/en-us/library/ms174979.aspx
  14. _re_data_type_terminator = re.compile(
  15. r'\s*\b(?:' +
  16. r'filestream|collate|sparse|not|null|constraint|default|identity|rowguidcol' +
  17. r'|primary|unique|clustered|nonclustered|with|on|foreign|references|check' +
  18. ')',
  19. re.IGNORECASE,
  20. )
  21. _re_constant = re.compile(r'\s*\(?\s*\d+\s*\)?\s*')
  22. class SQLCompiler(compiler.SQLCompiler):
  23. def as_sql(self, with_limits=True, with_col_aliases=False, subquery=False):
  24. # Get out of the way if we're not a select query or there's no limiting involved.
  25. has_limit_offset = with_limits and (self.query.low_mark or self.query.high_mark is not None)
  26. try:
  27. if not has_limit_offset:
  28. # The ORDER BY clause is invalid in views, inline functions,
  29. # derived tables, subqueries, and common table expressions,
  30. # unless TOP or FOR XML is also specified.
  31. setattr(self.query, '_mssql_ordering_not_allowed', with_col_aliases)
  32. # let the base do its thing, but we'll handle limit/offset
  33. sql, fields = super(SQLCompiler, self).as_sql(
  34. with_limits=False,
  35. with_col_aliases=with_col_aliases,
  36. subquery=subquery,
  37. )
  38. if has_limit_offset:
  39. if ' order by ' not in sql.lower():
  40. # Must have an ORDER BY to slice using OFFSET/FETCH. If
  41. # there is none, use the first column, which is typically a
  42. # PK
  43. sql += ' ORDER BY 1'
  44. sql += ' OFFSET %d ROWS' % (self.query.low_mark or 0)
  45. if self.query.high_mark is not None:
  46. sql += ' FETCH NEXT %d ROWS ONLY' % (self.query.high_mark - self.query.low_mark)
  47. finally:
  48. if not has_limit_offset:
  49. # remove in case query is ever reused
  50. delattr(self.query, '_mssql_ordering_not_allowed')
  51. return sql, fields
  52. def get_ordering(self):
  53. # The ORDER BY clause is invalid in views, inline functions,
  54. # derived tables, subqueries, and common table expressions,
  55. # unless TOP or FOR XML is also specified.
  56. if getattr(self.query, '_mssql_ordering_not_allowed', False):
  57. return (None, [], [])
  58. return super(SQLCompiler, self).get_ordering()
  59. def collapse_group_by(self, expressions, having):
  60. expressions = super(SQLCompiler, self).collapse_group_by(expressions, having)
  61. # MSSQL doesn't support having constants in the GROUP BY clause. Django
  62. # does this for exists() queries that have GROUP BY.
  63. return [x for x in expressions if not _re_constant.match(getattr(x, 'sql', ''))]
  64. class SQLInsertCompiler(compiler.SQLInsertCompiler, SQLCompiler):
  65. # search for after table/column list
  66. _re_values_sub = re.compile(
  67. r'(?P<prefix>\)|\])(?P<default>\s*|\s*default\s*)values(?P<suffix>\s*|\s+\()?',
  68. re.IGNORECASE
  69. )
  70. # ... and insert the OUTPUT clause between it and the values list (or DEFAULT VALUES).
  71. _values_repl = r'\g<prefix> OUTPUT INSERTED.{col} INTO @sqlserver_ado_return_id\g<default>VALUES\g<suffix>'
  72. def as_sql(self, *args, **kwargs):
  73. # Fix for Django ticket #14019
  74. if not hasattr(self, 'return_id'):
  75. self.return_id = False
  76. result = super(SQLInsertCompiler, self).as_sql(*args, **kwargs)
  77. return [self._fix_insert(x[0], x[1]) for x in result]
  78. def _fix_insert(self, sql, params):
  79. """
  80. Wrap the passed SQL with IDENTITY_INSERT statements and apply
  81. other necessary fixes.
  82. """
  83. meta = self.query.get_meta()
  84. if meta.has_auto_field:
  85. if hasattr(self.query, 'fields'):
  86. # django 1.4 replaced columns with fields
  87. fields = self.query.fields
  88. auto_field = meta.auto_field
  89. else:
  90. # < django 1.4
  91. fields = self.query.columns
  92. auto_field = meta.auto_field.db_column or meta.auto_field.column
  93. auto_in_fields = auto_field in fields
  94. quoted_table = self.connection.ops.quote_name(meta.db_table)
  95. if not fields or (auto_in_fields and len(fields) == 1 and not params):
  96. # convert format when inserting only the primary key without
  97. # specifying a value
  98. sql = 'INSERT INTO {0} DEFAULT VALUES'.format(
  99. quoted_table
  100. )
  101. params = []
  102. elif auto_in_fields:
  103. # wrap with identity insert
  104. sql = 'SET IDENTITY_INSERT {table} ON;{sql};SET IDENTITY_INSERT {table} OFF'.format(
  105. table=quoted_table,
  106. sql=sql,
  107. )
  108. # mangle SQL to return ID from insert
  109. # http://msdn.microsoft.com/en-us/library/ms177564.aspx
  110. if self.return_id and self.connection.features.can_return_id_from_insert:
  111. col = self.connection.ops.quote_name(meta.pk.db_column or meta.pk.get_attname())
  112. # Determine datatype for use with the table variable that will return the inserted ID
  113. pk_db_type = _re_data_type_terminator.split(meta.pk.db_type(self.connection))[0]
  114. # NOCOUNT ON to prevent additional trigger/stored proc related resultsets
  115. sql = 'SET NOCOUNT ON;{declare_table_var};{sql};{select_return_id}'.format(
  116. sql=sql,
  117. declare_table_var="DECLARE @sqlserver_ado_return_id table ({col_name} {pk_type})".format(
  118. col_name=col,
  119. pk_type=pk_db_type,
  120. ),
  121. select_return_id="SELECT * FROM @sqlserver_ado_return_id",
  122. )
  123. output = self._values_repl.format(col=col)
  124. sql = self._re_values_sub.sub(output, sql)
  125. return sql, params
  126. class SQLDeleteCompiler(compiler.SQLDeleteCompiler, SQLCompiler):
  127. pass
  128. class SQLUpdateCompiler(compiler.SQLUpdateCompiler, SQLCompiler):
  129. def as_sql(self):
  130. sql, params = super(SQLUpdateCompiler, self).as_sql()
  131. if sql:
  132. # Need the NOCOUNT OFF so UPDATE returns a count, instead of -1
  133. sql = 'SET NOCOUNT OFF; {0}; SET NOCOUNT ON'.format(sql)
  134. return sql, params
  135. class SQLAggregateCompiler(compiler.SQLAggregateCompiler, SQLCompiler):
  136. pass