|
@@ -1,5 +1,7 @@
|
|
|
+import operator
|
|
|
import unittest
|
|
|
from collections import namedtuple
|
|
|
+from contextlib import contextmanager
|
|
|
|
|
|
from django.db import connection
|
|
|
from django.test import TestCase
|
|
@@ -23,6 +25,18 @@ class ServerSideCursorsPostgres(TestCase):
|
|
|
cursors = cursor.fetchall()
|
|
|
return [self.PostgresCursor._make(cursor) for cursor in cursors]
|
|
|
|
|
|
+ @contextmanager
|
|
|
+ def override_db_setting(self, **kwargs):
|
|
|
+ for setting, value in kwargs.items():
|
|
|
+ original_value = connection.settings_dict.get(setting)
|
|
|
+ if setting in connection.settings_dict:
|
|
|
+ self.addCleanup(operator.setitem, connection.settings_dict, setting, original_value)
|
|
|
+ else:
|
|
|
+ self.addCleanup(operator.delitem, connection.settings_dict, setting)
|
|
|
+
|
|
|
+ connection.settings_dict[setting] = kwargs[setting]
|
|
|
+ yield
|
|
|
+
|
|
|
def test_server_side_cursor(self):
|
|
|
persons = Person.objects.iterator()
|
|
|
next(persons) # Open a server-side cursor
|
|
@@ -52,3 +66,17 @@ class ServerSideCursorsPostgres(TestCase):
|
|
|
del persons
|
|
|
cursors = self.inspect_cursors()
|
|
|
self.assertEqual(len(cursors), 0)
|
|
|
+
|
|
|
+ def test_server_side_cursors_setting(self):
|
|
|
+ with self.override_db_setting(DISABLE_SERVER_SIDE_CURSORS=False):
|
|
|
+ persons = Person.objects.iterator()
|
|
|
+ next(persons) # Open a server-side cursor
|
|
|
+ cursors = self.inspect_cursors()
|
|
|
+ self.assertEqual(len(cursors), 1)
|
|
|
+ del persons # Close server-side cursor
|
|
|
+
|
|
|
+ with self.override_db_setting(DISABLE_SERVER_SIDE_CURSORS=True):
|
|
|
+ persons = Person.objects.iterator()
|
|
|
+ next(persons) # Should not open a server-side cursor
|
|
|
+ cursors = self.inspect_cursors()
|
|
|
+ self.assertEqual(len(cursors), 0)
|