Ver Fonte

Merge branch 'python3-close-files' of git://github.com/garyvdm/dulwich

Jelmer Vernooij há 11 anos atrás
pai
commit
4900903b87

+ 11 - 0
NEWS

@@ -7,6 +7,9 @@
   * Support staging symbolic links in Repo.stage.
     (Robert Brown)
 
+  * Ensure that all files object are closed when running the test suite.
+    (Gary van der Merwe)
+
  IMPROVEMENTS
 
   * Add porcelain 'status'. (Ryan Faulkner)
@@ -18,6 +21,14 @@
    * Various changes to improve compatibility with Python 3.
      (Gary van der Merwe, Hannu Valtonen, michael-k)
 
+API CHANGES
+
+  * An optional close function can be passed to the Protocol class. This will
+    be called by it's close method. (Gary van der Merwe)
+
+  * All classes with close methods are now context managers, so that they can
+    be easily closed using a `with` statement. (Gary van der Merwe)
+
 0.9.6	2014-04-23
 
  IMPROVEMENTS

+ 139 - 122
dulwich/client.py

@@ -438,67 +438,68 @@ class TraditionalGitClient(GitClient):
                                  and rejects ref updates
         """
         proto, unused_can_read = self._connect('receive-pack', path)
-        old_refs, server_capabilities = read_pkt_refs(proto)
-        negotiated_capabilities = self._send_capabilities & server_capabilities
-
-        if 'report-status' in negotiated_capabilities:
-            self._report_status_parser = ReportStatusParser()
-        report_status_parser = self._report_status_parser
+        with proto:
+            old_refs, server_capabilities = read_pkt_refs(proto)
+            negotiated_capabilities = self._send_capabilities & server_capabilities
 
-        try:
-            new_refs = orig_new_refs = determine_wants(dict(old_refs))
-        except:
-            proto.write_pkt_line(None)
-            raise
-
-        if not 'delete-refs' in server_capabilities:
-            # Server does not support deletions. Fail later.
-            def remove_del(pair):
-                if pair[1] == ZERO_SHA:
-                    if 'report-status' in negotiated_capabilities:
-                        report_status_parser._ref_statuses.append(
-                            'ng %s remote does not support deleting refs'
-                            % pair[1])
-                        report_status_parser._ref_status_ok = False
-                    return False
-                else:
-                    return True
-
-            new_refs = dict(
-                filter(
-                    remove_del,
-                    [(ref, sha) for ref, sha in new_refs.iteritems()]))
-
-        if new_refs is None:
-            proto.write_pkt_line(None)
-            return old_refs
+            if 'report-status' in negotiated_capabilities:
+                self._report_status_parser = ReportStatusParser()
+            report_status_parser = self._report_status_parser
 
-        if len(new_refs) == 0 and len(orig_new_refs):
-            # NOOP - Original new refs filtered out by policy
-            proto.write_pkt_line(None)
-            if self._report_status_parser is not None:
-                self._report_status_parser.check()
-            return old_refs
-
-        (have, want) = self._handle_receive_pack_head(
-            proto, negotiated_capabilities, old_refs, new_refs)
-        if not want and old_refs == new_refs:
-            return new_refs
-        objects = generate_pack_contents(have, want)
-        if len(objects) > 0:
-            entries, sha = write_pack_objects(proto.write_file(), objects)
-        elif len(set(new_refs.values()) - set([ZERO_SHA])) > 0:
-            # Check for valid create/update refs
-            filtered_new_refs = \
-                dict([(ref, sha) for ref, sha in new_refs.iteritems()
-                     if sha != ZERO_SHA])
-            if len(set(filtered_new_refs.iteritems()) -
-                    set(old_refs.iteritems())) > 0:
+            try:
+                new_refs = orig_new_refs = determine_wants(dict(old_refs))
+            except:
+                proto.write_pkt_line(None)
+                raise
+
+            if not 'delete-refs' in server_capabilities:
+                # Server does not support deletions. Fail later.
+                def remove_del(pair):
+                    if pair[1] == ZERO_SHA:
+                        if 'report-status' in negotiated_capabilities:
+                            report_status_parser._ref_statuses.append(
+                                'ng %s remote does not support deleting refs'
+                                % pair[1])
+                            report_status_parser._ref_status_ok = False
+                        return False
+                    else:
+                        return True
+
+                new_refs = dict(
+                    filter(
+                        remove_del,
+                        [(ref, sha) for ref, sha in new_refs.iteritems()]))
+
+            if new_refs is None:
+                proto.write_pkt_line(None)
+                return old_refs
+
+            if len(new_refs) == 0 and len(orig_new_refs):
+                # NOOP - Original new refs filtered out by policy
+                proto.write_pkt_line(None)
+                if self._report_status_parser is not None:
+                    self._report_status_parser.check()
+                return old_refs
+
+            (have, want) = self._handle_receive_pack_head(
+                proto, negotiated_capabilities, old_refs, new_refs)
+            if not want and old_refs == new_refs:
+                return new_refs
+            objects = generate_pack_contents(have, want)
+            if len(objects) > 0:
                 entries, sha = write_pack_objects(proto.write_file(), objects)
-
-        self._handle_receive_pack_tail(
-            proto, negotiated_capabilities, progress)
-        return new_refs
+            elif len(set(new_refs.values()) - set([ZERO_SHA])) > 0:
+                # Check for valid create/update refs
+                filtered_new_refs = \
+                    dict([(ref, sha) for ref, sha in new_refs.iteritems()
+                         if sha != ZERO_SHA])
+                if len(set(filtered_new_refs.iteritems()) -
+                        set(old_refs.iteritems())) > 0:
+                    entries, sha = write_pack_objects(proto.write_file(), objects)
+
+            self._handle_receive_pack_tail(
+                proto, negotiated_capabilities, progress)
+            return new_refs
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
                    progress=None):
@@ -510,47 +511,49 @@ class TraditionalGitClient(GitClient):
         :param progress: Callback for progress reports (strings)
         """
         proto, can_read = self._connect('upload-pack', path)
-        refs, server_capabilities = read_pkt_refs(proto)
-        negotiated_capabilities = (
-            self._fetch_capabilities & server_capabilities)
+        with proto:
+            refs, server_capabilities = read_pkt_refs(proto)
+            negotiated_capabilities = (
+                self._fetch_capabilities & server_capabilities)
 
-        if refs is None:
-            proto.write_pkt_line(None)
-            return refs
+            if refs is None:
+                proto.write_pkt_line(None)
+                return refs
 
-        try:
-            wants = determine_wants(refs)
-        except:
-            proto.write_pkt_line(None)
-            raise
-        if wants is not None:
-            wants = [cid for cid in wants if cid != ZERO_SHA]
-        if not wants:
-            proto.write_pkt_line(None)
+            try:
+                wants = determine_wants(refs)
+            except:
+                proto.write_pkt_line(None)
+                raise
+            if wants is not None:
+                wants = [cid for cid in wants if cid != ZERO_SHA]
+            if not wants:
+                proto.write_pkt_line(None)
+                return refs
+            self._handle_upload_pack_head(
+                proto, negotiated_capabilities, graph_walker, wants, can_read)
+            self._handle_upload_pack_tail(
+                proto, negotiated_capabilities, graph_walker, pack_data, progress)
             return refs
-        self._handle_upload_pack_head(
-            proto, negotiated_capabilities, graph_walker, wants, can_read)
-        self._handle_upload_pack_tail(
-            proto, negotiated_capabilities, graph_walker, pack_data, progress)
-        return refs
 
     def archive(self, path, committish, write_data, progress=None):
-        proto, can_read = self._connect('upload-archive', path)
-        proto.write_pkt_line("argument %s" % committish)
-        proto.write_pkt_line(None)
-        pkt = proto.read_pkt_line()
-        if pkt == "NACK\n":
-            return
-        elif pkt == "ACK\n":
-            pass
-        elif pkt.startswith("ERR "):
-            raise GitProtocolError(pkt[4:].rstrip("\n"))
-        else:
-            raise AssertionError("invalid response %r" % pkt)
-        ret = proto.read_pkt_line()
-        if ret is not None:
-            raise AssertionError("expected pkt tail")
-        self._read_side_band64k_data(proto, {1: write_data, 2: progress})
+        proto, can_read = self._connect(b'upload-archive', path)
+        with proto:
+            proto.write_pkt_line("argument %s" % committish)
+            proto.write_pkt_line(None)
+            pkt = proto.read_pkt_line()
+            if pkt == "NACK\n":
+                return
+            elif pkt == "ACK\n":
+                pass
+            elif pkt.startswith("ERR "):
+                raise GitProtocolError(pkt[4:].rstrip("\n"))
+            else:
+                raise AssertionError("invalid response %r" % pkt)
+            ret = proto.read_pkt_line()
+            if ret is not None:
+                raise AssertionError("expected pkt tail")
+            self._read_side_band64k_data(proto, {1: write_data, 2: progress})
 
 
 class TCPGitClient(TraditionalGitClient):
@@ -584,7 +587,12 @@ class TCPGitClient(TraditionalGitClient):
         rfile = s.makefile('rb', -1)
         # 0 means unbuffered
         wfile = s.makefile('wb', 0)
-        proto = Protocol(rfile.read, wfile.write,
+        def close():
+            rfile.close()
+            wfile.close()
+            s.close()
+
+        proto = Protocol(rfile.read, wfile.write, close,
                          report_activity=self._report_activity)
         if path.startswith("/~"):
             path = path[1:]
@@ -612,6 +620,8 @@ class SubprocessWrapper(object):
     def close(self):
         self.proc.stdin.close()
         self.proc.stdout.close()
+        if self.proc.stderr:
+            self.proc.stderr.close()
         self.proc.wait()
 
 
@@ -633,7 +643,7 @@ class SubprocessGitClient(TraditionalGitClient):
             subprocess.Popen(argv, bufsize=0, stdin=subprocess.PIPE,
                              stdout=subprocess.PIPE,
                              stderr=self._stderr))
-        return Protocol(p.read, p.write,
+        return Protocol(p.read, p.write, p.close,
                         report_activity=self._report_activity), p.can_read
 
 
@@ -828,9 +838,6 @@ else:
             self.channel.close()
             self.stop_monitoring()
 
-        def __del__(self):
-            self.close()
-
     class ParamikoSSHVendor(object):
 
         def __init__(self):
@@ -882,9 +889,9 @@ class SSHGitClient(TraditionalGitClient):
         con = get_ssh_vendor().run_command(
             self.host, ["%s '%s'" % (self._get_cmd_path(cmd), path)],
             port=self.port, username=self.username)
-        return (Protocol(
-            con.read, con.write, report_activity=self._report_activity),
-            con.can_read)
+        return (Protocol(con.read, con.write, con.close, 
+                         report_activity=self._report_activity), 
+                con.can_read)
 
 
 def default_user_agent_string():
@@ -944,17 +951,20 @@ class HttpGitClient(GitClient):
             url += "?service=%s" % service
             headers["Content-Type"] = "application/x-%s-request" % service
         resp = self._http_request(url, headers)
-        self.dumb = (not resp.info().gettype().startswith("application/x-git-"))
-        if not self.dumb:
-            proto = Protocol(resp.read, None)
-            # The first line should mention the service
-            pkts = list(proto.read_pkt_seq())
-            if pkts != [('# service=%s\n' % service)]:
-                raise GitProtocolError(
-                    "unexpected first line %r from smart server" % pkts)
-            return read_pkt_refs(proto)
-        else:
-            return read_info_refs(resp), set()
+        try:
+            self.dumb = (not resp.info().gettype().startswith("application/x-git-"))
+            if not self.dumb:
+                proto = Protocol(resp.read, None)
+                # The first line should mention the service
+                pkts = list(proto.read_pkt_seq())
+                if pkts != [('# service=%s\n' % service)]:
+                    raise GitProtocolError(
+                        "unexpected first line %r from smart server" % pkts)
+                return read_pkt_refs(proto)
+            else:
+                return read_info_refs(resp), set()
+        finally:
+            resp.close()
 
     def _smart_request(self, service, url, data):
         assert url[-1] == "/"
@@ -1002,11 +1012,15 @@ class HttpGitClient(GitClient):
         if len(objects) > 0:
             entries, sha = write_pack_objects(req_proto.write_file(), objects)
         resp = self._smart_request("git-receive-pack", url,
-            data=req_data.getvalue())
-        resp_proto = Protocol(resp.read, None)
-        self._handle_receive_pack_tail(resp_proto, negotiated_capabilities,
-            progress)
-        return new_refs
+                                   data=req_data.getvalue())
+        try:
+            resp_proto = Protocol(resp.read, None)
+            self._handle_receive_pack_tail(resp_proto, negotiated_capabilities,
+                progress)
+            return new_refs
+        finally:
+            resp.close()
+
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
                    progress=None):
@@ -1036,10 +1050,13 @@ class HttpGitClient(GitClient):
             lambda: False)
         resp = self._smart_request(
             "git-upload-pack", url, data=req_data.getvalue())
-        resp_proto = Protocol(resp.read, None)
-        self._handle_upload_pack_tail(resp_proto, negotiated_capabilities,
-            graph_walker, pack_data, progress)
-        return refs
+        try:
+            resp_proto = Protocol(resp.read, None)
+            self._handle_upload_pack_tail(resp_proto, negotiated_capabilities,
+                graph_walker, pack_data, progress)
+            return refs
+        finally:
+            resp.close()
 
 
 def get_transport_and_path_from_url(url, config=None, **kwargs):

+ 6 - 0
dulwich/file.py

@@ -155,6 +155,12 @@ class _GitFile(object):
         finally:
             self.abort()
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
     def __getattr__(self, name):
         """Proxy property calls to the underlying file."""
         if name in self.PROXY_PROPERTIES:

+ 12 - 0
dulwich/pack.py

@@ -989,6 +989,12 @@ class PackData(object):
     def close(self):
         self._file.close()
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
     def _get_size(self):
         if self._size is not None:
             return self._size
@@ -1801,6 +1807,12 @@ class Pack(object):
         if self._idx is not None:
             self._idx.close()
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
     def __eq__(self, other):
         return isinstance(self, type(other)) and self.index == other.index
 

+ 12 - 1
dulwich/protocol.py

@@ -77,12 +77,23 @@ class Protocol(object):
         Documentation/technical/protocol-common.txt
     """
 
-    def __init__(self, read, write, report_activity=None):
+    def __init__(self, read, write, close=None, report_activity=None):
         self.read = read
         self.write = write
+        self._close = close
         self.report_activity = report_activity
         self._readahead = None
 
+    def close(self):
+        if self._close:
+            self._close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
     def read_pkt_line(self):
         """Reads a pkt-line from the remote git process.
 

+ 4 - 2
dulwich/repo.py

@@ -667,10 +667,12 @@ class Repo(BaseRepo):
         self._graftpoints = {}
         graft_file = self.get_named_file(os.path.join("info", "grafts"))
         if graft_file:
-            self._graftpoints.update(parse_graftpoints(graft_file))
+            with graft_file:
+                self._graftpoints.update(parse_graftpoints(graft_file))
         graft_file = self.get_named_file("shallow")
         if graft_file:
-            self._graftpoints.update(parse_graftpoints(graft_file))
+            with graft_file:
+                self._graftpoints.update(parse_graftpoints(graft_file))
 
         self.hooks['pre-commit'] = PreCommitShellHook(self.controldir())
         self.hooks['commit-msg'] = CommitMsgShellHook(self.controldir())

+ 5 - 1
dulwich/tests/compat/test_client.py

@@ -248,7 +248,9 @@ class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
 
     def tearDown(self):
         try:
-            os.kill(int(open(self.pidfile).read().strip()), signal.SIGKILL)
+            with open(self.pidfile) as f:
+                pid = f.read()
+            os.kill(int(pid.strip()), signal.SIGKILL)
             os.unlink(self.pidfile)
         except (OSError, IOError):
             pass
@@ -460,6 +462,8 @@ class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):
     def tearDown(self):
         DulwichClientTestBase.tearDown(self)
         CompatTestCase.tearDown(self)
+        self._httpd.shutdown()
+        self._httpd.socket.close()
 
     def _client(self):
         return client.HttpGitClient(self._httpd.get_url())

+ 15 - 15
dulwich/tests/compat/test_pack.py

@@ -48,19 +48,19 @@ class TestPack(PackTests):
         self.addCleanup(shutil.rmtree, self._tempdir)
 
     def test_copy(self):
-        origpack = self.get_pack(pack1_sha)
-        self.assertSucceeds(origpack.index.check)
-        pack_path = os.path.join(self._tempdir, "Elch")
-        write_pack(pack_path, origpack.pack_tuples())
-        output = run_git_or_fail(['verify-pack', '-v', pack_path])
+        with self.get_pack(pack1_sha) as origpack:
+            self.assertSucceeds(origpack.index.check)
+            pack_path = os.path.join(self._tempdir, "Elch")
+            write_pack(pack_path, origpack.pack_tuples())
+            output = run_git_or_fail(['verify-pack', '-v', pack_path])
 
-        pack_shas = set()
-        for line in output.splitlines():
-            sha = line[:40]
-            try:
-                binascii.unhexlify(sha)
-            except (TypeError, binascii.Error):
-                continue  # non-sha line
-            pack_shas.add(sha)
-        orig_shas = set(o.id for o in origpack.iterobjects())
-        self.assertEqual(orig_shas, pack_shas)
+            pack_shas = set()
+            for line in output.splitlines():
+                sha = line[:40]
+                try:
+                    binascii.unhexlify(sha)
+                except (TypeError, binascii.Error):
+                    continue  # non-sha line
+                pack_shas.add(sha)
+            orig_shas = set(o.id for o in origpack.iterobjects())
+            self.assertEqual(orig_shas, pack_shas)

+ 1 - 0
dulwich/tests/compat/test_server.py

@@ -61,6 +61,7 @@ class GitServerTestCase(ServerTests, CompatTestCase):
                                   handlers=self._handlers())
         self._check_server(dul_server)
         self.addCleanup(dul_server.shutdown)
+        self.addCleanup(dul_server.server_close)
         threading.Thread(target=dul_server.serve).start()
         self._server = dul_server
         _, port = self._server.socket.getsockname()

+ 1 - 0
dulwich/tests/compat/test_web.py

@@ -65,6 +65,7 @@ class WebTests(ServerTests):
           'localhost', 0, app, server_class=WSGIServerLogger,
           handler_class=WSGIRequestHandlerLogger)
         self.addCleanup(dul_server.shutdown)
+        self.addCleanup(dul_server.server_close)
         threading.Thread(target=dul_server.serve_forever).start()
         self._server = dul_server
         _, port = dul_server.socket.getsockname()

+ 2 - 1
dulwich/tests/compat/utils.py

@@ -193,13 +193,14 @@ def check_for_daemon(limit=10, delay=0.1, timeout=0.1, port=TCP_GIT_PORT):
         s.settimeout(delay)
         try:
             s.connect(('localhost', port))
-            s.close()
             return True
         except socket.error as e:
             if getattr(e, 'errno', False) and e.errno != errno.ECONNREFUSED:
                 raise
             elif e.args[0] != errno.ECONNREFUSED:
                 raise
+        finally:
+            s.close()
     return False
 
 

+ 1 - 0
dulwich/tests/test_client.py

@@ -492,6 +492,7 @@ class TestSSHVendor(object):
         class Subprocess: pass
         setattr(Subprocess, 'read', lambda: None)
         setattr(Subprocess, 'write', lambda: None)
+        setattr(Subprocess, 'close', lambda: None)
         setattr(Subprocess, 'can_read', lambda: None)
         return Subprocess()
 

+ 135 - 128
dulwich/tests/test_pack.py

@@ -184,64 +184,64 @@ class TestPackData(PackTests):
     """Tests getting the data from the packfile."""
 
     def test_create_pack(self):
-        self.get_pack_data(pack1_sha)
+        self.get_pack_data(pack1_sha).close()
 
     def test_from_file(self):
         path = os.path.join(self.datadir, 'pack-%s.pack' % pack1_sha)
         PackData.from_file(open(path), os.path.getsize(path))
 
     def test_pack_len(self):
-        p = self.get_pack_data(pack1_sha)
-        self.assertEqual(3, len(p))
+        with self.get_pack_data(pack1_sha) as p:
+            self.assertEqual(3, len(p))
 
     def test_index_check(self):
-        p = self.get_pack_data(pack1_sha)
-        self.assertSucceeds(p.check)
+        with self.get_pack_data(pack1_sha) as p:
+            self.assertSucceeds(p.check)
 
     def test_iterobjects(self):
-        p = self.get_pack_data(pack1_sha)
-        commit_data = ('tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\n'
-                       'author James Westby <jw+debian@jameswestby.net> '
-                       '1174945067 +0100\n'
-                       'committer James Westby <jw+debian@jameswestby.net> '
-                       '1174945067 +0100\n'
-                       '\n'
-                       'Test commit\n')
-        blob_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
-        tree_data = '100644 a\0%s' % hex_to_sha(blob_sha)
-        actual = []
-        for offset, type_num, chunks, crc32 in p.iterobjects():
-            actual.append((offset, type_num, ''.join(chunks), crc32))
-        self.assertEqual([
-          (12, 1, commit_data, 3775879613),
-          (138, 2, tree_data, 912998690),
-          (178, 3, 'test 1\n', 1373561701)
-          ], actual)
+        with self.get_pack_data(pack1_sha) as p:
+            commit_data = ('tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\n'
+                           'author James Westby <jw+debian@jameswestby.net> '
+                           '1174945067 +0100\n'
+                           'committer James Westby <jw+debian@jameswestby.net> '
+                           '1174945067 +0100\n'
+                           '\n'
+                           'Test commit\n')
+            blob_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
+            tree_data = '100644 a\0%s' % hex_to_sha(blob_sha)
+            actual = []
+            for offset, type_num, chunks, crc32 in p.iterobjects():
+                actual.append((offset, type_num, ''.join(chunks), crc32))
+            self.assertEqual([
+              (12, 1, commit_data, 3775879613),
+              (138, 2, tree_data, 912998690),
+              (178, 3, 'test 1\n', 1373561701)
+              ], actual)
 
     def test_iterentries(self):
-        p = self.get_pack_data(pack1_sha)
-        entries = set((sha_to_hex(s), o, c) for s, o, c in p.iterentries())
-        self.assertEqual(set([
-          ('6f670c0fb53f9463760b7295fbb814e965fb20c8', 178, 1373561701),
-          ('b2a2766a2879c209ab1176e7e778b81ae422eeaa', 138, 912998690),
-          ('f18faa16531ac570a3fdc8c7ca16682548dafd12', 12, 3775879613),
-          ]), entries)
+        with self.get_pack_data(pack1_sha) as p:
+            entries = set((sha_to_hex(s), o, c) for s, o, c in p.iterentries())
+            self.assertEqual(set([
+              ('6f670c0fb53f9463760b7295fbb814e965fb20c8', 178, 1373561701),
+              ('b2a2766a2879c209ab1176e7e778b81ae422eeaa', 138, 912998690),
+              ('f18faa16531ac570a3fdc8c7ca16682548dafd12', 12, 3775879613),
+              ]), entries)
 
     def test_create_index_v1(self):
-        p = self.get_pack_data(pack1_sha)
-        filename = os.path.join(self.tempdir, 'v1test.idx')
-        p.create_index_v1(filename)
-        idx1 = load_pack_index(filename)
-        idx2 = self.get_pack_index(pack1_sha)
-        self.assertEqual(idx1, idx2)
+        with self.get_pack_data(pack1_sha) as p:
+            filename = os.path.join(self.tempdir, 'v1test.idx')
+            p.create_index_v1(filename)
+            idx1 = load_pack_index(filename)
+            idx2 = self.get_pack_index(pack1_sha)
+            self.assertEqual(idx1, idx2)
 
     def test_create_index_v2(self):
-        p = self.get_pack_data(pack1_sha)
-        filename = os.path.join(self.tempdir, 'v2test.idx')
-        p.create_index_v2(filename)
-        idx1 = load_pack_index(filename)
-        idx2 = self.get_pack_index(pack1_sha)
-        self.assertEqual(idx1, idx2)
+        with self.get_pack_data(pack1_sha) as p:
+            filename = os.path.join(self.tempdir, 'v2test.idx')
+            p.create_index_v2(filename)
+            idx1 = load_pack_index(filename)
+            idx2 = self.get_pack_index(pack1_sha)
+            self.assertEqual(idx1, idx2)
 
     def test_compute_file_sha(self):
         f = BytesIO('abcd1234wxyz')
@@ -261,46 +261,46 @@ class TestPackData(PackTests):
 class TestPack(PackTests):
 
     def test_len(self):
-        p = self.get_pack(pack1_sha)
-        self.assertEqual(3, len(p))
+        with self.get_pack(pack1_sha) as p:
+            self.assertEqual(3, len(p))
 
     def test_contains(self):
-        p = self.get_pack(pack1_sha)
-        self.assertTrue(tree_sha in p)
+        with self.get_pack(pack1_sha) as p:
+            self.assertTrue(tree_sha in p)
 
     def test_get(self):
-        p = self.get_pack(pack1_sha)
-        self.assertEqual(type(p[tree_sha]), Tree)
+        with self.get_pack(pack1_sha) as p:
+            self.assertEqual(type(p[tree_sha]), Tree)
 
     def test_iter(self):
-        p = self.get_pack(pack1_sha)
-        self.assertEqual(set([tree_sha, commit_sha, a_sha]), set(p))
+        with self.get_pack(pack1_sha) as p:
+            self.assertEqual(set([tree_sha, commit_sha, a_sha]), set(p))
 
     def test_iterobjects(self):
-        p = self.get_pack(pack1_sha)
-        expected = set([p[s] for s in [commit_sha, tree_sha, a_sha]])
-        self.assertEqual(expected, set(list(p.iterobjects())))
+        with self.get_pack(pack1_sha) as p:
+            expected = set([p[s] for s in [commit_sha, tree_sha, a_sha]])
+            self.assertEqual(expected, set(list(p.iterobjects())))
 
     def test_pack_tuples(self):
-        p = self.get_pack(pack1_sha)
-        tuples = p.pack_tuples()
-        expected = set([(p[s], None) for s in [commit_sha, tree_sha, a_sha]])
-        self.assertEqual(expected, set(list(tuples)))
-        self.assertEqual(expected, set(list(tuples)))
-        self.assertEqual(3, len(tuples))
+        with self.get_pack(pack1_sha) as p:
+            tuples = p.pack_tuples()
+            expected = set([(p[s], None) for s in [commit_sha, tree_sha, a_sha]])
+            self.assertEqual(expected, set(list(tuples)))
+            self.assertEqual(expected, set(list(tuples)))
+            self.assertEqual(3, len(tuples))
 
     def test_get_object_at(self):
         """Tests random access for non-delta objects"""
-        p = self.get_pack(pack1_sha)
-        obj = p[a_sha]
-        self.assertEqual(obj.type_name, 'blob')
-        self.assertEqual(obj.sha().hexdigest(), a_sha)
-        obj = p[tree_sha]
-        self.assertEqual(obj.type_name, 'tree')
-        self.assertEqual(obj.sha().hexdigest(), tree_sha)
-        obj = p[commit_sha]
-        self.assertEqual(obj.type_name, 'commit')
-        self.assertEqual(obj.sha().hexdigest(), commit_sha)
+        with self.get_pack(pack1_sha) as p:
+            obj = p[a_sha]
+            self.assertEqual(obj.type_name, 'blob')
+            self.assertEqual(obj.sha().hexdigest(), a_sha)
+            obj = p[tree_sha]
+            self.assertEqual(obj.type_name, 'tree')
+            self.assertEqual(obj.sha().hexdigest(), tree_sha)
+            obj = p[commit_sha]
+            self.assertEqual(obj.type_name, 'commit')
+            self.assertEqual(obj.sha().hexdigest(), commit_sha)
 
     def test_copy(self):
         origpack = self.get_pack(pack1_sha)
@@ -328,11 +328,11 @@ class TestPack(PackTests):
             origpack.close()
 
     def test_commit_obj(self):
-        p = self.get_pack(pack1_sha)
-        commit = p[commit_sha]
-        self.assertEqual('James Westby <jw+debian@jameswestby.net>',
-                          commit.author)
-        self.assertEqual([], commit.parents)
+        with self.get_pack(pack1_sha) as p:
+            commit = p[commit_sha]
+            self.assertEqual('James Westby <jw+debian@jameswestby.net>',
+                             commit.author)
+            self.assertEqual([], commit.parents)
 
     def _copy_pack(self, origpack):
         basename = os.path.join(self.tempdir, 'somepack')
@@ -340,10 +340,12 @@ class TestPack(PackTests):
         return Pack(basename)
 
     def test_keep_no_message(self):
-        p = self.get_pack(pack1_sha)
-        p = self._copy_pack(p)
+        with self.get_pack(pack1_sha) as p:
+            p = self._copy_pack(p)
+
+        with p:
+            keepfile_name = p.keep()
 
-        keepfile_name = p.keep()
         # file should exist
         self.assertTrue(os.path.exists(keepfile_name))
 
@@ -355,11 +357,12 @@ class TestPack(PackTests):
             f.close()
 
     def test_keep_message(self):
-        p = self.get_pack(pack1_sha)
-        p = self._copy_pack(p)
+        with self.get_pack(pack1_sha) as p:
+            p = self._copy_pack(p)
 
         msg = 'some message'
-        keepfile_name = p.keep(msg)
+        with p:
+            keepfile_name = p.keep(msg)
 
         # file should exist
         self.assertTrue(os.path.exists(keepfile_name))
@@ -373,46 +376,46 @@ class TestPack(PackTests):
             f.close()
 
     def test_name(self):
-        p = self.get_pack(pack1_sha)
-        self.assertEqual(pack1_sha, p.name())
+        with self.get_pack(pack1_sha) as p:
+            self.assertEqual(pack1_sha, p.name())
 
     def test_length_mismatch(self):
-        data = self.get_pack_data(pack1_sha)
-        index = self.get_pack_index(pack1_sha)
-        Pack.from_objects(data, index).check_length_and_checksum()
-
-        data._file.seek(12)
-        bad_file = BytesIO()
-        write_pack_header(bad_file, 9999)
-        bad_file.write(data._file.read())
-        bad_file = BytesIO(bad_file.getvalue())
-        bad_data = PackData('', file=bad_file)
-        bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index)
-        self.assertRaises(AssertionError, lambda: bad_pack.data)
-        self.assertRaises(AssertionError,
-                          lambda: bad_pack.check_length_and_checksum())
+        with self.get_pack_data(pack1_sha) as data:
+            index = self.get_pack_index(pack1_sha)
+            Pack.from_objects(data, index).check_length_and_checksum()
+
+            data._file.seek(12)
+            bad_file = BytesIO()
+            write_pack_header(bad_file, 9999)
+            bad_file.write(data._file.read())
+            bad_file = BytesIO(bad_file.getvalue())
+            bad_data = PackData('', file=bad_file)
+            bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index)
+            self.assertRaises(AssertionError, lambda: bad_pack.data)
+            self.assertRaises(AssertionError,
+                              lambda: bad_pack.check_length_and_checksum())
 
     def test_checksum_mismatch(self):
-        data = self.get_pack_data(pack1_sha)
-        index = self.get_pack_index(pack1_sha)
-        Pack.from_objects(data, index).check_length_and_checksum()
-
-        data._file.seek(0)
-        bad_file = BytesIO(data._file.read()[:-20] + ('\xff' * 20))
-        bad_data = PackData('', file=bad_file)
-        bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index)
-        self.assertRaises(ChecksumMismatch, lambda: bad_pack.data)
-        self.assertRaises(ChecksumMismatch, lambda:
-                          bad_pack.check_length_and_checksum())
+        with self.get_pack_data(pack1_sha) as data:
+            index = self.get_pack_index(pack1_sha)
+            Pack.from_objects(data, index).check_length_and_checksum()
+
+            data._file.seek(0)
+            bad_file = BytesIO(data._file.read()[:-20] + ('\xff' * 20))
+            bad_data = PackData('', file=bad_file)
+            bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index)
+            self.assertRaises(ChecksumMismatch, lambda: bad_pack.data)
+            self.assertRaises(ChecksumMismatch, lambda:
+                              bad_pack.check_length_and_checksum())
 
     def test_iterobjects_2(self):
-        p = self.get_pack(pack1_sha)
-        objs = dict((o.id, o) for o in p.iterobjects())
-        self.assertEqual(3, len(objs))
-        self.assertEqual(sorted(objs), sorted(p.index))
-        self.assertTrue(isinstance(objs[a_sha], Blob))
-        self.assertTrue(isinstance(objs[tree_sha], Tree))
-        self.assertTrue(isinstance(objs[commit_sha], Commit))
+        with self.get_pack(pack1_sha) as p:
+            objs = dict((o.id, o) for o in p.iterobjects())
+            self.assertEqual(3, len(objs))
+            self.assertEqual(sorted(objs), sorted(p.index))
+            self.assertTrue(isinstance(objs[a_sha], Blob))
+            self.assertTrue(isinstance(objs[tree_sha], Tree))
+            self.assertTrue(isinstance(objs[commit_sha], Commit))
 
 
 class TestThinPack(PackTests):
@@ -443,10 +446,10 @@ class TestThinPack(PackTests):
             f.close()
 
         # Index the new pack.
-        pack = self.make_pack(True)
-        data = PackData(pack._data_path)
-        data.pack = pack
-        data.create_index(self.pack_prefix + '.idx')
+        with self.make_pack(True) as pack:
+            with PackData(pack._data_path) as data:
+                data.pack = pack
+                data.create_index(self.pack_prefix + '.idx')
 
         del self.store[self.blobs['bar'].id]
 
@@ -456,18 +459,22 @@ class TestThinPack(PackTests):
             resolve_ext_ref=self.store.get_raw if resolve_ext_ref else None)
 
     def test_get_raw(self):
-        self.assertRaises(
-            KeyError, self.make_pack(False).get_raw, self.blobs['foo1234'].id)
-        self.assertEqual(
-            (3, 'foo1234'),
-            self.make_pack(True).get_raw(self.blobs['foo1234'].id))
+        with self.make_pack(False) as p:
+            self.assertRaises(
+                KeyError, p.get_raw, self.blobs['foo1234'].id)
+        with self.make_pack(True) as p:
+            self.assertEqual(
+                (3, 'foo1234'),
+                p.get_raw(self.blobs['foo1234'].id))
 
     def test_iterobjects(self):
-        self.assertRaises(KeyError, list, self.make_pack(False).iterobjects())
-        self.assertEqual(
-            sorted([self.blobs['foo1234'].id, self.blobs['bar'].id,
-                    self.blobs['bar2468'].id]),
-            sorted(o.id for o in self.make_pack(True).iterobjects()))
+        with self.make_pack(False) as p:
+            self.assertRaises(KeyError, list, p.iterobjects())
+        with self.make_pack(True) as p:
+            self.assertEqual(
+                sorted([self.blobs['foo1234'].id, self.blobs[b'bar'].id,
+                        self.blobs['bar2468'].id]),
+                sorted(o.id for o in p.iterobjects()))
 
 
 class WritePackTests(TestCase):

+ 6 - 4
dulwich/tests/test_porcelain.py

@@ -281,8 +281,9 @@ class SymbolicRefTests(PorcelainTestCase):
         porcelain.symbolic_ref(self.repo.path, 'force_foobar', force=True)
 
         #test if we actually changed the file
-        new_ref = self.repo.get_named_file('HEAD').read()
-        self.assertEqual(new_ref, 'ref: refs/heads/force_foobar\n')
+        with self.repo.get_named_file('HEAD') as f:
+            new_ref = f.read()
+        self.assertEqual(new_ref, b'ref: refs/heads/force_foobar\n')
 
     def test_set_symbolic_ref(self):
         c1, c2, c3 = build_commit_graph(self.repo.object_store, [[1], [2, 1],
@@ -300,8 +301,9 @@ class SymbolicRefTests(PorcelainTestCase):
         porcelain.symbolic_ref(self.repo.path, 'develop')
 
         #test if we actually changed the file
-        new_ref = self.repo.get_named_file('HEAD').read()
-        self.assertEqual(new_ref, 'ref: refs/heads/develop\n')
+        with self.repo.get_named_file('HEAD') as f:
+            new_ref = f.read()
+        self.assertEqual(new_ref, b'ref: refs/heads/develop\n')
 
 
 class DiffTreeTests(PorcelainTestCase):

+ 8 - 8
dulwich/tests/test_server.py

@@ -888,10 +888,10 @@ class UpdateServerInfoTests(TestCase):
 
     def test_empty(self):
         update_server_info(self.repo)
-        self.assertEqual("",
-            open(os.path.join(self.path, ".git", "info", "refs"), 'r').read())
-        self.assertEqual("",
-            open(os.path.join(self.path, ".git", "objects", "info", "packs"), 'r').read())
+        with open(os.path.join(self.path, ".git", "info", "refs"), 'rb') as f:
+            self.assertEqual(b'', f.read())
+        with open(os.path.join(self.path, ".git", "objects", "info", "packs"), 'rb') as f:
+            self.assertEqual(b'', f.read())
 
     def test_simple(self):
         commit_id = self.repo.do_commit(
@@ -899,7 +899,7 @@ class UpdateServerInfoTests(TestCase):
             committer="Joe Example <joe@example.com>",
             ref="refs/heads/foo")
         update_server_info(self.repo)
-        ref_text = open(os.path.join(self.path, ".git", "info", "refs"), 'r').read()
-        self.assertEqual(ref_text, "%s\trefs/heads/foo\n" % commit_id)
-        packs_text = open(os.path.join(self.path, ".git", "objects", "info", "packs"), 'r').read()
-        self.assertEqual(packs_text, "")
+        with open(os.path.join(self.path, ".git", "info", "refs"), 'rb') as f:
+            self.assertEqual(f.read(), commit_id + b'\trefs/heads/foo\n')
+        with open(os.path.join(self.path, ".git", "objects", "info", "packs"), 'rb') as f:
+            self.assertEqual(f.read(), b'')