Browse Source

Fixed #26315 -- Allowed call_command() to accept a Command object as the first argument.

Jon Dufresne 9 years ago
parent
commit
4115288b4f

+ 23 - 10
django/core/management/__init__.py

@@ -82,22 +82,35 @@ def call_command(name, *args, **options):
 
     This is the primary API you should use for calling specific commands.
 
+    `name` may be a string or a command object. Using a string is preferred
+    unless the command object is required for further processing or testing.
+
     Some examples:
         call_command('migrate')
         call_command('shell', plain=True)
         call_command('sqlmigrate', 'myapp')
+
+        from django.core.management.commands import flush
+        cmd = flush.Command()
+        call_command(cmd, verbosity=0, interactive=False)
+        # Do something with cmd ...
     """
-    # Load the command object.
-    try:
-        app_name = get_commands()[name]
-    except KeyError:
-        raise CommandError("Unknown command: %r" % name)
-
-    if isinstance(app_name, BaseCommand):
-        # If the command is already loaded, use it directly.
-        command = app_name
+    if isinstance(name, BaseCommand):
+        # Command object passed in.
+        command = name
+        name = command.__class__.__module__.split('.')[-1]
     else:
-        command = load_command_class(app_name, name)
+        # Load the command object by name.
+        try:
+            app_name = get_commands()[name]
+        except KeyError:
+            raise CommandError("Unknown command: %r" % name)
+
+        if isinstance(app_name, BaseCommand):
+            # If the command is already loaded, use it directly.
+            command = app_name
+        else:
+            command = load_command_class(app_name, name)
 
     # Simulate argument parsing to get the option defaults (see #10080 for details).
     parser = command.create_parser('', name)

+ 7 - 2
docs/ref/django-admin.txt

@@ -1760,7 +1760,8 @@ Running management commands from your code
 To call a management command from code use ``call_command``.
 
 ``name``
-  the name of the command to call.
+  the name of the command to call or a command object. Passing the name is
+  preferred unless the object is required for testing.
 
 ``*args``
   a list of arguments accepted by the command.
@@ -1771,8 +1772,11 @@ To call a management command from code use ``call_command``.
 Examples::
 
       from django.core import management
+      from django.core.management.commands import loaddata
+
       management.call_command('flush', verbosity=0, interactive=False)
       management.call_command('loaddata', 'test_data', verbosity=0)
+      management.call_command(loaddata.Command(), 'test_data', verbosity=0)
 
 Note that command options that take no arguments are passed as keywords
 with ``True`` or ``False``, as you can see with the ``interactive`` option above.
@@ -1799,7 +1803,8 @@ value of the ``handle()`` method of the command.
 .. versionchanged:: 1.10
 
     ``call_command()`` now returns the value received from the
-    ``command.handle()`` method.
+    ``command.handle()`` method. It now also accepts a command object as the
+    first argument.
 
 Output redirection
 ==================

+ 3 - 0
docs/releases/1.10.txt

@@ -278,6 +278,9 @@ Management Commands
   :djadmin:`runserver` does, if the set of migrations on disk don't match the
   migrations in the database.
 
+* To assist with testing, :func:`~django.core.management.call_command` now
+  accepts a command object as the first argument.
+
 Migrations
 ~~~~~~~~~~
 

+ 19 - 19
tests/admin_scripts/tests.py

@@ -1309,52 +1309,52 @@ class ManageRunserver(AdminScriptTestCase):
         self.cmd = Command(stdout=self.output)
         self.cmd.run = monkey_run
 
-    def assertServerSettings(self, addr, port, ipv6=None, raw_ipv6=False):
+    def assertServerSettings(self, addr, port, ipv6=False, raw_ipv6=False):
         self.assertEqual(self.cmd.addr, addr)
         self.assertEqual(self.cmd.port, port)
         self.assertEqual(self.cmd.use_ipv6, ipv6)
         self.assertEqual(self.cmd._raw_ipv6, raw_ipv6)
 
     def test_runserver_addrport(self):
-        self.cmd.handle()
+        call_command(self.cmd)
         self.assertServerSettings('127.0.0.1', '8000')
 
-        self.cmd.handle(addrport="1.2.3.4:8000")
+        call_command(self.cmd, addrport="1.2.3.4:8000")
         self.assertServerSettings('1.2.3.4', '8000')
 
-        self.cmd.handle(addrport="7000")
+        call_command(self.cmd, addrport="7000")
         self.assertServerSettings('127.0.0.1', '7000')
 
     @unittest.skipUnless(socket.has_ipv6, "platform doesn't support IPv6")
     def test_runner_addrport_ipv6(self):
-        self.cmd.handle(addrport="", use_ipv6=True)
+        call_command(self.cmd, addrport="", use_ipv6=True)
         self.assertServerSettings('::1', '8000', ipv6=True, raw_ipv6=True)
 
-        self.cmd.handle(addrport="7000", use_ipv6=True)
+        call_command(self.cmd, addrport="7000", use_ipv6=True)
         self.assertServerSettings('::1', '7000', ipv6=True, raw_ipv6=True)
 
-        self.cmd.handle(addrport="[2001:0db8:1234:5678::9]:7000")
+        call_command(self.cmd, addrport="[2001:0db8:1234:5678::9]:7000")
         self.assertServerSettings('2001:0db8:1234:5678::9', '7000', ipv6=True, raw_ipv6=True)
 
     def test_runner_hostname(self):
-        self.cmd.handle(addrport="localhost:8000")
+        call_command(self.cmd, addrport="localhost:8000")
         self.assertServerSettings('localhost', '8000')
 
-        self.cmd.handle(addrport="test.domain.local:7000")
+        call_command(self.cmd, addrport="test.domain.local:7000")
         self.assertServerSettings('test.domain.local', '7000')
 
     @unittest.skipUnless(socket.has_ipv6, "platform doesn't support IPv6")
     def test_runner_hostname_ipv6(self):
-        self.cmd.handle(addrport="test.domain.local:7000", use_ipv6=True)
+        call_command(self.cmd, addrport="test.domain.local:7000", use_ipv6=True)
         self.assertServerSettings('test.domain.local', '7000', ipv6=True)
 
     def test_runner_ambiguous(self):
         # Only 4 characters, all of which could be in an ipv6 address
-        self.cmd.handle(addrport="beef:7654")
+        call_command(self.cmd, addrport="beef:7654")
         self.assertServerSettings('beef', '7654')
 
         # Uses only characters that could be in an ipv6 address
-        self.cmd.handle(addrport="deadbeef:7654")
+        call_command(self.cmd, addrport="deadbeef:7654")
         self.assertServerSettings('deadbeef', '7654')
 
     def test_no_database(self):
@@ -1530,7 +1530,7 @@ class CommandTypes(AdminScriptTestCase):
         out = StringIO()
         err = StringIO()
         command = Command(stdout=out, stderr=err)
-        command.execute()
+        call_command(command)
         if color.supports_color():
             self.assertIn('Hello, world!\n', out.getvalue())
             self.assertIn('Hello, world!\n', err.getvalue())
@@ -1552,14 +1552,14 @@ class CommandTypes(AdminScriptTestCase):
         out = StringIO()
         err = StringIO()
         command = Command(stdout=out, stderr=err, no_color=True)
-        command.execute()
+        call_command(command)
         self.assertEqual(out.getvalue(), 'Hello, world!\n')
         self.assertEqual(err.getvalue(), 'Hello, world!\n')
 
         out = StringIO()
         err = StringIO()
         command = Command(stdout=out, stderr=err)
-        command.execute(no_color=True)
+        call_command(command, no_color=True)
         self.assertEqual(out.getvalue(), 'Hello, world!\n')
         self.assertEqual(err.getvalue(), 'Hello, world!\n')
 
@@ -1572,11 +1572,11 @@ class CommandTypes(AdminScriptTestCase):
 
         out = StringIO()
         command = Command(stdout=out)
-        command.execute()
+        call_command(command)
         self.assertEqual(out.getvalue(), "Hello, World!\n")
         out.truncate(0)
         new_out = StringIO()
-        command.execute(stdout=new_out)
+        call_command(command, stdout=new_out)
         self.assertEqual(out.getvalue(), "")
         self.assertEqual(new_out.getvalue(), "Hello, World!\n")
 
@@ -1589,11 +1589,11 @@ class CommandTypes(AdminScriptTestCase):
 
         err = StringIO()
         command = Command(stderr=err)
-        command.execute()
+        call_command(command)
         self.assertEqual(err.getvalue(), "Hello, World!\n")
         err.truncate(0)
         new_err = StringIO()
-        command.execute(stderr=new_err)
+        call_command(command, stderr=new_err)
         self.assertEqual(err.getvalue(), "")
         self.assertEqual(new_err.getvalue(), "Hello, World!\n")
 

+ 4 - 4
tests/auth_tests/test_management.py

@@ -342,8 +342,8 @@ class CreatesuperuserManagementCommandTestCase(TestCase):
         """
         sentinel = object()
         command = createsuperuser.Command()
-        command.check = lambda: []
-        command.execute(
+        call_command(
+            command,
             stdin=sentinel,
             stdout=six.StringIO(),
             stderr=six.StringIO(),
@@ -355,8 +355,8 @@ class CreatesuperuserManagementCommandTestCase(TestCase):
         self.assertIs(command.stdin, sentinel)
 
         command = createsuperuser.Command()
-        command.check = lambda: []
-        command.execute(
+        call_command(
+            command,
             stdout=six.StringIO(),
             stderr=six.StringIO(),
             interactive=False,