Procházet zdrojové kódy

Refs #23919 -- Replaced usage of django.utils.functional.curry() with functools.partial()/partialmethod().

Sergey Fedoseev před 7 roky
rodič
revize
5b1c389603

+ 7 - 7
django/db/models/base.py

@@ -1,6 +1,7 @@
 import copy
 import inspect
 import warnings
+from functools import partialmethod
 from itertools import chain
 
 from django.apps import apps
@@ -27,7 +28,6 @@ from django.db.models.signals import (
 )
 from django.db.models.utils import make_model_tuple
 from django.utils.encoding import force_text
-from django.utils.functional import curry
 from django.utils.text import capfirst, get_text_list
 from django.utils.translation import gettext_lazy as _
 from django.utils.version import get_version
@@ -328,8 +328,8 @@ class ModelBase(type):
         opts._prepare(cls)
 
         if opts.order_with_respect_to:
-            cls.get_next_in_order = curry(cls._get_next_or_previous_in_order, is_next=True)
-            cls.get_previous_in_order = curry(cls._get_next_or_previous_in_order, is_next=False)
+            cls.get_next_in_order = partialmethod(cls._get_next_or_previous_in_order, is_next=True)
+            cls.get_previous_in_order = partialmethod(cls._get_next_or_previous_in_order, is_next=False)
 
             # Defer creating accessors on the foreign class until it has been
             # created and registered. If remote_field is None, we're ordering
@@ -1670,7 +1670,7 @@ class Model(metaclass=ModelBase):
 
 # ORDERING METHODS #########################
 
-def method_set_order(ordered_obj, self, id_list, using=None):
+def method_set_order(self, ordered_obj, id_list, using=None):
     if using is None:
         using = DEFAULT_DB_ALIAS
     order_wrt = ordered_obj._meta.order_with_respect_to
@@ -1682,7 +1682,7 @@ def method_set_order(ordered_obj, self, id_list, using=None):
             ordered_obj.objects.filter(pk=j, **filter_args).update(_order=i)
 
 
-def method_get_order(ordered_obj, self):
+def method_get_order(self, ordered_obj):
     order_wrt = ordered_obj._meta.order_with_respect_to
     filter_args = order_wrt.get_forward_related_filter(self)
     pk_name = ordered_obj._meta.pk.name
@@ -1693,12 +1693,12 @@ def make_foreign_order_accessors(model, related_model):
     setattr(
         related_model,
         'get_%s_order' % model.__name__.lower(),
-        curry(method_get_order, model)
+        partialmethod(method_get_order, model)
     )
     setattr(
         related_model,
         'set_%s_order' % model.__name__.lower(),
-        curry(method_set_order, model)
+        partialmethod(method_set_order, model)
     )
 
 ########

+ 5 - 5
django/db/models/fields/__init__.py

@@ -6,7 +6,7 @@ import itertools
 import uuid
 import warnings
 from base64 import b64decode, b64encode
-from functools import total_ordering
+from functools import partialmethod, total_ordering
 
 from django import forms
 from django.apps import apps
@@ -26,7 +26,7 @@ from django.utils.dateparse import (
 )
 from django.utils.duration import duration_string
 from django.utils.encoding import force_bytes, smart_text
-from django.utils.functional import Promise, cached_property, curry
+from django.utils.functional import Promise, cached_property
 from django.utils.ipv6 import clean_ipv6_address
 from django.utils.itercompat import is_iterable
 from django.utils.text import capfirst
@@ -717,7 +717,7 @@ class Field(RegisterLookupMixin):
                 setattr(cls, self.attname, DeferredAttribute(self.attname, cls))
         if self.choices:
             setattr(cls, 'get_%s_display' % self.name,
-                    curry(cls._get_FIELD_display, field=self))
+                    partialmethod(cls._get_FIELD_display, field=self))
 
     def get_filter_kwargs_for_object(self, obj):
         """
@@ -1254,11 +1254,11 @@ class DateField(DateTimeCheckMixin, Field):
         if not self.null:
             setattr(
                 cls, 'get_next_by_%s' % self.name,
-                curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=True)
+                partialmethod(cls._get_next_or_previous_by_FIELD, field=self, is_next=True)
             )
             setattr(
                 cls, 'get_previous_by_%s' % self.name,
-                curry(cls._get_next_or_previous_by_FIELD, field=self, is_next=False)
+                partialmethod(cls._get_next_or_previous_by_FIELD, field=self, is_next=False)
             )
 
     def get_prep_value(self, value):

+ 8 - 8
django/db/models/fields/related.py

@@ -12,7 +12,7 @@ from django.db.models.constants import LOOKUP_SEP
 from django.db.models.deletion import CASCADE, SET_DEFAULT, SET_NULL
 from django.db.models.query_utils import PathInfo
 from django.db.models.utils import make_model_tuple
-from django.utils.functional import cached_property, curry
+from django.utils.functional import cached_property
 from django.utils.translation import gettext_lazy as _
 
 from . import Field
@@ -1567,7 +1567,7 @@ class ManyToManyField(RelatedField):
         setattr(cls, self.name, ManyToManyDescriptor(self.remote_field, reverse=False))
 
         # Set up the accessor for the m2m table name for the relation.
-        self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)
+        self.m2m_db_table = partial(self._get_m2m_db_table, cls._meta)
 
     def contribute_to_related_class(self, cls, related):
         # Internal M2Ms (i.e., those with a related name ending with '+')
@@ -1576,15 +1576,15 @@ class ManyToManyField(RelatedField):
             setattr(cls, related.get_accessor_name(), ManyToManyDescriptor(self.remote_field, reverse=True))
 
         # Set up the accessors for the column names on the m2m table.
-        self.m2m_column_name = curry(self._get_m2m_attr, related, 'column')
-        self.m2m_reverse_name = curry(self._get_m2m_reverse_attr, related, 'column')
+        self.m2m_column_name = partial(self._get_m2m_attr, related, 'column')
+        self.m2m_reverse_name = partial(self._get_m2m_reverse_attr, related, 'column')
 
-        self.m2m_field_name = curry(self._get_m2m_attr, related, 'name')
-        self.m2m_reverse_field_name = curry(self._get_m2m_reverse_attr, related, 'name')
+        self.m2m_field_name = partial(self._get_m2m_attr, related, 'name')
+        self.m2m_reverse_field_name = partial(self._get_m2m_reverse_attr, related, 'name')
 
-        get_m2m_rel = curry(self._get_m2m_attr, related, 'remote_field')
+        get_m2m_rel = partial(self._get_m2m_attr, related, 'remote_field')
         self.m2m_target_field_name = lambda: get_m2m_rel().field_name
-        get_m2m_reverse_rel = curry(self._get_m2m_reverse_attr, related, 'remote_field')
+        get_m2m_reverse_rel = partial(self._get_m2m_reverse_attr, related, 'remote_field')
         self.m2m_reverse_target_field_name = lambda: get_m2m_reverse_rel().field_name
 
     def set_attributes_from_rel(self):

+ 4 - 3
django/test/client.py

@@ -4,6 +4,7 @@ import os
 import re
 import sys
 from copy import copy
+from functools import partial
 from importlib import import_module
 from io import BytesIO
 from urllib.parse import unquote_to_bytes, urljoin, urlparse, urlsplit
@@ -21,7 +22,7 @@ from django.test import signals
 from django.test.utils import ContextList
 from django.urls import resolve
 from django.utils.encoding import force_bytes
-from django.utils.functional import SimpleLazyObject, curry
+from django.utils.functional import SimpleLazyObject
 from django.utils.http import urlencode
 from django.utils.itercompat import is_iterable
 
@@ -455,7 +456,7 @@ class Client(RequestFactory):
         # Curry a data dictionary into an instance of the template renderer
         # callback function.
         data = {}
-        on_template_render = curry(store_rendered_templates, data)
+        on_template_render = partial(store_rendered_templates, data)
         signal_uid = "template-render-%s" % id(request)
         signals.template_rendered.connect(on_template_render, dispatch_uid=signal_uid)
         # Capture exceptions created by the handler.
@@ -491,7 +492,7 @@ class Client(RequestFactory):
             response.templates = data.get("templates", [])
             response.context = data.get("context")
 
-            response.json = curry(self._parse_json, response)
+            response.json = partial(self._parse_json, response)
 
             # Attach the ResolverMatch instance to the response
             response.resolver_match = SimpleLazyObject(lambda: resolve(request['PATH_INFO']))

+ 3 - 2
tests/schema/fields.py

@@ -1,9 +1,10 @@
+from functools import partial
+
 from django.db import models
 from django.db.models.fields.related import (
     RECURSIVE_RELATIONSHIP_CONSTANT, ManyToManyDescriptor, ManyToManyField,
     ManyToManyRel, RelatedField, create_many_to_many_intermediary_model,
 )
-from django.utils.functional import curry
 
 
 class CustomManyToManyField(RelatedField):
@@ -43,7 +44,7 @@ class CustomManyToManyField(RelatedField):
         if not self.remote_field.through and not cls._meta.abstract and not cls._meta.swapped:
             self.remote_field.through = create_many_to_many_intermediary_model(self, cls)
         setattr(cls, self.name, ManyToManyDescriptor(self.remote_field))
-        self.m2m_db_table = curry(self._get_m2m_db_table, cls._meta)
+        self.m2m_db_table = partial(self._get_m2m_db_table, cls._meta)
 
     def get_internal_type(self):
         return 'ManyToManyField'

+ 1 - 1
tests/serializers/test_data.py

@@ -390,7 +390,7 @@ class SerializerDataTests(TestCase):
     pass
 
 
-def serializerTest(format, self):
+def serializerTest(self, format):
 
     # Create all the objects defined in the test data
     objects = []

+ 2 - 2
tests/serializers/test_natural.py

@@ -10,7 +10,7 @@ class NaturalKeySerializerTests(TestCase):
     pass
 
 
-def natural_key_serializer_test(format, self):
+def natural_key_serializer_test(self, format):
     # Create all the objects defined in the test data
     with connection.constraint_checks_disabled():
         objects = [
@@ -36,7 +36,7 @@ def natural_key_serializer_test(format, self):
         )
 
 
-def natural_key_test(format, self):
+def natural_key_test(self, format):
     book1 = {
         'data': '978-1590597255',
         'title': 'The Definitive Guide to Django: Web Development Done Right',

+ 2 - 2
tests/serializers/tests.py

@@ -1,4 +1,5 @@
 from datetime import datetime
+from functools import partialmethod
 from io import StringIO
 from unittest import mock
 
@@ -9,7 +10,6 @@ from django.db import connection, transaction
 from django.http import HttpResponse
 from django.test import SimpleTestCase, override_settings, skipUnlessDBFeature
 from django.test.utils import Approximate
-from django.utils.functional import curry
 
 from .models import (
     Actor, Article, Author, AuthorProfile, BaseModel, Category, ComplexModel,
@@ -405,4 +405,4 @@ def register_tests(test_class, method_name, test_func, exclude=None):
             (exclude is None or f not in exclude))
     ]
     for format_ in formats:
-        setattr(test_class, method_name % format_, curry(test_func, format_))
+        setattr(test_class, method_name % format_, partialmethod(test_func, format_))