query.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. from __future__ import unicode_literals
  2. from django.db import connections
  3. from django.db.models import sql
  4. from django.db.models.query import RawQuerySet
  5. from sqlserver_ado.dbapi import FetchFailedError
  6. __all__ = [
  7. 'RawStoredProcedureQuery',
  8. 'RawStoredProcedureQuerySet',
  9. ]
  10. class RawStoredProcedureQuery(sql.RawQuery):
  11. """
  12. A single raw SQL stored procedure query
  13. """
  14. def clone(self, using):
  15. return RawStoredProcedureQuery(self.sql, using, params=self.params)
  16. def __repr__(self):
  17. return "<RawStoredProcedureQuery: %r %r>" % (self.sql, self.params)
  18. def _execute_query(self):
  19. """
  20. Execute the stored procedure using callproc, instead of execute.
  21. """
  22. self.cursor = connections[self.using].cursor()
  23. self.cursor.callproc(self.sql, self.params)
  24. class RawStoredProcedureQuerySet(RawQuerySet):
  25. """
  26. Provides an iterator which converts the results of raw SQL queries into
  27. annotated model instances.
  28. raw_query should only be the name of the stored procedure.
  29. """
  30. def __init__(self, raw_query, model=None, query=None, params=None,
  31. translations=None, using=None, hints=None):
  32. self.raw_query = raw_query
  33. self.model = model
  34. self._db = using
  35. self._hints = hints or {}
  36. self.query = query or RawStoredProcedureQuery(sql=raw_query, using=self.db, params=params)
  37. self.params = params or ()
  38. self.translations = translations or {}
  39. def __iter__(self):
  40. try:
  41. for x in super(RawStoredProcedureQuerySet, self).__iter__():
  42. yield x
  43. except FetchFailedError:
  44. # Stored procedure didn't return a record set
  45. pass
  46. def __repr__(self):
  47. return "<RawStoredProcedureQuerySet: %r %r>" % (self.raw_query, self.params)
  48. @property
  49. def columns(self):
  50. """
  51. A list of model field names in the order they'll appear in the
  52. query results.
  53. """
  54. if not hasattr(self, '_columns'):
  55. try:
  56. self._columns = self.query.get_columns()
  57. except TypeError:
  58. # "'NoneType' object is not iterable" thrown when stored procedure
  59. # doesn't return a result set.
  60. # no result means no column names, so grab them from the model
  61. self._columns = [self.model._meta.pk.db_column] # [x.db_column for x in self.model._meta.fields]
  62. # Adjust any column names which don't match field names
  63. for (query_name, model_name) in self.translations.items():
  64. try:
  65. index = self._columns.index(query_name)
  66. self._columns[index] = model_name
  67. except ValueError:
  68. # Ignore translations for non-existant column names
  69. pass
  70. return self._columns