123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- import datetime
- from django.utils import timezone
- import pymssql as Database
- from sqlserver_ado.base import (
- DatabaseFeatures as _DatabaseFeatures,
- DatabaseOperations as _DatabaseOperations,
- DatabaseWrapper as _DatabaseWrapper)
- DatabaseError = Database.DatabaseError
- IntegrityError = Database.IntegrityError
- VERSION_SQL2000 = 8
- VERSION_SQL2005 = 9
- VERSION_SQL2008 = 10
- def _fix_query(query):
- # For Django's inspectdb tests -- a model has a non-ASCII column name.
- if not isinstance(query, str):
- query = query.encode('utf-8')
- # For Django's backends and expressions_regress tests.
- query = query.replace('%%', '%')
- return query
- def _fix_value(value):
- if isinstance(value, datetime.datetime):
- if timezone.is_aware(value):
- value = timezone.make_naive(value, timezone.utc)
- return value
- def _fix_params(params):
- if params is not None:
- # pymssql needs a tuple, not another kind of iterable.
- params = tuple(_fix_value(value) for value in params)
- return params
- class CursorWrapper(object):
- def __init__(self, cursor):
- self.cursor = cursor
- def callproc(self, procname, params=None):
- params = _fix_params(params)
- return self.cursor.callproc(procname, params)
- def execute(self, query, params=None):
- query = _fix_query(query)
- params = _fix_params(params)
- return self.cursor.execute(query, params)
- def executemany(self, query, param_list):
- query = _fix_query(query)
- param_list = [_fix_params(params) for params in param_list]
- return self.cursor.executemany(query, param_list)
- def __getattr__(self, attr):
- return getattr(self.cursor, attr)
- def __iter__(self):
- return iter(self.cursor)
- class DatabaseOperations(_DatabaseOperations):
- compiler_module = "sqlserver_pymssql.compiler"
- class DatabaseFeatures(_DatabaseFeatures):
- can_introspect_max_length = False
- can_introspect_null = False
- can_introspect_decimal_field = False
- failing_tests = _DatabaseFeatures.failing_tests.copy()
- failing_tests.update({
- # pymssql doesn't handle binary data correctly.
- 'backends.tests.LastExecutedQueryTest'
- '.test_query_encoding': [(1, 7)],
- 'model_fields.tests.BinaryFieldTests'
- '.test_set_and_retrieve': [(1, 7)],
- # pymssql doesn't check parameter counts.
- 'backends.tests.ParameterHandlingTest'
- '.test_bad_parameter_count': [(1, 7)],
- # Several tests that depend on schema alteration fail at this time.
- # This should get fixed in django-mssql when it supports migrations.
- })
- class DatabaseWrapper(_DatabaseWrapper):
- Database = Database
- def __init__(self, *args, **kwargs):
- super(DatabaseWrapper, self).__init__(*args, **kwargs)
- self.features = DatabaseFeatures(self)
- self.ops = DatabaseOperations(self)
- def get_connection_params(self):
- settings_dict = self.settings_dict
- params = {
- 'host': settings_dict['HOST'],
- 'database': settings_dict['NAME'],
- 'user': settings_dict['USER'],
- 'password': settings_dict['PASSWORD'],
- 'port': settings_dict['PORT'],
- }
- options = settings_dict.get('OPTIONS', {})
- params.update(options)
- return params
- def get_new_connection(self, conn_params):
- return Database.connect(**conn_params)
- def init_connection_state(self):
- # Not calling super() because we don't care much about version checks.
- pass
- def create_cursor(self):
- cursor = self.connection.cursor()
- return CursorWrapper(cursor)
- def _set_autocommit(self, autocommit):
- self.connection.autocommit(autocommit)
- def __get_dbms_version(self, make_connection=True):
- """
- Returns the 'DBMS Version' string, or ''. If a connection to the
- database has not already been established, a connection will be made
- when `make_connection` is True.
- """
- if not self.connection and make_connection:
- self.connect()
- with self.connection.cursor() as cursor:
- cursor.execute("SELECT SERVERPROPERTY('productversion')")
- return cursor.fetchone()[0]
- def _is_sql2005_and_up(self, conn):
- return self._get_major_ver(conn) >= VERSION_SQL2005
- def _is_sql2008_and_up(self, conn):
- return self._get_major_ver(conn) >= VERSION_SQL2008
|