Selaa lähdekoodia

Add iterating object finder.

Jelmer Vernooij 16 vuotta sitten
vanhempi
commit
31ee91c443
2 muutettua tiedostoa jossa 50 lisäystä ja 38 poistoa
  1. 2 2
      dulwich/object_store.py
  2. 48 36
      dulwich/repo.py

+ 2 - 2
dulwich/object_store.py

@@ -208,9 +208,9 @@ class ObjectIterator(object):
 
 class ObjectStoreIterator(ObjectIterator):
 
-    def __init__(self, store, shas):
+    def __init__(self, store, sha_iter):
         self.store = store
-        self.shas = shas
+        self.shas = list(sha_iter)
 
     def __iter__(self):
         return ((self.store[sha], path) for sha, path in self.shas)

+ 48 - 36
dulwich/repo.py

@@ -66,6 +66,51 @@ class Tags(object):
             yield k, self[k]
 
 
+class MissingObjectFinder(object):
+
+    def __init__(self, object_store, wants, graph_walker, progress=None):
+        self.sha_done = set()
+        self.objects_to_send = set([(w, None) for w in wants])
+        self.object_store = object_store
+        if progress is None:
+            self.progress = lambda x: None
+        else:
+            self.progress = progress
+        ref = graph_walker.next()
+        while ref:
+            if ref in self.object_store:
+                graph_walker.ack(ref)
+            ref = graph_walker.next()
+
+    def add_todo(self, entries):
+        self.objects_to_send.update([e for e in entries if not e in self.sha_done])
+
+    def parse_tree(self, tree):
+        self.add_todo([(sha, name) for (mode, name, sha) in tree.entries()])
+
+    def parse_commit(self, commit):
+        self.add_todo([(commit.tree, "")])
+        self.add_todo([(p, None) for p in commit.parents])
+
+    def parse_tag(self, tag):
+        self.add_todo([(tag.object[1], None)])
+
+    def next(self):
+        if not self.objects_to_send:
+            return None
+        (sha, name) = self.objects_to_send.pop()
+        o = self.object_store[sha]
+        if isinstance(o, Commit):
+            self.parse_commit(o)
+        elif isinstance(o, Tree):
+            self.parse_tree(o)
+        elif isinstance(o, Tag):
+            self.parse_tag(o)
+        self.sha_done.add((sha, name))
+        self.progress("counting objects: %d\r" % len(self.sha_done))
+        return (sha, name)
+
+
 class Repo(object):
 
     ref_locs = ['', 'refs', 'refs/tags', 'refs/heads', 'refs/remotes']
@@ -87,7 +132,7 @@ class Repo(object):
         return self._controldir
 
     def find_missing_objects(self, determine_wants, graph_walker, progress):
-        """Fetch the missing objects required for a set of revisions.
+        """Find the missing objects required for a set of revisions.
 
         :param determine_wants: Function that takes a dictionary with heads 
             and returns the list of heads to fetch.
@@ -98,41 +143,8 @@ class Repo(object):
             updated progress strings.
         """
         wants = determine_wants(self.get_refs())
-        objects_to_send = set(wants)
-        sha_done = set()
-
-        def parse_tree(tree, sha_done):
-            for mode, name, sha in tree.entries():
-                if (sha, name) in sha_done:
-                    continue
-                if mode & stat.S_IFDIR:
-                    parse_tree(self.tree(sha), sha_done)
-                sha_done.add((sha, name))
-
-        def parse_commit(commit, sha_done):
-            treesha = c.tree
-            if c.tree not in sha_done:
-                parse_tree(self.tree(c.tree), sha_done)
-                sha_done.add((c.tree, None))
-
-        ref = graph_walker.next()
-        while ref:
-            if ref in self.object_store:
-                graph_walker.ack(ref)
-            ref = graph_walker.next()
-        while objects_to_send:
-            sha = objects_to_send.pop()
-            if (sha, None) in sha_done:
-                continue
-    
-            c = self.object_store[sha]
-            if isinstance(c, Commit):
-                parse_commit(c, sha_done)
-                objects_to_send.update([p for p in c.parents if not p in sha_done])
-            sha_done.add((sha, None))
-    
-            progress("counting objects: %d\r" % len(sha_done))
-        return sha_done
+        return iter(MissingObjectFinder(self.object_store, wants, graph_walker, 
+                progress).next, None)
 
     def fetch_objects(self, determine_wants, graph_walker, progress):
         """Fetch the missing objects required for a set of revisions.