Browse Source

Return symrefs from fetch_pack.

Jelmer Vernooij 7 years ago
parent
commit
453f80b481
5 changed files with 70 additions and 26 deletions
  1. 3 0
      NEWS
  2. 1 1
      docs/tutorial/remote.txt
  3. 34 11
      dulwich/client.py
  4. 8 8
      dulwich/tests/compat/test_client.py
  5. 24 6
      dulwich/tests/test_client.py

+ 3 - 0
NEWS

@@ -11,6 +11,9 @@
 
   * Add ``dulwich.porcelain.update_head``. (Jelmer Vernooij, #439)
 
+  * ``GitClient.fetch_pack`` now returns symrefs.
+    (Jelmer Vernooij, #485)
+
 0.18.2	2017-08-01
 
  TEST FIXES

+ 1 - 1
docs/tutorial/remote.txt

@@ -60,7 +60,7 @@ which we will write to a ``BytesIO`` object::
 
    >>> from io import BytesIO
    >>> f = BytesIO()
-   >>> remote_refs = client.fetch_pack(b"/", determine_wants,
+   >>> result = client.fetch_pack(b"/", determine_wants,
    ...    DummyGraphWalker(), pack_data=f.write)
 
 ``f`` will now contain a full pack file::

+ 34 - 11
dulwich/client.py

@@ -206,6 +206,28 @@ def read_pkt_refs(proto):
     return refs, set(server_capabilities)
 
 
+class FetchPackResult(object):
+
+    _FORWARDED_ATTRS = [
+            'clear', 'copy', 'fromkeys', 'get', 'has_key', 'items',
+            'iteritems', 'iterkeys', 'itervalues', 'keys', 'pop', 'popitem',
+            'setdefault', 'update', 'values', 'viewitems', 'viewkeys',
+            'viewvalues']
+
+    def __init__(self, refs, symrefs):
+        self.refs = refs
+        self.symrefs = symrefs
+
+    def __getattribute__(self, name):
+        if name in type(self)._FORWARDED_ATTRS:
+            import warnings
+            warnings.warn(
+                "Use FetchPackResult.refs instead.",
+                DeprecationWarning, stacklevel=2)
+            return getattr(self.refs, name)
+        return super(FetchPackResult, self).__getattribute__(name)
+
+
 # TODO(durin42): this doesn't correctly degrade if the server doesn't
 # support some capabilities. This should work properly with servers
 # that don't support multi_ack.
@@ -320,7 +342,7 @@ class GitClient(object):
         :param graph_walker: Object with next() and ack().
         :param pack_data: Callback called for each bit of data in the pack
         :param progress: Callback for progress reports (strings)
-        :return: Dictionary with all remote refs (not just those fetched)
+        :return: FetchPackResult object
         """
         raise NotImplementedError(self.fetch_pack)
 
@@ -660,7 +682,7 @@ class TraditionalGitClient(GitClient):
         :param graph_walker: Object with next() and ack().
         :param pack_data: Callback called for each bit of data in the pack
         :param progress: Callback for progress reports (strings)
-        :return: Dictionary with all remote refs (not just those fetched)
+        :return: FetchPackResult object
         """
         proto, can_read = self._connect(b'upload-pack', path)
         with proto:
@@ -671,7 +693,7 @@ class TraditionalGitClient(GitClient):
 
             if refs is None:
                 proto.write_pkt_line(None)
-                return refs
+                return FetchPackResult(refs, symrefs)
 
             try:
                 wants = determine_wants(refs)
@@ -682,13 +704,13 @@ class TraditionalGitClient(GitClient):
                 wants = [cid for cid in wants if cid != ZERO_SHA]
             if not wants:
                 proto.write_pkt_line(None)
-                return refs
+                return FetchPackResult(refs, symrefs)
             self._handle_upload_pack_head(
                 proto, negotiated_capabilities, graph_walker, wants, can_read)
             self._handle_upload_pack_tail(
                 proto, negotiated_capabilities, graph_walker, pack_data,
                 progress)
-            return refs
+            return FetchPackResult(refs, symrefs)
 
     def get_refs(self, path):
         """Retrieve the current refs from a git smart server."""
@@ -967,18 +989,19 @@ class LocalGitClient(GitClient):
         :param graph_walker: Object with next() and ack().
         :param pack_data: Callback called for each bit of data in the pack
         :param progress: Callback for progress reports (strings)
-        :return: Dictionary with all remote refs (not just those fetched)
+        :return: FetchPackResult object
         """
         with self._open_repo(path) as r:
             objects_iter = r.fetch_objects(
                 determine_wants, graph_walker, progress)
+            symrefs = r.refs.get_symrefs()
 
             # Did the process short-circuit (e.g. in a stateless RPC call)?
             # Note that the client still expects a 0-object pack in most cases.
             if objects_iter is None:
-                return
+                return FetchPackResult(None, symrefs)
             write_pack_objects(ProtocolFile(None, pack_data), objects_iter)
-            return r.get_refs()
+            return FetchPackResult(r.get_refs(), symrefs)
 
     def get_refs(self, path):
         """Retrieve the current refs from a git smart server."""
@@ -1288,7 +1311,7 @@ class HttpGitClient(GitClient):
         :param graph_walker: Object with next() and ack().
         :param pack_data: Callback called for each bit of data in the pack
         :param progress: Callback for progress reports (strings)
-        :return: Dictionary with all remote refs (not just those fetched)
+        :return: FetchPackResult object
         """
         url = self._get_url(path)
         refs, server_capabilities = self._discover_references(
@@ -1300,7 +1323,7 @@ class HttpGitClient(GitClient):
         if wants is not None:
             wants = [cid for cid in wants if cid != ZERO_SHA]
         if not wants:
-            return refs
+            return FetchPackResult(refs, symrefs)
         if self.dumb:
             raise NotImplementedError(self.send_pack)
         req_data = BytesIO()
@@ -1315,7 +1338,7 @@ class HttpGitClient(GitClient):
             self._handle_upload_pack_tail(
                 resp_proto, negotiated_capabilities, graph_walker, pack_data,
                 progress)
-            return refs
+            return FetchPackResult(refs, symrefs)
         finally:
             resp.close()
 

+ 8 - 8
dulwich/tests/compat/test_client.py

@@ -205,8 +205,8 @@ class DulwichClientTestBase(object):
     def test_fetch_pack(self):
         c = self._client()
         with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            refs = c.fetch(self._build_path('/server_new.export'), dest)
-            for r in refs.items():
+            result = c.fetch(self._build_path('/server_new.export'), dest)
+            for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
 
@@ -217,8 +217,8 @@ class DulwichClientTestBase(object):
         c = self._client()
         repo_dir = os.path.join(self.gitroot, 'server_new.export')
         with repo.Repo(repo_dir) as dest:
-            refs = c.fetch(self._build_path('/dest'), dest)
-            for r in refs.items():
+            result = c.fetch(self._build_path('/dest'), dest)
+            for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
 
@@ -226,8 +226,8 @@ class DulwichClientTestBase(object):
         c = self._client()
         c._fetch_capabilities.remove(b'side-band-64k')
         with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            refs = c.fetch(self._build_path('/server_new.export'), dest)
-            for r in refs.items():
+            result = c.fetch(self._build_path('/server_new.export'), dest)
+            for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
 
@@ -236,10 +236,10 @@ class DulwichClientTestBase(object):
         # be ignored
         c = self._client()
         with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            refs = c.fetch(
+            result = c.fetch(
                 self._build_path('/server_new.export'), dest,
                 lambda refs: [protocol.ZERO_SHA])
-            for r in refs.items():
+            for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
 
     def test_send_remove_branch(self):

+ 24 - 6
dulwich/tests/test_client.py

@@ -123,7 +123,8 @@ class GitClientTests(TestCase):
             self.assertIs(heads, None)
             return []
         ret = self.client.fetch_pack(b'/', check_heads, None, None)
-        self.assertIs(None, ret)
+        self.assertIs(None, ret.refs)
+        self.assertEqual({}, ret.symrefs)
 
     def test_fetch_pack_ignores_magic_ref(self):
         self.rin.write(
@@ -138,17 +139,23 @@ class GitClientTests(TestCase):
             self.assertEquals({}, heads)
             return []
         ret = self.client.fetch_pack(b'bla', check_heads, None, None, None)
-        self.assertIs(None, ret)
+        self.assertIs(None, ret.refs)
+        self.assertEqual({}, ret.symrefs)
         self.assertEqual(self.rout.getvalue(), b'0000')
 
     def test_fetch_pack_none(self):
         self.rin.write(
-            b'008855dcc6bf963f922e1ed5c4bbaaefcfacef57b1d7 HEAD.multi_ack '
+            b'008855dcc6bf963f922e1ed5c4bbaaefcfacef57b1d7 HEAD\x00multi_ack '
             b'thin-pack side-band side-band-64k ofs-delta shallow no-progress '
             b'include-tag\n'
             b'0000')
         self.rin.seek(0)
-        self.client.fetch_pack(b'bla', lambda heads: [], None, None, None)
+        ret = self.client.fetch_pack(
+                b'bla', lambda heads: [], None, None, None)
+        self.assertEqual(
+                {b'HEAD': b'55dcc6bf963f922e1ed5c4bbaaefcfacef57b1d7'},
+                ret.refs)
+        self.assertEqual({}, ret.symrefs)
         self.assertEqual(self.rout.getvalue(), b'0000')
 
     def test_send_pack_no_sideband64k_with_update_ref_error(self):
@@ -745,7 +752,10 @@ class LocalGitClientTests(TestCase):
             b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
             b'refs/tags/mytag-packed':
                 b'b0931cadc54336e78a1d980420e3268903b57a50'
-            }, ret)
+            }, ret.refs)
+        self.assertEqual(
+                {b'HEAD': b'refs/heads/master'},
+                ret.symrefs)
         self.assertEqual(
                 b"PACK\x00\x00\x00\x02\x00\x00\x00\x00\x02\x9d\x08"
                 b"\x82;\xd8\xa8\xea\xb5\x10\xadj\xc7\\\x82<\xfd>\xd3\x1e",
@@ -757,10 +767,18 @@ class LocalGitClientTests(TestCase):
         self.addCleanup(tear_down_repo, s)
         out = BytesIO()
         walker = MemoryRepo().get_graph_walker()
-        c.fetch_pack(
+        ret = c.fetch_pack(
             s.path,
             lambda heads: [b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"],
             graph_walker=walker, pack_data=out.write)
+        self.assertEqual({b'HEAD': b'refs/heads/master'}, ret.symrefs)
+        self.assertEqual({
+            b'HEAD': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            b'refs/heads/master': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
+            b'refs/tags/mytag-packed':
+            b'b0931cadc54336e78a1d980420e3268903b57a50'
+            }, ret.refs)
         # Hardcoding is not ideal, but we'll fix that some other day..
         self.assertTrue(out.getvalue().startswith(
                 b'PACK\x00\x00\x00\x02\x00\x00\x00\x07'))