operations.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. import datetime
  2. import decimal
  3. import uuid
  4. from functools import lru_cache
  5. from itertools import chain
  6. from django.conf import settings
  7. from django.core.exceptions import FieldError
  8. from django.db import DatabaseError, NotSupportedError, models
  9. from django.db.backends.base.operations import BaseDatabaseOperations
  10. from django.db.models.constants import OnConflict
  11. from django.db.models.expressions import Col
  12. from django.utils import timezone
  13. from django.utils.dateparse import parse_date, parse_datetime, parse_time
  14. from django.utils.functional import cached_property
  15. class DatabaseOperations(BaseDatabaseOperations):
  16. cast_char_field_without_max_length = "text"
  17. cast_data_types = {
  18. "DateField": "TEXT",
  19. "DateTimeField": "TEXT",
  20. }
  21. explain_prefix = "EXPLAIN QUERY PLAN"
  22. # List of datatypes to that cannot be extracted with JSON_EXTRACT() on
  23. # SQLite. Use JSON_TYPE() instead.
  24. jsonfield_datatype_values = frozenset(["null", "false", "true"])
  25. def bulk_batch_size(self, fields, objs):
  26. """
  27. SQLite has a compile-time default (SQLITE_LIMIT_VARIABLE_NUMBER) of
  28. 999 variables per query.
  29. If there's only a single field to insert, the limit is 500
  30. (SQLITE_MAX_COMPOUND_SELECT).
  31. """
  32. if len(fields) == 1:
  33. return 500
  34. elif len(fields) > 1:
  35. return self.connection.features.max_query_params // len(fields)
  36. else:
  37. return len(objs)
  38. def check_expression_support(self, expression):
  39. bad_fields = (models.DateField, models.DateTimeField, models.TimeField)
  40. bad_aggregates = (models.Sum, models.Avg, models.Variance, models.StdDev)
  41. if isinstance(expression, bad_aggregates):
  42. for expr in expression.get_source_expressions():
  43. try:
  44. output_field = expr.output_field
  45. except (AttributeError, FieldError):
  46. # Not every subexpression has an output_field which is fine
  47. # to ignore.
  48. pass
  49. else:
  50. if isinstance(output_field, bad_fields):
  51. raise NotSupportedError(
  52. "You cannot use Sum, Avg, StdDev, and Variance "
  53. "aggregations on date/time fields in sqlite3 "
  54. "since date/time is saved as text."
  55. )
  56. if (
  57. isinstance(expression, models.Aggregate)
  58. and expression.distinct
  59. and len(expression.source_expressions) > 1
  60. ):
  61. raise NotSupportedError(
  62. "SQLite doesn't support DISTINCT on aggregate functions "
  63. "accepting multiple arguments."
  64. )
  65. def date_extract_sql(self, lookup_type, sql, params):
  66. """
  67. Support EXTRACT with a user-defined function django_date_extract()
  68. that's registered in connect(). Use single quotes because this is a
  69. string and could otherwise cause a collision with a field name.
  70. """
  71. return f"django_date_extract(%s, {sql})", (lookup_type.lower(), *params)
  72. def fetch_returned_insert_rows(self, cursor):
  73. """
  74. Given a cursor object that has just performed an INSERT...RETURNING
  75. statement into a table, return the list of returned data.
  76. """
  77. return cursor.fetchall()
  78. def format_for_duration_arithmetic(self, sql):
  79. """Do nothing since formatting is handled in the custom function."""
  80. return sql
  81. def date_trunc_sql(self, lookup_type, sql, params, tzname=None):
  82. return f"django_date_trunc(%s, {sql}, %s, %s)", (
  83. lookup_type.lower(),
  84. *params,
  85. *self._convert_tznames_to_sql(tzname),
  86. )
  87. def time_trunc_sql(self, lookup_type, sql, params, tzname=None):
  88. return f"django_time_trunc(%s, {sql}, %s, %s)", (
  89. lookup_type.lower(),
  90. *params,
  91. *self._convert_tznames_to_sql(tzname),
  92. )
  93. def _convert_tznames_to_sql(self, tzname):
  94. if tzname and settings.USE_TZ:
  95. return tzname, self.connection.timezone_name
  96. return None, None
  97. def datetime_cast_date_sql(self, sql, params, tzname):
  98. return f"django_datetime_cast_date({sql}, %s, %s)", (
  99. *params,
  100. *self._convert_tznames_to_sql(tzname),
  101. )
  102. def datetime_cast_time_sql(self, sql, params, tzname):
  103. return f"django_datetime_cast_time({sql}, %s, %s)", (
  104. *params,
  105. *self._convert_tznames_to_sql(tzname),
  106. )
  107. def datetime_extract_sql(self, lookup_type, sql, params, tzname):
  108. return f"django_datetime_extract(%s, {sql}, %s, %s)", (
  109. lookup_type.lower(),
  110. *params,
  111. *self._convert_tznames_to_sql(tzname),
  112. )
  113. def datetime_trunc_sql(self, lookup_type, sql, params, tzname):
  114. return f"django_datetime_trunc(%s, {sql}, %s, %s)", (
  115. lookup_type.lower(),
  116. *params,
  117. *self._convert_tznames_to_sql(tzname),
  118. )
  119. def time_extract_sql(self, lookup_type, sql, params):
  120. return f"django_time_extract(%s, {sql})", (lookup_type.lower(), *params)
  121. def pk_default_value(self):
  122. return "NULL"
  123. def _quote_params_for_last_executed_query(self, params):
  124. """
  125. Only for last_executed_query! Don't use this to execute SQL queries!
  126. """
  127. # This function is limited both by SQLITE_LIMIT_VARIABLE_NUMBER (the
  128. # number of parameters, default = 999) and SQLITE_MAX_COLUMN (the
  129. # number of return values, default = 2000). Since Python's sqlite3
  130. # module doesn't expose the get_limit() C API, assume the default
  131. # limits are in effect and split the work in batches if needed.
  132. BATCH_SIZE = 999
  133. if len(params) > BATCH_SIZE:
  134. results = ()
  135. for index in range(0, len(params), BATCH_SIZE):
  136. chunk = params[index : index + BATCH_SIZE]
  137. results += self._quote_params_for_last_executed_query(chunk)
  138. return results
  139. sql = "SELECT " + ", ".join(["QUOTE(?)"] * len(params))
  140. # Bypass Django's wrappers and use the underlying sqlite3 connection
  141. # to avoid logging this query - it would trigger infinite recursion.
  142. cursor = self.connection.connection.cursor()
  143. # Native sqlite3 cursors cannot be used as context managers.
  144. try:
  145. return cursor.execute(sql, params).fetchone()
  146. finally:
  147. cursor.close()
  148. def last_executed_query(self, cursor, sql, params):
  149. # Python substitutes parameters in Modules/_sqlite/cursor.c with:
  150. # bind_parameters(state, self->statement, parameters);
  151. # Unfortunately there is no way to reach self->statement from Python,
  152. # so we quote and substitute parameters manually.
  153. if params:
  154. if isinstance(params, (list, tuple)):
  155. params = self._quote_params_for_last_executed_query(params)
  156. else:
  157. values = tuple(params.values())
  158. values = self._quote_params_for_last_executed_query(values)
  159. params = dict(zip(params, values))
  160. return sql % params
  161. # For consistency with SQLiteCursorWrapper.execute(), just return sql
  162. # when there are no parameters. See #13648 and #17158.
  163. else:
  164. return sql
  165. def quote_name(self, name):
  166. if name.startswith('"') and name.endswith('"'):
  167. return name # Quoting once is enough.
  168. return '"%s"' % name
  169. def no_limit_value(self):
  170. return -1
  171. def __references_graph(self, table_name):
  172. query = """
  173. WITH tables AS (
  174. SELECT %s name
  175. UNION
  176. SELECT sqlite_master.name
  177. FROM sqlite_master
  178. JOIN tables ON (sql REGEXP %s || tables.name || %s)
  179. ) SELECT name FROM tables;
  180. """
  181. params = (
  182. table_name,
  183. r'(?i)\s+references\s+("|\')?',
  184. r'("|\')?\s*\(',
  185. )
  186. with self.connection.cursor() as cursor:
  187. results = cursor.execute(query, params)
  188. return [row[0] for row in results.fetchall()]
  189. @cached_property
  190. def _references_graph(self):
  191. # 512 is large enough to fit the ~330 tables (as of this writing) in
  192. # Django's test suite.
  193. return lru_cache(maxsize=512)(self.__references_graph)
  194. def sql_flush(self, style, tables, *, reset_sequences=False, allow_cascade=False):
  195. if tables and allow_cascade:
  196. # Simulate TRUNCATE CASCADE by recursively collecting the tables
  197. # referencing the tables to be flushed.
  198. tables = set(
  199. chain.from_iterable(self._references_graph(table) for table in tables)
  200. )
  201. sql = [
  202. "%s %s %s;"
  203. % (
  204. style.SQL_KEYWORD("DELETE"),
  205. style.SQL_KEYWORD("FROM"),
  206. style.SQL_FIELD(self.quote_name(table)),
  207. )
  208. for table in tables
  209. ]
  210. if reset_sequences:
  211. sequences = [{"table": table} for table in tables]
  212. sql.extend(self.sequence_reset_by_name_sql(style, sequences))
  213. return sql
  214. def sequence_reset_by_name_sql(self, style, sequences):
  215. if not sequences:
  216. return []
  217. return [
  218. "%s %s %s %s = 0 %s %s %s (%s);"
  219. % (
  220. style.SQL_KEYWORD("UPDATE"),
  221. style.SQL_TABLE(self.quote_name("sqlite_sequence")),
  222. style.SQL_KEYWORD("SET"),
  223. style.SQL_FIELD(self.quote_name("seq")),
  224. style.SQL_KEYWORD("WHERE"),
  225. style.SQL_FIELD(self.quote_name("name")),
  226. style.SQL_KEYWORD("IN"),
  227. ", ".join(
  228. ["'%s'" % sequence_info["table"] for sequence_info in sequences]
  229. ),
  230. ),
  231. ]
  232. def adapt_datetimefield_value(self, value):
  233. if value is None:
  234. return None
  235. # Expression values are adapted by the database.
  236. if hasattr(value, "resolve_expression"):
  237. return value
  238. # SQLite doesn't support tz-aware datetimes
  239. if timezone.is_aware(value):
  240. if settings.USE_TZ:
  241. value = timezone.make_naive(value, self.connection.timezone)
  242. else:
  243. raise ValueError(
  244. "SQLite backend does not support timezone-aware datetimes when "
  245. "USE_TZ is False."
  246. )
  247. return str(value)
  248. def adapt_timefield_value(self, value):
  249. if value is None:
  250. return None
  251. # Expression values are adapted by the database.
  252. if hasattr(value, "resolve_expression"):
  253. return value
  254. # SQLite doesn't support tz-aware datetimes
  255. if timezone.is_aware(value):
  256. raise ValueError("SQLite backend does not support timezone-aware times.")
  257. return str(value)
  258. def get_db_converters(self, expression):
  259. converters = super().get_db_converters(expression)
  260. internal_type = expression.output_field.get_internal_type()
  261. if internal_type == "DateTimeField":
  262. converters.append(self.convert_datetimefield_value)
  263. elif internal_type == "DateField":
  264. converters.append(self.convert_datefield_value)
  265. elif internal_type == "TimeField":
  266. converters.append(self.convert_timefield_value)
  267. elif internal_type == "DecimalField":
  268. converters.append(self.get_decimalfield_converter(expression))
  269. elif internal_type == "UUIDField":
  270. converters.append(self.convert_uuidfield_value)
  271. elif internal_type == "BooleanField":
  272. converters.append(self.convert_booleanfield_value)
  273. return converters
  274. def convert_datetimefield_value(self, value, expression, connection):
  275. if value is not None:
  276. if not isinstance(value, datetime.datetime):
  277. value = parse_datetime(value)
  278. if settings.USE_TZ and not timezone.is_aware(value):
  279. value = timezone.make_aware(value, self.connection.timezone)
  280. return value
  281. def convert_datefield_value(self, value, expression, connection):
  282. if value is not None:
  283. if not isinstance(value, datetime.date):
  284. value = parse_date(value)
  285. return value
  286. def convert_timefield_value(self, value, expression, connection):
  287. if value is not None:
  288. if not isinstance(value, datetime.time):
  289. value = parse_time(value)
  290. return value
  291. def get_decimalfield_converter(self, expression):
  292. # SQLite stores only 15 significant digits. Digits coming from
  293. # float inaccuracy must be removed.
  294. create_decimal = decimal.Context(prec=15).create_decimal_from_float
  295. if isinstance(expression, Col):
  296. quantize_value = decimal.Decimal(1).scaleb(
  297. -expression.output_field.decimal_places
  298. )
  299. def converter(value, expression, connection):
  300. if value is not None:
  301. return create_decimal(value).quantize(
  302. quantize_value, context=expression.output_field.context
  303. )
  304. else:
  305. def converter(value, expression, connection):
  306. if value is not None:
  307. return create_decimal(value)
  308. return converter
  309. def convert_uuidfield_value(self, value, expression, connection):
  310. if value is not None:
  311. value = uuid.UUID(value)
  312. return value
  313. def convert_booleanfield_value(self, value, expression, connection):
  314. return bool(value) if value in (1, 0) else value
  315. def bulk_insert_sql(self, fields, placeholder_rows):
  316. placeholder_rows_sql = (", ".join(row) for row in placeholder_rows)
  317. values_sql = ", ".join(f"({sql})" for sql in placeholder_rows_sql)
  318. return f"VALUES {values_sql}"
  319. def combine_expression(self, connector, sub_expressions):
  320. # SQLite doesn't have a ^ operator, so use the user-defined POWER
  321. # function that's registered in connect().
  322. if connector == "^":
  323. return "POWER(%s)" % ",".join(sub_expressions)
  324. elif connector == "#":
  325. return "BITXOR(%s)" % ",".join(sub_expressions)
  326. return super().combine_expression(connector, sub_expressions)
  327. def combine_duration_expression(self, connector, sub_expressions):
  328. if connector not in ["+", "-", "*", "/"]:
  329. raise DatabaseError("Invalid connector for timedelta: %s." % connector)
  330. fn_params = ["'%s'" % connector] + sub_expressions
  331. if len(fn_params) > 3:
  332. raise ValueError("Too many params for timedelta operations.")
  333. return "django_format_dtdelta(%s)" % ", ".join(fn_params)
  334. def integer_field_range(self, internal_type):
  335. # SQLite doesn't enforce any integer constraints
  336. return (None, None)
  337. def subtract_temporals(self, internal_type, lhs, rhs):
  338. lhs_sql, lhs_params = lhs
  339. rhs_sql, rhs_params = rhs
  340. params = (*lhs_params, *rhs_params)
  341. if internal_type == "TimeField":
  342. return "django_time_diff(%s, %s)" % (lhs_sql, rhs_sql), params
  343. return "django_timestamp_diff(%s, %s)" % (lhs_sql, rhs_sql), params
  344. def insert_statement(self, on_conflict=None):
  345. if on_conflict == OnConflict.IGNORE:
  346. return "INSERT OR IGNORE INTO"
  347. return super().insert_statement(on_conflict=on_conflict)
  348. def return_insert_columns(self, fields):
  349. # SQLite < 3.35 doesn't support an INSERT...RETURNING statement.
  350. if not fields:
  351. return "", ()
  352. columns = [
  353. "%s.%s"
  354. % (
  355. self.quote_name(field.model._meta.db_table),
  356. self.quote_name(field.column),
  357. )
  358. for field in fields
  359. ]
  360. return "RETURNING %s" % ", ".join(columns), ()
  361. def on_conflict_suffix_sql(self, fields, on_conflict, update_fields, unique_fields):
  362. if (
  363. on_conflict == OnConflict.UPDATE
  364. and self.connection.features.supports_update_conflicts_with_target
  365. ):
  366. return "ON CONFLICT(%s) DO UPDATE SET %s" % (
  367. ", ".join(map(self.quote_name, unique_fields)),
  368. ", ".join(
  369. [
  370. f"{field} = EXCLUDED.{field}"
  371. for field in map(self.quote_name, update_fields)
  372. ]
  373. ),
  374. )
  375. return super().on_conflict_suffix_sql(
  376. fields,
  377. on_conflict,
  378. update_fields,
  379. unique_fields,
  380. )