Browse Source

client: Support smart server access over https, pass URLs around rather than tuples with hostname/port/username/etc.

Jelmer Vernooij 13 years ago
parent
commit
528cbe0287
2 changed files with 20 additions and 25 deletions
  1. 18 22
      dulwich/client.py
  2. 2 3
      dulwich/tests/compat/test_client.py

+ 18 - 22
dulwich/client.py

@@ -573,31 +573,27 @@ class SSHGitClient(TraditionalGitClient):
 
 class HttpGitClient(GitClient):
 
-    def __init__(self, host, port=None, username=None, password=None, dumb=None, *args, **kwargs):
-        self.host = host
-        self.port = port
+    def __init__(self, base_url, dumb=None, *args, **kwargs):
+        self.base_url = base_url
         self.dumb = dumb
-        self.username = username
-        self.password = password
-        netloc = self.host
-        if self.port:
-            netloc += ":%d" % self.port
-        self.url = "http://%s/" % netloc
         GitClient.__init__(self, *args, **kwargs)
 
-    @classmethod
-    def from_url(cls, url):
-        parsed = urlparse.urlparse(url)
-        assert parsed.scheme == 'http'
-        return cls(parsed.hostname, port=parsed.port, username=parsed.port,
-                   password=parsed.password), parsed.path
+    def _perform(self, req):
+        """Perform a HTTP request.
+
+        This is provided so subclasses can provide their own version.
+
+        :param req: urllib2.Request instance
+        :return: matching response
+        """
+        return urllib2.urlopen(req)
 
     def _discover_references(self, service, url):
         url = urlparse.urljoin(url+"/", "info/refs")
         if not self.dumb:
             url += "?service=%s" % service
         req = urllib2.Request(url)
-        resp = urllib2.urlopen(req)
+        resp = self._perform(req)
         if resp.getcode() == 404:
             raise NotGitRepository()
         if resp.getcode() != 200:
@@ -618,7 +614,7 @@ class HttpGitClient(GitClient):
         req = urllib2.Request(url,
             headers={"Content-Type": "application/x-%s-request" % service},
             data=data)
-        resp = urllib2.urlopen(req)
+        resp = self._perform(req)
         if resp.getcode() == 404:
             raise NotGitRepository()
         if resp.getcode() != 200:
@@ -642,7 +638,7 @@ class HttpGitClient(GitClient):
         :raises UpdateRefsError: if the server supports report-status
                                  and rejects ref updates
         """
-        url = urlparse.urljoin(self.url, path)
+        url = urlparse.urljoin(self.base_url, path)
         old_refs, server_capabilities = self._discover_references("git-receive-pack", url)
         negotiated_capabilities = list(self._send_capabilities)
         new_refs = determine_wants(old_refs)
@@ -674,7 +670,7 @@ class HttpGitClient(GitClient):
         :param pack_data: Callback called for each bit of data in the pack
         :param progress: Callback for progress reports (strings)
         """
-        url = urlparse.urljoin(self.url, path)
+        url = urlparse.urljoin(self.base_url, path)
         refs, server_capabilities = self._discover_references(
             "git-upload-pack", url)
         negotiated_capabilities = list(server_capabilities)
@@ -708,9 +704,9 @@ def get_transport_and_path(uri):
     elif parsed.scheme == 'git+ssh':
         return SSHGitClient(parsed.hostname, port=parsed.port,
                             username=parsed.username), parsed.path
-    elif parsed.scheme == 'http':
-        return HttpGitClient(parsed.hostname, port=parsed.port,
-                             username=parsed.username), parsed.path
+    elif parsed.scheme in ('http', 'https'):
+        return HttpGitClient(urlparse.urlunparse(
+            parsed.scheme, parsed.netloc, path='/'))
 
     if parsed.scheme and not parsed.netloc:
         # SSH with no user@, zero or one leading slash.

+ 2 - 3
dulwich/tests/compat/test_client.py

@@ -419,8 +419,7 @@ class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):
         CompatTestCase.tearDown(self)
 
     def _client(self):
-        ret, self._path = client.HttpGitClient.from_url(self._httpd.get_url())
-        return ret
+        return client.HttpGitClient(self._httpd.get_url())
 
     def _build_path(self, path):
-        return urlparse.urljoin(self._path.strip("/"), path.strip("/"))
+        return path