Browse Source

Refs #29898 -- Changed fields in ProjectState's relation registry to dict.

Mariusz Felisiak 3 years ago
parent
commit
fa1d7ba5b9
3 changed files with 30 additions and 27 deletions
  1. 7 4
      django/db/migrations/autodetector.py
  2. 13 13
      django/db/migrations/state.py
  3. 10 10
      tests/migrations/test_state.py

+ 7 - 4
django/db/migrations/autodetector.py

@@ -483,7 +483,7 @@ class MigrationAutodetector:
                             fields = list(model_state.fields.values()) + [
                                 field.remote_field
                                 for relations in self.to_state.relations[app_label, model_name].values()
-                                for _, field in relations
+                                for field in relations.values()
                             ]
                             for field in fields:
                                 if field.is_relation:
@@ -672,7 +672,7 @@ class MigrationAutodetector:
             if (app_label, model_name) in self.old_proxy_keys:
                 for related_model_key, related_fields in relations[app_label, model_name].items():
                     related_model_state = self.to_state.models[related_model_key]
-                    for related_field_name, related_field in related_fields:
+                    for related_field_name, related_field in related_fields.items():
                         self.add_operation(
                             related_model_state.app_label,
                             operations.AlterField(
@@ -777,7 +777,7 @@ class MigrationAutodetector:
             for (related_object_app_label, object_name), relation_related_fields in (
                 relations[app_label, model_name].items()
             ):
-                for field_name, field in relation_related_fields:
+                for field_name, field in relation_related_fields.items():
                     dependencies.append(
                         (related_object_app_label, object_name, field_name, False),
                     )
@@ -1082,7 +1082,10 @@ class MigrationAutodetector:
         else:
             relations = project_state.relations[app_label, model_name]
             for (remote_app_label, remote_model_name), fields in relations.items():
-                if any(field == related_field.remote_field for _, related_field in fields):
+                if any(
+                    field == related_field.remote_field
+                    for related_field in fields.values()
+                ):
                     remote_field_model = f'{remote_app_label}.{remote_model_name}'
                     break
         # Account for FKs to swappable models

+ 13 - 13
django/db/migrations/state.py

@@ -97,7 +97,7 @@ class ProjectState:
             assert isinstance(real_apps, set)
         self.real_apps = real_apps
         self.is_delayed = False
-        # {remote_model_key: {model_key: [(field_name, field)]}}
+        # {remote_model_key: {model_key: {field_name: field}}}
         self._relations = None
 
     @property
@@ -302,14 +302,10 @@ class ProjectState:
             old_name_lower = old_name.lower()
             new_name_lower = new_name.lower()
             for to_model in self._relations.values():
-                # It's safe to modify the same collection that is iterated
-                # because `break` is called right after.
-                for field_name, field in to_model[model_key]:
-                    if field_name == old_name_lower:
-                        field.name = new_name_lower
-                        to_model[model_key].remove((old_name_lower, field))
-                        to_model[model_key].append((new_name_lower, field))
-                        break
+                if old_name_lower in to_model[model_key]:
+                    field = to_model[model_key].pop(old_name_lower)
+                    field.name = new_name_lower
+                    to_model[model_key][new_name_lower] = field
         self.reload_model(*model_key, delay=delay)
 
     def _find_reload_model(self, app_label, model_name, delay=False):
@@ -406,9 +402,13 @@ class ProjectState:
             remote_model_key = concretes[remote_model_key]
         relations_to_remote_model = self._relations[remote_model_key]
         if field_name in self.models[model_key].fields:
-            relations_to_remote_model[model_key].append((field_name, field))
+            # The assert holds because it's a new relation, or an altered
+            # relation, in which case references have been removed by
+            # alter_field().
+            assert field_name not in relations_to_remote_model[model_key]
+            relations_to_remote_model[model_key][field_name] = field
         else:
-            relations_to_remote_model[model_key].remove((field_name, field))
+            del relations_to_remote_model[model_key][field_name]
             if not relations_to_remote_model[model_key]:
                 del relations_to_remote_model[model_key]
 
@@ -444,8 +444,8 @@ class ProjectState:
             for field_name, field in model_state.fields.items():
                 field.name = field_name
         # Resolve relations.
-        # {remote_model_key: {model_key: [(field_name, field)]}}
-        self._relations = defaultdict(partial(defaultdict, list))
+        # {remote_model_key: {model_key: {field_name: field}}}
+        self._relations = defaultdict(partial(defaultdict, dict))
         concretes, proxies = self._get_concrete_models_mapping_and_proxy_models()
 
         for model_key in concretes:

+ 10 - 10
tests/migrations/test_state.py

@@ -1216,7 +1216,7 @@ class StateRelationsTests(SimpleTestCase):
         )
         self.assertEqual(
             project_state.relations['tests', 'post']['tests', 'post'],
-            [('next_post', new_field)],
+            {'next_post': new_field},
         )
         # Add a foreign key.
         new_field = models.ForeignKey('tests.post', models.CASCADE)
@@ -1229,7 +1229,7 @@ class StateRelationsTests(SimpleTestCase):
         )
         self.assertEqual(
             project_state.relations['tests', 'post']['tests', 'comment'],
-            [('post', new_field)],
+            {'post': new_field},
         )
 
     def test_add_field_m2m_with_through(self):
@@ -1271,7 +1271,7 @@ class StateRelationsTests(SimpleTestCase):
         )
         self.assertEqual(
             project_state.relations['tests', 'tag']['tests', 'post'],
-            [('tags', new_field)],
+            {'tags': new_field},
         )
 
     def test_remove_field(self):
@@ -1308,14 +1308,14 @@ class StateRelationsTests(SimpleTestCase):
         field = project_state.models['tests', 'comment'].fields['user']
         self.assertEqual(
             project_state.relations['tests', 'user']['tests', 'comment'],
-            [('user', field)],
+            {'user': field},
         )
 
         project_state.rename_field('tests', 'comment', 'user', 'author')
         renamed_field = project_state.models['tests', 'comment'].fields['author']
         self.assertEqual(
             project_state.relations['tests', 'user']['tests', 'comment'],
-            [('author', renamed_field)],
+            {'author': renamed_field},
         )
         self.assertEqual(field, renamed_field)
 
@@ -1357,7 +1357,7 @@ class StateRelationsTests(SimpleTestCase):
         )
         self.assertEqual(
             project_state.relations['tests', 'user']['tests', 'comment'],
-            [('user', m2m_field)],
+            {'user': m2m_field},
         )
 
     def test_alter_field_m2m_to_fk(self):
@@ -1387,7 +1387,7 @@ class StateRelationsTests(SimpleTestCase):
         )
         self.assertEqual(
             project_state.relations['tests_other', 'user_other']['tests', 'post'],
-            [('authors', foreign_key)],
+            {'authors': foreign_key},
         )
 
     def test_many_relations_to_same_model(self):
@@ -1403,14 +1403,14 @@ class StateRelationsTests(SimpleTestCase):
         comment_rels = project_state.relations['tests', 'user']['tests', 'comment']
         # Two foreign keys to the same model.
         self.assertEqual(len(comment_rels), 2)
-        self.assertEqual(comment_rels[1], ('reviewer', new_field))
+        self.assertEqual(comment_rels['reviewer'], new_field)
         # Rename the second foreign key.
         project_state.rename_field('tests', 'comment', 'reviewer', 'supervisor')
         self.assertEqual(len(comment_rels), 2)
-        self.assertEqual(comment_rels[1], ('supervisor', new_field))
+        self.assertEqual(comment_rels['supervisor'], new_field)
         # Remove the first foreign key.
         project_state.remove_field('tests', 'comment', 'user')
-        self.assertEqual(comment_rels, [('supervisor', new_field)])
+        self.assertEqual(comment_rels, {'supervisor': new_field})
 
 
 class ModelStateTests(SimpleTestCase):