123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114 |
- # -*- coding: utf-8 -*-
- import six
- import time
- import traceback
- from contextlib import contextmanager
- import django
- from django.conf import settings
- from django.core.exceptions import ImproperlyConfigured
- from django.db.backends import utils
- @contextmanager
- def monkey_patch_cursordebugwrapper(print_sql=None, print_sql_location=False, truncate=None, logger=six.print_, confprefix="DJANGO_EXTENSIONS"):
- if not print_sql:
- yield
- else:
- truncate = getattr(settings, '%s_PRINT_SQL_TRUNCATE' % confprefix, 1000)
- # Code orginally from http://gist.github.com/118990
- sqlparse = None
- if getattr(settings, '%s_SQLPARSE_ENABLED' % confprefix, True):
- try:
- import sqlparse
- sqlparse_format_kwargs_defaults = dict(
- reindent_aligned=True,
- truncate_strings=500,
- )
- sqlparse_format_kwargs = getattr(settings, '%s_SQLPARSE_FORMAT_KWARGS' % confprefix, sqlparse_format_kwargs_defaults)
- except ImportError:
- sqlparse = None
- pygments = None
- if getattr(settings, '%s_PYGMENTS_ENABLED' % confprefix, True):
- try:
- import pygments.lexers
- import pygments.formatters
- pygments_formatter = getattr(settings, '%s_PYGMENTS_FORMATTER' % confprefix, pygments.formatters.TerminalFormatter)
- pygments_formatter_kwargs = getattr(settings, '%s_PYGMENTS_FORMATTER_KWARGS' % confprefix, {})
- except ImportError:
- pass
- class PrintQueryWrapperMixin:
- def execute(self, sql, params=()):
- starttime = time.time()
- try:
- return utils.CursorWrapper.execute(self, sql, params)
- finally:
- execution_time = time.time() - starttime
- raw_sql = self.db.ops.last_executed_query(self.cursor, sql, params)
- if truncate:
- raw_sql = raw_sql[:truncate]
- if sqlparse:
- raw_sql = sqlparse.format(raw_sql, **sqlparse_format_kwargs)
- if pygments:
- raw_sql = pygments.highlight(
- raw_sql,
- pygments.lexers.get_lexer_by_name("sql"),
- pygments_formatter(**pygments_formatter_kwargs),
- )
- logger(raw_sql)
- logger("Execution time: %.6fs [Database: %s]" % (execution_time, self.db.alias))
- if print_sql_location:
- logger("Location of SQL Call:")
- logger(''.join(traceback.format_stack()))
- _CursorDebugWrapper = utils.CursorDebugWrapper
- class PrintCursorQueryWrapper(PrintQueryWrapperMixin, _CursorDebugWrapper):
- pass
- try:
- from django.db import connections
- _force_debug_cursor = {}
- for connection_name in connections:
- _force_debug_cursor[connection_name] = connections[connection_name].force_debug_cursor
- except Exception:
- connections = None
- utils.CursorDebugWrapper = PrintCursorQueryWrapper
- postgresql_base = None
- if django.VERSION >= (3, 0):
- try:
- from django.db.backends.postgresql import base as postgresql_base
- _PostgreSQLCursorDebugWrapper = postgresql_base.CursorDebugWrapper
- class PostgreSQLPrintCursorDebugWrapper(PrintQueryWrapperMixin, _PostgreSQLCursorDebugWrapper):
- pass
- except (ImproperlyConfigured, TypeError):
- postgresql_base = None
- if postgresql_base:
- postgresql_base.CursorDebugWrapper = PostgreSQLPrintCursorDebugWrapper
- if connections:
- for connection_name in connections:
- connections[connection_name].force_debug_cursor = True
- yield
- utils.CursorDebugWrapper = _CursorDebugWrapper
- if postgresql_base:
- postgresql_base.CursorDebugWrapper = _PostgreSQLCursorDebugWrapper
- if connections:
- for connection_name in connections:
- connections[connection_name].force_debug_cursor = _force_debug_cursor[connection_name]
|