Kaynağa Gözat

Refactor server capability code into base Handler.

UploadPackHandler and ReceivePackHandler now both handle client
capabilities using a consistent interface, the set_client_capabilites
and has_capability functions. Both now error as soon as an unknown
capability is requested by the client.

Also renames the following methods:
  capabilities -> capability_line
  default_capabilities -> capabilities
This is because capability_line is the less general of the two
methods, as it is only useful when advertising capabilities to the
client.

Changed capabilities tests to use the base class and test new
functionality.

Change-Id: If7d3feeac27834119d6d4e4021569401e5444d51
Dave Borowitz 15 yıl önce
ebeveyn
işleme
37a673e4b2
2 değiştirilmiş dosya ile 57 ekleme ve 44 silme
  1. 30 24
      dulwich/server.py
  2. 27 20
      dulwich/tests/test_server.py

+ 30 - 24
dulwich/server.py

@@ -152,9 +152,27 @@ class Handler(object):
     def __init__(self, backend, read, write):
     def __init__(self, backend, read, write):
         self.backend = backend
         self.backend = backend
         self.proto = Protocol(read, write)
         self.proto = Protocol(read, write)
+        self._client_capabilities = None
+
+    def capability_line(self):
+        return " ".join(self.capabilities())
 
 
     def capabilities(self):
     def capabilities(self):
-        return " ".join(self.default_capabilities())
+        raise NotImplementedError(self.capabilities)
+
+    def set_client_capabilities(self, caps):
+        my_caps = self.capabilities()
+        for cap in caps:
+            if cap not in my_caps:
+                raise GitProtocolError('Client asked for capability %s that '
+                                       'was not advertised.' % cap)
+        self._client_capabilities = caps
+
+    def has_capability(self, cap):
+        if self._client_capabilities is None:
+            raise GitProtocolError('Server attempted to access capability %s '
+                                   'before asking client' % cap)
+        return cap in self._client_capabilities
 
 
 
 
 class UploadPackHandler(Handler):
 class UploadPackHandler(Handler):
@@ -163,29 +181,14 @@ class UploadPackHandler(Handler):
     def __init__(self, backend, read, write,
     def __init__(self, backend, read, write,
                  stateless_rpc=False, advertise_refs=False):
                  stateless_rpc=False, advertise_refs=False):
         Handler.__init__(self, backend, read, write)
         Handler.__init__(self, backend, read, write)
-        self._client_capabilities = None
         self._graph_walker = None
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
 
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
                 "ofs-delta")
                 "ofs-delta")
 
 
-    def set_client_capabilities(self, caps):
-        my_caps = self.default_capabilities()
-        for cap in caps:
-            if '_ack' in cap and cap not in my_caps:
-                raise GitProtocolError('Client asked for capability %s that '
-                                       'was not advertised.' % cap)
-        self._client_capabilities = caps
-
-    def get_client_capabilities(self):
-        return self._client_capabilities
-
-    client_capabilities = property(get_client_capabilities,
-                                   set_client_capabilities)
-
     def handle(self):
     def handle(self):
 
 
         progress = lambda x: self.proto.write_sideband(2, x)
         progress = lambda x: self.proto.write_sideband(2, x)
@@ -251,7 +254,7 @@ class ProtocolGraphWalker(object):
             for i, (ref, sha) in enumerate(heads.iteritems()):
             for i, (ref, sha) in enumerate(heads.iteritems()):
                 line = "%s %s" % (sha, ref)
                 line = "%s %s" % (sha, ref)
                 if not i:
                 if not i:
-                    line = "%s\x00%s" % (line, self.handler.capabilities())
+                    line = "%s\x00%s" % (line, self.handler.capability_line())
                 self.proto.write_pkt_line("%s\n" % line)
                 self.proto.write_pkt_line("%s\n" % line)
                 # TODO: include peeled value of any tags
                 # TODO: include peeled value of any tags
 
 
@@ -266,7 +269,7 @@ class ProtocolGraphWalker(object):
         if not want:
         if not want:
             return []
             return []
         line, caps = extract_want_line_capabilities(want)
         line, caps = extract_want_line_capabilities(want)
-        self.handler.client_capabilities = caps
+        self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
         self.set_ack_type(ack_type(caps))
         command, sha = self._split_proto_line(line)
         command, sha = self._split_proto_line(line)
 
 
@@ -509,7 +512,7 @@ class ReceivePackHandler(Handler):
         self._stateless_rpc = stateless_rpc
         self._stateless_rpc = stateless_rpc
         self._advertise_refs = advertise_refs
         self._advertise_refs = advertise_refs
 
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("report-status", "delete-refs")
         return ("report-status", "delete-refs")
 
 
     def handle(self):
     def handle(self):
@@ -517,12 +520,14 @@ class ReceivePackHandler(Handler):
 
 
         if self.advertise_refs or not self.stateless_rpc:
         if self.advertise_refs or not self.stateless_rpc:
             if refs:
             if refs:
-                self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
+                self.proto.write_pkt_line(
+                    "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
+                                       self.capability_line()))
                 for i in range(1, len(refs)):
                 for i in range(1, len(refs)):
                     ref = refs[i]
                     ref = refs[i]
                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
             else:
             else:
-                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
+                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capability_line())
 
 
             self.proto.write("0000")
             self.proto.write("0000")
             if self.advertise_refs:
             if self.advertise_refs:
@@ -535,7 +540,8 @@ class ReceivePackHandler(Handler):
         if ref is None:
         if ref is None:
             return
             return
 
 
-        ref, client_capabilities = extract_capabilities(ref)
+        ref, caps = extract_capabilities(ref)
+        self.set_client_capabilities(caps)
 
 
         # client will now send us a list of (oldsha, newsha, ref)
         # client will now send us a list of (oldsha, newsha, ref)
         while ref:
         while ref:
@@ -547,7 +553,7 @@ class ReceivePackHandler(Handler):
 
 
         # when we have read all the pack from the client, send a status report
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
         # if the client asked for it
-        if 'report-status' in client_capabilities:
+        if self.has_capability('report-status'):
             for name, msg in status:
             for name, msg in status:
                 if name == 'unpack':
                 if name == 'unpack':
                     self.proto.write_pkt_line('unpack %s\n' % msg)
                     self.proto.write_pkt_line('unpack %s\n' % msg)

+ 27 - 20
dulwich/tests/test_server.py

@@ -28,6 +28,7 @@ from dulwich.errors import (
     )
     )
 from dulwich.server import (
 from dulwich.server import (
     UploadPackHandler,
     UploadPackHandler,
+    Handler,
     ProtocolGraphWalker,
     ProtocolGraphWalker,
     SingleAckGraphWalkerImpl,
     SingleAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
@@ -75,30 +76,36 @@ class TestProto(object):
             return None
             return None
 
 
 
 
-class UploadPackHandlerTestCase(TestCase):
+class HandlerTestCase(TestCase):
     def setUp(self):
     def setUp(self):
-        self._handler = UploadPackHandler(None, None, None)
+        self._handler = Handler(None, None, None)
+        self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
 
 
-    def test_set_client_capabilities(self):
+    def assertSucceeds(self, func, *args, **kwargs):
         try:
         try:
-            self._handler.set_client_capabilities([])
+            func(*args, **kwargs)
         except GitProtocolError:
         except GitProtocolError:
             self.fail()
             self.fail()
 
 
-        try:
+    def test_capability_line(self):
-            self._handler.set_client_capabilities([
+        self.assertEquals('cap1 cap2 cap3', self._handler.capability_line())
-                'multi_ack', 'side-band-64k', 'thin-pack', 'ofs-delta'])
-        except GitProtocolError:
-            self.fail()
 
 
-    def test_set_client_capabilities_error(self):
+    def test_set_client_capabilities(self):
-        self.assertRaises(GitProtocolError,
+        set_caps = self._handler.set_client_capabilities
-                          self._handler.set_client_capabilities,
+        self.assertSucceeds(set_caps, [])
-                          ['weird_ack_level', 'ofs-delta'])
+        self.assertSucceeds(set_caps, ['cap2'])
-        try:
+        self.assertSucceeds(set_caps, ['cap1', 'cap2'])
-            self._handler.set_client_capabilities(['include-tag'])
+        # different order
-        except GitProtocolError:
+        self.assertSucceeds(set_caps, ['cap3', 'cap1', 'cap2'])
-            self.fail()
+        self.assertRaises(GitProtocolError, set_caps, ['capxxx', 'cap1'])
+
+    def test_has_capability(self):
+        self.assertRaises(GitProtocolError, self._handler.has_capability, 'cap')
+        caps = self._handler.capabilities()
+        self._handler.set_client_capabilities(caps)
+        for cap in caps:
+            self.assertTrue(self._handler.has_capability(cap))
+        self.assertFalse(self._handler.has_capability('capxxx'))
 
 
 
 
 class TestCommit(object):
 class TestCommit(object):
@@ -119,7 +126,7 @@ class TestBackend(object):
         self.object_store = objects
         self.object_store = objects
 
 
 
 
-class TestHandler(object):
+class TestUploadPackHandler(Handler):
     def __init__(self, objects, proto):
     def __init__(self, objects, proto):
         self.backend = TestBackend(objects)
         self.backend = TestBackend(objects)
         self.proto = proto
         self.proto = proto
@@ -127,7 +134,7 @@ class TestHandler(object):
         self.advertise_refs = False
         self.advertise_refs = False
 
 
     def capabilities(self):
     def capabilities(self):
-        return 'multi_ack'
+        return ('multi_ack',)
 
 
 
 
 class ProtocolGraphWalkerTestCase(TestCase):
 class ProtocolGraphWalkerTestCase(TestCase):
@@ -144,7 +151,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
             FIVE: TestCommit(FIVE, [THREE], 555),
             FIVE: TestCommit(FIVE, [THREE], 555),
             }
             }
         self._walker = ProtocolGraphWalker(
         self._walker = ProtocolGraphWalker(
-            TestHandler(self._objects, TestProto()))
+            TestUploadPackHandler(self._objects, TestProto()))
 
 
     def test_is_satisfied_no_haves(self):
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))