Prechádzať zdrojové kódy

Refs #25629 -- Added `arity` class attribute to `Func` expressions

Sergey Fedoseev 9 rokov pred
rodič
commit
0a26121797

+ 10 - 0
django/db/models/expressions.py

@@ -482,8 +482,18 @@ class Func(Expression):
     function = None
     template = '%(function)s(%(expressions)s)'
     arg_joiner = ', '
+    arity = None  # The number of arguments the function accepts.
 
     def __init__(self, *expressions, **extra):
+        if self.arity is not None and len(expressions) != self.arity:
+            raise TypeError(
+                "'%s' takes exactly %s %s (%s given)" % (
+                    self.__class__.__name__,
+                    self.arity,
+                    "argument" if self.arity == 1 else "arguments",
+                    len(expressions),
+                )
+            )
         output_field = extra.pop('output_field', None)
         super(Func, self).__init__(output_field=output_field)
         self.source_expressions = self._parse_expressions(*expressions)

+ 0 - 6
django/db/models/functions.py

@@ -145,9 +145,6 @@ class Lower(Transform):
     function = 'LOWER'
     lookup_name = 'lower'
 
-    def __init__(self, expression, **extra):
-        super(Lower, self).__init__(expression, **extra)
-
 
 class Now(Func):
     template = 'CURRENT_TIMESTAMP'
@@ -197,6 +194,3 @@ class Substr(Func):
 class Upper(Transform):
     function = 'UPPER'
     lookup_name = 'upper'
-
-    def __init__(self, expression, **extra):
-        super(Upper, self).__init__(expression, **extra)

+ 1 - 4
django/db/models/lookups.py

@@ -123,10 +123,7 @@ class Transform(RegisterLookupMixin, Func):
     first examine self and then check output_field.
     """
     bilateral = False
-
-    def __init__(self, expression, **extra):
-        # Restrict Transform to allow only a single expression.
-        super(Transform, self).__init__(expression, **extra)
+    arity = 1
 
     @property
     def lhs(self):

+ 9 - 0
docs/ref/models/expressions.txt

@@ -252,6 +252,15 @@ The ``Func`` API is as follows:
         A class attribute that denotes the character used to join the list of
         ``expressions`` together. Defaults to ``', '``.
 
+    .. attribute:: arity
+
+        .. versionadded:: 1.10
+
+        A class attribute that denotes the number of arguments the function
+        accepts. If this attribute is set and the function is called with a
+        different number of expressions, ``TypeError`` will be raised. Defaults
+        to ``None``.
+
 The ``*expressions`` argument is a list of positional expressions that the
 function will be applied to. The expressions will be converted to strings,
 joined together with ``arg_joiner``, and then interpolated into the ``template``

+ 4 - 0
docs/releases/1.10.txt

@@ -175,6 +175,10 @@ Models
   accessible as a descriptor on the proxied model class and may be referenced in
   queryset filtering.
 
+* The :attr:`~django.db.models.Func.arity` class attribute is added to
+  :class:`~django.db.models.Func`. This attribute can be used to set the number
+  of arguments the function accepts.
+
 Requests and Responses
 ^^^^^^^^^^^^^^^^^^^^^^
 

+ 3 - 0
tests/db_functions/tests.py

@@ -389,6 +389,9 @@ class FunctionTests(TestCase):
             lambda a: (a.lower_name, a.name)
         )
 
+        with self.assertRaisesMessage(TypeError, "'Lower' takes exactly 1 argument (2 given)"):
+            Author.objects.update(name=Lower('name', 'name'))
+
     def test_upper(self):
         Author.objects.create(name='John Smith', alias='smithj')
         Author.objects.create(name='Rhonda')