Răsfoiți Sursa

Test for ProtocolGraphWalker if the repository is empty.

Damien Tournoud 11 ani în urmă
părinte
comite
5587baddcd
1 a modificat fișierele cu 38 adăugiri și 9 ștergeri
  1. 38 9
      dulwich/tests/test_server.py

+ 38 - 9
dulwich/tests/test_server.py

@@ -26,6 +26,7 @@ from dulwich.errors import (
     GitProtocolError,
     NotGitRepository,
     UnexpectedCommandError,
+    HangupException,
     )
 from dulwich.objects import (
     Commit,
@@ -78,13 +79,18 @@ class TestProto(object):
         self._received = {0: [], 1: [], 2: [], 3: []}
 
     def set_output(self, output_lines):
-        self._output = ['%s\n' % line.rstrip() for line in output_lines]
+        self._output = output_lines
 
     def read_pkt_line(self):
         if self._output:
-            return self._output.pop(0)
+            data = self._output.pop(0)
+            if data is not None:
+                return '%s\n' % data.rstrip()
+            else:
+                # flush-pkt ('0000').
+                return None
         else:
-            return None
+            raise HangupException()
 
     def write_sideband(self, band, data):
         self._received[band].append(data)
@@ -308,6 +314,27 @@ class ReceivePackHandlerTestCase(TestCase):
         self.assertEqual(status[1][1], 'ok')
 
 
+class ProtocolGraphWalkerEmptyTestCase(TestCase):
+    def setUp(self):
+        super(ProtocolGraphWalkerEmptyTestCase, self).setUp()
+        self._repo = MemoryRepo.init_bare([], {})
+        backend = DictBackend({'/': self._repo})
+        self._walker = ProtocolGraphWalker(
+            TestUploadPackHandler(backend, ['/', 'host=lolcats'], TestProto()),
+            self._repo.object_store, self._repo.get_peeled)
+
+    def test_empty_repository(self):
+        # The server should wait for a flush packet.
+        self._walker.proto.set_output([])
+        self.assertRaises(HangupException, self._walker.determine_wants, {})
+        self.assertEqual(None, self._walker.proto.get_received_line())
+
+        self._walker.proto.set_output([None])
+        self.assertEqual([], self._walker.determine_wants({}))
+        self.assertEqual(None, self._walker.proto.get_received_line())
+
+
+
 class ProtocolGraphWalkerTestCase(TestCase):
 
     def setUp(self):
@@ -371,12 +398,14 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self.assertEqual((None, None), _split_proto_line('', allowed))
 
     def test_determine_wants(self):
+        self._walker.proto.set_output([None])
         self.assertEqual([], self._walker.determine_wants({}))
         self.assertEqual(None, self._walker.proto.get_received_line())
 
         self._walker.proto.set_output([
           'want %s multi_ack' % ONE,
           'want %s' % TWO,
+          None,
           ])
         heads = {
           'refs/heads/ref1': ONE,
@@ -390,20 +419,20 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self.assertEqual([], self._walker.determine_wants(heads))
         self._walker.advertise_refs = False
 
-        self._walker.proto.set_output(['want %s multi_ack' % FOUR])
+        self._walker.proto.set_output(['want %s multi_ack' % FOUR, None])
         self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
 
-        self._walker.proto.set_output([])
+        self._walker.proto.set_output([None])
         self.assertEqual([], self._walker.determine_wants(heads))
 
-        self._walker.proto.set_output(['want %s multi_ack' % ONE, 'foo'])
+        self._walker.proto.set_output(['want %s multi_ack' % ONE, 'foo', None])
         self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
 
-        self._walker.proto.set_output(['want %s multi_ack' % FOUR])
+        self._walker.proto.set_output(['want %s multi_ack' % FOUR, None])
         self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
 
     def test_determine_wants_advertisement(self):
-        self._walker.proto.set_output([])
+        self._walker.proto.set_output([None])
         # advertise branch tips plus tag
         heads = {
           'refs/heads/ref4': FOUR,
@@ -439,7 +468,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
     # TODO: test commit time cutoff
 
     def _handle_shallow_request(self, lines, heads):
-        self._walker.proto.set_output(lines)
+        self._walker.proto.set_output(lines + [None])
         self._walker._handle_shallow_request(heads)
 
     def assertReceived(self, expected):