فهرست منبع

Close files for Protocol objects.

Gary van der Merwe 11 سال پیش
والد
کامیت
c506678ad5
3فایلهای تغییر یافته به همراه152 افزوده شده و 123 حذف شده
  1. 139 122
      dulwich/client.py
  2. 12 1
      dulwich/protocol.py
  3. 1 0
      dulwich/tests/test_client.py

+ 139 - 122
dulwich/client.py

@@ -438,67 +438,68 @@ class TraditionalGitClient(GitClient):
                                  and rejects ref updates
         """
         proto, unused_can_read = self._connect('receive-pack', path)
-        old_refs, server_capabilities = read_pkt_refs(proto)
-        negotiated_capabilities = self._send_capabilities & server_capabilities
-
-        if 'report-status' in negotiated_capabilities:
-            self._report_status_parser = ReportStatusParser()
-        report_status_parser = self._report_status_parser
+        with proto:
+            old_refs, server_capabilities = read_pkt_refs(proto)
+            negotiated_capabilities = self._send_capabilities & server_capabilities
 
-        try:
-            new_refs = orig_new_refs = determine_wants(dict(old_refs))
-        except:
-            proto.write_pkt_line(None)
-            raise
-
-        if not 'delete-refs' in server_capabilities:
-            # Server does not support deletions. Fail later.
-            def remove_del(pair):
-                if pair[1] == ZERO_SHA:
-                    if 'report-status' in negotiated_capabilities:
-                        report_status_parser._ref_statuses.append(
-                            'ng %s remote does not support deleting refs'
-                            % pair[1])
-                        report_status_parser._ref_status_ok = False
-                    return False
-                else:
-                    return True
-
-            new_refs = dict(
-                filter(
-                    remove_del,
-                    [(ref, sha) for ref, sha in new_refs.iteritems()]))
-
-        if new_refs is None:
-            proto.write_pkt_line(None)
-            return old_refs
+            if 'report-status' in negotiated_capabilities:
+                self._report_status_parser = ReportStatusParser()
+            report_status_parser = self._report_status_parser
 
-        if len(new_refs) == 0 and len(orig_new_refs):
-            # NOOP - Original new refs filtered out by policy
-            proto.write_pkt_line(None)
-            if self._report_status_parser is not None:
-                self._report_status_parser.check()
-            return old_refs
-
-        (have, want) = self._handle_receive_pack_head(
-            proto, negotiated_capabilities, old_refs, new_refs)
-        if not want and old_refs == new_refs:
-            return new_refs
-        objects = generate_pack_contents(have, want)
-        if len(objects) > 0:
-            entries, sha = write_pack_objects(proto.write_file(), objects)
-        elif len(set(new_refs.values()) - set([ZERO_SHA])) > 0:
-            # Check for valid create/update refs
-            filtered_new_refs = \
-                dict([(ref, sha) for ref, sha in new_refs.iteritems()
-                     if sha != ZERO_SHA])
-            if len(set(filtered_new_refs.iteritems()) -
-                    set(old_refs.iteritems())) > 0:
+            try:
+                new_refs = orig_new_refs = determine_wants(dict(old_refs))
+            except:
+                proto.write_pkt_line(None)
+                raise
+
+            if not 'delete-refs' in server_capabilities:
+                # Server does not support deletions. Fail later.
+                def remove_del(pair):
+                    if pair[1] == ZERO_SHA:
+                        if 'report-status' in negotiated_capabilities:
+                            report_status_parser._ref_statuses.append(
+                                'ng %s remote does not support deleting refs'
+                                % pair[1])
+                            report_status_parser._ref_status_ok = False
+                        return False
+                    else:
+                        return True
+
+                new_refs = dict(
+                    filter(
+                        remove_del,
+                        [(ref, sha) for ref, sha in new_refs.iteritems()]))
+
+            if new_refs is None:
+                proto.write_pkt_line(None)
+                return old_refs
+
+            if len(new_refs) == 0 and len(orig_new_refs):
+                # NOOP - Original new refs filtered out by policy
+                proto.write_pkt_line(None)
+                if self._report_status_parser is not None:
+                    self._report_status_parser.check()
+                return old_refs
+
+            (have, want) = self._handle_receive_pack_head(
+                proto, negotiated_capabilities, old_refs, new_refs)
+            if not want and old_refs == new_refs:
+                return new_refs
+            objects = generate_pack_contents(have, want)
+            if len(objects) > 0:
                 entries, sha = write_pack_objects(proto.write_file(), objects)
-
-        self._handle_receive_pack_tail(
-            proto, negotiated_capabilities, progress)
-        return new_refs
+            elif len(set(new_refs.values()) - set([ZERO_SHA])) > 0:
+                # Check for valid create/update refs
+                filtered_new_refs = \
+                    dict([(ref, sha) for ref, sha in new_refs.iteritems()
+                         if sha != ZERO_SHA])
+                if len(set(filtered_new_refs.iteritems()) -
+                        set(old_refs.iteritems())) > 0:
+                    entries, sha = write_pack_objects(proto.write_file(), objects)
+
+            self._handle_receive_pack_tail(
+                proto, negotiated_capabilities, progress)
+            return new_refs
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
                    progress=None):
@@ -510,47 +511,49 @@ class TraditionalGitClient(GitClient):
         :param progress: Callback for progress reports (strings)
         """
         proto, can_read = self._connect('upload-pack', path)
-        refs, server_capabilities = read_pkt_refs(proto)
-        negotiated_capabilities = (
-            self._fetch_capabilities & server_capabilities)
+        with proto:
+            refs, server_capabilities = read_pkt_refs(proto)
+            negotiated_capabilities = (
+                self._fetch_capabilities & server_capabilities)
 
-        if refs is None:
-            proto.write_pkt_line(None)
-            return refs
+            if refs is None:
+                proto.write_pkt_line(None)
+                return refs
 
-        try:
-            wants = determine_wants(refs)
-        except:
-            proto.write_pkt_line(None)
-            raise
-        if wants is not None:
-            wants = [cid for cid in wants if cid != ZERO_SHA]
-        if not wants:
-            proto.write_pkt_line(None)
+            try:
+                wants = determine_wants(refs)
+            except:
+                proto.write_pkt_line(None)
+                raise
+            if wants is not None:
+                wants = [cid for cid in wants if cid != ZERO_SHA]
+            if not wants:
+                proto.write_pkt_line(None)
+                return refs
+            self._handle_upload_pack_head(
+                proto, negotiated_capabilities, graph_walker, wants, can_read)
+            self._handle_upload_pack_tail(
+                proto, negotiated_capabilities, graph_walker, pack_data, progress)
             return refs
-        self._handle_upload_pack_head(
-            proto, negotiated_capabilities, graph_walker, wants, can_read)
-        self._handle_upload_pack_tail(
-            proto, negotiated_capabilities, graph_walker, pack_data, progress)
-        return refs
 
     def archive(self, path, committish, write_data, progress=None):
-        proto, can_read = self._connect('upload-archive', path)
-        proto.write_pkt_line("argument %s" % committish)
-        proto.write_pkt_line(None)
-        pkt = proto.read_pkt_line()
-        if pkt == "NACK\n":
-            return
-        elif pkt == "ACK\n":
-            pass
-        elif pkt.startswith("ERR "):
-            raise GitProtocolError(pkt[4:].rstrip("\n"))
-        else:
-            raise AssertionError("invalid response %r" % pkt)
-        ret = proto.read_pkt_line()
-        if ret is not None:
-            raise AssertionError("expected pkt tail")
-        self._read_side_band64k_data(proto, {1: write_data, 2: progress})
+        proto, can_read = self._connect(b'upload-archive', path)
+        with proto:
+            proto.write_pkt_line("argument %s" % committish)
+            proto.write_pkt_line(None)
+            pkt = proto.read_pkt_line()
+            if pkt == "NACK\n":
+                return
+            elif pkt == "ACK\n":
+                pass
+            elif pkt.startswith("ERR "):
+                raise GitProtocolError(pkt[4:].rstrip("\n"))
+            else:
+                raise AssertionError("invalid response %r" % pkt)
+            ret = proto.read_pkt_line()
+            if ret is not None:
+                raise AssertionError("expected pkt tail")
+            self._read_side_band64k_data(proto, {1: write_data, 2: progress})
 
 
 class TCPGitClient(TraditionalGitClient):
@@ -584,7 +587,12 @@ class TCPGitClient(TraditionalGitClient):
         rfile = s.makefile('rb', -1)
         # 0 means unbuffered
         wfile = s.makefile('wb', 0)
-        proto = Protocol(rfile.read, wfile.write,
+        def close():
+            rfile.close()
+            wfile.close()
+            s.close()
+
+        proto = Protocol(rfile.read, wfile.write, close,
                          report_activity=self._report_activity)
         if path.startswith("/~"):
             path = path[1:]
@@ -612,6 +620,8 @@ class SubprocessWrapper(object):
     def close(self):
         self.proc.stdin.close()
         self.proc.stdout.close()
+        if self.proc.stderr:
+            self.proc.stderr.close()
         self.proc.wait()
 
 
@@ -633,7 +643,7 @@ class SubprocessGitClient(TraditionalGitClient):
             subprocess.Popen(argv, bufsize=0, stdin=subprocess.PIPE,
                              stdout=subprocess.PIPE,
                              stderr=self._stderr))
-        return Protocol(p.read, p.write,
+        return Protocol(p.read, p.write, p.close,
                         report_activity=self._report_activity), p.can_read
 
 
@@ -828,9 +838,6 @@ else:
             self.channel.close()
             self.stop_monitoring()
 
-        def __del__(self):
-            self.close()
-
     class ParamikoSSHVendor(object):
 
         def __init__(self):
@@ -882,9 +889,9 @@ class SSHGitClient(TraditionalGitClient):
         con = get_ssh_vendor().run_command(
             self.host, ["%s '%s'" % (self._get_cmd_path(cmd), path)],
             port=self.port, username=self.username)
-        return (Protocol(
-            con.read, con.write, report_activity=self._report_activity),
-            con.can_read)
+        return (Protocol(con.read, con.write, con.close, 
+                         report_activity=self._report_activity), 
+                con.can_read)
 
 
 def default_user_agent_string():
@@ -944,17 +951,20 @@ class HttpGitClient(GitClient):
             url += "?service=%s" % service
             headers["Content-Type"] = "application/x-%s-request" % service
         resp = self._http_request(url, headers)
-        self.dumb = (not resp.info().gettype().startswith("application/x-git-"))
-        if not self.dumb:
-            proto = Protocol(resp.read, None)
-            # The first line should mention the service
-            pkts = list(proto.read_pkt_seq())
-            if pkts != [('# service=%s\n' % service)]:
-                raise GitProtocolError(
-                    "unexpected first line %r from smart server" % pkts)
-            return read_pkt_refs(proto)
-        else:
-            return read_info_refs(resp), set()
+        try:
+            self.dumb = (not resp.info().gettype().startswith("application/x-git-"))
+            if not self.dumb:
+                proto = Protocol(resp.read, None)
+                # The first line should mention the service
+                pkts = list(proto.read_pkt_seq())
+                if pkts != [('# service=%s\n' % service)]:
+                    raise GitProtocolError(
+                        "unexpected first line %r from smart server" % pkts)
+                return read_pkt_refs(proto)
+            else:
+                return read_info_refs(resp), set()
+        finally:
+            resp.close()
 
     def _smart_request(self, service, url, data):
         assert url[-1] == "/"
@@ -1002,11 +1012,15 @@ class HttpGitClient(GitClient):
         if len(objects) > 0:
             entries, sha = write_pack_objects(req_proto.write_file(), objects)
         resp = self._smart_request("git-receive-pack", url,
-            data=req_data.getvalue())
-        resp_proto = Protocol(resp.read, None)
-        self._handle_receive_pack_tail(resp_proto, negotiated_capabilities,
-            progress)
-        return new_refs
+                                   data=req_data.getvalue())
+        try:
+            resp_proto = Protocol(resp.read, None)
+            self._handle_receive_pack_tail(resp_proto, negotiated_capabilities,
+                progress)
+            return new_refs
+        finally:
+            resp.close()
+
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
                    progress=None):
@@ -1036,10 +1050,13 @@ class HttpGitClient(GitClient):
             lambda: False)
         resp = self._smart_request(
             "git-upload-pack", url, data=req_data.getvalue())
-        resp_proto = Protocol(resp.read, None)
-        self._handle_upload_pack_tail(resp_proto, negotiated_capabilities,
-            graph_walker, pack_data, progress)
-        return refs
+        try:
+            resp_proto = Protocol(resp.read, None)
+            self._handle_upload_pack_tail(resp_proto, negotiated_capabilities,
+                graph_walker, pack_data, progress)
+            return refs
+        finally:
+            resp.close()
 
 
 def get_transport_and_path_from_url(url, config=None, **kwargs):

+ 12 - 1
dulwich/protocol.py

@@ -77,12 +77,23 @@ class Protocol(object):
         Documentation/technical/protocol-common.txt
     """
 
-    def __init__(self, read, write, report_activity=None):
+    def __init__(self, read, write, close=None, report_activity=None):
         self.read = read
         self.write = write
+        self._close = close
         self.report_activity = report_activity
         self._readahead = None
 
+    def close(self):
+        if self._close:
+            self._close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
     def read_pkt_line(self):
         """Reads a pkt-line from the remote git process.
 

+ 1 - 0
dulwich/tests/test_client.py

@@ -492,6 +492,7 @@ class TestSSHVendor(object):
         class Subprocess: pass
         setattr(Subprocess, 'read', lambda: None)
         setattr(Subprocess, 'write', lambda: None)
+        setattr(Subprocess, 'close', lambda: None)
         setattr(Subprocess, 'can_read', lambda: None)
         return Subprocess()