|
- # -*- coding: utf-8 -*-
- """
- sqldiff.py - Prints the (approximated) difference between models and database
- TODO:
- - better support for relations
- - better support for constraints (mainly postgresql?)
- - support for table spaces with postgresql
- - when a table is not managed (meta.managed==False) then only do a one-way
- sqldiff ? show differences from db->table but not the other way around since
- it's not managed.
- KNOWN ISSUES:
- - MySQL has by far the most problems with introspection. Please be
- carefull when using MySQL with sqldiff.
- - Booleans are reported back as Integers, so there's no way to know if
- there was a real change.
- - Varchar sizes are reported back without unicode support so their size
- may change in comparison to the real length of the varchar.
- - Some of the 'fixes' to counter these problems might create false
- positives or false negatives.
- """
- import importlib
- import sys
- import six
- import argparse
- from typing import Dict, Union, Callable, Optional # NOQA
- from django.apps import apps
- from django.core.management import BaseCommand, CommandError
- from django.core.management.base import OutputWrapper
- from django.core.management.color import no_style
- from django.db import connection, transaction, models
- from django.db.models.fields import AutoField, IntegerField
- from django.db.models.options import normalize_together
- from django_extensions.management.utils import signalcommand
- ORDERING_FIELD = IntegerField('_order', null=True)
- def flatten(lst, ltypes=(list, tuple)):
- ltype = type(lst)
- lst = list(lst)
- i = 0
- while i < len(lst):
- while isinstance(lst[i], ltypes):
- if not lst[i]:
- lst.pop(i)
- i -= 1
- break
- else:
- lst[i:i + 1] = lst[i]
- i += 1
- return ltype(lst)
- def all_local_fields(meta):
- all_fields = []
- if meta.proxy:
- for parent in meta.parents:
- all_fields.extend(all_local_fields(parent._meta))
- else:
- for f in meta.local_fields:
- col_type = f.db_type(connection=connection)
- if col_type is None:
- continue
- all_fields.append(f)
- return all_fields
- class SQLDiff:
- DATA_TYPES_REVERSE_OVERRIDE = {} # type: Dict[int, Union[str, Callable]]
- IGNORE_MISSING_TABLES = [
- "django_migrations",
- ]
- DIFF_TYPES = [
- 'error',
- 'comment',
- 'table-missing-in-db',
- 'table-missing-in-model',
- 'field-missing-in-db',
- 'field-missing-in-model',
- 'fkey-missing-in-db',
- 'fkey-missing-in-model',
- 'index-missing-in-db',
- 'index-missing-in-model',
- 'unique-missing-in-db',
- 'unique-missing-in-model',
- 'field-type-differ',
- 'field-parameter-differ',
- 'notnull-differ',
- ]
- DIFF_TEXTS = {
- 'error': 'error: %(0)s',
- 'comment': 'comment: %(0)s',
- 'table-missing-in-db': "table '%(0)s' missing in database",
- 'table-missing-in-model': "table '%(0)s' missing in models",
- 'field-missing-in-db': "field '%(1)s' defined in model but missing in database",
- 'field-missing-in-model': "field '%(1)s' defined in database but missing in model",
- 'fkey-missing-in-db': "field '%(1)s' FOREIGN KEY defined in model but missing in database",
- 'fkey-missing-in-model': "field '%(1)s' FOREIGN KEY defined in database but missing in model",
- 'index-missing-in-db': "field '%(1)s' INDEX named '%(2)s' defined in model but missing in database",
- 'index-missing-in-model': "field '%(1)s' INDEX defined in database schema but missing in model",
- 'unique-missing-in-db': "field '%(1)s' UNIQUE named '%(2)s' defined in model but missing in database",
- 'unique-missing-in-model': "field '%(1)s' UNIQUE defined in database schema but missing in model",
- 'field-type-differ': "field '%(1)s' not of same type: db='%(3)s', model='%(2)s'",
- 'field-parameter-differ': "field '%(1)s' parameters differ: db='%(3)s', model='%(2)s'",
- 'notnull-differ': "field '%(1)s' null constraint should be '%(2)s' in the database",
- }
- SQL_FIELD_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (
- style.SQL_KEYWORD('ALTER TABLE'),
- style.SQL_TABLE(qn(args[0])),
- style.SQL_KEYWORD('ADD COLUMN'),
- style.SQL_FIELD(qn(args[1])),
- ' '.join(style.SQL_COLTYPE(a) if i == 0 else style.SQL_KEYWORD(a) for i, a in enumerate(args[2:]))
- )
- SQL_FIELD_MISSING_IN_MODEL = lambda self, style, qn, args: "%s %s\n\t%s %s;" % (
- style.SQL_KEYWORD('ALTER TABLE'),
- style.SQL_TABLE(qn(args[0])),
- style.SQL_KEYWORD('DROP COLUMN'),
- style.SQL_FIELD(qn(args[1]))
- )
- SQL_FKEY_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s %s (%s)%s;" % (
- style.SQL_KEYWORD('ALTER TABLE'),
- style.SQL_TABLE(qn(args[0])),
- style.SQL_KEYWORD('ADD COLUMN'),
- style.SQL_FIELD(qn(args[1])),
- ' '.join(style.SQL_COLTYPE(a) if i == 0 else style.SQL_KEYWORD(a) for i, a in enumerate(args[4:])),
- style.SQL_KEYWORD('REFERENCES'),
- style.SQL_TABLE(qn(args[2])),
- style.SQL_FIELD(qn(args[3])),
- connection.ops.deferrable_sql()
- )
- SQL_INDEX_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s (%s%s);" % (
- style.SQL_KEYWORD('CREATE INDEX'),
- style.SQL_TABLE(qn(args[2])),
- # style.SQL_TABLE(qn("%s" % '_'.join('_'.join(a) if isinstance(a, (list, tuple)) else a for a in args[0:3] if a))),
- style.SQL_KEYWORD('ON'), style.SQL_TABLE(qn(args[0])),
- style.SQL_FIELD(', '.join(qn(e) for e in args[1])),
- style.SQL_KEYWORD(args[3])
- )
- SQL_INDEX_MISSING_IN_MODEL = lambda self, style, qn, args: "%s %s;" % (
- style.SQL_KEYWORD('DROP INDEX'),
- style.SQL_TABLE(qn(args[1]))
- )
- SQL_UNIQUE_MISSING_IN_DB = lambda self, style, qn, args: "%s %s\n\t%s %s %s (%s);" % (
- style.SQL_KEYWORD('ALTER TABLE'),
- style.SQL_TABLE(qn(args[0])),
- style.SQL_KEYWORD('ADD CONSTRAINT'),
- style.SQL_TABLE(qn(args[2])),
- style.SQL_KEYWORD('UNIQUE'),
- style.SQL_FIELD(', '.join(qn(e) for e in args[1]))
- )
- SQL_UNIQUE_MISSING_IN_MODEL = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (
- style.SQL_KEYWORD('ALTER TABLE'),
- style.SQL_TABLE(qn(args[0])),
- style.SQL_KEYWORD('DROP'),
- style.SQL_KEYWORD('CONSTRAINT'),
- style.SQL_TABLE(qn(args[1]))
- )
- SQL_FIELD_TYPE_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (
- style.SQL_KEYWORD('ALTER TABLE'),
- style.SQL_TABLE(qn(args[0])),
- style.SQL_KEYWORD("MODIFY"),
- style.SQL_FIELD(qn(args[1])),
- style.SQL_COLTYPE(args[2])
- )
- SQL_FIELD_PARAMETER_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s;" % (
- style.SQL_KEYWORD('ALTER TABLE'),
- style.SQL_TABLE(qn(args[0])),
- style.SQL_KEYWORD("MODIFY"),
- style.SQL_FIELD(qn(args[1])),
- style.SQL_COLTYPE(args[2])
- )
- SQL_NOTNULL_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s;" % (
- style.SQL_KEYWORD('ALTER TABLE'),
- style.SQL_TABLE(qn(args[0])),
- style.SQL_KEYWORD('MODIFY'),
- style.SQL_FIELD(qn(args[1])),
- style.SQL_KEYWORD(args[2]),
- style.SQL_KEYWORD('NOT NULL')
- )
- SQL_ERROR = lambda self, style, qn, args: style.NOTICE('-- Error: %s' % style.ERROR(args[0]))
- SQL_COMMENT = lambda self, style, qn, args: style.NOTICE('-- Comment: %s' % style.SQL_TABLE(args[0]))
- SQL_TABLE_MISSING_IN_DB = lambda self, style, qn, args: style.NOTICE('-- Table missing: %s' % args[0])
- SQL_TABLE_MISSING_IN_MODEL = lambda self, style, qn, args: style.NOTICE('-- Model missing for table: %s' % args[0])
- can_detect_notnull_differ = False
- can_detect_unsigned_differ = False
- unsigned_suffix = None # type: Optional[str]
- def __init__(self, app_models, options, stdout, stderr):
- self.has_differences = None
- self.app_models = app_models
- self.options = options
- self.dense = options['dense_output']
- self.stdout = stdout
- self.stderr = stderr
- self.introspection = connection.introspection
- self.differences = []
- self.unknown_db_fields = {}
- self.new_db_fields = set()
- self.null = {}
- self.unsigned = set()
- self.DIFF_SQL = {
- 'error': self.SQL_ERROR,
- 'comment': self.SQL_COMMENT,
- 'table-missing-in-db': self.SQL_TABLE_MISSING_IN_DB,
- 'table-missing-in-model': self.SQL_TABLE_MISSING_IN_MODEL,
- 'field-missing-in-db': self.SQL_FIELD_MISSING_IN_DB,
- 'field-missing-in-model': self.SQL_FIELD_MISSING_IN_MODEL,
- 'fkey-missing-in-db': self.SQL_FKEY_MISSING_IN_DB,
- 'fkey-missing-in-model': self.SQL_FIELD_MISSING_IN_MODEL,
- 'index-missing-in-db': self.SQL_INDEX_MISSING_IN_DB,
- 'index-missing-in-model': self.SQL_INDEX_MISSING_IN_MODEL,
- 'unique-missing-in-db': self.SQL_UNIQUE_MISSING_IN_DB,
- 'unique-missing-in-model': self.SQL_UNIQUE_MISSING_IN_MODEL,
- 'field-type-differ': self.SQL_FIELD_TYPE_DIFFER,
- 'field-parameter-differ': self.SQL_FIELD_PARAMETER_DIFFER,
- 'notnull-differ': self.SQL_NOTNULL_DIFFER,
- }
- def load(self):
- self.cursor = connection.cursor()
- self.django_tables = self.introspection.django_table_names(only_existing=self.options['only_existing'])
- # TODO: We are losing information about tables which are views here
- self.db_tables = [table_info.name for table_info in self.introspection.get_table_list(self.cursor)]
- if self.can_detect_notnull_differ:
- self.load_null()
- if self.can_detect_unsigned_differ:
- self.load_unsigned()
- def load_null(self):
- raise NotImplementedError("load_null functions must be implemented if diff backend has 'can_detect_notnull_differ' set to True")
- def load_unsigned(self):
- raise NotImplementedError("load_unsigned function must be implemented if diff backend has 'can_detect_unsigned_differ' set to True")
- def add_app_model_marker(self, app_label, model_name):
- self.differences.append((app_label, model_name, []))
- def add_difference(self, diff_type, *args):
- assert diff_type in self.DIFF_TYPES, 'Unknown difference type'
- self.differences[-1][-1].append((diff_type, args))
- def get_data_types_reverse_override(self):
- # type: () -> Dict[int, Union[str, Callable]]
- return self.DATA_TYPES_REVERSE_OVERRIDE
- def format_field_names(self, field_names):
- return field_names
- def sql_to_dict(self, query, param):
- """
- Execute query and return a dict
- sql_to_dict(query, param) -> list of dicts
- code from snippet at http://www.djangosnippets.org/snippets/1383/
- """
- cursor = connection.cursor()
- cursor.execute(query, param)
- fieldnames = [name[0] for name in cursor.description]
- fieldnames = self.format_field_names(fieldnames)
- result = []
- for row in cursor.fetchall():
- rowset = []
- for field in zip(fieldnames, row):
- rowset.append(field)
- result.append(dict(rowset))
- return result
- def get_field_model_type(self, field):
- return field.db_type(connection=connection)
- def get_field_db_type_kwargs(self, current_kwargs, description, field=None, table_name=None, reverse_type=None):
- return {}
- def get_field_db_type(self, description, field=None, table_name=None):
- # DB-API cursor.description
- # (name, type_code, display_size, internal_size, precision, scale, null_ok) = description
- type_code = description[1]
- DATA_TYPES_REVERSE_OVERRIDE = self.get_data_types_reverse_override()
- if type_code in DATA_TYPES_REVERSE_OVERRIDE:
- reverse_type = DATA_TYPES_REVERSE_OVERRIDE[type_code]
- else:
- try:
- reverse_type = self.introspection.get_field_type(type_code, description)
- except KeyError:
- reverse_type = self.get_field_db_type_lookup(type_code)
- if not reverse_type:
- # type_code not found in data_types_reverse map
- key = (self.differences[-1][:2], description[:2])
- if key not in self.unknown_db_fields:
- self.unknown_db_fields[key] = 1
- self.add_difference('comment', "Unknown database type for field '%s' (%s)" % (description[0], type_code))
- return None
- if callable(reverse_type):
- reverse_type = reverse_type()
- kwargs = {}
- if isinstance(reverse_type, dict):
- kwargs.update(reverse_type['kwargs'])
- reverse_type = reverse_type['name']
- if type_code == 16946 and field and getattr(field, 'geom_type', None) == 'POINT':
- reverse_type = 'django.contrib.gis.db.models.fields.PointField'
- if isinstance(reverse_type, tuple):
- kwargs.update(reverse_type[1])
- reverse_type = reverse_type[0]
- if reverse_type == "CharField" and description[3]:
- kwargs['max_length'] = description[3]
- if reverse_type == "DecimalField":
- kwargs['max_digits'] = description[4]
- kwargs['decimal_places'] = description[5] and abs(description[5]) or description[5]
- if description[6]:
- kwargs['blank'] = True
- if reverse_type not in ('TextField', 'CharField'):
- kwargs['null'] = True
- if field and getattr(field, 'geography', False):
- kwargs['geography'] = True
- if reverse_type == 'GeometryField':
- geo_col = description[0]
- # Getting a more specific field type and any additional parameters
- # from the `get_geometry_type` routine for the spatial backend.
- reverse_type, geo_params = self.introspection.get_geometry_type(table_name, geo_col)
- if geo_params:
- kwargs.update(geo_params)
- reverse_type = 'django.contrib.gis.db.models.fields.%s' % reverse_type
- extra_kwargs = self.get_field_db_type_kwargs(kwargs, description, field, table_name, reverse_type)
- kwargs.update(extra_kwargs)
- field_class = self.get_field_class(reverse_type)
- field_db_type = field_class(**kwargs).db_type(connection=connection)
- tablespace = field.db_tablespace
- if not tablespace:
- tablespace = "public"
- if (tablespace, table_name, field.column) in self.unsigned and self.unsigned_suffix not in field_db_type:
- field_db_type = '%s %s' % (field_db_type, self.unsigned_suffix)
- return field_db_type
- def get_field_db_type_lookup(self, type_code):
- return None
- def get_field_class(self, class_path):
- if '.' in class_path:
- module_path, package_name = class_path.rsplit('.', 1)
- module = importlib.import_module(module_path)
- return getattr(module, package_name)
- return getattr(models, class_path)
- def get_field_db_nullable(self, field, table_name):
- tablespace = field.db_tablespace
- if tablespace == "":
- tablespace = "public"
- attname = field.db_column or field.attname
- return self.null.get((tablespace, table_name, attname), 'fixme')
- def strip_parameters(self, field_type):
- if field_type and field_type != 'double precision':
- return field_type.split(" ")[0].split("(")[0].lower()
- return field_type
- def expand_together(self, together, meta):
- new_together = []
- for fields in normalize_together(together):
- new_together.append(
- tuple(meta.get_field(field).attname for field in fields)
- )
- return new_together
- def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, table_name, skip_list=None):
- schema_editor = connection.SchemaEditorClass(connection)
- for field in all_local_fields(meta):
- if skip_list and field.attname in skip_list:
- continue
- if field.unique and meta.managed:
- attname = field.db_column or field.attname
- db_field_unique = table_indexes.get(attname, {}).get('unique')
- if not db_field_unique and table_constraints:
- db_field_unique = any(constraint['unique'] for contraint_name, constraint in six.iteritems(table_constraints) if [attname] == constraint['columns'])
- if attname in table_indexes and db_field_unique:
- continue
- index_name = schema_editor._create_index_name(table_name, [attname])
- self.add_difference('unique-missing-in-db', table_name, [attname], index_name + "_uniq")
- db_type = field.db_type(connection=connection)
- if db_type.startswith('varchar'):
- self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' varchar_pattern_ops')
- if db_type.startswith('text'):
- self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' text_pattern_ops')
- unique_together = self.expand_together(meta.unique_together, meta)
- db_unique_columns = normalize_together([v['columns'] for v in six.itervalues(table_constraints) if v['unique'] and not v['index']])
- for unique_columns in unique_together:
- if unique_columns in db_unique_columns:
- continue
- if skip_list and unique_columns in skip_list:
- continue
- index_name = schema_editor._create_index_name(table_name, unique_columns)
- self.add_difference('unique-missing-in-db', table_name, unique_columns, index_name + "_uniq")
- def find_unique_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
- fields = dict([(field.column, field) for field in all_local_fields(meta)])
- unique_together = self.expand_together(meta.unique_together, meta)
- for constraint_name, constraint in six.iteritems(table_constraints):
- if not constraint['unique']:
- continue
- if constraint['index']:
- # unique indexes are handled by find_index_missing_in_model
- continue
- columns = constraint['columns']
- if len(columns) == 1:
- field = fields.get(columns[0])
- if field is None:
- pass
- elif field.unique:
- continue
- else:
- if tuple(columns) in unique_together:
- continue
- self.add_difference('unique-missing-in-model', table_name, constraint_name)
- def find_index_missing_in_db(self, meta, table_indexes, table_constraints, table_name):
- schema_editor = connection.SchemaEditorClass(connection)
- for field in all_local_fields(meta):
- if field.db_index:
- attname = field.db_column or field.attname
- if attname not in table_indexes:
- index_name = schema_editor._create_index_name(table_name, [attname])
- self.add_difference('index-missing-in-db', table_name, [attname], index_name, '')
- db_type = field.db_type(connection=connection)
- if db_type.startswith('varchar'):
- self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' varchar_pattern_ops')
- if db_type.startswith('text'):
- self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' text_pattern_ops')
- index_together = self.expand_together(meta.index_together, meta)
- db_index_together = normalize_together([v['columns'] for v in six.itervalues(table_constraints) if v['index'] and not v['unique']])
- for columns in index_together:
- if columns in db_index_together:
- continue
- index_name = schema_editor._create_index_name(table_name, columns)
- self.add_difference('index-missing-in-db', table_name, columns, index_name + "_idx", '')
- for index in meta.indexes:
- if index.name not in table_constraints:
- self.add_difference('index-missing-in-db', table_name, index.fields, index.name, '')
- def find_index_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
- fields = dict([(field.column, field) for field in all_local_fields(meta)])
- meta_index_names = [idx.name for idx in meta.indexes]
- index_together = self.expand_together(meta.index_together, meta)
- for constraint_name, constraint in six.iteritems(table_constraints):
- if constraint_name in meta_index_names:
- continue
- if constraint['unique'] and not constraint['index']:
- # unique constraints are handled by find_unique_missing_in_model
- continue
- columns = constraint['columns']
- field = fields.get(columns[0])
- if (constraint['unique'] and constraint['index']) or field is None:
- # unique indexes do not exist in django ? only unique constraints
- pass
- elif len(columns) == 1:
- if constraint['primary_key'] and field.primary_key:
- continue
- if constraint['foreign_key'] and isinstance(field, models.ForeignKey) and field.db_constraint:
- continue
- if constraint['unique'] and field.unique:
- continue
- if constraint['index'] and constraint['type'] == 'idx' and constraint.get('orders') and field.unique:
- # django automatically creates a _like varchar_pattern_ops/text_pattern_ops index see https://code.djangoproject.com/ticket/12234
- # note: mysql does not have and/or introspect and fill the 'orders' attribute of constraint information
- continue
- if constraint['index'] and field.db_index:
- continue
- if constraint['check'] and field.db_check(connection=connection):
- continue
- if getattr(field, 'spatial_index', False):
- continue
- else:
- if constraint['index'] and tuple(columns) in index_together:
- continue
- self.add_difference('index-missing-in-model', table_name, constraint_name)
- def find_field_missing_in_model(self, fieldmap, table_description, table_name):
- for row in table_description:
- if row[0] not in fieldmap:
- self.add_difference('field-missing-in-model', table_name, row[0])
- def find_field_missing_in_db(self, fieldmap, table_description, table_name):
- db_fields = [row[0] for row in table_description]
- for field_name, field in six.iteritems(fieldmap):
- if field_name not in db_fields:
- field_output = []
- if field.remote_field:
- field_output.extend([field.remote_field.model._meta.db_table, field.remote_field.model._meta.get_field(field.remote_field.field_name).column])
- op = 'fkey-missing-in-db'
- else:
- op = 'field-missing-in-db'
- field_output.append(field.db_type(connection=connection))
- if self.options['include_defaults'] and field.has_default():
- field_output.append('DEFAULT %s' % field.get_prep_value(field.get_default()))
- if not field.null:
- field_output.append('NOT NULL')
- self.add_difference(op, table_name, field_name, *field_output)
- self.new_db_fields.add((table_name, field_name))
- def find_field_type_differ(self, meta, table_description, table_name, func=None):
- db_fields = dict([(row[0], row) for row in table_description])
- for field in all_local_fields(meta):
- if field.name not in db_fields:
- continue
- description = db_fields[field.name]
- model_type = self.get_field_model_type(field)
- db_type = self.get_field_db_type(description, field, table_name)
- # use callback function if defined
- if func:
- model_type, db_type = func(field, description, model_type, db_type)
- if not self.strip_parameters(db_type) == self.strip_parameters(model_type):
- self.add_difference('field-type-differ', table_name, field.name, model_type, db_type)
- def find_field_parameter_differ(self, meta, table_description, table_name, func=None):
- db_fields = dict([(row[0], row) for row in table_description])
- for field in all_local_fields(meta):
- if field.name not in db_fields:
- continue
- description = db_fields[field.name]
- model_type = self.get_field_model_type(field)
- db_type = self.get_field_db_type(description, field, table_name)
- if not self.strip_parameters(model_type) == self.strip_parameters(db_type):
- continue
- # use callback function if defined
- if func:
- model_type, db_type = func(field, description, model_type, db_type)
- model_check = field.db_parameters(connection=connection)['check']
- if ' CHECK' in db_type:
- db_type, db_check = db_type.split(" CHECK", 1)
- db_check = db_check.strip().lstrip("(").rstrip(")")
- else:
- db_check = None
- if not model_type == db_type or not model_check == db_check:
- self.add_difference('field-parameter-differ', table_name, field.name, model_type, db_type)
- def find_field_notnull_differ(self, meta, table_description, table_name):
- if not self.can_detect_notnull_differ:
- return
- for field in all_local_fields(meta):
- attname = field.db_column or field.attname
- if (table_name, attname) in self.new_db_fields:
- continue
- null = self.get_field_db_nullable(field, table_name)
- if field.null != null:
- action = field.null and 'DROP' or 'SET'
- self.add_difference('notnull-differ', table_name, attname, action)
- def get_constraints(self, cursor, table_name, introspection):
- return {}
- def find_differences(self):
- if self.options['all_applications']:
- self.add_app_model_marker(None, None)
- for table in self.db_tables:
- if table not in self.django_tables and table not in self.IGNORE_MISSING_TABLES:
- self.add_difference('table-missing-in-model', table)
- cur_app_label = None
- for app_model in self.app_models:
- meta = app_model._meta
- table_name = meta.db_table
- app_label = meta.app_label
- if not self.options['include_proxy_models'] and meta.proxy:
- continue
- if cur_app_label != app_label:
- # Marker indicating start of difference scan for this table_name
- self.add_app_model_marker(app_label, app_model.__name__)
- if table_name not in self.db_tables:
- # Table is missing from database
- self.add_difference('table-missing-in-db', table_name)
- continue
- if hasattr(self.introspection, 'get_constraints'):
- table_constraints = self.introspection.get_constraints(self.cursor, table_name)
- else:
- table_constraints = self.get_constraints(self.cursor, table_name, self.introspection)
- fieldmap = dict([(field.db_column or field.get_attname(), field) for field in all_local_fields(meta)])
- # add ordering field if model uses order_with_respect_to
- if meta.order_with_respect_to:
- fieldmap['_order'] = ORDERING_FIELD
- try:
- table_description = self.introspection.get_table_description(self.cursor, table_name)
- except Exception as e:
- self.add_difference('error', 'unable to introspect table: %s' % str(e).strip())
- transaction.rollback() # reset transaction
- continue
- # map table_contraints into table_indexes
- table_indexes = {}
- for contraint_name, dct in table_constraints.items():
- columns = dct['columns']
- if len(columns) == 1:
- table_indexes[columns[0]] = {
- 'primary_key': dct['primary_key'],
- 'unique': dct['unique'],
- 'type': dct.get('type'),
- 'contraint_name': contraint_name,
- }
- # Fields which are defined in database but not in model
- # 1) find: 'unique-missing-in-model'
- self.find_unique_missing_in_model(meta, table_indexes, table_constraints, table_name)
- # 2) find: 'index-missing-in-model'
- self.find_index_missing_in_model(meta, table_indexes, table_constraints, table_name)
- # 3) find: 'field-missing-in-model'
- self.find_field_missing_in_model(fieldmap, table_description, table_name)
- # Fields which are defined in models but not in database
- # 4) find: 'field-missing-in-db'
- self.find_field_missing_in_db(fieldmap, table_description, table_name)
- # 5) find: 'unique-missing-in-db'
- self.find_unique_missing_in_db(meta, table_indexes, table_constraints, table_name)
- # 6) find: 'index-missing-in-db'
- self.find_index_missing_in_db(meta, table_indexes, table_constraints, table_name)
- # Fields which have a different type or parameters
- # 7) find: 'type-differs'
- self.find_field_type_differ(meta, table_description, table_name)
- # 8) find: 'type-parameter-differs'
- self.find_field_parameter_differ(meta, table_description, table_name)
- # 9) find: 'field-notnull'
- self.find_field_notnull_differ(meta, table_description, table_name)
- self.has_differences = max([len(diffs) for _app_label, _model_name, diffs in self.differences])
- def print_diff(self, style=no_style()):
- """ Print differences to stdout """
- if self.options['sql']:
- self.print_diff_sql(style)
- else:
- self.print_diff_text(style)
- def print_diff_text(self, style):
- if not self.can_detect_notnull_differ:
- self.stdout.write(style.NOTICE("# Detecting notnull changes not implemented for this database backend"))
- self.stdout.write("")
- if not self.can_detect_unsigned_differ:
- self.stdout.write(style.NOTICE("# Detecting unsigned changes not implemented for this database backend"))
- self.stdout.write("")
- cur_app_label = None
- for app_label, model_name, diffs in self.differences:
- if not diffs:
- continue
- if not self.dense and app_label and cur_app_label != app_label:
- self.stdout.write("%s %s" % (style.NOTICE("+ Application:"), style.SQL_TABLE(app_label)))
- cur_app_label = app_label
- if not self.dense and model_name:
- self.stdout.write("%s %s" % (style.NOTICE("|-+ Differences for model:"), style.SQL_TABLE(model_name)))
- for diff in diffs:
- diff_type, diff_args = diff
- text = self.DIFF_TEXTS[diff_type] % dict(
- (str(i), style.SQL_TABLE(', '.join(e) if isinstance(e, (list, tuple)) else e))
- for i, e in enumerate(diff_args)
- )
- text = "'".join(i % 2 == 0 and style.ERROR(e) or e for i, e in enumerate(text.split("'")))
- if not self.dense:
- self.stdout.write("%s %s" % (style.NOTICE("|--+"), text))
- else:
- if app_label:
- self.stdout.write("%s %s %s %s %s" % (style.NOTICE("App"), style.SQL_TABLE(app_label), style.NOTICE('Model'), style.SQL_TABLE(model_name), text))
- else:
- self.stdout.write(text)
- def print_diff_sql(self, style):
- if not self.can_detect_notnull_differ:
- self.stdout.write(style.NOTICE("-- Detecting notnull changes not implemented for this database backend"))
- self.stdout.write("")
- cur_app_label = None
- qn = connection.ops.quote_name
- if not self.has_differences:
- if not self.dense:
- self.stdout.write(style.SQL_KEYWORD("-- No differences"))
- else:
- self.stdout.write(style.SQL_KEYWORD("BEGIN;"))
- for app_label, model_name, diffs in self.differences:
- if not diffs:
- continue
- if not self.dense and cur_app_label != app_label:
- self.stdout.write(style.NOTICE("-- Application: %s" % style.SQL_TABLE(app_label)))
- cur_app_label = app_label
- if not self.dense and model_name:
- self.stdout.write(style.NOTICE("-- Model: %s" % style.SQL_TABLE(model_name)))
- for diff in diffs:
- diff_type, diff_args = diff
- text = self.DIFF_SQL[diff_type](style, qn, diff_args)
- if self.dense:
- text = text.replace("\n\t", " ")
- self.stdout.write(text)
- self.stdout.write(style.SQL_KEYWORD("COMMIT;"))
- class GenericSQLDiff(SQLDiff):
- can_detect_notnull_differ = False
- can_detect_unsigned_differ = False
- def load_null(self):
- pass
- def load_unsigned(self):
- pass
- class MySQLDiff(SQLDiff):
- can_detect_notnull_differ = True
- can_detect_unsigned_differ = True
- unsigned_suffix = 'UNSIGNED'
- def load(self):
- super().load()
- self.auto_increment = set()
- self.load_auto_increment()
- def format_field_names(self, field_names):
- return [f.lower() for f in field_names]
- def load_null(self):
- tablespace = 'public'
- for table_name in self.db_tables:
- result = self.sql_to_dict("""
- SELECT column_name, is_nullable
- FROM information_schema.columns
- WHERE table_schema = DATABASE()
- AND table_name = %s""", [table_name])
- for table_info in result:
- key = (tablespace, table_name, table_info['column_name'])
- self.null[key] = table_info['is_nullable'] == 'YES'
- def load_unsigned(self):
- tablespace = 'public'
- for table_name in self.db_tables:
- result = self.sql_to_dict("""
- SELECT column_name
- FROM information_schema.columns
- WHERE table_schema = DATABASE()
- AND table_name = %s
- AND column_type LIKE '%%unsigned'""", [table_name])
- for table_info in result:
- key = (tablespace, table_name, table_info['column_name'])
- self.unsigned.add(key)
- def load_auto_increment(self):
- for table_name in self.db_tables:
- result = self.sql_to_dict("""
- SELECT column_name
- FROM information_schema.columns
- WHERE table_schema = DATABASE()
- AND table_name = %s
- AND extra = 'auto_increment'""", [table_name])
- for table_info in result:
- key = (table_name, table_info['column_name'])
- self.auto_increment.add(key)
- # All the MySQL hacks together create something of a problem
- # Fixing one bug in MySQL creates another issue. So just keep in mind
- # that this is way unreliable for MySQL atm.
- def get_field_db_type(self, description, field=None, table_name=None):
- db_type = super().get_field_db_type(description, field, table_name)
- if not db_type:
- return
- if field:
- # MySQL isn't really sure about char's and varchar's like sqlite
- field_type = self.get_field_model_type(field)
- # Fix char/varchar inconsistencies
- if self.strip_parameters(field_type) == 'char' and self.strip_parameters(db_type) == 'varchar':
- db_type = db_type.lstrip("var")
- # They like to call bools various integer types and introspection makes that a integer
- # just convert them all to bools
- if self.strip_parameters(field_type) == 'bool':
- if db_type == 'integer':
- db_type = 'bool'
- if (table_name, field.column) in self.auto_increment and 'AUTO_INCREMENT' not in db_type:
- db_type += ' AUTO_INCREMENT'
- return db_type
- def find_index_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
- fields = dict([(field.column, field) for field in all_local_fields(meta)])
- meta_index_names = [idx.name for idx in meta.indexes]
- index_together = self.expand_together(meta.index_together, meta)
- unique_together = self.expand_together(meta.unique_together, meta)
- for constraint_name, constraint in six.iteritems(table_constraints):
- if constraint_name in meta_index_names:
- continue
- if constraint['unique'] and not constraint['index']:
- # unique constraints are handled by find_unique_missing_in_model
- continue
- columns = constraint['columns']
- field = fields.get(columns[0])
- # extra check removed from superclass here, otherwise function is the same
- if len(columns) == 1:
- if constraint['primary_key'] and field.primary_key:
- continue
- if constraint['foreign_key'] and isinstance(field, models.ForeignKey) and field.db_constraint:
- continue
- if constraint['unique'] and field.unique:
- continue
- if constraint['index'] and constraint['type'] == 'idx' and constraint.get('orders') and field.unique:
- # django automatically creates a _like varchar_pattern_ops/text_pattern_ops index see https://code.djangoproject.com/ticket/12234
- # note: mysql does not have and/or introspect and fill the 'orders' attribute of constraint information
- continue
- if constraint['index'] and field.db_index:
- continue
- if constraint['check'] and field.db_check(connection=connection):
- continue
- if getattr(field, 'spatial_index', False):
- continue
- else:
- if constraint['index'] and tuple(columns) in index_together:
- continue
- if constraint['index'] and constraint['unique'] and tuple(columns) in unique_together:
- continue
- self.add_difference('index-missing-in-model', table_name, constraint_name)
- def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, table_name, skip_list=None):
- schema_editor = connection.SchemaEditorClass(connection)
- for field in all_local_fields(meta):
- if skip_list and field.attname in skip_list:
- continue
- if field.unique and meta.managed:
- attname = field.db_column or field.attname
- db_field_unique = table_indexes.get(attname, {}).get('unique')
- if not db_field_unique and table_constraints:
- db_field_unique = any(constraint['unique'] for contraint_name, constraint in six.iteritems(table_constraints) if [attname] == constraint['columns'])
- if attname in table_indexes and db_field_unique:
- continue
- index_name = schema_editor._create_index_name(table_name, [attname])
- self.add_difference('unique-missing-in-db', table_name, [attname], index_name + "_uniq")
- db_type = field.db_type(connection=connection)
- if db_type.startswith('varchar'):
- self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' varchar_pattern_ops')
- if db_type.startswith('text'):
- self.add_difference('index-missing-in-db', table_name, [attname], index_name + '_like', ' text_pattern_ops')
- unique_together = self.expand_together(meta.unique_together, meta)
- # This comparison changed from superclass - otherwise function is the same
- db_unique_columns = normalize_together([v['columns'] for v in six.itervalues(table_constraints) if v['unique']])
- for unique_columns in unique_together:
- if unique_columns in db_unique_columns:
- continue
- if skip_list and unique_columns in skip_list:
- continue
- index_name = schema_editor._create_index_name(table_name, unique_columns)
- self.add_difference('unique-missing-in-db', table_name, unique_columns, index_name + "_uniq")
- class SqliteSQLDiff(SQLDiff):
- can_detect_notnull_differ = True
- can_detect_unsigned_differ = False
- def load_null(self):
- for table_name in self.db_tables:
- # sqlite does not support tablespaces
- tablespace = "public"
- # index, column_name, column_type, nullable, default_value
- # see: http://www.sqlite.org/pragma.html#pragma_table_info
- for table_info in self.sql_to_dict("PRAGMA table_info('%s');" % table_name, []):
- key = (tablespace, table_name, table_info['name'])
- self.null[key] = not table_info['notnull']
- def load_unsigned(self):
- pass
- # Unique does not seem to be implied on Sqlite for Primary_key's
- # if this is more generic among databases this might be usefull
- # to add to the superclass's find_unique_missing_in_db method
- def find_unique_missing_in_db(self, meta, table_indexes, table_constraints, table_name, skip_list=None):
- if skip_list is None:
- skip_list = []
- unique_columns = [field.db_column or field.attname for field in all_local_fields(meta) if field.unique]
- for constraint in six.itervalues(table_constraints):
- columns = constraint['columns']
- if len(columns) == 1:
- column = columns[0]
- if column in unique_columns and (constraint['unique'] or constraint['primary_key']):
- skip_list.append(column)
- unique_together = self.expand_together(meta.unique_together, meta)
- db_unique_columns = normalize_together([v['columns'] for v in six.itervalues(table_constraints) if v['unique']])
- for unique_columns in unique_together:
- if unique_columns in db_unique_columns:
- skip_list.append(unique_columns)
- super().find_unique_missing_in_db(meta, table_indexes, table_constraints, table_name, skip_list=skip_list)
- # Finding Indexes by using the get_indexes dictionary doesn't seem to work
- # for sqlite.
- def find_index_missing_in_db(self, meta, table_indexes, table_constraints, table_name):
- pass
- def find_index_missing_in_model(self, meta, table_indexes, table_constraints, table_name):
- pass
- def get_field_db_type(self, description, field=None, table_name=None):
- db_type = super().get_field_db_type(description, field, table_name)
- if not db_type:
- return None
- if field:
- field_type = self.get_field_model_type(field)
- # Fix char/varchar inconsistencies
- if self.strip_parameters(field_type) == 'char' and self.strip_parameters(db_type) == 'varchar':
- db_type = db_type.lstrip("var")
- return db_type
- class PostgresqlSQLDiff(SQLDiff):
- can_detect_notnull_differ = True
- can_detect_unsigned_differ = True
- DATA_TYPES_REVERSE_NAME = {
- 'hstore': 'django.contrib.postgres.fields.HStoreField',
- 'jsonb': 'django.contrib.postgres.fields.JSONField',
- }
- # Hopefully in the future we can add constraint checking and other more
- # advanced checks based on this database.
- SQL_LOAD_CONSTRAINTS = """
- SELECT nspname, relname, conname, attname, pg_get_constraintdef(pg_constraint.oid)
- FROM pg_constraint
- INNER JOIN pg_attribute ON pg_constraint.conrelid = pg_attribute.attrelid AND pg_attribute.attnum = any(pg_constraint.conkey)
- INNER JOIN pg_class ON conrelid=pg_class.oid
- INNER JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace
- ORDER BY CASE WHEN contype='f' THEN 0 ELSE 1 END,contype,nspname,relname,conname;
- """
- SQL_LOAD_NULL = """
- SELECT nspname, relname, attname, attnotnull
- FROM pg_attribute
- INNER JOIN pg_class ON attrelid=pg_class.oid
- INNER JOIN pg_namespace ON pg_namespace.oid=pg_class.relnamespace;
- """
- SQL_FIELD_TYPE_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('ALTER'), style.SQL_FIELD(qn(args[1])), style.SQL_KEYWORD("TYPE"), style.SQL_COLTYPE(args[2]))
- SQL_FIELD_PARAMETER_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('ALTER'), style.SQL_FIELD(qn(args[1])), style.SQL_KEYWORD("TYPE"), style.SQL_COLTYPE(args[2]))
- SQL_NOTNULL_DIFFER = lambda self, style, qn, args: "%s %s\n\t%s %s %s %s;" % (style.SQL_KEYWORD('ALTER TABLE'), style.SQL_TABLE(qn(args[0])), style.SQL_KEYWORD('ALTER COLUMN'), style.SQL_FIELD(qn(args[1])), style.SQL_KEYWORD(args[2]), style.SQL_KEYWORD('NOT NULL'))
- def load(self):
- super().load()
- self.check_constraints = {}
- self.load_constraints()
- def load_null(self):
- for dct in self.sql_to_dict(self.SQL_LOAD_NULL, []):
- key = (dct['nspname'], dct['relname'], dct['attname'])
- self.null[key] = not dct['attnotnull']
- def load_unsigned(self):
- # PostgreSQL does not support unsigned, so no columns are
- # unsigned. Nothing to do.
- pass
- def load_constraints(self):
- for dct in self.sql_to_dict(self.SQL_LOAD_CONSTRAINTS, []):
- key = (dct['nspname'], dct['relname'], dct['attname'])
- if 'CHECK' in dct['pg_get_constraintdef']:
- self.check_constraints[key] = dct
- def get_data_type_arrayfield(self, base_field):
- return {
- 'name': 'django.contrib.postgres.fields.ArrayField',
- 'kwargs': {
- 'base_field': self.get_field_class(base_field)(),
- },
- }
- def get_data_types_reverse_override(self):
- return {
- 1042: 'CharField',
- 1000: lambda: self.get_data_type_arrayfield(base_field='BooleanField'),
- 1001: lambda: self.get_data_type_arrayfield(base_field='BinaryField'),
- 1002: lambda: self.get_data_type_arrayfield(base_field='CharField'),
- 1005: lambda: self.get_data_type_arrayfield(base_field='IntegerField'),
- 1006: lambda: self.get_data_type_arrayfield(base_field='IntegerField'),
- 1007: lambda: self.get_data_type_arrayfield(base_field='IntegerField'),
- 1009: lambda: self.get_data_type_arrayfield(base_field='CharField'),
- 1014: lambda: self.get_data_type_arrayfield(base_field='CharField'),
- 1015: lambda: self.get_data_type_arrayfield(base_field='CharField'),
- 1016: lambda: self.get_data_type_arrayfield(base_field='BigIntegerField'),
- 1017: lambda: self.get_data_type_arrayfield(base_field='FloatField'),
- 1021: lambda: self.get_data_type_arrayfield(base_field='FloatField'),
- 1022: lambda: self.get_data_type_arrayfield(base_field='FloatField'),
- 1115: lambda: self.get_data_type_arrayfield(base_field='DateTimeField'),
- 1185: lambda: self.get_data_type_arrayfield(base_field='DateTimeField'),
- 1231: lambda: self.get_data_type_arrayfield(base_field='DecimalField'),
- # {'name': 'django.contrib.postgres.fields.ArrayField', 'kwargs': {'base_field': 'IntegerField'}},
- 1186: lambda: self.get_data_type_arrayfield(base_field='DurationField'),
- # 1186: 'django.db.models.fields.DurationField',
- 3614: 'django.contrib.postgres.search.SearchVectorField',
- 3802: 'django.contrib.postgres.fields.JSONField',
- }
- def get_constraints(self, cursor, table_name, introspection):
- """
- Find constraints for table
- Backport of django's introspection.get_constraints(...)
- """
- constraints = {}
- # Loop over the key table, collecting things as constraints
- # This will get PKs, FKs, and uniques, but not CHECK
- cursor.execute("""
- SELECT
- kc.constraint_name,
- kc.column_name,
- c.constraint_type,
- array(SELECT table_name::text || '.' || column_name::text FROM information_schema.constraint_column_usage WHERE constraint_name = kc.constraint_name)
- FROM information_schema.key_column_usage AS kc
- JOIN information_schema.table_constraints AS c ON
- kc.table_schema = c.table_schema AND
- kc.table_name = c.table_name AND
- kc.constraint_name = c.constraint_name
- WHERE
- kc.table_schema = %s AND
- kc.table_name = %s
- """, ["public", table_name])
- for constraint, column, kind, used_cols in cursor.fetchall():
- # If we're the first column, make the record
- if constraint not in constraints:
- constraints[constraint] = {
- "columns": [],
- "primary_key": kind.lower() == "primary key",
- "unique": kind.lower() in ["primary key", "unique"],
- "foreign_key": tuple(used_cols[0].split(".", 1)) if kind.lower() == "foreign key" else None,
- "check": False,
- "index": False,
- }
- # Record the details
- constraints[constraint]['columns'].append(column)
- # Now get CHECK constraint columns
- cursor.execute("""
- SELECT kc.constraint_name, kc.column_name
- FROM information_schema.constraint_column_usage AS kc
- JOIN information_schema.table_constraints AS c ON
- kc.table_schema = c.table_schema AND
- kc.table_name = c.table_name AND
- kc.constraint_name = c.constraint_name
- WHERE
- c.constraint_type = 'CHECK' AND
- kc.table_schema = %s AND
- kc.table_name = %s
- """, ["public", table_name])
- for constraint, column in cursor.fetchall():
- # If we're the first column, make the record
- if constraint not in constraints:
- constraints[constraint] = {
- "columns": [],
- "primary_key": False,
- "unique": False,
- "foreign_key": None,
- "check": True,
- "index": False,
- }
- # Record the details
- constraints[constraint]['columns'].append(column)
- # Now get indexes
- cursor.execute("""
- SELECT
- c2.relname,
- ARRAY(
- SELECT (SELECT attname FROM pg_catalog.pg_attribute WHERE attnum = i AND attrelid = c.oid)
- FROM unnest(idx.indkey) i
- ),
- idx.indisunique,
- idx.indisprimary
- FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
- pg_catalog.pg_index idx
- WHERE c.oid = idx.indrelid
- AND idx.indexrelid = c2.oid
- AND c.relname = %s
- """, [table_name])
- for index, columns, unique, primary in cursor.fetchall():
- if index not in constraints:
- constraints[index] = {
- "columns": list(columns),
- "primary_key": primary,
- "unique": unique,
- "foreign_key": None,
- "check": False,
- "index": True,
- }
- return constraints
- # def get_field_db_type_kwargs(self, current_kwargs, description, field=None, table_name=None, reverse_type=None):
- # kwargs = {}
- # if field and 'base_field' in current_kwargs:
- # # find
- # attname = field.db_column or field.attname
- # introspect_db_type = self.sql_to_dict(
- # """SELECT attname, format_type(atttypid, atttypmod) AS type
- # FROM pg_attribute
- # WHERE attrelid = %s::regclass
- # AND attname = %s
- # AND attnum > 0
- # AND NOT attisdropped
- # ORDER BY attnum;
- # """,
- # (table_name, attname)
- # )[0]['type']
- # # TODO: this gives the concrete type that the database uses, why not use this
- # # much earlier in the process to compare to whatever django spits out as
- # # the database type ?
- # max_length = re.search("character varying\((\d+)\)\[\]", introspect_db_type)
- # if max_length:
- # kwargs['max_length'] = max_length[1]
- # return kwargs
- def get_field_db_type(self, description, field=None, table_name=None):
- db_type = super().get_field_db_type(description, field, table_name)
- if not db_type:
- return
- if field:
- if db_type.endswith("[]"):
- # TODO: This is a hack for array types. Ideally we either pass the correct
- # constraints for the type in `get_data_type_arrayfield` which instantiates
- # the array base_field or maybe even better restructure sqldiff entirely
- # to be based around the concrete type yielded by the code below. That gives
- # the complete type the database uses, why not use thie much earlier in the
- # process to compare to whatever django spits out as the desired database type ?
- attname = field.db_column or field.attname
- introspect_db_type = self.sql_to_dict(
- """SELECT attname, format_type(atttypid, atttypmod) AS type
- FROM pg_attribute
- WHERE attrelid = %s::regclass
- AND attname = %s
- AND attnum > 0
- AND NOT attisdropped
- ORDER BY attnum;
- """,
- (table_name, attname)
- )[0]['type']
- if introspect_db_type.startswith("character varying"):
- introspect_db_type = introspect_db_type.replace("character varying", "varchar")
- return introspect_db_type
- if field.primary_key and isinstance(field, AutoField):
- if db_type == 'integer':
- db_type = 'serial'
- elif db_type == 'bigint':
- db_type = 'bigserial'
- if table_name:
- tablespace = field.db_tablespace
- if tablespace == "":
- tablespace = "public"
- attname = field.db_column or field.attname
- check_constraint = self.check_constraints.get((tablespace, table_name, attname), {}).get('pg_get_constraintdef', None)
- if check_constraint:
- check_constraint = check_constraint.replace("((", "(")
- check_constraint = check_constraint.replace("))", ")")
- check_constraint = '("'.join([')' in e and '" '.join(p.strip('"') for p in e.split(" ", 1)) or e for e in check_constraint.split("(")])
- # TODO: might be more then one constraint in definition ?
- db_type += ' ' + check_constraint
- return db_type
- def get_field_db_type_lookup(self, type_code):
- try:
- name = self.sql_to_dict("SELECT typname FROM pg_type WHERE typelem=%s;", [type_code])[0]['typname']
- return self.DATA_TYPES_REVERSE_NAME.get(name.strip('_'))
- except (IndexError, KeyError):
- pass
- """
- def find_field_type_differ(self, meta, table_description, table_name):
- def callback(field, description, model_type, db_type):
- if field.primary_key and db_type=='integer':
- db_type = 'serial'
- return model_type, db_type
- super().find_field_type_differ(meta, table_description, table_name, callback)
- """
- DATABASE_SQLDIFF_CLASSES = {
- 'postgis': PostgresqlSQLDiff,
- 'postgresql_psycopg2': PostgresqlSQLDiff,
- 'postgresql': PostgresqlSQLDiff,
- 'mysql': MySQLDiff,
- 'sqlite3': SqliteSQLDiff,
- 'oracle': GenericSQLDiff
- }
- class Command(BaseCommand):
- help = """Prints the (approximated) difference between models and fields in the database for the given app name(s).
- It indicates how columns in the database are different from the sql that would
- be generated by Django. This command is not a database migration tool. (Though
- it can certainly help) It's purpose is to show the current differences as a way
- to check/debug ur models compared to the real database tables and columns."""
- output_transaction = False
- def add_arguments(self, parser):
- super().add_arguments(parser)
- parser.add_argument('app_label', nargs='*')
- parser.add_argument(
- '--all-applications', '-a', action='store_true',
- default=False,
- dest='all_applications',
- help="Automaticly include all application from INSTALLED_APPS."
- )
- parser.add_argument(
- '--not-only-existing', '-e', action='store_false',
- default=True,
- dest='only_existing',
- help="Check all tables that exist in the database, not only tables that should exist based on models."
- )
- parser.add_argument(
- '--dense-output', '-d', action='store_true', dest='dense_output',
- default=False,
- help="Shows the output in dense format, normally output is spreaded over multiple lines."
- )
- parser.add_argument(
- '--output_text', '-t', action='store_false', dest='sql',
- default=True,
- help="Outputs the differences as descriptive text instead of SQL"
- )
- parser.add_argument(
- '--include-proxy-models', action='store_true', dest='include_proxy_models',
- default=False,
- help="Include proxy models in the graph"
- )
- parser.add_argument(
- '--include-defaults', action='store_true', dest='include_defaults',
- default=False,
- help="Include default values in SQL output (beta feature)"
- )
- parser.add_argument(
- '--migrate-for-tests', action='store_true', dest='migrate_for_tests',
- default=False,
- help=argparse.SUPPRESS
- )
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.exit_code = 1
- @signalcommand
- def handle(self, *args, **options):
- from django.conf import settings
- app_labels = options['app_label']
- engine = None
- if hasattr(settings, 'DATABASES'):
- engine = settings.DATABASES['default']['ENGINE']
- else:
- engine = settings.DATABASE_ENGINE
- if engine == 'dummy':
- # This must be the "dummy" database backend, which means the user
- # hasn't set DATABASE_ENGINE.
- raise CommandError("""Django doesn't know which syntax to use for your SQL statements,
- because you haven't specified the DATABASE_ENGINE setting.
- Edit your settings file and change DATABASE_ENGINE to something like 'postgresql' or 'mysql'.""")
- if options['all_applications']:
- app_models = apps.get_models(include_auto_created=True)
- else:
- if not app_labels:
- raise CommandError('Enter at least one appname.')
- if not isinstance(app_labels, (list, tuple, set)):
- app_labels = [app_labels]
- app_models = []
- for app_label in app_labels:
- app_config = apps.get_app_config(app_label)
- app_models.extend(app_config.get_models(include_auto_created=True))
- if not app_models:
- raise CommandError('Unable to execute sqldiff no models founds.')
- migrate_for_tests = options['migrate_for_tests']
- if migrate_for_tests:
- from django.core.management import call_command
- call_command("migrate", *app_labels, no_input=True, run_syncdb=True)
- if not engine:
- engine = connection.__module__.split('.')[-2]
- if '.' in engine:
- engine = engine.split('.')[-1]
- cls = DATABASE_SQLDIFF_CLASSES.get(engine, GenericSQLDiff)
- sqldiff_instance = cls(app_models, options, stdout=self.stdout, stderr=self.stderr)
- sqldiff_instance.load()
- sqldiff_instance.find_differences()
- if not sqldiff_instance.has_differences:
- self.exit_code = 0
- sqldiff_instance.print_diff(self.style)
- def execute(self, *args, **options):
- try:
- super().execute(*args, **options)
- except CommandError as e:
- if options['traceback']:
- raise
- # self.stderr is not guaranteed to be set here
- stderr = getattr(self, 'stderr', None)
- if not stderr:
- stderr = OutputWrapper(sys.stderr, self.style.ERROR)
- stderr.write('%s: %s' % (e.__class__.__name__, e))
- sys.exit(2)
- def run_from_argv(self, argv):
- super().run_from_argv(argv)
- sys.exit(self.exit_code)
|