Browse Source

Reverted #25961 -- Removed handling of thread-non-safe GEOS functions.

This reverts commit 312fc1af7b8a028611b45ee272652d59ea7bd85a as it seems
to cause segmentation faults as described in the ticket.
Tim Graham 9 years ago
parent
commit
59ef6559a3

+ 10 - 36
django/contrib/gis/geos/libgeos.py

@@ -9,13 +9,12 @@
 import logging
 import os
 import re
-import threading
 from ctypes import CDLL, CFUNCTYPE, POINTER, Structure, c_char_p
 from ctypes.util import find_library
 
 from django.contrib.gis.geos.error import GEOSException
 from django.core.exceptions import ImproperlyConfigured
-from django.utils.functional import SimpleLazyObject, cached_property
+from django.utils.functional import SimpleLazyObject
 
 logger = logging.getLogger('django.contrib.gis')
 
@@ -64,11 +63,10 @@ def load_geos():
     _lgeos = CDLL(lib_path)
     # Here we set up the prototypes for the initGEOS_r and finishGEOS_r
     # routines.  These functions aren't actually called until they are
-    # attached to a GEOS context handle.
+    # attached to a GEOS context handle -- this actually occurs in
+    # geos/prototypes/threadsafe.py.
     _lgeos.initGEOS_r.restype = CONTEXT_PTR
     _lgeos.finishGEOS_r.argtypes = [CONTEXT_PTR]
-    # Ensures compatibility across 32 and 64-bit platforms.
-    _lgeos.GEOSversion.restype = c_char_p
     return _lgeos
 
 
@@ -136,27 +134,6 @@ def get_pointer_arr(n):
 lgeos = SimpleLazyObject(load_geos)
 
 
-class GEOSContextHandle(object):
-    def __init__(self):
-        # Initializing the context handle for this thread with
-        # the notice and error handler.
-        self.ptr = lgeos.initGEOS_r(notice_h, error_h)
-
-    def __del__(self):
-        if self.ptr and lgeos:
-            lgeos.finishGEOS_r(self.ptr)
-
-
-class GEOSContext(threading.local):
-
-    @cached_property
-    def ptr(self):
-        # Assign handle so it will will garbage collected when
-        # thread is finished.
-        self.handle = GEOSContextHandle()
-        return self.handle.ptr
-
-
 class GEOSFuncFactory(object):
     """
     Lazy loading of GEOS functions.
@@ -164,7 +141,6 @@ class GEOSFuncFactory(object):
     argtypes = None
     restype = None
     errcheck = None
-    thread_context = GEOSContext()
 
     def __init__(self, func_name, *args, **kwargs):
         self.func_name = func_name
@@ -178,23 +154,21 @@ class GEOSFuncFactory(object):
     def __call__(self, *args, **kwargs):
         if self.func is None:
             self.func = self.get_func(*self.args, **self.kwargs)
-        # Call the threaded GEOS routine with pointer of the context handle
-        # as the first argument.
-        return self.func(self.thread_context.ptr, *args)
+        return self.func(*args, **kwargs)
 
     def get_func(self, *args, **kwargs):
-        # GEOS thread-safe function signatures end with '_r' and
-        # take an additional context handle parameter.
-        func = getattr(lgeos, self.func_name + '_r')
-        func.argtypes = [CONTEXT_PTR] + (self.argtypes or [])
+        from django.contrib.gis.geos.prototypes.threadsafe import GEOSFunc
+        func = GEOSFunc(self.func_name)
+        func.argtypes = self.argtypes or []
         func.restype = self.restype
         if self.errcheck:
             func.errcheck = self.errcheck
         return func
 
 
-# Returns the string version of the GEOS library.
-geos_version = lambda: lgeos.GEOSversion()
+# Returns the string version of the GEOS library. Have to set the restype
+# explicitly to c_char_p to ensure compatibility across 32 and 64-bit platforms.
+geos_version = GEOSFuncFactory('GEOSversion', restype=c_char_p)
 
 # Regular expression should be able to parse version strings such as
 # '3.0.0rc4-CAPI-1.3.3', '3.0.0-CAPI-1.4.1', '3.4.0dev-CAPI-1.8.0' or '3.4.0dev-CAPI-1.8.0 r0'

+ 93 - 0
django/contrib/gis/geos/prototypes/threadsafe.py

@@ -0,0 +1,93 @@
+import threading
+
+from django.contrib.gis.geos.libgeos import (
+    CONTEXT_PTR, error_h, lgeos, notice_h,
+)
+
+
+class GEOSContextHandle(object):
+    """
+    Python object representing a GEOS context handle.
+    """
+    def __init__(self):
+        # Initializing the context handler for this thread with
+        # the notice and error handler.
+        self.ptr = lgeos.initGEOS_r(notice_h, error_h)
+
+    def __del__(self):
+        if self.ptr and lgeos:
+            lgeos.finishGEOS_r(self.ptr)
+
+
+# Defining a thread-local object and creating an instance
+# to hold a reference to GEOSContextHandle for this thread.
+class GEOSContext(threading.local):
+    handle = None
+
+thread_context = GEOSContext()
+
+
+class GEOSFunc(object):
+    """
+    Class that serves as a wrapper for GEOS C Functions, and will
+    use thread-safe function variants when available.
+    """
+    def __init__(self, func_name):
+        try:
+            # GEOS thread-safe function signatures end with '_r', and
+            # take an additional context handle parameter.
+            self.cfunc = getattr(lgeos, func_name + '_r')
+            self.threaded = True
+            # Create a reference here to thread_context so it's not
+            # garbage-collected before an attempt to call this object.
+            self.thread_context = thread_context
+        except AttributeError:
+            # Otherwise, use usual function.
+            self.cfunc = getattr(lgeos, func_name)
+            self.threaded = False
+
+    def __call__(self, *args):
+        if self.threaded:
+            # If a context handle does not exist for this thread, initialize one.
+            if not self.thread_context.handle:
+                self.thread_context.handle = GEOSContextHandle()
+            # Call the threaded GEOS routine with pointer of the context handle
+            # as the first argument.
+            return self.cfunc(self.thread_context.handle.ptr, *args)
+        else:
+            return self.cfunc(*args)
+
+    def __str__(self):
+        return self.cfunc.__name__
+
+    # argtypes property
+    def _get_argtypes(self):
+        return self.cfunc.argtypes
+
+    def _set_argtypes(self, argtypes):
+        if self.threaded:
+            new_argtypes = [CONTEXT_PTR]
+            new_argtypes.extend(argtypes)
+            self.cfunc.argtypes = new_argtypes
+        else:
+            self.cfunc.argtypes = argtypes
+
+    argtypes = property(_get_argtypes, _set_argtypes)
+
+    # restype property
+    def _get_restype(self):
+        return self.cfunc.restype
+
+    def _set_restype(self, restype):
+        self.cfunc.restype = restype
+
+    restype = property(_get_restype, _set_restype)
+
+    # errcheck property
+    def _get_errcheck(self):
+        return self.cfunc.errcheck
+
+    def _set_errcheck(self, errcheck):
+        self.cfunc.errcheck = errcheck
+
+    errcheck = property(_get_errcheck, _set_errcheck)

+ 1 - 44
tests/gis_tests/geos_tests/test_geos.py

@@ -3,7 +3,6 @@ from __future__ import unicode_literals
 import ctypes
 import json
 import random
-import threading
 from binascii import a2b_hex, b2a_hex
 from io import BytesIO
 from unittest import skipUnless
@@ -13,7 +12,7 @@ from django.contrib.gis.gdal import HAS_GDAL
 from django.contrib.gis.geos import (
     HAS_GEOS, GeometryCollection, GEOSException, GEOSGeometry, LinearRing,
     LineString, MultiLineString, MultiPoint, MultiPolygon, Point, Polygon,
-    fromfile, fromstr, libgeos,
+    fromfile, fromstr,
 )
 from django.contrib.gis.geos.base import GEOSBase
 from django.contrib.gis.geos.libgeos import geos_version_info
@@ -1233,48 +1232,6 @@ class GEOSTest(SimpleTestCase, TestDataMixin):
             self.assertEqual(m.group('version'), v_geos)
             self.assertEqual(m.group('capi_version'), v_capi)
 
-    def test_geos_threads(self):
-        pnt = Point()
-        context_ptrs = []
-
-        geos_init = libgeos.lgeos.initGEOS_r
-        geos_finish = libgeos.lgeos.finishGEOS_r
-
-        def init(*args, **kwargs):
-            result = geos_init(*args, **kwargs)
-            context_ptrs.append(result)
-            return result
-
-        def finish(*args, **kwargs):
-            result = geos_finish(*args, **kwargs)
-            destructor_called.set()
-            return result
-
-        for i in range(2):
-            destructor_called = threading.Event()
-            patch_path = 'django.contrib.gis.geos.libgeos.lgeos'
-            with mock.patch.multiple(patch_path, initGEOS_r=mock.DEFAULT, finishGEOS_r=mock.DEFAULT) as mocked:
-                mocked['initGEOS_r'].side_effect = init
-                mocked['finishGEOS_r'].side_effect = finish
-                with mock.patch('django.contrib.gis.geos.prototypes.predicates.geos_hasz.func') as mocked_hasz:
-                    thread = threading.Thread(target=lambda: pnt.hasz)
-                    thread.start()
-                    thread.join()
-
-                    # We can't be sure that members of thread locals are
-                    # garbage collected right after `thread.join()` so
-                    # we must wait until destructor is actually called.
-                    # Fail if destructor wasn't called within a second.
-                    self.assertTrue(destructor_called.wait(1))
-
-                    context_ptr = context_ptrs[i]
-                    self.assertIsInstance(context_ptr, libgeos.CONTEXT_PTR)
-                    mocked_hasz.assert_called_once_with(context_ptr, pnt.ptr)
-                    mocked['finishGEOS_r'].assert_called_once_with(context_ptr)
-
-        # Check that different contexts were used for the different threads.
-        self.assertNotEqual(context_ptrs[0], context_ptrs[1])
-
     @ignore_warnings(category=RemovedInDjango20Warning)
     def test_deprecated_srid_getters_setters(self):
         p = Point(1, 2, srid=123)