Explorar el Código

Close Repo, ObjectStore, and Pack objects.

Gary van der Merwé hace 10 años
padre
commit
ff831aa895

+ 2 - 0
NEWS

@@ -4,6 +4,8 @@
 
   * Extended Python3 support to most of the codebase.
     (Gary van der Merwe, Jelmer Vernooij)
+  * The `Repo` object has a new `close` method that can be called to close any
+    open resources. (Gary van der Merwe)
 
 0.10.1  2015-03-25
 

+ 27 - 26
dulwich/client.py

@@ -38,6 +38,7 @@ Known capabilities that are not supported:
 
 __docformat__ = 'restructuredText'
 
+from contextlib import closing
 from io import BytesIO, BufferedReader
 import dulwich
 import select
@@ -702,26 +703,26 @@ class LocalGitClient(GitClient):
         """
         from dulwich.repo import Repo
 
-        target = Repo(path)
-        old_refs = target.get_refs()
-        new_refs = determine_wants(old_refs)
+        with closing(Repo(path)) as target:
+            old_refs = target.get_refs()
+            new_refs = determine_wants(old_refs)
 
-        have = [sha1 for sha1 in old_refs.values() if sha1 != ZERO_SHA]
-        want = []
-        all_refs = set(new_refs.keys()).union(set(old_refs.keys()))
-        for refname in all_refs:
-            old_sha1 = old_refs.get(refname, ZERO_SHA)
-            new_sha1 = new_refs.get(refname, ZERO_SHA)
-            if new_sha1 not in have and new_sha1 != ZERO_SHA:
-                want.append(new_sha1)
+            have = [sha1 for sha1 in old_refs.values() if sha1 != ZERO_SHA]
+            want = []
+            all_refs = set(new_refs.keys()).union(set(old_refs.keys()))
+            for refname in all_refs:
+                old_sha1 = old_refs.get(refname, ZERO_SHA)
+                new_sha1 = new_refs.get(refname, ZERO_SHA)
+                if new_sha1 not in have and new_sha1 != ZERO_SHA:
+                    want.append(new_sha1)
 
-        if not want and old_refs == new_refs:
-            return new_refs
+            if not want and old_refs == new_refs:
+                return new_refs
 
-        target.object_store.add_objects(generate_pack_contents(have, want))
+            target.object_store.add_objects(generate_pack_contents(have, want))
 
-        for name, sha in new_refs.items():
-            target.refs[name] = sha
+            for name, sha in new_refs.items():
+                target.refs[name] = sha
 
         return new_refs
 
@@ -736,9 +737,9 @@ class LocalGitClient(GitClient):
         :return: remote refs as dictionary
         """
         from dulwich.repo import Repo
-        r = Repo(path)
-        return r.fetch(target, determine_wants=determine_wants,
-                       progress=progress)
+        with closing(Repo(path)) as r:
+            return r.fetch(target, determine_wants=determine_wants,
+                           progress=progress)
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
                    progress=None):
@@ -750,14 +751,14 @@ class LocalGitClient(GitClient):
         :param progress: Callback for progress reports (strings)
         """
         from dulwich.repo import Repo
-        r = Repo(path)
-        objects_iter = r.fetch_objects(determine_wants, graph_walker, progress)
+        with closing(Repo(path)) as r:
+            objects_iter = r.fetch_objects(determine_wants, graph_walker, progress)
 
-        # Did the process short-circuit (e.g. in a stateless RPC call)? Note
-        # that the client still expects a 0-object pack in most cases.
-        if objects_iter is None:
-            return
-        write_pack_objects(ProtocolFile(None, pack_data), objects_iter)
+            # Did the process short-circuit (e.g. in a stateless RPC call)? Note
+            # that the client still expects a 0-object pack in most cases.
+            if objects_iter is None:
+                return
+            write_pack_objects(ProtocolFile(None, pack_data), objects_iter)
 
 
 # What Git client to use for local access

+ 0 - 6
dulwich/object_store.py

@@ -248,12 +248,6 @@ class BaseObjectStore(object):
         """Close any files opened by this object store."""
         # Default implementation is a NO-OP
 
-    def __enter__(self):
-        return self
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        self.close()
-
 
 class PackBasedObjectStore(BaseObjectStore):
 

+ 199 - 177
dulwich/porcelain.py

@@ -48,6 +48,10 @@ Differences in behaviour are considered bugs.
 __docformat__ = 'restructuredText'
 
 from collections import namedtuple
+from contextlib import (
+    closing,
+    contextmanager,
+)
 import os
 import sys
 import time
@@ -86,6 +90,22 @@ def open_repo(path_or_repo):
     return Repo(path_or_repo)
 
 
+@contextmanager
+def _noop_context_manager(obj):
+    """Context manager that has the same api as closing but does nothing."""
+    yield obj
+
+
+def open_repo_closing(path_or_repo):
+    """Open an argument that can be a repository or a path for a repository.
+    returns a context manager that will close the repo on exit if the argument
+    is a path, else does nothing if the argument is a repo.
+    """
+    if isinstance(path_or_repo, BaseRepo):
+        return _noop_context_manager(path_or_repo)
+    return closing(Repo(path_or_repo))
+
+
 def archive(location, committish=None, outstream=sys.stdout,
             errstream=sys.stderr):
     """Create an archive.
@@ -110,8 +130,8 @@ def update_server_info(repo="."):
 
     :param repo: path to the repository
     """
-    r = open_repo(repo)
-    server_update_server_info(r)
+    with open_repo_closing(repo) as r:
+        server_update_server_info(r)
 
 
 def symbolic_ref(repo, ref_name, force=False):
@@ -121,11 +141,11 @@ def symbolic_ref(repo, ref_name, force=False):
     :param ref_name: short name of the new ref
     :param force: force settings without checking if it exists in refs/heads
     """
-    repo_obj = open_repo(repo)
-    ref_path = b'refs/heads/' + ref_name
-    if not force and ref_path not in repo_obj.refs.keys():
-        raise ValueError('fatal: ref `%s` is not a ref' % ref_name)
-    repo_obj.refs.set_symbolic_ref(b'HEAD', ref_path)
+    with open_repo_closing(repo) as repo_obj:
+        ref_path = b'refs/heads/' + ref_name
+        if not force and ref_path not in repo_obj.refs.keys():
+            raise ValueError('fatal: ref `%s` is not a ref' % ref_name)
+        repo_obj.refs.set_symbolic_ref(b'HEAD', ref_path)
 
 
 def commit(repo=".", message=None, author=None, committer=None):
@@ -139,9 +159,9 @@ def commit(repo=".", message=None, author=None, committer=None):
     """
     # FIXME: Support --all argument
     # FIXME: Support --signoff argument
-    r = open_repo(repo)
-    return r.do_commit(message=message, author=author,
-        committer=committer)
+    with open_repo_closing(repo) as r:
+        return r.do_commit(message=message, author=author,
+            committer=committer)
 
 
 def commit_tree(repo, tree, message=None, author=None, committer=None):
@@ -152,9 +172,9 @@ def commit_tree(repo, tree, message=None, author=None, committer=None):
     :param author: Optional author name and email
     :param committer: Optional committer name and email
     """
-    r = open_repo(repo)
-    return r.do_commit(message=message, tree=tree, committer=committer,
-            author=author)
+    with open_repo_closing(repo) as r:
+        return r.do_commit(message=message, tree=tree, committer=committer,
+                author=author)
 
 
 def init(path=".", bare=False):
@@ -200,17 +220,22 @@ def clone(source, target=None, bare=False, checkout=None, errstream=sys.stdout,
 
     if not os.path.exists(target):
         os.mkdir(target)
+
     if bare:
         r = Repo.init_bare(target)
     else:
         r = Repo.init(target)
-    remote_refs = client.fetch(host_path, r,
-        determine_wants=r.object_store.determine_wants_all,
-        progress=errstream.write)
-    r[b"HEAD"] = remote_refs[b"HEAD"]
-    if checkout:
-        errstream.write(b'Checking out HEAD')
-        r.reset_index()
+    try:
+        remote_refs = client.fetch(host_path, r,
+            determine_wants=r.object_store.determine_wants_all,
+            progress=errstream.write)
+        r[b"HEAD"] = remote_refs[b"HEAD"]
+        if checkout:
+            errstream.write(b'Checking out HEAD')
+            r.reset_index()
+    except Exception:
+        r.close()
+        raise
 
     return r
 
@@ -222,17 +247,17 @@ def add(repo=".", paths=None):
     :param paths: Paths to add.  No value passed stages all modified files.
     """
     # FIXME: Support patterns, directories.
-    r = open_repo(repo)
-    if not paths:
-        # If nothing is specified, add all non-ignored files.
-        paths = []
-        for dirpath, dirnames, filenames in os.walk(r.path):
-            # Skip .git and below.
-            if '.git' in dirnames:
-                dirnames.remove('.git')
-            for filename in filenames:
-                paths.append(os.path.join(dirpath[len(r.path)+1:], filename))
-    r.stage(paths)
+    with open_repo_closing(repo) as r:
+        if not paths:
+            # If nothing is specified, add all non-ignored files.
+            paths = []
+            for dirpath, dirnames, filenames in os.walk(r.path):
+                # Skip .git and below.
+                if '.git' in dirnames:
+                    dirnames.remove('.git')
+                for filename in filenames:
+                    paths.append(os.path.join(dirpath[len(r.path)+1:], filename))
+        r.stage(paths)
 
 
 def rm(repo=".", paths=None):
@@ -241,11 +266,11 @@ def rm(repo=".", paths=None):
     :param repo: Repository for the files
     :param paths: Paths to remove
     """
-    r = open_repo(repo)
-    index = r.open_index()
-    for p in paths:
-        del index[p.encode(sys.getfilesystemencoding())]
-    index.write()
+    with open_repo_closing(repo) as r:
+        index = r.open_index()
+        for p in paths:
+            del index[p.encode(sys.getfilesystemencoding())]
+        index.write()
 
 
 def commit_decode(commit, contents):
@@ -344,10 +369,10 @@ def log(repo=".", outstream=sys.stdout, max_entries=None):
     :param outstream: Stream to write log output to
     :param max_entries: Optional maximum number of entries to display
     """
-    r = open_repo(repo)
-    walker = r.get_walker(max_entries=max_entries)
-    for entry in walker:
-        print_commit(entry.commit, outstream)
+    with open_repo_closing(repo) as r:
+        walker = r.get_walker(max_entries=max_entries)
+        for entry in walker:
+            print_commit(entry.commit, outstream)
 
 
 def show(repo=".", objects=None, outstream=sys.stdout):
@@ -361,9 +386,9 @@ def show(repo=".", objects=None, outstream=sys.stdout):
         objects = ["HEAD"]
     if not isinstance(objects, list):
         objects = [objects]
-    r = open_repo(repo)
-    for objectish in objects:
-        show_object(r, parse_object(r, objectish), outstream)
+    with open_repo_closing(repo) as r:
+        for objectish in objects:
+            show_object(r, parse_object(r, objectish), outstream)
 
 
 def diff_tree(repo, old_tree, new_tree, outstream=sys.stdout):
@@ -374,8 +399,8 @@ def diff_tree(repo, old_tree, new_tree, outstream=sys.stdout):
     :param new_tree: Id of new tree
     :param outstream: Stream to write to
     """
-    r = open_repo(repo)
-    write_tree_diff(outstream, r.object_store, old_tree, new_tree)
+    with open_repo_closing(repo) as r:
+        write_tree_diff(outstream, r.object_store, old_tree, new_tree)
 
 
 def rev_list(repo, commits, outstream=sys.stdout):
@@ -385,9 +410,9 @@ def rev_list(repo, commits, outstream=sys.stdout):
     :param commits: Commits over which to iterate
     :param outstream: Stream to write to
     """
-    r = open_repo(repo)
-    for entry in r.get_walker(include=[r[c].id for c in commits]):
-        outstream.write(entry.commit.id + b"\n")
+    with open_repo_closing(repo) as r:
+        for entry in r.get_walker(include=[r[c].id for c in commits]):
+            outstream.write(entry.commit.id + b"\n")
 
 
 def tag(*args, **kwargs):
@@ -410,34 +435,34 @@ def tag_create(repo, tag, author=None, message=None, annotated=False,
     :param tag_timezone: Optional timezone for annotated tag
     """
 
-    r = open_repo(repo)
-    object = parse_object(r, objectish)
-
-    if annotated:
-        # Create the tag object
-        tag_obj = Tag()
-        if author is None:
-            # TODO(jelmer): Don't use repo private method.
-            author = r._get_user_identity()
-        tag_obj.tagger = author
-        tag_obj.message = message
-        tag_obj.name = tag
-        tag_obj.object = (type(object), object.id)
-        tag_obj.tag_time = tag_time
-        if tag_time is None:
-            tag_time = int(time.time())
-        if tag_timezone is None:
-            # TODO(jelmer) Use current user timezone rather than UTC
-            tag_timezone = 0
-        elif isinstance(tag_timezone, str):
-            tag_timezone = parse_timezone(tag_timezone)
-        tag_obj.tag_timezone = tag_timezone
-        r.object_store.add_object(tag_obj)
-        tag_id = tag_obj.id
-    else:
-        tag_id = object.id
+    with open_repo_closing(repo) as r:
+        object = parse_object(r, objectish)
+
+        if annotated:
+            # Create the tag object
+            tag_obj = Tag()
+            if author is None:
+                # TODO(jelmer): Don't use repo private method.
+                author = r._get_user_identity()
+            tag_obj.tagger = author
+            tag_obj.message = message
+            tag_obj.name = tag
+            tag_obj.object = (type(object), object.id)
+            tag_obj.tag_time = tag_time
+            if tag_time is None:
+                tag_time = int(time.time())
+            if tag_timezone is None:
+                # TODO(jelmer) Use current user timezone rather than UTC
+                tag_timezone = 0
+            elif isinstance(tag_timezone, str):
+                tag_timezone = parse_timezone(tag_timezone)
+            tag_obj.tag_timezone = tag_timezone
+            r.object_store.add_object(tag_obj)
+            tag_id = tag_obj.id
+        else:
+            tag_id = object.id
 
-    r.refs[b'refs/tags/' + tag] = tag_id
+        r.refs[b'refs/tags/' + tag] = tag_id
 
 
 def list_tags(*args, **kwargs):
@@ -452,10 +477,10 @@ def tag_list(repo, outstream=sys.stdout):
     :param repo: Path to repository
     :param outstream: Stream to write tags to
     """
-    r = open_repo(repo)
-    tags = list(r.refs.as_dict(b"refs/tags"))
-    tags.sort()
-    return tags
+    with open_repo_closing(repo) as r:
+        tags = list(r.refs.as_dict(b"refs/tags"))
+        tags.sort()
+        return tags
 
 
 def tag_delete(repo, name):
@@ -464,15 +489,15 @@ def tag_delete(repo, name):
     :param repo: Path to repository
     :param name: Name of tag to remove
     """
-    r = open_repo(repo)
-    if isinstance(name, bytes):
-        names = [name]
-    elif isinstance(name, list):
-        names = name
-    else:
-        raise TypeError("Unexpected tag name type %r" % name)
-    for name in names:
-        del r.refs[b"refs/tags/" + name]
+    with open_repo_closing(repo) as r:
+        if isinstance(name, bytes):
+            names = [name]
+        elif isinstance(name, list):
+            names = name
+        else:
+            raise TypeError("Unexpected tag name type %r" % name)
+        for name in names:
+            del r.refs[b"refs/tags/" + name]
 
 
 def reset(repo, mode, committish="HEAD"):
@@ -485,10 +510,9 @@ def reset(repo, mode, committish="HEAD"):
     if mode != "hard":
         raise ValueError("hard is the only mode currently supported")
 
-    r = open_repo(repo)
-
-    tree = r[committish].tree
-    r.reset_index()
+    with open_repo_closing(repo) as r:
+        tree = r[committish].tree
+        r.reset_index()
 
 
 def push(repo, remote_location, refs_path,
@@ -503,28 +527,28 @@ def push(repo, remote_location, refs_path,
     """
 
     # Open the repo
-    r = open_repo(repo)
+    with open_repo_closing(repo) as r:
 
-    # Get the client and path
-    client, path = get_transport_and_path(remote_location)
+        # Get the client and path
+        client, path = get_transport_and_path(remote_location)
 
-    def update_refs(refs):
-        new_refs = r.get_refs()
-        refs[refs_path] = new_refs[b'HEAD']
-        del new_refs[b'HEAD']
-        return refs
+        def update_refs(refs):
+            new_refs = r.get_refs()
+            refs[refs_path] = new_refs[b'HEAD']
+            del new_refs[b'HEAD']
+            return refs
 
-    err_encoding = getattr(errstream, 'encoding', 'utf-8')
-    if not isinstance(remote_location, bytes):
-        remote_location_bytes = remote_location.encode(err_encoding)
-    else:
-        remote_location_bytes = remote_location
-    try:
-        client.send_pack(path, update_refs,
-            r.object_store.generate_pack_contents, progress=errstream.write)
-        errstream.write(b"Push to " + remote_location_bytes + b" successful.\n")
-    except (UpdateRefsError, SendPackError) as e:
-        errstream.write(b"Push to " + remote_location_bytes + b" failed -> " + e.message.encode(err_encoding) + b"\n")
+        err_encoding = getattr(errstream, 'encoding', 'utf-8')
+        if not isinstance(remote_location, bytes):
+            remote_location_bytes = remote_location.encode(err_encoding)
+        else:
+            remote_location_bytes = remote_location
+        try:
+            client.send_pack(path, update_refs,
+                r.object_store.generate_pack_contents, progress=errstream.write)
+            errstream.write(b"Push to " + remote_location_bytes + b" successful.\n")
+        except (UpdateRefsError, SendPackError) as e:
+            errstream.write(b"Push to " + remote_location_bytes + b" failed -> " + e.message.encode(err_encoding) + b"\n")
 
 
 def pull(repo, remote_location, refs_path,
@@ -539,15 +563,14 @@ def pull(repo, remote_location, refs_path,
     """
 
     # Open the repo
-    r = open_repo(repo)
-
-    client, path = get_transport_and_path(remote_location)
-    remote_refs = client.fetch(path, r, progress=errstream.write)
-    r[b'HEAD'] = remote_refs[refs_path]
+    with open_repo_closing(repo) as r:
+        client, path = get_transport_and_path(remote_location)
+        remote_refs = client.fetch(path, r, progress=errstream.write)
+        r[b'HEAD'] = remote_refs[refs_path]
 
-    # Perform 'git checkout .' - syncs staged changes
-    tree = r[b"HEAD"].tree
-    r.reset_index()
+        # Perform 'git checkout .' - syncs staged changes
+        tree = r[b"HEAD"].tree
+        r.reset_index()
 
 
 def status(repo="."):
@@ -559,15 +582,14 @@ def status(repo="."):
         unstaged -  list of unstaged paths (diff index/working-tree)
         untracked - list of untracked, un-ignored & non-.git paths
     """
-    r = open_repo(repo)
-
-    # 1. Get status of staged
-    tracked_changes = get_tree_changes(r)
-    # 2. Get status of unstaged
-    unstaged_changes = list(get_unstaged_changes(r.open_index(), r.path))
-    # TODO - Status of untracked - add untracked changes, need gitignore.
-    untracked_changes = []
-    return GitStatus(tracked_changes, unstaged_changes, untracked_changes)
+    with open_repo_closing(repo) as r:
+        # 1. Get status of staged
+        tracked_changes = get_tree_changes(r)
+        # 2. Get status of unstaged
+        unstaged_changes = list(get_unstaged_changes(r.open_index(), r.path))
+        # TODO - Status of untracked - add untracked changes, need gitignore.
+        untracked_changes = []
+        return GitStatus(tracked_changes, unstaged_changes, untracked_changes)
 
 
 def get_tree_changes(repo):
@@ -576,27 +598,27 @@ def get_tree_changes(repo):
     :param repo: repo path or object
     :return: dict with lists for each type of change
     """
-    r = open_repo(repo)
-    index = r.open_index()
-
-    # Compares the Index to the HEAD & determines changes
-    # Iterate through the changes and report add/delete/modify
-    # TODO: call out to dulwich.diff_tree somehow.
-    tracked_changes = {
-        'add': [],
-        'delete': [],
-        'modify': [],
-    }
-    for change in index.changes_from_tree(r.object_store, r[b'HEAD'].tree):
-        if not change[0][0]:
-            tracked_changes['add'].append(change[0][1])
-        elif not change[0][1]:
-            tracked_changes['delete'].append(change[0][0])
-        elif change[0][0] == change[0][1]:
-            tracked_changes['modify'].append(change[0][0])
-        else:
-            raise AssertionError('git mv ops not yet supported')
-    return tracked_changes
+    with open_repo_closing(repo) as r:
+        index = r.open_index()
+
+        # Compares the Index to the HEAD & determines changes
+        # Iterate through the changes and report add/delete/modify
+        # TODO: call out to dulwich.diff_tree somehow.
+        tracked_changes = {
+            'add': [],
+            'delete': [],
+            'modify': [],
+        }
+        for change in index.changes_from_tree(r.object_store, r[b'HEAD'].tree):
+            if not change[0][0]:
+                tracked_changes['add'].append(change[0][1])
+            elif not change[0][1]:
+                tracked_changes['delete'].append(change[0][0])
+            elif change[0][0] == change[0][1]:
+                tracked_changes['modify'].append(change[0][0])
+            else:
+                raise AssertionError('git mv ops not yet supported')
+        return tracked_changes
 
 
 def daemon(path=".", address=None, port=None):
@@ -683,15 +705,15 @@ def branch_delete(repo, name):
     :param repo: Path to the repository
     :param name: Name of the branch
     """
-    r = open_repo(repo)
-    if isinstance(name, bytes):
-        names = [name]
-    elif isinstance(name, list):
-        names = name
-    else:
-        raise TypeError("Unexpected branch name type %r" % name)
-    for name in names:
-        del r.refs[b"refs/heads/" + name]
+    with open_repo_closing(repo) as r:
+        if isinstance(name, bytes):
+            names = [name]
+        elif isinstance(name, list):
+            names = name
+        else:
+            raise TypeError("Unexpected branch name type %r" % name)
+        for name in names:
+            del r.refs[b"refs/heads/" + name]
 
 
 def branch_create(repo, name, objectish=None, force=False):
@@ -702,20 +724,20 @@ def branch_create(repo, name, objectish=None, force=False):
     :param objectish: Target object to point new branch at (defaults to HEAD)
     :param force: Force creation of branch, even if it already exists
     """
-    r = open_repo(repo)
-    if isinstance(name, bytes):
-        names = [name]
-    elif isinstance(name, list):
-        names = name
-    else:
-        raise TypeError("Unexpected branch name type %r" % name)
-    if objectish is None:
-        objectish = "HEAD"
-    object = parse_object(r, objectish)
-    refname = b"refs/heads/" + name
-    if refname in r.refs and not force:
-        raise KeyError("Branch with name %s already exists." % name)
-    r.refs[refname] = object.id
+    with open_repo_closing(repo) as r:
+        if isinstance(name, bytes):
+            names = [name]
+        elif isinstance(name, list):
+            names = name
+        else:
+            raise TypeError("Unexpected branch name type %r" % name)
+        if objectish is None:
+            objectish = "HEAD"
+        object = parse_object(r, objectish)
+        refname = b"refs/heads/" + name
+        if refname in r.refs and not force:
+            raise KeyError("Branch with name %s already exists." % name)
+        r.refs[refname] = object.id
 
 
 def branch_list(repo):
@@ -723,8 +745,8 @@ def branch_list(repo):
 
     :param repo: Path to the repository
     """
-    r = open_repo(repo)
-    return r.refs.keys(base=b"refs/heads/")
+    with open_repo_closing(repo) as r:
+        return r.refs.keys(base=b"refs/heads/")
 
 
 def fetch(repo, remote_location, outstream=sys.stdout, errstream=sys.stderr):
@@ -736,7 +758,7 @@ def fetch(repo, remote_location, outstream=sys.stdout, errstream=sys.stderr):
     :param errstream: Error stream (defaults to stderr)
     :return: Dictionary with refs on the remote
     """
-    r = open_repo(repo)
-    client, path = get_transport_and_path(remote_location)
-    remote_refs = client.fetch(path, r, progress=errstream.write)
+    with open_repo_closing(repo) as r:
+        client, path = get_transport_and_path(remote_location)
+        remote_refs = client.fetch(path, r, progress=errstream.write)
     return remote_refs

+ 1 - 5
dulwich/repo.py

@@ -917,13 +917,9 @@ class Repo(BaseRepo):
     create = init_bare
 
     def close(self):
+        """Close any files opened by this repository."""
         self.object_store.close()
 
-    def __enter__(self):
-        return self
-
-    def __exit__(self, exc_type, exc_val, exc_tb):
-        self.close()
 
 class MemoryRepo(BaseRepo):
     """Repo that stores refs, objects, and named files in memory.

+ 79 - 73
dulwich/tests/compat/test_client.py

@@ -19,8 +19,9 @@
 
 """Compatibilty tests between the Dulwich client and the cgit server."""
 
-from io import BytesIO
+from contextlib import closing
 import copy
+from io import BytesIO
 import os
 import select
 import signal
@@ -86,9 +87,11 @@ class DulwichClientTestBase(object):
         rmtree_ro(self.gitroot)
 
     def assertDestEqualsSrc(self):
-        src = repo.Repo(os.path.join(self.gitroot, 'server_new.export'))
-        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
-        self.assertReposEqual(src, dest)
+        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        dest_repo_dir = os.path.join(self.gitroot, 'dest')
+        with closing(repo.Repo(repo_dir)) as src:
+            with closing(repo.Repo(dest_repo_dir)) as dest:
+                self.assertReposEqual(src, dest)
 
     def _client(self):
         raise NotImplementedError()
@@ -99,11 +102,11 @@ class DulwichClientTestBase(object):
     def _do_send_pack(self):
         c = self._client()
         srcpath = os.path.join(self.gitroot, 'server_new.export')
-        src = repo.Repo(srcpath)
-        sendrefs = dict(src.get_refs())
-        del sendrefs[b'HEAD']
-        c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
-                    src.object_store.generate_pack_contents)
+        with closing(repo.Repo(srcpath)) as src:
+            sendrefs = dict(src.get_refs())
+            del sendrefs[b'HEAD']
+            c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
+                        src.object_store.generate_pack_contents)
 
     def test_send_pack(self):
         self._do_send_pack()
@@ -119,12 +122,12 @@ class DulwichClientTestBase(object):
         c = self._client()
         c._send_capabilities.remove(b'report-status')
         srcpath = os.path.join(self.gitroot, 'server_new.export')
-        src = repo.Repo(srcpath)
-        sendrefs = dict(src.get_refs())
-        del sendrefs[b'HEAD']
-        c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
-                    src.object_store.generate_pack_contents)
-        self.assertDestEqualsSrc()
+        with closing(repo.Repo(srcpath)) as src:
+            sendrefs = dict(src.get_refs())
+            del sendrefs[b'HEAD']
+            c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
+                        src.object_store.generate_pack_contents)
+            self.assertDestEqualsSrc()
 
     def make_dummy_commit(self, dest):
         b = objects.Blob.from_string(b'hi')
@@ -147,9 +150,7 @@ class DulwichClientTestBase(object):
         commit_id = self.make_dummy_commit(dest)
         return dest, commit_id
 
-    def compute_send(self):
-        srcpath = os.path.join(self.gitroot, 'server_new.export')
-        src = repo.Repo(srcpath)
+    def compute_send(self, src):
         sendrefs = dict(src.get_refs())
         del sendrefs[b'HEAD']
         return sendrefs, src.object_store.generate_pack_contents
@@ -157,35 +158,39 @@ class DulwichClientTestBase(object):
     def test_send_pack_one_error(self):
         dest, dummy_commit = self.disable_ff_and_make_dummy_commit()
         dest.refs[b'refs/heads/master'] = dummy_commit
-        sendrefs, gen_pack = self.compute_send()
-        c = self._client()
-        try:
-            c.send_pack(self._build_path('/dest'), lambda _: sendrefs, gen_pack)
-        except errors.UpdateRefsError as e:
-            self.assertEqual('refs/heads/master failed to update',
-                             e.args[0])
-            self.assertEqual({b'refs/heads/branch': b'ok',
-                              b'refs/heads/master': b'non-fast-forward'},
-                             e.ref_status)
+        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        with closing(repo.Repo(repo_dir)) as src:
+            sendrefs, gen_pack = self.compute_send(src)
+            c = self._client()
+            try:
+                c.send_pack(self._build_path('/dest'), lambda _: sendrefs, gen_pack)
+            except errors.UpdateRefsError as e:
+                self.assertEqual('refs/heads/master failed to update',
+                                 e.args[0])
+                self.assertEqual({b'refs/heads/branch': b'ok',
+                                  b'refs/heads/master': b'non-fast-forward'},
+                                 e.ref_status)
 
     def test_send_pack_multiple_errors(self):
         dest, dummy = self.disable_ff_and_make_dummy_commit()
         # set up for two non-ff errors
         branch, master = b'refs/heads/branch', b'refs/heads/master'
         dest.refs[branch] = dest.refs[master] = dummy
-        sendrefs, gen_pack = self.compute_send()
-        c = self._client()
-        try:
-            c.send_pack(self._build_path('/dest'), lambda _: sendrefs, gen_pack)
-        except errors.UpdateRefsError as e:
-            self.assertIn(str(e),
-                          ['{0}, {1} failed to update'.format(
-                              branch.decode('ascii'), master.decode('ascii')),
-                           '{1}, {0} failed to update'.format(
-                               branch.decode('ascii'), master.decode('ascii'))])
-            self.assertEqual({branch: b'non-fast-forward',
-                              master: b'non-fast-forward'},
-                             e.ref_status)
+        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        with closing(repo.Repo(repo_dir)) as src:
+            sendrefs, gen_pack = self.compute_send(src)
+            c = self._client()
+            try:
+                c.send_pack(self._build_path('/dest'), lambda _: sendrefs, gen_pack)
+            except errors.UpdateRefsError as e:
+                self.assertIn(str(e),
+                              ['{0}, {1} failed to update'.format(
+                                  branch.decode('ascii'), master.decode('ascii')),
+                               '{1}, {0} failed to update'.format(
+                                   branch.decode('ascii'), master.decode('ascii'))])
+                self.assertEqual({branch: b'non-fast-forward',
+                                  master: b'non-fast-forward'},
+                                 e.ref_status)
 
     def test_archive(self):
         c = self._client()
@@ -197,55 +202,56 @@ class DulwichClientTestBase(object):
 
     def test_fetch_pack(self):
         c = self._client()
-        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
-        refs = c.fetch(self._build_path('/server_new.export'), dest)
-        for r in refs.items():
-            dest.refs.set_if_equals(r[0], None, r[1])
-        self.assertDestEqualsSrc()
+        with closing(repo.Repo(os.path.join(self.gitroot, 'dest'))) as dest:
+            refs = c.fetch(self._build_path('/server_new.export'), dest)
+            for r in refs.items():
+                dest.refs.set_if_equals(r[0], None, r[1])
+            self.assertDestEqualsSrc()
 
     def test_incremental_fetch_pack(self):
         self.test_fetch_pack()
         dest, dummy = self.disable_ff_and_make_dummy_commit()
         dest.refs[b'refs/heads/master'] = dummy
         c = self._client()
-        dest = repo.Repo(os.path.join(self.gitroot, 'server_new.export'))
-        refs = c.fetch(self._build_path('/dest'), dest)
-        for r in refs.items():
-            dest.refs.set_if_equals(r[0], None, r[1])
-        self.assertDestEqualsSrc()
+        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        with closing(repo.Repo(repo_dir)) as dest:
+            refs = c.fetch(self._build_path('/dest'), dest)
+            for r in refs.items():
+                dest.refs.set_if_equals(r[0], None, r[1])
+            self.assertDestEqualsSrc()
 
     def test_fetch_pack_no_side_band_64k(self):
         c = self._client()
         c._fetch_capabilities.remove(b'side-band-64k')
-        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
-        refs = c.fetch(self._build_path('/server_new.export'), dest)
-        for r in refs.items():
-            dest.refs.set_if_equals(r[0], None, r[1])
-        self.assertDestEqualsSrc()
+        with closing(repo.Repo(os.path.join(self.gitroot, 'dest'))) as dest:
+            refs = c.fetch(self._build_path('/server_new.export'), dest)
+            for r in refs.items():
+                dest.refs.set_if_equals(r[0], None, r[1])
+            self.assertDestEqualsSrc()
 
     def test_fetch_pack_zero_sha(self):
         # zero sha1s are already present on the client, and should
         # be ignored
         c = self._client()
-        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
-        refs = c.fetch(self._build_path('/server_new.export'), dest,
-            lambda refs: [protocol.ZERO_SHA])
-        for r in refs.items():
-            dest.refs.set_if_equals(r[0], None, r[1])
+        with closing(repo.Repo(os.path.join(self.gitroot, 'dest'))) as dest:
+            refs = c.fetch(self._build_path('/server_new.export'), dest,
+                lambda refs: [protocol.ZERO_SHA])
+            for r in refs.items():
+                dest.refs.set_if_equals(r[0], None, r[1])
 
     def test_send_remove_branch(self):
-        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
-        dummy_commit = self.make_dummy_commit(dest)
-        dest.refs[b'refs/heads/master'] = dummy_commit
-        dest.refs[b'refs/heads/abranch'] = dummy_commit
-        sendrefs = dict(dest.refs)
-        sendrefs[b'refs/heads/abranch'] = b"00" * 20
-        del sendrefs[b'HEAD']
-        gen_pack = lambda have, want: []
-        c = self._client()
-        self.assertEqual(dest.refs[b"refs/heads/abranch"], dummy_commit)
-        c.send_pack(self._build_path('/dest'), lambda _: sendrefs, gen_pack)
-        self.assertFalse(b"refs/heads/abranch" in dest.refs)
+        with closing(repo.Repo(os.path.join(self.gitroot, 'dest'))) as dest:
+            dummy_commit = self.make_dummy_commit(dest)
+            dest.refs[b'refs/heads/master'] = dummy_commit
+            dest.refs[b'refs/heads/abranch'] = dummy_commit
+            sendrefs = dict(dest.refs)
+            sendrefs[b'refs/heads/abranch'] = b"00" * 20
+            del sendrefs[b'HEAD']
+            gen_pack = lambda have, want: []
+            c = self._client()
+            self.assertEqual(dest.refs[b"refs/heads/abranch"], dummy_commit)
+            c.send_pack(self._build_path('/dest'), lambda _: sendrefs, gen_pack)
+            self.assertFalse(b"refs/heads/abranch" in dest.refs)
 
 
 class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):

+ 21 - 21
dulwich/tests/compat/test_pack.py

@@ -81,11 +81,11 @@ class TestPack(PackTests):
             self.assertEqual(orig_shas, _git_verify_pack_object_list(output))
 
     def test_deltas_work(self):
-        orig_pack = self.get_pack(pack1_sha)
-        orig_blob = orig_pack[a_sha]
-        new_blob = Blob()
-        new_blob.data = orig_blob.data + b'x'
-        all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None)]
+        with self.get_pack(pack1_sha) as orig_pack:
+            orig_blob = orig_pack[a_sha]
+            new_blob = Blob()
+            new_blob.data = orig_blob.data + b'x'
+            all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None)]
         pack_path = os.path.join(self._tempdir, b'pack_with_deltas')
         write_pack(pack_path, all_to_pack, deltify=True)
         output = run_git_or_fail(['verify-pack', '-v', pack_path])
@@ -102,14 +102,14 @@ class TestPack(PackTests):
     def test_delta_medium_object(self):
         # This tests an object set that will have a copy operation
         # 2**20 in size.
-        orig_pack = self.get_pack(pack1_sha)
-        orig_blob = orig_pack[a_sha]
-        new_blob = Blob()
-        new_blob.data = orig_blob.data + (b'x' * 2 ** 20)
-        new_blob_2 = Blob()
-        new_blob_2.data = new_blob.data + b'y'
-        all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None),
-                                                       (new_blob_2, None)]
+        with self.get_pack(pack1_sha) as orig_pack:
+            orig_blob = orig_pack[a_sha]
+            new_blob = Blob()
+            new_blob.data = orig_blob.data + (b'x' * 2 ** 20)
+            new_blob_2 = Blob()
+            new_blob_2.data = new_blob.data + b'y'
+            all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None),
+                                                           (new_blob_2, None)]
         pack_path = os.path.join(self._tempdir, b'pack_with_deltas')
         write_pack(pack_path, all_to_pack, deltify=True)
         output = run_git_or_fail(['verify-pack', '-v', pack_path])
@@ -136,14 +136,14 @@ class TestPack(PackTests):
         # 2**25 in size. This is a copy large enough that it requires
         # two copy operations in git's binary delta format.
         raise SkipTest('skipping slow, large test')
-        orig_pack = self.get_pack(pack1_sha)
-        orig_blob = orig_pack[a_sha]
-        new_blob = Blob()
-        new_blob.data = 'big blob' + ('x' * 2 ** 25)
-        new_blob_2 = Blob()
-        new_blob_2.data = new_blob.data + 'y'
-        all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None),
-                                                       (new_blob_2, None)]
+        with self.get_pack(pack1_sha) as orig_pack:
+            orig_blob = orig_pack[a_sha]
+            new_blob = Blob()
+            new_blob.data = 'big blob' + ('x' * 2 ** 25)
+            new_blob_2 = Blob()
+            new_blob_2.data = new_blob.data + 'y'
+            all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None),
+                                                           (new_blob_2, None)]
         pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         write_pack(pack_path, all_to_pack, deltify=True)
         output = run_git_or_fail(['verify-pack', '-v', pack_path])

+ 6 - 2
dulwich/tests/compat/utils.py

@@ -233,8 +233,12 @@ class CompatTestCase(TestCase):
         :returns: An initialized Repo object that lives in a temporary directory.
         """
         path = import_repo_to_dir(name)
-        self.addCleanup(rmtree_ro, path)
-        return Repo(path)
+        repo = Repo(path)
+        def cleanup():
+            repo.close()
+            rmtree_ro(path)
+        self.addCleanup(cleanup)
+        return repo
 
 
 if sys.platform == 'win32':

+ 24 - 23
dulwich/tests/test_client.py

@@ -16,6 +16,7 @@
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
 
+from contextlib import closing
 from io import BytesIO
 import sys
 import shutil
@@ -585,36 +586,36 @@ class LocalGitClientTests(TestCase):
 
     def test_fetch_empty(self):
         c = LocalGitClient()
-        s = open_repo('a.git')
-        out = BytesIO()
-        walker = {}
-        c.fetch_pack(s.path, lambda heads: [], graph_walker=walker,
-            pack_data=out.write)
-        self.assertEqual(b"PACK\x00\x00\x00\x02\x00\x00\x00\x00\x02\x9d\x08"
-            b"\x82;\xd8\xa8\xea\xb5\x10\xadj\xc7\\\x82<\xfd>\xd3\x1e", out.getvalue())
+        with closing(open_repo('a.git')) as s:
+            out = BytesIO()
+            walker = {}
+            c.fetch_pack(s.path, lambda heads: [], graph_walker=walker,
+                pack_data=out.write)
+            self.assertEqual(b"PACK\x00\x00\x00\x02\x00\x00\x00\x00\x02\x9d\x08"
+                b"\x82;\xd8\xa8\xea\xb5\x10\xadj\xc7\\\x82<\xfd>\xd3\x1e", out.getvalue())
 
     def test_fetch_pack_none(self):
         c = LocalGitClient()
-        s = open_repo('a.git')
-        out = BytesIO()
-        walker = MemoryRepo().get_graph_walker()
-        c.fetch_pack(s.path,
-            lambda heads: [b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"],
-            graph_walker=walker, pack_data=out.write)
-        # Hardcoding is not ideal, but we'll fix that some other day..
-        self.assertTrue(out.getvalue().startswith(b'PACK\x00\x00\x00\x02\x00\x00\x00\x07'))
+        with closing(open_repo('a.git')) as s:
+            out = BytesIO()
+            walker = MemoryRepo().get_graph_walker()
+            c.fetch_pack(s.path,
+                lambda heads: [b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"],
+                graph_walker=walker, pack_data=out.write)
+            # Hardcoding is not ideal, but we'll fix that some other day..
+            self.assertTrue(out.getvalue().startswith(b'PACK\x00\x00\x00\x02\x00\x00\x00\x07'))
 
     def test_send_pack_without_changes(self):
-        local = open_repo('a.git')
-        target = open_repo('a.git')
-        self.send_and_verify(b"master", local, target)
+        with closing(open_repo('a.git')) as local:
+            with closing(open_repo('a.git')) as target:
+                self.send_and_verify(b"master", local, target)
 
     def test_send_pack_with_changes(self):
-        local = open_repo('a.git')
-        target_path = tempfile.mkdtemp()
-        self.addCleanup(shutil.rmtree, target_path)
-        target = Repo.init_bare(target_path)
-        self.send_and_verify(b"master", local, target)
+        with closing(open_repo('a.git')) as local:
+            target_path = tempfile.mkdtemp()
+            self.addCleanup(shutil.rmtree, target_path)
+            with closing(Repo.init_bare(target_path)) as target:
+                self.send_and_verify(b"master", local, target)
 
     def send_and_verify(self, branch, local, target):
         client = LocalGitClient()

+ 108 - 107
dulwich/tests/test_index.py

@@ -19,6 +19,7 @@
 """Tests for the index."""
 
 
+from contextlib import closing
 from io import BytesIO
 import os
 import shutil
@@ -255,122 +256,122 @@ class BuildIndexTests(TestCase):
 
     def test_empty(self):
         repo_dir = tempfile.mkdtemp()
-        repo = Repo.init(repo_dir)
-        self.addCleanup(shutil.rmtree, repo_dir)
+        with closing(Repo.init(repo_dir)) as repo:
+            self.addCleanup(shutil.rmtree, repo_dir)
 
-        tree = Tree()
-        repo.object_store.add_object(tree)
+            tree = Tree()
+            repo.object_store.add_object(tree)
 
-        build_index_from_tree(repo.path, repo.index_path(),
-                repo.object_store, tree.id)
+            build_index_from_tree(repo.path, repo.index_path(),
+                    repo.object_store, tree.id)
 
-        # Verify index entries
-        index = repo.open_index()
-        self.assertEqual(len(index), 0)
+            # Verify index entries
+            index = repo.open_index()
+            self.assertEqual(len(index), 0)
 
-        # Verify no files
-        self.assertEqual(['.git'], os.listdir(repo.path))
+            # Verify no files
+            self.assertEqual(['.git'], os.listdir(repo.path))
 
     def test_git_dir(self):
         if os.name != 'posix':
             self.skipTest("test depends on POSIX shell")
 
         repo_dir = tempfile.mkdtemp()
-        repo = Repo.init(repo_dir)
         self.addCleanup(shutil.rmtree, repo_dir)
+        with closing(Repo.init(repo_dir)) as repo:
 
-        # Populate repo
-        filea = Blob.from_string(b'file a')
-        filee = Blob.from_string(b'd')
+            # Populate repo
+            filea = Blob.from_string(b'file a')
+            filee = Blob.from_string(b'd')
 
-        tree = Tree()
-        tree[b'.git/a'] = (stat.S_IFREG | 0o644, filea.id)
-        tree[b'c/e'] = (stat.S_IFREG | 0o644, filee.id)
+            tree = Tree()
+            tree[b'.git/a'] = (stat.S_IFREG | 0o644, filea.id)
+            tree[b'c/e'] = (stat.S_IFREG | 0o644, filee.id)
 
-        repo.object_store.add_objects([(o, None)
-            for o in [filea, filee, tree]])
+            repo.object_store.add_objects([(o, None)
+                for o in [filea, filee, tree]])
 
-        build_index_from_tree(repo.path, repo.index_path(),
-                repo.object_store, tree.id)
+            build_index_from_tree(repo.path, repo.index_path(),
+                    repo.object_store, tree.id)
 
-        # Verify index entries
-        index = repo.open_index()
-        self.assertEqual(len(index), 1)
+            # Verify index entries
+            index = repo.open_index()
+            self.assertEqual(len(index), 1)
 
-        # filea
-        apath = os.path.join(repo.path, '.git', 'a')
-        self.assertFalse(os.path.exists(apath))
+            # filea
+            apath = os.path.join(repo.path, '.git', 'a')
+            self.assertFalse(os.path.exists(apath))
 
-        # filee
-        epath = os.path.join(repo.path, 'c', 'e')
-        self.assertTrue(os.path.exists(epath))
-        self.assertReasonableIndexEntry(index[b'c/e'],
-            stat.S_IFREG | 0o644, 1, filee.id)
-        self.assertFileContents(epath, b'd')
+            # filee
+            epath = os.path.join(repo.path, 'c', 'e')
+            self.assertTrue(os.path.exists(epath))
+            self.assertReasonableIndexEntry(index[b'c/e'],
+                stat.S_IFREG | 0o644, 1, filee.id)
+            self.assertFileContents(epath, b'd')
 
     def test_nonempty(self):
         if os.name != 'posix':
             self.skipTest("test depends on POSIX shell")
 
         repo_dir = tempfile.mkdtemp()
-        repo = Repo.init(repo_dir)
         self.addCleanup(shutil.rmtree, repo_dir)
-
-        # Populate repo
-        filea = Blob.from_string(b'file a')
-        fileb = Blob.from_string(b'file b')
-        filed = Blob.from_string(b'file d')
-        filee = Blob.from_string(b'd')
-
-        tree = Tree()
-        tree[b'a'] = (stat.S_IFREG | 0o644, filea.id)
-        tree[b'b'] = (stat.S_IFREG | 0o644, fileb.id)
-        tree[b'c/d'] = (stat.S_IFREG | 0o644, filed.id)
-        tree[b'c/e'] = (stat.S_IFLNK, filee.id)  # symlink
-
-        repo.object_store.add_objects([(o, None)
-            for o in [filea, fileb, filed, filee, tree]])
-
-        build_index_from_tree(repo.path, repo.index_path(),
-                repo.object_store, tree.id)
-
-        # Verify index entries
-        index = repo.open_index()
-        self.assertEqual(len(index), 4)
-
-        # filea
-        apath = os.path.join(repo.path, 'a')
-        self.assertTrue(os.path.exists(apath))
-        self.assertReasonableIndexEntry(index[b'a'],
-            stat.S_IFREG | 0o644, 6, filea.id)
-        self.assertFileContents(apath, b'file a')
-
-        # fileb
-        bpath = os.path.join(repo.path, 'b')
-        self.assertTrue(os.path.exists(bpath))
-        self.assertReasonableIndexEntry(index[b'b'],
-            stat.S_IFREG | 0o644, 6, fileb.id)
-        self.assertFileContents(bpath, b'file b')
-
-        # filed
-        dpath = os.path.join(repo.path, 'c', 'd')
-        self.assertTrue(os.path.exists(dpath))
-        self.assertReasonableIndexEntry(index[b'c/d'],
-            stat.S_IFREG | 0o644, 6, filed.id)
-        self.assertFileContents(dpath, b'file d')
-
-        # symlink to d
-        epath = os.path.join(repo.path, 'c', 'e')
-        self.assertTrue(os.path.exists(epath))
-        self.assertReasonableIndexEntry(index[b'c/e'],
-            stat.S_IFLNK, 1, filee.id)
-        self.assertFileContents(epath, 'd', symlink=True)
-
-        # Verify no extra files
-        self.assertEqual(['.git', 'a', 'b', 'c'],
-            sorted(os.listdir(repo.path)))
-        self.assertEqual(['d', 'e'],
-            sorted(os.listdir(os.path.join(repo.path, 'c'))))
+        with closing(Repo.init(repo_dir)) as repo:
+
+            # Populate repo
+            filea = Blob.from_string(b'file a')
+            fileb = Blob.from_string(b'file b')
+            filed = Blob.from_string(b'file d')
+            filee = Blob.from_string(b'd')
+
+            tree = Tree()
+            tree[b'a'] = (stat.S_IFREG | 0o644, filea.id)
+            tree[b'b'] = (stat.S_IFREG | 0o644, fileb.id)
+            tree[b'c/d'] = (stat.S_IFREG | 0o644, filed.id)
+            tree[b'c/e'] = (stat.S_IFLNK, filee.id)  # symlink
+
+            repo.object_store.add_objects([(o, None)
+                for o in [filea, fileb, filed, filee, tree]])
+
+            build_index_from_tree(repo.path, repo.index_path(),
+                    repo.object_store, tree.id)
+
+            # Verify index entries
+            index = repo.open_index()
+            self.assertEqual(len(index), 4)
+
+            # filea
+            apath = os.path.join(repo.path, 'a')
+            self.assertTrue(os.path.exists(apath))
+            self.assertReasonableIndexEntry(index[b'a'],
+                stat.S_IFREG | 0o644, 6, filea.id)
+            self.assertFileContents(apath, b'file a')
+
+            # fileb
+            bpath = os.path.join(repo.path, 'b')
+            self.assertTrue(os.path.exists(bpath))
+            self.assertReasonableIndexEntry(index[b'b'],
+                stat.S_IFREG | 0o644, 6, fileb.id)
+            self.assertFileContents(bpath, b'file b')
+
+            # filed
+            dpath = os.path.join(repo.path, 'c', 'd')
+            self.assertTrue(os.path.exists(dpath))
+            self.assertReasonableIndexEntry(index[b'c/d'],
+                stat.S_IFREG | 0o644, 6, filed.id)
+            self.assertFileContents(dpath, b'file d')
+
+            # symlink to d
+            epath = os.path.join(repo.path, 'c', 'e')
+            self.assertTrue(os.path.exists(epath))
+            self.assertReasonableIndexEntry(index[b'c/e'],
+                stat.S_IFLNK, 1, filee.id)
+            self.assertFileContents(epath, 'd', symlink=True)
+
+            # Verify no extra files
+            self.assertEqual(['.git', 'a', 'b', 'c'],
+                sorted(os.listdir(repo.path)))
+            self.assertEqual(['d', 'e'],
+                sorted(os.listdir(os.path.join(repo.path, 'c'))))
 
 
 class GetUnstagedChangesTests(TestCase):
@@ -379,30 +380,30 @@ class GetUnstagedChangesTests(TestCase):
         """Unit test for get_unstaged_changes."""
 
         repo_dir = tempfile.mkdtemp()
-        repo = Repo.init(repo_dir)
         self.addCleanup(shutil.rmtree, repo_dir)
+        with closing(Repo.init(repo_dir)) as repo:
 
-        # Commit a dummy file then modify it
-        foo1_fullpath = os.path.join(repo_dir, 'foo1')
-        with open(foo1_fullpath, 'wb') as f:
-            f.write(b'origstuff')
+            # Commit a dummy file then modify it
+            foo1_fullpath = os.path.join(repo_dir, 'foo1')
+            with open(foo1_fullpath, 'wb') as f:
+                f.write(b'origstuff')
 
-        foo2_fullpath = os.path.join(repo_dir, 'foo2')
-        with open(foo2_fullpath, 'wb') as f:
-            f.write(b'origstuff')
+            foo2_fullpath = os.path.join(repo_dir, 'foo2')
+            with open(foo2_fullpath, 'wb') as f:
+                f.write(b'origstuff')
 
-        repo.stage(['foo1', 'foo2'])
-        repo.do_commit(b'test status', author=b'', committer=b'')
+            repo.stage(['foo1', 'foo2'])
+            repo.do_commit(b'test status', author=b'', committer=b'')
 
-        with open(foo1_fullpath, 'wb') as f:
-            f.write(b'newstuff')
+            with open(foo1_fullpath, 'wb') as f:
+                f.write(b'newstuff')
 
-        # modify access and modify time of path
-        os.utime(foo1_fullpath, (0, 0))
+            # modify access and modify time of path
+            os.utime(foo1_fullpath, (0, 0))
 
-        changes = get_unstaged_changes(repo.open_index(), repo_dir)
+            changes = get_unstaged_changes(repo.open_index(), repo_dir)
 
-        self.assertEqual(list(changes), [b'foo1'])
+            self.assertEqual(list(changes), [b'foo1'])
 
 
 class TestValidatePathElement(TestCase):

+ 6 - 6
dulwich/tests/test_object_store.py

@@ -19,6 +19,7 @@
 """Tests for the object store interface."""
 
 
+from contextlib import closing
 from io import BytesIO
 import os
 import shutil
@@ -350,12 +351,11 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
             o.close()
 
     def test_add_thin_pack_empty(self):
-        o = DiskObjectStore(self.store_dir)
-
-        f = BytesIO()
-        entries = build_pack(f, [], store=o)
-        self.assertEqual([], entries)
-        o.add_thin_pack(f.read, None)
+        with closing(DiskObjectStore(self.store_dir)) as o:
+            f = BytesIO()
+            entries = build_pack(f, [], store=o)
+            self.assertEqual([], entries)
+            o.add_thin_pack(f.read, None)
 
 
 class TreeLookupPathTests(TestCase):

+ 28 - 17
dulwich/tests/test_porcelain.py

@@ -18,6 +18,7 @@
 
 """Tests for dulwich.porcelain."""
 
+from contextlib import closing
 from io import BytesIO
 import os
 import shutil
@@ -50,6 +51,10 @@ class PorcelainTestCase(TestCase):
         self.addCleanup(shutil.rmtree, repo_dir)
         self.repo = Repo.init(repo_dir)
 
+    def tearDown(self):
+        super(PorcelainTestCase, self).tearDown()
+        self.repo.close()
+
 
 class ArchiveTests(PorcelainTestCase):
     """Tests for the archive command."""
@@ -127,10 +132,12 @@ class CloneTests(PorcelainTestCase):
         target_path = tempfile.mkdtemp()
         errstream = BytesIO()
         self.addCleanup(shutil.rmtree, target_path)
-        r = porcelain.clone(self.repo.path, target_path,
-                            checkout=True, errstream=errstream)
-        self.assertEqual(r.path, target_path)
-        self.assertEqual(Repo(target_path).head(), c3.id)
+        with closing(porcelain.clone(self.repo.path, target_path,
+                                     checkout=True,
+                                     errstream=errstream)) as r:
+            self.assertEqual(r.path, target_path)
+        with closing(Repo(target_path)) as r:
+            self.assertEqual(r.head(), c3.id)
         self.assertTrue('f1' in os.listdir(target_path))
         self.assertTrue('f2' in os.listdir(target_path))
 
@@ -445,7 +452,8 @@ class PushTests(PorcelainTestCase):
 
         # Setup target repo cloned from temp test repo
         clone_path = tempfile.mkdtemp()
-        porcelain.clone(self.repo.path, target=clone_path, errstream=errstream)
+        target_repo = porcelain.clone(self.repo.path, target=clone_path, errstream=errstream)
+        target_repo.close()
 
         # create a second file to be pushed back to origin
         handle, fullpath = tempfile.mkstemp(dir=clone_path)
@@ -463,15 +471,15 @@ class PushTests(PorcelainTestCase):
                 errstream=errstream)
 
         # Check that the target and source
-        r_clone = Repo(clone_path)
+        with closing(Repo(clone_path)) as r_clone:
 
-        # Get the change in the target repo corresponding to the add
-        # this will be in the foo branch.
-        change = list(tree_changes(self.repo, self.repo[b'HEAD'].tree,
-                                   self.repo[b'refs/heads/foo'].tree))[0]
+            # Get the change in the target repo corresponding to the add
+            # this will be in the foo branch.
+            change = list(tree_changes(self.repo, self.repo[b'HEAD'].tree,
+                                       self.repo[b'refs/heads/foo'].tree))[0]
 
-        self.assertEqual(r_clone[b'HEAD'].id, self.repo[refs_path].id)
-        self.assertEqual(os.path.basename(fullpath), change.new.path.decode('ascii'))
+            self.assertEqual(r_clone[b'HEAD'].id, self.repo[refs_path].id)
+            self.assertEqual(os.path.basename(fullpath), change.new.path.decode('ascii'))
 
 
 class PullTests(PorcelainTestCase):
@@ -490,7 +498,9 @@ class PullTests(PorcelainTestCase):
 
         # Setup target repo
         target_path = tempfile.mkdtemp()
-        porcelain.clone(self.repo.path, target=target_path, errstream=errstream)
+        target_repo = porcelain.clone(self.repo.path, target=target_path,
+                                      errstream=errstream)
+        target_repo.close()
 
         # create a second file to be pushed
         handle, fullpath = tempfile.mkstemp(dir=self.repo.path)
@@ -505,8 +515,8 @@ class PullTests(PorcelainTestCase):
             outstream=outstream, errstream=errstream)
 
         # Check the target repo for pushed changes
-        r = Repo(target_path)
-        self.assertEqual(r[b'HEAD'].id, self.repo[b'HEAD'].id)
+        with closing(Repo(target_path)) as r:
+            self.assertEqual(r[b'HEAD'].id, self.repo[b'HEAD'].id)
 
 
 class StatusTests(PorcelainTestCase):
@@ -710,11 +720,12 @@ class FetchTests(PorcelainTestCase):
             author=b'test2', committer=b'test2')
 
         self.assertFalse(self.repo[b'HEAD'].id in target_repo)
+        target_repo.close()
 
         # Fetch changes into the cloned repo
         porcelain.fetch(target_path, self.repo.path, outstream=outstream,
             errstream=errstream)
 
         # Check the target repo for pushed changes
-        r = Repo(target_path)
-        self.assertTrue(self.repo[b'HEAD'].id in r)
+        with closing(Repo(target_path)) as r:
+            self.assertTrue(self.repo[b'HEAD'].id in r)

+ 14 - 14
dulwich/tests/test_repository.py

@@ -20,10 +20,10 @@
 
 """Tests for the repository."""
 
+from contextlib import closing
 import os
 import stat
 import shutil
-import sys
 import tempfile
 import warnings
 import sys
@@ -267,19 +267,19 @@ class RepositoryBytesRootTests(TestCase):
         r = self._repo = self.open_repo('a.git')
         tmp_dir = self.mkdtemp()
         self.addCleanup(shutil.rmtree, tmp_dir)
-        t = r.clone(tmp_dir, mkdir=False)
-        self.assertEqual({
-            b'HEAD': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
-            b'refs/remotes/origin/master':
-                b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
-            b'refs/heads/master': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
-            b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
-            b'refs/tags/mytag-packed':
-                b'b0931cadc54336e78a1d980420e3268903b57a50',
-            }, t.refs.as_dict())
-        shas = [e.commit.id for e in r.get_walker()]
-        self.assertEqual(shas, [t.head(),
-                         b'2a72d929692c41d8554c07f6301757ba18a65d91'])
+        with closing(r.clone(tmp_dir, mkdir=False)) as t:
+            self.assertEqual({
+                b'HEAD': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+                b'refs/remotes/origin/master':
+                    b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+                b'refs/heads/master': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+                b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
+                b'refs/tags/mytag-packed':
+                    b'b0931cadc54336e78a1d980420e3268903b57a50',
+                }, t.refs.as_dict())
+            shas = [e.commit.id for e in r.get_walker()]
+            self.assertEqual(shas, [t.head(),
+                             b'2a72d929692c41d8554c07f6301757ba18a65d91'])
 
     def test_clone_no_head(self):
         temp_dir = self.mkdtemp()

+ 1 - 0
dulwich/tests/utils.py

@@ -82,6 +82,7 @@ def open_repo(name, temp_dir=None):
 
 def tear_down_repo(repo):
     """Tear down a test repository."""
+    repo.close()
     temp_dir = os.path.dirname(repo._path_bytes.rstrip(os.sep.encode(sys.getfilesystemencoding())))
     shutil.rmtree(temp_dir)