Browse Source

Refs #28595 -- Added a hook to add execute wrappers for database queries.

Thanks Adam Johnson, Carl Meyer, Anssi Kääriäinen, Mariusz Felisiak,
Michael Manfre, and Tim Graham for discussion and review.
Shai Berger 7 years ago
parent
commit
d612026c37

+ 18 - 0
django/db/backends/base/base.py

@@ -92,6 +92,12 @@ class BaseDatabaseWrapper:
         # is called?
         self.run_commit_hooks_on_set_autocommit_on = False
 
+        # A stack of wrappers to be invoked around execute()/executemany()
+        # calls. Each entry is a function taking five arguments: execute, sql,
+        # params, many, and context. It's the function's responsibility to
+        # call execute(sql, params, many, context).
+        self.execute_wrappers = []
+
         self.client = self.client_class(self)
         self.creation = self.creation_class(self)
         self.features = self.features_class(self)
@@ -629,6 +635,18 @@ class BaseDatabaseWrapper:
             sids, func = current_run_on_commit.pop(0)
             func()
 
+    @contextmanager
+    def execute_wrapper(self, wrapper):
+        """
+        Return a context manager under which the wrapper is applied to suitable
+        database query executions.
+        """
+        self.execute_wrappers.append(wrapper)
+        try:
+            yield
+        finally:
+            self.execute_wrappers.pop()
+
     def copy(self, alias=None, allow_thread_sharing=None):
         """
         Return a copy of this connection.

+ 14 - 1
django/db/backends/utils.py

@@ -1,5 +1,6 @@
 import datetime
 import decimal
+import functools
 import hashlib
 import logging
 import re
@@ -65,6 +66,18 @@ class CursorWrapper:
                 return self.cursor.callproc(procname, params, kparams)
 
     def execute(self, sql, params=None):
+        return self._execute_with_wrappers(sql, params, many=False, executor=self._execute)
+
+    def executemany(self, sql, param_list):
+        return self._execute_with_wrappers(sql, param_list, many=True, executor=self._executemany)
+
+    def _execute_with_wrappers(self, sql, params, many, executor):
+        context = {'connection': self.db, 'cursor': self}
+        for wrapper in reversed(self.db.execute_wrappers):
+            executor = functools.partial(wrapper, executor)
+        return executor(sql, params, many, context)
+
+    def _execute(self, sql, params, *ignored_wrapper_args):
         self.db.validate_no_broken_transaction()
         with self.db.wrap_database_errors:
             if params is None:
@@ -72,7 +85,7 @@ class CursorWrapper:
             else:
                 return self.cursor.execute(sql, params)
 
-    def executemany(self, sql, param_list):
+    def _executemany(self, sql, param_list, *ignored_wrapper_args):
         self.db.validate_no_broken_transaction()
         with self.db.wrap_database_errors:
             return self.cursor.executemany(sql, param_list)

+ 5 - 0
docs/releases/2.0.txt

@@ -339,6 +339,11 @@ Models
   parameters, if the backend supports this feature. Of Django's built-in
   backends, only Oracle supports it.
 
+* The new :meth:`connection.execute_wrapper()
+  <django.db.backends.base.DatabaseWrapper.execute_wrapper>` method allows
+  :doc:`installing wrappers around execution of database queries
+  </topics/db/instrumentation>`.
+
 * The new ``filter`` argument for built-in aggregates allows :ref:`adding
   different conditionals <conditional-aggregation>` to multiple aggregations
   over the same fields or relations.

+ 1 - 0
docs/topics/db/index.txt

@@ -21,4 +21,5 @@ model maps to a single database table.
    multi-db
    tablespaces
    optimization
+   instrumentation
    examples/index

+ 114 - 0
docs/topics/db/instrumentation.txt

@@ -0,0 +1,114 @@
+========================
+Database instrumentation
+========================
+
+.. versionadded:: 2.0
+
+To help you understand and control the queries issued by your code, Django
+provides a hook for installing wrapper functions around the execution of
+database queries. For example, wrappers can count queries, measure query
+duration, log queries, or even prevent query execution (e.g. to make sure that
+no queries are issued while rendering a template with prefetched data).
+
+The wrappers are modeled after :doc:`middleware </topics/http/middleware>` --
+they are callables which take another callable as one of their arguments. They
+call that callable to invoke the (possibly wrapped) database query, and they
+can do what they want around that call. They are, however, created and
+installed by user code, and so don't need a separate factory like middleware do.
+
+Installing a wrapper is done in a context manager -- so the wrappers are
+temporary and specific to some flow in your code.
+
+As mentioned above, an example of a wrapper is a query execution blocker. It
+could look like this::
+
+    def blocker(*args):
+        raise Exception('No database access allowed here.')
+
+And it would be used in a view to block queries from the template like so::
+
+    from django.db import connection
+    from django.shortcuts import render
+
+    def my_view(request):
+        context = {...}  # Code to generate context with all data.
+        template_name = ...
+        with connection.execute_wrapper(blocker):
+            return render(request, template_name, context)
+
+The parameters sent to the wrappers are:
+
+* ``execute`` -- a callable, which should be invoked with the rest of the
+  parameters in order to execute the query.
+
+* ``sql`` -- a ``str``, the SQL query to be sent to the database.
+
+* ``params`` -- a list/tuple of parameter values for the SQL command, or a
+  list/tuple of lists/tuples if the wrapped call is ``executemany()``.
+
+* ``many`` -- a ``bool`` indicating whether the ultimately invoked call is
+  ``execute()`` or ``executemany()`` (and whether ``params`` is expected to be
+  a sequence of values, or a sequence of sequences of values).
+
+* ``context`` -- a dictionary with further data about the context of
+  invocation. This includes the connection and cursor.
+
+Using the parameters, a slightly more complex version of the blocker could
+include the connection name in the error message::
+
+    def blocker(execute, sql, params, many, context):
+        alias = context['connection'].alias
+        raise Exception("Access to database '{}' blocked here".format(alias))
+
+For a more complete example, a query logger could look like this::
+
+    import time
+
+    class QueryLogger:
+
+        def __init__(self):
+            self.queries = []
+
+        def __call__(self, execute, sql, params, many, context):
+            current_query = {'sql': sql, 'params': params, 'many': many}
+            start = time.time()
+            try:
+                result = execute(sql, params, many, context)
+            except Exception as e:
+                current_query['status'] = 'error'
+                current_query['exception'] = e
+                raise
+            else:
+                current_query['status'] = 'ok'
+                return result
+            finally:
+                duration = time.time() - start
+                current_query['duration'] = duration
+                self.queries.append(current_query)
+
+To use this, you would create a logger object and install it as a wrapper::
+
+    from django.db import connection
+
+    ql = QueryLogger()
+    with connection.execute_wrapper(ql):
+        do_queries()
+    # Now we can print the log.
+    print(ql.queries)
+
+.. currentmodule:: django.db.backends.base.DatabaseWrapper
+
+``connection.execute_wrapper()``
+--------------------------------
+
+.. method:: execute_wrapper(wrapper)
+
+Returns a context manager which, when entered, installs a wrapper around
+database query executions, and when exited, removes the wrapper. The wrapper is
+installed on the thread-local connection object.
+
+``wrapper`` is a callable taking five arguments.  It is called for every query
+execution in the scope of the context manager, with arguments ``execute``,
+``sql``, ``params``, ``many``, and ``context`` as described above. It's
+expected to call ``execute(sql, params, many, context)`` and return the return
+value of that call.

+ 98 - 1
tests/backends/base/test_base.py

@@ -1,6 +1,10 @@
+from unittest.mock import MagicMock
+
 from django.db import DEFAULT_DB_ALIAS, connection, connections
 from django.db.backends.base.base import BaseDatabaseWrapper
-from django.test import SimpleTestCase
+from django.test import SimpleTestCase, TestCase
+
+from ..models import Square
 
 
 class DatabaseWrapperTests(SimpleTestCase):
@@ -30,3 +34,96 @@ class DatabaseWrapperTests(SimpleTestCase):
     def test_initialization_display_name(self):
         self.assertEqual(BaseDatabaseWrapper.display_name, 'unknown')
         self.assertNotEqual(connection.display_name, 'unknown')
+
+
+class ExecuteWrapperTests(TestCase):
+
+    @staticmethod
+    def call_execute(connection, params=None):
+        ret_val = '1' if params is None else '%s'
+        sql = 'SELECT ' + ret_val + connection.features.bare_select_suffix
+        with connection.cursor() as cursor:
+            cursor.execute(sql, params)
+
+    def call_executemany(self, connection, params=None):
+        # executemany() must use an update query. Make sure it does nothing
+        # by putting a false condition in the WHERE clause.
+        sql = 'DELETE FROM {} WHERE 0=1 AND 0=%s'.format(Square._meta.db_table)
+        if params is None:
+            params = [(i,) for i in range(3)]
+        with connection.cursor() as cursor:
+            cursor.executemany(sql, params)
+
+    @staticmethod
+    def mock_wrapper():
+        return MagicMock(side_effect=lambda execute, *args: execute(*args))
+
+    def test_wrapper_invoked(self):
+        wrapper = self.mock_wrapper()
+        with connection.execute_wrapper(wrapper):
+            self.call_execute(connection)
+        self.assertTrue(wrapper.called)
+        (_, sql, params, many, context), _ = wrapper.call_args
+        self.assertIn('SELECT', sql)
+        self.assertIsNone(params)
+        self.assertIs(many, False)
+        self.assertEqual(context['connection'], connection)
+
+    def test_wrapper_invoked_many(self):
+        wrapper = self.mock_wrapper()
+        with connection.execute_wrapper(wrapper):
+            self.call_executemany(connection)
+        self.assertTrue(wrapper.called)
+        (_, sql, param_list, many, context), _ = wrapper.call_args
+        self.assertIn('DELETE', sql)
+        self.assertIsInstance(param_list, (list, tuple))
+        self.assertIs(many, True)
+        self.assertEqual(context['connection'], connection)
+
+    def test_database_queried(self):
+        wrapper = self.mock_wrapper()
+        with connection.execute_wrapper(wrapper):
+            with connection.cursor() as cursor:
+                sql = 'SELECT 17' + connection.features.bare_select_suffix
+                cursor.execute(sql)
+                seventeen = cursor.fetchall()
+                self.assertEqual(list(seventeen), [(17,)])
+            self.call_executemany(connection)
+
+    def test_nested_wrapper_invoked(self):
+        outer_wrapper = self.mock_wrapper()
+        inner_wrapper = self.mock_wrapper()
+        with connection.execute_wrapper(outer_wrapper), connection.execute_wrapper(inner_wrapper):
+            self.call_execute(connection)
+            self.assertEqual(inner_wrapper.call_count, 1)
+            self.call_executemany(connection)
+            self.assertEqual(inner_wrapper.call_count, 2)
+
+    def test_outer_wrapper_blocks(self):
+        def blocker(*args):
+            pass
+        wrapper = self.mock_wrapper()
+        c = connection  # This alias shortens the next line.
+        with c.execute_wrapper(wrapper), c.execute_wrapper(blocker), c.execute_wrapper(wrapper):
+            with c.cursor() as cursor:
+                cursor.execute("The database never sees this")
+                self.assertEqual(wrapper.call_count, 1)
+                cursor.executemany("The database never sees this %s", [("either",)])
+                self.assertEqual(wrapper.call_count, 2)
+
+    def test_wrapper_gets_sql(self):
+        wrapper = self.mock_wrapper()
+        sql = "SELECT 'aloha'" + connection.features.bare_select_suffix
+        with connection.execute_wrapper(wrapper), connection.cursor() as cursor:
+            cursor.execute(sql)
+        (_, reported_sql, _, _, _), _ = wrapper.call_args
+        self.assertEqual(reported_sql, sql)
+
+    def test_wrapper_connection_specific(self):
+        wrapper = self.mock_wrapper()
+        with connections['other'].execute_wrapper(wrapper):
+            self.assertEqual(connections['other'].execute_wrappers, [wrapper])
+            self.call_execute(connection)
+        self.assertFalse(wrapper.called)
+        self.assertEqual(connection.execute_wrappers, [])
+        self.assertEqual(connections['other'].execute_wrappers, [])