Browse Source

Merge server capability refactoring from Dave.

Jelmer Vernooij 15 years ago
parent
commit
dc0479f0bb
2 changed files with 57 additions and 44 deletions
  1. 30 24
      dulwich/server.py
  2. 27 20
      dulwich/tests/test_server.py

+ 30 - 24
dulwich/server.py

@@ -152,9 +152,27 @@ class Handler(object):
     def __init__(self, backend, read, write):
     def __init__(self, backend, read, write):
         self.backend = backend
         self.backend = backend
         self.proto = Protocol(read, write)
         self.proto = Protocol(read, write)
+        self._client_capabilities = None
+
+    def capability_line(self):
+        return " ".join(self.capabilities())
 
 
     def capabilities(self):
     def capabilities(self):
-        return " ".join(self.default_capabilities())
+        raise NotImplementedError(self.capabilities)
+
+    def set_client_capabilities(self, caps):
+        my_caps = self.capabilities()
+        for cap in caps:
+            if cap not in my_caps:
+                raise GitProtocolError('Client asked for capability %s that '
+                                       'was not advertised.' % cap)
+        self._client_capabilities = caps
+
+    def has_capability(self, cap):
+        if self._client_capabilities is None:
+            raise GitProtocolError('Server attempted to access capability %s '
+                                   'before asking client' % cap)
+        return cap in self._client_capabilities
 
 
 
 
 class UploadPackHandler(Handler):
 class UploadPackHandler(Handler):
@@ -163,29 +181,14 @@ class UploadPackHandler(Handler):
     def __init__(self, backend, read, write,
     def __init__(self, backend, read, write,
                  stateless_rpc=False, advertise_refs=False):
                  stateless_rpc=False, advertise_refs=False):
         Handler.__init__(self, backend, read, write)
         Handler.__init__(self, backend, read, write)
-        self._client_capabilities = None
         self._graph_walker = None
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
 
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
                 "ofs-delta")
                 "ofs-delta")
 
 
-    def set_client_capabilities(self, caps):
-        my_caps = self.default_capabilities()
-        for cap in caps:
-            if '_ack' in cap and cap not in my_caps:
-                raise GitProtocolError('Client asked for capability %s that '
-                                       'was not advertised.' % cap)
-        self._client_capabilities = caps
-
-    def get_client_capabilities(self):
-        return self._client_capabilities
-
-    client_capabilities = property(get_client_capabilities,
-                                   set_client_capabilities)
-
     def handle(self):
     def handle(self):
 
 
         progress = lambda x: self.proto.write_sideband(2, x)
         progress = lambda x: self.proto.write_sideband(2, x)
@@ -251,7 +254,7 @@ class ProtocolGraphWalker(object):
             for i, (ref, sha) in enumerate(heads.iteritems()):
             for i, (ref, sha) in enumerate(heads.iteritems()):
                 line = "%s %s" % (sha, ref)
                 line = "%s %s" % (sha, ref)
                 if not i:
                 if not i:
-                    line = "%s\x00%s" % (line, self.handler.capabilities())
+                    line = "%s\x00%s" % (line, self.handler.capability_line())
                 self.proto.write_pkt_line("%s\n" % line)
                 self.proto.write_pkt_line("%s\n" % line)
                 # TODO: include peeled value of any tags
                 # TODO: include peeled value of any tags
 
 
@@ -266,7 +269,7 @@ class ProtocolGraphWalker(object):
         if not want:
         if not want:
             return []
             return []
         line, caps = extract_want_line_capabilities(want)
         line, caps = extract_want_line_capabilities(want)
-        self.handler.client_capabilities = caps
+        self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
         self.set_ack_type(ack_type(caps))
         command, sha = self._split_proto_line(line)
         command, sha = self._split_proto_line(line)
 
 
@@ -509,7 +512,7 @@ class ReceivePackHandler(Handler):
         self._stateless_rpc = stateless_rpc
         self._stateless_rpc = stateless_rpc
         self._advertise_refs = advertise_refs
         self._advertise_refs = advertise_refs
 
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("report-status", "delete-refs")
         return ("report-status", "delete-refs")
 
 
     def handle(self):
     def handle(self):
@@ -517,12 +520,14 @@ class ReceivePackHandler(Handler):
 
 
         if self.advertise_refs or not self.stateless_rpc:
         if self.advertise_refs or not self.stateless_rpc:
             if refs:
             if refs:
-                self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
+                self.proto.write_pkt_line(
+                    "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
+                                       self.capability_line()))
                 for i in range(1, len(refs)):
                 for i in range(1, len(refs)):
                     ref = refs[i]
                     ref = refs[i]
                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
             else:
             else:
-                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
+                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capability_line())
 
 
             self.proto.write("0000")
             self.proto.write("0000")
             if self.advertise_refs:
             if self.advertise_refs:
@@ -535,7 +540,8 @@ class ReceivePackHandler(Handler):
         if ref is None:
         if ref is None:
             return
             return
 
 
-        ref, client_capabilities = extract_capabilities(ref)
+        ref, caps = extract_capabilities(ref)
+        self.set_client_capabilities(caps)
 
 
         # client will now send us a list of (oldsha, newsha, ref)
         # client will now send us a list of (oldsha, newsha, ref)
         while ref:
         while ref:
@@ -547,7 +553,7 @@ class ReceivePackHandler(Handler):
 
 
         # when we have read all the pack from the client, send a status report
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
         # if the client asked for it
-        if 'report-status' in client_capabilities:
+        if self.has_capability('report-status'):
             for name, msg in status:
             for name, msg in status:
                 if name == 'unpack':
                 if name == 'unpack':
                     self.proto.write_pkt_line('unpack %s\n' % msg)
                     self.proto.write_pkt_line('unpack %s\n' % msg)

+ 27 - 20
dulwich/tests/test_server.py

@@ -28,6 +28,7 @@ from dulwich.errors import (
     )
     )
 from dulwich.server import (
 from dulwich.server import (
     UploadPackHandler,
     UploadPackHandler,
+    Handler,
     ProtocolGraphWalker,
     ProtocolGraphWalker,
     SingleAckGraphWalkerImpl,
     SingleAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
@@ -75,30 +76,36 @@ class TestProto(object):
             return None
             return None
 
 
 
 
-class UploadPackHandlerTestCase(TestCase):
+class HandlerTestCase(TestCase):
     def setUp(self):
     def setUp(self):
-        self._handler = UploadPackHandler(None, None, None)
+        self._handler = Handler(None, None, None)
+        self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
 
 
-    def test_set_client_capabilities(self):
+    def assertSucceeds(self, func, *args, **kwargs):
         try:
         try:
-            self._handler.set_client_capabilities([])
+            func(*args, **kwargs)
         except GitProtocolError:
         except GitProtocolError:
             self.fail()
             self.fail()
 
 
-        try:
-            self._handler.set_client_capabilities([
-                'multi_ack', 'side-band-64k', 'thin-pack', 'ofs-delta'])
-        except GitProtocolError:
-            self.fail()
+    def test_capability_line(self):
+        self.assertEquals('cap1 cap2 cap3', self._handler.capability_line())
 
 
-    def test_set_client_capabilities_error(self):
-        self.assertRaises(GitProtocolError,
-                          self._handler.set_client_capabilities,
-                          ['weird_ack_level', 'ofs-delta'])
-        try:
-            self._handler.set_client_capabilities(['include-tag'])
-        except GitProtocolError:
-            self.fail()
+    def test_set_client_capabilities(self):
+        set_caps = self._handler.set_client_capabilities
+        self.assertSucceeds(set_caps, [])
+        self.assertSucceeds(set_caps, ['cap2'])
+        self.assertSucceeds(set_caps, ['cap1', 'cap2'])
+        # different order
+        self.assertSucceeds(set_caps, ['cap3', 'cap1', 'cap2'])
+        self.assertRaises(GitProtocolError, set_caps, ['capxxx', 'cap1'])
+
+    def test_has_capability(self):
+        self.assertRaises(GitProtocolError, self._handler.has_capability, 'cap')
+        caps = self._handler.capabilities()
+        self._handler.set_client_capabilities(caps)
+        for cap in caps:
+            self.assertTrue(self._handler.has_capability(cap))
+        self.assertFalse(self._handler.has_capability('capxxx'))
 
 
 
 
 class TestCommit(object):
 class TestCommit(object):
@@ -119,7 +126,7 @@ class TestBackend(object):
         self.object_store = objects
         self.object_store = objects
 
 
 
 
-class TestHandler(object):
+class TestUploadPackHandler(Handler):
     def __init__(self, objects, proto):
     def __init__(self, objects, proto):
         self.backend = TestBackend(objects)
         self.backend = TestBackend(objects)
         self.proto = proto
         self.proto = proto
@@ -127,7 +134,7 @@ class TestHandler(object):
         self.advertise_refs = False
         self.advertise_refs = False
 
 
     def capabilities(self):
     def capabilities(self):
-        return 'multi_ack'
+        return ('multi_ack',)
 
 
 
 
 class ProtocolGraphWalkerTestCase(TestCase):
 class ProtocolGraphWalkerTestCase(TestCase):
@@ -144,7 +151,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
             FIVE: TestCommit(FIVE, [THREE], 555),
             FIVE: TestCommit(FIVE, [THREE], 555),
             }
             }
         self._walker = ProtocolGraphWalker(
         self._walker = ProtocolGraphWalker(
-            TestHandler(self._objects, TestProto()))
+            TestUploadPackHandler(self._objects, TestProto()))
 
 
     def test_is_satisfied_no_haves(self):
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))