Selaa lähdekoodia

Fixed #17687 -- Made LayerMapping router-aware

Thanks nosamanuel@gmail.com for the report and the initial patch.
Claude Paroz 12 vuotta sitten
vanhempi
commit
7e32dab3a6

+ 30 - 0
django/contrib/gis/tests/layermap/tests.py

@@ -8,6 +8,7 @@ from django.contrib.gis.gdal import DataSource
 from django.contrib.gis.tests.utils import mysql
 from django.contrib.gis.utils.layermapping import (LayerMapping, LayerMapError,
     InvalidDecimal, MissingForeignKey)
+from django.db import router
 from django.test import TestCase
 
 from .models import (
@@ -26,6 +27,7 @@ NAMES  = ['Bexar', 'Galveston', 'Harris', 'Honolulu', 'Pueblo']
 NUMS   = [1, 2, 1, 19, 1] # Number of polygons for each.
 STATES = ['Texas', 'Texas', 'Texas', 'Hawaii', 'Colorado']
 
+
 class LayerMapTest(TestCase):
 
     def test_init(self):
@@ -281,3 +283,31 @@ class LayerMapTest(TestCase):
         lm.save(silent=True, strict=True)
         self.assertEqual(City.objects.count(), 3)
         self.assertEqual(City.objects.all().order_by('name_txt')[0].name_txt, "Houston")
+
+
+class OtherRouter(object):
+    def db_for_read(self, model, **hints):
+        return 'other'
+
+    def db_for_write(self, model, **hints):
+        return self.db_for_read(model, **hints)
+
+    def allow_relation(self, obj1, obj2, **hints):
+        return None
+
+    def allow_syncdb(self, db, model):
+        return True
+
+
+class LayerMapRouterTest(TestCase):
+
+    def setUp(self):
+        self.old_routers = router.routers
+        router.routers = [OtherRouter()]
+
+    def tearDown(self):
+        router.routers = self.old_routers
+
+    def test_layermapping_default_db(self):
+        lm = LayerMapping(City, city_shp, city_mapping)
+        self.assertEqual(lm.using, 'other')

+ 4 - 4
django/contrib/gis/utils/layermapping.py

@@ -9,7 +9,7 @@
 import sys
 from decimal import Decimal
 from django.core.exceptions import ObjectDoesNotExist
-from django.db import connections, DEFAULT_DB_ALIAS
+from django.db import connections, router
 from django.contrib.gis.db.models import GeometryField
 from django.contrib.gis.gdal import (CoordTransform, DataSource,
     OGRException, OGRGeometry, OGRGeomType, SpatialReference)
@@ -67,7 +67,7 @@ class LayerMapping(object):
     def __init__(self, model, data, mapping, layer=0,
                  source_srs=None, encoding=None,
                  transaction_mode='commit_on_success',
-                 transform=True, unique=None, using=DEFAULT_DB_ALIAS):
+                 transform=True, unique=None, using=None):
         """
         A LayerMapping object is initialized using the given Model (not an instance),
         a DataSource (or string path to an OGR-supported data file), and a mapping
@@ -81,8 +81,8 @@ class LayerMapping(object):
             self.ds = data
         self.layer = self.ds[layer]
 
-        self.using = using
-        self.spatial_backend = connections[using].ops
+        self.using = using if using is not None else router.db_for_write(model)
+        self.spatial_backend = connections[self.using].ops
 
         # Setting the mapping & model attributes.
         self.mapping = mapping