merge_model_instances.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # -*- coding: utf-8 -*-
  2. from django.apps import apps
  3. from django.contrib.contenttypes.fields import GenericForeignKey
  4. from django.core.management import BaseCommand
  5. from django.db import transaction
  6. from django_extensions.management.utils import signalcommand
  7. def get_model_to_deduplicate():
  8. models = apps.get_models()
  9. iterator = 1
  10. for model in models:
  11. print("%s. %s" % (iterator, model.__name__))
  12. iterator += 1
  13. model_choice = int(input("Enter the number of the model you would like to de-duplicate:"))
  14. model_to_deduplicate = models[model_choice - 1]
  15. return model_to_deduplicate
  16. def get_field_names(model):
  17. fields = [field.name for field in model._meta.get_fields()]
  18. iterator = 1
  19. for field in fields:
  20. print("%s. %s" % (iterator, field))
  21. iterator += 1
  22. validated = False
  23. while not validated:
  24. first_field = int(input("Enter the number of the (first) field you would like to de-duplicate."))
  25. if first_field in range(1, iterator):
  26. validated = True
  27. else:
  28. print("Invalid input. Please try again.")
  29. fields_to_deduplicate = [fields[first_field - 1]]
  30. done = False
  31. while not done:
  32. available_fields = [
  33. f for f in fields if f not in fields_to_deduplicate
  34. ]
  35. iterator = 1
  36. for field in available_fields:
  37. print("%s. %s" % (iterator, field))
  38. iterator += 1
  39. print("C. Done adding fields.")
  40. validated = False
  41. while not validated:
  42. print("You are currently deduplicating on the following fields:")
  43. print('\n'.join(fields_to_deduplicate) + '\n')
  44. additional_field = input("""
  45. Enter the number of the field you would like to de-duplicate.
  46. If you have entered all fields, enter C to continue.
  47. """)
  48. if additional_field == "C":
  49. done = True
  50. validated = True
  51. elif int(additional_field) in list(range(1, len(available_fields) + 1)):
  52. fields_to_deduplicate += [available_fields[int(additional_field) - 1]]
  53. validated = True
  54. else:
  55. print("Invalid input. Please try again.")
  56. return fields_to_deduplicate
  57. def keep_first_or_last_instance():
  58. while True:
  59. first_or_last = input("""
  60. Do you want to keep the first or last duplicate instance?
  61. Enter "first" or "last" to continue.
  62. """)
  63. if first_or_last in ["first", "last"]:
  64. return first_or_last
  65. def get_generic_fields():
  66. """Return a list of all GenericForeignKeys in all models."""
  67. generic_fields = []
  68. for model in apps.get_models():
  69. for field_name, field in model.__dict__.items():
  70. if isinstance(field, GenericForeignKey):
  71. generic_fields.append(field)
  72. return generic_fields
  73. class Command(BaseCommand):
  74. help = """
  75. Removes duplicate model instances based on a specified
  76. model and field name(s).
  77. Makes sure that any OneToOne, ForeignKey, or ManyToMany relationships
  78. attached to a deleted model(s) get reattached to the remaining model.
  79. Based on the following:
  80. https://djangosnippets.org/snippets/2283/
  81. https://stackoverflow.com/a/41291137/2532070
  82. https://gist.github.com/edelvalle/01886b6f79ba0c4dce66
  83. """
  84. @signalcommand
  85. def handle(self, *args, **options):
  86. model = get_model_to_deduplicate()
  87. field_names = get_field_names(model)
  88. first_or_last = keep_first_or_last_instance()
  89. total_deleted_objects_count = 0
  90. for instance in model.objects.all():
  91. kwargs = {}
  92. for field_name in field_names:
  93. instance_field_value = instance.__getattribute__(field_name)
  94. kwargs.update({
  95. field_name: instance_field_value
  96. })
  97. try:
  98. model.objects.get(**kwargs)
  99. except model.MultipleObjectsReturned:
  100. instances = model.objects.filter(**kwargs)
  101. if first_or_last == "first":
  102. primary_object = instances.first()
  103. alias_objects = instances.exclude(pk=primary_object.pk)
  104. elif first_or_last == "last":
  105. primary_object = instances.last()
  106. alias_objects = instances.exclude(pk=primary_object.pk)
  107. primary_object, deleted_objects, deleted_objects_count = self.merge_model_instances(primary_object, alias_objects)
  108. total_deleted_objects_count += deleted_objects_count
  109. print("Successfully deleted {} model instances.".format(total_deleted_objects_count))
  110. @transaction.atomic()
  111. def merge_model_instances(self, primary_object, alias_objects):
  112. """
  113. Merge several model instances into one, the `primary_object`.
  114. Use this function to merge model objects and migrate all of the related
  115. fields from the alias objects the primary object.
  116. """
  117. generic_fields = get_generic_fields()
  118. # get related fields
  119. related_fields = list(filter(
  120. lambda x: x.is_relation is True,
  121. primary_object._meta.get_fields()))
  122. many_to_many_fields = list(filter(
  123. lambda x: x.many_to_many is True, related_fields))
  124. related_fields = list(filter(
  125. lambda x: x.many_to_many is False, related_fields))
  126. # Loop through all alias objects and migrate their references to the
  127. # primary object
  128. deleted_objects = []
  129. deleted_objects_count = 0
  130. for alias_object in alias_objects:
  131. # Migrate all foreign key references from alias object to primary
  132. # object.
  133. for many_to_many_field in many_to_many_fields:
  134. alias_varname = many_to_many_field.name
  135. related_objects = getattr(alias_object, alias_varname)
  136. for obj in related_objects.all():
  137. try:
  138. # Handle regular M2M relationships.
  139. getattr(alias_object, alias_varname).remove(obj)
  140. getattr(primary_object, alias_varname).add(obj)
  141. except AttributeError:
  142. # Handle M2M relationships with a 'through' model.
  143. # This does not delete the 'through model.
  144. # TODO: Allow the user to delete a duplicate 'through' model.
  145. through_model = getattr(alias_object, alias_varname).through
  146. kwargs = {
  147. many_to_many_field.m2m_reverse_field_name(): obj,
  148. many_to_many_field.m2m_field_name(): alias_object,
  149. }
  150. through_model_instances = through_model.objects.filter(**kwargs)
  151. for instance in through_model_instances:
  152. # Re-attach the through model to the primary_object
  153. setattr(
  154. instance,
  155. many_to_many_field.m2m_field_name(),
  156. primary_object)
  157. instance.save()
  158. # TODO: Here, try to delete duplicate instances that are
  159. # disallowed by a unique_together constraint
  160. for related_field in related_fields:
  161. if related_field.one_to_many:
  162. alias_varname = related_field.get_accessor_name()
  163. related_objects = getattr(alias_object, alias_varname)
  164. for obj in related_objects.all():
  165. field_name = related_field.field.name
  166. setattr(obj, field_name, primary_object)
  167. obj.save()
  168. elif related_field.one_to_one or related_field.many_to_one:
  169. alias_varname = related_field.name
  170. related_object = getattr(alias_object, alias_varname)
  171. primary_related_object = getattr(primary_object, alias_varname)
  172. if primary_related_object is None:
  173. setattr(primary_object, alias_varname, related_object)
  174. primary_object.save()
  175. elif related_field.one_to_one:
  176. self.stdout.write("Deleted {} with id {}\n".format(
  177. related_object, related_object.id))
  178. related_object.delete()
  179. for field in generic_fields:
  180. filter_kwargs = {}
  181. filter_kwargs[field.fk_field] = alias_object._get_pk_val()
  182. filter_kwargs[field.ct_field] = field.get_content_type(alias_object)
  183. related_objects = field.model.objects.filter(**filter_kwargs)
  184. for generic_related_object in related_objects:
  185. setattr(generic_related_object, field.name, primary_object)
  186. generic_related_object.save()
  187. if alias_object.id:
  188. deleted_objects += [alias_object]
  189. self.stdout.write("Deleted {} with id {}\n".format(
  190. alias_object, alias_object.id))
  191. alias_object.delete()
  192. deleted_objects_count += 1
  193. return primary_object, deleted_objects, deleted_objects_count