|
@@ -1,7 +1,7 @@
|
|
|
"""
|
|
|
PostgreSQL database backend for Django.
|
|
|
|
|
|
-Requires psycopg 2: https://www.psycopg.org/
|
|
|
+Requires psycopg2 >= 2.8.4 or psycopg >= 3.1
|
|
|
"""
|
|
|
|
|
|
import asyncio
|
|
@@ -21,48 +21,63 @@ from django.utils.safestring import SafeString
|
|
|
from django.utils.version import get_version_tuple
|
|
|
|
|
|
try:
|
|
|
- import psycopg2 as Database
|
|
|
- import psycopg2.extensions
|
|
|
- import psycopg2.extras
|
|
|
-except ImportError as e:
|
|
|
- raise ImproperlyConfigured("Error loading psycopg2 module: %s" % e)
|
|
|
+ try:
|
|
|
+ import psycopg as Database
|
|
|
+ except ImportError:
|
|
|
+ import psycopg2 as Database
|
|
|
+except ImportError:
|
|
|
+ raise ImproperlyConfigured("Error loading psycopg2 or psycopg module")
|
|
|
|
|
|
|
|
|
-def psycopg2_version():
|
|
|
- version = psycopg2.__version__.split(" ", 1)[0]
|
|
|
+def psycopg_version():
|
|
|
+ version = Database.__version__.split(" ", 1)[0]
|
|
|
return get_version_tuple(version)
|
|
|
|
|
|
|
|
|
-PSYCOPG2_VERSION = psycopg2_version()
|
|
|
-
|
|
|
-if PSYCOPG2_VERSION < (2, 8, 4):
|
|
|
+if psycopg_version() < (2, 8, 4):
|
|
|
+ raise ImproperlyConfigured(
|
|
|
+ f"psycopg2 version 2.8.4 or newer is required; you have {Database.__version__}"
|
|
|
+ )
|
|
|
+if (3,) <= psycopg_version() < (3, 1):
|
|
|
raise ImproperlyConfigured(
|
|
|
- "psycopg2 version 2.8.4 or newer is required; you have %s"
|
|
|
- % psycopg2.__version__
|
|
|
+ f"psycopg version 3.1 or newer is required; you have {Database.__version__}"
|
|
|
)
|
|
|
|
|
|
|
|
|
-# Some of these import psycopg2, so import them after checking if it's installed.
|
|
|
-from .client import DatabaseClient # NOQA
|
|
|
-from .creation import DatabaseCreation # NOQA
|
|
|
-from .features import DatabaseFeatures # NOQA
|
|
|
-from .introspection import DatabaseIntrospection # NOQA
|
|
|
-from .operations import DatabaseOperations # NOQA
|
|
|
-from .psycopg_any import IsolationLevel # NOQA
|
|
|
-from .schema import DatabaseSchemaEditor # NOQA
|
|
|
+from .psycopg_any import IsolationLevel, is_psycopg3 # NOQA isort:skip
|
|
|
+
|
|
|
+if is_psycopg3:
|
|
|
+ from psycopg import adapters, sql
|
|
|
+ from psycopg.pq import Format
|
|
|
|
|
|
-psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
|
|
|
-psycopg2.extras.register_uuid()
|
|
|
+ from .psycopg_any import get_adapters_template, register_tzloader
|
|
|
|
|
|
-# Register support for inet[] manually so we don't have to handle the Inet()
|
|
|
-# object on load all the time.
|
|
|
-INETARRAY_OID = 1041
|
|
|
-INETARRAY = psycopg2.extensions.new_array_type(
|
|
|
- (INETARRAY_OID,),
|
|
|
- "INETARRAY",
|
|
|
- psycopg2.extensions.UNICODE,
|
|
|
-)
|
|
|
-psycopg2.extensions.register_type(INETARRAY)
|
|
|
+ TIMESTAMPTZ_OID = adapters.types["timestamptz"].oid
|
|
|
+
|
|
|
+else:
|
|
|
+ import psycopg2.extensions
|
|
|
+ import psycopg2.extras
|
|
|
+
|
|
|
+ psycopg2.extensions.register_adapter(SafeString, psycopg2.extensions.QuotedString)
|
|
|
+ psycopg2.extras.register_uuid()
|
|
|
+
|
|
|
+ # Register support for inet[] manually so we don't have to handle the Inet()
|
|
|
+ # object on load all the time.
|
|
|
+ INETARRAY_OID = 1041
|
|
|
+ INETARRAY = psycopg2.extensions.new_array_type(
|
|
|
+ (INETARRAY_OID,),
|
|
|
+ "INETARRAY",
|
|
|
+ psycopg2.extensions.UNICODE,
|
|
|
+ )
|
|
|
+ psycopg2.extensions.register_type(INETARRAY)
|
|
|
+
|
|
|
+# Some of these import psycopg, so import them after checking if it's installed.
|
|
|
+from .client import DatabaseClient # NOQA isort:skip
|
|
|
+from .creation import DatabaseCreation # NOQA isort:skip
|
|
|
+from .features import DatabaseFeatures # NOQA isort:skip
|
|
|
+from .introspection import DatabaseIntrospection # NOQA isort:skip
|
|
|
+from .operations import DatabaseOperations # NOQA isort:skip
|
|
|
+from .schema import DatabaseSchemaEditor # NOQA isort:skip
|
|
|
|
|
|
|
|
|
class DatabaseWrapper(BaseDatabaseWrapper):
|
|
@@ -209,6 +224,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|
|
conn_params["host"] = settings_dict["HOST"]
|
|
|
if settings_dict["PORT"]:
|
|
|
conn_params["port"] = settings_dict["PORT"]
|
|
|
+ if is_psycopg3:
|
|
|
+ conn_params["context"] = get_adapters_template(
|
|
|
+ settings.USE_TZ, self.timezone
|
|
|
+ )
|
|
|
+ # Disable prepared statements by default to keep connection poolers
|
|
|
+ # working. Can be reenabled via OPTIONS in the settings dict.
|
|
|
+ conn_params["prepare_threshold"] = conn_params.pop(
|
|
|
+ "prepare_threshold", None
|
|
|
+ )
|
|
|
return conn_params
|
|
|
|
|
|
@async_unsafe
|
|
@@ -232,17 +256,19 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|
|
except ValueError:
|
|
|
raise ImproperlyConfigured(
|
|
|
f"Invalid transaction isolation level {isolation_level_value} "
|
|
|
- f"specified. Use one of the IsolationLevel values."
|
|
|
+ f"specified. Use one of the psycopg.IsolationLevel values."
|
|
|
)
|
|
|
- connection = Database.connect(**conn_params)
|
|
|
+ connection = self.Database.connect(**conn_params)
|
|
|
if set_isolation_level:
|
|
|
connection.isolation_level = self.isolation_level
|
|
|
- # Register dummy loads() to avoid a round trip from psycopg2's decode
|
|
|
- # to json.dumps() to json.loads(), when using a custom decoder in
|
|
|
- # JSONField.
|
|
|
- psycopg2.extras.register_default_jsonb(
|
|
|
- conn_or_curs=connection, loads=lambda x: x
|
|
|
- )
|
|
|
+ if not is_psycopg3:
|
|
|
+ # Register dummy loads() to avoid a round trip from psycopg2's
|
|
|
+ # decode to json.dumps() to json.loads(), when using a custom
|
|
|
+ # decoder in JSONField.
|
|
|
+ psycopg2.extras.register_default_jsonb(
|
|
|
+ conn_or_curs=connection, loads=lambda x: x
|
|
|
+ )
|
|
|
+ connection.cursor_factory = Cursor
|
|
|
return connection
|
|
|
|
|
|
def ensure_timezone(self):
|
|
@@ -275,7 +301,15 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|
|
)
|
|
|
else:
|
|
|
cursor = self.connection.cursor()
|
|
|
- cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
|
|
|
+
|
|
|
+ if is_psycopg3:
|
|
|
+ # Register the cursor timezone only if the connection disagrees, to
|
|
|
+ # avoid copying the adapter map.
|
|
|
+ tzloader = self.connection.adapters.get_loader(TIMESTAMPTZ_OID, Format.TEXT)
|
|
|
+ if self.timezone != tzloader.timezone:
|
|
|
+ register_tzloader(self.timezone, cursor)
|
|
|
+ else:
|
|
|
+ cursor.tzinfo_factory = self.tzinfo_factory if settings.USE_TZ else None
|
|
|
return cursor
|
|
|
|
|
|
def tzinfo_factory(self, offset):
|
|
@@ -379,11 +413,43 @@ class DatabaseWrapper(BaseDatabaseWrapper):
|
|
|
return CursorDebugWrapper(cursor, self)
|
|
|
|
|
|
|
|
|
-class CursorDebugWrapper(BaseCursorDebugWrapper):
|
|
|
- def copy_expert(self, sql, file, *args):
|
|
|
- with self.debug_sql(sql):
|
|
|
- return self.cursor.copy_expert(sql, file, *args)
|
|
|
+if is_psycopg3:
|
|
|
+
|
|
|
+ class Cursor(Database.Cursor):
|
|
|
+ """
|
|
|
+ A subclass of psycopg cursor implementing callproc.
|
|
|
+ """
|
|
|
+
|
|
|
+ def callproc(self, name, args=None):
|
|
|
+ if not isinstance(name, sql.Identifier):
|
|
|
+ name = sql.Identifier(name)
|
|
|
+
|
|
|
+ qparts = [sql.SQL("SELECT * FROM "), name, sql.SQL("(")]
|
|
|
+ if args:
|
|
|
+ for item in args:
|
|
|
+ qparts.append(sql.Literal(item))
|
|
|
+ qparts.append(sql.SQL(","))
|
|
|
+ del qparts[-1]
|
|
|
+
|
|
|
+ qparts.append(sql.SQL(")"))
|
|
|
+ stmt = sql.Composed(qparts)
|
|
|
+ self.execute(stmt)
|
|
|
+ return args
|
|
|
+
|
|
|
+ class CursorDebugWrapper(BaseCursorDebugWrapper):
|
|
|
+ def copy(self, statement):
|
|
|
+ with self.debug_sql(statement):
|
|
|
+ return self.cursor.copy(statement)
|
|
|
+
|
|
|
+else:
|
|
|
+
|
|
|
+ Cursor = psycopg2.extensions.cursor
|
|
|
+
|
|
|
+ class CursorDebugWrapper(BaseCursorDebugWrapper):
|
|
|
+ def copy_expert(self, sql, file, *args):
|
|
|
+ with self.debug_sql(sql):
|
|
|
+ return self.cursor.copy_expert(sql, file, *args)
|
|
|
|
|
|
- def copy_to(self, file, table, *args, **kwargs):
|
|
|
- with self.debug_sql(sql="COPY %s TO STDOUT" % table):
|
|
|
- return self.cursor.copy_to(file, table, *args, **kwargs)
|
|
|
+ def copy_to(self, file, table, *args, **kwargs):
|
|
|
+ with self.debug_sql(sql="COPY %s TO STDOUT" % table):
|
|
|
+ return self.cursor.copy_to(file, table, *args, **kwargs)
|