Sfoglia il codice sorgente

Refactor server capability code into base Handler.

UploadPackHandler and ReceivePackHandler now both handle client
capabilities using a consistent interface, the set_client_capabilites
and has_capability functions. Both now error as soon as an unknown
capability is requested by the client.

Also renames the following methods:
  capabilities -> capability_line
  default_capabilities -> capabilities
This is because capability_line is the less general of the two
methods, as it is only useful when advertising capabilities to the
client.

Changed capabilities tests to use the base class and test new
functionality.

Change-Id: If7d3feeac27834119d6d4e4021569401e5444d51
Dave Borowitz 15 anni fa
parent
commit
37a673e4b2
2 ha cambiato i file con 57 aggiunte e 44 eliminazioni
  1. 30 24
      dulwich/server.py
  2. 27 20
      dulwich/tests/test_server.py

+ 30 - 24
dulwich/server.py

@@ -152,9 +152,27 @@ class Handler(object):
     def __init__(self, backend, read, write):
         self.backend = backend
         self.proto = Protocol(read, write)
+        self._client_capabilities = None
+
+    def capability_line(self):
+        return " ".join(self.capabilities())
 
     def capabilities(self):
-        return " ".join(self.default_capabilities())
+        raise NotImplementedError(self.capabilities)
+
+    def set_client_capabilities(self, caps):
+        my_caps = self.capabilities()
+        for cap in caps:
+            if cap not in my_caps:
+                raise GitProtocolError('Client asked for capability %s that '
+                                       'was not advertised.' % cap)
+        self._client_capabilities = caps
+
+    def has_capability(self, cap):
+        if self._client_capabilities is None:
+            raise GitProtocolError('Server attempted to access capability %s '
+                                   'before asking client' % cap)
+        return cap in self._client_capabilities
 
 
 class UploadPackHandler(Handler):
@@ -163,29 +181,14 @@ class UploadPackHandler(Handler):
     def __init__(self, backend, read, write,
                  stateless_rpc=False, advertise_refs=False):
         Handler.__init__(self, backend, read, write)
-        self._client_capabilities = None
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
                 "ofs-delta")
 
-    def set_client_capabilities(self, caps):
-        my_caps = self.default_capabilities()
-        for cap in caps:
-            if '_ack' in cap and cap not in my_caps:
-                raise GitProtocolError('Client asked for capability %s that '
-                                       'was not advertised.' % cap)
-        self._client_capabilities = caps
-
-    def get_client_capabilities(self):
-        return self._client_capabilities
-
-    client_capabilities = property(get_client_capabilities,
-                                   set_client_capabilities)
-
     def handle(self):
 
         progress = lambda x: self.proto.write_sideband(2, x)
@@ -251,7 +254,7 @@ class ProtocolGraphWalker(object):
             for i, (ref, sha) in enumerate(heads.iteritems()):
                 line = "%s %s" % (sha, ref)
                 if not i:
-                    line = "%s\x00%s" % (line, self.handler.capabilities())
+                    line = "%s\x00%s" % (line, self.handler.capability_line())
                 self.proto.write_pkt_line("%s\n" % line)
                 # TODO: include peeled value of any tags
 
@@ -266,7 +269,7 @@ class ProtocolGraphWalker(object):
         if not want:
             return []
         line, caps = extract_want_line_capabilities(want)
-        self.handler.client_capabilities = caps
+        self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
         command, sha = self._split_proto_line(line)
 
@@ -509,7 +512,7 @@ class ReceivePackHandler(Handler):
         self._stateless_rpc = stateless_rpc
         self._advertise_refs = advertise_refs
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("report-status", "delete-refs")
 
     def handle(self):
@@ -517,12 +520,14 @@ class ReceivePackHandler(Handler):
 
         if self.advertise_refs or not self.stateless_rpc:
             if refs:
-                self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
+                self.proto.write_pkt_line(
+                    "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
+                                       self.capability_line()))
                 for i in range(1, len(refs)):
                     ref = refs[i]
                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
             else:
-                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
+                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capability_line())
 
             self.proto.write("0000")
             if self.advertise_refs:
@@ -535,7 +540,8 @@ class ReceivePackHandler(Handler):
         if ref is None:
             return
 
-        ref, client_capabilities = extract_capabilities(ref)
+        ref, caps = extract_capabilities(ref)
+        self.set_client_capabilities(caps)
 
         # client will now send us a list of (oldsha, newsha, ref)
         while ref:
@@ -547,7 +553,7 @@ class ReceivePackHandler(Handler):
 
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
-        if 'report-status' in client_capabilities:
+        if self.has_capability('report-status'):
             for name, msg in status:
                 if name == 'unpack':
                     self.proto.write_pkt_line('unpack %s\n' % msg)

+ 27 - 20
dulwich/tests/test_server.py

@@ -28,6 +28,7 @@ from dulwich.errors import (
     )
 from dulwich.server import (
     UploadPackHandler,
+    Handler,
     ProtocolGraphWalker,
     SingleAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
@@ -75,30 +76,36 @@ class TestProto(object):
             return None
 
 
-class UploadPackHandlerTestCase(TestCase):
+class HandlerTestCase(TestCase):
     def setUp(self):
-        self._handler = UploadPackHandler(None, None, None)
+        self._handler = Handler(None, None, None)
+        self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
 
-    def test_set_client_capabilities(self):
+    def assertSucceeds(self, func, *args, **kwargs):
         try:
-            self._handler.set_client_capabilities([])
+            func(*args, **kwargs)
         except GitProtocolError:
             self.fail()
 
-        try:
-            self._handler.set_client_capabilities([
-                'multi_ack', 'side-band-64k', 'thin-pack', 'ofs-delta'])
-        except GitProtocolError:
-            self.fail()
+    def test_capability_line(self):
+        self.assertEquals('cap1 cap2 cap3', self._handler.capability_line())
 
-    def test_set_client_capabilities_error(self):
-        self.assertRaises(GitProtocolError,
-                          self._handler.set_client_capabilities,
-                          ['weird_ack_level', 'ofs-delta'])
-        try:
-            self._handler.set_client_capabilities(['include-tag'])
-        except GitProtocolError:
-            self.fail()
+    def test_set_client_capabilities(self):
+        set_caps = self._handler.set_client_capabilities
+        self.assertSucceeds(set_caps, [])
+        self.assertSucceeds(set_caps, ['cap2'])
+        self.assertSucceeds(set_caps, ['cap1', 'cap2'])
+        # different order
+        self.assertSucceeds(set_caps, ['cap3', 'cap1', 'cap2'])
+        self.assertRaises(GitProtocolError, set_caps, ['capxxx', 'cap1'])
+
+    def test_has_capability(self):
+        self.assertRaises(GitProtocolError, self._handler.has_capability, 'cap')
+        caps = self._handler.capabilities()
+        self._handler.set_client_capabilities(caps)
+        for cap in caps:
+            self.assertTrue(self._handler.has_capability(cap))
+        self.assertFalse(self._handler.has_capability('capxxx'))
 
 
 class TestCommit(object):
@@ -119,7 +126,7 @@ class TestBackend(object):
         self.object_store = objects
 
 
-class TestHandler(object):
+class TestUploadPackHandler(Handler):
     def __init__(self, objects, proto):
         self.backend = TestBackend(objects)
         self.proto = proto
@@ -127,7 +134,7 @@ class TestHandler(object):
         self.advertise_refs = False
 
     def capabilities(self):
-        return 'multi_ack'
+        return ('multi_ack',)
 
 
 class ProtocolGraphWalkerTestCase(TestCase):
@@ -144,7 +151,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
             FIVE: TestCommit(FIVE, [THREE], 555),
             }
         self._walker = ProtocolGraphWalker(
-            TestHandler(self._objects, TestProto()))
+            TestUploadPackHandler(self._objects, TestProto()))
 
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))