ソースを参照

Fix #140: Don't expect 'wsgi.input' to have 'seek' method

The WSGI specification does not require 'wsgi.input' to have a 'seek'
method.  However, the gunzipping code in the standard library of Python
2.x uses that method.  We work around this by buffering data from a
'wsgi.input' that lacks this method into a temporary file.

Signed-off-by: Jelmer Vernooij <jelmer@samba.org>
Jonas Haag 11 年 前
コミット
5a69c460b6
2 ファイル変更49 行追加9 行削除
  1. 33 6
      dulwich/tests/test_web.py
  2. 16 3
      dulwich/web.py

+ 33 - 6
dulwich/tests/test_web.py

@@ -457,6 +457,7 @@ class GunzipTestCase(HTTPGitApplicationTestCase):
     """TestCase for testing the GunzipFilter, ensuring the wsgi.input
     """TestCase for testing the GunzipFilter, ensuring the wsgi.input
     is correctly decompressed and headers are corrected.
     is correctly decompressed and headers are corrected.
     """
     """
+    example_text = __doc__
 
 
     def setUp(self):
     def setUp(self):
         super(GunzipTestCase, self).setUp()
         super(GunzipTestCase, self).setUp()
@@ -469,14 +470,12 @@ class GunzipTestCase(HTTPGitApplicationTestCase):
         zfile = gzip.GzipFile(fileobj=zstream, mode='w')
         zfile = gzip.GzipFile(fileobj=zstream, mode='w')
         zfile.write(text)
         zfile.write(text)
         zfile.close()
         zfile.close()
-        return zstream
-
-    def test_call(self):
-        self._add_handler(self._app.app)
-        orig = self.__class__.__doc__
-        zstream = self._get_zstream(orig)
         zlength = zstream.tell()
         zlength = zstream.tell()
         zstream.seek(0)
         zstream.seek(0)
+        return zstream, zlength
+
+    def _test_call(self, orig, zstream, zlength):
+        self._add_handler(self._app.app)
         self.assertLess(zlength, len(orig))
         self.assertLess(zlength, len(orig))
         self.assertEqual(self._environ['HTTP_CONTENT_ENCODING'], 'gzip')
         self.assertEqual(self._environ['HTTP_CONTENT_ENCODING'], 'gzip')
         self._environ['CONTENT_LENGTH'] = zlength
         self._environ['CONTENT_LENGTH'] = zlength
@@ -488,3 +487,31 @@ class GunzipTestCase(HTTPGitApplicationTestCase):
         self.assertEqual(orig, buf.read())
         self.assertEqual(orig, buf.read())
         self.assertIs(None, self._environ.get('CONTENT_LENGTH'))
         self.assertIs(None, self._environ.get('CONTENT_LENGTH'))
         self.assertNotIn('HTTP_CONTENT_ENCODING', self._environ)
         self.assertNotIn('HTTP_CONTENT_ENCODING', self._environ)
+
+    def test_call(self):
+        self._test_call(
+            self.example_text,
+            *self._get_zstream(self.example_text)
+        )
+
+    def test_call_no_seek(self):
+        """
+        This ensures that the gunzipping code doesn't require any methods on
+        'wsgi.input' except for '.read()'.  (In particular, it shouldn't
+        require '.seek()'. See https://github.com/jelmer/dulwich/issues/140.)
+        """
+        class MinimalistWSGIInputStream(object):
+            def __init__(self, data):
+                self.data = data
+                self.pos = 0
+
+            def read(self, howmuch):
+                start = self.pos
+                end = self.pos + howmuch
+                if start >= len(self.data):
+                    return ''
+                self.pos = end
+                return self.data[start:end]
+
+        zstream, zlength = self._get_zstream(self.example_text)
+        self._test_call(self.example_text, MinimalistWSGIInputStream(zstream.read()), zlength)

+ 16 - 3
dulwich/web.py

@@ -20,6 +20,8 @@
 """HTTP server for dulwich that implements the git smart HTTP protocol."""
 """HTTP server for dulwich that implements the git smart HTTP protocol."""
 
 
 from cStringIO import StringIO
 from cStringIO import StringIO
+import shutil
+import tempfile
 import gzip
 import gzip
 import os
 import os
 import re
 import re
@@ -358,11 +360,22 @@ class GunzipFilter(object):
 
 
     def __call__(self, environ, start_response):
     def __call__(self, environ, start_response):
         if environ.get('HTTP_CONTENT_ENCODING', '') == 'gzip':
         if environ.get('HTTP_CONTENT_ENCODING', '') == 'gzip':
-            environ.pop('HTTP_CONTENT_ENCODING')
+            if hasattr(environ['wsgi.input'], 'seek'):
+                wsgi_input = environ['wsgi.input']
+            else:
+                # The gzip implementation in the standard library of Python 2.x
+                # requires the '.seek()' and '.tell()' methods to be available
+                # on the input stream.  Read the data into a temporary file to
+                # work around this limitation.
+                wsgi_input = tempfile.SpooledTemporaryFile(16 * 1024 * 1024)
+                shutil.copyfileobj(environ['wsgi.input'], wsgi_input)
+                wsgi_input.seek(0)
+
+            environ['wsgi.input'] = gzip.GzipFile(filename=None, fileobj=wsgi_input, mode='r')
+            del environ['HTTP_CONTENT_ENCODING']
             if 'CONTENT_LENGTH' in environ:
             if 'CONTENT_LENGTH' in environ:
                 del environ['CONTENT_LENGTH']
                 del environ['CONTENT_LENGTH']
-            environ['wsgi.input'] = gzip.GzipFile(filename=None,
-                fileobj=environ['wsgi.input'], mode='r')
+
         return self.app(environ, start_response)
         return self.app(environ, start_response)