123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500 |
- import hashlib
- import sys
- import time
- from django.conf import settings
- from django.db.utils import load_backend
- from django.utils.encoding import force_bytes
- from django.utils.functional import cached_property
- from django.utils.six.moves import input
- from .utils import truncate_name
- # The prefix to put on the default database name when creating
- # the test database.
- TEST_DATABASE_PREFIX = 'test_'
- NO_DB_ALIAS = '__no_db__'
- class BaseDatabaseCreation(object):
- """
- This class encapsulates all backend-specific differences that pertain to
- database *creation*, such as the column types to use for particular Django
- Fields, the SQL used to create and destroy tables, and the creation and
- destruction of test databases.
- """
- data_types = {}
- data_types_suffix = {}
- data_type_check_constraints = {}
- def __init__(self, connection):
- self.connection = connection
- @cached_property
- def _nodb_connection(self):
- """
- Alternative connection to be used when there is no need to access
- the main database, specifically for test db creation/deletion.
- This also prevents the production database from being exposed to
- potential child threads while (or after) the test database is destroyed.
- Refs #10868, #17786, #16969.
- """
- settings_dict = self.connection.settings_dict.copy()
- settings_dict['NAME'] = None
- backend = load_backend(settings_dict['ENGINE'])
- nodb_connection = backend.DatabaseWrapper(
- settings_dict,
- alias=NO_DB_ALIAS,
- allow_thread_sharing=False)
- return nodb_connection
- @classmethod
- def _digest(cls, *args):
- """
- Generates a 32-bit digest of a set of arguments that can be used to
- shorten identifying names.
- """
- h = hashlib.md5()
- for arg in args:
- h.update(force_bytes(arg))
- return h.hexdigest()[:8]
- def sql_create_model(self, model, style, known_models=set()):
- """
- Returns the SQL required to create a single model, as a tuple of:
- (list_of_sql, pending_references_dict)
- """
- opts = model._meta
- if not opts.managed or opts.proxy or opts.swapped:
- return [], {}
- final_output = []
- table_output = []
- pending_references = {}
- qn = self.connection.ops.quote_name
- for f in opts.local_fields:
- col_type = f.db_type(connection=self.connection)
- col_type_suffix = f.db_type_suffix(connection=self.connection)
- tablespace = f.db_tablespace or opts.db_tablespace
- if col_type is None:
- # Skip ManyToManyFields, because they're not represented as
- # database columns in this table.
- continue
- # Make the definition (e.g. 'foo VARCHAR(30)') for this field.
- field_output = [style.SQL_FIELD(qn(f.column)),
- style.SQL_COLTYPE(col_type)]
- # Oracle treats the empty string ('') as null, so coerce the null
- # option whenever '' is a possible value.
- null = f.null
- if (f.empty_strings_allowed and not f.primary_key and
- self.connection.features.interprets_empty_strings_as_nulls):
- null = True
- if not null:
- field_output.append(style.SQL_KEYWORD('NOT NULL'))
- if f.primary_key:
- field_output.append(style.SQL_KEYWORD('PRIMARY KEY'))
- elif f.unique:
- field_output.append(style.SQL_KEYWORD('UNIQUE'))
- if tablespace and f.unique:
- # We must specify the index tablespace inline, because we
- # won't be generating a CREATE INDEX statement for this field.
- tablespace_sql = self.connection.ops.tablespace_sql(
- tablespace, inline=True)
- if tablespace_sql:
- field_output.append(tablespace_sql)
- if f.rel and f.db_constraint:
- ref_output, pending = self.sql_for_inline_foreign_key_references(
- model, f, known_models, style)
- if pending:
- pending_references.setdefault(f.rel.to, []).append(
- (model, f))
- else:
- field_output.extend(ref_output)
- if col_type_suffix:
- field_output.append(style.SQL_KEYWORD(col_type_suffix))
- table_output.append(' '.join(field_output))
- for field_constraints in opts.unique_together:
- table_output.append(style.SQL_KEYWORD('UNIQUE') + ' (%s)' %
- ", ".join(
- [style.SQL_FIELD(qn(opts.get_field(f).column))
- for f in field_constraints]))
- full_statement = [style.SQL_KEYWORD('CREATE TABLE') + ' ' +
- style.SQL_TABLE(qn(opts.db_table)) + ' (']
- for i, line in enumerate(table_output): # Combine and add commas.
- full_statement.append(
- ' %s%s' % (line, ',' if i < len(table_output) - 1 else ''))
- full_statement.append(')')
- if opts.db_tablespace:
- tablespace_sql = self.connection.ops.tablespace_sql(
- opts.db_tablespace)
- if tablespace_sql:
- full_statement.append(tablespace_sql)
- full_statement.append(';')
- final_output.append('\n'.join(full_statement))
- if opts.has_auto_field:
- # Add any extra SQL needed to support auto-incrementing primary
- # keys.
- auto_column = opts.auto_field.db_column or opts.auto_field.name
- autoinc_sql = self.connection.ops.autoinc_sql(opts.db_table,
- auto_column)
- if autoinc_sql:
- for stmt in autoinc_sql:
- final_output.append(stmt)
- return final_output, pending_references
- def sql_for_inline_foreign_key_references(self, model, field, known_models, style):
- """
- Return the SQL snippet defining the foreign key reference for a field.
- """
- qn = self.connection.ops.quote_name
- rel_to = field.rel.to
- if rel_to in known_models or rel_to == model:
- output = [style.SQL_KEYWORD('REFERENCES') + ' ' +
- style.SQL_TABLE(qn(rel_to._meta.db_table)) + ' (' +
- style.SQL_FIELD(qn(rel_to._meta.get_field(
- field.rel.field_name).column)) + ')' +
- self.connection.ops.deferrable_sql()
- ]
- pending = False
- else:
- # We haven't yet created the table to which this field
- # is related, so save it for later.
- output = []
- pending = True
- return output, pending
- def sql_for_pending_references(self, model, style, pending_references):
- """
- Returns any ALTER TABLE statements to add constraints after the fact.
- """
- opts = model._meta
- if not opts.managed or opts.swapped:
- return []
- qn = self.connection.ops.quote_name
- final_output = []
- if model in pending_references:
- for rel_class, f in pending_references[model]:
- rel_opts = rel_class._meta
- r_table = rel_opts.db_table
- r_col = f.column
- table = opts.db_table
- col = opts.get_field(f.rel.field_name).column
- # For MySQL, r_name must be unique in the first 64 characters.
- # So we are careful with character usage here.
- r_name = '%s_refs_%s_%s' % (
- r_col, col, self._digest(r_table, table))
- final_output.append(style.SQL_KEYWORD('ALTER TABLE') +
- ' %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)%s;' %
- (qn(r_table), qn(truncate_name(
- r_name, self.connection.ops.max_name_length())),
- qn(r_col), qn(table), qn(col),
- self.connection.ops.deferrable_sql()))
- del pending_references[model]
- return final_output
- def sql_indexes_for_model(self, model, style):
- """
- Returns the CREATE INDEX SQL statements for a single model.
- """
- if not model._meta.managed or model._meta.proxy or model._meta.swapped:
- return []
- output = []
- for f in model._meta.local_fields:
- output.extend(self.sql_indexes_for_field(model, f, style))
- for fs in model._meta.index_together:
- fields = [model._meta.get_field_by_name(f)[0] for f in fs]
- output.extend(self.sql_indexes_for_fields(model, fields, style))
- return output
- def sql_indexes_for_field(self, model, f, style):
- """
- Return the CREATE INDEX SQL statements for a single model field.
- """
- if f.db_index and not f.unique:
- return self.sql_indexes_for_fields(model, [f], style)
- else:
- return []
- def sql_indexes_for_fields(self, model, fields, style):
- if len(fields) == 1 and fields[0].db_tablespace:
- tablespace_sql = self.connection.ops.tablespace_sql(fields[0].db_tablespace)
- elif model._meta.db_tablespace:
- tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace)
- else:
- tablespace_sql = ""
- if tablespace_sql:
- tablespace_sql = " " + tablespace_sql
- field_names = []
- qn = self.connection.ops.quote_name
- for f in fields:
- field_names.append(style.SQL_FIELD(qn(f.column)))
- index_name = "%s_%s" % (model._meta.db_table, self._digest([f.name for f in fields]))
- return [
- style.SQL_KEYWORD("CREATE INDEX") + " " +
- style.SQL_TABLE(qn(truncate_name(index_name, self.connection.ops.max_name_length()))) + " " +
- style.SQL_KEYWORD("ON") + " " +
- style.SQL_TABLE(qn(model._meta.db_table)) + " " +
- "(%s)" % style.SQL_FIELD(", ".join(field_names)) +
- "%s;" % tablespace_sql,
- ]
- def sql_destroy_model(self, model, references_to_delete, style):
- """
- Return the DROP TABLE and restraint dropping statements for a single
- model.
- """
- if not model._meta.managed or model._meta.proxy or model._meta.swapped:
- return []
- # Drop the table now
- qn = self.connection.ops.quote_name
- output = ['%s %s;' % (style.SQL_KEYWORD('DROP TABLE'),
- style.SQL_TABLE(qn(model._meta.db_table)))]
- if model in references_to_delete:
- output.extend(self.sql_remove_table_constraints(
- model, references_to_delete, style))
- if model._meta.has_auto_field:
- ds = self.connection.ops.drop_sequence_sql(model._meta.db_table)
- if ds:
- output.append(ds)
- return output
- def sql_remove_table_constraints(self, model, references_to_delete, style):
- if not model._meta.managed or model._meta.proxy or model._meta.swapped:
- return []
- output = []
- qn = self.connection.ops.quote_name
- for rel_class, f in references_to_delete[model]:
- table = rel_class._meta.db_table
- col = f.column
- r_table = model._meta.db_table
- r_col = model._meta.get_field(f.rel.field_name).column
- r_name = '%s_refs_%s_%s' % (
- col, r_col, self._digest(table, r_table))
- output.append('%s %s %s %s;' % (
- style.SQL_KEYWORD('ALTER TABLE'),
- style.SQL_TABLE(qn(table)),
- style.SQL_KEYWORD(self.connection.ops.drop_foreignkey_sql()),
- style.SQL_FIELD(qn(truncate_name(
- r_name, self.connection.ops.max_name_length())))
- ))
- del references_to_delete[model]
- return output
- def sql_destroy_indexes_for_model(self, model, style):
- """
- Returns the DROP INDEX SQL statements for a single model.
- """
- if not model._meta.managed or model._meta.proxy or model._meta.swapped:
- return []
- output = []
- for f in model._meta.local_fields:
- output.extend(self.sql_destroy_indexes_for_field(model, f, style))
- for fs in model._meta.index_together:
- fields = [model._meta.get_field_by_name(f)[0] for f in fs]
- output.extend(self.sql_destroy_indexes_for_fields(model, fields, style))
- return output
- def sql_destroy_indexes_for_field(self, model, f, style):
- """
- Return the DROP INDEX SQL statements for a single model field.
- """
- if f.db_index and not f.unique:
- return self.sql_destroy_indexes_for_fields(model, [f], style)
- else:
- return []
- def sql_destroy_indexes_for_fields(self, model, fields, style):
- if len(fields) == 1 and fields[0].db_tablespace:
- tablespace_sql = self.connection.ops.tablespace_sql(fields[0].db_tablespace)
- elif model._meta.db_tablespace:
- tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace)
- else:
- tablespace_sql = ""
- if tablespace_sql:
- tablespace_sql = " " + tablespace_sql
- field_names = []
- qn = self.connection.ops.quote_name
- for f in fields:
- field_names.append(style.SQL_FIELD(qn(f.column)))
- index_name = "%s_%s" % (model._meta.db_table, self._digest([f.name for f in fields]))
- return [
- style.SQL_KEYWORD("DROP INDEX") + " " +
- style.SQL_TABLE(qn(truncate_name(index_name, self.connection.ops.max_name_length()))) + " " +
- ";",
- ]
- def create_test_db(self, verbosity=1, autoclobber=False, keepdb=False):
- """
- Creates a test database, prompting the user for confirmation if the
- database already exists. Returns the name of the test database created.
- """
- # Don't import django.core.management if it isn't needed.
- from django.core.management import call_command
- test_database_name = self._get_test_db_name()
- if verbosity >= 1:
- test_db_repr = ''
- action = 'Creating'
- if verbosity >= 2:
- test_db_repr = " ('%s')" % test_database_name
- if keepdb:
- action = "Using existing"
- print("%s test database for alias '%s'%s..." % (
- action, self.connection.alias, test_db_repr))
- # We could skip this call if keepdb is True, but we instead
- # give it the keepdb param. This is to handle the case
- # where the test DB doesn't exist, in which case we need to
- # create it, then just not destroy it. If we instead skip
- # this, we will get an exception.
- self._create_test_db(verbosity, autoclobber, keepdb)
- self.connection.close()
- settings.DATABASES[self.connection.alias]["NAME"] = test_database_name
- self.connection.settings_dict["NAME"] = test_database_name
- # Report migrate messages at one level lower than that requested.
- # This ensures we don't get flooded with messages during testing
- # (unless you really ask to be flooded)
- call_command('migrate',
- verbosity=max(verbosity - 1, 0),
- interactive=False,
- database=self.connection.alias,
- load_initial_data=False,
- test_database=True)
- # We need to then do a flush to ensure that any data installed by
- # custom SQL has been removed. The only test data should come from
- # test fixtures, or autogenerated from post_migrate triggers.
- # This has the side effect of loading initial data (which was
- # intentionally skipped in the syncdb).
- call_command('flush',
- verbosity=max(verbosity - 1, 0),
- interactive=False,
- database=self.connection.alias)
- call_command('createcachetable', database=self.connection.alias)
- # Ensure a connection for the side effect of initializing the test database.
- self.connection.ensure_connection()
- return test_database_name
- def _get_test_db_name(self):
- """
- Internal implementation - returns the name of the test DB that will be
- created. Only useful when called from create_test_db() and
- _create_test_db() and when no external munging is done with the 'NAME'
- or 'TEST_NAME' settings.
- """
- if self.connection.settings_dict['TEST']['NAME']:
- return self.connection.settings_dict['TEST']['NAME']
- return TEST_DATABASE_PREFIX + self.connection.settings_dict['NAME']
- def _create_test_db(self, verbosity, autoclobber, keepdb=False):
- """
- Internal implementation - creates the test db tables.
- """
- suffix = self.sql_table_creation_suffix()
- test_database_name = self._get_test_db_name()
- qn = self.connection.ops.quote_name
- # Create the test database and connect to it.
- with self._nodb_connection.cursor() as cursor:
- try:
- cursor.execute(
- "CREATE DATABASE %s %s" % (qn(test_database_name), suffix))
- except Exception as e:
- # if we want to keep the db, then no need to do any of the below,
- # just return and skip it all.
- if keepdb:
- return test_database_name
- sys.stderr.write(
- "Got an error creating the test database: %s\n" % e)
- if not autoclobber:
- confirm = input(
- "Type 'yes' if you would like to try deleting the test "
- "database '%s', or 'no' to cancel: " % test_database_name)
- if autoclobber or confirm == 'yes':
- try:
- if verbosity >= 1:
- print("Destroying old test database '%s'..."
- % self.connection.alias)
- cursor.execute(
- "DROP DATABASE %s" % qn(test_database_name))
- cursor.execute(
- "CREATE DATABASE %s %s" % (qn(test_database_name),
- suffix))
- except Exception as e:
- sys.stderr.write(
- "Got an error recreating the test database: %s\n" % e)
- sys.exit(2)
- else:
- print("Tests cancelled.")
- sys.exit(1)
- return test_database_name
- def destroy_test_db(self, old_database_name, verbosity=1):
- """
- Destroy a test database, prompting the user for confirmation if the
- database already exists.
- """
- self.connection.close()
- test_database_name = self.connection.settings_dict['NAME']
- if verbosity >= 1:
- test_db_repr = ''
- if verbosity >= 2:
- test_db_repr = " ('%s')" % test_database_name
- print("Destroying test database for alias '%s'%s..." % (
- self.connection.alias, test_db_repr))
- self._destroy_test_db(test_database_name, verbosity)
- def _destroy_test_db(self, test_database_name, verbosity):
- """
- Internal implementation - remove the test db tables.
- """
- # Remove the test database to clean up after
- # ourselves. Connect to the previous database (not the test database)
- # to do so, because it's not allowed to delete a database while being
- # connected to it.
- with self._nodb_connection.cursor() as cursor:
- # Wait to avoid "database is being accessed by other users" errors.
- time.sleep(1)
- cursor.execute("DROP DATABASE %s"
- % self.connection.ops.quote_name(test_database_name))
- def sql_table_creation_suffix(self):
- """
- SQL to append to the end of the test table creation statements.
- """
- return ''
- def test_db_signature(self):
- """
- Returns a tuple with elements of self.connection.settings_dict (a
- DATABASES setting value) that uniquely identify a database
- accordingly to the RDBMS particularities.
- """
- settings_dict = self.connection.settings_dict
- return (
- settings_dict['HOST'],
- settings_dict['PORT'],
- settings_dict['ENGINE'],
- settings_dict['NAME']
- )
|