fields.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from __future__ import unicode_literals
  2. from django.db.models.fields import NOT_PROVIDED
  3. from django.utils import six
  4. from .base import Operation
  5. class AddField(Operation):
  6. """
  7. Adds a field to a model.
  8. """
  9. def __init__(self, model_name, name, field, preserve_default=True):
  10. self.model_name = model_name
  11. self.name = name
  12. self.field = field
  13. self.preserve_default = preserve_default
  14. def state_forwards(self, app_label, state):
  15. # If preserve default is off, don't use the default for future state
  16. if not self.preserve_default:
  17. field = self.field.clone()
  18. field.default = NOT_PROVIDED
  19. else:
  20. field = self.field
  21. state.models[app_label, self.model_name.lower()].fields.append((self.name, field))
  22. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  23. from_model = from_state.render().get_model(app_label, self.model_name)
  24. to_model = to_state.render().get_model(app_label, self.model_name)
  25. if self.allowed_to_migrate(schema_editor.connection.alias, to_model):
  26. field = to_model._meta.get_field_by_name(self.name)[0]
  27. if not self.preserve_default:
  28. field.default = self.field.default
  29. schema_editor.add_field(
  30. from_model,
  31. field,
  32. )
  33. if not self.preserve_default:
  34. field.default = NOT_PROVIDED
  35. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  36. from_model = from_state.render().get_model(app_label, self.model_name)
  37. if self.allowed_to_migrate(schema_editor.connection.alias, from_model):
  38. schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0])
  39. def describe(self):
  40. return "Add field %s to %s" % (self.name, self.model_name)
  41. def __eq__(self, other):
  42. return (
  43. (self.__class__ == other.__class__) and
  44. (self.name == other.name) and
  45. (self.model_name.lower() == other.model_name.lower()) and
  46. (self.field.deconstruct()[1:] == other.field.deconstruct()[1:])
  47. )
  48. def references_model(self, name, app_label=None):
  49. return name.lower() == self.model_name.lower()
  50. def references_field(self, model_name, name, app_label=None):
  51. return self.references_model(model_name) and name.lower() == self.name.lower()
  52. class RemoveField(Operation):
  53. """
  54. Removes a field from a model.
  55. """
  56. def __init__(self, model_name, name):
  57. self.model_name = model_name
  58. self.name = name
  59. def state_forwards(self, app_label, state):
  60. new_fields = []
  61. for name, instance in state.models[app_label, self.model_name.lower()].fields:
  62. if name != self.name:
  63. new_fields.append((name, instance))
  64. state.models[app_label, self.model_name.lower()].fields = new_fields
  65. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  66. from_model = from_state.render().get_model(app_label, self.model_name)
  67. if self.allowed_to_migrate(schema_editor.connection.alias, from_model):
  68. schema_editor.remove_field(from_model, from_model._meta.get_field_by_name(self.name)[0])
  69. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  70. from_model = from_state.render().get_model(app_label, self.model_name)
  71. to_model = to_state.render().get_model(app_label, self.model_name)
  72. if self.allowed_to_migrate(schema_editor.connection.alias, to_model):
  73. schema_editor.add_field(from_model, to_model._meta.get_field_by_name(self.name)[0])
  74. def describe(self):
  75. return "Remove field %s from %s" % (self.name, self.model_name)
  76. def references_model(self, name, app_label=None):
  77. return name.lower() == self.model_name.lower()
  78. def references_field(self, model_name, name, app_label=None):
  79. return self.references_model(model_name) and name.lower() == self.name.lower()
  80. class AlterField(Operation):
  81. """
  82. Alters a field's database column (e.g. null, max_length) to the provided new field
  83. """
  84. def __init__(self, model_name, name, field):
  85. self.model_name = model_name
  86. self.name = name
  87. self.field = field
  88. def state_forwards(self, app_label, state):
  89. state.models[app_label, self.model_name.lower()].fields = [
  90. (n, self.field if n == self.name else f) for n, f in state.models[app_label, self.model_name.lower()].fields
  91. ]
  92. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  93. from_model = from_state.render().get_model(app_label, self.model_name)
  94. to_model = to_state.render().get_model(app_label, self.model_name)
  95. if self.allowed_to_migrate(schema_editor.connection.alias, to_model):
  96. from_field = from_model._meta.get_field_by_name(self.name)[0]
  97. to_field = to_model._meta.get_field_by_name(self.name)[0]
  98. # If the field is a relatedfield with an unresolved rel.to, just
  99. # set it equal to the other field side. Bandaid fix for AlterField
  100. # migrations that are part of a RenameModel change.
  101. if from_field.rel and from_field.rel.to:
  102. if isinstance(from_field.rel.to, six.string_types):
  103. from_field.rel.to = to_field.rel.to
  104. elif to_field.rel and isinstance(to_field.rel.to, six.string_types):
  105. to_field.rel.to = from_field.rel.to
  106. schema_editor.alter_field(from_model, from_field, to_field)
  107. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  108. self.database_forwards(app_label, schema_editor, from_state, to_state)
  109. def describe(self):
  110. return "Alter field %s on %s" % (self.name, self.model_name)
  111. def __eq__(self, other):
  112. return (
  113. (self.__class__ == other.__class__) and
  114. (self.name == other.name) and
  115. (self.model_name.lower() == other.model_name.lower()) and
  116. (self.field.deconstruct()[1:] == other.field.deconstruct()[1:])
  117. )
  118. def references_model(self, name, app_label=None):
  119. return name.lower() == self.model_name.lower()
  120. def references_field(self, model_name, name, app_label=None):
  121. return self.references_model(model_name) and name.lower() == self.name.lower()
  122. class RenameField(Operation):
  123. """
  124. Renames a field on the model. Might affect db_column too.
  125. """
  126. def __init__(self, model_name, old_name, new_name):
  127. self.model_name = model_name
  128. self.old_name = old_name
  129. self.new_name = new_name
  130. def state_forwards(self, app_label, state):
  131. # Rename the field
  132. state.models[app_label, self.model_name.lower()].fields = [
  133. (self.new_name if n == self.old_name else n, f) for n, f in state.models[app_label, self.model_name.lower()].fields
  134. ]
  135. # Fix unique_together to refer to the new field
  136. options = state.models[app_label, self.model_name.lower()].options
  137. if "unique_together" in options:
  138. options['unique_together'] = [
  139. [self.new_name if n == self.old_name else n for n in unique]
  140. for unique in options['unique_together']
  141. ]
  142. def database_forwards(self, app_label, schema_editor, from_state, to_state):
  143. from_model = from_state.render().get_model(app_label, self.model_name)
  144. to_model = to_state.render().get_model(app_label, self.model_name)
  145. if self.allowed_to_migrate(schema_editor.connection.alias, to_model):
  146. schema_editor.alter_field(
  147. from_model,
  148. from_model._meta.get_field_by_name(self.old_name)[0],
  149. to_model._meta.get_field_by_name(self.new_name)[0],
  150. )
  151. def database_backwards(self, app_label, schema_editor, from_state, to_state):
  152. from_model = from_state.render().get_model(app_label, self.model_name)
  153. to_model = to_state.render().get_model(app_label, self.model_name)
  154. if self.allowed_to_migrate(schema_editor.connection.alias, to_model):
  155. schema_editor.alter_field(
  156. from_model,
  157. from_model._meta.get_field_by_name(self.new_name)[0],
  158. to_model._meta.get_field_by_name(self.old_name)[0],
  159. )
  160. def describe(self):
  161. return "Rename field %s on %s to %s" % (self.old_name, self.model_name, self.new_name)
  162. def references_model(self, name, app_label=None):
  163. return name.lower() == self.model_name.lower()
  164. def references_field(self, model_name, name, app_label=None):
  165. return self.references_model(model_name) and (
  166. name.lower() == self.old_name.lower() or
  167. name.lower() == self.new_name.lower()
  168. )