Jelajahi Sumber

Fixed #34587 -- Allowed customizing table name normalization in inspectdb command.

Andrii Kohut 1 tahun lalu
induk
melakukan
f8172f45fc

+ 6 - 5
django/core/management/commands/inspectdb.py

@@ -56,9 +56,6 @@ class Command(BaseCommand):
         # 'table_name_filter' is a stealth option
         table_name_filter = options.get("table_name_filter")
 
-        def table2model(table_name):
-            return re.sub(r"[^a-zA-Z0-9]", "", table_name.title())
-
         with connection.cursor() as cursor:
             yield "# This is an auto-generated Django model module."
             yield "# You'll have to do the following manually to clean this up:"
@@ -125,7 +122,7 @@ class Command(BaseCommand):
                     yield "# The error was: %s" % e
                     continue
 
-                model_name = table2model(table_name)
+                model_name = self.normalize_table_name(table_name)
                 yield ""
                 yield ""
                 yield "class %s(models.Model):" % model_name
@@ -180,7 +177,7 @@ class Command(BaseCommand):
                         rel_to = (
                             "self"
                             if ref_db_table == table_name
-                            else table2model(ref_db_table)
+                            else self.normalize_table_name(ref_db_table)
                         )
                         if rel_to in known_models:
                             field_type = "%s(%s" % (rel_type, rel_to)
@@ -322,6 +319,10 @@ class Command(BaseCommand):
 
         return new_name, field_params, field_notes
 
+    def normalize_table_name(self, table_name):
+        """Translate the table name to a Python-compatible model name."""
+        return re.sub(r"[^a-zA-Z0-9]", "", table_name.title())
+
     def get_field_type(self, connection, table_name, row):
         """
         Given the database connection, the table name, and the cursor row

+ 5 - 0
tests/inspectdb/models.py

@@ -51,6 +51,11 @@ class SpecialName(models.Model):
         db_table = "inspectdb_special.table name"
 
 
+class PascalCaseName(models.Model):
+    class Meta:
+        db_table = "inspectdb_pascal.PascalCase"
+
+
 class ColumnTypes(models.Model):
     id = models.AutoField(primary_key=True)
     big_int_field = models.BigIntegerField()

+ 20 - 0
tests/inspectdb/tests.py

@@ -3,6 +3,7 @@ from io import StringIO
 from unittest import mock, skipUnless
 
 from django.core.management import call_command
+from django.core.management.commands import inspectdb
 from django.db import connection
 from django.db.backends.base.introspection import TableInfo
 from django.test import TestCase, TransactionTestCase, skipUnlessDBFeature
@@ -354,6 +355,25 @@ class InspectDBTestCase(TestCase):
         output = out.getvalue()
         self.assertIn("class InspectdbSpecialTableName(models.Model):", output)
 
+    def test_custom_normalize_table_name(self):
+        def pascal_case_table_only(table_name):
+            return table_name.startswith("inspectdb_pascal")
+
+        class MyCommand(inspectdb.Command):
+            def normalize_table_name(self, table_name):
+                normalized_name = table_name.split(".")[1]
+                if connection.features.ignores_table_name_case:
+                    normalized_name = normalized_name.lower()
+                return normalized_name
+
+        out = StringIO()
+        call_command(MyCommand(), table_name_filter=pascal_case_table_only, stdout=out)
+        if connection.features.ignores_table_name_case:
+            expected_model_name = "pascalcase"
+        else:
+            expected_model_name = "PascalCase"
+        self.assertIn(f"class {expected_model_name}(models.Model):", out.getvalue())
+
     @skipUnlessDBFeature("supports_expression_indexes")
     def test_table_with_func_unique_constraint(self):
         out = StringIO()