creation.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. from __future__ import absolute_import, unicode_literals
  2. import sys
  3. import time
  4. from unittest import expectedFailure
  5. import django
  6. from django.conf import settings
  7. from django.core.exceptions import ImproperlyConfigured
  8. from django.db import connections
  9. from django.db.backends.base.base import NO_DB_ALIAS
  10. from django.db.backends.base.creation import BaseDatabaseCreation
  11. from django.utils import six
  12. from django.utils.functional import cached_property
  13. from django.utils.module_loading import import_string
  14. class DatabaseCreation(BaseDatabaseCreation):
  15. def _create_master_connection(self):
  16. """
  17. Create a transactionless connection to 'master' database.
  18. """
  19. settings_dict = self.connection.settings_dict.copy()
  20. settings_dict['NAME'] = 'master'
  21. nodb_connection = type(self.connection)(
  22. settings_dict,
  23. alias=NO_DB_ALIAS,
  24. allow_thread_sharing=False)
  25. return nodb_connection
  26. _nodb_connection = cached_property(_create_master_connection)
  27. def mark_tests_as_expected_failure(self, failing_tests):
  28. """
  29. Flag tests as expectedFailure. This should only run during the
  30. testsuite.
  31. """
  32. django_version = django.VERSION[:2]
  33. for test_name, versions in six.iteritems(failing_tests):
  34. if not versions or not isinstance(versions, (list, tuple)):
  35. # skip None, empty, or invalid
  36. continue
  37. if not isinstance(versions[0], (list, tuple)):
  38. # Ensure list of versions
  39. versions = [versions]
  40. if all(map(lambda v: v[:2] != django_version, versions)):
  41. continue
  42. try:
  43. test_case_name, _, method_name = test_name.rpartition('.')
  44. test_case = import_string(test_case_name)
  45. method = getattr(test_case, method_name)
  46. method = expectedFailure(method)
  47. setattr(test_case, method_name, method)
  48. except (ImportError, ImproperlyConfigured):
  49. pass
  50. def create_test_db(self, *args, **kwargs):
  51. self.mark_tests_as_expected_failure(self.connection.features.failing_tests)
  52. super(DatabaseCreation, self).create_test_db(*args, **kwargs)
  53. def _create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
  54. """
  55. Create the test databases using a connection to database 'master'.
  56. """
  57. if self._test_database_create(settings):
  58. try:
  59. test_database_name = super(DatabaseCreation, self)._create_test_db(verbosity, autoclobber)
  60. except Exception as e:
  61. if 'Choose a different database name.' in str(e):
  62. six.print_('Database "%s" could not be created because it already exists.' % test_database_name)
  63. else:
  64. six.reraise(*sys.exc_info())
  65. self.install_regex_clr(test_database_name)
  66. return test_database_name
  67. if verbosity >= 1:
  68. six.print_("Skipping Test DB creation")
  69. return self._get_test_db_name()
  70. def _destroy_test_db(self, test_database_name, verbosity=1):
  71. """
  72. Drop the test databases using a connection to database 'master'.
  73. """
  74. if not self._test_database_create(settings):
  75. if verbosity >= 1:
  76. six.print_("Skipping Test DB destruction")
  77. return
  78. for alias in connections:
  79. connections[alias].close()
  80. try:
  81. with self._nodb_connection.cursor() as cursor:
  82. qn_db_name = self.connection.ops.quote_name(test_database_name)
  83. # boot all other connections to the database, leaving only this connection
  84. cursor.execute("ALTER DATABASE %s SET SINGLE_USER WITH ROLLBACK IMMEDIATE" % qn_db_name)
  85. time.sleep(1)
  86. # database is now clear to drop
  87. cursor.execute("DROP DATABASE %s" % qn_db_name)
  88. except Exception:
  89. # if 'it is currently in use' in str(e):
  90. # six.print_('Cannot drop database %s because it is in use' % test_database_name)
  91. # else:
  92. six.reraise(*sys.exc_info())
  93. def _test_database_create(self, settings):
  94. """
  95. Check the settings to see if the test database should be created.
  96. """
  97. if 'TEST_CREATE' in self.connection.settings_dict:
  98. return self.connection.settings_dict.get('TEST_CREATE', True)
  99. if hasattr(settings, 'TEST_DATABASE_CREATE'):
  100. return settings.TEST_DATABASE_CREATE
  101. else:
  102. return True
  103. def enable_clr(self):
  104. """ Enables clr for server if not already enabled
  105. This function will not fail if current user doesn't have
  106. permissions to enable clr, and clr is already enabled
  107. """
  108. with self._nodb_connection.cursor() as cursor:
  109. # check whether clr is enabled
  110. cursor.execute('''
  111. SELECT value FROM sys.configurations
  112. WHERE name = 'clr enabled'
  113. ''')
  114. res = cursor.fetchone()
  115. if not res or not res[0]:
  116. # if not enabled enable clr
  117. cursor.execute("sp_configure 'clr enabled', 1")
  118. cursor.execute("RECONFIGURE")
  119. def install_regex_clr(self, database_name):
  120. sql = '''
  121. USE {database_name};
  122. -- Drop and recreate the function if it already exists
  123. IF OBJECT_ID('REGEXP_LIKE') IS NOT NULL
  124. DROP FUNCTION [dbo].[REGEXP_LIKE]
  125. IF EXISTS(select * from sys.assemblies where name like 'regex_clr')
  126. DROP ASSEMBLY regex_clr
  127. ;
  128. CREATE ASSEMBLY regex_clr
  129. FROM 0x{assembly_hex}
  130. WITH PERMISSION_SET = SAFE;
  131. create function [dbo].[REGEXP_LIKE]
  132. (
  133. @input nvarchar(max),
  134. @pattern nvarchar(max),
  135. @caseSensitive int
  136. )
  137. RETURNS INT AS
  138. EXTERNAL NAME regex_clr.UserDefinedFunctions.REGEXP_LIKE
  139. '''.format(
  140. database_name=self.connection.ops.quote_name(database_name),
  141. assembly_hex=self.get_regex_clr_assembly_hex(),
  142. ).split(';')
  143. self.enable_clr()
  144. with self._nodb_connection.cursor() as cursor:
  145. for s in sql:
  146. cursor.execute(s)
  147. def get_regex_clr_assembly_hex(self):
  148. import os
  149. import binascii
  150. with open(os.path.join(os.path.dirname(__file__), 'regex_clr.dll'), 'rb') as f:
  151. assembly = binascii.hexlify(f.read()).decode('ascii')
  152. return assembly