Ver código fonte

Fixed #31766 -- Made GDALRaster.transform() return a clone for the same SRID and driver.

Thanks Daniel Wiesmann for the review.
Barton Ip 4 anos atrás
pai
commit
12d6cae7c0

+ 1 - 0
AUTHORS

@@ -110,6 +110,7 @@ answer newbie questions, and generally made Django that much better:
     Baptiste Mispelon <bmispelon@gmail.com>
     Barry Pederson <bp@barryp.org>
     Bartolome Sanchez Salado <i42sasab@uco.es>
+    Barton Ip <notbartonip@gmail.com>
     Bartosz Grabski <bartosz.grabski@gmail.com>
     Bashar Al-Abdulhadi
     Bastian Kleineidam <calvin@debian.org>

+ 24 - 0
django/contrib/gis/gdal/raster/source.py

@@ -425,6 +425,27 @@ class GDALRaster(GDALRasterBase):
 
         return target
 
+    def clone(self, name=None):
+        """Return a clone of this GDALRaster."""
+        if name:
+            clone_name = name
+        elif self.driver.name != 'MEM':
+            clone_name = self.name + '_copy.' + self.driver.name
+        else:
+            clone_name = os.path.join(VSI_FILESYSTEM_BASE_PATH, str(uuid.uuid4()))
+        return GDALRaster(
+            capi.copy_ds(
+                self.driver._ptr,
+                force_bytes(clone_name),
+                self._ptr,
+                c_int(),
+                c_char_p(),
+                c_void_p(),
+                c_void_p(),
+            ),
+            write=self._write,
+        )
+
     def transform(self, srs, driver=None, name=None, resampling='NearestNeighbour',
                   max_error=0.0):
         """
@@ -443,6 +464,9 @@ class GDALRaster(GDALRasterBase):
                 'Transform only accepts SpatialReference, string, and integer '
                 'objects.'
             )
+
+        if target_srs.srid == self.srid and (not driver or driver == self.driver.name):
+            return self.clone(name)
         # Create warped virtual dataset in the target reference system
         target = capi.auto_create_warped_vrt(
             self._ptr, self.srs.wkt.encode(), target_srs.wkt.encode(),

+ 83 - 0
tests/gis_tests/gdal_tests/test_raster.py

@@ -2,6 +2,7 @@ import os
 import shutil
 import struct
 import tempfile
+from unittest import mock
 
 from django.contrib.gis.gdal import GDAL_VERSION, GDALRaster, SpatialReference
 from django.contrib.gis.gdal.error import GDALException
@@ -470,6 +471,40 @@ class GDALRasterTests(SimpleTestCase):
         # The result is an empty raster filled with the correct nodata value.
         self.assertEqual(result, [23] * 16)
 
+    def test_raster_clone(self):
+        rstfile = tempfile.NamedTemporaryFile(suffix='.tif')
+        tests = [
+            ('MEM', '', 23),  # In memory raster.
+            ('tif', rstfile.name, 99),  # In file based raster.
+        ]
+        for driver, name, nodata_value in tests:
+            with self.subTest(driver=driver):
+                source = GDALRaster({
+                    'datatype': 1,
+                    'driver': driver,
+                    'name': name,
+                    'width': 4,
+                    'height': 4,
+                    'srid': 3086,
+                    'origin': (500000, 400000),
+                    'scale': (100, -100),
+                    'skew': (0, 0),
+                    'bands': [{
+                        'data': range(16),
+                        'nodata_value': nodata_value,
+                    }],
+                })
+                clone = source.clone()
+                self.assertNotEqual(clone.name, source.name)
+                self.assertEqual(clone._write, source._write)
+                self.assertEqual(clone.srs.srid, source.srs.srid)
+                self.assertEqual(clone.width, source.width)
+                self.assertEqual(clone.height, source.height)
+                self.assertEqual(clone.origin, source.origin)
+                self.assertEqual(clone.scale, source.scale)
+                self.assertEqual(clone.skew, source.skew)
+                self.assertIsNot(clone, source)
+
     def test_raster_transform(self):
         tests = [
             3086,
@@ -531,6 +566,54 @@ class GDALRasterTests(SimpleTestCase):
                     ],
                 )
 
+    def test_raster_transform_clone(self):
+        with mock.patch.object(GDALRaster, 'clone') as mocked_clone:
+            # Create in file based raster.
+            rstfile = tempfile.NamedTemporaryFile(suffix='.tif')
+            source = GDALRaster({
+                'datatype': 1,
+                'driver': 'tif',
+                'name': rstfile.name,
+                'width': 5,
+                'height': 5,
+                'nr_of_bands': 1,
+                'srid': 4326,
+                'origin': (-5, 5),
+                'scale': (2, -2),
+                'skew': (0, 0),
+                'bands': [{
+                    'data': range(25),
+                    'nodata_value': 99,
+                }],
+            })
+            # transform() returns a clone because it is the same SRID and
+            # driver.
+            source.transform(4326)
+            self.assertEqual(mocked_clone.call_count, 1)
+
+    def test_raster_transform_clone_name(self):
+        # Create in file based raster.
+        rstfile = tempfile.NamedTemporaryFile(suffix='.tif')
+        source = GDALRaster({
+            'datatype': 1,
+            'driver': 'tif',
+            'name': rstfile.name,
+            'width': 5,
+            'height': 5,
+            'nr_of_bands': 1,
+            'srid': 4326,
+            'origin': (-5, 5),
+            'scale': (2, -2),
+            'skew': (0, 0),
+            'bands': [{
+                'data': range(25),
+                'nodata_value': 99,
+            }],
+        })
+        clone_name = rstfile.name + '_respect_name.GTiff'
+        target = source.transform(4326, name=clone_name)
+        self.assertEqual(target.name, clone_name)
+
 
 class GDALBandTests(SimpleTestCase):
     rs_path = os.path.join(os.path.dirname(__file__), '../data/rasters/raster.tif')