Bladeren bron

Fixed #14386, #8960, #10235, #10909, #10608, #13845, #14377 - standardize Site/RequestSite usage in various places.

Many thanks to gabrielhurley for putting most of this together.  Also to
bmihelac, arthurk, qingfeng, hvendelbo, petr.pulc@s-cape.cz, Hraban for
reports and some initial patches.

The patch also contains some whitespace/PEP8 fixes.

git-svn-id: http://code.djangoproject.com/svn/django/trunk@13980 bcc190cf-cafb-0310-a4f2-bffc1f526a37
Luke Plant 14 jaren geleden
bovenliggende
commit
667d832e90

+ 3 - 3
django/contrib/auth/forms.py

@@ -1,7 +1,7 @@
 from django.contrib.auth.models import User
 from django.contrib.auth import authenticate
 from django.contrib.auth.tokens import default_token_generator
-from django.contrib.sites.models import Site
+from django.contrib.sites.models import get_current_site
 from django.template import Context, loader
 from django import forms
 from django.utils.translation import ugettext_lazy as _
@@ -117,14 +117,14 @@ class PasswordResetForm(forms.Form):
         return email
 
     def save(self, domain_override=None, email_template_name='registration/password_reset_email.html',
-             use_https=False, token_generator=default_token_generator, from_email=None):
+             use_https=False, token_generator=default_token_generator, from_email=None, request=None):
         """
         Generates a one-use only link for resetting password and sends to the user
         """
         from django.core.mail import send_mail
         for user in self.users_cache:
             if not domain_override:
-                current_site = Site.objects.get_current()
+                current_site = get_current_site(request)
                 site_name = current_site.name
                 domain = current_site.domain
             else:

+ 6 - 0
django/contrib/auth/tests/views.py

@@ -248,6 +248,12 @@ class LogoutTest(AuthViewsTestCase):
         self.assert_('Logged out' in response.content)
         self.confirm_logged_out()
 
+    def test_14377(self):
+        # Bug 14377
+        self.login()
+        response = self.client.get('/logout/')
+        self.assertTrue('site' in response.context)
+
     def test_logout_with_next_page_specified(self): 
         "Logout with next_page option given redirects to specified resource"
         self.login()

+ 16 - 18
django/contrib/auth/views.py

@@ -10,7 +10,7 @@ from django.contrib.auth.tokens import default_token_generator
 from django.views.decorators.csrf import csrf_protect
 from django.core.urlresolvers import reverse
 from django.shortcuts import render_to_response, get_object_or_404
-from django.contrib.sites.models import Site, RequestSite
+from django.contrib.sites.models import get_current_site
 from django.http import HttpResponseRedirect, Http404
 from django.template import RequestContext
 from django.utils.http import urlquote, base36_to_int
@@ -26,21 +26,21 @@ def login(request, template_name='registration/login.html',
     """Displays the login form and handles the login action."""
 
     redirect_to = request.REQUEST.get(redirect_field_name, '')
-    
+
     if request.method == "POST":
         form = authentication_form(data=request.POST)
         if form.is_valid():
             # Light security check -- make sure redirect_to isn't garbage.
             if not redirect_to or ' ' in redirect_to:
                 redirect_to = settings.LOGIN_REDIRECT_URL
-            
-            # Heavier security check -- redirects to http://example.com should 
-            # not be allowed, but things like /view/?param=http://example.com 
+
+            # Heavier security check -- redirects to http://example.com should
+            # not be allowed, but things like /view/?param=http://example.com
             # should be allowed. This regex checks if there is a '//' *before* a
             # question mark.
             elif '//' in redirect_to and re.match(r'[^\?]*//', redirect_to):
                     redirect_to = settings.LOGIN_REDIRECT_URL
-            
+
             # Okay, security checks complete. Log the user in.
             auth_login(request, form.get_user())
 
@@ -51,14 +51,11 @@ def login(request, template_name='registration/login.html',
 
     else:
         form = authentication_form(request)
-    
+
     request.session.set_test_cookie()
-    
-    if Site._meta.installed:
-        current_site = Site.objects.get_current()
-    else:
-        current_site = RequestSite(request)
-    
+
+    current_site = get_current_site(request)
+
     return render_to_response(template_name, {
         'form': form,
         redirect_field_name: redirect_to,
@@ -75,7 +72,10 @@ def logout(request, next_page=None, template_name='registration/logged_out.html'
         if redirect_to:
             return HttpResponseRedirect(redirect_to)
         else:
+            current_site = get_current_site(request)
             return render_to_response(template_name, {
+                'site': current_site,
+                'site_name': current_site.name,
                 'title': _('Logged out')
             }, context_instance=RequestContext(request))
     else:
@@ -97,7 +97,7 @@ def redirect_to_login(next, login_url=None, redirect_field_name=REDIRECT_FIELD_N
 # 4 views for password reset:
 # - password_reset sends the mail
 # - password_reset_done shows a success message for the above
-# - password_reset_confirm checks the link the user clicked and 
+# - password_reset_confirm checks the link the user clicked and
 #   prompts for a new password
 # - password_reset_complete shows a success message for the above
 
@@ -115,12 +115,10 @@ def password_reset(request, is_admin_site=False, template_name='registration/pas
             opts['use_https'] = request.is_secure()
             opts['token_generator'] = token_generator
             opts['from_email'] = from_email
+            opts['email_template_name'] = email_template_name
+            opts['request'] = request
             if is_admin_site:
                 opts['domain_override'] = request.META['HTTP_HOST']
-            else:
-                opts['email_template_name'] = email_template_name
-                if not Site._meta.installed:
-                    opts['domain_override'] = RequestSite(request).domain
             form.save(**opts)
             return HttpResponseRedirect(post_reset_redirect)
     else:

+ 59 - 37
django/contrib/contenttypes/tests.py

@@ -1,47 +1,69 @@
-"""
-Make sure that the content type cache (see ContentTypeManager) works correctly.
-Lookups for a particular content type -- by model or by ID -- should hit the
-database only on the first lookup.
+from django import db
+from django.conf import settings
+from django.contrib.contenttypes.models import ContentType
+from django.contrib.sites.models import Site
+from django.contrib.contenttypes.views import shortcut
+from django.core.exceptions import ObjectDoesNotExist
+from django.http import HttpRequest
+from django.test import TestCase
 
-First, let's make sure we're dealing with a blank slate (and that DEBUG is on so
-that queries get logged)::
 
-    >>> from django.conf import settings
-    >>> settings.DEBUG = True
+class ContentTypesTests(TestCase):
 
-    >>> from django.contrib.contenttypes.models import ContentType
-    >>> ContentType.objects.clear_cache()
+    def setUp(self):
+        # First, let's make sure we're dealing with a blank slate (and that
+        # DEBUG is on so that queries get logged)
+        self.old_DEBUG = settings.DEBUG
+        self.old_Site_meta_installed = Site._meta.installed
+        settings.DEBUG = True
+        ContentType.objects.clear_cache()
+        db.reset_queries()
 
-    >>> from django import db
-    >>> db.reset_queries()
-    
-At this point, a lookup for a ContentType should hit the DB::
+    def tearDown(self):
+        settings.DEBUG = self.old_DEBUG
+        Site._meta.installed = self.old_Site_meta_installed
 
-    >>> ContentType.objects.get_for_model(ContentType)
-    <ContentType: content type>
-    
-    >>> len(db.connection.queries)
-    1
+    def test_lookup_cache(self):
+        """
+        Make sure that the content type cache (see ContentTypeManager)
+        works correctly. Lookups for a particular content type -- by model or
+        by ID -- should hit the database only on the first lookup.
+        """
 
-A second hit, though, won't hit the DB, nor will a lookup by ID::
+        # At this point, a lookup for a ContentType should hit the DB
+        ContentType.objects.get_for_model(ContentType)
+        self.assertEqual(1, len(db.connection.queries))
 
-    >>> ct = ContentType.objects.get_for_model(ContentType)
-    >>> len(db.connection.queries)
-    1
-    >>> ContentType.objects.get_for_id(ct.id)
-    <ContentType: content type>
-    >>> len(db.connection.queries)
-    1
+        # A second hit, though, won't hit the DB, nor will a lookup by ID
+        ct = ContentType.objects.get_for_model(ContentType)
+        self.assertEqual(1, len(db.connection.queries))
+        ContentType.objects.get_for_id(ct.id)
+        self.assertEqual(1, len(db.connection.queries))
 
-Once we clear the cache, another lookup will again hit the DB::
+        # Once we clear the cache, another lookup will again hit the DB
+        ContentType.objects.clear_cache()
+        ContentType.objects.get_for_model(ContentType)
+        len(db.connection.queries)
+        self.assertEqual(2, len(db.connection.queries))
 
-    >>> ContentType.objects.clear_cache()
-    >>> ContentType.objects.get_for_model(ContentType)
-    <ContentType: content type>
-    >>> len(db.connection.queries)
-    2
+    def test_shortcut_view(self):
+        """
+        Check that the shortcut view (used for the admin "view on site"
+        functionality) returns a complete URL regardless of whether the sites
+        framework is installed
+        """
 
-Don't forget to reset DEBUG!
-
-    >>> settings.DEBUG = False
-"""
+        request = HttpRequest()
+        request.META = {
+            "SERVER_NAME": "Example.com",
+            "SERVER_PORT": "80",
+        }
+        from django.contrib.auth.models import User
+        user_ct = ContentType.objects.get_for_model(User)
+        obj = User.objects.create(username="john")
+        Site._meta.installed = True
+        response = shortcut(request, user_ct.id, obj.id)
+        self.assertEqual("http://example.com/users/john/", response._headers.get("location")[1])
+        Site._meta.installed = False
+        response = shortcut(request, user_ct.id, obj.id)
+        self.assertEqual("http://Example.com/users/john/", response._headers.get("location")[1])

+ 22 - 20
django/contrib/contenttypes/views.py

@@ -1,6 +1,6 @@
 from django import http
 from django.contrib.contenttypes.models import ContentType
-from django.contrib.sites.models import Site
+from django.contrib.sites.models import Site, get_current_site
 from django.core.exceptions import ObjectDoesNotExist
 
 def shortcut(request, content_type_id, object_id):
@@ -26,35 +26,37 @@ def shortcut(request, content_type_id, object_id):
     # Otherwise, we need to introspect the object's relationships for a
     # relation to the Site object
     object_domain = None
-    opts = obj._meta
 
-    # First, look for an many-to-many relationship to Site.
-    for field in opts.many_to_many:
-        if field.rel.to is Site:
-            try:
-                # Caveat: In the case of multiple related Sites, this just
-                # selects the *first* one, which is arbitrary.
-                object_domain = getattr(obj, field.name).all()[0].domain
-            except IndexError:
-                pass
-            if object_domain is not None:
-                break
+    if Site._meta.installed:
+        opts = obj._meta
 
-    # Next, look for a many-to-one relationship to Site.
-    if object_domain is None:
-        for field in obj._meta.fields:
-            if field.rel and field.rel.to is Site:
+        # First, look for an many-to-many relationship to Site.
+        for field in opts.many_to_many:
+            if field.rel.to is Site:
                 try:
-                    object_domain = getattr(obj, field.name).domain
-                except Site.DoesNotExist:
+                    # Caveat: In the case of multiple related Sites, this just
+                    # selects the *first* one, which is arbitrary.
+                    object_domain = getattr(obj, field.name).all()[0].domain
+                except IndexError:
                     pass
                 if object_domain is not None:
                     break
 
+        # Next, look for a many-to-one relationship to Site.
+        if object_domain is None:
+            for field in obj._meta.fields:
+                if field.rel and field.rel.to is Site:
+                    try:
+                        object_domain = getattr(obj, field.name).domain
+                    except Site.DoesNotExist:
+                        pass
+                    if object_domain is not None:
+                        break
+
     # Fall back to the current site (if possible).
     if object_domain is None:
         try:
-            object_domain = Site.objects.get_current().domain
+            object_domain = get_current_site(request).domain
         except Site.DoesNotExist:
             pass
 

+ 2 - 2
django/contrib/gis/sitemaps/views.py

@@ -1,6 +1,6 @@
 from django.http import HttpResponse, Http404
 from django.template import loader
-from django.contrib.sites.models import Site
+from django.contrib.sites.models import get_current_site
 from django.core import urlresolvers
 from django.core.paginator import EmptyPage, PageNotAnInteger
 from django.contrib.gis.db.models.fields import GeometryField
@@ -15,7 +15,7 @@ def index(request, sitemaps):
     This view generates a sitemap index that uses the proper view
     for resolving geographic section sitemap URLs.
     """
-    current_site = Site.objects.get_current()
+    current_site = get_current_site(request)
     sites = []
     protocol = request.is_secure() and 'https' or 'http'
     for section, site in sitemaps.items():

+ 7 - 5
django/contrib/sitemaps/__init__.py

@@ -1,3 +1,4 @@
+from django.contrib.sites.models import get_current_site
 from django.core import urlresolvers, paginator
 import urllib
 
@@ -60,8 +61,7 @@ class Sitemap(object):
     paginator = property(_get_paginator)
 
     def get_urls(self, page=1):
-        from django.contrib.sites.models import Site
-        current_site = Site.objects.get_current()
+        current_site = get_current_site(self.request)
         urls = []
         for item in self.paginator.page(page).object_list:
             loc = "http://%s%s" % (current_site.domain, self.__get('location', item))
@@ -77,9 +77,11 @@ class Sitemap(object):
 
 class FlatPageSitemap(Sitemap):
     def items(self):
-        from django.contrib.sites.models import Site
-        current_site = Site.objects.get_current()
-        return current_site.flatpage_set.filter(registration_required=False)
+        current_site = get_current_site(self.request)
+        if hasattr(current_site, "flatpage_set"):
+            return current_site.flatpage_set.filter(registration_required=False)
+        else:
+            return ()
 
 class GenericSitemap(Sitemap):
     priority = None

+ 17 - 1
django/contrib/sitemaps/tests/basic.py

@@ -2,6 +2,7 @@ from datetime import date
 from django.conf import settings
 from django.contrib.auth.models import User
 from django.contrib.flatpages.models import FlatPage
+from django.contrib.sites.models import Site
 from django.test import TestCase
 from django.utils.formats import localize
 from django.utils.translation import activate
@@ -12,11 +13,13 @@ class SitemapTests(TestCase):
 
     def setUp(self):
         self.old_USE_L10N = settings.USE_L10N
+        self.old_Site_meta_installed = Site._meta.installed
         # Create a user that will double as sitemap content
         User.objects.create_user('testuser', 'test@example.com', 's3krit')
 
     def tearDown(self):
         settings.USE_L10N = self.old_USE_L10N
+        Site._meta.installed = self.old_Site_meta_installed
 
     def test_simple_sitemap(self):
         "A simple sitemap can be rendered"
@@ -66,7 +69,7 @@ class SitemapTests(TestCase):
             url=u'/private/',
             title=u'Public Page',
             enable_comments=True,
-            registration_required=True    
+            registration_required=True
         )
         private.sites.add(settings.SITE_ID)
         response = self.client.get('/flatpages/sitemap.xml')
@@ -75,3 +78,16 @@ class SitemapTests(TestCase):
         # Private flatpage should not be in the sitemap
         self.assertNotContains(response, '<loc>http://example.com%s</loc>' % private.url)
 
+    def test_requestsite_sitemap(self):
+        # Make sure hitting the flatpages sitemap without the sites framework
+        # installed doesn't raise an exception
+        Site._meta.installed = False
+        response = self.client.get('/flatpages/sitemap.xml')
+        # Retrieve the sitemap.
+        response = self.client.get('/simple/sitemap.xml')
+        # Check for all the important bits:
+        self.assertEquals(response.content, """<?xml version="1.0" encoding="UTF-8"?>
+<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">
+<url><loc>http://testserver/location/</loc><lastmod>%s</lastmod><changefreq>never</changefreq><priority>0.5</priority></url>
+</urlset>
+""" % date.today().strftime('%Y-%m-%d'))

+ 4 - 2
django/contrib/sitemaps/views.py

@@ -1,15 +1,16 @@
 from django.http import HttpResponse, Http404
 from django.template import loader
-from django.contrib.sites.models import Site
+from django.contrib.sites.models import get_current_site
 from django.core import urlresolvers
 from django.utils.encoding import smart_str
 from django.core.paginator import EmptyPage, PageNotAnInteger
 
 def index(request, sitemaps):
-    current_site = Site.objects.get_current()
+    current_site = get_current_site(request)
     sites = []
     protocol = request.is_secure() and 'https' or 'http'
     for section, site in sitemaps.items():
+        site.request = request
         if callable(site):
             pages = site().paginator.num_pages
         else:
@@ -32,6 +33,7 @@ def sitemap(request, sitemaps, section=None):
         maps = sitemaps.values()
     page = request.GET.get("p", 1)
     for site in maps:
+        site.request = request
         try:
             if callable(site):
                 urls.extend(site().get_urls(page))

+ 18 - 0
django/contrib/sites/models.py

@@ -1,9 +1,12 @@
 from django.db import models
 from django.utils.translation import ugettext_lazy as _
 
+
 SITE_CACHE = {}
 
+
 class SiteManager(models.Manager):
+
     def get_current(self):
         """
         Returns the current ``Site`` based on the SITE_ID in the
@@ -28,7 +31,9 @@ class SiteManager(models.Manager):
         global SITE_CACHE
         SITE_CACHE = {}
 
+
 class Site(models.Model):
+
     domain = models.CharField(_('domain name'), max_length=100)
     name = models.CharField(_('display name'), max_length=50)
     objects = SiteManager()
@@ -56,6 +61,7 @@ class Site(models.Model):
         except KeyError:
             pass
 
+
 class RequestSite(object):
     """
     A class that shares the primary interface of Site (i.e., it has
@@ -75,3 +81,15 @@ class RequestSite(object):
 
     def delete(self):
         raise NotImplementedError('RequestSite cannot be deleted.')
+
+
+def get_current_site(request):
+    """
+    Checks if contrib.sites is installed and returns either the current
+    ``Site`` object or a ``RequestSite`` object based on the request.
+    """
+    if Site._meta.installed:
+        current_site = Site.objects.get_current()
+    else:
+        current_site = RequestSite(request)
+    return current_site

+ 56 - 29
django/contrib/sites/tests.py

@@ -1,29 +1,56 @@
-"""
->>> from django.contrib.sites.models import Site
->>> from django.conf import settings
->>> Site(id=settings.SITE_ID, domain="example.com", name="example.com").save()
-
-# Make sure that get_current() does not return a deleted Site object.
->>> s = Site.objects.get_current()
->>> isinstance(s, Site)
-True
-
->>> s.delete()
->>> Site.objects.get_current()
-Traceback (most recent call last):
-...
-DoesNotExist: Site matching query does not exist.
-
-# After updating a Site object (e.g. via the admin), we shouldn't return a
-# bogus value from the SITE_CACHE.
->>> _ = Site.objects.create(id=settings.SITE_ID, domain="example.com", name="example.com")
->>> site = Site.objects.get_current()
->>> site.name
-u"example.com"
->>> s2 = Site.objects.get(id=settings.SITE_ID)
->>> s2.name = "Example site"
->>> s2.save()
->>> site = Site.objects.get_current()
->>> site.name
-u"Example site"
-"""
+from django.conf import settings
+from django.contrib.sites.models import Site, RequestSite, get_current_site
+from django.core.exceptions import ObjectDoesNotExist
+from django.http import HttpRequest
+from django.test import TestCase
+
+
+class SitesFrameworkTests(TestCase):
+
+    def setUp(self):
+        Site(id=settings.SITE_ID, domain="example.com", name="example.com").save()
+        self.old_Site_meta_installed = Site._meta.installed
+        Site._meta.installed = True
+
+    def tearDown(self):
+        Site._meta.installed = self.old_Site_meta_installed
+
+    def test_site_manager(self):
+        # Make sure that get_current() does not return a deleted Site object.
+        s = Site.objects.get_current()
+        self.assert_(isinstance(s, Site))
+        s.delete()
+        self.assertRaises(ObjectDoesNotExist, Site.objects.get_current)
+
+    def test_site_cache(self):
+        # After updating a Site object (e.g. via the admin), we shouldn't return a
+        # bogus value from the SITE_CACHE.
+        site = Site.objects.get_current()
+        self.assertEqual(u"example.com", site.name)
+        s2 = Site.objects.get(id=settings.SITE_ID)
+        s2.name = "Example site"
+        s2.save()
+        site = Site.objects.get_current()
+        self.assertEqual(u"Example site", site.name)
+
+    def test_get_current_site(self):
+        # Test that the correct Site object is returned
+        request = HttpRequest()
+        request.META = {
+            "SERVER_NAME": "example.com",
+            "SERVER_PORT": "80",
+        }
+        site = get_current_site(request)
+        self.assert_(isinstance(site, Site))
+        self.assertEqual(site.id, settings.SITE_ID)
+
+        # Test that an exception is raised if the sites framework is installed
+        # but there is no matching Site
+        site.delete()
+        self.assertRaises(ObjectDoesNotExist, get_current_site, request)
+
+        # A RequestSite is returned if the sites framework is not installed
+        Site._meta.installed = False
+        site = get_current_site(request)
+        self.assert_(isinstance(site, RequestSite))
+        self.assertEqual(site.name, u"example.com")

+ 2 - 5
django/contrib/syndication/views.py

@@ -1,6 +1,6 @@
 import datetime
 from django.conf import settings
-from django.contrib.sites.models import Site, RequestSite
+from django.contrib.sites.models import get_current_site
 from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
 from django.http import HttpResponse, Http404
 from django.template import loader, Template, TemplateDoesNotExist, RequestContext
@@ -91,10 +91,7 @@ class Feed(object):
         Returns a feedgenerator.DefaultFeed object, fully populated, for
         this feed. Raises FeedDoesNotExist for invalid parameters.
         """
-        if Site._meta.installed:
-            current_site = Site.objects.get_current()
-        else:
-            current_site = RequestSite(request)
+        current_site = get_current_site(request)
 
         link = self.__get_dynamic_attr('link', obj)
         link = add_domain(current_site.domain, link)

+ 11 - 1
docs/ref/contrib/sites.txt

@@ -107,7 +107,7 @@ This has the same benefits as described in the last section.
 Hooking into the current site from views
 ----------------------------------------
 
-On a lower level, you can use the sites framework in your Django views to do
+You can use the sites framework in your Django views to do
 particular things based on the site in which the view is being called.
 For example::
 
@@ -148,6 +148,16 @@ the :class:`~django.contrib.sites.models.Site` model's manager has a
         else:
             # Do something else.
 
+.. versionchanged:: 1.3
+
+For code which relies on getting the current domain but cannot be certain
+that the sites framework will be installed for any given project, there is a
+utility function :func:`~django.contrib.sites.models.get_current_site` that
+takes a request object as an argument and returns either a Site instance (if
+the sites framework is installed) or a RequestSite instance (if it is not).
+This allows loose coupling with the sites framework and provides a usable
+fallback for cases where it is not installed.
+
 Getting the current domain for display
 --------------------------------------