Browse Source

Fixed #22078 -- Fixed crash of Feed with decorated methods.

Marcelo Galigniana 2 years ago
parent
commit
8c0886b068

+ 15 - 2
django/contrib/syndication/views.py

@@ -1,3 +1,5 @@
+from inspect import getattr_static, unwrap
+
 from django.contrib.sites.shortcuts import get_current_site
 from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
 from django.http import Http404, HttpResponse
@@ -82,10 +84,21 @@ class Feed:
             # Check co_argcount rather than try/excepting the function and
             # catching the TypeError, because something inside the function
             # may raise the TypeError. This technique is more accurate.
+            func = unwrap(attr)
             try:
-                code = attr.__code__
+                code = func.__code__
             except AttributeError:
-                code = attr.__call__.__code__
+                func = unwrap(attr.__call__)
+                code = func.__code__
+            # If function doesn't have arguments and it is not a static method,
+            # it was decorated without using @functools.wraps.
+            if not code.co_argcount and not isinstance(
+                getattr_static(self, func.__name__, None), staticmethod
+            ):
+                raise ImproperlyConfigured(
+                    f"Feed method {attname!r} decorated by {func.__name__!r} needs to "
+                    f"use @functools.wraps."
+                )
             if code.co_argcount == 2:  # one argument is 'self'
                 return attr(obj)
             else:

+ 54 - 1
tests/syndication_tests/feeds.py

@@ -1,3 +1,5 @@
+from functools import wraps
+
 from django.contrib.syndication import views
 from django.utils import feedgenerator
 from django.utils.timezone import get_fixed_timezone
@@ -5,6 +7,23 @@ from django.utils.timezone import get_fixed_timezone
 from .models import Article, Entry
 
 
+def wraps_decorator(f):
+    @wraps(f)
+    def wrapper(*args, **kwargs):
+        value = f(*args, **kwargs)
+        return f"{value} -- decorated by @wraps."
+
+    return wrapper
+
+
+def common_decorator(f):
+    def wrapper(*args, **kwargs):
+        value = f(*args, **kwargs)
+        return f"{value} -- common decorated."
+
+    return wrapper
+
+
 class TestRss2Feed(views.Feed):
     title = "My blog"
     description = "A more thorough description of my blog."
@@ -47,11 +66,45 @@ class TestRss2FeedWithCallableObject(TestRss2Feed):
     ttl = TimeToLive()
 
 
-class TestRss2FeedWithStaticMethod(TestRss2Feed):
+class TestRss2FeedWithDecoratedMethod(TestRss2Feed):
+    class TimeToLive:
+        @wraps_decorator
+        def __call__(self):
+            return 800
+
+    @staticmethod
+    @wraps_decorator
+    def feed_copyright():
+        return "Copyright (c) 2022, John Doe"
+
+    ttl = TimeToLive()
+
     @staticmethod
     def categories():
         return ("javascript", "vue")
 
+    @wraps_decorator
+    def title(self):
+        return "Overridden title"
+
+    @wraps_decorator
+    def item_title(self, item):
+        return f"Overridden item title: {item.title}"
+
+    @wraps_decorator
+    def description(self, obj):
+        return "Overridden description"
+
+    @wraps_decorator
+    def item_description(self):
+        return "Overridden item description"
+
+
+class TestRss2FeedWithWrongDecoratedMethod(TestRss2Feed):
+    @common_decorator
+    def item_description(self, item):
+        return f"Overridden item description: {item.title}"
+
 
 class TestRss2FeedWithGuidIsPermaLinkTrue(TestRss2Feed):
     def item_guid_is_permalink(self, item):

+ 29 - 2
tests/syndication_tests/tests.py

@@ -202,11 +202,38 @@ class SyndicationFeedTest(FeedTestCase):
         chan = doc.getElementsByTagName("rss")[0].getElementsByTagName("channel")[0]
         self.assertChildNodeContent(chan, {"ttl": "700"})
 
-    def test_rss2_feed_with_static_methods(self):
-        response = self.client.get("/syndication/rss2/with-static-methods/")
+    def test_rss2_feed_with_decorated_methods(self):
+        response = self.client.get("/syndication/rss2/with-decorated-methods/")
         doc = minidom.parseString(response.content)
         chan = doc.getElementsByTagName("rss")[0].getElementsByTagName("channel")[0]
         self.assertCategories(chan, ["javascript", "vue"])
+        self.assertChildNodeContent(
+            chan,
+            {
+                "title": "Overridden title -- decorated by @wraps.",
+                "description": "Overridden description -- decorated by @wraps.",
+                "ttl": "800 -- decorated by @wraps.",
+                "copyright": "Copyright (c) 2022, John Doe -- decorated by @wraps.",
+            },
+        )
+        items = chan.getElementsByTagName("item")
+        self.assertChildNodeContent(
+            items[0],
+            {
+                "title": (
+                    f"Overridden item title: {self.e1.title} -- decorated by @wraps."
+                ),
+                "description": "Overridden item description -- decorated by @wraps.",
+            },
+        )
+
+    def test_rss2_feed_with_wrong_decorated_methods(self):
+        msg = (
+            "Feed method 'item_description' decorated by 'wrapper' needs to use "
+            "@functools.wraps."
+        )
+        with self.assertRaisesMessage(ImproperlyConfigured, msg):
+            self.client.get("/syndication/rss2/with-wrong-decorated-methods/")
 
     def test_rss2_feed_guid_permalink_false(self):
         """

+ 8 - 1
tests/syndication_tests/urls.py

@@ -7,7 +7,14 @@ urlpatterns = [
     path(
         "syndication/rss2/with-callable-object/", feeds.TestRss2FeedWithCallableObject()
     ),
-    path("syndication/rss2/with-static-methods/", feeds.TestRss2FeedWithStaticMethod()),
+    path(
+        "syndication/rss2/with-decorated-methods/",
+        feeds.TestRss2FeedWithDecoratedMethod(),
+    ),
+    path(
+        "syndication/rss2/with-wrong-decorated-methods/",
+        feeds.TestRss2FeedWithWrongDecoratedMethod(),
+    ),
     path("syndication/rss2/articles/<int:entry_id>/", feeds.TestGetObjectFeed()),
     path(
         "syndication/rss2/guid_ispermalink_true/",