123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- from __future__ import absolute_import, unicode_literals
- import sys
- import time
- from unittest import expectedFailure
- import django
- from django.conf import settings
- from django.core.exceptions import ImproperlyConfigured
- from django.db import connections
- from django.db.backends.base.base import NO_DB_ALIAS
- from django.db.backends.base.creation import BaseDatabaseCreation
- from django.utils import six
- from django.utils.functional import cached_property
- from django.utils.module_loading import import_string
- class DatabaseCreation(BaseDatabaseCreation):
- def _create_master_connection(self):
- """
- Create a transactionless connection to 'master' database.
- """
- settings_dict = self.connection.settings_dict.copy()
- settings_dict['NAME'] = 'master'
- nodb_connection = type(self.connection)(
- settings_dict,
- alias=NO_DB_ALIAS,
- allow_thread_sharing=False)
- return nodb_connection
- _nodb_connection = cached_property(_create_master_connection)
- def mark_tests_as_expected_failure(self, failing_tests):
- """
- Flag tests as expectedFailure. This should only run during the
- testsuite.
- """
- django_version = django.VERSION[:2]
- for test_name, versions in six.iteritems(failing_tests):
- if not versions or not isinstance(versions, (list, tuple)):
- # skip None, empty, or invalid
- continue
- if not isinstance(versions[0], (list, tuple)):
- # Ensure list of versions
- versions = [versions]
- if all(map(lambda v: v[:2] != django_version, versions)):
- continue
- try:
- test_case_name, _, method_name = test_name.rpartition('.')
- test_case = import_string(test_case_name)
- method = getattr(test_case, method_name)
- method = expectedFailure(method)
- setattr(test_case, method_name, method)
- except (ImportError, ImproperlyConfigured):
- pass
- def create_test_db(self, *args, **kwargs):
- self.mark_tests_as_expected_failure(self.connection.features.failing_tests)
- super(DatabaseCreation, self).create_test_db(*args, **kwargs)
- def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
- """
- Create the test databases using a connection to database 'master'.
- """
- if self._test_database_create(settings):
- try:
- test_database_name = super(DatabaseCreation, self)._create_test_db(verbosity, autoclobber)
- except Exception as e:
- if 'Choose a different database name.' in str(e):
- six.print_('Database "%s" could not be created because it already exists.' % test_database_name)
- else:
- six.reraise(*sys.exc_info())
- self.install_regex_clr(test_database_name)
- return test_database_name
- if verbosity >= 1:
- six.print_("Skipping Test DB creation")
- return self._get_test_db_name()
- def _destroy_test_db(self, test_database_name, verbosity=1):
- """
- Drop the test databases using a connection to database 'master'.
- """
- if not self._test_database_create(settings):
- if verbosity >= 1:
- six.print_("Skipping Test DB destruction")
- return
- for alias in connections:
- connections[alias].close()
- try:
- with self._nodb_connection.cursor() as cursor:
- qn_db_name = self.connection.ops.quote_name(test_database_name)
- # boot all other connections to the database, leaving only this connection
- cursor.execute("ALTER DATABASE %s SET SINGLE_USER WITH ROLLBACK IMMEDIATE" % qn_db_name)
- time.sleep(1)
- # database is now clear to drop
- cursor.execute("DROP DATABASE %s" % qn_db_name)
- except Exception:
- # if 'it is currently in use' in str(e):
- # six.print_('Cannot drop database %s because it is in use' % test_database_name)
- # else:
- six.reraise(*sys.exc_info())
- def _test_database_create(self, settings):
- """
- Check the settings to see if the test database should be created.
- """
- if 'TEST_CREATE' in self.connection.settings_dict:
- return self.connection.settings_dict.get('TEST_CREATE', True)
- if hasattr(settings, 'TEST_DATABASE_CREATE'):
- return settings.TEST_DATABASE_CREATE
- else:
- return True
- def enable_clr(self):
- """ Enables clr for server if not already enabled
- This function will not fail if current user doesn't have
- permissions to enable clr, and clr is already enabled
- """
- with self._nodb_connection.cursor() as cursor:
- # check whether clr is enabled
- cursor.execute('''
- SELECT value FROM sys.configurations
- WHERE name = 'clr enabled'
- ''')
- res = cursor.fetchone()
- if not res or not res[0]:
- # if not enabled enable clr
- cursor.execute("sp_configure 'clr enabled', 1")
- cursor.execute("RECONFIGURE")
- def install_regex_clr(self, database_name):
- sql = '''
- USE {database_name};
- -- Drop and recreate the function if it already exists
- IF OBJECT_ID('REGEXP_LIKE') IS NOT NULL
- DROP FUNCTION [dbo].[REGEXP_LIKE]
- IF EXISTS(select * from sys.assemblies where name like 'regex_clr')
- DROP ASSEMBLY regex_clr
- ;
- CREATE ASSEMBLY regex_clr
- FROM 0x{assembly_hex}
- WITH PERMISSION_SET = SAFE;
- create function [dbo].[REGEXP_LIKE]
- (
- @input nvarchar(max),
- @pattern nvarchar(max),
- @caseSensitive int
- )
- RETURNS INT AS
- EXTERNAL NAME regex_clr.UserDefinedFunctions.REGEXP_LIKE
- '''.format(
- database_name=self.connection.ops.quote_name(database_name),
- assembly_hex=self.get_regex_clr_assembly_hex(),
- ).split(';')
- self.enable_clr()
- with self._nodb_connection.cursor() as cursor:
- for s in sql:
- cursor.execute(s)
- def get_regex_clr_assembly_hex(self):
- import os
- import binascii
- with open(os.path.join(os.path.dirname(__file__), 'regex_clr.dll'), 'rb') as f:
- assembly = binascii.hexlify(f.read()).decode('ascii')
- return assembly
|