123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434 |
- import datetime
- import decimal
- import uuid
- from functools import lru_cache
- from itertools import chain
- from django.conf import settings
- from django.core.exceptions import FieldError
- from django.db import DatabaseError, NotSupportedError, models
- from django.db.backends.base.operations import BaseDatabaseOperations
- from django.db.models.constants import OnConflict
- from django.db.models.expressions import Col
- from django.utils import timezone
- from django.utils.dateparse import parse_date, parse_datetime, parse_time
- from django.utils.functional import cached_property
- class DatabaseOperations(BaseDatabaseOperations):
- cast_char_field_without_max_length = "text"
- cast_data_types = {
- "DateField": "TEXT",
- "DateTimeField": "TEXT",
- }
- explain_prefix = "EXPLAIN QUERY PLAN"
- # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
- # SQLite. Use JSON_TYPE() instead.
- jsonfield_datatype_values = frozenset(["null", "false", "true"])
- def bulk_batch_size(self, fields, objs):
- """
- SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
- 999 variables per query.
- If there's only a single field to insert, the limit is 500
- (SQLITE_MAX_COMPOUND_SELECT).
- """
- if len(fields) == 1:
- return 500
- elif len(fields) > 1:
- return self.connection.features.max_query_params // len(fields)
- else:
- return len(objs)
- def check_expression_support(self, expression):
- bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
- bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
- if isinstance(expression, bad_aggregates):
- for expr in expression.get_source_expressions():
- try:
- output_field = expr.output_field
- except (AttributeError, FieldError):
- # Not every subexpression has an output_field which is fine
- # to ignore.
- pass
- else:
- if isinstance(output_field, bad_fields):
- raise NotSupportedError(
- "You cannot use Sum, Avg, StdDev, and Variance "
- "aggregations on date/time fields in sqlite3 "
- "since date/time is saved as text."
- )
- if (
- isinstance(expression, models.Aggregate)
- and expression.distinct
- and len(expression.source_expressions) > 1
- ):
- raise NotSupportedError(
- "SQLite doesn't support DISTINCT on aggregate functions "
- "accepting multiple arguments."
- )
- def date_extract_sql(self, lookup_type, sql, params):
- """
- Support EXTRACT with a user-defined function django_date_extract()
- that's registered in connect(). Use single quotes because this is a
- string and could otherwise cause a collision with a field name.
- """
- return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params)
- def fetch_returned_insert_rows(self, cursor):
- """
- Given a cursor object that has just performed an INSERT...RETURNING
- statement into a table, return the list of returned data.
- """
- return cursor.fetchall()
- def format_for_duration_arithmetic(self, sql):
- """Do nothing since formatting is handled in the custom function."""
- return sql
- def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
- return f"django_date_trunc(%s, {sql}, %s, %s)", (
- lookup_type.lower(),
- *params,
- *self._convert_tznames_to_sql(tzname),
- )
- def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
- return f"django_time_trunc(%s, {sql}, %s, %s)", (
- lookup_type.lower(),
- *params,
- *self._convert_tznames_to_sql(tzname),
- )
- def _convert_tznames_to_sql(self, tzname):
- if tzname and settings.USE_TZ:
- return tzname, self.connection.timezone_name
- return None, None
- def datetime_cast_date_sql(self, sql, params, tzname):
- return f"django_datetime_cast_date({sql}, %s, %s)", (
- *params,
- *self._convert_tznames_to_sql(tzname),
- )
- def datetime_cast_time_sql(self, sql, params, tzname):
- return f"django_datetime_cast_time({sql}, %s, %s)", (
- *params,
- *self._convert_tznames_to_sql(tzname),
- )
- def datetime_extract_sql(self, lookup_type, sql, params, tzname):
- return f"django_datetime_extract(%s, {sql}, %s, %s)", (
- lookup_type.lower(),
- *params,
- *self._convert_tznames_to_sql(tzname),
- )
- def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
- return f"django_datetime_trunc(%s, {sql}, %s, %s)", (
- lookup_type.lower(),
- *params,
- *self._convert_tznames_to_sql(tzname),
- )
- def time_extract_sql(self, lookup_type, sql, params):
- return f"django_time_extract(%s, {sql})", (lookup_type.lower(), *params)
- def pk_default_value(self):
- return "NULL"
- def _quote_params_for_last_executed_query(self, params):
- """
- Only for last_executed_query! Don't use this to execute SQL queries!
- """
- # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
- # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
- # number of return values, default = 2000). Since Python's sqlite3
- # module doesn't expose the get_limit() C API, assume the default
- # limits are in effect and split the work in batches if needed.
- BATCH_SIZE = 999
- if len(params) > BATCH_SIZE:
- results = ()
- for index in range(0, len(params), BATCH_SIZE):
- chunk = params[index : index + BATCH_SIZE]
- results += self._quote_params_for_last_executed_query(chunk)
- return results
- sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
- # Bypass Django's wrappers and use the underlying sqlite3 connection
- # to avoid logging this query - it would trigger infinite recursion.
- cursor = self.connection.connection.cursor()
- # Native sqlite3 cursors cannot be used as context managers.
- try:
- return cursor.execute(sql, params).fetchone()
- finally:
- cursor.close()
- def last_executed_query(self, cursor, sql, params):
- # Python substitutes parameters in Modules/_sqlite/cursor.c with:
- # bind_parameters(state, self->statement, parameters);
- # Unfortunately there is no way to reach self->statement from Python,
- # so we quote and substitute parameters manually.
- if params:
- if isinstance(params, (list, tuple)):
- params = self._quote_params_for_last_executed_query(params)
- else:
- values = tuple(params.values())
- values = self._quote_params_for_last_executed_query(values)
- params = dict(zip(params, values))
- return sql % params
- # For consistency with SQLiteCursorWrapper.execute(), just return sql
- # when there are no parameters. See #13648 and #17158.
- else:
- return sql
- def quote_name(self, name):
- if name.startswith('"') and name.endswith('"'):
- return name # Quoting once is enough.
- return '"%s"' % name
- def no_limit_value(self):
- return -1
- def __references_graph(self, table_name):
- query = """
- WITH tables AS (
- SELECT %s name
- UNION
- SELECT sqlite_master.name
- FROM sqlite_master
- JOIN tables ON (sql REGEXP %s || tables.name || %s)
- ) SELECT name FROM tables;
- """
- params = (
- table_name,
- r'(?i)\s+references\s+("|\')?',
- r'("|\')?\s*\(',
- )
- with self.connection.cursor() as cursor:
- results = cursor.execute(query, params)
- return [row[0] for row in results.fetchall()]
- @cached_property
- def _references_graph(self):
- # 512 is large enough to fit the ~330 tables (as of this writing) in
- # Django's test suite.
- return lru_cache(maxsize=512)(self.__references_graph)
- def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
- if tables and allow_cascade:
- # Simulate TRUNCATE CASCADE by recursively collecting the tables
- # referencing the tables to be flushed.
- tables = set(
- chain.from_iterable(self._references_graph(table) for table in tables)
- )
- sql = [
- "%s %s %s;"
- % (
- style.SQL_KEYWORD("DELETE"),
- style.SQL_KEYWORD("FROM"),
- style.SQL_FIELD(self.quote_name(table)),
- )
- for table in tables
- ]
- if reset_sequences:
- sequences = [{"table": table} for table in tables]
- sql.extend(self.sequence_reset_by_name_sql(style, sequences))
- return sql
- def sequence_reset_by_name_sql(self, style, sequences):
- if not sequences:
- return []
- return [
- "%s %s %s %s = 0 %s %s %s (%s);"
- % (
- style.SQL_KEYWORD("UPDATE"),
- style.SQL_TABLE(self.quote_name("sqlite_sequence")),
- style.SQL_KEYWORD("SET"),
- style.SQL_FIELD(self.quote_name("seq")),
- style.SQL_KEYWORD("WHERE"),
- style.SQL_FIELD(self.quote_name("name")),
- style.SQL_KEYWORD("IN"),
- ", ".join(
- ["'%s'" % sequence_info["table"] for sequence_info in sequences]
- ),
- ),
- ]
- def adapt_datetimefield_value(self, value):
- if value is None:
- return None
- # Expression values are adapted by the database.
- if hasattr(value, "resolve_expression"):
- return value
- # SQLite doesn't support tz-aware datetimes
- if timezone.is_aware(value):
- if settings.USE_TZ:
- value = timezone.make_naive(value, self.connection.timezone)
- else:
- raise ValueError(
- "SQLite backend does not support timezone-aware datetimes when "
- "USE_TZ is False."
- )
- return str(value)
- def adapt_timefield_value(self, value):
- if value is None:
- return None
- # Expression values are adapted by the database.
- if hasattr(value, "resolve_expression"):
- return value
- # SQLite doesn't support tz-aware datetimes
- if timezone.is_aware(value):
- raise ValueError("SQLite backend does not support timezone-aware times.")
- return str(value)
- def get_db_converters(self, expression):
- converters = super().get_db_converters(expression)
- internal_type = expression.output_field.get_internal_type()
- if internal_type == "DateTimeField":
- converters.append(self.convert_datetimefield_value)
- elif internal_type == "DateField":
- converters.append(self.convert_datefield_value)
- elif internal_type == "TimeField":
- converters.append(self.convert_timefield_value)
- elif internal_type == "DecimalField":
- converters.append(self.get_decimalfield_converter(expression))
- elif internal_type == "UUIDField":
- converters.append(self.convert_uuidfield_value)
- elif internal_type == "BooleanField":
- converters.append(self.convert_booleanfield_value)
- return converters
- def convert_datetimefield_value(self, value, expression, connection):
- if value is not None:
- if not isinstance(value, datetime.datetime):
- value = parse_datetime(value)
- if settings.USE_TZ and not timezone.is_aware(value):
- value = timezone.make_aware(value, self.connection.timezone)
- return value
- def convert_datefield_value(self, value, expression, connection):
- if value is not None:
- if not isinstance(value, datetime.date):
- value = parse_date(value)
- return value
- def convert_timefield_value(self, value, expression, connection):
- if value is not None:
- if not isinstance(value, datetime.time):
- value = parse_time(value)
- return value
- def get_decimalfield_converter(self, expression):
- # SQLite stores only 15 significant digits. Digits coming from
- # float inaccuracy must be removed.
- create_decimal = decimal.Context(prec=15).create_decimal_from_float
- if isinstance(expression, Col):
- quantize_value = decimal.Decimal(1).scaleb(
- -expression.output_field.decimal_places
- )
- def converter(value, expression, connection):
- if value is not None:
- return create_decimal(value).quantize(
- quantize_value, context=expression.output_field.context
- )
- else:
- def converter(value, expression, connection):
- if value is not None:
- return create_decimal(value)
- return converter
- def convert_uuidfield_value(self, value, expression, connection):
- if value is not None:
- value = uuid.UUID(value)
- return value
- def convert_booleanfield_value(self, value, expression, connection):
- return bool(value) if value in (1, 0) else value
- def bulk_insert_sql(self, fields, placeholder_rows):
- placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
- values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
- return f"VALUES {values_sql}"
- def combine_expression(self, connector, sub_expressions):
- # SQLite doesn't have a ^ operator, so use the user-defined POWER
- # function that's registered in connect().
- if connector == "^":
- return "POWER(%s)" % ",".join(sub_expressions)
- elif connector == "#":
- return "BITXOR(%s)" % ",".join(sub_expressions)
- return super().combine_expression(connector, sub_expressions)
- def combine_duration_expression(self, connector, sub_expressions):
- if connector not in ["+", "-", "*", "/"]:
- raise DatabaseError("Invalid connector for timedelta: %s." % connector)
- fn_params = ["'%s'" % connector] + sub_expressions
- if len(fn_params) > 3:
- raise ValueError("Too many params for timedelta operations.")
- return "django_format_dtdelta(%s)" % ", ".join(fn_params)
- def integer_field_range(self, internal_type):
- # SQLite doesn't enforce any integer constraints
- return (None, None)
- def subtract_temporals(self, internal_type, lhs, rhs):
- lhs_sql, lhs_params = lhs
- rhs_sql, rhs_params = rhs
- params = (*lhs_params, *rhs_params)
- if internal_type == "TimeField":
- return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params
- return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params
- def insert_statement(self, on_conflict=None):
- if on_conflict == OnConflict.IGNORE:
- return "INSERT OR IGNORE INTO"
- return super().insert_statement(on_conflict=on_conflict)
- def return_insert_columns(self, fields):
- # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
- if not fields:
- return "", ()
- columns = [
- "%s.%s"
- % (
- self.quote_name(field.model._meta.db_table),
- self.quote_name(field.column),
- )
- for field in fields
- ]
- return "RETURNING %s" % ", ".join(columns), ()
- def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
- if (
- on_conflict == OnConflict.UPDATE
- and self.connection.features.supports_update_conflicts_with_target
- ):
- return "ON CONFLICT(%s) DO UPDATE SET %s" % (
- ", ".join(map(self.quote_name, unique_fields)),
- ", ".join(
- [
- f"{field} = EXCLUDED.{field}"
- for field in map(self.quote_name, update_fields)
- ]
- ),
- )
- return super().on_conflict_suffix_sql(
- fields,
- on_conflict,
- update_fields,
- unique_fields,
- )
|