|
@@ -1,6 +1,10 @@
|
|
|
+from unittest.mock import MagicMock
|
|
|
+
|
|
|
from django.db import DEFAULT_DB_ALIAS, connection, connections
|
|
|
from django.db.backends.base.base import BaseDatabaseWrapper
|
|
|
-from django.test import SimpleTestCase
|
|
|
+from django.test import SimpleTestCase, TestCase
|
|
|
+
|
|
|
+from ..models import Square
|
|
|
|
|
|
|
|
|
class DatabaseWrapperTests(SimpleTestCase):
|
|
@@ -30,3 +34,96 @@ class DatabaseWrapperTests(SimpleTestCase):
|
|
|
def test_initialization_display_name(self):
|
|
|
self.assertEqual(BaseDatabaseWrapper.display_name, 'unknown')
|
|
|
self.assertNotEqual(connection.display_name, 'unknown')
|
|
|
+
|
|
|
+
|
|
|
+class ExecuteWrapperTests(TestCase):
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def call_execute(connection, params=None):
|
|
|
+ ret_val = '1' if params is None else '%s'
|
|
|
+ sql = 'SELECT ' + ret_val + connection.features.bare_select_suffix
|
|
|
+ with connection.cursor() as cursor:
|
|
|
+ cursor.execute(sql, params)
|
|
|
+
|
|
|
+ def call_executemany(self, connection, params=None):
|
|
|
+ # executemany() must use an update query. Make sure it does nothing
|
|
|
+ # by putting a false condition in the WHERE clause.
|
|
|
+ sql = 'DELETE FROM {} WHERE 0=1 AND 0=%s'.format(Square._meta.db_table)
|
|
|
+ if params is None:
|
|
|
+ params = [(i,) for i in range(3)]
|
|
|
+ with connection.cursor() as cursor:
|
|
|
+ cursor.executemany(sql, params)
|
|
|
+
|
|
|
+ @staticmethod
|
|
|
+ def mock_wrapper():
|
|
|
+ return MagicMock(side_effect=lambda execute, *args: execute(*args))
|
|
|
+
|
|
|
+ def test_wrapper_invoked(self):
|
|
|
+ wrapper = self.mock_wrapper()
|
|
|
+ with connection.execute_wrapper(wrapper):
|
|
|
+ self.call_execute(connection)
|
|
|
+ self.assertTrue(wrapper.called)
|
|
|
+ (_, sql, params, many, context), _ = wrapper.call_args
|
|
|
+ self.assertIn('SELECT', sql)
|
|
|
+ self.assertIsNone(params)
|
|
|
+ self.assertIs(many, False)
|
|
|
+ self.assertEqual(context['connection'], connection)
|
|
|
+
|
|
|
+ def test_wrapper_invoked_many(self):
|
|
|
+ wrapper = self.mock_wrapper()
|
|
|
+ with connection.execute_wrapper(wrapper):
|
|
|
+ self.call_executemany(connection)
|
|
|
+ self.assertTrue(wrapper.called)
|
|
|
+ (_, sql, param_list, many, context), _ = wrapper.call_args
|
|
|
+ self.assertIn('DELETE', sql)
|
|
|
+ self.assertIsInstance(param_list, (list, tuple))
|
|
|
+ self.assertIs(many, True)
|
|
|
+ self.assertEqual(context['connection'], connection)
|
|
|
+
|
|
|
+ def test_database_queried(self):
|
|
|
+ wrapper = self.mock_wrapper()
|
|
|
+ with connection.execute_wrapper(wrapper):
|
|
|
+ with connection.cursor() as cursor:
|
|
|
+ sql = 'SELECT 17' + connection.features.bare_select_suffix
|
|
|
+ cursor.execute(sql)
|
|
|
+ seventeen = cursor.fetchall()
|
|
|
+ self.assertEqual(list(seventeen), [(17,)])
|
|
|
+ self.call_executemany(connection)
|
|
|
+
|
|
|
+ def test_nested_wrapper_invoked(self):
|
|
|
+ outer_wrapper = self.mock_wrapper()
|
|
|
+ inner_wrapper = self.mock_wrapper()
|
|
|
+ with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper(inner_wrapper):
|
|
|
+ self.call_execute(connection)
|
|
|
+ self.assertEqual(inner_wrapper.call_count, 1)
|
|
|
+ self.call_executemany(connection)
|
|
|
+ self.assertEqual(inner_wrapper.call_count, 2)
|
|
|
+
|
|
|
+ def test_outer_wrapper_blocks(self):
|
|
|
+ def blocker(*args):
|
|
|
+ pass
|
|
|
+ wrapper = self.mock_wrapper()
|
|
|
+ c = connection # This alias shortens the next line.
|
|
|
+ with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper(wrapper):
|
|
|
+ with c.cursor() as cursor:
|
|
|
+ cursor.execute("The database never sees this")
|
|
|
+ self.assertEqual(wrapper.call_count, 1)
|
|
|
+ cursor.executemany("The database never sees this %s", [("either",)])
|
|
|
+ self.assertEqual(wrapper.call_count, 2)
|
|
|
+
|
|
|
+ def test_wrapper_gets_sql(self):
|
|
|
+ wrapper = self.mock_wrapper()
|
|
|
+ sql = "SELECT 'aloha'" + connection.features.bare_select_suffix
|
|
|
+ with connection.execute_wrapper(wrapper), connection.cursor() as cursor:
|
|
|
+ cursor.execute(sql)
|
|
|
+ (_, reported_sql, _, _, _), _ = wrapper.call_args
|
|
|
+ self.assertEqual(reported_sql, sql)
|
|
|
+
|
|
|
+ def test_wrapper_connection_specific(self):
|
|
|
+ wrapper = self.mock_wrapper()
|
|
|
+ with connections['other'].execute_wrapper(wrapper):
|
|
|
+ self.assertEqual(connections['other'].execute_wrappers, [wrapper])
|
|
|
+ self.call_execute(connection)
|
|
|
+ self.assertFalse(wrapper.called)
|
|
|
+ self.assertEqual(connection.execute_wrappers, [])
|
|
|
+ self.assertEqual(connections['other'].execute_wrappers, [])
|