Explorar el Código

Fixed #30651 -- Made __eq__() methods return NotImplemented for not implemented comparisons.

Changed __eq__ to return NotImplemented instead of False if compared to
an object of the same type, as is recommended by the Python data model
reference. Now these models can be compared to ANY (or other objects
with __eq__ overwritten) without returning False automatically.
ElizabethU hace 5 años
padre
commit
54ea290e5b

+ 3 - 2
django/contrib/messages/storage/base.py

@@ -25,8 +25,9 @@ class Message:
         self.extra_tags = str(self.extra_tags) if self.extra_tags is not None else None
 
     def __eq__(self, other):
-        return isinstance(other, Message) and self.level == other.level and \
-            self.message == other.message
+        if not isinstance(other, Message):
+            return NotImplemented
+        return self.level == other.level and self.message == other.message
 
     def __str__(self):
         return str(self.message)

+ 8 - 7
django/contrib/postgres/constraints.py

@@ -89,13 +89,14 @@ class ExclusionConstraint(BaseConstraint):
         return path, args, kwargs
 
     def __eq__(self, other):
-        return (
-            isinstance(other, self.__class__) and
-            self.name == other.name and
-            self.index_type == other.index_type and
-            self.expressions == other.expressions and
-            self.condition == other.condition
-        )
+        if isinstance(other, self.__class__):
+            return (
+                self.name == other.name and
+                self.index_type == other.index_type and
+                self.expressions == other.expressions and
+                self.condition == other.condition
+            )
+        return super().__eq__(other)
 
     def __repr__(self):
         return '<%s: index_type=%s, expressions=%s%s>' % (

+ 2 - 1
django/core/validators.py

@@ -324,8 +324,9 @@ class BaseValidator:
             raise ValidationError(self.message, code=self.code, params=params)
 
     def __eq__(self, other):
+        if not isinstance(other, self.__class__):
+            return NotImplemented
         return (
-            isinstance(other, self.__class__) and
             self.limit_value == other.limit_value and
             self.message == other.message and
             self.code == other.code

+ 1 - 1
django/db/models/base.py

@@ -522,7 +522,7 @@ class Model(metaclass=ModelBase):
 
     def __eq__(self, other):
         if not isinstance(other, Model):
-            return False
+            return NotImplemented
         if self._meta.concrete_model != other._meta.concrete_model:
             return False
         my_pk = self.pk

+ 10 - 11
django/db/models/constraints.py

@@ -54,11 +54,9 @@ class CheckConstraint(BaseConstraint):
         return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name)
 
     def __eq__(self, other):
-        return (
-            isinstance(other, CheckConstraint) and
-            self.name == other.name and
-            self.check == other.check
-        )
+        if isinstance(other, CheckConstraint):
+            return self.name == other.name and self.check == other.check
+        return super().__eq__(other)
 
     def deconstruct(self):
         path, args, kwargs = super().deconstruct()
@@ -106,12 +104,13 @@ class UniqueConstraint(BaseConstraint):
         )
 
     def __eq__(self, other):
-        return (
-            isinstance(other, UniqueConstraint) and
-            self.name == other.name and
-            self.fields == other.fields and
-            self.condition == other.condition
-        )
+        if isinstance(other, UniqueConstraint):
+            return (
+                self.name == other.name and
+                self.fields == other.fields and
+                self.condition == other.condition
+            )
+        return super().__eq__(other)
 
     def deconstruct(self):
         path, args, kwargs = super().deconstruct()

+ 3 - 1
django/db/models/expressions.py

@@ -401,7 +401,9 @@ class BaseExpression:
         return tuple(identity)
 
     def __eq__(self, other):
-        return isinstance(other, BaseExpression) and other.identity == self.identity
+        if not isinstance(other, BaseExpression):
+            return NotImplemented
+        return other.identity == self.identity
 
     def __hash__(self):
         return hash(self.identity)

+ 3 - 1
django/db/models/indexes.py

@@ -112,4 +112,6 @@ class Index:
         )
 
     def __eq__(self, other):
-        return (self.__class__ == other.__class__) and (self.deconstruct() == other.deconstruct())
+        if self.__class__ == other.__class__:
+            return self.deconstruct() == other.deconstruct()
+        return NotImplemented

+ 3 - 1
django/db/models/query.py

@@ -1543,7 +1543,9 @@ class Prefetch:
         return None
 
     def __eq__(self, other):
-        return isinstance(other, Prefetch) and self.prefetch_to == other.prefetch_to
+        if not isinstance(other, Prefetch):
+            return NotImplemented
+        return self.prefetch_to == other.prefetch_to
 
     def __hash__(self):
         return hash((self.__class__, self.prefetch_to))

+ 2 - 1
django/db/models/query_utils.py

@@ -309,8 +309,9 @@ class FilteredRelation:
         self.path = []
 
     def __eq__(self, other):
+        if not isinstance(other, self.__class__):
+            return NotImplemented
         return (
-            isinstance(other, self.__class__) and
             self.relation_name == other.relation_name and
             self.alias == other.alias and
             self.condition == other.condition

+ 4 - 6
django/template/context.py

@@ -124,12 +124,10 @@ class BaseContext:
         """
         Compare two contexts by comparing theirs 'dicts' attributes.
         """
-        return (
-            isinstance(other, BaseContext) and
-            # because dictionaries can be put in different order
-            # we have to flatten them like in templates
-            self.flatten() == other.flatten()
-        )
+        if not isinstance(other, BaseContext):
+            return NotImplemented
+        # flatten dictionaries because they can be put in a different order.
+        return self.flatten() == other.flatten()
 
 
 class Context(BaseContext):

+ 2 - 0
tests/basic/tests.py

@@ -1,5 +1,6 @@
 import threading
 from datetime import datetime, timedelta
+from unittest import mock
 
 from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist
 from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models
@@ -354,6 +355,7 @@ class ModelTest(TestCase):
         self.assertNotEqual(object(), Article(id=1))
         a = Article()
         self.assertEqual(a, a)
+        self.assertEqual(a, mock.ANY)
         self.assertNotEqual(Article(), a)
 
     def test_hash(self):

+ 7 - 0
tests/constraints/tests.py

@@ -1,3 +1,5 @@
+from unittest import mock
+
 from django.core.exceptions import ValidationError
 from django.db import IntegrityError, connection, models
 from django.db.models.constraints import BaseConstraint
@@ -39,6 +41,7 @@ class CheckConstraintTests(TestCase):
             models.CheckConstraint(check=check1, name='price'),
             models.CheckConstraint(check=check1, name='price'),
         )
+        self.assertEqual(models.CheckConstraint(check=check1, name='price'), mock.ANY)
         self.assertNotEqual(
             models.CheckConstraint(check=check1, name='price'),
             models.CheckConstraint(check=check1, name='price2'),
@@ -102,6 +105,10 @@ class UniqueConstraintTests(TestCase):
             models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
             models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
         )
+        self.assertEqual(
+            models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
+            mock.ANY,
+        )
         self.assertNotEqual(
             models.UniqueConstraint(fields=['foo', 'bar'], name='unique'),
             models.UniqueConstraint(fields=['foo', 'bar'], name='unique2'),

+ 2 - 0
tests/expressions/tests.py

@@ -3,6 +3,7 @@ import pickle
 import unittest
 import uuid
 from copy import deepcopy
+from unittest import mock
 
 from django.core.exceptions import FieldError
 from django.db import DatabaseError, connection, models
@@ -965,6 +966,7 @@ class SimpleExpressionTests(SimpleTestCase):
             Expression(models.IntegerField()),
             Expression(output_field=models.IntegerField())
         )
+        self.assertEqual(Expression(models.IntegerField()), mock.ANY)
         self.assertNotEqual(
             Expression(models.IntegerField()),
             Expression(models.CharField())

+ 5 - 0
tests/filtered_relation/tests.py

@@ -1,3 +1,5 @@
+from unittest import mock
+
 from django.db import connection, transaction
 from django.db.models import Case, Count, F, FilteredRelation, Q, When
 from django.test import TestCase
@@ -323,6 +325,9 @@ class FilteredRelationTests(TestCase):
             [self.book1]
         )
 
+    def test_eq(self):
+        self.assertEqual(FilteredRelation('book', condition=Q(book__title='b')), mock.ANY)
+
 
 class FilteredRelationAggregationTests(TestCase):
 

+ 3 - 0
tests/messages_tests/tests.py

@@ -1,3 +1,5 @@
+from unittest import mock
+
 from django.contrib.messages import constants
 from django.contrib.messages.storage.base import Message
 from django.test import SimpleTestCase
@@ -9,6 +11,7 @@ class MessageTests(SimpleTestCase):
         msg_2 = Message(constants.INFO, 'Test message 2')
         msg_3 = Message(constants.WARNING, 'Test message 1')
         self.assertEqual(msg_1, msg_1)
+        self.assertEqual(msg_1, mock.ANY)
         self.assertNotEqual(msg_1, msg_2)
         self.assertNotEqual(msg_1, msg_3)
         self.assertNotEqual(msg_2, msg_3)

+ 3 - 0
tests/model_indexes/tests.py

@@ -1,3 +1,5 @@
+from unittest import mock
+
 from django.conf import settings
 from django.db import connection, models
 from django.db.models.query_utils import Q
@@ -28,6 +30,7 @@ class SimpleIndexesTests(SimpleTestCase):
         same_index.model = Book
         another_index.model = Book
         self.assertEqual(index, same_index)
+        self.assertEqual(index, mock.ANY)
         self.assertNotEqual(index, another_index)
 
     def test_index_fields_type(self):

+ 2 - 0
tests/postgres_tests/test_constraints.py

@@ -1,4 +1,5 @@
 import datetime
+from unittest import mock
 
 from django.db import connection, transaction
 from django.db.models import F, Func, Q
@@ -175,6 +176,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase):
             condition=Q(cancelled=False),
         )
         self.assertEqual(constraint_1, constraint_1)
+        self.assertEqual(constraint_1, mock.ANY)
         self.assertNotEqual(constraint_1, constraint_2)
         self.assertNotEqual(constraint_1, constraint_3)
         self.assertNotEqual(constraint_2, constraint_3)

+ 3 - 0
tests/prefetch_related/tests.py

@@ -1,3 +1,5 @@
+from unittest import mock
+
 from django.contrib.contenttypes.models import ContentType
 from django.core.exceptions import ObjectDoesNotExist
 from django.db import connection
@@ -243,6 +245,7 @@ class PrefetchRelatedTests(TestDataMixin, TestCase):
         prefetch_1 = Prefetch('authors', queryset=Author.objects.all())
         prefetch_2 = Prefetch('books', queryset=Book.objects.all())
         self.assertEqual(prefetch_1, prefetch_1)
+        self.assertEqual(prefetch_1, mock.ANY)
         self.assertNotEqual(prefetch_1, prefetch_2)
 
     def test_forward_m2m_to_attr_conflict(self):

+ 3 - 0
tests/template_tests/test_context.py

@@ -1,3 +1,5 @@
+from unittest import mock
+
 from django.http import HttpRequest
 from django.template import (
     Context, Engine, RequestContext, Template, Variable, VariableDoesNotExist,
@@ -18,6 +20,7 @@ class ContextTests(SimpleTestCase):
         self.assertEqual(c.pop(), {"a": 2})
         self.assertEqual(c["a"], 1)
         self.assertEqual(c.get("foo", 42), 42)
+        self.assertEqual(c, mock.ANY)
 
     def test_push_context_manager(self):
         c = Context({"a": 1})

+ 2 - 1
tests/validators/tests.py

@@ -3,7 +3,7 @@ import re
 import types
 from datetime import datetime, timedelta
 from decimal import Decimal
-from unittest import TestCase
+from unittest import TestCase, mock
 
 from django.core.exceptions import ValidationError
 from django.core.files.base import ContentFile
@@ -424,6 +424,7 @@ class TestValidatorEquality(TestCase):
             MaxValueValidator(44),
             MaxValueValidator(44),
         )
+        self.assertEqual(MaxValueValidator(44), mock.ANY)
         self.assertNotEqual(
             MaxValueValidator(44),
             MinValueValidator(44),