Browse Source

Add tests for HTTP server.

These are unittests for the functionality in web.py, not end-to-end integration
tests. Fixed two bugs where the smart HTTP handlers wouldn't correctly forbid
unknown RPC services. Refactored for testing by injecting a dict of services
into smart HTTP handlers and rewriting tag-peeling code to not depend on objects
having a particular type.

Change-Id: I42a0ed89781687a655b5803103eb21c1b62cee83
Dave Borowitz 15 years ago
parent
commit
076b8d5a7b
2 changed files with 303 additions and 5 deletions
  1. 289 0
      dulwich/tests/test_web.py
  2. 14 5
      dulwich/web.py

+ 289 - 0
dulwich/tests/test_web.py

@@ -0,0 +1,289 @@
+# test_web.py -- Tests for the git HTTP server
+# Copryight (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# or (at your option) any later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Tests for the Git HTTP server."""
+
+from cStringIO import StringIO
+import re
+from unittest import TestCase
+
+from dulwich.objects import (
+    type_map,
+    Tag,
+    Blob,
+    )
+from dulwich.web import (
+    HTTP_OK,
+    HTTP_NOT_FOUND,
+    HTTP_FORBIDDEN,
+    send_file,
+    get_info_refs,
+    handle_service_request,
+    _LengthLimitedFile,
+    HTTPGitRequest,
+    HTTPGitApplication,
+    )
+
+
+class WebTestCase(TestCase):
+    """Base TestCase that sets up some useful instance vars."""
+    def setUp(self):
+        self._environ = {}
+        self._req = HTTPGitRequest(self._environ, self._start_response)
+        self._status = None
+        self._headers = []
+
+    def _start_response(self, status, headers):
+        self._status = status
+        self._headers = list(headers)
+
+
+class DumbHandlersTestCase(WebTestCase):
+
+    def test_send_file_not_found(self):
+        list(send_file(self._req, None, 'text/plain'))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+
+    def test_send_file(self):
+        f = StringIO('foobar')
+        output = ''.join(send_file(self._req, f, 'text/plain'))
+        self.assertEquals('foobar', output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertTrue(('Content-Type', 'text/plain') in self._headers)
+        self.assertTrue(f.closed)
+
+    def test_send_file_buffered(self):
+        bufsize = 10240
+        xs = 'x' * bufsize
+        f = StringIO(2 * xs)
+        self.assertEquals([xs, xs],
+                          list(send_file(self._req, f, 'text/plain')))
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertTrue(('Content-Type', 'text/plain') in self._headers)
+        self.assertTrue(f.closed)
+
+    def test_send_file_error(self):
+        class TestFile(object):
+            def __init__(self):
+                self.closed = False
+
+            def read(self, size=-1):
+                raise IOError
+
+            def close(self):
+                self.closed = True
+
+        f = TestFile()
+        list(send_file(self._req, f, 'text/plain'))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+        self.assertTrue(f.closed)
+
+    def test_get_info_refs(self):
+        self._environ['QUERY_STRING'] = ''
+
+        class TestTag(object):
+            type = Tag().type
+
+            def __init__(self, sha, obj_type, obj_sha):
+                self.sha = lambda: sha
+                self.object = (obj_type, obj_sha)
+
+        class TestBlob(object):
+            type = Blob().type
+
+            def __init__(self, sha):
+                self.sha = lambda: sha
+
+        blob1 = TestBlob('111')
+        blob2 = TestBlob('222')
+        blob3 = TestBlob('333')
+
+        tag1 = TestTag('aaa', TestTag.type, 'bbb')
+        tag2 = TestTag('bbb', TestBlob.type, '222')
+
+        class TestBackend(object):
+            def __init__(self):
+                objects = [blob1, blob2, blob3, tag1, tag2]
+                self.repo = dict((o.sha(), o) for o in objects)
+
+            def get_refs(self):
+                return {
+                    'HEAD': '000',
+                    'refs/heads/master': blob1.sha(),
+                    'refs/tags/tag-tag': tag1.sha(),
+                    'refs/tags/blob-tag': blob3.sha(),
+                    }
+
+        self.assertEquals(['111\trefs/heads/master\n',
+                           '333\trefs/tags/blob-tag\n',
+                           'aaa\trefs/tags/tag-tag\n',
+                           '222\trefs/tags/tag-tag^{}\n'],
+                          list(get_info_refs(self._req, TestBackend(), None)))
+
+
+class SmartHandlersTestCase(WebTestCase):
+
+    class TestProtocol(object):
+        def __init__(self, handler):
+            self._handler = handler
+
+        def write_pkt_line(self, line):
+            if line is None:
+                self._handler.write('flush-pkt\n')
+            else:
+                self._handler.write('pkt-line: %s' % line)
+
+    class _TestUploadPackHandler(object):
+        def __init__(self, backend, read, write, stateless_rpc=False,
+                     advertise_refs=False):
+            self.read = read
+            self.write = write
+            self.proto = SmartHandlersTestCase.TestProtocol(self)
+            self.stateless_rpc = stateless_rpc
+            self.advertise_refs = advertise_refs
+
+        def handle(self):
+            self.write('handled input: %s' % self.read())
+
+    def _MakeHandler(self, *args, **kwargs):
+        self._handler = self._TestUploadPackHandler(*args, **kwargs)
+        return self._handler
+
+    def services(self):
+        return {'git-upload-pack': self._MakeHandler}
+
+    def test_handle_service_request_unknown(self):
+        mat = re.search('.*', '/git-evil-handler')
+        list(handle_service_request(self._req, 'backend', mat))
+        self.assertEquals(HTTP_FORBIDDEN, self._status)
+
+    def test_handle_service_request(self):
+        self._environ['wsgi.input'] = StringIO('foo')
+        mat = re.search('.*', '/git-upload-pack')
+        output = ''.join(handle_service_request(self._req, 'backend', mat,
+                                                services=self.services()))
+        self.assertEqual('handled input: foo', output)
+        response_type = 'application/x-git-upload-pack-response'
+        self.assertTrue(('Content-Type', response_type) in self._headers)
+        self.assertFalse(self._handler.advertise_refs)
+        self.assertTrue(self._handler.stateless_rpc)
+
+    def test_handle_service_request_with_length(self):
+        self._environ['wsgi.input'] = StringIO('foobar')
+        self._environ['CONTENT_LENGTH'] = 3
+        mat = re.search('.*', '/git-upload-pack')
+        output = ''.join(handle_service_request(self._req, 'backend', mat,
+                                                services=self.services()))
+        self.assertEqual('handled input: foo', output)
+        response_type = 'application/x-git-upload-pack-response'
+        self.assertTrue(('Content-Type', response_type) in self._headers)
+
+    def test_get_info_refs_unknown(self):
+        self._environ['QUERY_STRING'] = 'service=git-evil-handler'
+        list(get_info_refs(self._req, 'backend', None,
+                           services=self.services()))
+        self.assertEquals(HTTP_FORBIDDEN, self._status)
+
+    def test_get_info_refs(self):
+        self._environ['wsgi.input'] = StringIO('foo')
+        self._environ['QUERY_STRING'] = 'service=git-upload-pack'
+
+        output = ''.join(get_info_refs(self._req, 'backend', None,
+                                       services=self.services()))
+        self.assertEquals(('pkt-line: # service=git-upload-pack\n'
+                           'flush-pkt\n'
+                           # input is ignored by the handler
+                           'handled input: '), output)
+        self.assertTrue(self._handler.advertise_refs)
+        self.assertTrue(self._handler.stateless_rpc)
+
+
+class LengthLimitedFileTestCase(TestCase):
+    def test_no_cutoff(self):
+        f = _LengthLimitedFile(StringIO('foobar'), 1024)
+        self.assertEquals('foobar', f.read())
+
+    def test_cutoff(self):
+        f = _LengthLimitedFile(StringIO('foobar'), 3)
+        self.assertEquals('foo', f.read())
+        self.assertEquals('', f.read())
+
+    def test_multiple_reads(self):
+        f = _LengthLimitedFile(StringIO('foobar'), 3)
+        self.assertEquals('fo', f.read(2))
+        self.assertEquals('o', f.read(2))
+        self.assertEquals('', f.read())
+
+
+class HTTPGitRequestTestCase(WebTestCase):
+    def test_not_found(self):
+        self._req.cache_forever()  # cache headers should be discarded
+        message = 'Something not found'
+        self.assertEquals(message, self._req.not_found(message))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+        self.assertEquals(set([('Content-Type', 'text/plain')]),
+                          set(self._headers))
+
+    def test_forbidden(self):
+        self._req.cache_forever()  # cache headers should be discarded
+        message = 'Something not found'
+        self.assertEquals(message, self._req.forbidden(message))
+        self.assertEquals(HTTP_FORBIDDEN, self._status)
+        self.assertEquals(set([('Content-Type', 'text/plain')]),
+                          set(self._headers))
+
+    def test_respond_ok(self):
+        self._req.respond()
+        self.assertEquals([], self._headers)
+        self.assertEquals(HTTP_OK, self._status)
+
+    def test_respond(self):
+        self._req.nocache()
+        self._req.respond(status=402, content_type='some/type',
+                          headers=[('X-Foo', 'foo'), ('X-Bar', 'bar')])
+        self.assertEquals(set([
+            ('X-Foo', 'foo'),
+            ('X-Bar', 'bar'),
+            ('Content-Type', 'some/type'),
+            ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
+            ('Pragma', 'no-cache'),
+            ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
+            ]), set(self._headers))
+        self.assertEquals(402, self._status)
+
+
+class HTTPGitApplicationTestCase(TestCase):
+    def setUp(self):
+        self._app = HTTPGitApplication('backend')
+
+    def test_call(self):
+        def test_handler(req, backend, mat):
+            # tests interface used by all handlers
+            self.assertEquals(environ, req.environ)
+            self.assertEquals('backend', backend)
+            self.assertEquals('/foo', mat.group(0))
+            return 'output'
+
+        self._app.services = {
+            ('GET', re.compile('/foo$')): test_handler,
+        }
+        environ = {
+            'PATH_INFO': '/foo',
+            'REQUEST_METHOD': 'GET',
+            }
+        self.assertEquals('output', self._app(environ, None))

+ 14 - 5
dulwich/web.py

@@ -26,6 +26,7 @@ import time
 
 from dulwich.objects import (
     Tag,
+    num_type_map,
     )
 from dulwich.repo import (
     Repo,
@@ -114,13 +115,16 @@ def get_idx_file(req, backend, mat):
 
 services = {'git-upload-pack': UploadPackHandler,
             'git-receive-pack': ReceivePackHandler}
-def get_info_refs(req, backend, mat):
+def get_info_refs(req, backend, mat, services=None):
+    if services is None:
+        services = services
     params = cgi.parse_qs(req.environ['QUERY_STRING'])
     service = params.get('service', [None])[0]
     if service:
         handler_cls = services.get(service, None)
         if handler_cls is None:
             yield req.forbidden('Unsupported service %s' % service)
+            return
         req.nocache()
         req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
         output = StringIO()
@@ -147,9 +151,11 @@ def get_info_refs(req, backend, mat):
             if not o:
                 continue
             yield '%s\t%s\n' % (sha, name)
-            if isinstance(o, Tag):
-                while isinstance(o, Tag):
-                    _, sha = o.object
+            obj_type = num_type_map[o.type]
+            if obj_type == Tag:
+                while obj_type == Tag:
+                    num_type, sha = o.object
+                    obj_type = num_type_map[num_type]
                     o = backend.repo[sha]
                 if not o:
                     continue
@@ -184,11 +190,14 @@ class _LengthLimitedFile(object):
 
     # TODO: support more methods as necessary
 
-def handle_service_request(req, backend, mat):
+def handle_service_request(req, backend, mat, services=services):
+    if services is None:
+        services = services
     service = mat.group().lstrip('/')
     handler_cls = services.get(service, None)
     if handler_cls is None:
         yield req.forbidden('Unsupported service %s' % service)
+        return
     req.nocache()
     req.respond(HTTP_OK, 'application/x-%s-response' % service)