Browse Source

Merge milki's shallow branch support.

Jelmer Vernooij 11 years ago
parent
commit
606fa97b20

+ 2 - 0
NEWS

@@ -7,6 +7,8 @@
  * New module `dulwich.objectspec` for parsing strings referencing
  * New module `dulwich.objectspec` for parsing strings referencing
    objects and commit ranges. (Jelmer Vernooij)
    objects and commit ranges. (Jelmer Vernooij)
 
 
+ * Add shallow branch support. (milki)
+
  CHANGES
  CHANGES
 
 
  * Drop support for Python 2.4 and 2.5. (Jelmer Vernooij)
  * Drop support for Python 2.4 and 2.5. (Jelmer Vernooij)

+ 18 - 8
dulwich/object_store.py

@@ -162,7 +162,8 @@ class BaseObjectStore(object):
                 yield entry
                 yield entry
 
 
     def find_missing_objects(self, haves, wants, progress=None,
     def find_missing_objects(self, haves, wants, progress=None,
-                             get_tagged=None):
+                             get_tagged=None,
+                             get_parents=lambda commit: commit.parents):
         """Find the missing objects required for a set of revisions.
         """Find the missing objects required for a set of revisions.
 
 
         :param haves: Iterable over SHAs already in common.
         :param haves: Iterable over SHAs already in common.
@@ -171,9 +172,10 @@ class BaseObjectStore(object):
             updated progress strings.
             updated progress strings.
         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
             sha for including tags.
             sha for including tags.
+        :param get_parents: Optional function for getting the parents of a commit.
         :return: Iterator over (sha, path) pairs.
         :return: Iterator over (sha, path) pairs.
         """
         """
-        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged)
+        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged, get_parents=get_parents)
         return iter(finder.next, None)
         return iter(finder.next, None)
 
 
     def find_common_revisions(self, graphwalker):
     def find_common_revisions(self, graphwalker):
@@ -215,12 +217,14 @@ class BaseObjectStore(object):
             obj = self[sha]
             obj = self[sha]
         return obj
         return obj
 
 
-    def _collect_ancestors(self, heads, common=set()):
+    def _collect_ancestors(self, heads, common=set(),
+                           get_parents=lambda commit: commit.parents):
         """Collect all ancestors of heads up to (excluding) those in common.
         """Collect all ancestors of heads up to (excluding) those in common.
 
 
         :param heads: commits to start from
         :param heads: commits to start from
         :param common: commits to end at, or empty set to walk repository
         :param common: commits to end at, or empty set to walk repository
             completely
             completely
+        :param get_parents: Optional function for getting the parents of a commit.
         :return: a tuple (A, B) where A - all commits reachable
         :return: a tuple (A, B) where A - all commits reachable
             from heads but not present in common, B - common (shared) elements
             from heads but not present in common, B - common (shared) elements
             that are directly reachable from heads
             that are directly reachable from heads
@@ -236,7 +240,7 @@ class BaseObjectStore(object):
             elif e not in commits:
             elif e not in commits:
                 commits.add(e)
                 commits.add(e)
                 cmt = self[e]
                 cmt = self[e]
-                queue.extend(cmt.parents)
+                queue.extend(get_parents(cmt))
         return (commits, bases)
         return (commits, bases)
 
 
     def close(self):
     def close(self):
@@ -970,12 +974,14 @@ class MissingObjectFinder(object):
     :param progress: Optional function to report progress to.
     :param progress: Optional function to report progress to.
     :param get_tagged: Function that returns a dict of pointed-to sha -> tag
     :param get_tagged: Function that returns a dict of pointed-to sha -> tag
         sha for including tags.
         sha for including tags.
+    :param get_parents: Optional function for getting the parents of a commit.
     :param tagged: dict of pointed-to sha -> tag sha for including tags
     :param tagged: dict of pointed-to sha -> tag sha for including tags
     """
     """
 
 
     def __init__(self, object_store, haves, wants, progress=None,
     def __init__(self, object_store, haves, wants, progress=None,
-                 get_tagged=None):
+            get_tagged=None, get_parents=lambda commit: commit.parents):
         self.object_store = object_store
         self.object_store = object_store
+        self._get_parents = get_parents
         # process Commits and Tags differently
         # process Commits and Tags differently
         # Note, while haves may list commits/tags not available locally,
         # Note, while haves may list commits/tags not available locally,
         # and such SHAs would get filtered out by _split_commits_and_tags,
         # and such SHAs would get filtered out by _split_commits_and_tags,
@@ -987,12 +993,16 @@ class MissingObjectFinder(object):
                 _split_commits_and_tags(object_store, wants, False)
                 _split_commits_and_tags(object_store, wants, False)
         # all_ancestors is a set of commits that shall not be sent
         # all_ancestors is a set of commits that shall not be sent
         # (complete repository up to 'haves')
         # (complete repository up to 'haves')
-        all_ancestors = object_store._collect_ancestors(have_commits)[0]
+        all_ancestors = object_store._collect_ancestors(
+                have_commits,
+                get_parents=self._get_parents)[0]
         # all_missing - complete set of commits between haves and wants
         # all_missing - complete set of commits between haves and wants
         # common - commits from all_ancestors we hit into while
         # common - commits from all_ancestors we hit into while
         # traversing parent hierarchy of wants
         # traversing parent hierarchy of wants
-        missing_commits, common_commits = \
-            object_store._collect_ancestors(want_commits, all_ancestors)
+        missing_commits, common_commits = object_store._collect_ancestors(
+            want_commits,
+            all_ancestors,
+            get_parents=self._get_parents);
         self.sha_done = set()
         self.sha_done = set()
         # Now, fill sha_done with commits and revisions of
         # Now, fill sha_done with commits and revisions of
         # files and directories known to be both locally
         # files and directories known to be both locally

+ 31 - 3
dulwich/repo.py

@@ -248,14 +248,38 @@ class BaseRepo(object):
         wants = determine_wants(self.get_refs())
         wants = determine_wants(self.get_refs())
         if type(wants) is not list:
         if type(wants) is not list:
             raise TypeError("determine_wants() did not return a list")
             raise TypeError("determine_wants() did not return a list")
+
+        shallows = getattr(graph_walker, 'shallow', frozenset())
+        unshallows = getattr(graph_walker, 'unshallow', frozenset())
+
         if wants == []:
         if wants == []:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
             # TODO(dborowitz): find a way to short-circuit that doesn't change
             # this interface.
             # this interface.
+
+            if shallows or unshallows:
+                # Do not send a pack in shallow short-circuit path
+                return None
+
             return []
             return []
+
         haves = self.object_store.find_common_revisions(graph_walker)
         haves = self.object_store.find_common_revisions(graph_walker)
+
+        # Deal with shallow requests separately because the haves do
+        # not reflect what objects are missing
+        if shallows or unshallows:
+            haves = []  # TODO: filter the haves commits from iter_shas.
+                        # the specific commits aren't missing.
+
+        def get_parents(commit):
+            if commit.id in shallows:
+                return []
+            return self.get_parents(commit.id, commit)
+
         return self.object_store.iter_shas(
         return self.object_store.iter_shas(
-          self.object_store.find_missing_objects(haves, wants, progress,
-                                                 get_tagged))
+          self.object_store.find_missing_objects(
+              haves, wants, progress,
+              get_tagged,
+              get_parents=get_parents))
 
 
     def get_graph_walker(self, heads=None):
     def get_graph_walker(self, heads=None):
         """Retrieve a graph walker.
         """Retrieve a graph walker.
@@ -632,9 +656,13 @@ class Repo(BaseRepo):
         refs = DiskRefsContainer(self.controldir())
         refs = DiskRefsContainer(self.controldir())
         BaseRepo.__init__(self, object_store, refs)
         BaseRepo.__init__(self, object_store, refs)
 
 
+        self._graftpoints = {}
         graft_file = self.get_named_file(os.path.join("info", "grafts"))
         graft_file = self.get_named_file(os.path.join("info", "grafts"))
         if graft_file:
         if graft_file:
-            self._graftpoints = parse_graftpoints(graft_file)
+            self._graftpoints.update(parse_graftpoints(graft_file))
+        graft_file = self.get_named_file("shallow")
+        if graft_file:
+            self._graftpoints.update(parse_graftpoints(graft_file))
 
 
         self.hooks['pre-commit'] = PreCommitShellHook(self.controldir())
         self.hooks['pre-commit'] = PreCommitShellHook(self.controldir())
         self.hooks['commit-msg'] = CommitMsgShellHook(self.controldir())
         self.hooks['commit-msg'] = CommitMsgShellHook(self.controldir())

+ 88 - 7
dulwich/server.py

@@ -59,6 +59,7 @@ from dulwich.errors import (
 from dulwich import log_utils
 from dulwich import log_utils
 from dulwich.objects import (
 from dulwich.objects import (
     hex_to_sha,
     hex_to_sha,
+    Commit,
     )
     )
 from dulwich.pack import (
 from dulwich.pack import (
     write_pack_objects,
     write_pack_objects,
@@ -226,7 +227,7 @@ class UploadPackHandler(Handler):
     @classmethod
     @classmethod
     def capabilities(cls):
     def capabilities(cls):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
-                "ofs-delta", "no-progress", "include-tag")
+                "ofs-delta", "no-progress", "include-tag", "shallow")
 
 
     @classmethod
     @classmethod
     def required_capabilities(cls):
     def required_capabilities(cls):
@@ -315,14 +316,55 @@ def _split_proto_line(line, allowed):
     try:
     try:
         if len(fields) == 1 and command in ('done', None):
         if len(fields) == 1 and command in ('done', None):
             return (command, None)
             return (command, None)
-        elif len(fields) == 2 and command in ('want', 'have'):
-            hex_to_sha(fields[1])
-            return tuple(fields)
+        elif len(fields) == 2:
+            if command in ('want', 'have', 'shallow', 'unshallow'):
+                hex_to_sha(fields[1])
+                return tuple(fields)
+            elif command == 'deepen':
+                return command, int(fields[1])
     except (TypeError, AssertionError), e:
     except (TypeError, AssertionError), e:
         raise GitProtocolError(e)
         raise GitProtocolError(e)
     raise GitProtocolError('Received invalid line from client: %s' % line)
     raise GitProtocolError('Received invalid line from client: %s' % line)
 
 
 
 
+def _find_shallow(store, heads, depth):
+    """Find shallow commits according to a given depth.
+
+    :param store: An ObjectStore for looking up objects.
+    :param heads: Iterable of head SHAs to start walking from.
+    :param depth: The depth of ancestors to include.
+    :return: A tuple of (shallow, not_shallow), sets of SHAs that should be
+        considered shallow and unshallow according to the arguments. Note that
+        these sets may overlap if a commit is reachable along multiple paths.
+    """
+    parents = {}
+    def get_parents(sha):
+        result = parents.get(sha, None)
+        if not result:
+            result = store[sha].parents
+            parents[sha] = result
+        return result
+
+    todo = []  # stack of (sha, depth)
+    for head_sha in heads:
+        obj = store.peel_sha(head_sha)
+        if isinstance(obj, Commit):
+            todo.append((obj.id, 0))
+
+    not_shallow = set()
+    shallow = set()
+    while todo:
+        sha, cur_depth = todo.pop()
+        if cur_depth < depth:
+            not_shallow.add(sha)
+            new_depth = cur_depth + 1
+            todo.extend((p, new_depth) for p in get_parents(sha))
+        else:
+            shallow.add(sha)
+
+    return shallow, not_shallow
+
+
 class ProtocolGraphWalker(object):
 class ProtocolGraphWalker(object):
     """A graph walker that knows the git protocol.
     """A graph walker that knows the git protocol.
 
 
@@ -344,6 +386,9 @@ class ProtocolGraphWalker(object):
         self.http_req = handler.http_req
         self.http_req = handler.http_req
         self.advertise_refs = handler.advertise_refs
         self.advertise_refs = handler.advertise_refs
         self._wants = []
         self._wants = []
+        self.shallow = set()
+        self.client_shallow = set()
+        self.unshallow = set()
         self._cached = False
         self._cached = False
         self._cache = []
         self._cache = []
         self._cache_index = 0
         self._cache_index = 0
@@ -357,6 +402,12 @@ class ProtocolGraphWalker(object):
         same regardless of ack type, and in fact is used to set the ack type of
         same regardless of ack type, and in fact is used to set the ack type of
         the ProtocolGraphWalker.
         the ProtocolGraphWalker.
 
 
+        If the client has the 'shallow' capability, this method also reads and
+        responds to the 'shallow' and 'deepen' lines from the client. These are
+        not part of the wants per se, but they set up necessary state for
+        walking the graph. Additionally, later code depends on this method
+        consuming everything up to the first 'have' line.
+
         :param heads: a dict of refname->SHA1 to advertise
         :param heads: a dict of refname->SHA1 to advertise
         :return: a list of SHA1s requested by the client
         :return: a list of SHA1s requested by the client
         """
         """
@@ -389,11 +440,11 @@ class ProtocolGraphWalker(object):
         line, caps = extract_want_line_capabilities(want)
         line, caps = extract_want_line_capabilities(want)
         self.handler.set_client_capabilities(caps)
         self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
         self.set_ack_type(ack_type(caps))
-        allowed = ('want', None)
+        allowed = ('want', 'shallow', 'deepen', None)
         command, sha = _split_proto_line(line, allowed)
         command, sha = _split_proto_line(line, allowed)
 
 
         want_revs = []
         want_revs = []
-        while command != None:
+        while command == 'want':
             if sha not in values:
             if sha not in values:
                 raise GitProtocolError(
                 raise GitProtocolError(
                   'Client wants invalid object %s' % sha)
                   'Client wants invalid object %s' % sha)
@@ -401,6 +452,9 @@ class ProtocolGraphWalker(object):
             command, sha = self.read_proto_line(allowed)
             command, sha = self.read_proto_line(allowed)
 
 
         self.set_wants(want_revs)
         self.set_wants(want_revs)
+        if command in ('shallow', 'deepen'):
+            self.unread_proto_line(command, sha)
+            self._handle_shallow_request(want_revs)
 
 
         if self.http_req and self.proto.eof():
         if self.http_req and self.proto.eof():
             # The client may close the socket at this point, expecting a
             # The client may close the socket at this point, expecting a
@@ -410,6 +464,9 @@ class ProtocolGraphWalker(object):
 
 
         return want_revs
         return want_revs
 
 
+    def unread_proto_line(self, command, value):
+        self.proto.unread_pkt_line('%s %s' % (command, value))
+
     def ack(self, have_ref):
     def ack(self, have_ref):
         return self._impl.ack(have_ref)
         return self._impl.ack(have_ref)
 
 
@@ -432,10 +489,34 @@ class ProtocolGraphWalker(object):
 
 
         :param allowed: An iterable of command names that should be allowed.
         :param allowed: An iterable of command names that should be allowed.
         :return: A tuple of (command, value); see _split_proto_line.
         :return: A tuple of (command, value); see _split_proto_line.
-        :raise GitProtocolError: If an error occurred reading the line.
+        :raise UnexpectedCommandError: If an error occurred reading the line.
         """
         """
         return _split_proto_line(self.proto.read_pkt_line(), allowed)
         return _split_proto_line(self.proto.read_pkt_line(), allowed)
 
 
+    def _handle_shallow_request(self, wants):
+        while True:
+            command, val = self.read_proto_line(('deepen', 'shallow'))
+            if command == 'deepen':
+                depth = val
+                break
+            self.client_shallow.add(val)
+        self.read_proto_line((None,))  # consume client's flush-pkt
+
+        shallow, not_shallow = _find_shallow(self.store, wants, depth)
+
+        # Update self.shallow instead of reassigning it since we passed a
+        # reference to it before this method was called.
+        self.shallow.update(shallow - not_shallow)
+        new_shallow = self.shallow - self.client_shallow
+        unshallow = self.unshallow = not_shallow & self.client_shallow
+
+        for sha in sorted(new_shallow):
+            self.proto.write_pkt_line('shallow %s' % sha)
+        for sha in sorted(unshallow):
+            self.proto.write_pkt_line('unshallow %s' % sha)
+
+        self.proto.write_pkt_line(None)
+
     def send_ack(self, sha, ack_type=''):
     def send_ack(self, sha, ack_type=''):
         if ack_type:
         if ack_type:
             ack_type = ' %s' % ack_type
             ack_type = ' %s' % ack_type

+ 95 - 0
dulwich/tests/compat/server_utils.py

@@ -28,6 +28,7 @@ import tempfile
 import threading
 import threading
 
 
 from dulwich.repo import Repo
 from dulwich.repo import Repo
+from dulwich.objects import hex_to_sha
 from dulwich.server import (
 from dulwich.server import (
     ReceivePackHandler,
     ReceivePackHandler,
     )
     )
@@ -40,6 +41,32 @@ from dulwich.tests.compat.utils import (
     )
     )
 
 
 
 
+class _StubRepo(object):
+    """A stub repo that just contains a path to tear down."""
+
+    def __init__(self, name):
+        temp_dir = tempfile.mkdtemp()
+        self.path = os.path.join(temp_dir, name)
+        os.mkdir(self.path)
+
+
+def _get_shallow(repo):
+    shallow_file = repo.get_named_file('shallow')
+    if not shallow_file:
+        return []
+    shallows = []
+    try:
+        for line in shallow_file:
+            sha = line.strip()
+            if not sha:
+                continue
+            hex_to_sha(sha)
+            shallows.append(sha)
+    finally:
+        shallow_file.close()
+    return shallows
+
+
 class ServerTests(object):
 class ServerTests(object):
     """Base tests for testing servers.
     """Base tests for testing servers.
 
 
@@ -71,7 +98,9 @@ class ServerTests(object):
 
 
     def test_push_to_dulwich_no_op(self):
     def test_push_to_dulwich_no_op(self):
         self._old_repo = import_repo('server_old.export')
         self._old_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._old_repo)
         self._new_repo = import_repo('server_old.export')
         self._new_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
         port = self._start_server(self._old_repo)
 
 
@@ -81,7 +110,9 @@ class ServerTests(object):
 
 
     def test_push_to_dulwich_remove_branch(self):
     def test_push_to_dulwich_remove_branch(self):
         self._old_repo = import_repo('server_old.export')
         self._old_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._old_repo)
         self._new_repo = import_repo('server_old.export')
         self._new_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
         port = self._start_server(self._old_repo)
 
 
@@ -104,7 +135,9 @@ class ServerTests(object):
 
 
     def test_fetch_from_dulwich_no_op(self):
     def test_fetch_from_dulwich_no_op(self):
         self._old_repo = import_repo('server_old.export')
         self._old_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._old_repo)
         self._new_repo = import_repo('server_old.export')
         self._new_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._new_repo)
         port = self._start_server(self._new_repo)
 
 
@@ -118,6 +151,7 @@ class ServerTests(object):
         old_repo_dir = os.path.join(tempfile.mkdtemp(), 'empty_old')
         old_repo_dir = os.path.join(tempfile.mkdtemp(), 'empty_old')
         run_git_or_fail(['init', '--quiet', '--bare', old_repo_dir])
         run_git_or_fail(['init', '--quiet', '--bare', old_repo_dir])
         self._old_repo = Repo(old_repo_dir)
         self._old_repo = Repo(old_repo_dir)
+        self.addCleanup(tear_down_repo, self._old_repo)
         port = self._start_server(self._old_repo)
         port = self._start_server(self._old_repo)
 
 
         new_repo_base_dir = tempfile.mkdtemp()
         new_repo_base_dir = tempfile.mkdtemp()
@@ -138,6 +172,67 @@ class ServerTests(object):
         o = run_git_or_fail(['ls-remote', self.url(port)])
         o = run_git_or_fail(['ls-remote', self.url(port)])
         self.assertEqual(len(o.split('\n')), 4)
         self.assertEqual(len(o.split('\n')), 4)
 
 
+    def test_new_shallow_clone_from_dulwich(self):
+        self._source_repo = import_repo('server_new.export')
+        self.addCleanup(tear_down_repo, self._source_repo)
+        self._stub_repo = _StubRepo('shallow')
+        self.addCleanup(tear_down_repo, self._stub_repo)
+        port = self._start_server(self._source_repo)
+
+        # Fetch at depth 1
+        run_git_or_fail(['clone', '--mirror', '--depth=1', '--no-single-branch',
+                        self.url(port), self._stub_repo.path])
+        clone = self._stub_repo = Repo(self._stub_repo.path)
+        expected_shallow = ['94de09a530df27ac3bb613aaecdd539e0a0655e1',
+                            'da5cd81e1883c62a25bb37c4d1f8ad965b29bf8d']
+        self.assertEqual(expected_shallow, _get_shallow(clone))
+        self.assertReposNotEqual(clone, self._source_repo)
+
+    def test_fetch_same_depth_into_shallow_clone_from_dulwich(self):
+        self._source_repo = import_repo('server_new.export')
+        self.addCleanup(tear_down_repo, self._source_repo)
+        self._stub_repo = _StubRepo('shallow')
+        self.addCleanup(tear_down_repo, self._stub_repo)
+        port = self._start_server(self._source_repo)
+
+        # Fetch at depth 1
+        run_git_or_fail(['clone', '--mirror', '--depth=1', '--no-single-branch',
+                        self.url(port), self._stub_repo.path])
+        clone = self._stub_repo = Repo(self._stub_repo.path)
+
+        # Fetching at the same depth is a no-op.
+        run_git_or_fail(
+          ['fetch', '--depth=1', self.url(port)] + self.branch_args(),
+          cwd=self._stub_repo.path)
+        expected_shallow = ['94de09a530df27ac3bb613aaecdd539e0a0655e1',
+                            'da5cd81e1883c62a25bb37c4d1f8ad965b29bf8d']
+        self.assertEqual(expected_shallow, _get_shallow(clone))
+        self.assertReposNotEqual(clone, self._source_repo)
+
+    def test_fetch_full_depth_into_shallow_clone_from_dulwich(self):
+        self._source_repo = import_repo('server_new.export')
+        self.addCleanup(tear_down_repo, self._source_repo)
+        self._stub_repo = _StubRepo('shallow')
+        self.addCleanup(tear_down_repo, self._stub_repo)
+        port = self._start_server(self._source_repo)
+
+        # Fetch at depth 1
+        run_git_or_fail(['clone', '--mirror', '--depth=1', '--no-single-branch',
+                        self.url(port), self._stub_repo.path])
+        clone = self._stub_repo = Repo(self._stub_repo.path)
+
+        # Fetching at the same depth is a no-op.
+        run_git_or_fail(
+          ['fetch', '--depth=1', self.url(port)] + self.branch_args(),
+          cwd=self._stub_repo.path)
+
+        # The whole repo only has depth 3, so it should equal server_new.
+        run_git_or_fail(
+          ['fetch', '--depth=3', self.url(port)] + self.branch_args(),
+          cwd=self._stub_repo.path)
+        self.assertEqual([], _get_shallow(clone))
+        self.assertReposEqual(clone, self._source_repo)
+
 
 
 class ShutdownServerMixIn:
 class ShutdownServerMixIn:
     """Mixin that allows serve_forever to be shut down.
     """Mixin that allows serve_forever to be shut down.

+ 16 - 1
dulwich/tests/compat/test_web.py

@@ -134,9 +134,24 @@ class DumbWebTestCase(WebTests, CompatTestCase):
         return make_wsgi_chain(backend, dumb=True)
         return make_wsgi_chain(backend, dumb=True)
 
 
     def test_push_to_dulwich(self):
     def test_push_to_dulwich(self):
-        # Note: remove this if dumb pushing is supported
+        # Note: remove this if dulwich implements dumb web pushing.
         raise SkipTest('Dumb web pushing not supported.')
         raise SkipTest('Dumb web pushing not supported.')
 
 
     def test_push_to_dulwich_remove_branch(self):
     def test_push_to_dulwich_remove_branch(self):
         # Note: remove this if dumb pushing is supported
         # Note: remove this if dumb pushing is supported
         raise SkipTest('Dumb web pushing not supported.')
         raise SkipTest('Dumb web pushing not supported.')
+
+    def test_new_shallow_clone_from_dulwich(self):
+        # Note: remove this if C git and dulwich implement dumb web shallow
+        # clones.
+        raise SkipTest('Dumb web shallow cloning not supported.')
+
+    def test_fetch_same_depth_into_shallow_clone_from_dulwich(self):
+        # Note: remove this if C git and dulwich implement dumb web shallow
+        # clones.
+        raise SkipTest('Dumb web shallow cloning not supported.')
+
+    def test_fetch_full_depth_into_shallow_clone_from_dulwich(self):
+        # Note: remove this if C git and dulwich implement dumb web shallow
+        # clones.
+        raise SkipTest('Dumb web shallow cloning not supported.')

+ 123 - 0
dulwich/tests/test_server.py

@@ -27,6 +27,13 @@ from dulwich.errors import (
     NotGitRepository,
     NotGitRepository,
     UnexpectedCommandError,
     UnexpectedCommandError,
     )
     )
+from dulwich.objects import (
+    Commit,
+    Tag,
+    )
+from dulwich.object_store import (
+    MemoryObjectStore,
+    )
 from dulwich.repo import (
 from dulwich.repo import (
     MemoryRepo,
     MemoryRepo,
     Repo,
     Repo,
@@ -40,6 +47,7 @@ from dulwich.server import (
     MultiAckDetailedGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
     _split_proto_line,
     _split_proto_line,
     serve_command,
     serve_command,
+    _find_shallow,
     ProtocolGraphWalker,
     ProtocolGraphWalker,
     ReceivePackHandler,
     ReceivePackHandler,
     SingleAckGraphWalkerImpl,
     SingleAckGraphWalkerImpl,
@@ -49,6 +57,7 @@ from dulwich.server import (
 from dulwich.tests import TestCase
 from dulwich.tests import TestCase
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     make_commit,
     make_commit,
+    make_object,
     )
     )
 from dulwich.protocol import (
 from dulwich.protocol import (
     ZERO_SHA,
     ZERO_SHA,
@@ -197,6 +206,81 @@ class UploadPackHandlerTestCase(TestCase):
         self.assertEqual({}, self._handler.get_tagged(refs, repo=self._repo))
         self.assertEqual({}, self._handler.get_tagged(refs, repo=self._repo))
 
 
 
 
+class FindShallowTests(TestCase):
+
+    def setUp(self):
+        self._store = MemoryObjectStore()
+
+    def make_commit(self, **attrs):
+        commit = make_commit(**attrs)
+        self._store.add_object(commit)
+        return commit
+
+    def make_linear_commits(self, n, message=''):
+        commits = []
+        parents = []
+        for _ in xrange(n):
+            commits.append(self.make_commit(parents=parents, message=message))
+            parents = [commits[-1].id]
+        return commits
+
+    def assertSameElements(self, expected, actual):
+        self.assertEqual(set(expected), set(actual))
+
+    def test_linear(self):
+        c1, c2, c3 = self.make_linear_commits(3)
+
+        self.assertEqual((set([c3.id]), set([])),
+                         _find_shallow(self._store, [c3.id], 0))
+        self.assertEqual((set([c2.id]), set([c3.id])),
+                         _find_shallow(self._store, [c3.id], 1))
+        self.assertEqual((set([c1.id]), set([c2.id, c3.id])),
+                         _find_shallow(self._store, [c3.id], 2))
+        self.assertEqual((set([]), set([c1.id, c2.id, c3.id])),
+                         _find_shallow(self._store, [c3.id], 3))
+
+    def test_multiple_independent(self):
+        a = self.make_linear_commits(2, message='a')
+        b = self.make_linear_commits(2, message='b')
+        c = self.make_linear_commits(2, message='c')
+        heads = [a[1].id, b[1].id, c[1].id]
+
+        self.assertEqual((set([a[0].id, b[0].id, c[0].id]), set(heads)),
+                         _find_shallow(self._store, heads, 1))
+
+    def test_multiple_overlapping(self):
+        # Create the following commit tree:
+        # 1--2
+        #  \
+        #   3--4
+        c1, c2 = self.make_linear_commits(2)
+        c3 = self.make_commit(parents=[c1.id])
+        c4 = self.make_commit(parents=[c3.id])
+
+        # 1 is shallow along the path from 4, but not along the path from 2.
+        self.assertEqual((set([c1.id]), set([c1.id, c2.id, c3.id, c4.id])),
+                         _find_shallow(self._store, [c2.id, c4.id], 2))
+
+    def test_merge(self):
+        c1 = self.make_commit()
+        c2 = self.make_commit()
+        c3 = self.make_commit(parents=[c1.id, c2.id])
+
+        self.assertEqual((set([c1.id, c2.id]), set([c3.id])),
+                         _find_shallow(self._store, [c3.id], 1))
+
+    def test_tag(self):
+        c1, c2 = self.make_linear_commits(2)
+        tag = make_object(Tag, name='tag', message='',
+                          tagger='Tagger <test@example.com>',
+                          tag_time=12345, tag_timezone=0,
+                          object=(Commit, c2.id))
+        self._store.add_object(tag)
+
+        self.assertEqual((set([c1.id]), set([c2.id])),
+                         _find_shallow(self._store, [tag.id], 1))
+
+
 class TestUploadPackHandler(UploadPackHandler):
 class TestUploadPackHandler(UploadPackHandler):
     @classmethod
     @classmethod
     def required_capabilities(self):
     def required_capabilities(self):
@@ -354,6 +438,45 @@ class ProtocolGraphWalkerTestCase(TestCase):
 
 
     # TODO: test commit time cutoff
     # TODO: test commit time cutoff
 
 
+    def _handle_shallow_request(self, lines, heads):
+        self._walker.proto.set_output(lines)
+        self._walker._handle_shallow_request(heads)
+
+    def assertReceived(self, expected):
+        self.assertEquals(
+          expected, list(iter(self._walker.proto.get_received_line, None)))
+
+    def test_handle_shallow_request_no_client_shallows(self):
+        self._handle_shallow_request(['deepen 1\n'], [FOUR, FIVE])
+        self.assertEquals(set([TWO, THREE]), self._walker.shallow)
+        self.assertReceived([
+          'shallow %s' % TWO,
+          'shallow %s' % THREE,
+          ])
+
+    def test_handle_shallow_request_no_new_shallows(self):
+        lines = [
+          'shallow %s\n' % TWO,
+          'shallow %s\n' % THREE,
+          'deepen 1\n',
+          ]
+        self._handle_shallow_request(lines, [FOUR, FIVE])
+        self.assertEquals(set([TWO, THREE]), self._walker.shallow)
+        self.assertReceived([])
+
+    def test_handle_shallow_request_unshallows(self):
+        lines = [
+          'shallow %s\n' % TWO,
+          'deepen 2\n',
+          ]
+        self._handle_shallow_request(lines, [FOUR, FIVE])
+        self.assertEquals(set([ONE]), self._walker.shallow)
+        self.assertReceived([
+          'shallow %s' % ONE,
+          'unshallow %s' % TWO,
+          # THREE is unshallow but was is not shallow in the client
+          ])
+
 
 
 class TestProtocolGraphWalker(object):
 class TestProtocolGraphWalker(object):