123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- # -*- coding: utf-8 -*-
- from django.apps import apps
- from django.contrib.contenttypes.fields import GenericForeignKey
- from django.core.management import BaseCommand
- from django.db import transaction
- from django_extensions.management.utils import signalcommand
- def get_model_to_deduplicate():
- models = apps.get_models()
- iterator = 1
- for model in models:
- print("%s. %s" % (iterator, model.__name__))
- iterator += 1
- model_choice = int(input("Enter the number of the model you would like to de-duplicate:"))
- model_to_deduplicate = models[model_choice - 1]
- return model_to_deduplicate
- def get_field_names(model):
- fields = [field.name for field in model._meta.get_fields()]
- iterator = 1
- for field in fields:
- print("%s. %s" % (iterator, field))
- iterator += 1
- validated = False
- while not validated:
- first_field = int(input("Enter the number of the (first) field you would like to de-duplicate."))
- if first_field in range(1, iterator):
- validated = True
- else:
- print("Invalid input. Please try again.")
- fields_to_deduplicate = [fields[first_field - 1]]
- done = False
- while not done:
- available_fields = [
- f for f in fields if f not in fields_to_deduplicate
- ]
- iterator = 1
- for field in available_fields:
- print("%s. %s" % (iterator, field))
- iterator += 1
- print("C. Done adding fields.")
- validated = False
- while not validated:
- print("You are currently deduplicating on the following fields:")
- print('\n'.join(fields_to_deduplicate) + '\n')
- additional_field = input("""
- Enter the number of the field you would like to de-duplicate.
- If you have entered all fields, enter C to continue.
- """)
- if additional_field == "C":
- done = True
- validated = True
- elif int(additional_field) in list(range(1, len(available_fields) + 1)):
- fields_to_deduplicate += [available_fields[int(additional_field) - 1]]
- validated = True
- else:
- print("Invalid input. Please try again.")
- return fields_to_deduplicate
- def keep_first_or_last_instance():
- while True:
- first_or_last = input("""
- Do you want to keep the first or last duplicate instance?
- Enter "first" or "last" to continue.
- """)
- if first_or_last in ["first", "last"]:
- return first_or_last
- def get_generic_fields():
- """Return a list of all GenericForeignKeys in all models."""
- generic_fields = []
- for model in apps.get_models():
- for field_name, field in model.__dict__.items():
- if isinstance(field, GenericForeignKey):
- generic_fields.append(field)
- return generic_fields
- class Command(BaseCommand):
- help = """
- Removes duplicate model instances based on a specified
- model and field name(s).
- Makes sure that any OneToOne, ForeignKey, or ManyToMany relationships
- attached to a deleted model(s) get reattached to the remaining model.
- Based on the following:
- https://djangosnippets.org/snippets/2283/
- https://stackoverflow.com/a/41291137/2532070
- https://gist.github.com/edelvalle/01886b6f79ba0c4dce66
- """
- @signalcommand
- def handle(self, *args, **options):
- model = get_model_to_deduplicate()
- field_names = get_field_names(model)
- first_or_last = keep_first_or_last_instance()
- total_deleted_objects_count = 0
- for instance in model.objects.all():
- kwargs = {}
- for field_name in field_names:
- instance_field_value = instance.__getattribute__(field_name)
- kwargs.update({
- field_name: instance_field_value
- })
- try:
- model.objects.get(**kwargs)
- except model.MultipleObjectsReturned:
- instances = model.objects.filter(**kwargs)
- if first_or_last == "first":
- primary_object = instances.first()
- alias_objects = instances.exclude(pk=primary_object.pk)
- elif first_or_last == "last":
- primary_object = instances.last()
- alias_objects = instances.exclude(pk=primary_object.pk)
- primary_object, deleted_objects, deleted_objects_count = self.merge_model_instances(primary_object, alias_objects)
- total_deleted_objects_count += deleted_objects_count
- print("Successfully deleted {} model instances.".format(total_deleted_objects_count))
- @transaction.atomic()
- def merge_model_instances(self, primary_object, alias_objects):
- """
- Merge several model instances into one, the `primary_object`.
- Use this function to merge model objects and migrate all of the related
- fields from the alias objects the primary object.
- """
- generic_fields = get_generic_fields()
- # get related fields
- related_fields = list(filter(
- lambda x: x.is_relation is True,
- primary_object._meta.get_fields()))
- many_to_many_fields = list(filter(
- lambda x: x.many_to_many is True, related_fields))
- related_fields = list(filter(
- lambda x: x.many_to_many is False, related_fields))
- # Loop through all alias objects and migrate their references to the
- # primary object
- deleted_objects = []
- deleted_objects_count = 0
- for alias_object in alias_objects:
- # Migrate all foreign key references from alias object to primary
- # object.
- for many_to_many_field in many_to_many_fields:
- alias_varname = many_to_many_field.name
- related_objects = getattr(alias_object, alias_varname)
- for obj in related_objects.all():
- try:
- # Handle regular M2M relationships.
- getattr(alias_object, alias_varname).remove(obj)
- getattr(primary_object, alias_varname).add(obj)
- except AttributeError:
- # Handle M2M relationships with a 'through' model.
- # This does not delete the 'through model.
- # TODO: Allow the user to delete a duplicate 'through' model.
- through_model = getattr(alias_object, alias_varname).through
- kwargs = {
- many_to_many_field.m2m_reverse_field_name(): obj,
- many_to_many_field.m2m_field_name(): alias_object,
- }
- through_model_instances = through_model.objects.filter(**kwargs)
- for instance in through_model_instances:
- # Re-attach the through model to the primary_object
- setattr(
- instance,
- many_to_many_field.m2m_field_name(),
- primary_object)
- instance.save()
- # TODO: Here, try to delete duplicate instances that are
- # disallowed by a unique_together constraint
- for related_field in related_fields:
- if related_field.one_to_many:
- alias_varname = related_field.get_accessor_name()
- related_objects = getattr(alias_object, alias_varname)
- for obj in related_objects.all():
- field_name = related_field.field.name
- setattr(obj, field_name, primary_object)
- obj.save()
- elif related_field.one_to_one or related_field.many_to_one:
- alias_varname = related_field.name
- related_object = getattr(alias_object, alias_varname)
- primary_related_object = getattr(primary_object, alias_varname)
- if primary_related_object is None:
- setattr(primary_object, alias_varname, related_object)
- primary_object.save()
- elif related_field.one_to_one:
- self.stdout.write("Deleted {} with id {}\n".format(
- related_object, related_object.id))
- related_object.delete()
- for field in generic_fields:
- filter_kwargs = {}
- filter_kwargs[field.fk_field] = alias_object._get_pk_val()
- filter_kwargs[field.ct_field] = field.get_content_type(alias_object)
- related_objects = field.model.objects.filter(**filter_kwargs)
- for generic_related_object in related_objects:
- setattr(generic_related_object, field.name, primary_object)
- generic_related_object.save()
- if alias_object.id:
- deleted_objects += [alias_object]
- self.stdout.write("Deleted {} with id {}\n".format(
- alias_object, alias_object.id))
- alias_object.delete()
- deleted_objects_count += 1
- return primary_object, deleted_objects, deleted_objects_count
|