소스 검색

Refs #33308 -- Added get_type_oids() hook and simplified registering type handlers on PostgreSQL.

Daniele Varrazzo 2 년 전
부모
커밋
d3e746ace5
2개의 변경된 파일27개의 추가작업 그리고 33개의 파일을 삭제
  1. 25 32
      django/contrib/postgres/signals.py
  2. 2 1
      tests/postgres_tests/test_signals.py

+ 25 - 32
django/contrib/postgres/signals.py

@@ -1,22 +1,16 @@
 import functools
 
 import psycopg2
-from psycopg2 import ProgrammingError
 from psycopg2.extras import register_hstore
 
 from django.db import connections
 from django.db.backends.base.base import NO_DB_ALIAS
 
 
-@functools.lru_cache
-def get_hstore_oids(connection_alias):
-    """Return hstore and hstore array OIDs."""
+def get_type_oids(connection_alias, type_name):
     with connections[connection_alias].cursor() as cursor:
         cursor.execute(
-            "SELECT t.oid, typarray "
-            "FROM pg_type t "
-            "JOIN pg_namespace ns ON typnamespace = ns.oid "
-            "WHERE typname = 'hstore'"
+            "SELECT oid, typarray FROM pg_type WHERE typname = %s", (type_name,)
         )
         oids = []
         array_oids = []
@@ -26,43 +20,42 @@ def get_hstore_oids(connection_alias):
         return tuple(oids), tuple(array_oids)
 
 
+@functools.lru_cache
+def get_hstore_oids(connection_alias):
+    """Return hstore and hstore array OIDs."""
+    return get_type_oids(connection_alias, "hstore")
+
+
 @functools.lru_cache
 def get_citext_oids(connection_alias):
-    """Return citext array OIDs."""
-    with connections[connection_alias].cursor() as cursor:
-        cursor.execute("SELECT typarray FROM pg_type WHERE typname = 'citext'")
-        return tuple(row[0] for row in cursor)
+    """Return citext and citext array OIDs."""
+    return get_type_oids(connection_alias, "citext")
 
 
 def register_type_handlers(connection, **kwargs):
     if connection.vendor != "postgresql" or connection.alias == NO_DB_ALIAS:
         return
 
-    try:
-        oids, array_oids = get_hstore_oids(connection.alias)
+    oids, array_oids = get_hstore_oids(connection.alias)
+    # Don't register handlers when hstore is not available on the database.
+    #
+    # If someone tries to create an hstore field it will error there. This is
+    # necessary as someone may be using PSQL without extensions installed but
+    # be using other features of contrib.postgres.
+    #
+    # This is also needed in order to create the connection in order to install
+    # the hstore extension.
+    if oids:
         register_hstore(
             connection.connection, globally=True, oid=oids, array_oid=array_oids
         )
-    except ProgrammingError:
-        # Hstore is not available on the database.
-        #
-        # If someone tries to create an hstore field it will error there.
-        # This is necessary as someone may be using PSQL without extensions
-        # installed but be using other features of contrib.postgres.
-        #
-        # This is also needed in order to create the connection in order to
-        # install the hstore extension.
-        pass
 
-    try:
-        citext_oids = get_citext_oids(connection.alias)
+    oids, citext_oids = get_citext_oids(connection.alias)
+    # Don't register handlers when citext is not available on the database.
+    #
+    # The same comments in the above call to register_hstore() also apply here.
+    if oids:
         array_type = psycopg2.extensions.new_array_type(
             citext_oids, "citext[]", psycopg2.STRING
         )
         psycopg2.extensions.register_type(array_type, None)
-    except ProgrammingError:
-        # citext is not available on the database.
-        #
-        # The same comments in the except block of the above call to
-        # register_hstore() also apply here.
-        pass

+ 2 - 1
tests/postgres_tests/test_signals.py

@@ -34,8 +34,9 @@ class OIDTests(PostgreSQLTestCase):
         self.assertOIDs(array_oids)
 
     def test_citext_values(self):
-        oids = get_citext_oids(connection.alias)
+        oids, citext_oids = get_citext_oids(connection.alias)
         self.assertOIDs(oids)
+        self.assertOIDs(citext_oids)
 
     def test_register_type_handlers_no_db(self):
         """Registering type handlers for the nodb connection does nothing."""