Просмотр исходного кода

Fixed #14774 -- the test client and assertNumQueries didn't work well together. Thanks to Jonas Obrist for the initial patch.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@15251 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Alex Gaynor 14 лет назад
Родитель
Сommit
8308ad4f05

+ 5 - 1
django/test/testcases.py

@@ -6,8 +6,10 @@ from xml.dom.minidom import parseString, Node
 from django.conf import settings
 from django.core import mail
 from django.core.management import call_command
+from django.core.signals import request_started
 from django.core.urlresolvers import clear_url_caches
-from django.db import transaction, connection, connections, DEFAULT_DB_ALIAS
+from django.db import (transaction, connection, connections, DEFAULT_DB_ALIAS,
+    reset_queries)
 from django.http import QueryDict
 from django.test import _doctest as doctest
 from django.test.client import Client
@@ -220,10 +222,12 @@ class _AssertNumQueriesContext(object):
         self.old_debug_cursor = self.connection.use_debug_cursor
         self.connection.use_debug_cursor = True
         self.starting_queries = len(self.connection.queries)
+        request_started.disconnect(reset_queries)
         return self
 
     def __exit__(self, exc_type, exc_value, traceback):
         self.connection.use_debug_cursor = self.old_debug_cursor
+        request_started.connect(reset_queries)
         if exc_type is not None:
             return
 

+ 31 - 9
tests/regressiontests/test_utils/tests.py

@@ -2,12 +2,24 @@ import sys
 
 from django.test import TestCase, skipUnlessDBFeature, skipIfDBFeature
 
+from models import Person
 
 if sys.version_info >= (2, 5):
-    from tests_25 import AssertNumQueriesTests
+    from tests_25 import AssertNumQueriesContextManagerTests
 
 
 class SkippingTestCase(TestCase):
+    def test_skip_unless_db_feature(self):
+        "A test that might be skipped is actually called."
+        # Total hack, but it works, just want an attribute that's always true.
+        @skipUnlessDBFeature("__class__")
+        def test_func():
+            raise ValueError
+
+        self.assertRaises(ValueError, test_func)
+
+
+class AssertNumQueriesTests(TestCase):
     def test_assert_num_queries(self):
         def test_func():
             raise ValueError
@@ -16,18 +28,28 @@ class SkippingTestCase(TestCase):
             self.assertNumQueries, 2, test_func
         )
 
-    def test_skip_unless_db_feature(self):
-        "A test that might be skipped is actually called."
-        # Total hack, but it works, just want an attribute that's always true.
-        @skipUnlessDBFeature("__class__")
-        def test_func():
-            raise ValueError
+    def test_assert_num_queries_with_client(self):
+        person = Person.objects.create(name='test')
 
-        self.assertRaises(ValueError, test_func)
+        self.assertNumQueries(
+            1,
+            self.client.get,
+            "/test_utils/get_person/%s/" % person.pk
+        )
 
+        self.assertNumQueries(
+            1,
+            self.client.get,
+            "/test_utils/get_person/%s/" % person.pk
+        )
 
-class SaveRestoreWarningState(TestCase):
+        def test_func():
+            self.client.get("/test_utils/get_person/%s/" % person.pk)
+            self.client.get("/test_utils/get_person/%s/" % person.pk)
+        self.assertNumQueries(2, test_func)
 
+
+class SaveRestoreWarningState(TestCase):
     def test_save_restore_warnings_state(self):
         """
         Ensure save_warnings_state/restore_warnings_state work correctly.

+ 14 - 1
tests/regressiontests/test_utils/tests_25.py

@@ -5,7 +5,7 @@ from django.test import TestCase
 from models import Person
 
 
-class AssertNumQueriesTests(TestCase):
+class AssertNumQueriesContextManagerTests(TestCase):
     def test_simple(self):
         with self.assertNumQueries(0):
             pass
@@ -26,3 +26,16 @@ class AssertNumQueriesTests(TestCase):
         with self.assertRaises(TypeError):
             with self.assertNumQueries(4000):
                 raise TypeError
+
+    def test_with_client(self):
+        person = Person.objects.create(name="test")
+
+        with self.assertNumQueries(1):
+            self.client.get("/test_utils/get_person/%s/" % person.pk)
+
+        with self.assertNumQueries(1):
+            self.client.get("/test_utils/get_person/%s/" % person.pk)
+
+        with self.assertNumQueries(2):
+            self.client.get("/test_utils/get_person/%s/" % person.pk)
+            self.client.get("/test_utils/get_person/%s/" % person.pk)

+ 8 - 0
tests/regressiontests/test_utils/urls.py

@@ -0,0 +1,8 @@
+from django.conf.urls.defaults import patterns
+
+import views
+
+
+urlpatterns = patterns('',
+    (r'^get_person/(\d+)/$', views.get_person),
+)

+ 7 - 0
tests/regressiontests/test_utils/views.py

@@ -0,0 +1,7 @@
+from django.http import HttpResponse
+from django.shortcuts import get_object_or_404
+from models import Person
+
+def get_person(request, pk):
+    person = get_object_or_404(Person, pk=pk)
+    return HttpResponse(person.name)

+ 4 - 0
tests/urls.py

@@ -1,5 +1,6 @@
 from django.conf.urls.defaults import *
 
+
 urlpatterns = patterns('',
     # test_client modeltest urls
     (r'^test_client/', include('modeltests.test_client.urls')),
@@ -41,4 +42,7 @@ urlpatterns = patterns('',
 
     # special headers views
     (r'special_headers/', include('regressiontests.special_headers.urls')),
+
+    # test util views
+    (r'test_utils/', include('regressiontests.test_utils.urls')),
 )