Forráskód Böngészése

Merge milki's shallow branch support.

Jelmer Vernooij 11 éve
szülő
commit
606fa97b20

+ 2 - 0
NEWS

@@ -7,6 +7,8 @@
  * New module `dulwich.objectspec` for parsing strings referencing
    objects and commit ranges. (Jelmer Vernooij)
 
+ * Add shallow branch support. (milki)
+
  CHANGES
 
  * Drop support for Python 2.4 and 2.5. (Jelmer Vernooij)

+ 18 - 8
dulwich/object_store.py

@@ -162,7 +162,8 @@ class BaseObjectStore(object):
                 yield entry
 
     def find_missing_objects(self, haves, wants, progress=None,
-                             get_tagged=None):
+                             get_tagged=None,
+                             get_parents=lambda commit: commit.parents):
         """Find the missing objects required for a set of revisions.
 
         :param haves: Iterable over SHAs already in common.
@@ -171,9 +172,10 @@ class BaseObjectStore(object):
             updated progress strings.
         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
             sha for including tags.
+        :param get_parents: Optional function for getting the parents of a commit.
         :return: Iterator over (sha, path) pairs.
         """
-        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged)
+        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged, get_parents=get_parents)
         return iter(finder.next, None)
 
     def find_common_revisions(self, graphwalker):
@@ -215,12 +217,14 @@ class BaseObjectStore(object):
             obj = self[sha]
         return obj
 
-    def _collect_ancestors(self, heads, common=set()):
+    def _collect_ancestors(self, heads, common=set(),
+                           get_parents=lambda commit: commit.parents):
         """Collect all ancestors of heads up to (excluding) those in common.
 
         :param heads: commits to start from
         :param common: commits to end at, or empty set to walk repository
             completely
+        :param get_parents: Optional function for getting the parents of a commit.
         :return: a tuple (A, B) where A - all commits reachable
             from heads but not present in common, B - common (shared) elements
             that are directly reachable from heads
@@ -236,7 +240,7 @@ class BaseObjectStore(object):
             elif e not in commits:
                 commits.add(e)
                 cmt = self[e]
-                queue.extend(cmt.parents)
+                queue.extend(get_parents(cmt))
         return (commits, bases)
 
     def close(self):
@@ -970,12 +974,14 @@ class MissingObjectFinder(object):
     :param progress: Optional function to report progress to.
     :param get_tagged: Function that returns a dict of pointed-to sha -> tag
         sha for including tags.
+    :param get_parents: Optional function for getting the parents of a commit.
     :param tagged: dict of pointed-to sha -> tag sha for including tags
     """
 
     def __init__(self, object_store, haves, wants, progress=None,
-                 get_tagged=None):
+            get_tagged=None, get_parents=lambda commit: commit.parents):
         self.object_store = object_store
+        self._get_parents = get_parents
         # process Commits and Tags differently
         # Note, while haves may list commits/tags not available locally,
         # and such SHAs would get filtered out by _split_commits_and_tags,
@@ -987,12 +993,16 @@ class MissingObjectFinder(object):
                 _split_commits_and_tags(object_store, wants, False)
         # all_ancestors is a set of commits that shall not be sent
         # (complete repository up to 'haves')
-        all_ancestors = object_store._collect_ancestors(have_commits)[0]
+        all_ancestors = object_store._collect_ancestors(
+                have_commits,
+                get_parents=self._get_parents)[0]
         # all_missing - complete set of commits between haves and wants
         # common - commits from all_ancestors we hit into while
         # traversing parent hierarchy of wants
-        missing_commits, common_commits = \
-            object_store._collect_ancestors(want_commits, all_ancestors)
+        missing_commits, common_commits = object_store._collect_ancestors(
+            want_commits,
+            all_ancestors,
+            get_parents=self._get_parents);
         self.sha_done = set()
         # Now, fill sha_done with commits and revisions of
         # files and directories known to be both locally

+ 31 - 3
dulwich/repo.py

@@ -248,14 +248,38 @@ class BaseRepo(object):
         wants = determine_wants(self.get_refs())
         if type(wants) is not list:
             raise TypeError("determine_wants() did not return a list")
+
+        shallows = getattr(graph_walker, 'shallow', frozenset())
+        unshallows = getattr(graph_walker, 'unshallow', frozenset())
+
         if wants == []:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
             # this interface.
+
+            if shallows or unshallows:
+                # Do not send a pack in shallow short-circuit path
+                return None
+
             return []
+
         haves = self.object_store.find_common_revisions(graph_walker)
+
+        # Deal with shallow requests separately because the haves do
+        # not reflect what objects are missing
+        if shallows or unshallows:
+            haves = []  # TODO: filter the haves commits from iter_shas.
+                        # the specific commits aren't missing.
+
+        def get_parents(commit):
+            if commit.id in shallows:
+                return []
+            return self.get_parents(commit.id, commit)
+
         return self.object_store.iter_shas(
-          self.object_store.find_missing_objects(haves, wants, progress,
-                                                 get_tagged))
+          self.object_store.find_missing_objects(
+              haves, wants, progress,
+              get_tagged,
+              get_parents=get_parents))
 
     def get_graph_walker(self, heads=None):
         """Retrieve a graph walker.
@@ -632,9 +656,13 @@ class Repo(BaseRepo):
         refs = DiskRefsContainer(self.controldir())
         BaseRepo.__init__(self, object_store, refs)
 
+        self._graftpoints = {}
         graft_file = self.get_named_file(os.path.join("info", "grafts"))
         if graft_file:
-            self._graftpoints = parse_graftpoints(graft_file)
+            self._graftpoints.update(parse_graftpoints(graft_file))
+        graft_file = self.get_named_file("shallow")
+        if graft_file:
+            self._graftpoints.update(parse_graftpoints(graft_file))
 
         self.hooks['pre-commit'] = PreCommitShellHook(self.controldir())
         self.hooks['commit-msg'] = CommitMsgShellHook(self.controldir())

+ 88 - 7
dulwich/server.py

@@ -59,6 +59,7 @@ from dulwich.errors import (
 from dulwich import log_utils
 from dulwich.objects import (
     hex_to_sha,
+    Commit,
     )
 from dulwich.pack import (
     write_pack_objects,
@@ -226,7 +227,7 @@ class UploadPackHandler(Handler):
     @classmethod
     def capabilities(cls):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
-                "ofs-delta", "no-progress", "include-tag")
+                "ofs-delta", "no-progress", "include-tag", "shallow")
 
     @classmethod
     def required_capabilities(cls):
@@ -315,14 +316,55 @@ def _split_proto_line(line, allowed):
     try:
         if len(fields) == 1 and command in ('done', None):
             return (command, None)
-        elif len(fields) == 2 and command in ('want', 'have'):
-            hex_to_sha(fields[1])
-            return tuple(fields)
+        elif len(fields) == 2:
+            if command in ('want', 'have', 'shallow', 'unshallow'):
+                hex_to_sha(fields[1])
+                return tuple(fields)
+            elif command == 'deepen':
+                return command, int(fields[1])
     except (TypeError, AssertionError), e:
         raise GitProtocolError(e)
     raise GitProtocolError('Received invalid line from client: %s' % line)
 
 
+def _find_shallow(store, heads, depth):
+    """Find shallow commits according to a given depth.
+
+    :param store: An ObjectStore for looking up objects.
+    :param heads: Iterable of head SHAs to start walking from.
+    :param depth: The depth of ancestors to include.
+    :return: A tuple of (shallow, not_shallow), sets of SHAs that should be
+        considered shallow and unshallow according to the arguments. Note that
+        these sets may overlap if a commit is reachable along multiple paths.
+    """
+    parents = {}
+    def get_parents(sha):
+        result = parents.get(sha, None)
+        if not result:
+            result = store[sha].parents
+            parents[sha] = result
+        return result
+
+    todo = []  # stack of (sha, depth)
+    for head_sha in heads:
+        obj = store.peel_sha(head_sha)
+        if isinstance(obj, Commit):
+            todo.append((obj.id, 0))
+
+    not_shallow = set()
+    shallow = set()
+    while todo:
+        sha, cur_depth = todo.pop()
+        if cur_depth < depth:
+            not_shallow.add(sha)
+            new_depth = cur_depth + 1
+            todo.extend((p, new_depth) for p in get_parents(sha))
+        else:
+            shallow.add(sha)
+
+    return shallow, not_shallow
+
+
 class ProtocolGraphWalker(object):
     """A graph walker that knows the git protocol.
 
@@ -344,6 +386,9 @@ class ProtocolGraphWalker(object):
         self.http_req = handler.http_req
         self.advertise_refs = handler.advertise_refs
         self._wants = []
+        self.shallow = set()
+        self.client_shallow = set()
+        self.unshallow = set()
         self._cached = False
         self._cache = []
         self._cache_index = 0
@@ -357,6 +402,12 @@ class ProtocolGraphWalker(object):
         same regardless of ack type, and in fact is used to set the ack type of
         the ProtocolGraphWalker.
 
+        If the client has the 'shallow' capability, this method also reads and
+        responds to the 'shallow' and 'deepen' lines from the client. These are
+        not part of the wants per se, but they set up necessary state for
+        walking the graph. Additionally, later code depends on this method
+        consuming everything up to the first 'have' line.
+
         :param heads: a dict of refname->SHA1 to advertise
         :return: a list of SHA1s requested by the client
         """
@@ -389,11 +440,11 @@ class ProtocolGraphWalker(object):
         line, caps = extract_want_line_capabilities(want)
         self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
-        allowed = ('want', None)
+        allowed = ('want', 'shallow', 'deepen', None)
         command, sha = _split_proto_line(line, allowed)
 
         want_revs = []
-        while command != None:
+        while command == 'want':
             if sha not in values:
                 raise GitProtocolError(
                   'Client wants invalid object %s' % sha)
@@ -401,6 +452,9 @@ class ProtocolGraphWalker(object):
             command, sha = self.read_proto_line(allowed)
 
         self.set_wants(want_revs)
+        if command in ('shallow', 'deepen'):
+            self.unread_proto_line(command, sha)
+            self._handle_shallow_request(want_revs)
 
         if self.http_req and self.proto.eof():
             # The client may close the socket at this point, expecting a
@@ -410,6 +464,9 @@ class ProtocolGraphWalker(object):
 
         return want_revs
 
+    def unread_proto_line(self, command, value):
+        self.proto.unread_pkt_line('%s %s' % (command, value))
+
     def ack(self, have_ref):
         return self._impl.ack(have_ref)
 
@@ -432,10 +489,34 @@ class ProtocolGraphWalker(object):
 
         :param allowed: An iterable of command names that should be allowed.
         :return: A tuple of (command, value); see _split_proto_line.
-        :raise GitProtocolError: If an error occurred reading the line.
+        :raise UnexpectedCommandError: If an error occurred reading the line.
         """
         return _split_proto_line(self.proto.read_pkt_line(), allowed)
 
+    def _handle_shallow_request(self, wants):
+        while True:
+            command, val = self.read_proto_line(('deepen', 'shallow'))
+            if command == 'deepen':
+                depth = val
+                break
+            self.client_shallow.add(val)
+        self.read_proto_line((None,))  # consume client's flush-pkt
+
+        shallow, not_shallow = _find_shallow(self.store, wants, depth)
+
+        # Update self.shallow instead of reassigning it since we passed a
+        # reference to it before this method was called.
+        self.shallow.update(shallow - not_shallow)
+        new_shallow = self.shallow - self.client_shallow
+        unshallow = self.unshallow = not_shallow & self.client_shallow
+
+        for sha in sorted(new_shallow):
+            self.proto.write_pkt_line('shallow %s' % sha)
+        for sha in sorted(unshallow):
+            self.proto.write_pkt_line('unshallow %s' % sha)
+
+        self.proto.write_pkt_line(None)
+
     def send_ack(self, sha, ack_type=''):
         if ack_type:
             ack_type = ' %s' % ack_type

+ 95 - 0
dulwich/tests/compat/server_utils.py

@@ -28,6 +28,7 @@ import tempfile
 import threading
 
 from dulwich.repo import Repo
+from dulwich.objects import hex_to_sha
 from dulwich.server import (
     ReceivePackHandler,
     )
@@ -40,6 +41,32 @@ from dulwich.tests.compat.utils import (
     )
 
 
+class _StubRepo(object):
+    """A stub repo that just contains a path to tear down."""
+
+    def __init__(self, name):
+        temp_dir = tempfile.mkdtemp()
+        self.path = os.path.join(temp_dir, name)
+        os.mkdir(self.path)
+
+
+def _get_shallow(repo):
+    shallow_file = repo.get_named_file('shallow')
+    if not shallow_file:
+        return []
+    shallows = []
+    try:
+        for line in shallow_file:
+            sha = line.strip()
+            if not sha:
+                continue
+            hex_to_sha(sha)
+            shallows.append(sha)
+    finally:
+        shallow_file.close()
+    return shallows
+
+
 class ServerTests(object):
     """Base tests for testing servers.
 
@@ -71,7 +98,9 @@ class ServerTests(object):
 
     def test_push_to_dulwich_no_op(self):
         self._old_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._old_repo)
         self._new_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
 
@@ -81,7 +110,9 @@ class ServerTests(object):
 
     def test_push_to_dulwich_remove_branch(self):
         self._old_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._old_repo)
         self._new_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
 
@@ -104,7 +135,9 @@ class ServerTests(object):
 
     def test_fetch_from_dulwich_no_op(self):
         self._old_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._old_repo)
         self._new_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._new_repo)
 
@@ -118,6 +151,7 @@ class ServerTests(object):
         old_repo_dir = os.path.join(tempfile.mkdtemp(), 'empty_old')
         run_git_or_fail(['init', '--quiet', '--bare', old_repo_dir])
         self._old_repo = Repo(old_repo_dir)
+        self.addCleanup(tear_down_repo, self._old_repo)
         port = self._start_server(self._old_repo)
 
         new_repo_base_dir = tempfile.mkdtemp()
@@ -138,6 +172,67 @@ class ServerTests(object):
         o = run_git_or_fail(['ls-remote', self.url(port)])
         self.assertEqual(len(o.split('\n')), 4)
 
+    def test_new_shallow_clone_from_dulwich(self):
+        self._source_repo = import_repo('server_new.export')
+        self.addCleanup(tear_down_repo, self._source_repo)
+        self._stub_repo = _StubRepo('shallow')
+        self.addCleanup(tear_down_repo, self._stub_repo)
+        port = self._start_server(self._source_repo)
+
+        # Fetch at depth 1
+        run_git_or_fail(['clone', '--mirror', '--depth=1', '--no-single-branch',
+                        self.url(port), self._stub_repo.path])
+        clone = self._stub_repo = Repo(self._stub_repo.path)
+        expected_shallow = ['94de09a530df27ac3bb613aaecdd539e0a0655e1',
+                            'da5cd81e1883c62a25bb37c4d1f8ad965b29bf8d']
+        self.assertEqual(expected_shallow, _get_shallow(clone))
+        self.assertReposNotEqual(clone, self._source_repo)
+
+    def test_fetch_same_depth_into_shallow_clone_from_dulwich(self):
+        self._source_repo = import_repo('server_new.export')
+        self.addCleanup(tear_down_repo, self._source_repo)
+        self._stub_repo = _StubRepo('shallow')
+        self.addCleanup(tear_down_repo, self._stub_repo)
+        port = self._start_server(self._source_repo)
+
+        # Fetch at depth 1
+        run_git_or_fail(['clone', '--mirror', '--depth=1', '--no-single-branch',
+                        self.url(port), self._stub_repo.path])
+        clone = self._stub_repo = Repo(self._stub_repo.path)
+
+        # Fetching at the same depth is a no-op.
+        run_git_or_fail(
+          ['fetch', '--depth=1', self.url(port)] + self.branch_args(),
+          cwd=self._stub_repo.path)
+        expected_shallow = ['94de09a530df27ac3bb613aaecdd539e0a0655e1',
+                            'da5cd81e1883c62a25bb37c4d1f8ad965b29bf8d']
+        self.assertEqual(expected_shallow, _get_shallow(clone))
+        self.assertReposNotEqual(clone, self._source_repo)
+
+    def test_fetch_full_depth_into_shallow_clone_from_dulwich(self):
+        self._source_repo = import_repo('server_new.export')
+        self.addCleanup(tear_down_repo, self._source_repo)
+        self._stub_repo = _StubRepo('shallow')
+        self.addCleanup(tear_down_repo, self._stub_repo)
+        port = self._start_server(self._source_repo)
+
+        # Fetch at depth 1
+        run_git_or_fail(['clone', '--mirror', '--depth=1', '--no-single-branch',
+                        self.url(port), self._stub_repo.path])
+        clone = self._stub_repo = Repo(self._stub_repo.path)
+
+        # Fetching at the same depth is a no-op.
+        run_git_or_fail(
+          ['fetch', '--depth=1', self.url(port)] + self.branch_args(),
+          cwd=self._stub_repo.path)
+
+        # The whole repo only has depth 3, so it should equal server_new.
+        run_git_or_fail(
+          ['fetch', '--depth=3', self.url(port)] + self.branch_args(),
+          cwd=self._stub_repo.path)
+        self.assertEqual([], _get_shallow(clone))
+        self.assertReposEqual(clone, self._source_repo)
+
 
 class ShutdownServerMixIn:
     """Mixin that allows serve_forever to be shut down.

+ 16 - 1
dulwich/tests/compat/test_web.py

@@ -134,9 +134,24 @@ class DumbWebTestCase(WebTests, CompatTestCase):
         return make_wsgi_chain(backend, dumb=True)
 
     def test_push_to_dulwich(self):
-        # Note: remove this if dumb pushing is supported
+        # Note: remove this if dulwich implements dumb web pushing.
         raise SkipTest('Dumb web pushing not supported.')
 
     def test_push_to_dulwich_remove_branch(self):
         # Note: remove this if dumb pushing is supported
         raise SkipTest('Dumb web pushing not supported.')
+
+    def test_new_shallow_clone_from_dulwich(self):
+        # Note: remove this if C git and dulwich implement dumb web shallow
+        # clones.
+        raise SkipTest('Dumb web shallow cloning not supported.')
+
+    def test_fetch_same_depth_into_shallow_clone_from_dulwich(self):
+        # Note: remove this if C git and dulwich implement dumb web shallow
+        # clones.
+        raise SkipTest('Dumb web shallow cloning not supported.')
+
+    def test_fetch_full_depth_into_shallow_clone_from_dulwich(self):
+        # Note: remove this if C git and dulwich implement dumb web shallow
+        # clones.
+        raise SkipTest('Dumb web shallow cloning not supported.')

+ 123 - 0
dulwich/tests/test_server.py

@@ -27,6 +27,13 @@ from dulwich.errors import (
     NotGitRepository,
     UnexpectedCommandError,
     )
+from dulwich.objects import (
+    Commit,
+    Tag,
+    )
+from dulwich.object_store import (
+    MemoryObjectStore,
+    )
 from dulwich.repo import (
     MemoryRepo,
     Repo,
@@ -40,6 +47,7 @@ from dulwich.server import (
     MultiAckDetailedGraphWalkerImpl,
     _split_proto_line,
     serve_command,
+    _find_shallow,
     ProtocolGraphWalker,
     ReceivePackHandler,
     SingleAckGraphWalkerImpl,
@@ -49,6 +57,7 @@ from dulwich.server import (
 from dulwich.tests import TestCase
 from dulwich.tests.utils import (
     make_commit,
+    make_object,
     )
 from dulwich.protocol import (
     ZERO_SHA,
@@ -197,6 +206,81 @@ class UploadPackHandlerTestCase(TestCase):
         self.assertEqual({}, self._handler.get_tagged(refs, repo=self._repo))
 
 
+class FindShallowTests(TestCase):
+
+    def setUp(self):
+        self._store = MemoryObjectStore()
+
+    def make_commit(self, **attrs):
+        commit = make_commit(**attrs)
+        self._store.add_object(commit)
+        return commit
+
+    def make_linear_commits(self, n, message=''):
+        commits = []
+        parents = []
+        for _ in xrange(n):
+            commits.append(self.make_commit(parents=parents, message=message))
+            parents = [commits[-1].id]
+        return commits
+
+    def assertSameElements(self, expected, actual):
+        self.assertEqual(set(expected), set(actual))
+
+    def test_linear(self):
+        c1, c2, c3 = self.make_linear_commits(3)
+
+        self.assertEqual((set([c3.id]), set([])),
+                         _find_shallow(self._store, [c3.id], 0))
+        self.assertEqual((set([c2.id]), set([c3.id])),
+                         _find_shallow(self._store, [c3.id], 1))
+        self.assertEqual((set([c1.id]), set([c2.id, c3.id])),
+                         _find_shallow(self._store, [c3.id], 2))
+        self.assertEqual((set([]), set([c1.id, c2.id, c3.id])),
+                         _find_shallow(self._store, [c3.id], 3))
+
+    def test_multiple_independent(self):
+        a = self.make_linear_commits(2, message='a')
+        b = self.make_linear_commits(2, message='b')
+        c = self.make_linear_commits(2, message='c')
+        heads = [a[1].id, b[1].id, c[1].id]
+
+        self.assertEqual((set([a[0].id, b[0].id, c[0].id]), set(heads)),
+                         _find_shallow(self._store, heads, 1))
+
+    def test_multiple_overlapping(self):
+        # Create the following commit tree:
+        # 1--2
+        #  \
+        #   3--4
+        c1, c2 = self.make_linear_commits(2)
+        c3 = self.make_commit(parents=[c1.id])
+        c4 = self.make_commit(parents=[c3.id])
+
+        # 1 is shallow along the path from 4, but not along the path from 2.
+        self.assertEqual((set([c1.id]), set([c1.id, c2.id, c3.id, c4.id])),
+                         _find_shallow(self._store, [c2.id, c4.id], 2))
+
+    def test_merge(self):
+        c1 = self.make_commit()
+        c2 = self.make_commit()
+        c3 = self.make_commit(parents=[c1.id, c2.id])
+
+        self.assertEqual((set([c1.id, c2.id]), set([c3.id])),
+                         _find_shallow(self._store, [c3.id], 1))
+
+    def test_tag(self):
+        c1, c2 = self.make_linear_commits(2)
+        tag = make_object(Tag, name='tag', message='',
+                          tagger='Tagger <test@example.com>',
+                          tag_time=12345, tag_timezone=0,
+                          object=(Commit, c2.id))
+        self._store.add_object(tag)
+
+        self.assertEqual((set([c1.id]), set([c2.id])),
+                         _find_shallow(self._store, [tag.id], 1))
+
+
 class TestUploadPackHandler(UploadPackHandler):
     @classmethod
     def required_capabilities(self):
@@ -354,6 +438,45 @@ class ProtocolGraphWalkerTestCase(TestCase):
 
     # TODO: test commit time cutoff
 
+    def _handle_shallow_request(self, lines, heads):
+        self._walker.proto.set_output(lines)
+        self._walker._handle_shallow_request(heads)
+
+    def assertReceived(self, expected):
+        self.assertEquals(
+          expected, list(iter(self._walker.proto.get_received_line, None)))
+
+    def test_handle_shallow_request_no_client_shallows(self):
+        self._handle_shallow_request(['deepen 1\n'], [FOUR, FIVE])
+        self.assertEquals(set([TWO, THREE]), self._walker.shallow)
+        self.assertReceived([
+          'shallow %s' % TWO,
+          'shallow %s' % THREE,
+          ])
+
+    def test_handle_shallow_request_no_new_shallows(self):
+        lines = [
+          'shallow %s\n' % TWO,
+          'shallow %s\n' % THREE,
+          'deepen 1\n',
+          ]
+        self._handle_shallow_request(lines, [FOUR, FIVE])
+        self.assertEquals(set([TWO, THREE]), self._walker.shallow)
+        self.assertReceived([])
+
+    def test_handle_shallow_request_unshallows(self):
+        lines = [
+          'shallow %s\n' % TWO,
+          'deepen 2\n',
+          ]
+        self._handle_shallow_request(lines, [FOUR, FIVE])
+        self.assertEquals(set([ONE]), self._walker.shallow)
+        self.assertReceived([
+          'shallow %s' % ONE,
+          'unshallow %s' % TWO,
+          # THREE is unshallow but was is not shallow in the client
+          ])
+
 
 class TestProtocolGraphWalker(object):