|
@@ -1,7 +1,9 @@
|
|
|
import copy
|
|
|
import datetime
|
|
|
+import functools
|
|
|
import inspect
|
|
|
from decimal import Decimal
|
|
|
+from uuid import UUID
|
|
|
|
|
|
from django.core.exceptions import EmptyResultSet, FieldError
|
|
|
from django.db import NotSupportedError, connection
|
|
@@ -56,12 +58,7 @@ class Combinable:
|
|
|
def _combine(self, other, connector, reversed):
|
|
|
if not hasattr(other, 'resolve_expression'):
|
|
|
# everything must be resolvable to an expression
|
|
|
- output_field = (
|
|
|
- fields.DurationField()
|
|
|
- if isinstance(other, datetime.timedelta) else
|
|
|
- None
|
|
|
- )
|
|
|
- other = Value(other, output_field=output_field)
|
|
|
+ other = Value(other)
|
|
|
|
|
|
if reversed:
|
|
|
return CombinedExpression(other, connector, self)
|
|
@@ -422,6 +419,25 @@ class Expression(BaseExpression, Combinable):
|
|
|
pass
|
|
|
|
|
|
|
|
|
+_connector_combinators = {
|
|
|
+ connector: [
|
|
|
+ (fields.IntegerField, fields.DecimalField, fields.DecimalField),
|
|
|
+ (fields.DecimalField, fields.IntegerField, fields.DecimalField),
|
|
|
+ (fields.IntegerField, fields.FloatField, fields.FloatField),
|
|
|
+ (fields.FloatField, fields.IntegerField, fields.FloatField),
|
|
|
+ ]
|
|
|
+ for connector in (Combinable.ADD, Combinable.SUB, Combinable.MUL, Combinable.DIV)
|
|
|
+}
|
|
|
+
|
|
|
+
|
|
|
+@functools.lru_cache(maxsize=128)
|
|
|
+def _resolve_combined_type(connector, lhs_type, rhs_type):
|
|
|
+ combinators = _connector_combinators.get(connector, ())
|
|
|
+ for combinator_lhs_type, combinator_rhs_type, combined_type in combinators:
|
|
|
+ if issubclass(lhs_type, combinator_lhs_type) and issubclass(rhs_type, combinator_rhs_type):
|
|
|
+ return combined_type
|
|
|
+
|
|
|
+
|
|
|
class CombinedExpression(SQLiteNumericMixin, Expression):
|
|
|
|
|
|
def __init__(self, lhs, connector, rhs, output_field=None):
|
|
@@ -442,6 +458,19 @@ class CombinedExpression(SQLiteNumericMixin, Expression):
|
|
|
def set_source_expressions(self, exprs):
|
|
|
self.lhs, self.rhs = exprs
|
|
|
|
|
|
+ def _resolve_output_field(self):
|
|
|
+ try:
|
|
|
+ return super()._resolve_output_field()
|
|
|
+ except FieldError:
|
|
|
+ combined_type = _resolve_combined_type(
|
|
|
+ self.connector,
|
|
|
+ type(self.lhs.output_field),
|
|
|
+ type(self.rhs.output_field),
|
|
|
+ )
|
|
|
+ if combined_type is None:
|
|
|
+ raise
|
|
|
+ return combined_type()
|
|
|
+
|
|
|
def as_sql(self, compiler, connection):
|
|
|
expressions = []
|
|
|
expression_params = []
|
|
@@ -721,6 +750,30 @@ class Value(Expression):
|
|
|
def get_group_by_cols(self, alias=None):
|
|
|
return []
|
|
|
|
|
|
+ def _resolve_output_field(self):
|
|
|
+ if isinstance(self.value, str):
|
|
|
+ return fields.CharField()
|
|
|
+ if isinstance(self.value, bool):
|
|
|
+ return fields.BooleanField()
|
|
|
+ if isinstance(self.value, int):
|
|
|
+ return fields.IntegerField()
|
|
|
+ if isinstance(self.value, float):
|
|
|
+ return fields.FloatField()
|
|
|
+ if isinstance(self.value, datetime.datetime):
|
|
|
+ return fields.DateTimeField()
|
|
|
+ if isinstance(self.value, datetime.date):
|
|
|
+ return fields.DateField()
|
|
|
+ if isinstance(self.value, datetime.time):
|
|
|
+ return fields.TimeField()
|
|
|
+ if isinstance(self.value, datetime.timedelta):
|
|
|
+ return fields.DurationField()
|
|
|
+ if isinstance(self.value, Decimal):
|
|
|
+ return fields.DecimalField()
|
|
|
+ if isinstance(self.value, bytes):
|
|
|
+ return fields.BinaryField()
|
|
|
+ if isinstance(self.value, UUID):
|
|
|
+ return fields.UUIDField()
|
|
|
+
|
|
|
|
|
|
class RawSQL(Expression):
|
|
|
def __init__(self, sql, params, output_field=None):
|
|
@@ -1177,7 +1230,6 @@ class OrderBy(BaseExpression):
|
|
|
copy.expression = Case(
|
|
|
When(self.expression, then=True),
|
|
|
default=False,
|
|
|
- output_field=fields.BooleanField(),
|
|
|
)
|
|
|
return copy.as_sql(compiler, connection)
|
|
|
return self.as_sql(compiler, connection)
|