ソースを参照

Refs #33497 -- Added connection pool support for PostgreSQL.

Co-authored-by: Florian Apolloner <florian@apolloner.eu>
Co-authored-by: Ran Benita <ran@unusedvar.com>
Sarah Boyce 1 年間 前
コミット
fad334e1a9

+ 5 - 1
django/db/backends/base/base.py

@@ -17,7 +17,7 @@ from django.db.backends.base.validation import BaseDatabaseValidation
 from django.db.backends.signals import connection_created
 from django.db.backends.utils import debug_transaction
 from django.db.transaction import TransactionManagementError
-from django.db.utils import DatabaseErrorWrapper
+from django.db.utils import DatabaseErrorWrapper, ProgrammingError
 from django.utils.asyncio import async_unsafe
 from django.utils.functional import cached_property
 
@@ -271,6 +271,10 @@ class BaseDatabaseWrapper:
     def ensure_connection(self):
         """Guarantee that a connection to the database is established."""
         if self.connection is None:
+            if self.in_atomic_block and self.closed_in_transaction:
+                raise ProgrammingError(
+                    "Cannot open a new connection in an atomic block."
+                )
             with self.wrap_database_errors:
                 self.connect()
 

+ 122 - 23
django/db/backends/postgresql/base.py

@@ -13,7 +13,7 @@ from django.conf import settings
 from django.core.exceptions import ImproperlyConfigured
 from django.db import DatabaseError as WrappedDatabaseError
 from django.db import connections
-from django.db.backends.base.base import BaseDatabaseWrapper
+from django.db.backends.base.base import NO_DB_ALIAS, BaseDatabaseWrapper
 from django.db.backends.utils import CursorDebugWrapper as BaseCursorDebugWrapper
 from django.utils.asyncio import async_unsafe
 from django.utils.functional import cached_property
@@ -86,6 +86,24 @@ def _get_varchar_column(data):
     return "varchar(%(max_length)s)" % data
 
 
+def ensure_timezone(connection, ops, timezone_name):
+    conn_timezone_name = connection.info.parameter_status("TimeZone")
+    if timezone_name and conn_timezone_name != timezone_name:
+        with connection.cursor() as cursor:
+            cursor.execute(ops.set_time_zone_sql(), [timezone_name])
+        return True
+    return False
+
+
+def ensure_role(connection, ops, role_name):
+    if role_name:
+        with connection.cursor() as cursor:
+            sql = ops.compose_sql("SET ROLE %s", [role_name])
+            cursor.execute(sql)
+        return True
+    return False
+
+
 class DatabaseWrapper(BaseDatabaseWrapper):
     vendor = "postgresql"
     display_name = "PostgreSQL"
@@ -179,6 +197,53 @@ class DatabaseWrapper(BaseDatabaseWrapper):
     ops_class = DatabaseOperations
     # PostgreSQL backend-specific attributes.
     _named_cursor_idx = 0
+    _connection_pools = {}
+
+    @property
+    def pool(self):
+        pool_options = self.settings_dict["OPTIONS"].get("pool")
+        if self.alias == NO_DB_ALIAS or not pool_options:
+            return None
+
+        if self.alias not in self._connection_pools:
+            if self.settings_dict.get("CONN_MAX_AGE", 0) != 0:
+                raise ImproperlyConfigured(
+                    "Pooling doesn't support persistent connections."
+                )
+            # Set the default options.
+            if pool_options is True:
+                pool_options = {}
+
+            try:
+                from psycopg_pool import ConnectionPool
+            except ImportError as err:
+                raise ImproperlyConfigured(
+                    "Error loading psycopg_pool module.\nDid you install psycopg[pool]?"
+                ) from err
+
+            connect_kwargs = self.get_connection_params()
+            # Ensure we run in autocommit, Django properly sets it later on.
+            connect_kwargs["autocommit"] = True
+            enable_checks = self.settings_dict["CONN_HEALTH_CHECKS"]
+            pool = ConnectionPool(
+                kwargs=connect_kwargs,
+                open=False,  # Do not open the pool during startup.
+                configure=self._configure_connection,
+                check=ConnectionPool.check_connection if enable_checks else None,
+                **pool_options,
+            )
+            # setdefault() ensures that multiple threads don't set this in
+            # parallel. Since we do not open the pool during it's init above,
+            # this means that at worst during startup multiple threads generate
+            # pool objects and the first to set it wins.
+            self._connection_pools.setdefault(self.alias, pool)
+
+        return self._connection_pools[self.alias]
+
+    def close_pool(self):
+        if self.pool:
+            self.pool.close()
+            del self._connection_pools[self.alias]
 
     def get_database_version(self):
         """
@@ -221,6 +286,11 @@ class DatabaseWrapper(BaseDatabaseWrapper):
 
         conn_params.pop("assume_role", None)
         conn_params.pop("isolation_level", None)
+
+        pool_options = conn_params.pop("pool", None)
+        if pool_options and not is_psycopg3:
+            raise ImproperlyConfigured("Database pooling requires psycopg >= 3")
+
         server_side_binding = conn_params.pop("server_side_binding", None)
         conn_params.setdefault(
             "cursor_factory",
@@ -272,7 +342,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
                     f"Invalid transaction isolation level {isolation_level_value} "
                     f"specified. Use one of the psycopg.IsolationLevel values."
                 )
-        connection = self.Database.connect(**conn_params)
+        if self.pool:
+            # If nothing else has opened the pool, open it now.
+            self.pool.open()
+            connection = self.pool.getconn()
+        else:
+            connection = self.Database.connect(**conn_params)
         if set_isolation_level:
             connection.isolation_level = self.isolation_level
         if not is_psycopg3:
@@ -285,36 +360,52 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         return connection
 
     def ensure_timezone(self):
+        # Close the pool so new connections pick up the correct timezone.
+        self.close_pool()
         if self.connection is None:
             return False
-        conn_timezone_name = self.connection.info.parameter_status("TimeZone")
-        timezone_name = self.timezone_name
-        if timezone_name and conn_timezone_name != timezone_name:
-            with self.connection.cursor() as cursor:
-                cursor.execute(self.ops.set_time_zone_sql(), [timezone_name])
-            return True
-        return False
-
-    def ensure_role(self):
-        if new_role := self.settings_dict["OPTIONS"].get("assume_role"):
-            with self.connection.cursor() as cursor:
-                sql = self.ops.compose_sql("SET ROLE %s", [new_role])
-                cursor.execute(sql)
-            return True
-        return False
+        return ensure_timezone(self.connection, self.ops, self.timezone_name)
 
-    def init_connection_state(self):
-        super().init_connection_state()
+    def _configure_connection(self, connection):
+        # This function is called from init_connection_state and from the
+        # psycopg pool itself after a connection is opened. Make sure that
+        # whatever is done here does not access anything on self aside from
+        # variables.
 
         # Commit after setting the time zone.
-        commit_tz = self.ensure_timezone()
+        commit_tz = ensure_timezone(connection, self.ops, self.timezone_name)
         # Set the role on the connection. This is useful if the credential used
         # to login is not the same as the role that owns database resources. As
         # can be the case when using temporary or ephemeral credentials.
-        commit_role = self.ensure_role()
+        role_name = self.settings_dict["OPTIONS"].get("assume_role")
+        commit_role = ensure_role(connection, self.ops, role_name)
+
+        return commit_role or commit_tz
+
+    def _close(self):
+        if self.connection is not None:
+            # `wrap_database_errors` only works for `putconn` as long as there
+            # is no `reset` function set in the pool because it is deferred
+            # into a thread and not directly executed.
+            with self.wrap_database_errors:
+                if self.pool:
+                    # Ensure the correct pool is returned. This is a workaround
+                    # for tests so a pool can be changed on setting changes
+                    # (e.g. USE_TZ, TIME_ZONE).
+                    self.connection._pool.putconn(self.connection)
+                    # Connection can no longer be used.
+                    self.connection = None
+                else:
+                    return self.connection.close()
 
-        if (commit_role or commit_tz) and not self.get_autocommit():
-            self.connection.commit()
+    def init_connection_state(self):
+        super().init_connection_state()
+
+        if self.connection is not None and not self.pool:
+            commit = self._configure_connection(self.connection)
+
+            if commit and not self.get_autocommit():
+                self.connection.commit()
 
     @async_unsafe
     def create_cursor(self, name=None):
@@ -396,6 +487,8 @@ class DatabaseWrapper(BaseDatabaseWrapper):
             cursor.execute("SET CONSTRAINTS ALL DEFERRED")
 
     def is_usable(self):
+        if self.connection is None:
+            return False
         try:
             # Use a psycopg cursor directly, bypassing Django's utilities.
             with self.connection.cursor() as cursor:
@@ -405,6 +498,12 @@ class DatabaseWrapper(BaseDatabaseWrapper):
         else:
             return True
 
+    def close_if_health_check_failed(self):
+        if self.pool:
+            # The pool only returns healthy connections.
+            return
+        return super().close_if_health_check_failed()
+
     @contextmanager
     def _nodb_cursor(self):
         cursor = None

+ 5 - 0
django/db/backends/postgresql/creation.py

@@ -58,6 +58,7 @@ class DatabaseCreation(BaseDatabaseCreation):
         # CREATE DATABASE ... WITH TEMPLATE ... requires closing connections
         # to the template database.
         self.connection.close()
+        self.connection.close_pool()
 
         source_database_name = self.connection.settings_dict["NAME"]
         target_database_name = self.get_test_db_clone_settings(suffix)["NAME"]
@@ -84,3 +85,7 @@ class DatabaseCreation(BaseDatabaseCreation):
                 except Exception as e:
                     self.log("Got an error cloning the test database: %s" % e)
                     sys.exit(2)
+
+    def _destroy_test_db(self, test_database_name, verbosity):
+        self.connection.close_pool()
+        return super()._destroy_test_db(test_database_name, verbosity)

+ 23 - 9
django/db/backends/postgresql/features.py

@@ -83,15 +83,29 @@ class DatabaseFeatures(BaseDatabaseFeatures):
     test_now_utc_template = "STATEMENT_TIMESTAMP() AT TIME ZONE 'UTC'"
     insert_test_table_with_defaults = "INSERT INTO {} DEFAULT VALUES"
 
-    django_test_skips = {
-        "opclasses are PostgreSQL only.": {
-            "indexes.tests.SchemaIndexesNotPostgreSQLTests."
-            "test_create_index_ignores_opclasses",
-        },
-        "PostgreSQL requires casting to text.": {
-            "lookup.tests.LookupTests.test_textfield_exact_null",
-        },
-    }
+    @cached_property
+    def django_test_skips(self):
+        skips = {
+            "opclasses are PostgreSQL only.": {
+                "indexes.tests.SchemaIndexesNotPostgreSQLTests."
+                "test_create_index_ignores_opclasses",
+            },
+            "PostgreSQL requires casting to text.": {
+                "lookup.tests.LookupTests.test_textfield_exact_null",
+            },
+        }
+        if self.connection.settings_dict["OPTIONS"].get("pool"):
+            skips.update(
+                {
+                    "Pool does implicit health checks": {
+                        "backends.base.test_base.ConnectionHealthChecksTests."
+                        "test_health_checks_enabled",
+                        "backends.base.test_base.ConnectionHealthChecksTests."
+                        "test_set_autocommit_health_checks_enabled",
+                    },
+                }
+            )
+        return skips
 
     @cached_property
     def django_test_expected_failures(self):

+ 25 - 0
docs/ref/databases.txt

@@ -245,6 +245,31 @@ database configuration in :setting:`DATABASES`::
         },
     }
 
+.. _postgresql-pool:
+
+Connection pool
+---------------
+
+.. versionadded:: 5.1
+
+To use a connection pool with `psycopg`_, you can either set ``"pool"`` in the
+:setting:`OPTIONS` part of your database configuration in :setting:`DATABASES`
+to be a dict to be passed to :class:`~psycopg:psycopg_pool.ConnectionPool`, or
+to ``True`` to use the ``ConnectionPool`` defaults::
+
+    DATABASES = {
+        "default": {
+            "ENGINE": "django.db.backends.postgresql",
+            # ...
+            "OPTIONS": {
+                "pool": True,
+            },
+        },
+    }
+
+This option requires ``psycopg[pool]`` or :pypi:`psycopg-pool` to be installed
+and is ignored with ``psycopg2``.
+
 .. _database-server-side-parameters-binding:
 
 Server-side parameters binding

+ 3 - 0
docs/releases/5.1.txt

@@ -162,6 +162,9 @@ Database backends
   to allow specifying :ref:`pragma options <sqlite-init-command>` to set upon
   connection.
 
+* ``"pool"`` option is now supported in :setting:`OPTIONS` on PostgreSQL to
+  allow using :ref:`connection pools <postgresql-pool>`.
+
 Decorators
 ~~~~~~~~~~
 

+ 141 - 11
tests/backends/postgresql/tests.py

@@ -8,6 +8,7 @@ from django.db import (
     DEFAULT_DB_ALIAS,
     DatabaseError,
     NotSupportedError,
+    ProgrammingError,
     connection,
     connections,
 )
@@ -20,6 +21,15 @@ except ImportError:
     is_psycopg3 = False
 
 
+def no_pool_connection(alias=None):
+    new_connection = connection.copy(alias)
+    new_connection.settings_dict = copy.deepcopy(connection.settings_dict)
+    # Ensure that the second connection circumvents the pool, this is kind
+    # of a hack, but we cannot easily change the pool connections.
+    new_connection.settings_dict["OPTIONS"]["pool"] = False
+    return new_connection
+
+
 @unittest.skipUnless(connection.vendor == "postgresql", "PostgreSQL tests")
 class Tests(TestCase):
     databases = {"default", "other"}
@@ -177,7 +187,7 @@ class Tests(TestCase):
         PostgreSQL shouldn't roll back SET TIME ZONE, even if the first
         transaction is rolled back (#17062).
         """
-        new_connection = connection.copy()
+        new_connection = no_pool_connection()
         try:
             # Ensure the database default time zone is different than
             # the time zone in new_connection.settings_dict. We can
@@ -213,7 +223,7 @@ class Tests(TestCase):
         The connection wrapper shouldn't believe that autocommit is enabled
         after setting the time zone when AUTOCOMMIT is False (#21452).
         """
-        new_connection = connection.copy()
+        new_connection = no_pool_connection()
         new_connection.settings_dict["AUTOCOMMIT"] = False
 
         try:
@@ -223,6 +233,126 @@ class Tests(TestCase):
         finally:
             new_connection.close()
 
+    @unittest.skipUnless(is_psycopg3, "psycopg3 specific test")
+    def test_connect_pool(self):
+        from psycopg_pool import PoolTimeout
+
+        new_connection = no_pool_connection(alias="default_pool")
+        new_connection.settings_dict["OPTIONS"]["pool"] = {
+            "min_size": 0,
+            "max_size": 2,
+            "timeout": 0.1,
+        }
+        self.assertIsNotNone(new_connection.pool)
+
+        connections = []
+
+        def get_connection():
+            # copy() reuses the existing alias and as such the same pool.
+            conn = new_connection.copy()
+            conn.connect()
+            connections.append(conn)
+            return conn
+
+        try:
+            connection_1 = get_connection()  # First connection.
+            connection_1_backend_pid = connection_1.connection.info.backend_pid
+            get_connection()  # Get the second connection.
+            with self.assertRaises(PoolTimeout):
+                # The pool has a maximum of 2 connections.
+                get_connection()
+
+            connection_1.close()  # Release back to the pool.
+            connection_3 = get_connection()
+            # Reuses the first connection as it is available.
+            self.assertEqual(
+                connection_3.connection.info.backend_pid, connection_1_backend_pid
+            )
+        finally:
+            # Release all connections back to the pool.
+            for conn in connections:
+                conn.close()
+            new_connection.close_pool()
+
+    @unittest.skipUnless(is_psycopg3, "psycopg3 specific test")
+    def test_connect_pool_set_to_true(self):
+        new_connection = no_pool_connection(alias="default_pool")
+        new_connection.settings_dict["OPTIONS"]["pool"] = True
+        try:
+            self.assertIsNotNone(new_connection.pool)
+        finally:
+            new_connection.close_pool()
+
+    @unittest.skipUnless(is_psycopg3, "psycopg3 specific test")
+    def test_connect_pool_with_timezone(self):
+        new_time_zone = "Africa/Nairobi"
+        new_connection = no_pool_connection(alias="default_pool")
+
+        try:
+            with new_connection.cursor() as cursor:
+                cursor.execute("SHOW TIMEZONE")
+                tz = cursor.fetchone()[0]
+                self.assertNotEqual(new_time_zone, tz)
+        finally:
+            new_connection.close()
+
+        del new_connection.timezone_name
+        new_connection.settings_dict["OPTIONS"]["pool"] = True
+        try:
+            with self.settings(TIME_ZONE=new_time_zone):
+                with new_connection.cursor() as cursor:
+                    cursor.execute("SHOW TIMEZONE")
+                    tz = cursor.fetchone()[0]
+                    self.assertEqual(new_time_zone, tz)
+        finally:
+            new_connection.close()
+            new_connection.close_pool()
+
+    @unittest.skipUnless(is_psycopg3, "psycopg3 specific test")
+    def test_pooling_health_checks(self):
+        new_connection = no_pool_connection(alias="default_pool")
+        new_connection.settings_dict["OPTIONS"]["pool"] = True
+        new_connection.settings_dict["CONN_HEALTH_CHECKS"] = False
+
+        try:
+            self.assertIsNone(new_connection.pool._check)
+        finally:
+            new_connection.close_pool()
+
+        new_connection.settings_dict["CONN_HEALTH_CHECKS"] = True
+        try:
+            self.assertIsNotNone(new_connection.pool._check)
+        finally:
+            new_connection.close_pool()
+
+    @unittest.skipUnless(is_psycopg3, "psycopg3 specific test")
+    def test_cannot_open_new_connection_in_atomic_block(self):
+        new_connection = no_pool_connection(alias="default_pool")
+        new_connection.settings_dict["OPTIONS"]["pool"] = True
+
+        msg = "Cannot open a new connection in an atomic block."
+        new_connection.in_atomic_block = True
+        new_connection.closed_in_transaction = True
+        with self.assertRaisesMessage(ProgrammingError, msg):
+            new_connection.ensure_connection()
+
+    @unittest.skipUnless(is_psycopg3, "psycopg3 specific test")
+    def test_pooling_not_support_persistent_connections(self):
+        new_connection = no_pool_connection(alias="default_pool")
+        new_connection.settings_dict["OPTIONS"]["pool"] = True
+        new_connection.settings_dict["CONN_MAX_AGE"] = 10
+        msg = "Pooling doesn't support persistent connections."
+        with self.assertRaisesMessage(ImproperlyConfigured, msg):
+            new_connection.pool
+
+    @unittest.skipIf(is_psycopg3, "psycopg2 specific test")
+    def test_connect_pool_setting_ignored_for_psycopg2(self):
+        new_connection = no_pool_connection()
+        new_connection.settings_dict["OPTIONS"]["pool"] = True
+        msg = "Database pooling requires psycopg >= 3"
+        with self.assertRaisesMessage(ImproperlyConfigured, msg):
+            new_connection.connect()
+
     def test_connect_isolation_level(self):
         """
         The transaction level can be configured with
@@ -236,7 +366,7 @@ class Tests(TestCase):
         # Check the level on the psycopg connection, not the Django wrapper.
         self.assertIsNone(connection.connection.isolation_level)
 
-        new_connection = connection.copy()
+        new_connection = no_pool_connection()
         new_connection.settings_dict["OPTIONS"][
             "isolation_level"
         ] = IsolationLevel.SERIALIZABLE
@@ -253,7 +383,7 @@ class Tests(TestCase):
 
     def test_connect_invalid_isolation_level(self):
         self.assertIsNone(connection.connection.isolation_level)
-        new_connection = connection.copy()
+        new_connection = no_pool_connection()
         new_connection.settings_dict["OPTIONS"]["isolation_level"] = -1
         msg = (
             "Invalid transaction isolation level -1 specified. Use one of the "
@@ -269,7 +399,7 @@ class Tests(TestCase):
         """
         try:
             custom_role = "django_nonexistent_role"
-            new_connection = connection.copy()
+            new_connection = no_pool_connection()
             new_connection.settings_dict["OPTIONS"]["assume_role"] = custom_role
             msg = f'role "{custom_role}" does not exist'
             with self.assertRaisesMessage(errors.InvalidParameterValue, msg):
@@ -285,7 +415,7 @@ class Tests(TestCase):
         """
         from django.db.backends.postgresql.base import ServerBindingCursor
 
-        new_connection = connection.copy()
+        new_connection = no_pool_connection()
         new_connection.settings_dict["OPTIONS"]["server_side_binding"] = True
         try:
             new_connection.connect()
@@ -306,7 +436,7 @@ class Tests(TestCase):
         class MyCursor(Cursor):
             pass
 
-        new_connection = connection.copy()
+        new_connection = no_pool_connection()
         new_connection.settings_dict["OPTIONS"]["cursor_factory"] = MyCursor
         try:
             new_connection.connect()
@@ -315,7 +445,7 @@ class Tests(TestCase):
             new_connection.close()
 
     def test_connect_no_is_usable_checks(self):
-        new_connection = connection.copy()
+        new_connection = no_pool_connection()
         try:
             with mock.patch.object(new_connection, "is_usable") as is_usable:
                 new_connection.connect()
@@ -324,7 +454,7 @@ class Tests(TestCase):
             new_connection.close()
 
     def test_client_encoding_utf8_enforce(self):
-        new_connection = connection.copy()
+        new_connection = no_pool_connection()
         new_connection.settings_dict["OPTIONS"]["client_encoding"] = "iso-8859-2"
         try:
             new_connection.connect()
@@ -417,7 +547,7 @@ class Tests(TestCase):
         self.assertEqual([q["sql"] for q in connection.queries], [copy_sql])
 
     def test_get_database_version(self):
-        new_connection = connection.copy()
+        new_connection = no_pool_connection()
         new_connection.pg_version = 130009
         self.assertEqual(new_connection.get_database_version(), (13, 9))
 
@@ -429,7 +559,7 @@ class Tests(TestCase):
         self.assertTrue(mocked_get_database_version.called)
 
     def test_compose_sql_when_no_connection(self):
-        new_connection = connection.copy()
+        new_connection = no_pool_connection()
         try:
             self.assertEqual(
                 new_connection.ops.compose_sql("SELECT %s", ["test"]),

+ 1 - 0
tests/requirements/postgres.txt

@@ -1,2 +1,3 @@
 psycopg>=3.1.14; implementation_name == 'pypy'
 psycopg[binary]>=3.1.8; implementation_name != 'pypy'
+psycopg-pool>=3.2.0