Browse Source

Fixed #373 -- Added CompositePrimaryKey.

Thanks Lily Foote and Simon Charette for reviews and mentoring
this Google Summer of Code 2024 project.

Co-authored-by: Simon Charette <charette.s@gmail.com>
Co-authored-by: Lily Foote <code@lilyf.org>
Bendeguz Csirmaz 11 months ago
parent
commit
978aae4334
43 changed files with 3078 additions and 29 deletions
  1. 5 0
      django/contrib/admin/sites.py
  2. 3 0
      django/core/serializers/python.py
  3. 12 0
      django/db/backends/base/schema.py
  4. 2 0
      django/db/backends/oracle/schema.py
  5. 14 1
      django/db/backends/sqlite3/schema.py
  6. 2 0
      django/db/models/__init__.py
  7. 18 1
      django/db/models/aggregates.py
  8. 71 2
      django/db/models/base.py
  9. 2 0
      django/db/models/fields/__init__.py
  10. 150 0
      django/db/models/fields/composite.py
  11. 15 5
      django/db/models/fields/related.py
  12. 4 2
      django/db/models/fields/related_lookups.py
  13. 9 1
      django/db/models/fields/tuple_lookups.py
  14. 16 1
      django/db/models/options.py
  15. 6 3
      django/db/models/query.py
  16. 65 10
      django/db/models/sql/compiler.py
  17. 6 2
      django/db/models/sql/query.py
  18. 3 0
      docs/ref/checks.txt
  19. 19 0
      docs/ref/models/fields.txt
  20. 19 0
      docs/releases/5.2.txt
  21. 183 0
      docs/topics/composite-primary-key.txt
  22. 1 0
      docs/topics/index.txt
  23. 6 0
      tests/admin_registration/models.py
  24. 9 1
      tests/admin_registration/tests.py
  25. 0 0
      tests/composite_pk/__init__.py
  26. 75 0
      tests/composite_pk/fixtures/tenant.json
  27. 9 0
      tests/composite_pk/models/__init__.py
  28. 50 0
      tests/composite_pk/models/tenant.py
  29. 139 0
      tests/composite_pk/test_aggregate.py
  30. 242 0
      tests/composite_pk/test_checks.py
  31. 138 0
      tests/composite_pk/test_create.py
  32. 83 0
      tests/composite_pk/test_delete.py
  33. 412 0
      tests/composite_pk/test_filter.py
  34. 126 0
      tests/composite_pk/test_get.py
  35. 153 0
      tests/composite_pk/test_models.py
  36. 134 0
      tests/composite_pk/test_names_to_path.py
  37. 135 0
      tests/composite_pk/test_update.py
  38. 212 0
      tests/composite_pk/test_values.py
  39. 345 0
      tests/composite_pk/tests.py
  40. 89 0
      tests/migrations/test_autodetector.py
  41. 55 0
      tests/migrations/test_operations.py
  42. 22 0
      tests/migrations/test_state.py
  43. 19 0
      tests/migrations/test_writer.py

+ 5 - 0
django/contrib/admin/sites.py

@@ -113,6 +113,11 @@ class AdminSite:
                     "The model %s is abstract, so it cannot be registered with admin."
                     % model.__name__
                 )
+            if model._meta.is_composite_pk:
+                raise ImproperlyConfigured(
+                    "The model %s has a composite primary key, so it cannot be "
+                    "registered with admin." % model.__name__
+                )
 
             if self.is_registered(model):
                 registered_admin = str(self.get_model_admin(model))

+ 3 - 0
django/core/serializers/python.py

@@ -7,6 +7,7 @@ other serializers.
 from django.apps import apps
 from django.core.serializers import base
 from django.db import DEFAULT_DB_ALIAS, models
+from django.db.models import CompositePrimaryKey
 from django.utils.encoding import is_protected_type
 
 
@@ -39,6 +40,8 @@ class Serializer(base.Serializer):
         return data
 
     def _value_from_field(self, obj, field):
+        if isinstance(field, CompositePrimaryKey):
+            return [self._value_from_field(obj, f) for f in field]
         value = field.value_from_object(obj)
         # Protected types (i.e., primitives like None, numbers, dates,
         # and Decimals) are passed through as is. All other values are

+ 12 - 0
django/db/backends/base/schema.py

@@ -14,6 +14,7 @@ from django.db.backends.ddl_references import (
 )
 from django.db.backends.utils import names_digest, split_identifier, truncate_name
 from django.db.models import NOT_PROVIDED, Deferrable, Index
+from django.db.models.fields.composite import CompositePrimaryKey
 from django.db.models.sql import Query
 from django.db.transaction import TransactionManagementError, atomic
 from django.utils import timezone
@@ -106,6 +107,7 @@ class BaseDatabaseSchemaEditor:
     sql_check_constraint = "CHECK (%(check)s)"
     sql_delete_constraint = "ALTER TABLE %(table)s DROP CONSTRAINT %(name)s"
     sql_constraint = "CONSTRAINT %(name)s %(constraint)s"
+    sql_pk_constraint = "PRIMARY KEY (%(columns)s)"
 
     sql_create_check = "ALTER TABLE %(table)s ADD CONSTRAINT %(name)s CHECK (%(check)s)"
     sql_delete_check = sql_delete_constraint
@@ -282,6 +284,11 @@ class BaseDatabaseSchemaEditor:
                 constraint.constraint_sql(model, self)
                 for constraint in model._meta.constraints
             )
+
+        pk = model._meta.pk
+        if isinstance(pk, CompositePrimaryKey):
+            constraint_sqls.append(self._pk_constraint_sql(pk.columns))
+
         sql = self.sql_create_table % {
             "table": self.quote_name(model._meta.db_table),
             "definition": ", ".join(
@@ -1999,6 +2006,11 @@ class BaseDatabaseSchemaEditor:
                     result.append(name)
         return result
 
+    def _pk_constraint_sql(self, columns):
+        return self.sql_pk_constraint % {
+            "columns": ", ".join(self.quote_name(column) for column in columns)
+        }
+
     def _delete_primary_key(self, model, strict=False):
         constraint_names = self._constraint_names(model, primary_key=True)
         if strict and len(constraint_names) != 1:

+ 2 - 0
django/db/backends/oracle/schema.py

@@ -211,6 +211,8 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
         return create_index
 
     def _is_identity_column(self, table_name, column_name):
+        if not column_name:
+            return False
         with self.connection.cursor() as cursor:
             cursor.execute(
                 """

+ 14 - 1
django/db/backends/sqlite3/schema.py

@@ -6,7 +6,7 @@ from django.db import NotSupportedError
 from django.db.backends.base.schema import BaseDatabaseSchemaEditor
 from django.db.backends.ddl_references import Statement
 from django.db.backends.utils import strip_quotes
-from django.db.models import NOT_PROVIDED, UniqueConstraint
+from django.db.models import NOT_PROVIDED, CompositePrimaryKey, UniqueConstraint
 
 
 class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
@@ -104,6 +104,13 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
             f.name: f.clone() if is_self_referential(f) else f
             for f in model._meta.local_concrete_fields
         }
+
+        # Since CompositePrimaryKey is not a concrete field (column is None),
+        # it's not copied by default.
+        pk = model._meta.pk
+        if isinstance(pk, CompositePrimaryKey):
+            body[pk.name] = pk.clone()
+
         # Since mapping might mix column names and default values,
         # its values must be already quoted.
         mapping = {
@@ -296,6 +303,12 @@ class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
         # Special-case implicit M2M tables.
         if field.many_to_many and field.remote_field.through._meta.auto_created:
             self.create_model(field.remote_field.through)
+        elif isinstance(field, CompositePrimaryKey):
+            # If a CompositePrimaryKey field was added, the existing primary key field
+            # had to be altered too, resulting in an AddField, AlterField migration.
+            # The table cannot be re-created on AddField, it would result in a
+            # duplicate primary key error.
+            return
         elif (
             # Primary keys and unique fields are not supported in ALTER TABLE
             # ADD COLUMN.

+ 2 - 0
django/db/models/__init__.py

@@ -38,6 +38,7 @@ from django.db.models.expressions import (
 )
 from django.db.models.fields import *  # NOQA
 from django.db.models.fields import __all__ as fields_all
+from django.db.models.fields.composite import CompositePrimaryKey
 from django.db.models.fields.files import FileField, ImageField
 from django.db.models.fields.generated import GeneratedField
 from django.db.models.fields.json import JSONField
@@ -82,6 +83,7 @@ __all__ += [
     "ProtectedError",
     "RestrictedError",
     "Case",
+    "CompositePrimaryKey",
     "Exists",
     "Expression",
     "ExpressionList",

+ 18 - 1
django/db/models/aggregates.py

@@ -3,7 +3,8 @@ Classes to represent the definitions of aggregate functions.
 """
 
 from django.core.exceptions import FieldError, FullResultSet
-from django.db.models.expressions import Case, Func, Star, Value, When
+from django.db import NotSupportedError
+from django.db.models.expressions import Case, ColPairs, Func, Star, Value, When
 from django.db.models.fields import IntegerField
 from django.db.models.functions.comparison import Coalesce
 from django.db.models.functions.mixins import (
@@ -174,6 +175,22 @@ class Count(Aggregate):
             raise ValueError("Star cannot be used with filter. Please specify a field.")
         super().__init__(expression, filter=filter, **extra)
 
+    def resolve_expression(self, *args, **kwargs):
+        result = super().resolve_expression(*args, **kwargs)
+        expr = result.source_expressions[0]
+
+        # In case of composite primary keys, count the first column.
+        if isinstance(expr, ColPairs):
+            if self.distinct:
+                raise NotSupportedError(
+                    "COUNT(DISTINCT) doesn't support composite primary keys"
+                )
+
+            cols = expr.get_cols()
+            return Count(cols[0], filter=result.filter)
+
+        return result
+
 
 class Max(Aggregate):
     function = "MAX"

+ 71 - 2
django/db/models/base.py

@@ -1,6 +1,7 @@
 import copy
 import inspect
 import warnings
+from collections import defaultdict
 from functools import partialmethod
 from itertools import chain
 
@@ -30,6 +31,7 @@ from django.db.models import NOT_PROVIDED, ExpressionWrapper, IntegerField, Max,
 from django.db.models.constants import LOOKUP_SEP
 from django.db.models.deletion import CASCADE, Collector
 from django.db.models.expressions import DatabaseDefault
+from django.db.models.fields.composite import CompositePrimaryKey
 from django.db.models.fields.related import (
     ForeignObjectRel,
     OneToOneField,
@@ -508,7 +510,7 @@ class Model(AltersData, metaclass=ModelBase):
         for field in fields_iter:
             is_related_object = False
             # Virtual field
-            if field.attname not in kwargs and field.column is None or field.generated:
+            if field.column is None or field.generated:
                 continue
             if kwargs:
                 if isinstance(field.remote_field, ForeignObjectRel):
@@ -663,7 +665,11 @@ class Model(AltersData, metaclass=ModelBase):
     pk = property(_get_pk_val, _set_pk_val)
 
     def _is_pk_set(self, meta=None):
-        return self._get_pk_val(meta) is not None
+        pk_val = self._get_pk_val(meta)
+        return not (
+            pk_val is None
+            or (isinstance(pk_val, tuple) and any(f is None for f in pk_val))
+        )
 
     def get_deferred_fields(self):
         """
@@ -1454,6 +1460,11 @@ class Model(AltersData, metaclass=ModelBase):
                 name = f.name
                 if name in exclude:
                     continue
+                if isinstance(f, CompositePrimaryKey):
+                    names = tuple(field.name for field in f.fields)
+                    if exclude.isdisjoint(names):
+                        unique_checks.append((model_class, names))
+                    continue
                 if f.unique:
                     unique_checks.append((model_class, (name,)))
                 if f.unique_for_date and f.unique_for_date not in exclude:
@@ -1728,6 +1739,7 @@ class Model(AltersData, metaclass=ModelBase):
                 *cls._check_constraints(databases),
                 *cls._check_default_pk(),
                 *cls._check_db_table_comment(databases),
+                *cls._check_composite_pk(),
             ]
 
         return errors
@@ -1764,6 +1776,63 @@ class Model(AltersData, metaclass=ModelBase):
             ]
         return []
 
+    @classmethod
+    def _check_composite_pk(cls):
+        errors = []
+        meta = cls._meta
+        pk = meta.pk
+
+        if not isinstance(pk, CompositePrimaryKey):
+            return errors
+
+        seen_columns = defaultdict(list)
+
+        for field_name in pk.field_names:
+            hint = None
+
+            try:
+                field = meta.get_field(field_name)
+            except FieldDoesNotExist:
+                field = None
+
+            if not field:
+                hint = f"{field_name!r} is not a valid field."
+            elif not field.column:
+                hint = f"{field_name!r} field has no column."
+            elif field.null:
+                hint = f"{field_name!r} field may not set 'null=True'."
+            elif field.generated:
+                hint = f"{field_name!r} field is a generated field."
+            else:
+                seen_columns[field.column].append(field_name)
+
+            if hint:
+                errors.append(
+                    checks.Error(
+                        f"{field_name!r} cannot be included in the composite primary "
+                        "key.",
+                        hint=hint,
+                        obj=cls,
+                        id="models.E042",
+                    )
+                )
+
+        for column, field_names in seen_columns.items():
+            if len(field_names) > 1:
+                field_name, *rest = field_names
+                duplicates = ", ".join(repr(field) for field in rest)
+                errors.append(
+                    checks.Error(
+                        f"{duplicates} cannot be included in the composite primary "
+                        "key.",
+                        hint=f"{duplicates} and {field_name!r} are the same fields.",
+                        obj=cls,
+                        id="models.E042",
+                    )
+                )
+
+        return errors
+
     @classmethod
     def _check_db_table_comment(cls, databases):
         if not cls._meta.db_table_comment:

+ 2 - 0
django/db/models/fields/__init__.py

@@ -656,6 +656,8 @@ class Field(RegisterLookupMixin):
             path = path.replace("django.db.models.fields.json", "django.db.models")
         elif path.startswith("django.db.models.fields.proxy"):
             path = path.replace("django.db.models.fields.proxy", "django.db.models")
+        elif path.startswith("django.db.models.fields.composite"):
+            path = path.replace("django.db.models.fields.composite", "django.db.models")
         elif path.startswith("django.db.models.fields"):
             path = path.replace("django.db.models.fields", "django.db.models")
         # Return basic info - other fields should override this.

+ 150 - 0
django/db/models/fields/composite.py

@@ -0,0 +1,150 @@
+from django.core import checks
+from django.db.models import NOT_PROVIDED, Field
+from django.db.models.expressions import ColPairs
+from django.db.models.fields.tuple_lookups import (
+    TupleExact,
+    TupleGreaterThan,
+    TupleGreaterThanOrEqual,
+    TupleIn,
+    TupleIsNull,
+    TupleLessThan,
+    TupleLessThanOrEqual,
+)
+from django.utils.functional import cached_property
+
+
+class CompositeAttribute:
+    def __init__(self, field):
+        self.field = field
+
+    @property
+    def attnames(self):
+        return [field.attname for field in self.field.fields]
+
+    def __get__(self, instance, cls=None):
+        return tuple(getattr(instance, attname) for attname in self.attnames)
+
+    def __set__(self, instance, values):
+        attnames = self.attnames
+        length = len(attnames)
+
+        if values is None:
+            values = (None,) * length
+
+        if not isinstance(values, (list, tuple)):
+            raise ValueError(f"{self.field.name!r} must be a list or a tuple.")
+        if length != len(values):
+            raise ValueError(f"{self.field.name!r} must have {length} elements.")
+
+        for attname, value in zip(attnames, values):
+            setattr(instance, attname, value)
+
+
+class CompositePrimaryKey(Field):
+    descriptor_class = CompositeAttribute
+
+    def __init__(self, *args, **kwargs):
+        if (
+            not args
+            or not all(isinstance(field, str) for field in args)
+            or len(set(args)) != len(args)
+        ):
+            raise ValueError("CompositePrimaryKey args must be unique strings.")
+        if len(args) == 1:
+            raise ValueError("CompositePrimaryKey must include at least two fields.")
+        if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
+            raise ValueError("CompositePrimaryKey cannot have a default.")
+        if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
+            raise ValueError("CompositePrimaryKey cannot have a database default.")
+        if kwargs.setdefault("editable", False):
+            raise ValueError("CompositePrimaryKey cannot be editable.")
+        if not kwargs.setdefault("primary_key", True):
+            raise ValueError("CompositePrimaryKey must be a primary key.")
+        if not kwargs.setdefault("blank", True):
+            raise ValueError("CompositePrimaryKey must be blank.")
+
+        self.field_names = args
+        super().__init__(**kwargs)
+
+    def deconstruct(self):
+        # args is always [] so it can be ignored.
+        name, path, _, kwargs = super().deconstruct()
+        return name, path, self.field_names, kwargs
+
+    @cached_property
+    def fields(self):
+        meta = self.model._meta
+        return tuple(meta.get_field(field_name) for field_name in self.field_names)
+
+    @cached_property
+    def columns(self):
+        return tuple(field.column for field in self.fields)
+
+    def contribute_to_class(self, cls, name, private_only=False):
+        super().contribute_to_class(cls, name, private_only=private_only)
+        cls._meta.pk = self
+        setattr(cls, self.attname, self.descriptor_class(self))
+
+    def get_attname_column(self):
+        return self.get_attname(), None
+
+    def __iter__(self):
+        return iter(self.fields)
+
+    def __len__(self):
+        return len(self.field_names)
+
+    @cached_property
+    def cached_col(self):
+        return ColPairs(self.model._meta.db_table, self.fields, self.fields, self)
+
+    def get_col(self, alias, output_field=None):
+        if alias == self.model._meta.db_table and (
+            output_field is None or output_field == self
+        ):
+            return self.cached_col
+
+        return ColPairs(alias, self.fields, self.fields, output_field)
+
+    def get_pk_value_on_save(self, instance):
+        values = []
+
+        for field in self.fields:
+            value = field.value_from_object(instance)
+            if value is None:
+                value = field.get_pk_value_on_save(instance)
+            values.append(value)
+
+        return tuple(values)
+
+    def _check_field_name(self):
+        if self.name == "pk":
+            return []
+        return [
+            checks.Error(
+                "'CompositePrimaryKey' must be named 'pk'.",
+                obj=self,
+                id="fields.E013",
+            )
+        ]
+
+
+CompositePrimaryKey.register_lookup(TupleExact)
+CompositePrimaryKey.register_lookup(TupleGreaterThan)
+CompositePrimaryKey.register_lookup(TupleGreaterThanOrEqual)
+CompositePrimaryKey.register_lookup(TupleLessThan)
+CompositePrimaryKey.register_lookup(TupleLessThanOrEqual)
+CompositePrimaryKey.register_lookup(TupleIn)
+CompositePrimaryKey.register_lookup(TupleIsNull)
+
+
+def unnest(fields):
+    result = []
+
+    for field in fields:
+        if isinstance(field, CompositePrimaryKey):
+            result.extend(field.fields)
+        else:
+            result.append(field)
+
+    return result

+ 15 - 5
django/db/models/fields/related.py

@@ -624,11 +624,21 @@ class ForeignObject(RelatedField):
         if not has_unique_constraint:
             foreign_fields = {f.name for f in self.foreign_related_fields}
             remote_opts = self.remote_field.model._meta
-            has_unique_constraint = any(
-                frozenset(ut) <= foreign_fields for ut in remote_opts.unique_together
-            ) or any(
-                frozenset(uc.fields) <= foreign_fields
-                for uc in remote_opts.total_unique_constraints
+            has_unique_constraint = (
+                any(
+                    frozenset(ut) <= foreign_fields
+                    for ut in remote_opts.unique_together
+                )
+                or any(
+                    frozenset(uc.fields) <= foreign_fields
+                    for uc in remote_opts.total_unique_constraints
+                )
+                # If the model defines a composite primary key and the foreign key
+                # refers to it, the target is unique.
+                or (
+                    frozenset(field.name for field in remote_opts.pk_fields)
+                    == foreign_fields
+                )
             )
 
         if not has_unique_constraint:

+ 4 - 2
django/db/models/fields/related_lookups.py

@@ -1,5 +1,6 @@
 from django.db import NotSupportedError
 from django.db.models.expressions import ColPairs
+from django.db.models.fields import composite
 from django.db.models.fields.tuple_lookups import TupleIn, tuple_lookups
 from django.db.models.lookups import (
     Exact,
@@ -19,7 +20,7 @@ def get_normalized_value(value, lhs):
         if not value._is_pk_set():
             raise ValueError("Model instances passed to related filters must be saved.")
         value_list = []
-        sources = lhs.output_field.path_infos[-1].target_fields
+        sources = composite.unnest(lhs.output_field.path_infos[-1].target_fields)
         for source in sources:
             while not isinstance(value, source.model) and source.remote_field:
                 source = source.remote_field.model._meta.get_field(
@@ -30,7 +31,8 @@ def get_normalized_value(value, lhs):
             except AttributeError:
                 # A case like Restaurant.objects.filter(place=restaurant_instance),
                 # where place is a OneToOneField and the primary key of Restaurant.
-                return (value.pk,)
+                pk = value.pk
+                return pk if isinstance(pk, tuple) else (pk,)
         return tuple(value_list)
     if not isinstance(value, tuple):
         return (value,)

+ 9 - 1
django/db/models/fields/tuple_lookups.py

@@ -250,6 +250,8 @@ class TupleIn(TupleLookupMixin, In):
 
     def check_rhs_select_length_equals_lhs_length(self):
         len_rhs = len(self.rhs.select)
+        if len_rhs == 1 and isinstance(self.rhs.select[0], ColPairs):
+            len_rhs = len(self.rhs.select[0])
         len_lhs = len(self.lhs)
         if len_rhs != len_lhs:
             lhs_str = self.get_lhs_str()
@@ -304,7 +306,13 @@ class TupleIn(TupleLookupMixin, In):
         return root.as_sql(compiler, connection)
 
     def as_subquery(self, compiler, connection):
-        return compiler.compile(In(self.lhs, self.rhs))
+        lhs = self.lhs
+        rhs = self.rhs
+        if isinstance(lhs, ColPairs):
+            rhs = rhs.clone()
+            rhs.set_values([source.name for source in lhs.sources])
+            lhs = Tuple(lhs)
+        return compiler.compile(In(lhs, rhs))
 
 
 tuple_lookups = {

+ 16 - 1
django/db/models/options.py

@@ -7,7 +7,14 @@ from django.conf import settings
 from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
 from django.core.signals import setting_changed
 from django.db import connections
-from django.db.models import AutoField, Manager, OrderWrt, UniqueConstraint
+from django.db.models import (
+    AutoField,
+    CompositePrimaryKey,
+    Manager,
+    OrderWrt,
+    UniqueConstraint,
+)
+from django.db.models.fields import composite
 from django.db.models.query_utils import PathInfo
 from django.utils.datastructures import ImmutableList, OrderedSet
 from django.utils.functional import cached_property
@@ -973,6 +980,14 @@ class Options:
             )
         ]
 
+    @cached_property
+    def pk_fields(self):
+        return composite.unnest([self.pk])
+
+    @property
+    def is_composite_pk(self):
+        return isinstance(self.pk, CompositePrimaryKey)
+
     @cached_property
     def _property_names(self):
         """Return a set of the names of the properties defined on the model."""

+ 6 - 3
django/db/models/query.py

@@ -171,11 +171,14 @@ class RawModelIterable(BaseIterable):
                     "Raw query must include the primary key"
                 )
             fields = [self.queryset.model_fields.get(c) for c in self.queryset.columns]
-            converters = compiler.get_converters(
-                [f.get_col(f.model._meta.db_table) if f else None for f in fields]
-            )
+            cols = [f.get_col(f.model._meta.db_table) if f else None for f in fields]
+            converters = compiler.get_converters(cols)
             if converters:
                 query_iterator = compiler.apply_converters(query_iterator, converters)
+            if compiler.has_composite_fields(cols):
+                query_iterator = compiler.composite_fields_to_tuples(
+                    query_iterator, cols
+                )
             for values in query_iterator:
                 # Associate fields to values
                 model_init_values = [values[pos] for pos in model_init_pos]

+ 65 - 10
django/db/models/sql/compiler.py

@@ -7,7 +7,9 @@ from itertools import chain
 from django.core.exceptions import EmptyResultSet, FieldError, FullResultSet
 from django.db import DatabaseError, NotSupportedError
 from django.db.models.constants import LOOKUP_SEP
-from django.db.models.expressions import F, OrderBy, RawSQL, Ref, Value
+from django.db.models.expressions import ColPairs, F, OrderBy, RawSQL, Ref, Value
+from django.db.models.fields import composite
+from django.db.models.fields.composite import CompositePrimaryKey
 from django.db.models.functions import Cast, Random
 from django.db.models.lookups import Lookup
 from django.db.models.query_utils import select_related_descend
@@ -283,6 +285,9 @@ class SQLCompiler:
                 # Reference to a column.
                 elif isinstance(expression, int):
                     expression = cols[expression]
+                # ColPairs cannot be aliased.
+                if isinstance(expression, ColPairs):
+                    alias = None
                 selected.append((alias, expression))
 
         for select_idx, (alias, expression) in enumerate(selected):
@@ -997,6 +1002,7 @@ class SQLCompiler:
         # alias for a given field. This also includes None -> start_alias to
         # be used by local fields.
         seen_models = {None: start_alias}
+        select_mask_fields = set(composite.unnest(select_mask))
 
         for field in opts.concrete_fields:
             model = field.model._meta.concrete_model
@@ -1017,7 +1023,7 @@ class SQLCompiler:
                 # parent model data is already present in the SELECT clause,
                 # and we want to avoid reloading the same data again.
                 continue
-            if select_mask and field not in select_mask:
+            if select_mask and field not in select_mask_fields:
                 continue
             alias = self.query.join_parent_model(opts, model, start_alias, seen_models)
             column = field.get_col(alias)
@@ -1110,9 +1116,10 @@ class SQLCompiler:
                 )
             return results
         targets, alias, _ = self.query.trim_joins(targets, joins, path)
+        target_fields = composite.unnest(targets)
         return [
             (OrderBy(transform_function(t, alias), descending=descending), False)
-            for t in targets
+            for t in target_fields
         ]
 
     def _setup_joins(self, pieces, opts, alias):
@@ -1504,13 +1511,25 @@ class SQLCompiler:
         return result
 
     def get_converters(self, expressions):
+        i = 0
         converters = {}
-        for i, expression in enumerate(expressions):
-            if expression:
+
+        for expression in expressions:
+            if isinstance(expression, ColPairs):
+                cols = expression.get_source_expressions()
+                cols_converters = self.get_converters(cols)
+                for j, (convs, col) in cols_converters.items():
+                    converters[i + j] = (convs, col)
+                i += len(expression)
+            elif expression:
                 backend_converters = self.connection.ops.get_db_converters(expression)
                 field_converters = expression.get_db_converters(self.connection)
                 if backend_converters or field_converters:
                     converters[i] = (backend_converters + field_converters, expression)
+                i += 1
+            else:
+                i += 1
+
         return converters
 
     def apply_converters(self, rows, converters):
@@ -1524,6 +1543,24 @@ class SQLCompiler:
                 row[pos] = value
             yield row
 
+    def has_composite_fields(self, expressions):
+        # Check for composite fields before calling the relatively costly
+        # composite_fields_to_tuples.
+        return any(isinstance(expression, ColPairs) for expression in expressions)
+
+    def composite_fields_to_tuples(self, rows, expressions):
+        col_pair_slices = [
+            slice(i, i + len(expression))
+            for i, expression in enumerate(expressions)
+            if isinstance(expression, ColPairs)
+        ]
+
+        for row in map(list, rows):
+            for pos in col_pair_slices:
+                row[pos] = (tuple(row[pos]),)
+
+            yield row
+
     def results_iter(
         self,
         results=None,
@@ -1541,8 +1578,10 @@ class SQLCompiler:
         rows = chain.from_iterable(results)
         if converters:
             rows = self.apply_converters(rows, converters)
-            if tuple_expected:
-                rows = map(tuple, rows)
+        if self.has_composite_fields(fields):
+            rows = self.composite_fields_to_tuples(rows, fields)
+        if tuple_expected:
+            rows = map(tuple, rows)
         return rows
 
     def has_results(self):
@@ -1863,6 +1902,18 @@ class SQLInsertCompiler(SQLCompiler):
                     )
                 ]
                 cols = [field.get_col(opts.db_table) for field in self.returning_fields]
+            elif isinstance(opts.pk, CompositePrimaryKey):
+                returning_field = returning_fields[0]
+                cols = [returning_field.get_col(opts.db_table)]
+                rows = [
+                    (
+                        self.connection.ops.last_insert_id(
+                            cursor,
+                            opts.db_table,
+                            returning_field.column,
+                        ),
+                    )
+                ]
             else:
                 cols = [opts.pk.get_col(opts.db_table)]
                 rows = [
@@ -1876,8 +1927,10 @@ class SQLInsertCompiler(SQLCompiler):
                 ]
         converters = self.get_converters(cols)
         if converters:
-            rows = list(self.apply_converters(rows, converters))
-        return rows
+            rows = self.apply_converters(rows, converters)
+        if self.has_composite_fields(cols):
+            rows = self.composite_fields_to_tuples(rows, cols)
+        return list(rows)
 
 
 class SQLDeleteCompiler(SQLCompiler):
@@ -2065,6 +2118,7 @@ class SQLUpdateCompiler(SQLCompiler):
         query.add_fields(fields)
         super().pre_sql_setup()
 
+        is_composite_pk = meta.is_composite_pk
         must_pre_select = (
             count > 1 and not self.connection.features.update_can_self_select
         )
@@ -2079,7 +2133,8 @@ class SQLUpdateCompiler(SQLCompiler):
             idents = []
             related_ids = collections.defaultdict(list)
             for rows in query.get_compiler(self.using).execute_sql(MULTI):
-                idents.extend(r[0] for r in rows)
+                pks = [row if is_composite_pk else row[0] for row in rows]
+                idents.extend(pks)
                 for parent, index in related_ids_index:
                     related_ids[parent].extend(r[index] for r in rows)
             self.query.add_filter("pk__in", idents)

+ 6 - 2
django/db/models/sql/query.py

@@ -627,8 +627,12 @@ class Query(BaseExpression):
         if result is None:
             result = empty_set_result
         else:
-            converters = compiler.get_converters(outer_query.annotation_select.values())
-            result = next(compiler.apply_converters((result,), converters))
+            cols = outer_query.annotation_select.values()
+            converters = compiler.get_converters(cols)
+            rows = compiler.apply_converters((result,), converters)
+            if compiler.has_composite_fields(cols):
+                rows = compiler.composite_fields_to_tuples(rows, cols)
+            result = next(rows)
 
         return dict(zip(outer_query.annotation_select, result))
 

+ 3 - 0
docs/ref/checks.txt

@@ -181,6 +181,7 @@ Model fields
 * **fields.E011**: ``<database>`` does not support default database values with
   expressions (``db_default``).
 * **fields.E012**: ``<expression>`` cannot be used in ``db_default``.
+* **fields.E013**: ``CompositePrimaryKey`` must be named ``pk``.
 * **fields.E100**: ``AutoField``\s must set primary_key=True.
 * **fields.E110**: ``BooleanField``\s do not accept null values. *This check
   appeared before support for null values was added in Django 2.1.*
@@ -417,6 +418,8 @@ Models
 * **models.W040**: ``<database>`` does not support indexes with non-key
   columns.
 * **models.E041**: ``constraints`` refers to the joined field ``<field name>``.
+* **models.E042**: ``<field name>`` cannot be included in the composite
+  primary key.
 * **models.W042**: Auto-created primary key used when not defining a primary
   key type, by default ``django.db.models.AutoField``.
 * **models.W043**: ``<database>`` does not support indexes on expressions.

+ 19 - 0
docs/ref/models/fields.txt

@@ -707,6 +707,23 @@ or :class:`~django.forms.NullBooleanSelect` if :attr:`null=True <Field.null>`.
 The default value of ``BooleanField`` is ``None`` when :attr:`Field.default`
 isn't defined.
 
+``CompositePrimaryKey``
+-----------------------
+
+.. versionadded:: 5.2
+
+.. class:: CompositePrimaryKey(*field_names, **options)
+
+A virtual field used for defining a composite primary key.
+
+This field must be defined as the model's ``pk`` field. If present, Django will
+create the underlying model table with a composite primary key.
+
+The ``*field_names`` argument is a list of positional field names that compose
+the primary key.
+
+See :doc:`/topics/composite-primary-key` for more details.
+
 ``CharField``
 -------------
 
@@ -1615,6 +1632,8 @@ not an instance of ``UUID``.
     hyphens, because PostgreSQL and MariaDB 10.7+ store them in a hyphenated
     uuid datatype type.
 
+.. _relationship-fields:
+
 Relationship fields
 ===================
 

+ 19 - 0
docs/releases/5.2.txt

@@ -31,6 +31,25 @@ and only officially support the latest release of each series.
 What's new in Django 5.2
 ========================
 
+Composite Primary Keys
+----------------------
+
+The new :class:`django.db.models.CompositePrimaryKey` allows tables to be
+created with a primary key consisting of multiple fields.
+
+To use a composite primary key, when creating a model set the ``pk`` field to
+be a ``CompositePrimaryKey``::
+
+    from django.db import models
+
+
+    class Release(models.Model):
+        pk = models.CompositePrimaryKey("version", "name")
+        version = models.IntegerField()
+        name = models.CharField(max_length=20)
+
+See :doc:`/topics/composite-primary-key` for more details.
+
 Minor features
 --------------
 

+ 183 - 0
docs/topics/composite-primary-key.txt

@@ -0,0 +1,183 @@
+======================
+Composite primary keys
+======================
+
+.. versionadded:: 5.2
+
+In Django, each model has a primary key. By default, this primary key consists
+of a single field.
+
+In most cases, a single primary key should suffice. In database design,
+however, defining a primary key consisting of multiple fields is sometimes
+necessary.
+
+To use a composite primary key, when creating a model set the ``pk`` field to
+be a :class:`.CompositePrimaryKey`::
+
+    class Product(models.Model):
+        name = models.CharField(max_length=100)
+
+
+    class Order(models.Model):
+        reference = models.CharField(max_length=20, primary_key=True)
+
+
+    class OrderLineItem(models.Model):
+        pk = models.CompositePrimaryKey("product_id", "order_id")
+        product = models.ForeignKey(Product, on_delete=models.CASCADE)
+        order = models.ForeignKey(Order, on_delete=models.CASCADE)
+        quantity = models.IntegerField()
+
+This will instruct Django to create a composite primary key
+(``PRIMARY KEY (product_id, order_id)``) when creating the table.
+
+A composite primary key is represented by a ``tuple``:
+
+.. code-block:: pycon
+
+    >>> product = Product.objects.create(name="apple")
+    >>> order = Order.objects.create(reference="A755H")
+    >>> item = OrderLineItem.objects.create(product=product, order=order, quantity=1)
+    >>> item.pk
+    (1, "A755H")
+
+You can assign a ``tuple`` to a composite primary key. This sets the associated
+field values.
+
+.. code-block:: pycon
+
+    >>> item = OrderLineItem(pk=(2, "B142C"))
+    >>> item.pk
+    (2, "B142C")
+    >>> item.product_id
+    2
+    >>> item.order_id
+    "B142C"
+
+A composite primary key can also be filtered by a ``tuple``:
+
+.. code-block:: pycon
+
+    >>> OrderLineItem.objects.filter(pk=(1, "A755H")).count()
+    1
+
+We're still working on composite primary key support for
+:ref:`relational fields <cpk-and-relations>`, including
+:class:`.GenericForeignKey` fields, and the Django admin. Models with composite
+primary keys cannot be registered in the Django admin at this time. You can
+expect to see this in future releases.
+
+Migrating to a composite primary key
+====================================
+
+Django doesn't support migrating to, or from, a composite primary key after the
+table is created. It also doesn't support adding or removing fields from the
+composite primary key.
+
+If you would like to migrate an existing table from a single primary key to a
+composite primary key, follow your database backend's instructions to do so.
+
+Once the composite primary key is in place, add the ``CompositePrimaryKey``
+field to your model. This allows Django to recognize and handle the composite
+primary key appropriately.
+
+While migration operations (e.g. ``AddField``, ``AlterField``) on primary key
+fields are not supported, ``makemigrations`` will still detect changes.
+
+In order to avoid errors, it's recommended to apply such migrations with
+``--fake``.
+
+Alternatively, :class:`.SeparateDatabaseAndState` may be used to execute the
+backend-specific migrations and Django-generated migrations in a single
+operation.
+
+.. _cpk-and-relations:
+
+Composite primary keys and relations
+====================================
+
+:ref:`Relationship fields <relationship-fields>`, including
+:ref:`generic relations <generic-relations>` do not support composite primary
+keys.
+
+For example, given the ``OrderLineItem`` model, the following is not
+supported::
+
+    class Foo(models.Model):
+        item = models.ForeignKey(OrderLineItem, on_delete=models.CASCADE)
+
+Because ``ForeignKey`` currently cannot reference models with composite primary
+keys.
+
+To work around this limitation, ``ForeignObject`` can be used as an
+alternative::
+
+    class Foo(models.Model):
+        item_order_id = models.IntegerField()
+        item_product_id = models.CharField(max_length=20)
+        item = models.ForeignObject(
+            OrderLineItem,
+            on_delete=models.CASCADE,
+            from_fields=("item_order_id", "item_product_id"),
+            to_fields=("order_id", "product_id"),
+        )
+
+``ForeignObject`` is much like ``ForeignKey``, except that it doesn't create
+any columns (e.g. ``item_id``), foreign key constraints or indexes in the
+database.
+
+.. warning::
+
+    ``ForeignObject`` is an internal API. This means it is not covered by our
+    :ref:`deprecation policy <internal-release-deprecation-policy>`.
+
+Composite primary keys and database functions
+=============================================
+
+Many database functions only accept a single expression.
+
+.. code-block:: sql
+
+    MAX("order_id")  -- OK
+    MAX("product_id", "order_id")  -- ERROR
+
+As a consequence, they cannot be used with composite primary key references as
+they are composed of multiple column expressions.
+
+.. code-block:: python
+
+    Max("order_id")  # OK
+    Max("pk")  # ERROR
+
+Composite primary keys in forms
+===============================
+
+As a composite primary key is a virtual field, a field which doesn't represent
+a single database column, this field is excluded from ModelForms.
+
+For example, take the following form::
+
+    class OrderLineItemForm(forms.ModelForm):
+        class Meta:
+            model = OrderLineItem
+            fields = "__all__"
+
+This form does not have a form field ``pk`` for the composite primary key:
+
+.. code-block:: pycon
+
+    >>> OrderLineItemForm()
+    <OrderLineItemForm bound=False, valid=Unknown, fields=(product;order;quantity)>
+
+Setting the primary composite field ``pk`` as a form field raises an unknown
+field :exc:`.FieldError`.
+
+.. admonition:: Primary key fields are read only
+
+    If you change the value of a primary key on an existing object and then
+    save it, a new object will be created alongside the old one (see
+    :attr:`.Field.primary_key`).
+
+    This is also true of composite primary keys. Hence, you may want to set
+    :attr:`.Field.editable` to ``False`` on all primary key fields to exclude
+    them from ModelForms.

+ 1 - 0
docs/topics/index.txt

@@ -19,6 +19,7 @@ Introductions to all the key parts of Django you'll need to know:
    auth/index
    cache
    conditional-view-processing
+   composite-primary-key
    signing
    email
    i18n/index

+ 6 - 0
tests/admin_registration/models.py

@@ -20,3 +20,9 @@ class Location(models.Model):
 
 class Place(Location):
     name = models.CharField(max_length=200)
+
+
+class Guest(models.Model):
+    pk = models.CompositePrimaryKey("traveler", "place")
+    traveler = models.ForeignKey(Traveler, on_delete=models.CASCADE)
+    place = models.ForeignKey(Place, on_delete=models.CASCADE)

+ 9 - 1
tests/admin_registration/tests.py

@@ -5,7 +5,7 @@ from django.contrib.admin.sites import site
 from django.core.exceptions import ImproperlyConfigured
 from django.test import SimpleTestCase
 
-from .models import Location, Person, Place, Traveler
+from .models import Guest, Location, Person, Place, Traveler
 
 
 class NameAdmin(admin.ModelAdmin):
@@ -92,6 +92,14 @@ class TestRegistration(SimpleTestCase):
         with self.assertRaisesMessage(ImproperlyConfigured, msg):
             self.site.register(Location)
 
+    def test_composite_pk_model(self):
+        msg = (
+            "The model Guest has a composite primary key, so it cannot be registered "
+            "with admin."
+        )
+        with self.assertRaisesMessage(ImproperlyConfigured, msg):
+            self.site.register(Guest)
+
     def test_is_registered_model(self):
         "Checks for registered models should return true."
         self.site.register(Person)

+ 0 - 0
tests/composite_pk/__init__.py


+ 75 - 0
tests/composite_pk/fixtures/tenant.json

@@ -0,0 +1,75 @@
+[
+    {
+        "pk": 1,
+        "model": "composite_pk.tenant",
+        "fields": {
+            "id": 1,
+            "name": "Tenant 1"
+        }
+    },
+    {
+        "pk": 2,
+        "model": "composite_pk.tenant",
+        "fields": {
+            "id": 2,
+            "name": "Tenant 2"
+        }
+    },
+    {
+        "pk": 3,
+        "model": "composite_pk.tenant",
+        "fields": {
+            "id": 3,
+            "name": "Tenant 3"
+        }
+    },
+    {
+        "pk": [1, 1],
+        "model": "composite_pk.user",
+        "fields": {
+            "tenant_id": 1,
+            "id": 1,
+            "email": "user0001@example.com"
+        }
+    },
+    {
+        "pk": [1, 2],
+        "model": "composite_pk.user",
+        "fields": {
+            "tenant_id": 1,
+            "id": 2,
+            "email": "user0002@example.com"
+        }
+    },
+    {
+        "pk": [2, 3],
+        "model": "composite_pk.user",
+        "fields": {
+            "email": "user0003@example.com"
+        }
+    },
+    {
+        "model": "composite_pk.user",
+        "fields": {
+            "tenant_id": 2,
+            "id": 4,
+            "email": "user0004@example.com"
+        }
+    },
+    {
+        "pk": [2, "11111111-1111-1111-1111-111111111111"],
+        "model": "composite_pk.post",
+        "fields": {
+            "tenant_id": 2,
+            "id": "11111111-1111-1111-1111-111111111111"
+        }
+    },
+    {
+        "pk": [2, "ffffffff-ffff-ffff-ffff-ffffffffffff"],
+        "model": "composite_pk.post",
+        "fields": {
+            "tenant_id": 2,
+            "id": "ffffffff-ffff-ffff-ffff-ffffffffffff"
+        }
+    }
+]

+ 9 - 0
tests/composite_pk/models/__init__.py

@@ -0,0 +1,9 @@
+from .tenant import Comment, Post, Tenant, Token, User
+
+__all__ = [
+    "Comment",
+    "Post",
+    "Tenant",
+    "Token",
+    "User",
+]

+ 50 - 0
tests/composite_pk/models/tenant.py

@@ -0,0 +1,50 @@
+from django.db import models
+
+
+class Tenant(models.Model):
+    name = models.CharField(max_length=10, default="", blank=True)
+
+
+class Token(models.Model):
+    pk = models.CompositePrimaryKey("tenant_id", "id")
+    tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE, related_name="tokens")
+    id = models.SmallIntegerField()
+    secret = models.CharField(max_length=10, default="", blank=True)
+
+
+class BaseModel(models.Model):
+    pk = models.CompositePrimaryKey("tenant_id", "id")
+    tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
+    id = models.SmallIntegerField(unique=True)
+
+    class Meta:
+        abstract = True
+
+
+class User(BaseModel):
+    email = models.EmailField(unique=True)
+
+
+class Comment(models.Model):
+    pk = models.CompositePrimaryKey("tenant", "id")
+    tenant = models.ForeignKey(
+        Tenant,
+        on_delete=models.CASCADE,
+        related_name="comments",
+    )
+    id = models.SmallIntegerField(unique=True, db_column="comment_id")
+    user_id = models.SmallIntegerField()
+    user = models.ForeignObject(
+        User,
+        on_delete=models.CASCADE,
+        from_fields=("tenant_id", "user_id"),
+        to_fields=("tenant_id", "id"),
+        related_name="comments",
+    )
+    text = models.TextField(default="", blank=True)
+
+
+class Post(models.Model):
+    pk = models.CompositePrimaryKey("tenant_id", "id")
+    tenant = models.ForeignKey(Tenant, on_delete=models.CASCADE)
+    id = models.UUIDField()

+ 139 - 0
tests/composite_pk/test_aggregate.py

@@ -0,0 +1,139 @@
+from django.db import NotSupportedError
+from django.db.models import Count, Q
+from django.test import TestCase
+
+from .models import Comment, Tenant, User
+
+
+class CompositePKAggregateTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        cls.tenant_1 = Tenant.objects.create()
+        cls.tenant_2 = Tenant.objects.create()
+        cls.user_1 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=1,
+            email="user0001@example.com",
+        )
+        cls.user_2 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=2,
+            email="user0002@example.com",
+        )
+        cls.user_3 = User.objects.create(
+            tenant=cls.tenant_2,
+            id=3,
+            email="user0003@example.com",
+        )
+        cls.comment_1 = Comment.objects.create(id=1, user=cls.user_2, text="foo")
+        cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1, text="bar")
+        cls.comment_3 = Comment.objects.create(id=3, user=cls.user_1, text="foobar")
+        cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3, text="foobarbaz")
+        cls.comment_5 = Comment.objects.create(id=5, user=cls.user_3, text="barbaz")
+        cls.comment_6 = Comment.objects.create(id=6, user=cls.user_3, text="baz")
+
+    def test_users_annotated_with_comments_id_count(self):
+        user_1, user_2, user_3 = User.objects.annotate(Count("comments__id")).order_by(
+            "pk"
+        )
+
+        self.assertEqual(user_1, self.user_1)
+        self.assertEqual(user_1.comments__id__count, 2)
+        self.assertEqual(user_2, self.user_2)
+        self.assertEqual(user_2.comments__id__count, 1)
+        self.assertEqual(user_3, self.user_3)
+        self.assertEqual(user_3.comments__id__count, 3)
+
+    def test_users_annotated_with_aliased_comments_id_count(self):
+        user_1, user_2, user_3 = User.objects.annotate(
+            comments_count=Count("comments__id")
+        ).order_by("pk")
+
+        self.assertEqual(user_1, self.user_1)
+        self.assertEqual(user_1.comments_count, 2)
+        self.assertEqual(user_2, self.user_2)
+        self.assertEqual(user_2.comments_count, 1)
+        self.assertEqual(user_3, self.user_3)
+        self.assertEqual(user_3.comments_count, 3)
+
+    def test_users_annotated_with_comments_count(self):
+        user_1, user_2, user_3 = User.objects.annotate(Count("comments")).order_by("pk")
+
+        self.assertEqual(user_1, self.user_1)
+        self.assertEqual(user_1.comments__count, 2)
+        self.assertEqual(user_2, self.user_2)
+        self.assertEqual(user_2.comments__count, 1)
+        self.assertEqual(user_3, self.user_3)
+        self.assertEqual(user_3.comments__count, 3)
+
+    def test_users_annotated_with_comments_count_filter(self):
+        user_1, user_2, user_3 = User.objects.annotate(
+            comments__count=Count(
+                "comments", filter=Q(pk__in=[self.user_1.pk, self.user_2.pk])
+            )
+        ).order_by("pk")
+
+        self.assertEqual(user_1, self.user_1)
+        self.assertEqual(user_1.comments__count, 2)
+        self.assertEqual(user_2, self.user_2)
+        self.assertEqual(user_2.comments__count, 1)
+        self.assertEqual(user_3, self.user_3)
+        self.assertEqual(user_3.comments__count, 0)
+
+    def test_count_distinct_not_supported(self):
+        with self.assertRaisesMessage(
+            NotSupportedError, "COUNT(DISTINCT) doesn't support composite primary keys"
+        ):
+            self.assertIsNone(
+                User.objects.annotate(comments__count=Count("comments", distinct=True))
+            )
+
+    def test_user_values_annotated_with_comments_id_count(self):
+        self.assertSequenceEqual(
+            User.objects.values("pk").annotate(Count("comments__id")).order_by("pk"),
+            (
+                {"pk": self.user_1.pk, "comments__id__count": 2},
+                {"pk": self.user_2.pk, "comments__id__count": 1},
+                {"pk": self.user_3.pk, "comments__id__count": 3},
+            ),
+        )
+
+    def test_user_values_annotated_with_filtered_comments_id_count(self):
+        self.assertSequenceEqual(
+            User.objects.values("pk")
+            .annotate(
+                comments_count=Count(
+                    "comments__id",
+                    filter=Q(comments__text__icontains="foo"),
+                )
+            )
+            .order_by("pk"),
+            (
+                {"pk": self.user_1.pk, "comments_count": 1},
+                {"pk": self.user_2.pk, "comments_count": 1},
+                {"pk": self.user_3.pk, "comments_count": 1},
+            ),
+        )
+
+    def test_filter_and_count_users_by_comments_fields(self):
+        users = User.objects.filter(comments__id__gt=2).order_by("pk")
+        self.assertEqual(users.count(), 4)
+        self.assertSequenceEqual(
+            users, (self.user_1, self.user_3, self.user_3, self.user_3)
+        )
+
+        users = User.objects.filter(comments__text__icontains="foo").order_by("pk")
+        self.assertEqual(users.count(), 3)
+        self.assertSequenceEqual(users, (self.user_1, self.user_2, self.user_3))
+
+        users = User.objects.filter(comments__text__icontains="baz").order_by("pk")
+        self.assertEqual(users.count(), 3)
+        self.assertSequenceEqual(users, (self.user_3, self.user_3, self.user_3))
+
+    def test_order_by_comments_id_count(self):
+        self.assertSequenceEqual(
+            User.objects.annotate(comments_count=Count("comments__id")).order_by(
+                "-comments_count"
+            ),
+            (self.user_3, self.user_1, self.user_2),
+        )

+ 242 - 0
tests/composite_pk/test_checks.py

@@ -0,0 +1,242 @@
+from django.core import checks
+from django.db import connection, models
+from django.db.models import F
+from django.test import TestCase
+from django.test.utils import isolate_apps
+
+
+@isolate_apps("composite_pk")
+class CompositePKChecksTests(TestCase):
+    maxDiff = None
+
+    def test_composite_pk_must_be_unique_strings(self):
+        test_cases = (
+            (),
+            (0,),
+            (1,),
+            ("id", False),
+            ("id", "id"),
+            (("id",),),
+        )
+
+        for i, args in enumerate(test_cases):
+            with (
+                self.subTest(args=args),
+                self.assertRaisesMessage(
+                    ValueError, "CompositePrimaryKey args must be unique strings."
+                ),
+            ):
+                models.CompositePrimaryKey(*args)
+
+    def test_composite_pk_must_include_at_least_2_fields(self):
+        expected_message = "CompositePrimaryKey must include at least two fields."
+        with self.assertRaisesMessage(ValueError, expected_message):
+            models.CompositePrimaryKey("id")
+
+    def test_composite_pk_cannot_have_a_default(self):
+        expected_message = "CompositePrimaryKey cannot have a default."
+        with self.assertRaisesMessage(ValueError, expected_message):
+            models.CompositePrimaryKey("tenant_id", "id", default=(1, 1))
+
+    def test_composite_pk_cannot_have_a_database_default(self):
+        expected_message = "CompositePrimaryKey cannot have a database default."
+        with self.assertRaisesMessage(ValueError, expected_message):
+            models.CompositePrimaryKey("tenant_id", "id", db_default=models.F("id"))
+
+    def test_composite_pk_cannot_be_editable(self):
+        expected_message = "CompositePrimaryKey cannot be editable."
+        with self.assertRaisesMessage(ValueError, expected_message):
+            models.CompositePrimaryKey("tenant_id", "id", editable=True)
+
+    def test_composite_pk_must_be_a_primary_key(self):
+        expected_message = "CompositePrimaryKey must be a primary key."
+        with self.assertRaisesMessage(ValueError, expected_message):
+            models.CompositePrimaryKey("tenant_id", "id", primary_key=False)
+
+    def test_composite_pk_must_be_blank(self):
+        expected_message = "CompositePrimaryKey must be blank."
+        with self.assertRaisesMessage(ValueError, expected_message):
+            models.CompositePrimaryKey("tenant_id", "id", blank=False)
+
+    def test_composite_pk_must_not_have_other_pk_field(self):
+        class Foo(models.Model):
+            pk = models.CompositePrimaryKey("foo_id", "id")
+            foo_id = models.IntegerField()
+            id = models.IntegerField(primary_key=True)
+
+        self.assertEqual(
+            Foo.check(databases=self.databases),
+            [
+                checks.Error(
+                    "The model cannot have more than one field with "
+                    "'primary_key=True'.",
+                    obj=Foo,
+                    id="models.E026",
+                ),
+            ],
+        )
+
+    def test_composite_pk_cannot_include_nullable_field(self):
+        class Foo(models.Model):
+            pk = models.CompositePrimaryKey("foo_id", "id")
+            foo_id = models.IntegerField()
+            id = models.IntegerField(null=True)
+
+        self.assertEqual(
+            Foo.check(databases=self.databases),
+            [
+                checks.Error(
+                    "'id' cannot be included in the composite primary key.",
+                    hint="'id' field may not set 'null=True'.",
+                    obj=Foo,
+                    id="models.E042",
+                ),
+            ],
+        )
+
+    def test_composite_pk_can_include_fk_name(self):
+        class Foo(models.Model):
+            pass
+
+        class Bar(models.Model):
+            pk = models.CompositePrimaryKey("foo", "id")
+            foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
+            id = models.SmallIntegerField()
+
+        self.assertEqual(Foo.check(databases=self.databases), [])
+        self.assertEqual(Bar.check(databases=self.databases), [])
+
+    def test_composite_pk_cannot_include_same_field(self):
+        class Foo(models.Model):
+            pass
+
+        class Bar(models.Model):
+            pk = models.CompositePrimaryKey("foo", "foo_id")
+            foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
+            id = models.SmallIntegerField()
+
+        self.assertEqual(Foo.check(databases=self.databases), [])
+        self.assertEqual(
+            Bar.check(databases=self.databases),
+            [
+                checks.Error(
+                    "'foo_id' cannot be included in the composite primary key.",
+                    hint="'foo_id' and 'foo' are the same fields.",
+                    obj=Bar,
+                    id="models.E042",
+                ),
+            ],
+        )
+
+    def test_composite_pk_cannot_include_composite_pk_field(self):
+        class Foo(models.Model):
+            pk = models.CompositePrimaryKey("id", "pk")
+            id = models.SmallIntegerField()
+
+        self.assertEqual(
+            Foo.check(databases=self.databases),
+            [
+                checks.Error(
+                    "'pk' cannot be included in the composite primary key.",
+                    hint="'pk' field has no column.",
+                    obj=Foo,
+                    id="models.E042",
+                ),
+            ],
+        )
+
+    def test_composite_pk_cannot_include_db_column(self):
+        class Foo(models.Model):
+            pk = models.CompositePrimaryKey("foo", "bar")
+            foo = models.SmallIntegerField(db_column="foo_id")
+            bar = models.SmallIntegerField(db_column="bar_id")
+
+        class Bar(models.Model):
+            pk = models.CompositePrimaryKey("foo_id", "bar_id")
+            foo = models.SmallIntegerField(db_column="foo_id")
+            bar = models.SmallIntegerField(db_column="bar_id")
+
+        self.assertEqual(Foo.check(databases=self.databases), [])
+        self.assertEqual(
+            Bar.check(databases=self.databases),
+            [
+                checks.Error(
+                    "'foo_id' cannot be included in the composite primary key.",
+                    hint="'foo_id' is not a valid field.",
+                    obj=Bar,
+                    id="models.E042",
+                ),
+                checks.Error(
+                    "'bar_id' cannot be included in the composite primary key.",
+                    hint="'bar_id' is not a valid field.",
+                    obj=Bar,
+                    id="models.E042",
+                ),
+            ],
+        )
+
+    def test_foreign_object_can_refer_composite_pk(self):
+        class Foo(models.Model):
+            pass
+
+        class Bar(models.Model):
+            pk = models.CompositePrimaryKey("foo_id", "id")
+            foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
+            id = models.IntegerField()
+
+        class Baz(models.Model):
+            pk = models.CompositePrimaryKey("foo_id", "id")
+            foo = models.ForeignKey(Foo, on_delete=models.CASCADE)
+            id = models.IntegerField()
+            bar_id = models.IntegerField()
+            bar = models.ForeignObject(
+                Bar,
+                on_delete=models.CASCADE,
+                from_fields=("foo_id", "bar_id"),
+                to_fields=("foo_id", "id"),
+            )
+
+        self.assertEqual(Foo.check(databases=self.databases), [])
+        self.assertEqual(Bar.check(databases=self.databases), [])
+        self.assertEqual(Baz.check(databases=self.databases), [])
+
+    def test_composite_pk_must_be_named_pk(self):
+        class Foo(models.Model):
+            primary_key = models.CompositePrimaryKey("foo_id", "id")
+            foo_id = models.IntegerField()
+            id = models.IntegerField()
+
+        self.assertEqual(
+            Foo.check(databases=self.databases),
+            [
+                checks.Error(
+                    "'CompositePrimaryKey' must be named 'pk'.",
+                    obj=Foo._meta.get_field("primary_key"),
+                    id="fields.E013",
+                ),
+            ],
+        )
+
+    def test_composite_pk_cannot_include_generated_field(self):
+        is_oracle = connection.vendor == "oracle"
+
+        class Foo(models.Model):
+            pk = models.CompositePrimaryKey("id", "foo")
+            id = models.IntegerField()
+            foo = models.GeneratedField(
+                expression=F("id"),
+                output_field=models.IntegerField(),
+                db_persist=not is_oracle,
+            )
+
+        self.assertEqual(
+            Foo.check(databases=self.databases),
+            [
+                checks.Error(
+                    "'foo' cannot be included in the composite primary key.",
+                    hint="'foo' field is a generated field.",
+                    obj=Foo,
+                    id="models.E042",
+                ),
+            ],
+        )

+ 138 - 0
tests/composite_pk/test_create.py

@@ -0,0 +1,138 @@
+from django.test import TestCase
+
+from .models import Tenant, User
+
+
+class CompositePKCreateTests(TestCase):
+    maxDiff = None
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.tenant = Tenant.objects.create()
+        cls.user = User.objects.create(
+            tenant=cls.tenant,
+            id=1,
+            email="user0001@example.com",
+        )
+
+    def test_create_user(self):
+        test_cases = (
+            {"tenant": self.tenant, "id": 2412, "email": "user2412@example.com"},
+            {"tenant_id": self.tenant.id, "id": 5316, "email": "user5316@example.com"},
+            {"pk": (self.tenant.id, 7424), "email": "user7424@example.com"},
+        )
+
+        for fields in test_cases:
+            with self.subTest(fields=fields):
+                count = User.objects.count()
+                user = User(**fields)
+                obj = User.objects.create(**fields)
+                self.assertEqual(obj.tenant_id, self.tenant.id)
+                self.assertEqual(obj.id, user.id)
+                self.assertEqual(obj.pk, (self.tenant.id, user.id))
+                self.assertEqual(obj.email, user.email)
+                self.assertEqual(count + 1, User.objects.count())
+
+    def test_save_user(self):
+        test_cases = (
+            {"tenant": self.tenant, "id": 9241, "email": "user9241@example.com"},
+            {"tenant_id": self.tenant.id, "id": 5132, "email": "user5132@example.com"},
+            {"pk": (self.tenant.id, 3014), "email": "user3014@example.com"},
+        )
+
+        for fields in test_cases:
+            with self.subTest(fields=fields):
+                count = User.objects.count()
+                user = User(**fields)
+                self.assertIsNotNone(user.id)
+                self.assertIsNotNone(user.email)
+                user.save()
+                self.assertEqual(user.tenant_id, self.tenant.id)
+                self.assertEqual(user.tenant, self.tenant)
+                self.assertIsNotNone(user.id)
+                self.assertEqual(user.pk, (self.tenant.id, user.id))
+                self.assertEqual(user.email, fields["email"])
+                self.assertEqual(user.email, f"user{user.id}@example.com")
+                self.assertEqual(count + 1, User.objects.count())
+
+    def test_bulk_create_users(self):
+        objs = [
+            User(tenant=self.tenant, id=8291, email="user8291@example.com"),
+            User(tenant_id=self.tenant.id, id=4021, email="user4021@example.com"),
+            User(pk=(self.tenant.id, 8214), email="user8214@example.com"),
+        ]
+
+        obj_1, obj_2, obj_3 = User.objects.bulk_create(objs)
+
+        self.assertEqual(obj_1.tenant_id, self.tenant.id)
+        self.assertEqual(obj_1.id, 8291)
+        self.assertEqual(obj_1.pk, (obj_1.tenant_id, obj_1.id))
+        self.assertEqual(obj_1.email, "user8291@example.com")
+        self.assertEqual(obj_2.tenant_id, self.tenant.id)
+        self.assertEqual(obj_2.id, 4021)
+        self.assertEqual(obj_2.pk, (obj_2.tenant_id, obj_2.id))
+        self.assertEqual(obj_2.email, "user4021@example.com")
+        self.assertEqual(obj_3.tenant_id, self.tenant.id)
+        self.assertEqual(obj_3.id, 8214)
+        self.assertEqual(obj_3.pk, (obj_3.tenant_id, obj_3.id))
+        self.assertEqual(obj_3.email, "user8214@example.com")
+
+    def test_get_or_create_user(self):
+        test_cases = (
+            {
+                "pk": (self.tenant.id, 8314),
+                "defaults": {"email": "user8314@example.com"},
+            },
+            {
+                "tenant": self.tenant,
+                "id": 3142,
+                "defaults": {"email": "user3142@example.com"},
+            },
+            {
+                "tenant_id": self.tenant.id,
+                "id": 4218,
+                "defaults": {"email": "user4218@example.com"},
+            },
+        )
+
+        for fields in test_cases:
+            with self.subTest(fields=fields):
+                count = User.objects.count()
+                user, created = User.objects.get_or_create(**fields)
+                self.assertIs(created, True)
+                self.assertIsNotNone(user.id)
+                self.assertEqual(user.pk, (self.tenant.id, user.id))
+                self.assertEqual(user.tenant_id, self.tenant.id)
+                self.assertEqual(user.email, fields["defaults"]["email"])
+                self.assertEqual(user.email, f"user{user.id}@example.com")
+                self.assertEqual(count + 1, User.objects.count())
+
+    def test_update_or_create_user(self):
+        test_cases = (
+            {
+                "pk": (self.tenant.id, 2931),
+                "defaults": {"email": "user2931@example.com"},
+            },
+            {
+                "tenant": self.tenant,
+                "id": 6428,
+                "defaults": {"email": "user6428@example.com"},
+            },
+            {
+                "tenant_id": self.tenant.id,
+                "id": 5278,
+                "defaults": {"email": "user5278@example.com"},
+            },
+        )
+
+        for fields in test_cases:
+            with self.subTest(fields=fields):
+                count = User.objects.count()
+                user, created = User.objects.update_or_create(**fields)
+                self.assertIs(created, True)
+                self.assertIsNotNone(user.id)
+                self.assertEqual(user.pk, (self.tenant.id, user.id))
+                self.assertEqual(user.tenant_id, self.tenant.id)
+                self.assertEqual(user.email, fields["defaults"]["email"])
+                self.assertEqual(user.email, f"user{user.id}@example.com")
+                self.assertEqual(count + 1, User.objects.count())

+ 83 - 0
tests/composite_pk/test_delete.py

@@ -0,0 +1,83 @@
+from django.test import TestCase
+
+from .models import Comment, Tenant, User
+
+
+class CompositePKDeleteTests(TestCase):
+    maxDiff = None
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.tenant_1 = Tenant.objects.create()
+        cls.tenant_2 = Tenant.objects.create()
+        cls.user_1 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=1,
+            email="user0001@example.com",
+        )
+        cls.user_2 = User.objects.create(
+            tenant=cls.tenant_2,
+            id=2,
+            email="user0002@example.com",
+        )
+        cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
+        cls.comment_2 = Comment.objects.create(id=2, user=cls.user_2)
+        cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
+
+    def test_delete_tenant_by_pk(self):
+        result = Tenant.objects.filter(pk=self.tenant_1.pk).delete()
+
+        self.assertEqual(
+            result,
+            (
+                3,
+                {
+                    "composite_pk.Comment": 1,
+                    "composite_pk.User": 1,
+                    "composite_pk.Tenant": 1,
+                },
+            ),
+        )
+
+        self.assertIs(Tenant.objects.filter(pk=self.tenant_1.pk).exists(), False)
+        self.assertIs(Tenant.objects.filter(pk=self.tenant_2.pk).exists(), True)
+        self.assertIs(User.objects.filter(pk=self.user_1.pk).exists(), False)
+        self.assertIs(User.objects.filter(pk=self.user_2.pk).exists(), True)
+        self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), False)
+        self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), True)
+        self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), True)
+
+    def test_delete_user_by_pk(self):
+        result = User.objects.filter(pk=self.user_1.pk).delete()
+
+        self.assertEqual(
+            result, (2, {"composite_pk.User": 1, "composite_pk.Comment": 1})
+        )
+
+        self.assertIs(User.objects.filter(pk=self.user_1.pk).exists(), False)
+        self.assertIs(User.objects.filter(pk=self.user_2.pk).exists(), True)
+        self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), False)
+        self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), True)
+        self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), True)
+
+    def test_delete_comments_by_user(self):
+        result = Comment.objects.filter(user=self.user_2).delete()
+
+        self.assertEqual(result, (2, {"composite_pk.Comment": 2}))
+
+        self.assertIs(Comment.objects.filter(pk=self.comment_1.pk).exists(), True)
+        self.assertIs(Comment.objects.filter(pk=self.comment_2.pk).exists(), False)
+        self.assertIs(Comment.objects.filter(pk=self.comment_3.pk).exists(), False)
+
+    def test_delete_without_pk(self):
+        msg = (
+            "Comment object can't be deleted because its pk attribute is set "
+            "to None."
+        )
+
+        with self.assertRaisesMessage(ValueError, msg):
+            Comment().delete()
+        with self.assertRaisesMessage(ValueError, msg):
+            Comment(tenant_id=1).delete()
+        with self.assertRaisesMessage(ValueError, msg):
+            Comment(id=1).delete()

+ 412 - 0
tests/composite_pk/test_filter.py

@@ -0,0 +1,412 @@
+from django.test import TestCase
+
+from .models import Comment, Tenant, User
+
+
+class CompositePKFilterTests(TestCase):
+    maxDiff = None
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.tenant_1 = Tenant.objects.create()
+        cls.tenant_2 = Tenant.objects.create()
+        cls.tenant_3 = Tenant.objects.create()
+        cls.user_1 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=1,
+            email="user0001@example.com",
+        )
+        cls.user_2 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=2,
+            email="user0002@example.com",
+        )
+        cls.user_3 = User.objects.create(
+            tenant=cls.tenant_2,
+            id=3,
+            email="user0003@example.com",
+        )
+        cls.user_4 = User.objects.create(
+            tenant=cls.tenant_3,
+            id=4,
+            email="user0004@example.com",
+        )
+        cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
+        cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
+        cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
+        cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3)
+        cls.comment_5 = Comment.objects.create(id=5, user=cls.user_1)
+
+    def test_filter_and_count_user_by_pk(self):
+        test_cases = (
+            ({"pk": self.user_1.pk}, 1),
+            ({"pk": self.user_2.pk}, 1),
+            ({"pk": self.user_3.pk}, 1),
+            ({"pk": (self.tenant_1.id, self.user_1.id)}, 1),
+            ({"pk": (self.tenant_1.id, self.user_2.id)}, 1),
+            ({"pk": (self.tenant_2.id, self.user_3.id)}, 1),
+            ({"pk": (self.tenant_1.id, self.user_3.id)}, 0),
+            ({"pk": (self.tenant_2.id, self.user_1.id)}, 0),
+            ({"pk": (self.tenant_2.id, self.user_2.id)}, 0),
+        )
+
+        for lookup, count in test_cases:
+            with self.subTest(lookup=lookup, count=count):
+                self.assertEqual(User.objects.filter(**lookup).count(), count)
+
+    def test_order_comments_by_pk_asc(self):
+        self.assertSequenceEqual(
+            Comment.objects.order_by("pk"),
+            (
+                self.comment_1,  # (1, 1)
+                self.comment_2,  # (1, 2)
+                self.comment_3,  # (1, 3)
+                self.comment_5,  # (1, 5)
+                self.comment_4,  # (2, 4)
+            ),
+        )
+
+    def test_order_comments_by_pk_desc(self):
+        self.assertSequenceEqual(
+            Comment.objects.order_by("-pk"),
+            (
+                self.comment_4,  # (2, 4)
+                self.comment_5,  # (1, 5)
+                self.comment_3,  # (1, 3)
+                self.comment_2,  # (1, 2)
+                self.comment_1,  # (1, 1)
+            ),
+        )
+
+    def test_filter_comments_by_pk_gt(self):
+        c11, c12, c13, c24, c15 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+        test_cases = (
+            (c11, (c12, c13, c15, c24)),
+            (c12, (c13, c15, c24)),
+            (c13, (c15, c24)),
+            (c15, (c24,)),
+            (c24, ()),
+        )
+
+        for obj, objs in test_cases:
+            with self.subTest(obj=obj, objs=objs):
+                self.assertSequenceEqual(
+                    Comment.objects.filter(pk__gt=obj.pk).order_by("pk"), objs
+                )
+
+    def test_filter_comments_by_pk_gte(self):
+        c11, c12, c13, c24, c15 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+        test_cases = (
+            (c11, (c11, c12, c13, c15, c24)),
+            (c12, (c12, c13, c15, c24)),
+            (c13, (c13, c15, c24)),
+            (c15, (c15, c24)),
+            (c24, (c24,)),
+        )
+
+        for obj, objs in test_cases:
+            with self.subTest(obj=obj, objs=objs):
+                self.assertSequenceEqual(
+                    Comment.objects.filter(pk__gte=obj.pk).order_by("pk"), objs
+                )
+
+    def test_filter_comments_by_pk_lt(self):
+        c11, c12, c13, c24, c15 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+        test_cases = (
+            (c24, (c11, c12, c13, c15)),
+            (c15, (c11, c12, c13)),
+            (c13, (c11, c12)),
+            (c12, (c11,)),
+            (c11, ()),
+        )
+
+        for obj, objs in test_cases:
+            with self.subTest(obj=obj, objs=objs):
+                self.assertSequenceEqual(
+                    Comment.objects.filter(pk__lt=obj.pk).order_by("pk"), objs
+                )
+
+    def test_filter_comments_by_pk_lte(self):
+        c11, c12, c13, c24, c15 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+        test_cases = (
+            (c24, (c11, c12, c13, c15, c24)),
+            (c15, (c11, c12, c13, c15)),
+            (c13, (c11, c12, c13)),
+            (c12, (c11, c12)),
+            (c11, (c11,)),
+        )
+
+        for obj, objs in test_cases:
+            with self.subTest(obj=obj, objs=objs):
+                self.assertSequenceEqual(
+                    Comment.objects.filter(pk__lte=obj.pk).order_by("pk"), objs
+                )
+
+    def test_filter_comments_by_pk_in(self):
+        test_cases = (
+            (),
+            (self.comment_1,),
+            (self.comment_1, self.comment_4),
+        )
+
+        for objs in test_cases:
+            with self.subTest(objs=objs):
+                pks = [obj.pk for obj in objs]
+                self.assertSequenceEqual(
+                    Comment.objects.filter(pk__in=pks).order_by("pk"), objs
+                )
+
+    def test_filter_comments_by_user_and_order_by_pk_asc(self):
+        self.assertSequenceEqual(
+            Comment.objects.filter(user=self.user_1).order_by("pk"),
+            (self.comment_1, self.comment_2, self.comment_5),
+        )
+
+    def test_filter_comments_by_user_and_order_by_pk_desc(self):
+        self.assertSequenceEqual(
+            Comment.objects.filter(user=self.user_1).order_by("-pk"),
+            (self.comment_5, self.comment_2, self.comment_1),
+        )
+
+    def test_filter_comments_by_user_and_exclude_by_pk(self):
+        self.assertSequenceEqual(
+            Comment.objects.filter(user=self.user_1)
+            .exclude(pk=self.comment_1.pk)
+            .order_by("pk"),
+            (self.comment_2, self.comment_5),
+        )
+
+    def test_filter_comments_by_user_and_contains(self):
+        self.assertIs(
+            Comment.objects.filter(user=self.user_1).contains(self.comment_1), True
+        )
+
+    def test_filter_users_by_comments_in(self):
+        c1, c2, c3, c4, c5 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+        u1, u2, u3 = (
+            self.user_1,
+            self.user_2,
+            self.user_3,
+        )
+        test_cases = (
+            ((), ()),
+            ((c1,), (u1,)),
+            ((c1, c2), (u1, u1)),
+            ((c1, c2, c3), (u1, u1, u2)),
+            ((c1, c2, c3, c4), (u1, u1, u2, u3)),
+            ((c1, c2, c3, c4, c5), (u1, u1, u1, u2, u3)),
+        )
+
+        for comments, users in test_cases:
+            with self.subTest(comments=comments, users=users):
+                self.assertSequenceEqual(
+                    User.objects.filter(comments__in=comments).order_by("pk"), users
+                )
+
+    def test_filter_users_by_comments_lt(self):
+        c11, c12, c13, c24, c15 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+        u1, u2 = (
+            self.user_1,
+            self.user_2,
+        )
+        test_cases = (
+            (c11, ()),
+            (c12, (u1,)),
+            (c13, (u1, u1)),
+            (c15, (u1, u1, u2)),
+            (c24, (u1, u1, u1, u2)),
+        )
+
+        for comment, users in test_cases:
+            with self.subTest(comment=comment, users=users):
+                self.assertSequenceEqual(
+                    User.objects.filter(comments__lt=comment).order_by("pk"), users
+                )
+
+    def test_filter_users_by_comments_lte(self):
+        c11, c12, c13, c24, c15 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+        u1, u2, u3 = (
+            self.user_1,
+            self.user_2,
+            self.user_3,
+        )
+        test_cases = (
+            (c11, (u1,)),
+            (c12, (u1, u1)),
+            (c13, (u1, u1, u2)),
+            (c15, (u1, u1, u1, u2)),
+            (c24, (u1, u1, u1, u2, u3)),
+        )
+
+        for comment, users in test_cases:
+            with self.subTest(comment=comment, users=users):
+                self.assertSequenceEqual(
+                    User.objects.filter(comments__lte=comment).order_by("pk"), users
+                )
+
+    def test_filter_users_by_comments_gt(self):
+        c11, c12, c13, c24, c15 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+        u1, u2, u3 = (
+            self.user_1,
+            self.user_2,
+            self.user_3,
+        )
+        test_cases = (
+            (c11, (u1, u1, u2, u3)),
+            (c12, (u1, u2, u3)),
+            (c13, (u1, u3)),
+            (c15, (u3,)),
+            (c24, ()),
+        )
+
+        for comment, users in test_cases:
+            with self.subTest(comment=comment, users=users):
+                self.assertSequenceEqual(
+                    User.objects.filter(comments__gt=comment).order_by("pk"), users
+                )
+
+    def test_filter_users_by_comments_gte(self):
+        c11, c12, c13, c24, c15 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+        u1, u2, u3 = (
+            self.user_1,
+            self.user_2,
+            self.user_3,
+        )
+        test_cases = (
+            (c11, (u1, u1, u1, u2, u3)),
+            (c12, (u1, u1, u2, u3)),
+            (c13, (u1, u2, u3)),
+            (c15, (u1, u3)),
+            (c24, (u3,)),
+        )
+
+        for comment, users in test_cases:
+            with self.subTest(comment=comment, users=users):
+                self.assertSequenceEqual(
+                    User.objects.filter(comments__gte=comment).order_by("pk"), users
+                )
+
+    def test_filter_users_by_comments_exact(self):
+        c11, c12, c13, c24, c15 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+        u1, u2, u3 = (
+            self.user_1,
+            self.user_2,
+            self.user_3,
+        )
+        test_cases = (
+            (c11, (u1,)),
+            (c12, (u1,)),
+            (c13, (u2,)),
+            (c15, (u1,)),
+            (c24, (u3,)),
+        )
+
+        for comment, users in test_cases:
+            with self.subTest(comment=comment, users=users):
+                self.assertSequenceEqual(
+                    User.objects.filter(comments=comment).order_by("pk"), users
+                )
+
+    def test_filter_users_by_comments_isnull(self):
+        u1, u2, u3, u4 = (
+            self.user_1,
+            self.user_2,
+            self.user_3,
+            self.user_4,
+        )
+
+        with self.subTest("comments__isnull=True"):
+            self.assertSequenceEqual(
+                User.objects.filter(comments__isnull=True).order_by("pk"),
+                (u4,),
+            )
+        with self.subTest("comments__isnull=False"):
+            self.assertSequenceEqual(
+                User.objects.filter(comments__isnull=False).order_by("pk"),
+                (u1, u1, u1, u2, u3),
+            )
+
+    def test_filter_comments_by_pk_isnull(self):
+        c11, c12, c13, c24, c15 = (
+            self.comment_1,
+            self.comment_2,
+            self.comment_3,
+            self.comment_4,
+            self.comment_5,
+        )
+
+        with self.subTest("pk__isnull=True"):
+            self.assertSequenceEqual(
+                Comment.objects.filter(pk__isnull=True).order_by("pk"),
+                (),
+            )
+        with self.subTest("pk__isnull=False"):
+            self.assertSequenceEqual(
+                Comment.objects.filter(pk__isnull=False).order_by("pk"),
+                (c11, c12, c13, c15, c24),
+            )
+
+    def test_filter_users_by_comments_subquery(self):
+        subquery = Comment.objects.filter(id=3).only("pk")
+        queryset = User.objects.filter(comments__in=subquery)
+        self.assertSequenceEqual(queryset, (self.user_2,))

+ 126 - 0
tests/composite_pk/test_get.py

@@ -0,0 +1,126 @@
+from django.test import TestCase
+
+from .models import Comment, Tenant, User
+
+
+class CompositePKGetTests(TestCase):
+    maxDiff = None
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.tenant_1 = Tenant.objects.create()
+        cls.tenant_2 = Tenant.objects.create()
+        cls.user_1 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=1,
+            email="user0001@example.com",
+        )
+        cls.user_2 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=2,
+            email="user0002@example.com",
+        )
+        cls.user_3 = User.objects.create(
+            tenant=cls.tenant_2,
+            id=3,
+            email="user0003@example.com",
+        )
+        cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
+
+    def test_get_user(self):
+        test_cases = (
+            {"pk": self.user_1.pk},
+            {"pk": (self.tenant_1.id, self.user_1.id)},
+            {"id": self.user_1.id},
+        )
+
+        for lookup in test_cases:
+            with self.subTest(lookup=lookup):
+                self.assertEqual(User.objects.get(**lookup), self.user_1)
+
+    def test_get_comment(self):
+        test_cases = (
+            {"pk": self.comment_1.pk},
+            {"pk": (self.tenant_1.id, self.comment_1.id)},
+            {"id": self.comment_1.id},
+            {"user": self.user_1},
+            {"user_id": self.user_1.id},
+            {"user__id": self.user_1.id},
+            {"user__pk": self.user_1.pk},
+            {"tenant": self.tenant_1},
+            {"tenant_id": self.tenant_1.id},
+            {"tenant__id": self.tenant_1.id},
+            {"tenant__pk": self.tenant_1.pk},
+        )
+
+        for lookup in test_cases:
+            with self.subTest(lookup=lookup):
+                self.assertEqual(Comment.objects.get(**lookup), self.comment_1)
+
+    def test_get_or_create_user(self):
+        test_cases = (
+            {
+                "pk": self.user_1.pk,
+                "defaults": {"email": "user9201@example.com"},
+            },
+            {
+                "pk": (self.tenant_1.id, self.user_1.id),
+                "defaults": {"email": "user9201@example.com"},
+            },
+            {
+                "tenant": self.tenant_1,
+                "id": self.user_1.id,
+                "defaults": {"email": "user3512@example.com"},
+            },
+            {
+                "tenant_id": self.tenant_1.id,
+                "id": self.user_1.id,
+                "defaults": {"email": "user8239@example.com"},
+            },
+        )
+
+        for fields in test_cases:
+            with self.subTest(fields=fields):
+                count = User.objects.count()
+                user, created = User.objects.get_or_create(**fields)
+                self.assertIs(created, False)
+                self.assertEqual(user.id, self.user_1.id)
+                self.assertEqual(user.pk, (self.tenant_1.id, self.user_1.id))
+                self.assertEqual(user.tenant_id, self.tenant_1.id)
+                self.assertEqual(user.email, self.user_1.email)
+                self.assertEqual(count, User.objects.count())
+
+    def test_lookup_errors(self):
+        m_tuple = "'%s' lookup of 'pk' must be a tuple or a list"
+        m_2_elements = "'%s' lookup of 'pk' must have 2 elements"
+        m_tuple_collection = (
+            "'in' lookup of 'pk' must be a collection of tuples or lists"
+        )
+        m_2_elements_each = "'in' lookup of 'pk' must have 2 elements each"
+        test_cases = (
+            ({"pk": 1}, m_tuple % "exact"),
+            ({"pk": (1, 2, 3)}, m_2_elements % "exact"),
+            ({"pk__exact": 1}, m_tuple % "exact"),
+            ({"pk__exact": (1, 2, 3)}, m_2_elements % "exact"),
+            ({"pk__in": 1}, m_tuple % "in"),
+            ({"pk__in": (1, 2, 3)}, m_tuple_collection),
+            ({"pk__in": ((1, 2, 3),)}, m_2_elements_each),
+            ({"pk__gt": 1}, m_tuple % "gt"),
+            ({"pk__gt": (1, 2, 3)}, m_2_elements % "gt"),
+            ({"pk__gte": 1}, m_tuple % "gte"),
+            ({"pk__gte": (1, 2, 3)}, m_2_elements % "gte"),
+            ({"pk__lt": 1}, m_tuple % "lt"),
+            ({"pk__lt": (1, 2, 3)}, m_2_elements % "lt"),
+            ({"pk__lte": 1}, m_tuple % "lte"),
+            ({"pk__lte": (1, 2, 3)}, m_2_elements % "lte"),
+        )
+
+        for kwargs, message in test_cases:
+            with (
+                self.subTest(kwargs=kwargs),
+                self.assertRaisesMessage(ValueError, message),
+            ):
+                Comment.objects.get(**kwargs)
+
+    def test_get_user_by_comments(self):
+        self.assertEqual(User.objects.get(comments=self.comment_1), self.user_1)

+ 153 - 0
tests/composite_pk/test_models.py

@@ -0,0 +1,153 @@
+from django.contrib.contenttypes.models import ContentType
+from django.core.exceptions import ValidationError
+from django.test import TestCase
+
+from .models import Comment, Tenant, Token, User
+
+
+class CompositePKModelsTests(TestCase):
+    @classmethod
+    def setUpTestData(cls):
+        cls.tenant_1 = Tenant.objects.create()
+        cls.tenant_2 = Tenant.objects.create()
+        cls.user_1 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=1,
+            email="user0001@example.com",
+        )
+        cls.user_2 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=2,
+            email="user0002@example.com",
+        )
+        cls.user_3 = User.objects.create(
+            tenant=cls.tenant_2,
+            id=3,
+            email="user0003@example.com",
+        )
+        cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
+        cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
+        cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
+        cls.comment_4 = Comment.objects.create(id=4, user=cls.user_3)
+
+    def test_fields(self):
+        # tenant_1
+        self.assertSequenceEqual(
+            self.tenant_1.user_set.order_by("pk"),
+            [self.user_1, self.user_2],
+        )
+        self.assertSequenceEqual(
+            self.tenant_1.comments.order_by("pk"),
+            [self.comment_1, self.comment_2, self.comment_3],
+        )
+
+        # tenant_2
+        self.assertSequenceEqual(self.tenant_2.user_set.order_by("pk"), [self.user_3])
+        self.assertSequenceEqual(
+            self.tenant_2.comments.order_by("pk"), [self.comment_4]
+        )
+
+        # user_1
+        self.assertEqual(self.user_1.id, 1)
+        self.assertEqual(self.user_1.tenant_id, self.tenant_1.id)
+        self.assertEqual(self.user_1.tenant, self.tenant_1)
+        self.assertEqual(self.user_1.pk, (self.tenant_1.id, self.user_1.id))
+        self.assertSequenceEqual(
+            self.user_1.comments.order_by("pk"), [self.comment_1, self.comment_2]
+        )
+
+        # user_2
+        self.assertEqual(self.user_2.id, 2)
+        self.assertEqual(self.user_2.tenant_id, self.tenant_1.id)
+        self.assertEqual(self.user_2.tenant, self.tenant_1)
+        self.assertEqual(self.user_2.pk, (self.tenant_1.id, self.user_2.id))
+        self.assertSequenceEqual(self.user_2.comments.order_by("pk"), [self.comment_3])
+
+        # comment_1
+        self.assertEqual(self.comment_1.id, 1)
+        self.assertEqual(self.comment_1.user_id, self.user_1.id)
+        self.assertEqual(self.comment_1.user, self.user_1)
+        self.assertEqual(self.comment_1.tenant_id, self.tenant_1.id)
+        self.assertEqual(self.comment_1.tenant, self.tenant_1)
+        self.assertEqual(self.comment_1.pk, (self.tenant_1.id, self.user_1.id))
+
+    def test_full_clean_success(self):
+        test_cases = (
+            # 1, 1234, {}
+            ({"tenant": self.tenant_1, "id": 1234}, {}),
+            ({"tenant_id": self.tenant_1.id, "id": 1234}, {}),
+            ({"pk": (self.tenant_1.id, 1234)}, {}),
+            # 1, 1, {"id"}
+            ({"tenant": self.tenant_1, "id": 1}, {"id"}),
+            ({"tenant_id": self.tenant_1.id, "id": 1}, {"id"}),
+            ({"pk": (self.tenant_1.id, 1)}, {"id"}),
+            # 1, 1, {"tenant", "id"}
+            ({"tenant": self.tenant_1, "id": 1}, {"tenant", "id"}),
+            ({"tenant_id": self.tenant_1.id, "id": 1}, {"tenant", "id"}),
+            ({"pk": (self.tenant_1.id, 1)}, {"tenant", "id"}),
+        )
+
+        for kwargs, exclude in test_cases:
+            with self.subTest(kwargs):
+                kwargs["email"] = "user0004@example.com"
+                User(**kwargs).full_clean(exclude=exclude)
+
+    def test_full_clean_failure(self):
+        e_tenant_and_id = "User with this Tenant and Id already exists."
+        e_id = "User with this Id already exists."
+        test_cases = (
+            # 1, 1, {}
+            ({"tenant": self.tenant_1, "id": 1}, {}, (e_tenant_and_id, e_id)),
+            ({"tenant_id": self.tenant_1.id, "id": 1}, {}, (e_tenant_and_id, e_id)),
+            ({"pk": (self.tenant_1.id, 1)}, {}, (e_tenant_and_id, e_id)),
+            # 2, 1, {}
+            ({"tenant": self.tenant_2, "id": 1}, {}, (e_id,)),
+            ({"tenant_id": self.tenant_2.id, "id": 1}, {}, (e_id,)),
+            ({"pk": (self.tenant_2.id, 1)}, {}, (e_id,)),
+            # 1, 1, {"tenant"}
+            ({"tenant": self.tenant_1, "id": 1}, {"tenant"}, (e_id,)),
+            ({"tenant_id": self.tenant_1.id, "id": 1}, {"tenant"}, (e_id,)),
+            ({"pk": (self.tenant_1.id, 1)}, {"tenant"}, (e_id,)),
+        )
+
+        for kwargs, exclude, messages in test_cases:
+            with self.subTest(kwargs):
+                with self.assertRaises(ValidationError) as ctx:
+                    kwargs["email"] = "user0004@example.com"
+                    User(**kwargs).full_clean(exclude=exclude)
+
+                self.assertSequenceEqual(ctx.exception.messages, messages)
+
+    def test_field_conflicts(self):
+        test_cases = (
+            ({"pk": (1, 1), "id": 2}, (1, 1)),
+            ({"id": 2, "pk": (1, 1)}, (1, 1)),
+            ({"pk": (1, 1), "tenant_id": 2}, (1, 1)),
+            ({"tenant_id": 2, "pk": (1, 1)}, (1, 1)),
+            ({"pk": (2, 2), "tenant_id": 3, "id": 4}, (2, 2)),
+            ({"tenant_id": 3, "id": 4, "pk": (2, 2)}, (2, 2)),
+        )
+
+        for kwargs, pk in test_cases:
+            with self.subTest(kwargs=kwargs):
+                user = User(**kwargs)
+                self.assertEqual(user.pk, pk)
+
+    def test_validate_unique(self):
+        user = User.objects.get(pk=self.user_1.pk)
+        user.id = None
+
+        with self.assertRaises(ValidationError) as ctx:
+            user.validate_unique()
+
+        self.assertSequenceEqual(
+            ctx.exception.messages, ("User with this Email already exists.",)
+        )
+
+    def test_permissions(self):
+        token = ContentType.objects.get_for_model(Token)
+        user = ContentType.objects.get_for_model(User)
+        comment = ContentType.objects.get_for_model(Comment)
+        self.assertEqual(4, token.permission_set.count())
+        self.assertEqual(4, user.permission_set.count())
+        self.assertEqual(4, comment.permission_set.count())

+ 134 - 0
tests/composite_pk/test_names_to_path.py

@@ -0,0 +1,134 @@
+from django.db.models.query_utils import PathInfo
+from django.db.models.sql import Query
+from django.test import TestCase
+
+from .models import Comment, Tenant, User
+
+
+class NamesToPathTests(TestCase):
+    def test_id(self):
+        query = Query(User)
+        path, final_field, targets, rest = query.names_to_path(["id"], User._meta)
+
+        self.assertEqual(path, [])
+        self.assertEqual(final_field, User._meta.get_field("id"))
+        self.assertEqual(targets, (User._meta.get_field("id"),))
+        self.assertEqual(rest, [])
+
+    def test_pk(self):
+        query = Query(User)
+        path, final_field, targets, rest = query.names_to_path(["pk"], User._meta)
+
+        self.assertEqual(path, [])
+        self.assertEqual(final_field, User._meta.get_field("pk"))
+        self.assertEqual(targets, (User._meta.get_field("pk"),))
+        self.assertEqual(rest, [])
+
+    def test_tenant_id(self):
+        query = Query(User)
+        path, final_field, targets, rest = query.names_to_path(
+            ["tenant", "id"], User._meta
+        )
+
+        self.assertEqual(
+            path,
+            [
+                PathInfo(
+                    from_opts=User._meta,
+                    to_opts=Tenant._meta,
+                    target_fields=(Tenant._meta.get_field("id"),),
+                    join_field=User._meta.get_field("tenant"),
+                    m2m=False,
+                    direct=True,
+                    filtered_relation=None,
+                ),
+            ],
+        )
+        self.assertEqual(final_field, Tenant._meta.get_field("id"))
+        self.assertEqual(targets, (Tenant._meta.get_field("id"),))
+        self.assertEqual(rest, [])
+
+    def test_user_id(self):
+        query = Query(Comment)
+        path, final_field, targets, rest = query.names_to_path(
+            ["user", "id"], Comment._meta
+        )
+
+        self.assertEqual(
+            path,
+            [
+                PathInfo(
+                    from_opts=Comment._meta,
+                    to_opts=User._meta,
+                    target_fields=(
+                        User._meta.get_field("tenant"),
+                        User._meta.get_field("id"),
+                    ),
+                    join_field=Comment._meta.get_field("user"),
+                    m2m=False,
+                    direct=True,
+                    filtered_relation=None,
+                ),
+            ],
+        )
+        self.assertEqual(final_field, User._meta.get_field("id"))
+        self.assertEqual(targets, (User._meta.get_field("id"),))
+        self.assertEqual(rest, [])
+
+    def test_user_tenant_id(self):
+        query = Query(Comment)
+        path, final_field, targets, rest = query.names_to_path(
+            ["user", "tenant", "id"], Comment._meta
+        )
+
+        self.assertEqual(
+            path,
+            [
+                PathInfo(
+                    from_opts=Comment._meta,
+                    to_opts=User._meta,
+                    target_fields=(
+                        User._meta.get_field("tenant"),
+                        User._meta.get_field("id"),
+                    ),
+                    join_field=Comment._meta.get_field("user"),
+                    m2m=False,
+                    direct=True,
+                    filtered_relation=None,
+                ),
+                PathInfo(
+                    from_opts=User._meta,
+                    to_opts=Tenant._meta,
+                    target_fields=(Tenant._meta.get_field("id"),),
+                    join_field=User._meta.get_field("tenant"),
+                    m2m=False,
+                    direct=True,
+                    filtered_relation=None,
+                ),
+            ],
+        )
+        self.assertEqual(final_field, Tenant._meta.get_field("id"))
+        self.assertEqual(targets, (Tenant._meta.get_field("id"),))
+        self.assertEqual(rest, [])
+
+    def test_comments(self):
+        query = Query(User)
+        path, final_field, targets, rest = query.names_to_path(["comments"], User._meta)
+
+        self.assertEqual(
+            path,
+            [
+                PathInfo(
+                    from_opts=User._meta,
+                    to_opts=Comment._meta,
+                    target_fields=(Comment._meta.get_field("pk"),),
+                    join_field=User._meta.get_field("comments"),
+                    m2m=True,
+                    direct=False,
+                    filtered_relation=None,
+                ),
+            ],
+        )
+        self.assertEqual(final_field, User._meta.get_field("comments"))
+        self.assertEqual(targets, (Comment._meta.get_field("pk"),))
+        self.assertEqual(rest, [])

+ 135 - 0
tests/composite_pk/test_update.py

@@ -0,0 +1,135 @@
+from django.test import TestCase
+
+from .models import Comment, Tenant, Token, User
+
+
+class CompositePKUpdateTests(TestCase):
+    maxDiff = None
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.tenant_1 = Tenant.objects.create(name="A")
+        cls.tenant_2 = Tenant.objects.create(name="B")
+        cls.user_1 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=1,
+            email="user0001@example.com",
+        )
+        cls.user_2 = User.objects.create(
+            tenant=cls.tenant_1,
+            id=2,
+            email="user0002@example.com",
+        )
+        cls.user_3 = User.objects.create(
+            tenant=cls.tenant_2,
+            id=3,
+            email="user0003@example.com",
+        )
+        cls.comment_1 = Comment.objects.create(id=1, user=cls.user_1)
+        cls.comment_2 = Comment.objects.create(id=2, user=cls.user_1)
+        cls.comment_3 = Comment.objects.create(id=3, user=cls.user_2)
+        cls.token_1 = Token.objects.create(id=1, tenant=cls.tenant_1)
+        cls.token_2 = Token.objects.create(id=2, tenant=cls.tenant_2)
+        cls.token_3 = Token.objects.create(id=3, tenant=cls.tenant_1)
+        cls.token_4 = Token.objects.create(id=4, tenant=cls.tenant_2)
+
+    def test_update_user(self):
+        email = "user9315@example.com"
+        result = User.objects.filter(pk=self.user_1.pk).update(email=email)
+        self.assertEqual(result, 1)
+        user = User.objects.get(pk=self.user_1.pk)
+        self.assertEqual(user.email, email)
+
+    def test_save_user(self):
+        count = User.objects.count()
+        email = "user9314@example.com"
+        user = User.objects.get(pk=self.user_1.pk)
+        user.email = email
+        user.save()
+        user.refresh_from_db()
+        self.assertEqual(user.email, email)
+        user = User.objects.get(pk=self.user_1.pk)
+        self.assertEqual(user.email, email)
+        self.assertEqual(count, User.objects.count())
+
+    def test_bulk_update_comments(self):
+        comment_1 = Comment.objects.get(pk=self.comment_1.pk)
+        comment_2 = Comment.objects.get(pk=self.comment_2.pk)
+        comment_3 = Comment.objects.get(pk=self.comment_3.pk)
+        comment_1.text = "foo"
+        comment_2.text = "bar"
+        comment_3.text = "baz"
+
+        result = Comment.objects.bulk_update(
+            [comment_1, comment_2, comment_3], ["text"]
+        )
+
+        self.assertEqual(result, 3)
+        comment_1 = Comment.objects.get(pk=self.comment_1.pk)
+        comment_2 = Comment.objects.get(pk=self.comment_2.pk)
+        comment_3 = Comment.objects.get(pk=self.comment_3.pk)
+        self.assertEqual(comment_1.text, "foo")
+        self.assertEqual(comment_2.text, "bar")
+        self.assertEqual(comment_3.text, "baz")
+
+    def test_update_or_create_user(self):
+        test_cases = (
+            {
+                "pk": self.user_1.pk,
+                "defaults": {"email": "user3914@example.com"},
+            },
+            {
+                "pk": (self.tenant_1.id, self.user_1.id),
+                "defaults": {"email": "user9375@example.com"},
+            },
+            {
+                "tenant": self.tenant_1,
+                "id": self.user_1.id,
+                "defaults": {"email": "user3517@example.com"},
+            },
+            {
+                "tenant_id": self.tenant_1.id,
+                "id": self.user_1.id,
+                "defaults": {"email": "user8391@example.com"},
+            },
+        )
+
+        for fields in test_cases:
+            with self.subTest(fields=fields):
+                count = User.objects.count()
+                user, created = User.objects.update_or_create(**fields)
+                self.assertIs(created, False)
+                self.assertEqual(user.id, self.user_1.id)
+                self.assertEqual(user.pk, (self.tenant_1.id, self.user_1.id))
+                self.assertEqual(user.tenant_id, self.tenant_1.id)
+                self.assertEqual(user.email, fields["defaults"]["email"])
+                self.assertEqual(count, User.objects.count())
+
+    def test_update_comment_by_user_email(self):
+        result = Comment.objects.filter(user__email=self.user_1.email).update(
+            text="foo"
+        )
+
+        self.assertEqual(result, 2)
+        comment_1 = Comment.objects.get(pk=self.comment_1.pk)
+        comment_2 = Comment.objects.get(pk=self.comment_2.pk)
+        self.assertEqual(comment_1.text, "foo")
+        self.assertEqual(comment_2.text, "foo")
+
+    def test_update_token_by_tenant_name(self):
+        result = Token.objects.filter(tenant__name="A").update(secret="bar")
+
+        self.assertEqual(result, 2)
+        token_1 = Token.objects.get(pk=self.token_1.pk)
+        self.assertEqual(token_1.secret, "bar")
+        token_3 = Token.objects.get(pk=self.token_3.pk)
+        self.assertEqual(token_3.secret, "bar")
+
+    def test_cant_update_to_unsaved_object(self):
+        msg = (
+            "Unsaved model instance <User: User object ((None, None))> cannot be used "
+            "in an ORM query."
+        )
+
+        with self.assertRaisesMessage(ValueError, msg):
+            Comment.objects.update(user=User())

+ 212 - 0
tests/composite_pk/test_values.py

@@ -0,0 +1,212 @@
+from collections import namedtuple
+from uuid import UUID
+
+from django.test import TestCase
+
+from .models import Post, Tenant, User
+
+
+class CompositePKValuesTests(TestCase):
+    USER_1_EMAIL = "user0001@example.com"
+    USER_2_EMAIL = "user0002@example.com"
+    USER_3_EMAIL = "user0003@example.com"
+    POST_1_ID = "77777777-7777-7777-7777-777777777777"
+    POST_2_ID = "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"
+    POST_3_ID = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"
+
+    @classmethod
+    def setUpTestData(cls):
+        super().setUpTestData()
+        cls.tenant_1 = Tenant.objects.create()
+        cls.tenant_2 = Tenant.objects.create()
+        cls.user_1 = User.objects.create(
+            tenant=cls.tenant_1, id=1, email=cls.USER_1_EMAIL
+        )
+        cls.user_2 = User.objects.create(
+            tenant=cls.tenant_1, id=2, email=cls.USER_2_EMAIL
+        )
+        cls.user_3 = User.objects.create(
+            tenant=cls.tenant_2, id=3, email=cls.USER_3_EMAIL
+        )
+        cls.post_1 = Post.objects.create(tenant=cls.tenant_1, id=cls.POST_1_ID)
+        cls.post_2 = Post.objects.create(tenant=cls.tenant_1, id=cls.POST_2_ID)
+        cls.post_3 = Post.objects.create(tenant=cls.tenant_2, id=cls.POST_3_ID)
+
+    def test_values_list(self):
+        with self.subTest('User.objects.values_list("pk")'):
+            self.assertSequenceEqual(
+                User.objects.values_list("pk").order_by("pk"),
+                (
+                    (self.user_1.pk,),
+                    (self.user_2.pk,),
+                    (self.user_3.pk,),
+                ),
+            )
+        with self.subTest('User.objects.values_list("pk", "email")'):
+            self.assertSequenceEqual(
+                User.objects.values_list("pk", "email").order_by("pk"),
+                (
+                    (self.user_1.pk, self.USER_1_EMAIL),
+                    (self.user_2.pk, self.USER_2_EMAIL),
+                    (self.user_3.pk, self.USER_3_EMAIL),
+                ),
+            )
+        with self.subTest('User.objects.values_list("pk", "id")'):
+            self.assertSequenceEqual(
+                User.objects.values_list("pk", "id").order_by("pk"),
+                (
+                    (self.user_1.pk, self.user_1.id),
+                    (self.user_2.pk, self.user_2.id),
+                    (self.user_3.pk, self.user_3.id),
+                ),
+            )
+        with self.subTest('User.objects.values_list("pk", "tenant_id", "id")'):
+            self.assertSequenceEqual(
+                User.objects.values_list("pk", "tenant_id", "id").order_by("pk"),
+                (
+                    (self.user_1.pk, self.user_1.tenant_id, self.user_1.id),
+                    (self.user_2.pk, self.user_2.tenant_id, self.user_2.id),
+                    (self.user_3.pk, self.user_3.tenant_id, self.user_3.id),
+                ),
+            )
+        with self.subTest('User.objects.values_list("pk", flat=True)'):
+            self.assertSequenceEqual(
+                User.objects.values_list("pk", flat=True).order_by("pk"),
+                (
+                    self.user_1.pk,
+                    self.user_2.pk,
+                    self.user_3.pk,
+                ),
+            )
+        with self.subTest('Post.objects.values_list("pk", flat=True)'):
+            self.assertSequenceEqual(
+                Post.objects.values_list("pk", flat=True).order_by("pk"),
+                (
+                    (self.tenant_1.id, UUID(self.POST_1_ID)),
+                    (self.tenant_1.id, UUID(self.POST_2_ID)),
+                    (self.tenant_2.id, UUID(self.POST_3_ID)),
+                ),
+            )
+        with self.subTest('Post.objects.values_list("pk")'):
+            self.assertSequenceEqual(
+                Post.objects.values_list("pk").order_by("pk"),
+                (
+                    ((self.tenant_1.id, UUID(self.POST_1_ID)),),
+                    ((self.tenant_1.id, UUID(self.POST_2_ID)),),
+                    ((self.tenant_2.id, UUID(self.POST_3_ID)),),
+                ),
+            )
+        with self.subTest('Post.objects.values_list("pk", "id")'):
+            self.assertSequenceEqual(
+                Post.objects.values_list("pk", "id").order_by("pk"),
+                (
+                    ((self.tenant_1.id, UUID(self.POST_1_ID)), UUID(self.POST_1_ID)),
+                    ((self.tenant_1.id, UUID(self.POST_2_ID)), UUID(self.POST_2_ID)),
+                    ((self.tenant_2.id, UUID(self.POST_3_ID)), UUID(self.POST_3_ID)),
+                ),
+            )
+        with self.subTest('Post.objects.values_list("id", "pk")'):
+            self.assertSequenceEqual(
+                Post.objects.values_list("id", "pk").order_by("pk"),
+                (
+                    (UUID(self.POST_1_ID), (self.tenant_1.id, UUID(self.POST_1_ID))),
+                    (UUID(self.POST_2_ID), (self.tenant_1.id, UUID(self.POST_2_ID))),
+                    (UUID(self.POST_3_ID), (self.tenant_2.id, UUID(self.POST_3_ID))),
+                ),
+            )
+        with self.subTest('User.objects.values_list("pk", named=True)'):
+            Row = namedtuple("Row", ["pk"])
+            self.assertSequenceEqual(
+                User.objects.values_list("pk", named=True).order_by("pk"),
+                (
+                    Row(pk=self.user_1.pk),
+                    Row(pk=self.user_2.pk),
+                    Row(pk=self.user_3.pk),
+                ),
+            )
+        with self.subTest('User.objects.values_list("pk", "pk")'):
+            self.assertSequenceEqual(
+                User.objects.values_list("pk", "pk").order_by("pk"),
+                (
+                    (self.user_1.pk,),
+                    (self.user_2.pk,),
+                    (self.user_3.pk,),
+                ),
+            )
+        with self.subTest('User.objects.values_list("pk", "id", "pk", "id")'):
+            self.assertSequenceEqual(
+                User.objects.values_list("pk", "id", "pk", "id").order_by("pk"),
+                (
+                    (self.user_1.pk, self.user_1.id),
+                    (self.user_2.pk, self.user_2.id),
+                    (self.user_3.pk, self.user_3.id),
+                ),
+            )
+
+    def test_values(self):
+        with self.subTest('User.objects.values("pk")'):
+            self.assertSequenceEqual(
+                User.objects.values("pk").order_by("pk"),
+                (
+                    {"pk": self.user_1.pk},
+                    {"pk": self.user_2.pk},
+                    {"pk": self.user_3.pk},
+                ),
+            )
+        with self.subTest('User.objects.values("pk", "email")'):
+            self.assertSequenceEqual(
+                User.objects.values("pk", "email").order_by("pk"),
+                (
+                    {"pk": self.user_1.pk, "email": self.USER_1_EMAIL},
+                    {"pk": self.user_2.pk, "email": self.USER_2_EMAIL},
+                    {"pk": self.user_3.pk, "email": self.USER_3_EMAIL},
+                ),
+            )
+        with self.subTest('User.objects.values("pk", "id")'):
+            self.assertSequenceEqual(
+                User.objects.values("pk", "id").order_by("pk"),
+                (
+                    {"pk": self.user_1.pk, "id": self.user_1.id},
+                    {"pk": self.user_2.pk, "id": self.user_2.id},
+                    {"pk": self.user_3.pk, "id": self.user_3.id},
+                ),
+            )
+        with self.subTest('User.objects.values("pk", "tenant_id", "id")'):
+            self.assertSequenceEqual(
+                User.objects.values("pk", "tenant_id", "id").order_by("pk"),
+                (
+                    {
+                        "pk": self.user_1.pk,
+                        "tenant_id": self.user_1.tenant_id,
+                        "id": self.user_1.id,
+                    },
+                    {
+                        "pk": self.user_2.pk,
+                        "tenant_id": self.user_2.tenant_id,
+                        "id": self.user_2.id,
+                    },
+                    {
+                        "pk": self.user_3.pk,
+                        "tenant_id": self.user_3.tenant_id,
+                        "id": self.user_3.id,
+                    },
+                ),
+            )
+        with self.subTest('User.objects.values("pk", "pk")'):
+            self.assertSequenceEqual(
+                User.objects.values("pk", "pk").order_by("pk"),
+                (
+                    {"pk": self.user_1.pk},
+                    {"pk": self.user_2.pk},
+                    {"pk": self.user_3.pk},
+                ),
+            )
+        with self.subTest('User.objects.values("pk", "id", "pk", "id")'):
+            self.assertSequenceEqual(
+                User.objects.values("pk", "id", "pk", "id").order_by("pk"),
+                (
+                    {"pk": self.user_1.pk, "id": self.user_1.id},
+                    {"pk": self.user_2.pk, "id": self.user_2.id},
+                    {"pk": self.user_3.pk, "id": self.user_3.id},
+                ),
+            )

+ 345 - 0
tests/composite_pk/tests.py

@@ -0,0 +1,345 @@
+import json
+import unittest
+from uuid import UUID
+
+import yaml
+
+from django import forms
+from django.core import serializers
+from django.core.exceptions import FieldError
+from django.db import IntegrityError, connection
+from django.db.models import CompositePrimaryKey
+from django.forms import modelform_factory
+from django.test import TestCase
+
+from .models import Comment, Post, Tenant, User
+
+
+class CommentForm(forms.ModelForm):
+    class Meta:
+        model = Comment
+        fields = "__all__"
+
+
+class CompositePKTests(TestCase):
+    maxDiff = None
+
+    @classmethod
+    def setUpTestData(cls):
+        cls.tenant = Tenant.objects.create()
+        cls.user = User.objects.create(
+            tenant=cls.tenant,
+            id=1,
+            email="user0001@example.com",
+        )
+        cls.comment = Comment.objects.create(tenant=cls.tenant, id=1, user=cls.user)
+
+    @staticmethod
+    def get_constraints(table):
+        with connection.cursor() as cursor:
+            return connection.introspection.get_constraints(cursor, table)
+
+    def test_pk_updated_if_field_updated(self):
+        user = User.objects.get(pk=self.user.pk)
+        self.assertEqual(user.pk, (self.tenant.id, self.user.id))
+        self.assertIs(user._is_pk_set(), True)
+        user.tenant_id = 9831
+        self.assertEqual(user.pk, (9831, self.user.id))
+        self.assertIs(user._is_pk_set(), True)
+        user.id = 4321
+        self.assertEqual(user.pk, (9831, 4321))
+        self.assertIs(user._is_pk_set(), True)
+        user.pk = (9132, 3521)
+        self.assertEqual(user.tenant_id, 9132)
+        self.assertEqual(user.id, 3521)
+        self.assertIs(user._is_pk_set(), True)
+        user.id = None
+        self.assertEqual(user.pk, (9132, None))
+        self.assertEqual(user.tenant_id, 9132)
+        self.assertIsNone(user.id)
+        self.assertIs(user._is_pk_set(), False)
+
+    def test_hash(self):
+        self.assertEqual(hash(User(pk=(1, 2))), hash((1, 2)))
+        self.assertEqual(hash(User(tenant_id=2, id=3)), hash((2, 3)))
+        msg = "Model instances without primary key value are unhashable"
+
+        with self.assertRaisesMessage(TypeError, msg):
+            hash(User())
+        with self.assertRaisesMessage(TypeError, msg):
+            hash(User(tenant_id=1))
+        with self.assertRaisesMessage(TypeError, msg):
+            hash(User(id=1))
+
+    def test_pk_must_be_list_or_tuple(self):
+        user = User.objects.get(pk=self.user.pk)
+        test_cases = [
+            "foo",
+            1000,
+            3.14,
+            True,
+            False,
+        ]
+
+        for pk in test_cases:
+            with self.assertRaisesMessage(
+                ValueError, "'pk' must be a list or a tuple."
+            ):
+                user.pk = pk
+
+    def test_pk_must_have_2_elements(self):
+        user = User.objects.get(pk=self.user.pk)
+        test_cases = [
+            (),
+            [],
+            (1000,),
+            [1000],
+            (1, 2, 3),
+            [1, 2, 3],
+        ]
+
+        for pk in test_cases:
+            with self.assertRaisesMessage(ValueError, "'pk' must have 2 elements."):
+                user.pk = pk
+
+    def test_composite_pk_in_fields(self):
+        user_fields = {f.name for f in User._meta.get_fields()}
+        self.assertEqual(user_fields, {"pk", "tenant", "id", "email", "comments"})
+
+        comment_fields = {f.name for f in Comment._meta.get_fields()}
+        self.assertEqual(
+            comment_fields,
+            {"pk", "tenant", "id", "user_id", "user", "text"},
+        )
+
+    def test_pk_field(self):
+        pk = User._meta.get_field("pk")
+        self.assertIsInstance(pk, CompositePrimaryKey)
+        self.assertIs(User._meta.pk, pk)
+
+    def test_error_on_user_pk_conflict(self):
+        with self.assertRaises(IntegrityError):
+            User.objects.create(tenant=self.tenant, id=self.user.id)
+
+    def test_error_on_comment_pk_conflict(self):
+        with self.assertRaises(IntegrityError):
+            Comment.objects.create(tenant=self.tenant, id=self.comment.id)
+
+    @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL specific test")
+    def test_get_constraints_postgresql(self):
+        user_constraints = self.get_constraints(User._meta.db_table)
+        user_pk = user_constraints["composite_pk_user_pkey"]
+        self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+        self.assertIs(user_pk["primary_key"], True)
+
+        comment_constraints = self.get_constraints(Comment._meta.db_table)
+        comment_pk = comment_constraints["composite_pk_comment_pkey"]
+        self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+        self.assertIs(comment_pk["primary_key"], True)
+
+    @unittest.skipUnless(connection.vendor == "sqlite", "SQLite specific test")
+    def test_get_constraints_sqlite(self):
+        user_constraints = self.get_constraints(User._meta.db_table)
+        user_pk = user_constraints["__primary__"]
+        self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+        self.assertIs(user_pk["primary_key"], True)
+
+        comment_constraints = self.get_constraints(Comment._meta.db_table)
+        comment_pk = comment_constraints["__primary__"]
+        self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+        self.assertIs(comment_pk["primary_key"], True)
+
+    @unittest.skipUnless(connection.vendor == "mysql", "MySQL specific test")
+    def test_get_constraints_mysql(self):
+        user_constraints = self.get_constraints(User._meta.db_table)
+        user_pk = user_constraints["PRIMARY"]
+        self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+        self.assertIs(user_pk["primary_key"], True)
+
+        comment_constraints = self.get_constraints(Comment._meta.db_table)
+        comment_pk = comment_constraints["PRIMARY"]
+        self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+        self.assertIs(comment_pk["primary_key"], True)
+
+    @unittest.skipUnless(connection.vendor == "oracle", "Oracle specific test")
+    def test_get_constraints_oracle(self):
+        user_constraints = self.get_constraints(User._meta.db_table)
+        user_pk = next(c for c in user_constraints.values() if c["primary_key"])
+        self.assertEqual(user_pk["columns"], ["tenant_id", "id"])
+        self.assertEqual(user_pk["primary_key"], 1)
+
+        comment_constraints = self.get_constraints(Comment._meta.db_table)
+        comment_pk = next(c for c in comment_constraints.values() if c["primary_key"])
+        self.assertEqual(comment_pk["columns"], ["tenant_id", "comment_id"])
+        self.assertEqual(comment_pk["primary_key"], 1)
+
+    def test_in_bulk(self):
+        """
+        Test the .in_bulk() method of composite_pk models.
+        """
+        result = Comment.objects.in_bulk()
+        self.assertEqual(result, {self.comment.pk: self.comment})
+
+        result = Comment.objects.in_bulk([self.comment.pk])
+        self.assertEqual(result, {self.comment.pk: self.comment})
+
+    def test_iterator(self):
+        """
+        Test the .iterator() method of composite_pk models.
+        """
+        result = list(Comment.objects.iterator())
+        self.assertEqual(result, [self.comment])
+
+    def test_query(self):
+        users = User.objects.values_list("pk").order_by("pk")
+        self.assertNotIn('AS "pk"', str(users.query))
+
+    def test_only(self):
+        users = User.objects.only("pk")
+        self.assertSequenceEqual(users, (self.user,))
+        user = users[0]
+
+        with self.assertNumQueries(0):
+            self.assertEqual(user.pk, (self.user.tenant_id, self.user.id))
+            self.assertEqual(user.tenant_id, self.user.tenant_id)
+            self.assertEqual(user.id, self.user.id)
+        with self.assertNumQueries(1):
+            self.assertEqual(user.email, self.user.email)
+
+    def test_model_forms(self):
+        fields = ["tenant", "id", "user_id", "text"]
+        self.assertEqual(list(CommentForm.base_fields), fields)
+
+        form = modelform_factory(Comment, fields="__all__")
+        self.assertEqual(list(form().fields), fields)
+
+        with self.assertRaisesMessage(
+            FieldError, "Unknown field(s) (pk) specified for Comment"
+        ):
+            self.assertIsNone(modelform_factory(Comment, fields=["pk"]))
+
+
+class CompositePKFixturesTests(TestCase):
+    fixtures = ["tenant"]
+
+    def test_objects(self):
+        tenant_1, tenant_2, tenant_3 = Tenant.objects.order_by("pk")
+        self.assertEqual(tenant_1.id, 1)
+        self.assertEqual(tenant_1.name, "Tenant 1")
+        self.assertEqual(tenant_2.id, 2)
+        self.assertEqual(tenant_2.name, "Tenant 2")
+        self.assertEqual(tenant_3.id, 3)
+        self.assertEqual(tenant_3.name, "Tenant 3")
+
+        user_1, user_2, user_3, user_4 = User.objects.order_by("pk")
+        self.assertEqual(user_1.id, 1)
+        self.assertEqual(user_1.tenant_id, 1)
+        self.assertEqual(user_1.pk, (user_1.tenant_id, user_1.id))
+        self.assertEqual(user_1.email, "user0001@example.com")
+        self.assertEqual(user_2.id, 2)
+        self.assertEqual(user_2.tenant_id, 1)
+        self.assertEqual(user_2.pk, (user_2.tenant_id, user_2.id))
+        self.assertEqual(user_2.email, "user0002@example.com")
+        self.assertEqual(user_3.id, 3)
+        self.assertEqual(user_3.tenant_id, 2)
+        self.assertEqual(user_3.pk, (user_3.tenant_id, user_3.id))
+        self.assertEqual(user_3.email, "user0003@example.com")
+        self.assertEqual(user_4.id, 4)
+        self.assertEqual(user_4.tenant_id, 2)
+        self.assertEqual(user_4.pk, (user_4.tenant_id, user_4.id))
+        self.assertEqual(user_4.email, "user0004@example.com")
+
+        post_1, post_2 = Post.objects.order_by("pk")
+        self.assertEqual(post_1.id, UUID("11111111-1111-1111-1111-111111111111"))
+        self.assertEqual(post_1.tenant_id, 2)
+        self.assertEqual(post_1.pk, (post_1.tenant_id, post_1.id))
+        self.assertEqual(post_2.id, UUID("ffffffff-ffff-ffff-ffff-ffffffffffff"))
+        self.assertEqual(post_2.tenant_id, 2)
+        self.assertEqual(post_2.pk, (post_2.tenant_id, post_2.id))
+
+    def test_serialize_user_json(self):
+        users = User.objects.filter(pk=(1, 1))
+        result = serializers.serialize("json", users)
+        self.assertEqual(
+            json.loads(result),
+            [
+                {
+                    "model": "composite_pk.user",
+                    "pk": [1, 1],
+                    "fields": {
+                        "email": "user0001@example.com",
+                        "id": 1,
+                        "tenant": 1,
+                    },
+                }
+            ],
+        )
+
+    def test_serialize_user_jsonl(self):
+        users = User.objects.filter(pk=(1, 2))
+        result = serializers.serialize("jsonl", users)
+        self.assertEqual(
+            json.loads(result),
+            {
+                "model": "composite_pk.user",
+                "pk": [1, 2],
+                "fields": {
+                    "email": "user0002@example.com",
+                    "id": 2,
+                    "tenant": 1,
+                },
+            },
+        )
+
+    def test_serialize_user_yaml(self):
+        users = User.objects.filter(pk=(2, 3))
+        result = serializers.serialize("yaml", users)
+        self.assertEqual(
+            yaml.safe_load(result),
+            [
+                {
+                    "model": "composite_pk.user",
+                    "pk": [2, 3],
+                    "fields": {
+                        "email": "user0003@example.com",
+                        "id": 3,
+                        "tenant": 2,
+                    },
+                },
+            ],
+        )
+
+    def test_serialize_user_python(self):
+        users = User.objects.filter(pk=(2, 4))
+        result = serializers.serialize("python", users)
+        self.assertEqual(
+            result,
+            [
+                {
+                    "model": "composite_pk.user",
+                    "pk": [2, 4],
+                    "fields": {
+                        "email": "user0004@example.com",
+                        "id": 4,
+                        "tenant": 2,
+                    },
+                },
+            ],
+        )
+
+    def test_serialize_post_uuid(self):
+        posts = Post.objects.filter(pk=(2, "11111111-1111-1111-1111-111111111111"))
+        result = serializers.serialize("json", posts)
+        self.assertEqual(
+            json.loads(result),
+            [
+                {
+                    "model": "composite_pk.post",
+                    "pk": [2, "11111111-1111-1111-1111-111111111111"],
+                    "fields": {
+                        "id": "11111111-1111-1111-1111-111111111111",
+                        "tenant": 2,
+                    },
+                },
+            ],
+        )

+ 89 - 0
tests/migrations/test_autodetector.py

@@ -5059,6 +5059,95 @@ class AutodetectorTests(BaseAutodetectorTests):
         self.assertOperationTypes(changes, "testapp", 0, ["CreateModel"])
         self.assertOperationAttributes(changes, "testapp", 0, 0, name="Book")
 
+    @mock.patch(
+        "django.db.migrations.questioner.MigrationQuestioner.ask_not_null_addition"
+    )
+    def test_add_composite_pk(self, mocked_ask_method):
+        before = [
+            ModelState(
+                "app",
+                "foo",
+                [
+                    ("id", models.AutoField(primary_key=True)),
+                ],
+            ),
+        ]
+        after = [
+            ModelState(
+                "app",
+                "foo",
+                [
+                    ("pk", models.CompositePrimaryKey("foo_id", "bar_id")),
+                    ("id", models.IntegerField()),
+                ],
+            ),
+        ]
+
+        changes = self.get_changes(before, after)
+        self.assertEqual(mocked_ask_method.call_count, 0)
+        self.assertNumberMigrations(changes, "app", 1)
+        self.assertOperationTypes(changes, "app", 0, ["AddField", "AlterField"])
+        self.assertOperationAttributes(
+            changes,
+            "app",
+            0,
+            0,
+            name="pk",
+            model_name="foo",
+            preserve_default=True,
+        )
+        self.assertOperationAttributes(
+            changes,
+            "app",
+            0,
+            1,
+            name="id",
+            model_name="foo",
+            preserve_default=True,
+        )
+
+    def test_remove_composite_pk(self):
+        before = [
+            ModelState(
+                "app",
+                "foo",
+                [
+                    ("pk", models.CompositePrimaryKey("foo_id", "bar_id")),
+                    ("id", models.IntegerField()),
+                ],
+            ),
+        ]
+        after = [
+            ModelState(
+                "app",
+                "foo",
+                [
+                    ("id", models.AutoField(primary_key=True)),
+                ],
+            ),
+        ]
+
+        changes = self.get_changes(before, after)
+        self.assertNumberMigrations(changes, "app", 1)
+        self.assertOperationTypes(changes, "app", 0, ["RemoveField", "AlterField"])
+        self.assertOperationAttributes(
+            changes,
+            "app",
+            0,
+            0,
+            name="pk",
+            model_name="foo",
+        )
+        self.assertOperationAttributes(
+            changes,
+            "app",
+            0,
+            1,
+            name="id",
+            model_name="foo",
+            preserve_default=True,
+        )
+
 
 class MigrationSuggestNameTests(SimpleTestCase):
     def test_no_operations(self):

+ 55 - 0
tests/migrations/test_operations.py

@@ -6287,6 +6287,61 @@ class OperationTests(OperationTestBase):
         self.assertEqual(pony_new.generated, 1)
         self.assertEqual(pony_new.static, 2)
 
+    def test_composite_pk_operations(self):
+        app_label = "test_d8d90af6"
+        project_state = self.set_up_test_model(app_label)
+        operation_1 = migrations.AddField(
+            "Pony", "pk", models.CompositePrimaryKey("id", "pink")
+        )
+        operation_2 = migrations.AlterField("Pony", "id", models.IntegerField())
+        operation_3 = migrations.RemoveField("Pony", "pk")
+        table_name = f"{app_label}_pony"
+
+        # 1. Add field (pk).
+        new_state = project_state.clone()
+        operation_1.state_forwards(app_label, new_state)
+        with connection.schema_editor() as editor:
+            operation_1.database_forwards(app_label, editor, project_state, new_state)
+        self.assertColumnNotExists(table_name, "pk")
+        Pony = new_state.apps.get_model(app_label, "pony")
+        obj_1 = Pony.objects.create(weight=1)
+        msg = (
+            f"obj_1={obj_1}, "
+            f"obj_1.id={obj_1.id}, "
+            f"obj_1.pink={obj_1.pink}, "
+            f"obj_1.pk={obj_1.pk}, "
+            f"Pony._meta.pk={repr(Pony._meta.pk)}, "
+            f"Pony._meta.get_field('id')={repr(Pony._meta.get_field('id'))}"
+        )
+        self.assertEqual(obj_1.pink, 3, msg)
+        self.assertEqual(obj_1.pk, (obj_1.id, obj_1.pink), msg)
+
+        # 2. Alter field (id -> IntegerField()).
+        project_state, new_state = new_state, new_state.clone()
+        operation_2.state_forwards(app_label, new_state)
+        with connection.schema_editor() as editor:
+            operation_2.database_forwards(app_label, editor, project_state, new_state)
+        Pony = new_state.apps.get_model(app_label, "pony")
+        obj_1 = Pony.objects.get(id=obj_1.id)
+        self.assertEqual(obj_1.pink, 3)
+        self.assertEqual(obj_1.pk, (obj_1.id, obj_1.pink))
+        obj_2 = Pony.objects.create(id=2, weight=2)
+        self.assertEqual(obj_2.id, 2)
+        self.assertEqual(obj_2.pink, 3)
+        self.assertEqual(obj_2.pk, (obj_2.id, obj_2.pink))
+
+        # 3. Remove field (pk).
+        project_state, new_state = new_state, new_state.clone()
+        operation_3.state_forwards(app_label, new_state)
+        with connection.schema_editor() as editor:
+            operation_3.database_forwards(app_label, editor, project_state, new_state)
+        Pony = new_state.apps.get_model(app_label, "pony")
+        obj_1 = Pony.objects.get(id=obj_1.id)
+        self.assertEqual(obj_1.pk, obj_1.id)
+        obj_2 = Pony.objects.get(id=obj_2.id)
+        self.assertEqual(obj_2.id, 2)
+        self.assertEqual(obj_2.pk, obj_2.id)
+
 
 class SwappableOperationTests(OperationTestBase):
     """

+ 22 - 0
tests/migrations/test_state.py

@@ -1206,6 +1206,28 @@ class StateTests(SimpleTestCase):
         choices_field = Author._meta.get_field("choice")
         self.assertEqual(list(choices_field.choices), choices)
 
+    def test_composite_pk_state(self):
+        new_apps = Apps(["migrations"])
+
+        class Foo(models.Model):
+            pk = models.CompositePrimaryKey("account_id", "id")
+            account_id = models.SmallIntegerField()
+            id = models.SmallIntegerField()
+
+            class Meta:
+                app_label = "migrations"
+                apps = new_apps
+
+        project_state = ProjectState.from_apps(new_apps)
+        model_state = project_state.models["migrations", "foo"]
+        self.assertEqual(len(model_state.options), 2)
+        self.assertEqual(model_state.options["constraints"], [])
+        self.assertEqual(model_state.options["indexes"], [])
+        self.assertEqual(len(model_state.fields), 3)
+        self.assertIn("pk", model_state.fields)
+        self.assertIn("account_id", model_state.fields)
+        self.assertIn("id", model_state.fields)
+
 
 class StateRelationsTests(SimpleTestCase):
     def get_base_project_state(self):

+ 19 - 0
tests/migrations/test_writer.py

@@ -1138,3 +1138,22 @@ class WriterTests(SimpleTestCase):
             ValueError, "'TestModel1' must inherit from 'BaseSerializer'."
         ):
             MigrationWriter.register_serializer(complex, TestModel1)
+
+    def test_composite_pk_import(self):
+        migration = type(
+            "Migration",
+            (migrations.Migration,),
+            {
+                "operations": [
+                    migrations.AddField(
+                        "foo",
+                        "bar",
+                        models.CompositePrimaryKey("foo_id", "bar_id"),
+                    ),
+                ],
+            },
+        )
+        writer = MigrationWriter(migration)
+        output = writer.as_string()
+        self.assertEqual(output.count("import"), 1)
+        self.assertIn("from django.db import migrations, models", output)