Kaynağa Gözat

Merge upstream

John Carr 16 yıl önce
ebeveyn
işleme
38939454d1
6 değiştirilmiş dosya ile 290 ekleme ve 14 silme
  1. 46 0
      bin/dul-fetch-pack
  2. 1 0
      bin/dumppack
  3. 185 0
      dulwich/client.py
  4. 34 13
      dulwich/pack.py
  5. 19 1
      dulwich/repo.py
  6. 5 0
      dulwich/tests/test_pack.py

+ 46 - 0
bin/dul-fetch-pack

@@ -0,0 +1,46 @@
+#!/usr/bin/python
+# dul-daemon - Simple git smart server client
+# Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
+# 
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# or (at your option) a later version of the License.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+from dulwich.client import TCPGitClient, SimpleFetchGraphWalker
+from dulwich.repo import Repo
+import sys
+
+(host, path) = sys.argv[1].split(":", 1)
+client = TCPGitClient(host)
+
+all = True
+
+if all:
+    determine_wants = lambda x: x.values()
+else:
+    determine_wants = lambda x: sys.argv[1:]
+
+r = Repo(".")
+
+# FIXME: Will just fetch everything..
+graphwalker = SimpleFetchGraphWalker(r.heads().values(), r.get_parents)
+
+f, commit = r.add_pack()
+try:
+    client.fetch_pack(path, determine_wants, graphwalker, f.write, sys.stdout.write)
+    f.close()
+    commit()
+except:
+    f.close()
+    raise

+ 1 - 0
bin/dumppack

@@ -24,6 +24,7 @@ import sys
 
 basename = sys.argv[1]
 x = Pack(basename)
+print "Object names checksum: %s" % x.name()
 print "Checksum: %s" % sha_to_hex(x.get_stored_checksum())
 if not x.check():
     print "CHECKSUM DOES NOT MATCH"

+ 185 - 0
dulwich/client.py

@@ -0,0 +1,185 @@
+# server.py -- Implementation of the server side git protocols
+# Copryight (C) 2008 Jelmer Vernooij <jelmer@samba.org>
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+import select
+import socket
+
+
+def extract_capabilities(text):
+    if not "\0" in text:
+        return text
+    capabilities = text.split("\0")
+    return (capabilities[0], capabilities[1:])
+
+
+class SimpleFetchGraphWalker(object):
+
+    def __init__(self, local_heads, get_parents):
+        self.heads = set(local_heads)
+        self.get_parents = get_parents
+        self.parents = {}
+
+    def ack(self, ref):
+        if ref in self.heads:
+            self.heads.remove(ref)
+        if not ref in self.parents:
+            return
+        for p in self.parents[ref]:
+            self.ack(p)
+
+    def next(self):
+        if self.heads:
+            ret = self.heads.pop()
+            ps = self.get_parents(ret)
+            self.parents[ret] = ps
+            self.heads.update(ps)
+            return ret
+        return None
+
+
+class GitClient(object):
+    """Git smart server client.
+
+    """
+
+    def __init__(self, fileno, read, write, host):
+        self.read = read
+        self.write = write
+        self.fileno = fileno
+        self.host = host
+
+    def read_pkt_line(self):
+        """
+        Reads a 'pkt line' from the remote git process
+
+        :return: The next string from the stream
+        """
+        sizestr = self.read(4)
+        if sizestr == "":
+            raise RuntimeError("Socket broken")
+        size = int(sizestr, 16)
+        if size == 0:
+            return None
+        return self.read(size-4)
+
+    def read_pkt_seq(self):
+        pkt = self.read_pkt_line()
+        while pkt:
+            yield pkt
+            pkt = self.read_pkt_line()
+
+    def write_pkt_line(self, line):
+        """
+        Sends a 'pkt line' to the remote git process
+
+        :param line: A string containing the data to send
+        """
+        if line is None:
+            self.write("0000")
+        else:
+            self.write("%04x%s" % (len(line)+4, line))
+
+    def send_cmd(self, name, *args):
+        self.write_pkt_line("%s %s" % (name, "".join(["%s\0" % a for a in args])))
+
+    def capabilities(self):
+        return "multi_ack side-band-64k thin-pack ofs-delta"
+
+    def read_refs(self):
+        server_capabilities = None
+        refs = {}
+        # Receive refs from server
+        for pkt in self.read_pkt_seq():
+            (sha, ref) = pkt.rstrip("\n").split(" ", 1)
+            if server_capabilities is None:
+                (ref, server_capabilities) = extract_capabilities(ref)
+            if not (ref == "capabilities^{}" and sha == "0" * 40):
+                refs[ref] = sha
+        return refs, server_capabilities
+
+    def send_pack(self, path):
+        self.send_cmd("git-receive-pack", path, "host=%s" % self.host)
+        refs, server_capabilities = self.read_refs()
+        changed_refs = [] # FIXME
+        if not changed_refs:
+            self.write_pkt_line(None)
+            return
+        self.write_pkt_line("%s %s %s\0%s" % (changed_refs[0][0], changed_refs[0][1], changed_refs[0][2], self.capabilities()))
+        for changed_ref in changed_refs[:]:
+            self.write_pkt_line("%s %s %s" % changed_refs)
+        self.write_pkt_line(None)
+        # FIXME: Send pack
+
+    def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
+        """Retrieve a pack from a git smart server.
+
+        :param determine_wants: Callback that returns list of commits to fetch
+        :param graph_walker: Object with next() and ack().
+        :param pack_data: Callback called for each bit of data in the pack
+        :param progress: Callback for progress reports (strings)
+        """
+        self.send_cmd("git-upload-pack", path, "host=%s" % self.host)
+
+        (refs, server_capabilities) = self.read_refs()
+       
+        wants = determine_wants(refs)
+        if not wants:
+            self.write_pkt_line(None)
+            return
+        self.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities()))
+        for want in wants[1:]:
+            self.write_pkt_line("want %s\n" % want)
+        self.write_pkt_line(None)
+        have = graph_walker.next()
+        while have:
+            self.write_pkt_line("have %s\n" % have)
+            if len(select.select([self.fileno], [], [], 0)[0]) > 0:
+                pkt = self.read_pkt_line()
+                parts = pkt.rstrip("\n").split(" ")
+                if parts[0] == "ACK":
+                    graph_walker.ack(parts[1])
+                    assert parts[2] == "continue"
+            have = graph_walker.next()
+        self.write_pkt_line("done\n")
+        pkt = self.read_pkt_line()
+        while pkt:
+            parts = pkt.rstrip("\n").split(" ")
+            if parts[0] == "ACK":
+                graph_walker.ack(pkt.split(" ")[1])
+            if len(parts) < 3 or parts[2] != "continue":
+                break
+            pkt = self.read_pkt_line()
+        for pkt in self.read_pkt_seq():
+            channel = ord(pkt[0])
+            pkt = pkt[1:]
+            if channel == 1:
+                pack_data(pkt)
+            elif channel == 2:
+                progress(pkt)
+            else:
+                raise AssertionError("Invalid sideband channel %d" % channel)
+
+
+class TCPGitClient(GitClient):
+
+    def __init__(self, host, port=9418):
+        self._socket = socket.socket()
+        self._socket.connect((host, port))
+        self.rfile = self._socket.makefile('rb', -1)
+        self.wfile = self._socket.makefile('wb', 0)
+        super(TCPGitClient, self).__init__(self._socket.fileno(), self.rfile.read, self.wfile.write, host)

+ 34 - 13
dulwich/pack.py

@@ -35,9 +35,10 @@ a pointer in to the corresponding packfile.
 
 from collections import defaultdict
 import hashlib
-from itertools import izip
+from itertools import imap, izip
 import mmap
 import os
+import sha
 import struct
 import sys
 import zlib
@@ -73,6 +74,13 @@ def read_zlib(data, offset, dec_size):
     return x, comp_len
 
 
+def iter_sha1(iter):
+    sha = hashlib.sha1()
+    for name in iter:
+        sha.update(name)
+    return sha.hexdigest()
+
+
 def hex_to_sha(hex):
   """Convert a hex string to a binary sha string."""
   ret = ""
@@ -241,8 +249,14 @@ class PackIndex(object):
                                   self._crc32_table_offset + i * 4)[0]
 
   def __iter__(self):
+      return imap(sha_to_hex, self._itersha())
+
+  def _itersha(self):
     for i in range(len(self)):
-        yield sha_to_hex(self._unpack_name(i))
+        yield self._unpack_name(i)
+
+  def objects_sha1(self):
+    return iter_sha1(self._itersha())
 
   def iterentries(self):
     """Iterate over the entries in this pack index.
@@ -348,6 +362,7 @@ class PackData(object):
     self._filename = filename
     assert os.path.exists(filename), "%s is not a packfile" % filename
     self._size = os.path.getsize(filename)
+    assert self._size >= 12, "%s is too small for a packfile" % filename
     self._header_size = self._read_header()
 
   def _read_header(self):
@@ -405,13 +420,18 @@ class PackData(object):
         found[sha] = (type, obj)
         yield sha, offset, shafile.crc32()
 
+  def sorted_entries(self):
+    ret = list(self.iterentries())
+    ret.sort()
+    return ret
+
   def create_index_v1(self, filename):
-    entries = list(self.iterentries())
+    entries = self.sorted_entries()
     write_pack_index_v1(filename, entries, self.calculate_checksum())
 
   def create_index_v2(self, filename):
-    entries = list(self.iterentries())
-    write_pack_index_v1(filename, entries, self.calculate_checksum())
+    entries = self.sorted_entries()
+    write_pack_index_v2(filename, entries, self.calculate_checksum())
 
   def get_stored_checksum(self):
     return self._stored_checksum
@@ -534,6 +554,7 @@ def write_pack(filename, objects, num_objects):
         entries, data_sum = write_pack_data(f, objects, num_objects)
     except:
         f.close()
+    entries.sort()
     write_pack_index_v2(filename + ".idx", entries, data_sum)
 
 
@@ -567,9 +588,6 @@ def write_pack_index_v1(filename, entries, pack_checksum):
             crc32_checksum.
     :param pack_checksum: Checksum of the pack file.
     """
-    # Sort entries first
-
-    entries = sorted(entries)
     f = open(filename, 'w')
     f = SHA1Writer(f)
     fan_out_table = defaultdict(lambda: 0)
@@ -651,8 +669,6 @@ def write_pack_index_v2(filename, entries, pack_checksum):
             crc32_checksum.
     :param pack_checksum: Checksum of the pack file.
     """
-    # Sort entries first
-    entries = sorted(entries)
     f = open(filename, 'w')
     f = SHA1Writer(f)
     f.write('\377tOc')
@@ -681,13 +697,18 @@ class Pack(object):
 
     def __init__(self, basename):
         self._basename = basename
+        self._data_path = self._basename + ".pack"
+        self._idx_path = self._basename + ".idx"
         self._data = None
         self._idx = None
 
+    def name(self):
+        return self.idx.objects_sha1()
+
     @property
     def data(self):
         if self._data is None:
-            self._data = PackData(self._basename + ".pack")
+            self._data = PackData(self._data_path)
             assert len(self.idx) == len(self._data)
             assert self.idx.get_stored_checksums()[0] == self._data.get_stored_checksum()
         return self._data
@@ -695,7 +716,7 @@ class Pack(object):
     @property
     def idx(self):
         if self._idx is None:
-            self._idx = PackIndex(self._basename + ".idx")
+            self._idx = PackIndex(self._idx_path)
         return self._idx
 
     def close(self):
@@ -754,5 +775,5 @@ def load_packs(path):
     if not os.path.exists(path):
         return
     for name in os.listdir(path):
-        if name.endswith(".pack"):
+        if name.startswith("pack-") and name.endswith(".pack"):
             yield Pack(os.path.join(path, name[:-len(".pack")]))

+ 19 - 1
dulwich/repo.py

@@ -26,7 +26,8 @@ from objects import (ShaFile,
                      Tree,
                      Blob,
                      )
-from pack import load_packs
+from pack import load_packs, iter_sha1, PackData, write_pack_index_v2
+import tempfile
 
 OBJECTDIR = 'objects'
 PACKDIR = 'pack'
@@ -65,6 +66,20 @@ class Repo(object):
   def pack_dir(self):
     return os.path.join(self.object_dir(), PACKDIR)
 
+  def add_pack(self):
+    fd, path = tempfile.mkstemp(dir=self.pack_dir(), suffix=".pack")
+    f = os.fdopen(fd, 'w')
+    def commit():
+       self._move_in_pack(path)
+    return f, commit
+
+  def _move_in_pack(self, path):
+    p = PackData(path)
+    entries = p.sorted_entries()
+    basename = os.path.join(self.pack_dir(), "pack-%s" % iter_sha1(entry[0] for entry in entries))
+    write_pack_index_v2(basename+".idx", entries, p.calculate_checksum())
+    os.rename(path, basename + ".pack")
+
   def _get_packs(self):
     if self._packs is None:
         self._packs = list(load_packs(self.pack_dir()))
@@ -135,6 +150,9 @@ class Repo(object):
   def get_object(self, sha):
     return self._get_object(sha, ShaFile)
 
+  def get_parents(self, sha):
+    return self.commit(sha).parents
+
   def commit(self, sha):
     return self._get_object(sha, Commit)
 

+ 5 - 0
dulwich/tests/test_pack.py

@@ -171,6 +171,10 @@ class TestPack(PackTests):
         self.assertEquals("James Westby <jw+debian@jameswestby.net>", commit.author)
         self.assertEquals([], commit.parents)
 
+    def test_name(self):
+        p = self.get_pack(pack1_sha)
+        self.assertEquals(pack1_sha, p.name())
+
 
 class TestHexToSha(unittest.TestCase):
 
@@ -194,6 +198,7 @@ class TestPackIndexWriting(object):
     def test_single(self):
         pack_checksum = 'r\x19\x80\xe8f\xaf\x9a_\x93\xadgAD\xe1E\x9b\x8b\xa3\xe7\xb7'
         my_entries = [('og\x0c\x0f\xb5?\x94cv\x0br\x95\xfb\xb8\x14\xe9e\xfb \xc8', 178, 42)]
+        my_entries.sort()
         self._write_fn("single.idx", my_entries, pack_checksum)
         idx = PackIndex("single.idx")
         self.assertEquals(idx.version, self._expected_version)