فهرست منبع

Merge upstream

John Carr 16 سال پیش
والد
کامیت
9c52705348
6فایلهای تغییر یافته به همراه112 افزوده شده و 54 حذف شده
  1. 4 4
      bin/dul-fetch-pack
  2. 1 1
      dulwich/client.py
  3. 15 8
      dulwich/pack.py
  4. 2 1
      dulwich/protocol.py
  5. 89 38
      dulwich/repo.py
  6. 1 2
      dulwich/tests/test_repository.py

+ 4 - 4
bin/dul-fetch-pack

@@ -34,20 +34,20 @@ if not ":" in args[0]:
 	print "Usage: dul-fetch-pack host:path"
 	sys.exit(1)
 
-(host, path) = args[0].split(":", 1)
+(host, path) = args.pop(0).split(":", 1)
 client = TCPGitClient(host)
 
 if "--all" in opts:
-    determine_wants = lambda x: x.values()
+    determine_wants = lambda x: [y for y in x.values() if not y in r.object_store]
 else:
-    determine_wants = lambda x: sys.argv[1:]
+    determine_wants = lambda x: [y for y in args if not y in r.object_store]
 
 r = Repo(".")
 
 # FIXME: Will just fetch everything..
 graphwalker = SimpleFetchGraphWalker(r.heads().values(), r.get_parents)
 
-f, commit = r.add_pack()
+f, commit = r.object_store.add_pack()
 try:
     client.fetch_pack(path, determine_wants, graphwalker, f.write, sys.stdout.write)
     f.close()

+ 1 - 1
dulwich/client.py

@@ -138,7 +138,7 @@ class GitClient(object):
 class TCPGitClient(GitClient):
 
     def __init__(self, host, port=TCP_GIT_PORT):
-        self._socket = socket.socket()
+        self._socket = socket.socket(type=socket.SOCK_STREAM)
         self._socket.connect((host, port))
         self.rfile = self._socket.makefile('rb', -1)
         self.wfile = self._socket.makefile('wb', 0)

+ 15 - 8
dulwich/pack.py

@@ -401,7 +401,7 @@ class PackData(object):
         offset += total_size
     f.close()
 
-  def iterentries(self):
+  def iterentries(self, ext_resolve_ref=None):
     found = {}
     at = {}
     postponed = defaultdict(list)
@@ -411,6 +411,11 @@ class PackData(object):
     def get_ref_text(sha):
         if sha in found:
             return found[sha]
+        if ext_resolve_ref:
+            try:
+                return ext_resolve_ref(sha)
+            except KeyError:
+                pass
         raise Postpone, (sha, )
     todo = list(self.iterobjects())
     while todo:
@@ -433,8 +438,8 @@ class PackData(object):
     if postponed:
         raise KeyError([sha_to_hex(h) for h in postponed.keys()])
 
-  def sorted_entries(self):
-    ret = list(self.iterentries())
+  def sorted_entries(self, resolve_ext_ref=None):
+    ret = list(self.iterentries(resolve_ext_ref))
     ret.sort()
     return ret
 
@@ -684,7 +689,7 @@ def write_pack_index_v2(filename, entries, pack_checksum):
     """
     f = open(filename, 'w')
     f = SHA1Writer(f)
-    f.write('\377tOc')
+    f.write('\377tOc') # Magic!
     f.write(struct.pack(">L", 2))
     fan_out_table = defaultdict(lambda: 0)
     for (name, offset, entry_checksum) in entries:
@@ -761,26 +766,28 @@ class Pack(object):
         """Check whether this pack contains a particular SHA1."""
         return (self.idx.object_index(sha1) is not None)
 
-    def _get_text(self, sha1):
+    def get_raw(self, sha1, resolve_ref=None):
+        if resolve_ref is None:
+            resolve_ref = self.get_raw
         offset = self.idx.object_index(sha1)
         if offset is None:
             raise KeyError(sha1)
 
         type, obj = self.data.get_object_at(offset)
         assert isinstance(offset, int)
-        return resolve_object(offset, type, obj, self._get_text, 
+        return resolve_object(offset, type, obj, resolve_ref,
             self.data.get_object_at)
 
     def __getitem__(self, sha1):
         """Retrieve the specified SHA1."""
-        type, uncomp = self._get_text(sha1)
+        type, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type, uncomp)
 
     def iterobjects(self):
         for offset, type, obj in self.data.iterobjects():
             assert isinstance(offset, int)
             yield ShaFile.from_raw_string(
-                    *resolve_object(offset, type, obj, self._get_text, 
+                    *resolve_object(offset, type, obj, self.get_raw, 
                 self.data.get_object_at))
 
 

+ 2 - 1
dulwich/protocol.py

@@ -35,6 +35,7 @@ class ProtocolFile(object):
     def close(self):
         pass
 
+
 class Protocol(object):
 
     def __init__(self, read, write):
@@ -95,7 +96,7 @@ class Protocol(object):
         :param cmd: The remote service to access
         :param args: List of arguments to send to remove service
         """
-        self.proto.write_pkt_line("%s %s" % (name, "".join(["%s\0" % a for a in args])))
+        self.write_pkt_line("%s %s" % (cmd, "".join(["%s\0" % a for a in args])))
 
     def read_cmd(self):
         """

+ 89 - 38
dulwich/repo.py

@@ -20,7 +20,7 @@
 import os
 
 from commit import Commit
-from errors import MissingCommitError
+from errors import MissingCommitError, NotBlobError, NotTreeError, NotCommitError
 from objects import (ShaFile,
                      Commit,
                      Tree,
@@ -55,7 +55,7 @@ class Repo(object):
       self._basedir = root
     self.path = controldir
     self.tags = [Tag(name, ref) for name, ref in self.get_tags().items()]
-    self._packs = None
+    self._object_store = None
 
   def basedir(self):
     return self._basedir
@@ -63,29 +63,15 @@ class Repo(object):
   def object_dir(self):
     return os.path.join(self.basedir(), OBJECTDIR)
 
+  @property
+  def object_store(self):
+    if self._object_store is None:
+        self._object_store = ObjectStore(self.object_dir())
+    return self._object_store
+
   def pack_dir(self):
     return os.path.join(self.object_dir(), PACKDIR)
 
-  def add_pack(self):
-    fd, path = tempfile.mkstemp(dir=self.pack_dir(), suffix=".pack")
-    f = os.fdopen(fd, 'w')
-    def commit():
-       if os.path.getsize(path) > 0:
-           self._move_in_pack(path)
-    return f, commit
-
-  def _move_in_pack(self, path):
-    p = PackData(path)
-    entries = p.sorted_entries()
-    basename = os.path.join(self.pack_dir(), "pack-%s" % iter_sha1(entry[0] for entry in entries))
-    write_pack_index_v2(basename+".idx", entries, p.calculate_checksum())
-    os.rename(path, basename + ".pack")
-
-  def _get_packs(self):
-    if self._packs is None:
-        self._packs = list(load_packs(self.pack_dir()))
-    return self._packs
-
   def _get_ref(self, file):
     f = open(file, 'rb')
     try:
@@ -134,22 +120,20 @@ class Repo(object):
     return self.ref('HEAD')
 
   def _get_object(self, sha, cls):
-    assert len(sha) == 40, "Incorrect length sha: %s" % str(sha)
-    dir = sha[:2]
-    file = sha[2:]
-    # Check from object dir
-    path = os.path.join(self.object_dir(), dir, file)
-    if os.path.exists(path):
-      return cls.from_file(path)
-    # Check from packs
-    for pack in self._get_packs():
-        if sha in pack:
-            return pack[sha]
-    # Should this raise instead?
-    return None
+    ret = self.get_object(sha)
+    if ret._type != cls._type:
+        if cls is Commit:
+            raise NotCommitError(ret)
+        elif cls is Blob:
+            raise NotBlobError(ret)
+        elif cls is Tree:
+            raise NotTreeError(ret)
+        else:
+            raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
+    return ret
 
   def get_object(self, sha):
-    return self._get_object(sha, ShaFile)
+    return self.object_store[sha]
 
   def get_parents(self, sha):
     return self.commit(sha).parents
@@ -180,8 +164,9 @@ class Repo(object):
     history = []
     while pending_commits != []:
       head = pending_commits.pop(0)
-      commit = self.commit(head)
-      if commit is None:
+      try:
+          commit = self.commit(head)
+      except KeyError:
         raise MissingCommitError(head)
       if commit in history:
         continue
@@ -214,3 +199,69 @@ class Repo(object):
 
   create = init_bare
 
+
+class ObjectStore(object):
+
+    def __init__(self, path):
+        self.path = path
+        self._packs = None
+
+    def pack_dir(self):
+        return os.path.join(self.path, PACKDIR)
+
+    def __contains__(self, sha):
+        # TODO: This can be more efficient
+        try:
+            self[sha]
+            return True
+        except KeyError:
+            return False
+
+    @property
+    def packs(self):
+        if self._packs is None:
+            self._packs = list(load_packs(self.pack_dir()))
+        return self._packs
+
+    def _get_shafile(self, sha):
+        dir = sha[:2]
+        file = sha[2:]
+        # Check from object dir
+        path = os.path.join(self.path, dir, file)
+        if os.path.exists(path):
+          return ShaFile.from_file(path)
+        return None
+
+    def get_raw(self, sha):
+        for pack in self.packs:
+            if sha in pack:
+                return pack.get_raw(sha, self.get_raw)
+        # FIXME: Are pack deltas ever against on-disk shafiles ?
+        ret = self._get_shafile(sha)
+        if ret is not None:
+            return ret.as_raw_string()
+        raise KeyError(sha)
+
+    def __getitem__(self, sha):
+        assert len(sha) == 40, "Incorrect length sha: %s" % str(sha)
+        ret = self._get_shafile(sha)
+        if ret is not None:
+            return ret
+        # Check from packs
+        type, uncomp = self.get_raw(sha)
+        return ShaFile.from_raw_string(type, uncomp)
+
+    def move_in_pack(self, path):
+        p = PackData(path)
+        entries = p.sorted_entries(self.get_raw)
+        basename = os.path.join(self.pack_dir(), "pack-%s" % iter_sha1(entry[0] for entry in entries))
+        write_pack_index_v2(basename+".idx", entries, p.calculate_checksum())
+        os.rename(path, basename + ".pack")
+
+    def add_pack(self):
+        fd, path = tempfile.mkstemp(dir=self.pack_dir(), suffix=".pack")
+        f = os.fdopen(fd, 'w')
+        def commit():
+            if os.path.getsize(path) > 0:
+                self.move_in_pack(path)
+        return f, commit

+ 1 - 2
dulwich/tests/test_repository.py

@@ -52,8 +52,7 @@ class RepositoryTests(unittest.TestCase):
 
   def test_get_object_non_existant(self):
     r = self.open_repo('a')
-    obj = r.get_object(missing_sha)
-    self.assertEqual(obj, None)
+    self.assertRaises(KeyError, r.get_object, missing_sha)
 
   def test_commit(self):
     r = self.open_repo('a')