Browse Source

Fixed #5805 -- it is now possible to specify multi-column indexes. Thanks to jgelens for the original patch.

Alex Gaynor 12 years ago
parent
commit
4285571c5a

+ 25 - 10
django/core/management/validation.py

@@ -1,3 +1,4 @@
+import collections
 import sys
 
 from django.conf import settings
@@ -327,15 +328,29 @@ def get_validation_errors(outfile, app=None):
 
         # Check unique_together.
         for ut in opts.unique_together:
-            for field_name in ut:
-                try:
-                    f = opts.get_field(field_name, many_to_many=True)
-                except models.FieldDoesNotExist:
-                    e.add(opts, '"unique_together" refers to %s, a field that doesn\'t exist. Check your syntax.' % field_name)
-                else:
-                    if isinstance(f.rel, models.ManyToManyRel):
-                        e.add(opts, '"unique_together" refers to %s. ManyToManyFields are not supported in unique_together.' % f.name)
-                    if f not in opts.local_fields:
-                        e.add(opts, '"unique_together" refers to %s. This is not in the same model as the unique_together statement.' % f.name)
+            validate_local_fields(e, opts, "unique_together", ut)
+        if not isinstance(opts.index_together, collections.Sequence):
+            e.add(opts, '"index_together" must a sequence')
+        else:
+            for it in opts.index_together:
+                validate_local_fields(e, opts, "index_together", it)
 
     return len(e.errors)
+
+
+def validate_local_fields(e, opts, field_name, fields):
+    from django.db import models
+
+    if not isinstance(fields, collections.Sequence):
+        e.add(opts, 'all %s elements must be sequences' % field_name)
+    else:
+        for field in fields:
+            try:
+                f = opts.get_field(field, many_to_many=True)
+            except models.FieldDoesNotExist:
+                e.add(opts, '"%s" refers to %s, a field that doesn\'t exist.' % (field_name, field))
+            else:
+                if isinstance(f.rel, models.ManyToManyRel):
+                    e.add(opts, '"%s" refers to %s. ManyToManyFields are not supported in %s.' % (field_name, f.name, field_name))
+                if f not in opts.local_fields:
+                    e.add(opts, '"%s" refers to %s. This is not in the same model as the %s statement.' % (field_name, f.name, field_name))

+ 32 - 19
django/db/backends/creation.py

@@ -177,34 +177,47 @@ class BaseDatabaseCreation(object):
         output = []
         for f in model._meta.local_fields:
             output.extend(self.sql_indexes_for_field(model, f, style))
+        for fs in model._meta.index_together:
+            fields = [model._meta.get_field_by_name(f)[0] for f in fs]
+            output.extend(self.sql_indexes_for_fields(model, fields, style))
         return output
 
     def sql_indexes_for_field(self, model, f, style):
         """
         Return the CREATE INDEX SQL statements for a single model field.
         """
+        if f.db_index and not f.unique:
+            return self.sql_indexes_for_fields(model, [f], style)
+        else:
+            return []
+
+    def sql_indexes_for_fields(self, model, fields, style):
         from django.db.backends.util import truncate_name
 
-        if f.db_index and not f.unique:
-            qn = self.connection.ops.quote_name
-            tablespace = f.db_tablespace or model._meta.db_tablespace
-            if tablespace:
-                tablespace_sql = self.connection.ops.tablespace_sql(tablespace)
-                if tablespace_sql:
-                    tablespace_sql = ' ' + tablespace_sql
-            else:
-                tablespace_sql = ''
-            i_name = '%s_%s' % (model._meta.db_table, self._digest(f.column))
-            output = [style.SQL_KEYWORD('CREATE INDEX') + ' ' +
-                style.SQL_TABLE(qn(truncate_name(
-                    i_name, self.connection.ops.max_name_length()))) + ' ' +
-                style.SQL_KEYWORD('ON') + ' ' +
-                style.SQL_TABLE(qn(model._meta.db_table)) + ' ' +
-                "(%s)" % style.SQL_FIELD(qn(f.column)) +
-                "%s;" % tablespace_sql]
+        if len(fields) == 1 and fields[0].db_tablespace:
+            tablespace_sql = self.connection.ops.tablespace_sql(fields[0].db_tablespace)
+        elif model._meta.db_tablespace:
+            tablespace_sql = self.connection.ops.tablespace_sql(model._meta.db_tablespace)
         else:
-            output = []
-        return output
+            tablespace_sql = ""
+        if tablespace_sql:
+            tablespace_sql = " " + tablespace_sql
+
+        field_names = []
+        qn = self.connection.ops.quote_name
+        for f in fields:
+            field_names.append(style.SQL_FIELD(qn(f.column)))
+
+        index_name = "%s_%s" % (model._meta.db_table, self._digest([f.name for f in fields]))
+
+        return [
+            style.SQL_KEYWORD("CREATE INDEX") + " " +
+            style.SQL_TABLE(qn(truncate_name(index_name, self.connection.ops.max_name_length()))) + " " +
+            style.SQL_KEYWORD("ON") + " " +
+            style.SQL_TABLE(qn(model._meta.db_table)) + " " +
+            "(%s)" % style.SQL_FIELD(", ".join(field_names)) +
+            "%s;" % tablespace_sql,
+        ]
 
     def sql_destroy_model(self, model, references_to_delete, style):
         """

+ 3 - 1
django/db/models/options.py

@@ -21,7 +21,8 @@ get_verbose_name = lambda class_name: re.sub('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|
 DEFAULT_NAMES = ('verbose_name', 'verbose_name_plural', 'db_table', 'ordering',
                  'unique_together', 'permissions', 'get_latest_by',
                  'order_with_respect_to', 'app_label', 'db_tablespace',
-                 'abstract', 'managed', 'proxy', 'swappable', 'auto_created')
+                 'abstract', 'managed', 'proxy', 'swappable', 'auto_created',
+                 'index_together')
 
 
 @python_2_unicode_compatible
@@ -34,6 +35,7 @@ class Options(object):
         self.db_table = ''
         self.ordering = []
         self.unique_together = []
+        self.index_together = []
         self.permissions = []
         self.object_name, self.app_label = None, app_label
         self.get_latest_by = None

+ 15 - 0
docs/ref/models/options.txt

@@ -261,6 +261,21 @@ Django quotes column and table names behind the scenes.
     :class:`~django.db.models.ManyToManyField`, try using a signal or
     an explicit :attr:`through <ManyToManyField.through>` model.
 
+``index_together``
+
+.. versionadded:: 1.5
+
+.. attribute:: Options.index_together
+
+    Sets of field names that, taken together, are indexed::
+
+        index_together = [
+            ["pub_date", "deadline"],
+        ]
+
+    This list of fields will be indexed together (i.e. the appropriate
+    ``CREATE INDEX`` statement will be issued.)
+
 ``verbose_name``
 ----------------
 

+ 8 - 0
tests/modeltests/invalid_models/invalid_models/models.py

@@ -356,6 +356,13 @@ class HardReferenceModel(models.Model):
     m2m_4 = models.ManyToManyField('invalid_models.SwappedModel', related_name='m2m_hardref4')
 
 
+class BadIndexTogether1(models.Model):
+    class Meta:
+        index_together = [
+            ["field_that_does_not_exist"],
+        ]
+
+
 model_errors = """invalid_models.fielderrors: "charfield": CharFields require a "max_length" attribute that is a positive integer.
 invalid_models.fielderrors: "charfield2": CharFields require a "max_length" attribute that is a positive integer.
 invalid_models.fielderrors: "charfield3": CharFields require a "max_length" attribute that is a positive integer.
@@ -470,6 +477,7 @@ invalid_models.hardreferencemodel: 'm2m_3' defines a relation with the model 'in
 invalid_models.hardreferencemodel: 'm2m_4' defines a relation with the model 'invalid_models.SwappedModel', which has been swapped out. Update the relation to point at settings.TEST_SWAPPED_MODEL.
 invalid_models.badswappablevalue: TEST_SWAPPED_MODEL_BAD_VALUE is not of the form 'app_label.app_name'.
 invalid_models.badswappablemodel: Model has been swapped out for 'not_an_app.Target' which has not been installed or is abstract.
+invalid_models.badindextogether1: "index_together" refers to field_that_does_not_exist, a field that doesn't exist.
 """
 
 if not connection.features.interprets_empty_strings_as_nulls:

+ 0 - 0
tests/regressiontests/indexes/__init__.py


+ 11 - 0
tests/regressiontests/indexes/models.py

@@ -0,0 +1,11 @@
+from django.db import models
+
+
+class Article(models.Model):
+    headline = models.CharField(max_length=100)
+    pub_date = models.DateTimeField()
+
+    class Meta:
+        index_together = [
+            ["headline", "pub_date"],
+        ]

+ 12 - 0
tests/regressiontests/indexes/tests.py

@@ -0,0 +1,12 @@
+from django.core.management.color import no_style
+from django.db import connections, DEFAULT_DB_ALIAS
+from django.test import TestCase
+
+from .models import Article
+
+
+class IndexesTests(TestCase):
+    def test_index_together(self):
+        connection = connections[DEFAULT_DB_ALIAS]
+        index_sql = connection.creation.sql_indexes_for_model(Article, no_style())
+        self.assertEqual(len(index_sql), 1)

+ 3 - 4
tests/regressiontests/initial_sql_regress/tests.py

@@ -1,3 +1,6 @@
+from django.core.management.color import no_style
+from django.core.management.sql import custom_sql_for_model
+from django.db import connections, DEFAULT_DB_ALIAS
 from django.test import TestCase
 
 from .models import Simple
@@ -15,10 +18,6 @@ class InitialSQLTests(TestCase):
         self.assertEqual(Simple.objects.count(), 0)
 
     def test_custom_sql(self):
-        from django.core.management.sql import custom_sql_for_model
-        from django.core.management.color import no_style
-        from django.db import connections, DEFAULT_DB_ALIAS
-
         # Simulate the custom SQL loading by syncdb
         connection = connections[DEFAULT_DB_ALIAS]
         custom_sql = custom_sql_for_model(Simple, no_style(), connection)

+ 4 - 0
tests/regressiontests/introspection/models.py

@@ -17,6 +17,7 @@ class Reporter(models.Model):
     def __str__(self):
         return "%s %s" % (self.first_name, self.last_name)
 
+
 @python_2_unicode_compatible
 class Article(models.Model):
     headline = models.CharField(max_length=100)
@@ -28,3 +29,6 @@ class Article(models.Model):
 
     class Meta:
         ordering = ('headline',)
+        index_together = [
+            ["headline", "pub_date"],
+        ]

+ 7 - 5
tests/regressiontests/introspection/tests.py

@@ -1,4 +1,4 @@
-from __future__ import absolute_import,unicode_literals
+from __future__ import absolute_import, unicode_literals
 
 from functools import update_wrapper
 
@@ -13,7 +13,7 @@ if connection.vendor == 'oracle':
 else:
     expectedFailureOnOracle = lambda f: f
 
-#
+
 # The introspection module is optional, so methods tested here might raise
 # NotImplementedError. This is perfectly acceptable behavior for the backend
 # in question, but the tests need to handle this without failing. Ideally we'd
@@ -23,7 +23,7 @@ else:
 # wrapper that ignores the exception.
 #
 # The metaclass is just for fun.
-#
+
 
 def ignore_not_implemented(func):
     def _inner(*args, **kwargs):
@@ -34,15 +34,16 @@ def ignore_not_implemented(func):
     update_wrapper(_inner, func)
     return _inner
 
+
 class IgnoreNotimplementedError(type):
     def __new__(cls, name, bases, attrs):
-        for k,v in attrs.items():
+        for k, v in attrs.items():
             if k.startswith('test'):
                 attrs[k] = ignore_not_implemented(v)
         return type.__new__(cls, name, bases, attrs)
 
-class IntrospectionTests(six.with_metaclass(IgnoreNotimplementedError, TestCase)):
 
+class IntrospectionTests(six.with_metaclass(IgnoreNotimplementedError, TestCase)):
     def test_table_names(self):
         tl = connection.introspection.table_names()
         self.assertEqual(tl, sorted(tl))
@@ -163,6 +164,7 @@ class IntrospectionTests(six.with_metaclass(IgnoreNotimplementedError, TestCase)
         self.assertNotIn('first_name', indexes)
         self.assertIn('id', indexes)
 
+
 def datatype(dbtype, description):
     """Helper to convert a data type into a string."""
     dt = connection.introspection.get_field_type(dbtype, description)

+ 1 - 1
tests/runtests.py

@@ -277,7 +277,7 @@ if __name__ == "__main__":
     usage = "%prog [options] [module module module ...]"
     parser = OptionParser(usage=usage)
     parser.add_option(
-        '-v','--verbosity', action='store', dest='verbosity', default='1',
+        '-v', '--verbosity', action='store', dest='verbosity', default='1',
         type='choice', choices=['0', '1', '2', '3'],
         help='Verbosity level; 0=minimal output, 1=normal output, 2=all '
              'output')