Browse Source

Fixed #29048 -- Added **extra_context to database function as_vendor() methods.

priyanshsaxena 7 years ago
parent
commit
83b04d4f88

+ 2 - 2
django/contrib/gis/db/models/aggregates.py

@@ -26,10 +26,10 @@ class GeoAggregate(Aggregate):
             **extra_context
         )
 
-    def as_oracle(self, compiler, connection):
+    def as_oracle(self, compiler, connection, **extra_context):
         tolerance = self.extra.get('tolerance') or getattr(self, 'tolerance', 0.05)
         template = None if self.is_extent else '%(function)s(SDOAGGRTYPE(%(expressions)s,%(tolerance)s))'
-        return self.as_sql(compiler, connection, template=template, tolerance=tolerance)
+        return self.as_sql(compiler, connection, template=template, tolerance=tolerance, **extra_context)
 
     def resolve_expression(self, query=None, allow_joins=True, reuse=None, summarize=False, for_save=False):
         c = super().resolve_expression(query, allow_joins, reuse, summarize, for_save)

+ 26 - 22
django/contrib/gis/db/models/functions.py

@@ -102,19 +102,23 @@ class SQLiteDecimalToFloatMixin:
     By default, Decimal values are converted to str by the SQLite backend, which
     is not acceptable by the GIS functions expecting numeric values.
     """
-    def as_sqlite(self, compiler, connection):
+    def as_sqlite(self, compiler, connection, **extra_context):
         for expr in self.get_source_expressions():
             if hasattr(expr, 'value') and isinstance(expr.value, Decimal):
                 expr.value = float(expr.value)
-        return super().as_sql(compiler, connection)
+        return super().as_sql(compiler, connection, **extra_context)
 
 
 class OracleToleranceMixin:
     tolerance = 0.05
 
-    def as_oracle(self, compiler, connection):
+    def as_oracle(self, compiler, connection, **extra_context):
         tol = self.extra.get('tolerance', self.tolerance)
-        return self.as_sql(compiler, connection, template="%%(function)s(%%(expressions)s, %s)" % tol)
+        return self.as_sql(
+            compiler, connection,
+            template="%%(function)s(%%(expressions)s, %s)" % tol,
+            **extra_context
+        )
 
 
 class Area(OracleToleranceMixin, GeoFunc):
@@ -181,11 +185,11 @@ class AsGML(GeoFunc):
 
 
 class AsKML(AsGML):
-    def as_sqlite(self, compiler, connection):
+    def as_sqlite(self, compiler, connection, **extra_context):
         # No version parameter
         clone = self.copy()
         clone.set_source_expressions(self.get_source_expressions()[1:])
-        return clone.as_sql(compiler, connection)
+        return clone.as_sql(compiler, connection, **extra_context)
 
 
 class AsSVG(GeoFunc):
@@ -205,10 +209,10 @@ class BoundingCircle(OracleToleranceMixin, GeoFunc):
     def __init__(self, expression, num_seg=48, **extra):
         super().__init__(expression, num_seg, **extra)
 
-    def as_oracle(self, compiler, connection):
+    def as_oracle(self, compiler, connection, **extra_context):
         clone = self.copy()
         clone.set_source_expressions([self.get_source_expressions()[0]])
-        return super(BoundingCircle, clone).as_oracle(compiler, connection)
+        return super(BoundingCircle, clone).as_oracle(compiler, connection, **extra_context)
 
 
 class Centroid(OracleToleranceMixin, GeomOutputGeoFunc):
@@ -239,7 +243,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
             self.spheroid = self._handle_param(spheroid, 'spheroid', bool)
         super().__init__(*expressions, **extra)
 
-    def as_postgresql(self, compiler, connection):
+    def as_postgresql(self, compiler, connection, **extra_context):
         clone = self.copy()
         function = None
         expr2 = clone.source_expressions[1]
@@ -262,7 +266,7 @@ class Distance(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
                 clone.source_expressions.append(Value(self.geo_field.spheroid(connection)))
             else:
                 function = connection.ops.spatial_function_name('DistanceSphere')
-        return super(Distance, clone).as_sql(compiler, connection, function=function)
+        return super(Distance, clone).as_sql(compiler, connection, function=function, **extra_context)
 
     def as_sqlite(self, compiler, connection, **extra_context):
         if self.geo_field.geodetic(connection):
@@ -300,12 +304,12 @@ class GeoHash(GeoFunc):
             expressions.append(self._handle_param(precision, 'precision', int))
         super().__init__(*expressions, **extra)
 
-    def as_mysql(self, compiler, connection):
+    def as_mysql(self, compiler, connection, **extra_context):
         clone = self.copy()
         # If no precision is provided, set it to the maximum.
         if len(clone.source_expressions) < 2:
             clone.source_expressions.append(Value(100))
-        return clone.as_sql(compiler, connection)
+        return clone.as_sql(compiler, connection, **extra_context)
 
 
 class Intersection(OracleToleranceMixin, GeomOutputGeoFunc):
@@ -333,7 +337,7 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
             raise NotSupportedError("This backend doesn't support Length on geodetic fields")
         return super().as_sql(compiler, connection, **extra_context)
 
-    def as_postgresql(self, compiler, connection):
+    def as_postgresql(self, compiler, connection, **extra_context):
         clone = self.copy()
         function = None
         if self.source_is_geography():
@@ -346,13 +350,13 @@ class Length(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
             dim = min(f.dim for f in self.get_source_fields() if f)
             if dim > 2:
                 function = connection.ops.length3d
-        return super(Length, clone).as_sql(compiler, connection, function=function)
+        return super(Length, clone).as_sql(compiler, connection, function=function, **extra_context)
 
-    def as_sqlite(self, compiler, connection):
+    def as_sqlite(self, compiler, connection, **extra_context):
         function = None
         if self.geo_field.geodetic(connection):
             function = 'GeodesicLength' if self.spheroid else 'GreatCircleLength'
-        return super().as_sql(compiler, connection, function=function)
+        return super().as_sql(compiler, connection, function=function, **extra_context)
 
 
 class LineLocatePoint(GeoFunc):
@@ -383,19 +387,19 @@ class NumPoints(GeoFunc):
 class Perimeter(DistanceResultMixin, OracleToleranceMixin, GeoFunc):
     arity = 1
 
-    def as_postgresql(self, compiler, connection):
+    def as_postgresql(self, compiler, connection, **extra_context):
         function = None
         if self.geo_field.geodetic(connection) and not self.source_is_geography():
             raise NotSupportedError("ST_Perimeter cannot use a non-projected non-geography field.")
         dim = min(f.dim for f in self.get_source_fields())
         if dim > 2:
             function = connection.ops.perimeter3d
-        return super().as_sql(compiler, connection, function=function)
+        return super().as_sql(compiler, connection, function=function, **extra_context)
 
-    def as_sqlite(self, compiler, connection):
+    def as_sqlite(self, compiler, connection, **extra_context):
         if self.geo_field.geodetic(connection):
             raise NotSupportedError("Perimeter cannot use a non-projected field.")
-        return super().as_sql(compiler, connection)
+        return super().as_sql(compiler, connection, **extra_context)
 
 
 class PointOnSurface(OracleToleranceMixin, GeomOutputGeoFunc):
@@ -454,12 +458,12 @@ class Transform(GeomOutputGeoFunc):
 
 
 class Translate(Scale):
-    def as_sqlite(self, compiler, connection):
+    def as_sqlite(self, compiler, connection, **extra_context):
         clone = self.copy()
         if len(self.source_expressions) < 4:
             # Always provide the z parameter for ST_Translate
             clone.source_expressions.append(Value(0))
-        return super(Translate, clone).as_sqlite(compiler, connection)
+        return super(Translate, clone).as_sqlite(compiler, connection, **extra_context)
 
 
 class Union(OracleToleranceMixin, GeomOutputGeoFunc):

+ 12 - 9
django/db/models/aggregates.py

@@ -64,7 +64,10 @@ class Aggregate(Func):
             if connection.features.supports_aggregate_filter_clause:
                 filter_sql, filter_params = self.filter.as_sql(compiler, connection)
                 template = self.filter_template % extra_context.get('template', self.template)
-                sql, params = super().as_sql(compiler, connection, template=template, filter=filter_sql)
+                sql, params = super().as_sql(
+                    compiler, connection, template=template, filter=filter_sql,
+                    **extra_context
+                )
                 return sql, params + filter_params
             else:
                 copy = self.copy()
@@ -92,20 +95,20 @@ class Avg(Aggregate):
             return FloatField()
         return super()._resolve_output_field()
 
-    def as_mysql(self, compiler, connection):
-        sql, params = super().as_sql(compiler, connection)
+    def as_mysql(self, compiler, connection, **extra_context):
+        sql, params = super().as_sql(compiler, connection, **extra_context)
         if self.output_field.get_internal_type() == 'DurationField':
             sql = 'CAST(%s as SIGNED)' % sql
         return sql, params
 
-    def as_oracle(self, compiler, connection):
+    def as_oracle(self, compiler, connection, **extra_context):
         if self.output_field.get_internal_type() == 'DurationField':
             expression = self.get_source_expressions()[0]
             from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
             return compiler.compile(
                 SecondsToInterval(Avg(IntervalToSeconds(expression), filter=self.filter))
             )
-        return super().as_sql(compiler, connection)
+        return super().as_sql(compiler, connection, **extra_context)
 
 
 class Count(Aggregate):
@@ -157,20 +160,20 @@ class Sum(Aggregate):
     function = 'SUM'
     name = 'Sum'
 
-    def as_mysql(self, compiler, connection):
-        sql, params = super().as_sql(compiler, connection)
+    def as_mysql(self, compiler, connection, **extra_context):
+        sql, params = super().as_sql(compiler, connection, **extra_context)
         if self.output_field.get_internal_type() == 'DurationField':
             sql = 'CAST(%s as SIGNED)' % sql
         return sql, params
 
-    def as_oracle(self, compiler, connection):
+    def as_oracle(self, compiler, connection, **extra_context):
         if self.output_field.get_internal_type() == 'DurationField':
             expression = self.get_source_expressions()[0]
             from django.db.backends.oracle.functions import IntervalToSeconds, SecondsToInterval
             return compiler.compile(
                 SecondsToInterval(Sum(IntervalToSeconds(expression)))
             )
-        return super().as_sql(compiler, connection)
+        return super().as_sql(compiler, connection, **extra_context)
 
 
 class Variance(Aggregate):

+ 11 - 11
django/db/models/functions/comparison.py

@@ -14,16 +14,16 @@ class Cast(Func):
         extra_context['db_type'] = self.output_field.cast_db_type(connection)
         return super().as_sql(compiler, connection, **extra_context)
 
-    def as_mysql(self, compiler, connection):
+    def as_mysql(self, compiler, connection, **extra_context):
         # MySQL doesn't support explicit cast to float.
         template = '(%(expressions)s + 0.0)' if self.output_field.get_internal_type() == 'FloatField' else None
-        return self.as_sql(compiler, connection, template=template)
+        return self.as_sql(compiler, connection, template=template, **extra_context)
 
-    def as_postgresql(self, compiler, connection):
+    def as_postgresql(self, compiler, connection, **extra_context):
         # CAST would be valid too, but the :: shortcut syntax is more readable.
         # 'expressions' is wrapped in parentheses in case it's a complex
         # expression.
-        return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s')
+        return self.as_sql(compiler, connection, template='(%(expressions)s)::%(db_type)s', **extra_context)
 
 
 class Coalesce(Func):
@@ -35,7 +35,7 @@ class Coalesce(Func):
             raise ValueError('Coalesce must take at least two expressions')
         super().__init__(*expressions, **extra)
 
-    def as_oracle(self, compiler, connection):
+    def as_oracle(self, compiler, connection, **extra_context):
         # Oracle prohibits mixing TextField (NCLOB) and CharField (NVARCHAR2),
         # so convert all fields to NCLOB when that type is expected.
         if self.output_field.get_internal_type() == 'TextField':
@@ -47,8 +47,8 @@ class Coalesce(Func):
             ]
             clone = self.copy()
             clone.set_source_expressions(expressions)
-            return super(Coalesce, clone).as_sql(compiler, connection)
-        return self.as_sql(compiler, connection)
+            return super(Coalesce, clone).as_sql(compiler, connection, **extra_context)
+        return self.as_sql(compiler, connection, **extra_context)
 
 
 class Greatest(Func):
@@ -66,9 +66,9 @@ class Greatest(Func):
             raise ValueError('Greatest must take at least two expressions')
         super().__init__(*expressions, **extra)
 
-    def as_sqlite(self, compiler, connection):
+    def as_sqlite(self, compiler, connection, **extra_context):
         """Use the MAX function on SQLite."""
-        return super().as_sqlite(compiler, connection, function='MAX')
+        return super().as_sqlite(compiler, connection, function='MAX', **extra_context)
 
 
 class Least(Func):
@@ -86,6 +86,6 @@ class Least(Func):
             raise ValueError('Least must take at least two expressions')
         super().__init__(*expressions, **extra)
 
-    def as_sqlite(self, compiler, connection):
+    def as_sqlite(self, compiler, connection, **extra_context):
         """Use the MIN function on SQLite."""
-        return super().as_sqlite(compiler, connection, function='MIN')
+        return super().as_sqlite(compiler, connection, function='MIN', **extra_context)

+ 2 - 2
django/db/models/functions/datetime.py

@@ -159,11 +159,11 @@ class Now(Func):
     template = 'CURRENT_TIMESTAMP'
     output_field = fields.DateTimeField()
 
-    def as_postgresql(self, compiler, connection):
+    def as_postgresql(self, compiler, connection, **extra_context):
         # PostgreSQL's CURRENT_TIMESTAMP means "the time at the start of the
         # transaction". Use STATEMENT_TIMESTAMP to be cross-compatible with
         # other databases.
-        return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()')
+        return self.as_sql(compiler, connection, template='STATEMENT_TIMESTAMP()', **extra_context)
 
 
 class TruncBase(TimezoneMixin, Transform):

+ 24 - 16
django/db/models/functions/math.py

@@ -9,7 +9,7 @@ from django.db.models.functions import Cast
 
 class DecimalInputMixin:
 
-    def as_postgresql(self, compiler, connection):
+    def as_postgresql(self, compiler, connection, **extra_context):
         # Cast FloatField to DecimalField as PostgreSQL doesn't support the
         # following function signatures:
         # - LOG(double, double)
@@ -20,7 +20,7 @@ class DecimalInputMixin:
             Cast(expression, output_field) if isinstance(expression.output_field, FloatField)
             else expression for expression in self.get_source_expressions()
         ])
-        return clone.as_sql(compiler, connection)
+        return clone.as_sql(compiler, connection, **extra_context)
 
 
 class OutputFieldMixin:
@@ -54,7 +54,7 @@ class ATan2(OutputFieldMixin, Func):
     function = 'ATAN2'
     arity = 2
 
-    def as_sqlite(self, compiler, connection):
+    def as_sqlite(self, compiler, connection, **extra_context):
         if not getattr(connection.ops, 'spatialite', False) or connection.ops.spatial_version < (4, 3, 0):
             return self.as_sql(compiler, connection)
         # This function is usually ATan2(y, x), returning the inverse tangent
@@ -67,15 +67,15 @@ class ATan2(OutputFieldMixin, Func):
             Cast(expression, FloatField()) if isinstance(expression.output_field, IntegerField)
             else expression for expression in self.get_source_expressions()[::-1]
         ])
-        return clone.as_sql(compiler, connection)
+        return clone.as_sql(compiler, connection, **extra_context)
 
 
 class Ceil(Transform):
     function = 'CEILING'
     lookup_name = 'ceil'
 
-    def as_oracle(self, compiler, connection):
-        return super().as_sql(compiler, connection, function='CEIL')
+    def as_oracle(self, compiler, connection, **extra_context):
+        return super().as_sql(compiler, connection, function='CEIL', **extra_context)
 
 
 class Cos(OutputFieldMixin, Transform):
@@ -87,16 +87,20 @@ class Cot(OutputFieldMixin, Transform):
     function = 'COT'
     lookup_name = 'cot'
 
-    def as_oracle(self, compiler, connection):
-        return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))')
+    def as_oracle(self, compiler, connection, **extra_context):
+        return super().as_sql(compiler, connection, template='(1 / TAN(%(expressions)s))', **extra_context)
 
 
 class Degrees(OutputFieldMixin, Transform):
     function = 'DEGREES'
     lookup_name = 'degrees'
 
-    def as_oracle(self, compiler, connection):
-        return super().as_sql(compiler, connection, template='((%%(expressions)s) * 180 / %s)' % math.pi)
+    def as_oracle(self, compiler, connection, **extra_context):
+        return super().as_sql(
+            compiler, connection,
+            template='((%%(expressions)s) * 180 / %s)' % math.pi,
+            **extra_context
+        )
 
 
 class Exp(OutputFieldMixin, Transform):
@@ -118,14 +122,14 @@ class Log(DecimalInputMixin, OutputFieldMixin, Func):
     function = 'LOG'
     arity = 2
 
-    def as_sqlite(self, compiler, connection):
+    def as_sqlite(self, compiler, connection, **extra_context):
         if not getattr(connection.ops, 'spatialite', False):
             return self.as_sql(compiler, connection)
         # This function is usually Log(b, x) returning the logarithm of x to
         # the base b, but on SpatiaLite it's Log(x, b).
         clone = self.copy()
         clone.set_source_expressions(self.get_source_expressions()[::-1])
-        return clone.as_sql(compiler, connection)
+        return clone.as_sql(compiler, connection, **extra_context)
 
 
 class Mod(DecimalInputMixin, OutputFieldMixin, Func):
@@ -137,8 +141,8 @@ class Pi(OutputFieldMixin, Func):
     function = 'PI'
     arity = 0
 
-    def as_oracle(self, compiler, connection):
-        return super().as_sql(compiler, connection, template=str(math.pi))
+    def as_oracle(self, compiler, connection, **extra_context):
+        return super().as_sql(compiler, connection, template=str(math.pi), **extra_context)
 
 
 class Power(OutputFieldMixin, Func):
@@ -150,8 +154,12 @@ class Radians(OutputFieldMixin, Transform):
     function = 'RADIANS'
     lookup_name = 'radians'
 
-    def as_oracle(self, compiler, connection):
-        return super().as_sql(compiler, connection, template='((%%(expressions)s) * %s / 180)' % math.pi)
+    def as_oracle(self, compiler, connection, **extra_context):
+        return super().as_sql(
+            compiler, connection,
+            template='((%%(expressions)s) * %s / 180)' % math.pi,
+            **extra_context
+        )
 
 
 class Round(Transform):

+ 25 - 16
django/db/models/functions/text.py

@@ -22,13 +22,19 @@ class Chr(Transform):
     function = 'CHR'
     lookup_name = 'chr'
 
-    def as_mysql(self, compiler, connection):
+    def as_mysql(self, compiler, connection, **extra_context):
         return super().as_sql(
-            compiler, connection, function='CHAR', template='%(function)s(%(expressions)s USING utf16)'
+            compiler, connection, function='CHAR',
+            template='%(function)s(%(expressions)s USING utf16)',
+            **extra_context
         )
 
-    def as_oracle(self, compiler, connection):
-        return super().as_sql(compiler, connection, template='%(function)s(%(expressions)s USING NCHAR_CS)')
+    def as_oracle(self, compiler, connection, **extra_context):
+        return super().as_sql(
+            compiler, connection,
+            template='%(function)s(%(expressions)s USING NCHAR_CS)',
+            **extra_context
+        )
 
     def as_sqlite(self, compiler, connection, **extra_context):
         return super().as_sql(compiler, connection, function='CHAR', **extra_context)
@@ -41,16 +47,19 @@ class ConcatPair(Func):
     """
     function = 'CONCAT'
 
-    def as_sqlite(self, compiler, connection):
+    def as_sqlite(self, compiler, connection, **extra_context):
         coalesced = self.coalesce()
         return super(ConcatPair, coalesced).as_sql(
-            compiler, connection, template='%(expressions)s', arg_joiner=' || '
+            compiler, connection, template='%(expressions)s', arg_joiner=' || ',
+            **extra_context
         )
 
-    def as_mysql(self, compiler, connection):
+    def as_mysql(self, compiler, connection, **extra_context):
         # Use CONCAT_WS with an empty separator so that NULLs are ignored.
         return super().as_sql(
-            compiler, connection, function='CONCAT_WS', template="%(function)s('', %(expressions)s)"
+            compiler, connection, function='CONCAT_WS',
+            template="%(function)s('', %(expressions)s)",
+            **extra_context
         )
 
     def coalesce(self):
@@ -117,8 +126,8 @@ class Length(Transform):
     lookup_name = 'length'
     output_field = fields.IntegerField()
 
-    def as_mysql(self, compiler, connection):
-        return super().as_sql(compiler, connection, function='CHAR_LENGTH')
+    def as_mysql(self, compiler, connection, **extra_context):
+        return super().as_sql(compiler, connection, function='CHAR_LENGTH', **extra_context)
 
 
 class Lower(Transform):
@@ -199,8 +208,8 @@ class StrIndex(Func):
     arity = 2
     output_field = fields.IntegerField()
 
-    def as_postgresql(self, compiler, connection):
-        return super().as_sql(compiler, connection, function='STRPOS')
+    def as_postgresql(self, compiler, connection, **extra_context):
+        return super().as_sql(compiler, connection, function='STRPOS', **extra_context)
 
 
 class Substr(Func):
@@ -220,11 +229,11 @@ class Substr(Func):
             expressions.append(length)
         super().__init__(*expressions, **extra)
 
-    def as_sqlite(self, compiler, connection):
-        return super().as_sql(compiler, connection, function='SUBSTR')
+    def as_sqlite(self, compiler, connection, **extra_context):
+        return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
 
-    def as_oracle(self, compiler, connection):
-        return super().as_sql(compiler, connection, function='SUBSTR')
+    def as_oracle(self, compiler, connection, **extra_context):
+        return super().as_sql(compiler, connection, function='SUBSTR', **extra_context)
 
 
 class Trim(Transform):

+ 1 - 1
docs/howto/custom-lookups.txt

@@ -275,7 +275,7 @@ We can change the behavior on a specific backend by creating a subclass of
 ``NotEqual`` with an ``as_mysql`` method::
 
   class MySQLNotEqual(NotEqual):
-      def as_mysql(self, compiler, connection):
+      def as_mysql(self, compiler, connection, **extra_context):
           lhs, lhs_params = self.process_lhs(compiler, connection)
           rhs, rhs_params = self.process_rhs(compiler, connection)
           params = lhs_params + rhs_params

+ 2 - 1
docs/ref/models/expressions.txt

@@ -322,11 +322,12 @@ The ``Func`` API is as follows:
                 function = 'CONCAT'
                 ...
 
-                def as_mysql(self, compiler, connection):
+                def as_mysql(self, compiler, connection, **extra_context):
                     return super().as_sql(
                         compiler, connection,
                         function='CONCAT_WS',
                         template="%(function)s('', %(expressions)s)",
+                        **extra_context
                     )
 
         To avoid a SQL injection vulnerability, ``extra_context`` :ref:`must

+ 2 - 2
tests/aggregation/tests.py

@@ -1083,8 +1083,8 @@ class AggregateTestCase(TestCase):
         class Greatest(Func):
             function = 'GREATEST'
 
-            def as_sqlite(self, compiler, connection):
-                return super().as_sql(compiler, connection, function='MAX')
+            def as_sqlite(self, compiler, connection, **extra_context):
+                return super().as_sql(compiler, connection, function='MAX', **extra_context)
 
         qs = Publisher.objects.annotate(
             price_or_median=Greatest(Avg('book__rating'), Avg('book__price'))

+ 1 - 1
tests/custom_lookups/tests.py

@@ -34,7 +34,7 @@ class Div3Transform(models.Transform):
         lhs, lhs_params = compiler.compile(self.lhs)
         return '(%s) %%%% 3' % lhs, lhs_params
 
-    def as_oracle(self, compiler, connection):
+    def as_oracle(self, compiler, connection, **extra_context):
         lhs, lhs_params = compiler.compile(self.lhs)
         return 'mod(%s, 3)' % lhs, lhs_params