소스 검색

Refs #30183 -- Moved SQLite table constraint parsing to a method.

Paveł Tyślacki 6 년 전
부모
커밋
4492be348a
1개의 변경된 파일48개의 추가작업 그리고 42개의 파일을 삭제
  1. 48 42
      django/db/backends/sqlite3/introspection.py

+ 48 - 42
django/db/backends/sqlite3/introspection.py

@@ -217,6 +217,52 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
             }
         return constraints
 
+    def _parse_table_constraints(self, sql):
+        # Check constraint parsing is based of SQLite syntax diagram.
+        # https://www.sqlite.org/syntaxdiagrams.html#table-constraint
+        def next_ttype(ttype):
+            for token in tokens:
+                if token.ttype == ttype:
+                    return token
+
+        statement = sqlparse.parse(sql)[0]
+        constraints = {}
+        tokens = statement.flatten()
+        for token in tokens:
+            name = None
+            if token.match(sqlparse.tokens.Keyword, 'CONSTRAINT'):
+                # Table constraint
+                name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
+                name = name_token.value[1:-1]
+                token = next_ttype(sqlparse.tokens.Keyword)
+            if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
+                constraints[name] = {
+                    'unique': True,
+                    'columns': [],
+                    'primary_key': False,
+                    'foreign_key': None,
+                    'check': False,
+                    'index': False,
+                }
+            if token.match(sqlparse.tokens.Keyword, 'CHECK'):
+                # Column check constraint
+                if name is None:
+                    column_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
+                    column = column_token.value[1:-1]
+                    name = '__check__%s' % column
+                    columns = [column]
+                else:
+                    columns = []
+                constraints[name] = {
+                    'check': True,
+                    'columns': columns,
+                    'primary_key': False,
+                    'unique': False,
+                    'foreign_key': None,
+                    'index': False,
+                }
+        return constraints
+
     def get_constraints(self, cursor, table_name):
         """
         Retrieve any constraints or keys (unique, pk, fk, check, index) across
@@ -234,48 +280,8 @@ class DatabaseIntrospection(BaseDatabaseIntrospection):
             # table_name is a view.
             pass
         else:
-            # Check constraint parsing is based of SQLite syntax diagram.
-            # https://www.sqlite.org/syntaxdiagrams.html#table-constraint
-            def next_ttype(ttype):
-                for token in tokens:
-                    if token.ttype == ttype:
-                        return token
-
-            statement = sqlparse.parse(table_schema)[0]
-            tokens = statement.flatten()
-            for token in tokens:
-                name = None
-                if token.match(sqlparse.tokens.Keyword, 'CONSTRAINT'):
-                    # Table constraint
-                    name_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
-                    name = name_token.value[1:-1]
-                    token = next_ttype(sqlparse.tokens.Keyword)
-                if token.match(sqlparse.tokens.Keyword, 'UNIQUE'):
-                    constraints[name] = {
-                        'unique': True,
-                        'columns': [],
-                        'primary_key': False,
-                        'foreign_key': None,
-                        'check': False,
-                        'index': False,
-                    }
-                if token.match(sqlparse.tokens.Keyword, 'CHECK'):
-                    # Column check constraint
-                    if name is None:
-                        column_token = next_ttype(sqlparse.tokens.Literal.String.Symbol)
-                        column = column_token.value[1:-1]
-                        name = '__check__%s' % column
-                        columns = [column]
-                    else:
-                        columns = []
-                    constraints[name] = {
-                        'check': True,
-                        'columns': columns,
-                        'primary_key': False,
-                        'unique': False,
-                        'foreign_key': None,
-                        'index': False,
-                    }
+            constraints.update(self._parse_table_constraints(table_schema))
+
         # Get the index info
         cursor.execute("PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name))
         for row in cursor.fetchall():