Browse Source

Avoid seek when loading pack index files.

Jelmer Vernooij 15 years ago
parent
commit
34fae6292f
1 changed files with 34 additions and 23 deletions
  1. 34 23
      dulwich/pack.py

+ 34 - 23
dulwich/pack.py

@@ -36,6 +36,7 @@ except ImportError:
     from misc import defaultdict
 
 import difflib
+import errno
 from itertools import (
     chain,
     imap,
@@ -124,22 +125,41 @@ def load_pack_index(path):
     return load_pack_index_file(path, f)
 
 
+def _load_file_contents(f, size=None):
+    fileno = getattr(f, 'fileno', None)
+    # Attempt to use mmap if possible
+    if fileno is not None:
+        fd = f.fileno()
+        if size is None:
+            size = os.fstat(fd).st_size
+        try:
+            contents = mmap.mmap(fd, size, access=mmap.ACCESS_READ)
+        except mmap.error:
+            # Perhaps a socket?
+            pass
+        else:
+            return contents, size
+    contents = f.read()
+    size = len(contents)
+    return contents, size
+
+
 def load_pack_index_file(path, f):
     """Load an index file from a file-like object.
 
     :param path: Path for the index file
     :param f: File-like object
     """
-    if f.read(4) == '\377tOc':
-        version = struct.unpack(">L", f.read(4))[0]
+    contents, size = _load_file_contents(f)
+    if contents[:4] == '\377tOc':
+        version = struct.unpack(">L", contents[4:8])[0]
         if version == 2:
-            f.seek(0)
-            return PackIndex2(path, file=f)
+            return PackIndex2(path, file=f, contents=contents,
+                size=size)
         else:
             raise KeyError("Unknown pack index format %d" % version)
     else:
-        f.seek(0)
-        return PackIndex1(path, file=f)
+        return PackIndex1(path, file=f, contents=contents, size=size)
 
 
 def bisect_find_sha(start, end, sha, unpack_name):
@@ -179,7 +199,7 @@ class PackIndex(object):
     the start and end offset and then bisect in to find if the value is present.
     """
   
-    def __init__(self, filename, file=None, size=None):
+    def __init__(self, filename, file=None, contents=None, size=None):
         """Create a pack index object.
     
         Provide it with the name of the index file to consider, and it will map
@@ -192,19 +212,10 @@ class PackIndex(object):
             self._file = GitFile(filename, 'rb')
         else:
             self._file = file
-        fileno = getattr(self._file, 'fileno', None)
-        if fileno is not None:
-            fd = self._file.fileno()
-            if size is None:
-                self._size = os.fstat(fd).st_size
-            else:
-                self._size = size
-            self._contents = mmap.mmap(fd, self._size,
-                access=mmap.ACCESS_READ)
+        if contents is None:
+            self._contents, self._size = _load_file_contents(file, size)
         else:
-            self._file.seek(0)
-            self._contents = self._file.read()
-            self._size = len(self._contents)
+            self._contents, self._size = (contents, size)
   
     def __eq__(self, other):
         if not isinstance(other, PackIndex):
@@ -338,8 +349,8 @@ class PackIndex(object):
 class PackIndex1(PackIndex):
     """Version 1 Pack Index."""
 
-    def __init__(self, filename, file=None, size=None):
-        PackIndex.__init__(self, filename, file, size)
+    def __init__(self, filename, file=None, contents=None, size=None):
+        PackIndex.__init__(self, filename, file, contents, size)
         self.version = 1
         self._fan_out_table = self._read_fan_out_table(0)
 
@@ -364,8 +375,8 @@ class PackIndex1(PackIndex):
 class PackIndex2(PackIndex):
     """Version 2 Pack Index."""
 
-    def __init__(self, filename, file=None, size=None):
-        PackIndex.__init__(self, filename, file, size)
+    def __init__(self, filename, file=None, contents=None, size=None):
+        PackIndex.__init__(self, filename, file, contents, size)
         assert self._contents[:4] == '\377tOc', "Not a v2 pack index file"
         (self.version, ) = unpack_from(">L", self._contents, 4)
         assert self.version == 2, "Version was %d" % self.version