Browse Source

Fixed #36158 -- Refactored shell command to improve auto-imported objects reporting.

Natalia 1 month ago
parent
commit
56e23b2319
3 changed files with 238 additions and 126 deletions
  1. 73 31
      django/core/management/commands/shell.py
  2. 36 14
      docs/howto/custom-shell.txt
  3. 129 81
      tests/shell/tests.py

+ 73 - 31
django/core/management/commands/shell.py

@@ -3,10 +3,12 @@ import select
 import sys
 import traceback
 from collections import defaultdict
+from importlib import import_module
 
 from django.apps import apps
 from django.core.management import BaseCommand, CommandError
 from django.utils.datastructures import OrderedSet
+from django.utils.module_loading import import_string as import_dotted_path
 
 
 class Command(BaseCommand):
@@ -54,18 +56,18 @@ class Command(BaseCommand):
     def ipython(self, options):
         from IPython import start_ipython
 
-        start_ipython(argv=[], user_ns=self.get_and_report_namespace(**options))
+        start_ipython(argv=[], user_ns=self.get_namespace(**options))
 
     def bpython(self, options):
         import bpython
 
-        bpython.embed(self.get_and_report_namespace(**options))
+        bpython.embed(self.get_namespace(**options))
 
     def python(self, options):
         import code
 
         # Set up a dictionary to serve as the environment for the shell.
-        imported_objects = self.get_and_report_namespace(**options)
+        imported_objects = self.get_namespace(**options)
 
         # We want to honor both $PYTHONSTARTUP and .pythonrc.py, so follow system
         # conventions and get $PYTHONSTARTUP first then .pythonrc.py.
@@ -118,16 +120,75 @@ class Command(BaseCommand):
         # Start the interactive interpreter.
         code.interact(local=imported_objects)
 
-    def get_and_report_namespace(self, **options):
+    def get_auto_imports(self):
+        """Return a sequence of import paths for objects to be auto-imported.
+
+        By default, import paths for models in INSTALLED_APPS are included,
+        with models from earlier apps taking precedence in case of a name
+        collision.
+
+        For example, for an unchanged INSTALLED_APPS, this method returns:
+
+        [
+            "django.contrib.sessions.models.Session",
+            "django.contrib.contenttypes.models.ContentType",
+            "django.contrib.auth.models.User",
+            "django.contrib.auth.models.Group",
+            "django.contrib.auth.models.Permission",
+            "django.contrib.admin.models.LogEntry",
+        ]
+
+        """
+        app_models_imports = [
+            f"{model.__module__}.{model.__name__}"
+            for model in reversed(apps.get_models())
+            if model.__module__
+        ]
+        return app_models_imports
+
+    def get_namespace(self, **options):
         if options and options.get("no_imports"):
             return {}
 
-        namespace = self.get_namespace()
+        path_imports = self.get_auto_imports()
+        if path_imports is None:
+            return {}
+
+        auto_imports = defaultdict(list)
+        import_errors = []
+        for path in path_imports:
+            try:
+                obj = import_dotted_path(path) if "." in path else import_module(path)
+            except ImportError:
+                import_errors.append(path)
+                continue
+
+            if "." in path:
+                module, name = path.rsplit(".", 1)
+            else:
+                module = None
+                name = path
+
+            auto_imports[module].append((name, obj))
+
+        namespace = {
+            name: obj for items in auto_imports.values() for name, obj in items
+        }
 
         verbosity = options["verbosity"] if options else 0
         if verbosity < 1:
             return namespace
 
+        errors = len(import_errors)
+        if errors:
+            msg = "\n".join(f"  {e}" for e in import_errors)
+            objects = "objects" if errors != 1 else "object"
+            self.stdout.write(
+                f"{errors} {objects} could not be automatically imported:\n\n{msg}",
+                self.style.ERROR,
+                ending="\n\n",
+            )
+
         amount = len(namespace)
         objects_str = "objects" if amount != 1 else "object"
         msg = f"{amount} {objects_str} imported automatically"
@@ -135,27 +196,16 @@ class Command(BaseCommand):
         if verbosity < 2:
             if amount:
                 msg += " (use -v 2 for details)"
-            self.stdout.write(f"{msg}.", self.style.SUCCESS)
+            self.stdout.write(f"{msg}.", self.style.SUCCESS, ending="\n\n")
             return namespace
 
-        imports_by_module = defaultdict(list)
-        for obj_name, obj in namespace.items():
-            if hasattr(obj, "__module__") and (
-                (hasattr(obj, "__qualname__") and obj.__qualname__.find(".") == -1)
-                or not hasattr(obj, "__qualname__")
-            ):
-                imports_by_module[obj.__module__].append(obj_name)
-            if not hasattr(obj, "__module__") and hasattr(obj, "__name__"):
-                tokens = obj.__name__.split(".")
-                if obj_name in tokens:
-                    module = ".".join(t for t in tokens if t != obj_name)
-                    imports_by_module[module].append(obj_name)
-
+        top_level = auto_imports.pop(None, [])
         import_string = "\n".join(
-            [
+            [f"  import {obj}" for obj, _ in top_level]
+            + [
                 f"  from {module} import {objects}"
-                for module, imported_objects in imports_by_module.items()
-                if (objects := ", ".join(imported_objects))
+                for module, imported_objects in auto_imports.items()
+                if (objects := ", ".join(i[0] for i in imported_objects))
             ]
         )
 
@@ -167,7 +217,7 @@ class Command(BaseCommand):
             import_string = isort.code(import_string)
 
         if import_string:
-            msg = f"{msg}, including:\n\n{import_string}"
+            msg = f"{msg}:\n\n{import_string}"
         else:
             msg = f"{msg}."
 
@@ -175,14 +225,6 @@ class Command(BaseCommand):
 
         return namespace
 
-    def get_namespace(self):
-        apps_models = apps.get_models()
-        namespace = {}
-        for model in reversed(apps_models):
-            if model.__module__:
-                namespace[model.__name__] = model
-        return namespace
-
     def handle(self, **options):
         # Execute the command and exit.
         if options["command"]:

+ 36 - 14
docs/howto/custom-shell.txt

@@ -20,7 +20,9 @@ Customize automatic imports
 .. versionadded:: 5.2
 
 To customize the automatic import behavior of the :djadmin:`shell` management
-command, override the ``get_namespace()`` method. For example:
+command, override the ``get_auto_imports()`` method. This method should return
+a sequence of import paths for objects or modules available in the application.
+For example:
 
 .. code-block:: python
     :caption: ``polls/management/commands/shell.py``
@@ -29,16 +31,36 @@ command, override the ``get_namespace()`` method. For example:
 
 
     class Command(shell.Command):
-        def get_namespace(self):
-            from django.urls.base import resolve, reverse
-
-            return {
-                **super().get_namespace(),
-                "resolve": resolve,
-                "reverse": reverse,
-            }
-
-The above customization adds :func:`~django.urls.resolve` and
-:func:`~django.urls.reverse` to the default namespace, which includes all
-models from all apps. These two functions will then be available when the
-shell opens, without a manual import statement.
+        def get_auto_imports(self):
+            return super().get_auto_imports() + [
+                "django.urls.reverse",
+                "django.urls.resolve",
+            ]
+
+The customization above adds :func:`~django.urls.resolve` and
+:func:`~django.urls.reverse` to the default namespace, which already includes
+all models from the apps listed in :setting:`INSTALLED_APPS`. These objects
+will be available in the ``shell`` without requiring a manual import.
+
+Running this customized ``shell`` command with ``verbosity=2`` would show:
+
+.. console::
+
+    8 objects imported automatically:
+
+      from django.contrib.admin.models import LogEntry
+      from django.contrib.auth.models import Group, Permission, User
+      from django.contrib.contenttypes.models import ContentType
+      from django.contrib.sessions.models import Session
+      from django.urls import resolve, reverse
+
+If an overridden ``shell`` command includes paths that cannot be imported,
+these errors are shown when ``verbosity`` is set to ``1`` or higher.
+
+Note that automatic imports can be disabled for a specific ``shell`` session
+using the :option:`--no-imports <shell --no-imports>` flag. To permanently
+disable automatic imports, override ``get_auto_imports()`` to return ``None``::
+
+    class Command(shell.Command):
+        def get_auto_imports(self):
+            return None

+ 129 - 81
tests/shell/tests.py

@@ -7,15 +7,10 @@ from django.contrib.auth.models import Group, Permission, User
 from django.contrib.contenttypes.models import ContentType
 from django.core.management import CommandError, call_command
 from django.core.management.commands import shell
-from django.db import connection, models
+from django.db import connection
 from django.test import SimpleTestCase
-from django.test.utils import (
-    captured_stdin,
-    captured_stdout,
-    isolate_apps,
-    override_settings,
-)
-from django.urls.base import resolve, reverse
+from django.test.utils import captured_stdin, captured_stdout, override_settings
+from django.urls import resolve, reverse
 
 from .models import Marker, Phone
 
@@ -92,7 +87,7 @@ class ShellCommandTestCase(SimpleTestCase):
 
         self.assertEqual(
             mock_ipython.start_ipython.mock_calls,
-            [mock.call(argv=[], user_ns=cmd.get_and_report_namespace(**options))],
+            [mock.call(argv=[], user_ns=cmd.get_namespace(**options))],
         )
 
     @mock.patch("django.core.management.commands.shell.select.select")  # [1]
@@ -113,8 +108,7 @@ class ShellCommandTestCase(SimpleTestCase):
             cmd.bpython(options)
 
         self.assertEqual(
-            mock_bpython.embed.mock_calls,
-            [mock.call(cmd.get_and_report_namespace(**options))],
+            mock_bpython.embed.mock_calls, [mock.call(cmd.get_namespace(**options))]
         )
 
     @mock.patch("django.core.management.commands.shell.select.select")  # [1]
@@ -136,7 +130,7 @@ class ShellCommandTestCase(SimpleTestCase):
 
         self.assertEqual(
             mock_code.interact.mock_calls,
-            [mock.call(local=cmd.get_and_report_namespace(**options))],
+            [mock.call(local=cmd.get_namespace(**options))],
         )
 
     # [1] Patch select to prevent tests failing when the test suite is run
@@ -167,42 +161,35 @@ class ShellCommandAutoImportsTestCase(SimpleTestCase):
             },
         )
 
-    @override_settings(INSTALLED_APPS=["basic", "shell"])
-    @isolate_apps("basic", "shell", kwarg_name="apps")
-    def test_get_namespace_precedence(self, apps):
-        class Article(models.Model):
-            class Meta:
-                app_label = "basic"
-
-        winner_article = Article
-
-        class Article(models.Model):
-            class Meta:
-                app_label = "shell"
+    @override_settings(
+        INSTALLED_APPS=["model_forms", "contenttypes_tests", "forms_tests"]
+    )
+    def test_get_namespace_precedence(self):
+        # All of these apps define an `Article` model. The one defined first in
+        # INSTALLED_APPS, takes precedence.
+        import model_forms.models
 
-        with mock.patch("django.apps.apps.get_models", return_value=apps.get_models()):
-            namespace = shell.Command().get_namespace()
-            self.assertEqual(namespace, {"Article": winner_article})
+        namespace = shell.Command().get_namespace()
+        self.assertIs(namespace.get("Article"), model_forms.models.Article)
 
     @override_settings(
         INSTALLED_APPS=["shell", "django.contrib.auth", "django.contrib.contenttypes"]
     )
     def test_get_namespace_overridden(self):
         class TestCommand(shell.Command):
-            def get_namespace(self):
-                from django.urls.base import resolve, reverse
-
-                return {
-                    **super().get_namespace(),
-                    "resolve": resolve,
-                    "reverse": reverse,
-                }
+            def get_auto_imports(self):
+                return super().get_auto_imports() + [
+                    "django.urls.reverse",
+                    "django.urls.resolve",
+                    "django.db.connection",
+                ]
 
         namespace = TestCommand().get_namespace()
 
         self.assertEqual(
             namespace,
             {
+                "connection": connection,
                 "resolve": resolve,
                 "reverse": reverse,
                 "Marker": Marker,
@@ -220,7 +207,7 @@ class ShellCommandAutoImportsTestCase(SimpleTestCase):
     def test_no_imports_flag(self):
         for verbosity in (0, 1, 2, 3):
             with self.subTest(verbosity=verbosity), captured_stdout() as stdout:
-                namespace = shell.Command().get_and_report_namespace(
+                namespace = shell.Command().get_namespace(
                     verbosity=verbosity, no_imports=True
                 )
             self.assertEqual(namespace, {})
@@ -232,8 +219,8 @@ class ShellCommandAutoImportsTestCase(SimpleTestCase):
     def test_verbosity_zero(self):
         with captured_stdout() as stdout:
             cmd = shell.Command()
-            namespace = cmd.get_and_report_namespace(verbosity=0)
-        self.assertEqual(namespace, cmd.get_namespace())
+            namespace = cmd.get_namespace(verbosity=0)
+        self.assertEqual(len(namespace), len(cmd.get_auto_imports()))
         self.assertEqual(stdout.getvalue().strip(), "")
 
     @override_settings(
@@ -242,8 +229,8 @@ class ShellCommandAutoImportsTestCase(SimpleTestCase):
     def test_verbosity_one(self):
         with captured_stdout() as stdout:
             cmd = shell.Command()
-            namespace = cmd.get_and_report_namespace(verbosity=1)
-        self.assertEqual(namespace, cmd.get_namespace())
+            namespace = cmd.get_namespace(verbosity=1)
+        self.assertEqual(len(namespace), len(cmd.get_auto_imports()))
         self.assertEqual(
             stdout.getvalue().strip(),
             "6 objects imported automatically (use -v 2 for details).",
@@ -253,55 +240,51 @@ class ShellCommandAutoImportsTestCase(SimpleTestCase):
     @mock.patch.dict(sys.modules, {"isort": None})
     def test_message_with_stdout_listing_objects_with_isort_not_installed(self):
         class TestCommand(shell.Command):
-            def get_namespace(self):
-                class MyClass:
-                    pass
-
-                constant = "constant"
-
-                return {
-                    **super().get_namespace(),
-                    "MyClass": MyClass,
-                    "constant": constant,
-                }
+            def get_auto_imports(self):
+                return super().get_auto_imports() + [
+                    "django.urls.reverse",
+                    "django.urls.resolve",
+                    "shell",
+                    "django",
+                ]
 
         with captured_stdout() as stdout:
-            TestCommand().get_and_report_namespace(verbosity=2)
+            TestCommand().get_namespace(verbosity=2)
 
         self.assertEqual(
             stdout.getvalue().strip(),
-            "5 objects imported automatically, including:\n\n"
+            "7 objects imported automatically:\n\n"
+            "  import shell\n"
+            "  import django\n"
             "  from django.contrib.contenttypes.models import ContentType\n"
-            "  from shell.models import Phone, Marker",
+            "  from shell.models import Phone, Marker\n"
+            "  from django.urls import reverse, resolve",
         )
 
     def test_message_with_stdout_one_object(self):
         class TestCommand(shell.Command):
-            def get_namespace(self):
-                return {"connection": connection}
+            def get_auto_imports(self):
+                return ["django.db.connection"]
 
         with captured_stdout() as stdout:
-            TestCommand().get_and_report_namespace(verbosity=2)
+            TestCommand().get_namespace(verbosity=2)
 
         cases = {
             0: "",
             1: "1 object imported automatically (use -v 2 for details).",
             2: (
-                "1 object imported automatically, including:\n\n"
-                "  from django.utils.connection import connection"
+                "1 object imported automatically:\n\n"
+                "  from django.db import connection"
             ),
         }
         for verbosity, expected in cases.items():
             with self.subTest(verbosity=verbosity):
                 with captured_stdout() as stdout:
-                    TestCommand().get_and_report_namespace(verbosity=verbosity)
+                    TestCommand().get_namespace(verbosity=verbosity)
                     self.assertEqual(stdout.getvalue().strip(), expected)
 
-    def test_message_with_stdout_zero_objects(self):
-        class TestCommand(shell.Command):
-            def get_namespace(self):
-                return {}
-
+    @override_settings(INSTALLED_APPS=[])
+    def test_message_with_stdout_no_installed_apps(self):
         cases = {
             0: "",
             1: "0 objects imported automatically.",
@@ -310,9 +293,21 @@ class ShellCommandAutoImportsTestCase(SimpleTestCase):
         for verbosity, expected in cases.items():
             with self.subTest(verbosity=verbosity):
                 with captured_stdout() as stdout:
-                    TestCommand().get_and_report_namespace(verbosity=verbosity)
+                    shell.Command().get_namespace(verbosity=verbosity)
                     self.assertEqual(stdout.getvalue().strip(), expected)
 
+    def test_message_with_stdout_overriden_none_result(self):
+        class TestCommand(shell.Command):
+            def get_auto_imports(self):
+                return None
+
+        for verbosity in [0, 1, 2]:
+            with self.subTest(verbosity=verbosity):
+                with captured_stdout() as stdout:
+                    result = TestCommand().get_namespace(verbosity=verbosity)
+                    self.assertEqual(result, {})
+                    self.assertEqual(stdout.getvalue().strip(), "")
+
     @override_settings(INSTALLED_APPS=["shell", "django.contrib.contenttypes"])
     def test_message_with_stdout_listing_objects_with_isort(self):
         sorted_imports = (
@@ -322,27 +317,80 @@ class ShellCommandAutoImportsTestCase(SimpleTestCase):
         mock_isort_code = mock.Mock(code=mock.MagicMock(return_value=sorted_imports))
 
         class TestCommand(shell.Command):
-            def get_namespace(self):
-                class MyClass:
-                    pass
-
-                constant = "constant"
-
-                return {
-                    **super().get_namespace(),
-                    "MyClass": MyClass,
-                    "constant": constant,
-                }
+            def get_auto_imports(self):
+                return super().get_auto_imports() + [
+                    "django.urls.reverse",
+                    "django.urls.resolve",
+                    "django",
+                ]
 
         with (
             mock.patch.dict(sys.modules, {"isort": mock_isort_code}),
             captured_stdout() as stdout,
         ):
-            TestCommand().get_and_report_namespace(verbosity=2)
+            TestCommand().get_namespace(verbosity=2)
 
         self.assertEqual(
             stdout.getvalue().strip(),
-            "5 objects imported automatically, including:\n\n"
-            "  from shell.models import Marker, Phone\n\n"
-            "  from django.contrib.contenttypes.models import ContentType",
+            "6 objects imported automatically:\n\n" + sorted_imports,
+        )
+
+    def test_override_get_auto_imports(self):
+        class TestCommand(shell.Command):
+            def get_auto_imports(self):
+                return [
+                    "model_forms",
+                    "shell",
+                    "does.not.exist",
+                    "doesntexisteither",
+                ]
+
+        with captured_stdout() as stdout:
+            TestCommand().get_namespace(verbosity=2)
+
+        expected = (
+            "2 objects could not be automatically imported:\n\n"
+            "  does.not.exist\n"
+            "  doesntexisteither\n\n"
+            "2 objects imported automatically:\n\n"
+            "  import model_forms\n"
+            "  import shell\n\n"
+        )
+        self.assertEqual(stdout.getvalue(), expected)
+
+    def test_override_get_auto_imports_one_error(self):
+        class TestCommand(shell.Command):
+            def get_auto_imports(self):
+                return [
+                    "foo",
+                ]
+
+        expected = (
+            "1 object could not be automatically imported:\n\n  foo\n\n"
+            "0 objects imported automatically.\n\n"
+        )
+        for verbosity, expected in [(0, ""), (1, expected), (2, expected)]:
+            with self.subTest(verbosity=verbosity):
+                with captured_stdout() as stdout:
+                    TestCommand().get_namespace(verbosity=verbosity)
+                    self.assertEqual(stdout.getvalue(), expected)
+
+    def test_override_get_auto_imports_many_errors(self):
+        class TestCommand(shell.Command):
+            def get_auto_imports(self):
+                return [
+                    "does.not.exist",
+                    "doesntexisteither",
+                ]
+
+        expected = (
+            "2 objects could not be automatically imported:\n\n"
+            "  does.not.exist\n"
+            "  doesntexisteither\n\n"
+            "0 objects imported automatically.\n\n"
         )
+        for verbosity, expected in [(0, ""), (1, expected), (2, expected)]:
+            with self.subTest(verbosity=verbosity):
+                with captured_stdout() as stdout:
+                    TestCommand().get_namespace(verbosity=verbosity)
+                    self.assertEqual(stdout.getvalue(), expected)