compiler.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322
  1. import datetime
  2. from django.conf import settings
  3. from django.db.backends.utils import truncate_name, typecast_date, typecast_timestamp
  4. from django.db.models.sql import compiler
  5. from django.db.models.sql.constants import MULTI
  6. from django.utils import six
  7. from django.utils.six.moves import zip, zip_longest
  8. from django.utils import timezone
  9. SQLCompiler = compiler.SQLCompiler
  10. class GeoSQLCompiler(compiler.SQLCompiler):
  11. def get_columns(self, with_aliases=False):
  12. """
  13. Return the list of columns to use in the select statement. If no
  14. columns have been specified, returns all columns relating to fields in
  15. the model.
  16. If 'with_aliases' is true, any column names that are duplicated
  17. (without the table names) are given unique aliases. This is needed in
  18. some cases to avoid ambiguity with nested queries.
  19. This routine is overridden from Query to handle customized selection of
  20. geometry columns.
  21. """
  22. qn = self
  23. qn2 = self.connection.ops.quote_name
  24. result = ['(%s) AS %s' % (self.get_extra_select_format(alias) % col[0], qn2(alias))
  25. for alias, col in six.iteritems(self.query.extra_select)]
  26. params = []
  27. aliases = set(self.query.extra_select.keys())
  28. if with_aliases:
  29. col_aliases = aliases.copy()
  30. else:
  31. col_aliases = set()
  32. if self.query.select:
  33. only_load = self.deferred_to_columns()
  34. # This loop customized for GeoQuery.
  35. for col, field in self.query.select:
  36. if isinstance(col, (list, tuple)):
  37. alias, column = col
  38. table = self.query.alias_map[alias].table_name
  39. if table in only_load and column not in only_load[table]:
  40. continue
  41. r = self.get_field_select(field, alias, column)
  42. if with_aliases:
  43. if col[1] in col_aliases:
  44. c_alias = 'Col%d' % len(col_aliases)
  45. result.append('%s AS %s' % (r, c_alias))
  46. aliases.add(c_alias)
  47. col_aliases.add(c_alias)
  48. else:
  49. result.append('%s AS %s' % (r, qn2(col[1])))
  50. aliases.add(r)
  51. col_aliases.add(col[1])
  52. else:
  53. result.append(r)
  54. aliases.add(r)
  55. col_aliases.add(col[1])
  56. else:
  57. col_sql, col_params = col.as_sql(qn, self.connection)
  58. result.append(col_sql)
  59. params.extend(col_params)
  60. if hasattr(col, 'alias'):
  61. aliases.add(col.alias)
  62. col_aliases.add(col.alias)
  63. elif self.query.default_cols:
  64. cols, new_aliases = self.get_default_columns(with_aliases,
  65. col_aliases)
  66. result.extend(cols)
  67. aliases.update(new_aliases)
  68. max_name_length = self.connection.ops.max_name_length()
  69. for alias, aggregate in self.query.aggregate_select.items():
  70. agg_sql, agg_params = aggregate.as_sql(qn, self.connection)
  71. if alias is None:
  72. result.append(agg_sql)
  73. else:
  74. result.append('%s AS %s' % (agg_sql, qn(truncate_name(alias, max_name_length))))
  75. params.extend(agg_params)
  76. # This loop customized for GeoQuery.
  77. for (table, col), field in self.query.related_select_cols:
  78. r = self.get_field_select(field, table, col)
  79. if with_aliases and col in col_aliases:
  80. c_alias = 'Col%d' % len(col_aliases)
  81. result.append('%s AS %s' % (r, c_alias))
  82. aliases.add(c_alias)
  83. col_aliases.add(c_alias)
  84. else:
  85. result.append(r)
  86. aliases.add(r)
  87. col_aliases.add(col)
  88. self._select_aliases = aliases
  89. return result, params
  90. def get_default_columns(self, with_aliases=False, col_aliases=None,
  91. start_alias=None, opts=None, as_pairs=False, from_parent=None):
  92. """
  93. Computes the default columns for selecting every field in the base
  94. model. Will sometimes be called to pull in related models (e.g. via
  95. select_related), in which case "opts" and "start_alias" will be given
  96. to provide a starting point for the traversal.
  97. Returns a list of strings, quoted appropriately for use in SQL
  98. directly, as well as a set of aliases used in the select statement (if
  99. 'as_pairs' is True, returns a list of (alias, col_name) pairs instead
  100. of strings as the first component and None as the second component).
  101. This routine is overridden from Query to handle customized selection of
  102. geometry columns.
  103. """
  104. result = []
  105. if opts is None:
  106. opts = self.query.get_meta()
  107. aliases = set()
  108. only_load = self.deferred_to_columns()
  109. seen = self.query.included_inherited_models.copy()
  110. if start_alias:
  111. seen[None] = start_alias
  112. for field, model in opts.get_concrete_fields_with_model():
  113. if from_parent and model is not None and issubclass(from_parent, model):
  114. # Avoid loading data for already loaded parents.
  115. continue
  116. alias = self.query.join_parent_model(opts, model, start_alias, seen)
  117. table = self.query.alias_map[alias].table_name
  118. if table in only_load and field.column not in only_load[table]:
  119. continue
  120. if as_pairs:
  121. result.append((alias, field))
  122. aliases.add(alias)
  123. continue
  124. # This part of the function is customized for GeoQuery. We
  125. # see if there was any custom selection specified in the
  126. # dictionary, and set up the selection format appropriately.
  127. field_sel = self.get_field_select(field, alias)
  128. if with_aliases and field.column in col_aliases:
  129. c_alias = 'Col%d' % len(col_aliases)
  130. result.append('%s AS %s' % (field_sel, c_alias))
  131. col_aliases.add(c_alias)
  132. aliases.add(c_alias)
  133. else:
  134. r = field_sel
  135. result.append(r)
  136. aliases.add(r)
  137. if with_aliases:
  138. col_aliases.add(field.column)
  139. return result, aliases
  140. def resolve_columns(self, row, fields=()):
  141. """
  142. This routine is necessary so that distances and geometries returned
  143. from extra selection SQL get resolved appropriately into Python
  144. objects.
  145. """
  146. values = []
  147. aliases = list(self.query.extra_select)
  148. # Have to set a starting row number offset that is used for
  149. # determining the correct starting row index -- needed for
  150. # doing pagination with Oracle.
  151. rn_offset = 0
  152. if self.connection.ops.oracle:
  153. if self.query.high_mark is not None or self.query.low_mark:
  154. rn_offset = 1
  155. index_start = rn_offset + len(aliases)
  156. # Converting any extra selection values (e.g., geometries and
  157. # distance objects added by GeoQuerySet methods).
  158. values = [self.query.convert_values(v,
  159. self.query.extra_select_fields.get(a, None),
  160. self.connection)
  161. for v, a in zip(row[rn_offset:index_start], aliases)]
  162. if self.connection.ops.oracle or getattr(self.query, 'geo_values', False):
  163. # We resolve the rest of the columns if we're on Oracle or if
  164. # the `geo_values` attribute is defined.
  165. for value, field in zip_longest(row[index_start:], fields):
  166. values.append(self.query.convert_values(value, field, self.connection))
  167. else:
  168. values.extend(row[index_start:])
  169. return tuple(values)
  170. #### Routines unique to GeoQuery ####
  171. def get_extra_select_format(self, alias):
  172. sel_fmt = '%s'
  173. if hasattr(self.query, 'custom_select') and alias in self.query.custom_select:
  174. sel_fmt = sel_fmt % self.query.custom_select[alias]
  175. return sel_fmt
  176. def get_field_select(self, field, alias=None, column=None):
  177. """
  178. Returns the SELECT SQL string for the given field. Figures out
  179. if any custom selection SQL is needed for the column The `alias`
  180. keyword may be used to manually specify the database table where
  181. the column exists, if not in the model associated with this
  182. `GeoQuery`. Similarly, `column` may be used to specify the exact
  183. column name, rather than using the `column` attribute on `field`.
  184. """
  185. sel_fmt = self.get_select_format(field)
  186. if field in self.query.custom_select:
  187. field_sel = sel_fmt % self.query.custom_select[field]
  188. else:
  189. field_sel = sel_fmt % self._field_column(field, alias, column)
  190. return field_sel
  191. def get_select_format(self, fld):
  192. """
  193. Returns the selection format string, depending on the requirements
  194. of the spatial backend. For example, Oracle and MySQL require custom
  195. selection formats in order to retrieve geometries in OGC WKT. For all
  196. other fields a simple '%s' format string is returned.
  197. """
  198. if self.connection.ops.select and hasattr(fld, 'geom_type'):
  199. # This allows operations to be done on fields in the SELECT,
  200. # overriding their values -- used by the Oracle and MySQL
  201. # spatial backends to get database values as WKT, and by the
  202. # `transform` method.
  203. sel_fmt = self.connection.ops.select
  204. # Because WKT doesn't contain spatial reference information,
  205. # the SRID is prefixed to the returned WKT to ensure that the
  206. # transformed geometries have an SRID different than that of the
  207. # field -- this is only used by `transform` for Oracle and
  208. # SpatiaLite backends.
  209. if self.query.transformed_srid and (self.connection.ops.oracle or
  210. self.connection.ops.spatialite):
  211. sel_fmt = "'SRID=%d;'||%s" % (self.query.transformed_srid, sel_fmt)
  212. else:
  213. sel_fmt = '%s'
  214. return sel_fmt
  215. # Private API utilities, subject to change.
  216. def _field_column(self, field, table_alias=None, column=None):
  217. """
  218. Helper function that returns the database column for the given field.
  219. The table and column are returned (quoted) in the proper format, e.g.,
  220. `"geoapp_city"."point"`. If `table_alias` is not specified, the
  221. database table associated with the model of this `GeoQuery` will be
  222. used. If `column` is specified, it will be used instead of the value
  223. in `field.column`.
  224. """
  225. if table_alias is None:
  226. table_alias = self.query.get_meta().db_table
  227. return "%s.%s" % (self.quote_name_unless_alias(table_alias),
  228. self.connection.ops.quote_name(column or field.column))
  229. class SQLInsertCompiler(compiler.SQLInsertCompiler, GeoSQLCompiler):
  230. pass
  231. class SQLDeleteCompiler(compiler.SQLDeleteCompiler, GeoSQLCompiler):
  232. pass
  233. class SQLUpdateCompiler(compiler.SQLUpdateCompiler, GeoSQLCompiler):
  234. pass
  235. class SQLAggregateCompiler(compiler.SQLAggregateCompiler, GeoSQLCompiler):
  236. pass
  237. class SQLDateCompiler(compiler.SQLDateCompiler, GeoSQLCompiler):
  238. """
  239. This is overridden for GeoDjango to properly cast date columns, since
  240. `GeoQuery.resolve_columns` is used for spatial values.
  241. See #14648, #16757.
  242. """
  243. def results_iter(self):
  244. if self.connection.ops.oracle:
  245. from django.db.models.fields import DateTimeField
  246. fields = [DateTimeField()]
  247. else:
  248. needs_string_cast = self.connection.features.needs_datetime_string_cast
  249. offset = len(self.query.extra_select)
  250. for rows in self.execute_sql(MULTI):
  251. for row in rows:
  252. date = row[offset]
  253. if self.connection.ops.oracle:
  254. date = self.resolve_columns(row, fields)[offset]
  255. elif needs_string_cast:
  256. date = typecast_date(str(date))
  257. if isinstance(date, datetime.datetime):
  258. date = date.date()
  259. yield date
  260. class SQLDateTimeCompiler(compiler.SQLDateTimeCompiler, GeoSQLCompiler):
  261. """
  262. This is overridden for GeoDjango to properly cast date columns, since
  263. `GeoQuery.resolve_columns` is used for spatial values.
  264. See #14648, #16757.
  265. """
  266. def results_iter(self):
  267. if self.connection.ops.oracle:
  268. from django.db.models.fields import DateTimeField
  269. fields = [DateTimeField()]
  270. else:
  271. needs_string_cast = self.connection.features.needs_datetime_string_cast
  272. offset = len(self.query.extra_select)
  273. for rows in self.execute_sql(MULTI):
  274. for row in rows:
  275. datetime = row[offset]
  276. if self.connection.ops.oracle:
  277. datetime = self.resolve_columns(row, fields)[offset]
  278. elif needs_string_cast:
  279. datetime = typecast_timestamp(str(datetime))
  280. # Datetimes are artificially returned in UTC on databases that
  281. # don't support time zone. Restore the zone used in the query.
  282. if settings.USE_TZ:
  283. datetime = datetime.replace(tzinfo=None)
  284. datetime = timezone.make_aware(datetime, self.query.tzinfo)
  285. yield datetime