base.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import datetime
  2. from django.utils import timezone
  3. import pymssql as Database
  4. from sqlserver_ado.base import (
  5. DatabaseFeatures as _DatabaseFeatures,
  6. DatabaseOperations as _DatabaseOperations,
  7. DatabaseWrapper as _DatabaseWrapper)
  8. DatabaseError = Database.DatabaseError
  9. IntegrityError = Database.IntegrityError
  10. VERSION_SQL2000 = 8
  11. VERSION_SQL2005 = 9
  12. VERSION_SQL2008 = 10
  13. def _fix_query(query):
  14. # For Django's inspectdb tests -- a model has a non-ASCII column name.
  15. if not isinstance(query, str):
  16. query = query.encode('utf-8')
  17. # For Django's backends and expressions_regress tests.
  18. query = query.replace('%%', '%')
  19. return query
  20. def _fix_value(value):
  21. if isinstance(value, datetime.datetime):
  22. if timezone.is_aware(value):
  23. value = timezone.make_naive(value, timezone.utc)
  24. return value
  25. def _fix_params(params):
  26. if params is not None:
  27. # pymssql needs a tuple, not another kind of iterable.
  28. params = tuple(_fix_value(value) for value in params)
  29. return params
  30. class CursorWrapper(object):
  31. def __init__(self, cursor):
  32. self.cursor = cursor
  33. def callproc(self, procname, params=None):
  34. params = _fix_params(params)
  35. return self.cursor.callproc(procname, params)
  36. def execute(self, query, params=None):
  37. query = _fix_query(query)
  38. params = _fix_params(params)
  39. return self.cursor.execute(query, params)
  40. def executemany(self, query, param_list):
  41. query = _fix_query(query)
  42. param_list = [_fix_params(params) for params in param_list]
  43. return self.cursor.executemany(query, param_list)
  44. def __getattr__(self, attr):
  45. return getattr(self.cursor, attr)
  46. def __iter__(self):
  47. return iter(self.cursor)
  48. class DatabaseOperations(_DatabaseOperations):
  49. compiler_module = "sqlserver_pymssql.compiler"
  50. class DatabaseFeatures(_DatabaseFeatures):
  51. can_introspect_max_length = False
  52. can_introspect_null = False
  53. can_introspect_decimal_field = False
  54. failing_tests = _DatabaseFeatures.failing_tests.copy()
  55. failing_tests.update({
  56. # pymssql doesn't handle binary data correctly.
  57. 'backends.tests.LastExecutedQueryTest'
  58. '.test_query_encoding': [(1, 7)],
  59. 'model_fields.tests.BinaryFieldTests'
  60. '.test_set_and_retrieve': [(1, 7)],
  61. # pymssql doesn't check parameter counts.
  62. 'backends.tests.ParameterHandlingTest'
  63. '.test_bad_parameter_count': [(1, 7)],
  64. # Several tests that depend on schema alteration fail at this time.
  65. # This should get fixed in django-mssql when it supports migrations.
  66. })
  67. class DatabaseWrapper(_DatabaseWrapper):
  68. Database = Database
  69. def __init__(self, *args, **kwargs):
  70. super(DatabaseWrapper, self).__init__(*args, **kwargs)
  71. self.features = DatabaseFeatures(self)
  72. self.ops = DatabaseOperations(self)
  73. def get_connection_params(self):
  74. settings_dict = self.settings_dict
  75. params = {
  76. 'host': settings_dict['HOST'],
  77. 'database': settings_dict['NAME'],
  78. 'user': settings_dict['USER'],
  79. 'password': settings_dict['PASSWORD'],
  80. 'port': settings_dict['PORT'],
  81. }
  82. options = settings_dict.get('OPTIONS', {})
  83. params.update(options)
  84. return params
  85. def get_new_connection(self, conn_params):
  86. return Database.connect(**conn_params)
  87. def init_connection_state(self):
  88. # Not calling super() because we don't care much about version checks.
  89. pass
  90. def create_cursor(self):
  91. cursor = self.connection.cursor()
  92. return CursorWrapper(cursor)
  93. def _set_autocommit(self, autocommit):
  94. self.connection.autocommit(autocommit)
  95. def __get_dbms_version(self, make_connection=True):
  96. """
  97. Returns the 'DBMS Version' string, or ''. If a connection to the
  98. database has not already been established, a connection will be made
  99. when `make_connection` is True.
  100. """
  101. if not self.connection and make_connection:
  102. self.connect()
  103. with self.connection.cursor() as cursor:
  104. cursor.execute("SELECT SERVERPROPERTY('productversion')")
  105. return cursor.fetchone()[0]
  106. def _is_sql2005_and_up(self, conn):
  107. return self._get_major_ver(conn) >= VERSION_SQL2005
  108. def _is_sql2008_and_up(self, conn):
  109. return self._get_major_ver(conn) >= VERSION_SQL2008