Selaa lähdekoodia

Don't retrieve all pack files when fetching.

Jelmer Vernooij 16 vuotta sitten
vanhempi
commit
aa912fdd9f
4 muutettua tiedostoa jossa 39 lisäystä ja 16 poistoa
  1. 9 7
      bin/dul-fetch-pack
  2. 25 8
      dulwich/client.py
  3. 2 1
      dulwich/pack.py
  4. 3 0
      dulwich/repo.py

+ 9 - 7
bin/dul-fetch-pack

@@ -27,18 +27,20 @@ client = TCPGitClient(host)
 all = True
 
 if all:
-	determine_wants = lambda x: x.values()
+    determine_wants = lambda x: x.values()
 else:
-	determinw_wants = lambda x: sys.argv[1:]
+    determine_wants = lambda x: sys.argv[1:]
 
 r = Repo(".")
 
 # FIXME: Will just fetch everything..
-graphwalker = SimpleFetchGraphWalker([], None)
+graphwalker = SimpleFetchGraphWalker(r.heads().values(), r.get_parents)
 
 f, commit = r.add_pack()
 try:
-	client.fetch_pack(path, determine_wants, graphwalker, f.write, sys.stdout.write)
-finally:
-	f.close()
-	commit()
+    client.fetch_pack(path, determine_wants, graphwalker, f.write, sys.stdout.write)
+    f.close()
+    commit()
+except:
+    f.close()
+    raise

+ 25 - 8
dulwich/client.py

@@ -16,6 +16,7 @@
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
 
+import select
 import socket
 
 
@@ -30,13 +31,24 @@ class SimpleFetchGraphWalker(object):
 
     def __init__(self, local_heads, get_parents):
         self.heads = set(local_heads)
+        self.get_parents = get_parents
+        self.parents = {}
 
     def ack(self, ref):
-        pass
+        if ref in self.heads:
+            self.heads.remove(ref)
+        if not ref in self.parents:
+            return
+        for p in self.parents[ref]:
+            self.ack(p)
 
     def next(self):
         if self.heads:
-            return self.heads.pop()
+            ret = self.heads.pop()
+            ps = self.get_parents(ret)
+            self.parents[ret] = ps
+            self.heads.update(ps)
+            return ret
         return None
 
 
@@ -86,7 +98,7 @@ class GitClient(object):
         self.write_pkt_line("%s %s" % (name, "".join(["%s\0" % a for a in args])))
 
     def capabilities(self):
-        return "4b multi_ack side-band-64k thin-pack ofs-delta"
+        return "multi_ack side-band-64k thin-pack ofs-delta"
 
     def read_refs(self):
         server_capabilities = None
@@ -132,20 +144,25 @@ class GitClient(object):
         self.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities()))
         for want in wants[1:]:
             self.write_pkt_line("want %s\n" % want)
+        self.write_pkt_line(None)
         have = graph_walker.next()
         while have:
             self.write_pkt_line("have %s\n" % have)
             if len(select.select([self.fileno], [], [], 0)[0]) > 0:
                 pkt = self.read_pkt_line()
-                if pkt[:3] == "ACK":
-                    graph_walker.ack(pkt.split(" ")[1])
+                parts = pkt.rstrip("\n").split(" ")
+                if parts[0] == "ACK":
+                    graph_walker.ack(parts[1])
+                    assert parts[2] == "continue"
             have = graph_walker.next()
-        self.write_pkt_line(None)
         self.write_pkt_line("done\n")
         pkt = self.read_pkt_line()
-        while pkt != "NAK\n":
-            if pkt[:3] == "ACK":
+        while pkt:
+            parts = pkt.rstrip("\n").split(" ")
+            if parts[0] == "ACK":
                 graph_walker.ack(pkt.split(" ")[1])
+            if len(parts) < 3 or parts[2] != "continue":
+                break
             pkt = self.read_pkt_line()
         for pkt in self.read_pkt_seq():
             channel = ord(pkt[0])

+ 2 - 1
dulwich/pack.py

@@ -362,6 +362,7 @@ class PackData(object):
     self._filename = filename
     assert os.path.exists(filename), "%s is not a packfile" % filename
     self._size = os.path.getsize(filename)
+    assert self._size >= 12, "%s is too small for a packfile" % filename
     self._header_size = self._read_header()
 
   def _read_header(self):
@@ -774,5 +775,5 @@ def load_packs(path):
     if not os.path.exists(path):
         return
     for name in os.listdir(path):
-        if name.endswith(".pack"):
+        if name.startswith("pack-") and name.endswith(".pack"):
             yield Pack(os.path.join(path, name[:-len(".pack")]))

+ 3 - 0
dulwich/repo.py

@@ -150,6 +150,9 @@ class Repo(object):
   def get_object(self, sha):
     return self._get_object(sha, ShaFile)
 
+  def get_parents(self, sha):
+    return self.commit(sha).parents
+
   def commit(self, sha):
     return self._get_object(sha, Commit)