related_lookups.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. from django.db.models.lookups import (
  2. Exact, GreaterThan, GreaterThanOrEqual, In, IsNull, LessThan,
  3. LessThanOrEqual,
  4. )
  5. class MultiColSource:
  6. contains_aggregate = False
  7. def __init__(self, alias, targets, sources, field):
  8. self.targets, self.sources, self.field, self.alias = targets, sources, field, alias
  9. self.output_field = self.field
  10. def __repr__(self):
  11. return "{}({}, {})".format(
  12. self.__class__.__name__, self.alias, self.field)
  13. def relabeled_clone(self, relabels):
  14. return self.__class__(relabels.get(self.alias, self.alias),
  15. self.targets, self.sources, self.field)
  16. def get_lookup(self, lookup):
  17. return self.output_field.get_lookup(lookup)
  18. def resolve_expression(self, *args, **kwargs):
  19. return self
  20. def get_normalized_value(value, lhs):
  21. from django.db.models import Model
  22. if isinstance(value, Model):
  23. value_list = []
  24. sources = lhs.output_field.get_path_info()[-1].target_fields
  25. for source in sources:
  26. while not isinstance(value, source.model) and source.remote_field:
  27. source = source.remote_field.model._meta.get_field(source.remote_field.field_name)
  28. try:
  29. value_list.append(getattr(value, source.attname))
  30. except AttributeError:
  31. # A case like Restaurant.objects.filter(place=restaurant_instance),
  32. # where place is a OneToOneField and the primary key of Restaurant.
  33. return (value.pk,)
  34. return tuple(value_list)
  35. if not isinstance(value, tuple):
  36. return (value,)
  37. return value
  38. class RelatedIn(In):
  39. def get_prep_lookup(self):
  40. if not isinstance(self.lhs, MultiColSource) and self.rhs_is_direct_value():
  41. # If we get here, we are dealing with single-column relations.
  42. self.rhs = [get_normalized_value(val, self.lhs)[0] for val in self.rhs]
  43. # We need to run the related field's get_prep_value(). Consider case
  44. # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
  45. # doesn't have validation for non-integers, so we must run validation
  46. # using the target field.
  47. if hasattr(self.lhs.output_field, 'get_path_info'):
  48. # Run the target field's get_prep_value. We can safely assume there is
  49. # only one as we don't get to the direct value branch otherwise.
  50. target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
  51. self.rhs = [target_field.get_prep_value(v) for v in self.rhs]
  52. return super().get_prep_lookup()
  53. def as_sql(self, compiler, connection):
  54. if isinstance(self.lhs, MultiColSource):
  55. # For multicolumn lookups we need to build a multicolumn where clause.
  56. # This clause is either a SubqueryConstraint (for values that need to be compiled to
  57. # SQL) or an OR-combined list of (col1 = val1 AND col2 = val2 AND ...) clauses.
  58. from django.db.models.sql.where import (
  59. AND, OR, SubqueryConstraint, WhereNode,
  60. )
  61. root_constraint = WhereNode(connector=OR)
  62. if self.rhs_is_direct_value():
  63. values = [get_normalized_value(value, self.lhs) for value in self.rhs]
  64. for value in values:
  65. value_constraint = WhereNode()
  66. for source, target, val in zip(self.lhs.sources, self.lhs.targets, value):
  67. lookup_class = target.get_lookup('exact')
  68. lookup = lookup_class(target.get_col(self.lhs.alias, source), val)
  69. value_constraint.add(lookup, AND)
  70. root_constraint.add(value_constraint, OR)
  71. else:
  72. root_constraint.add(
  73. SubqueryConstraint(
  74. self.lhs.alias, [target.column for target in self.lhs.targets],
  75. [source.name for source in self.lhs.sources], self.rhs),
  76. AND)
  77. return root_constraint.as_sql(compiler, connection)
  78. else:
  79. if (not getattr(self.rhs, 'has_select_fields', True) and
  80. not getattr(self.lhs.field.target_field, 'primary_key', False)):
  81. self.rhs.clear_select_clause()
  82. if (getattr(self.lhs.output_field, 'primary_key', False) and
  83. self.lhs.output_field.model == self.rhs.model):
  84. # A case like Restaurant.objects.filter(place__in=restaurant_qs),
  85. # where place is a OneToOneField and the primary key of
  86. # Restaurant.
  87. target_field = self.lhs.field.name
  88. else:
  89. target_field = self.lhs.field.target_field.name
  90. self.rhs.add_fields([target_field], True)
  91. return super().as_sql(compiler, connection)
  92. class RelatedLookupMixin:
  93. def get_prep_lookup(self):
  94. if not isinstance(self.lhs, MultiColSource) and not hasattr(self.rhs, 'resolve_expression'):
  95. # If we get here, we are dealing with single-column relations.
  96. self.rhs = get_normalized_value(self.rhs, self.lhs)[0]
  97. # We need to run the related field's get_prep_value(). Consider case
  98. # ForeignKey to IntegerField given value 'abc'. The ForeignKey itself
  99. # doesn't have validation for non-integers, so we must run validation
  100. # using the target field.
  101. if self.prepare_rhs and hasattr(self.lhs.output_field, 'get_path_info'):
  102. # Get the target field. We can safely assume there is only one
  103. # as we don't get to the direct value branch otherwise.
  104. target_field = self.lhs.output_field.get_path_info()[-1].target_fields[-1]
  105. self.rhs = target_field.get_prep_value(self.rhs)
  106. return super().get_prep_lookup()
  107. def as_sql(self, compiler, connection):
  108. if isinstance(self.lhs, MultiColSource):
  109. assert self.rhs_is_direct_value()
  110. self.rhs = get_normalized_value(self.rhs, self.lhs)
  111. from django.db.models.sql.where import AND, WhereNode
  112. root_constraint = WhereNode()
  113. for target, source, val in zip(self.lhs.targets, self.lhs.sources, self.rhs):
  114. lookup_class = target.get_lookup(self.lookup_name)
  115. root_constraint.add(
  116. lookup_class(target.get_col(self.lhs.alias, source), val), AND)
  117. return root_constraint.as_sql(compiler, connection)
  118. return super().as_sql(compiler, connection)
  119. class RelatedExact(RelatedLookupMixin, Exact):
  120. pass
  121. class RelatedLessThan(RelatedLookupMixin, LessThan):
  122. pass
  123. class RelatedGreaterThan(RelatedLookupMixin, GreaterThan):
  124. pass
  125. class RelatedGreaterThanOrEqual(RelatedLookupMixin, GreaterThanOrEqual):
  126. pass
  127. class RelatedLessThanOrEqual(RelatedLookupMixin, LessThanOrEqual):
  128. pass
  129. class RelatedIsNull(RelatedLookupMixin, IsNull):
  130. pass