Browse Source

Fixed #28433 -- Made GEOSGeometry.__eq__() work correctly with non-canonical EWKT string.

Sergey Fedoseev 7 years ago
parent
commit
5ccbcc5bf6
2 changed files with 47 additions and 8 deletions
  1. 28 8
      django/contrib/gis/geos/geometry.py
  2. 19 0
      tests/gis_tests/geos_tests/test_geos.py

+ 28 - 8
django/contrib/gis/geos/geometry.py

@@ -2,6 +2,7 @@
  This module contains the 'base' GEOSGeometry object -- all GEOS Geometries
  inherit from this object.
 """
+import re
 from ctypes import addressof, byref, c_double
 
 from django.contrib.gis import gdal
@@ -54,7 +55,7 @@ class GEOSGeometry(GEOSBase, ListMixin):
                 # Handling WKT input.
                 if wkt_m.group('srid'):
                     input_srid = int(wkt_m.group('srid'))
-                g = wkt_r().read(force_bytes(wkt_m.group('wkt')))
+                g = self._from_wkt(force_bytes(wkt_m.group('wkt')))
             elif hex_regex.match(geo_input):
                 # Handling HEXEWKB input.
                 g = wkb_r().read(force_bytes(geo_input))
@@ -163,6 +164,27 @@ class GEOSGeometry(GEOSBase, ListMixin):
     def _from_wkb(cls, wkb):
         return wkb_r().read(wkb)
 
+    @staticmethod
+    def from_ewkt(ewkt):
+        ewkt = force_bytes(ewkt)
+        srid = None
+        parts = ewkt.split(b';', 1)
+        if len(parts) == 2:
+            srid_part, wkt = parts
+            match = re.match(b'SRID=(?P<srid>\-?\d+)', srid_part)
+            if not match:
+                raise ValueError('EWKT has invalid SRID part.')
+            srid = int(match.group('srid'))
+        else:
+            wkt = ewkt
+        if not wkt:
+            raise ValueError('Expected WKT but got an empty string.')
+        return GEOSGeometry(GEOSGeometry._from_wkt(wkt), srid=srid)
+
+    @staticmethod
+    def _from_wkt(wkt):
+        return wkt_r().read(wkt)
+
     @classmethod
     def from_gml(cls, gml_string):
         return gdal.OGRGeometry.from_gml(gml_string).geos
@@ -174,13 +196,11 @@ class GEOSGeometry(GEOSBase, ListMixin):
         or an EWKT representation.
         """
         if isinstance(other, str):
-            if other.startswith('SRID=0;'):
-                return self.ewkt == other[7:]  # Test only WKT part of other
-            return self.ewkt == other
-        elif isinstance(other, GEOSGeometry):
-            return self.srid == other.srid and self.equals_exact(other)
-        else:
-            return False
+            try:
+                other = GEOSGeometry.from_ewkt(other)
+            except (ValueError, GEOSException):
+                return False
+        return isinstance(other, GEOSGeometry) and self.srid == other.srid and self.equals_exact(other)
 
     # ### Geometry set-like operations ###
     # Thanks to Sean Gillies for inspiration:

+ 19 - 0
tests/gis_tests/geos_tests/test_geos.py

@@ -179,6 +179,7 @@ class GEOSTest(SimpleTestCase, TestDataMixin):
         ls = fromstr('LINESTRING(0 0, 1 1, 5 5)')
         self.assertEqual(ls, ls.wkt)
         self.assertNotEqual(p, 'bar')
+        self.assertEqual(p, 'POINT(5.0 23.0)')
         # Error shouldn't be raise on equivalence testing with
         # an invalid type.
         for g in (p, ls):
@@ -1322,6 +1323,24 @@ class GEOSTest(SimpleTestCase, TestDataMixin):
             ),
         )
 
+    def test_from_ewkt(self):
+        self.assertEqual(GEOSGeometry.from_ewkt('SRID=1;POINT(1 1)'), Point(1, 1, srid=1))
+        self.assertEqual(GEOSGeometry.from_ewkt('POINT(1 1)'), Point(1, 1))
+
+    def test_from_ewkt_empty_string(self):
+        msg = 'Expected WKT but got an empty string.'
+        with self.assertRaisesMessage(ValueError, msg):
+            GEOSGeometry.from_ewkt('')
+        with self.assertRaisesMessage(ValueError, msg):
+            GEOSGeometry.from_ewkt('SRID=1;')
+
+    def test_from_ewkt_invalid_srid(self):
+        msg = 'EWKT has invalid SRID part.'
+        with self.assertRaisesMessage(ValueError, msg):
+            GEOSGeometry.from_ewkt('SRUD=1;POINT(1 1)')
+        with self.assertRaisesMessage(ValueError, msg):
+            GEOSGeometry.from_ewkt('SRID=WGS84;POINT(1 1)')
+
     def test_normalize(self):
         g = MultiPoint(Point(0, 0), Point(2, 2), Point(1, 1))
         self.assertIsNone(g.normalize())