Răsfoiți Sursa

Support progress reporting in iterobjects.

Jelmer Vernooij 16 ani în urmă
părinte
comite
ab09036231
2 a modificat fișierele cu 22 adăugiri și 16 ștergeri
  1. 16 13
      dulwich/pack.py
  2. 6 3
      dulwich/protocol.py

+ 16 - 13
dulwich/pack.py

@@ -35,7 +35,7 @@ try:
 except ImportError:
     from misc import defaultdict
 
-from itertools import imap, izip
+from itertools import chain, imap, izip
 import mmap
 import os
 import struct
@@ -491,7 +491,7 @@ class PackData(object):
         ret = (type, apply_delta(base_text, delta))
         return ret
   
-    def iterobjects(self):
+    def iterobjects(self, progress=None):
         offset = self._header_size
         num = len(self)
         map, _ = simple_mmap(self._file, 0, self._size)
@@ -501,16 +501,19 @@ class PackData(object):
                 crc32 = zlib.crc32(map[offset:offset+total_size]) & 0xffffffff
                 yield offset, type, obj, crc32
                 offset += total_size
+                if progress:
+                    progress(i, num)
         finally:
             map.close()
   
-    def iterentries(self, ext_resolve_ref=None):
+    def iterentries(self, ext_resolve_ref=None, progress=None):
         found = {}
         postponed = defaultdict(list)
         class Postpone(Exception):
             """Raised to postpone delta resolving."""
           
         def get_ref_text(sha):
+            assert len(sha) == 20
             if sha in found:
                 return found[sha]
             if ext_resolve_ref:
@@ -519,9 +522,9 @@ class PackData(object):
                 except KeyError:
                     pass
             raise Postpone, (sha, )
-        todo = list(self.iterobjects())
-        while todo:
-            (offset, type, obj, crc32) = todo.pop(0)
+        extra = []
+        todo = chain(self.iterobjects(progress), extra)
+        for (offset, type, obj, crc32) in todo:
             assert isinstance(offset, int)
             assert isinstance(type, int)
             assert isinstance(obj, tuple) or isinstance(obj, str)
@@ -534,21 +537,21 @@ class PackData(object):
                 sha = shafile.sha().digest()
                 found[sha] = (type, obj)
                 yield sha, offset, crc32
-                todo += postponed.get(sha, [])
+                extra.extend(postponed.get(sha, []))
         if postponed:
             raise KeyError([sha_to_hex(h) for h in postponed.keys()])
   
-    def sorted_entries(self, resolve_ext_ref=None):
-        ret = list(self.iterentries(resolve_ext_ref))
+    def sorted_entries(self, resolve_ext_ref=None, progress=None):
+        ret = list(self.iterentries(resolve_ext_ref, progress=progress))
         ret.sort()
         return ret
   
-    def create_index_v1(self, filename, resolve_ext_ref=None):
-        entries = self.sorted_entries(resolve_ext_ref)
+    def create_index_v1(self, filename, resolve_ext_ref=None, progress=None):
+        entries = self.sorted_entries(resolve_ext_ref, progress=progress)
         write_pack_index_v1(filename, entries, self.calculate_checksum())
   
-    def create_index_v2(self, filename, resolve_ext_ref=None):
-        entries = self.sorted_entries(resolve_ext_ref)
+    def create_index_v2(self, filename, resolve_ext_ref=None, progress=None):
+        entries = self.sorted_entries(resolve_ext_ref, progress=progress)
         write_pack_index_v2(filename, entries, self.calculate_checksum())
   
     def get_stored_checksum(self):

+ 6 - 3
dulwich/protocol.py

@@ -64,7 +64,8 @@ class Protocol(object):
                 raise HangupException()
             size = int(sizestr, 16)
             if size == 0:
-                self.report_activity(4, 'read')
+                if self.report_activity:
+                    self.report_activity(4, 'read')
                 return None
             if self.report_activity:
                 self.report_activity(size, 'read')
@@ -87,10 +88,12 @@ class Protocol(object):
         try:
             if line is None:
                 self.write("0000")
-                self.report_activity(4, 'write')
+                if self.report_activity:
+                    self.report_activity(4, 'write')
             else:
                 self.write("%04x%s" % (len(line)+4, line))
-                self.report_activity(4+len(line), 'write')
+                if self.report_activity:
+                    self.report_activity(4+len(line), 'write')
         except socket.error, e:
             raise GitProtocolError(e)