Explorar o código

client: Rewrite URI parsing in get_transport_and_path.

Fixes bug 568493.

Change-Id: I1382d70f0357e8754863e2dbebffa861c0a69504
Signed-off-by: Jelmer Vernooij <jelmer@samba.org>
Dave Borowitz %!s(int64=14) %!d(string=hai) anos
pai
achega
a9e4643b42
Modificáronse 2 ficheiros con 81 adicións e 7 borrados
  1. 20 7
      dulwich/client.py
  2. 61 0
      dulwich/tests/test_client.py

+ 20 - 7
dulwich/client.py

@@ -24,6 +24,7 @@ __docformat__ = 'restructuredText'
 import select
 import socket
 import subprocess
+import urlparse
 
 from dulwich.errors import (
     SendPackError,
@@ -358,11 +359,23 @@ def get_transport_and_path(uri):
     :param uri: URI or path
     :return: Tuple with client instance and relative path.
     """
-    from dulwich.client import TCPGitClient, SSHGitClient, SubprocessGitClient
-    for handler, transport in (("git://", TCPGitClient), ("git+ssh://", SSHGitClient)):
-        if uri.startswith(handler):
-            host, path = uri[len(handler):].split("/", 1)
-            return transport(host), "/"+path
-    # FIXME: Parse rsync-like git URLs (user@host:/path), bug 568493
-    # if its not git or git+ssh, try a local url..
+    parsed = urlparse.urlparse(uri)
+    if parsed.scheme == 'git':
+        return TCPGitClient(parsed.hostname, port=parsed.port), parsed.path
+    elif parsed.scheme == 'git+ssh':
+        return SSHGitClient(parsed.hostname, port=parsed.port,
+                            username=parsed.username), parsed.path
+
+    if parsed.scheme and not parsed.netloc:
+        # SSH with no user@, zero or one leading slash.
+        return SSHGitClient(parsed.scheme), parsed.path
+    elif parsed.scheme:
+        raise ValueError('Unknown git protocol scheme: %s' % parsed.scheme)
+    elif '@' in parsed.path and ':' in parsed.path:
+        # SSH with user@host:foo.
+        user_host, path = parsed.path.split(':')
+        user, host = user_host.rsplit('@')
+        return SSHGitClient(host, username=user), path
+
+    # Otherwise, assume it's a local path.
     return SubprocessGitClient(), uri

+ 61 - 0
dulwich/tests/test_client.py

@@ -20,12 +20,16 @@ from cStringIO import StringIO
 
 from dulwich.client import (
     GitClient,
+    TCPGitClient,
+    SubprocessGitClient,
     SSHGitClient,
+    get_transport_and_path,
     )
 from dulwich.tests import (
     TestCase,
     )
 from dulwich.protocol import (
+    TCP_GIT_PORT,
     Protocol,
     )
 
@@ -66,6 +70,63 @@ class GitClientTests(TestCase):
         self.client.fetch_pack("bla", lambda heads: [], None, None, None)
         self.assertEquals(self.rout.getvalue(), "0000")
 
+    def test_get_transport_and_path_tcp(self):
+        client, path = get_transport_and_path('git://foo.com/bar/baz')
+        self.assertTrue(isinstance(client, TCPGitClient))
+        self.assertEquals('foo.com', client._host)
+        self.assertEquals(TCP_GIT_PORT, client._port)
+        self.assertEqual('/bar/baz', path)
+
+        client, path = get_transport_and_path('git://foo.com:1234/bar/baz')
+        self.assertTrue(isinstance(client, TCPGitClient))
+        self.assertEquals('foo.com', client._host)
+        self.assertEquals(1234, client._port)
+        self.assertEqual('/bar/baz', path)
+
+    def test_get_transport_and_path_ssh_explicit(self):
+        client, path = get_transport_and_path('git+ssh://foo.com/bar/baz')
+        self.assertTrue(isinstance(client, SSHGitClient))
+        self.assertEquals('foo.com', client.host)
+        self.assertEquals(None, client.port)
+        self.assertEquals(None, client.username)
+        self.assertEqual('/bar/baz', path)
+
+        client, path = get_transport_and_path('git+ssh://foo.com:1234/bar/baz')
+        self.assertTrue(isinstance(client, SSHGitClient))
+        self.assertEquals('foo.com', client.host)
+        self.assertEquals(1234, client.port)
+        self.assertEqual('/bar/baz', path)
+
+    def test_get_transport_and_path_ssh_implicit(self):
+        client, path = get_transport_and_path('foo:/bar/baz')
+        self.assertTrue(isinstance(client, SSHGitClient))
+        self.assertEquals('foo', client.host)
+        self.assertEquals(None, client.port)
+        self.assertEquals(None, client.username)
+        self.assertEqual('/bar/baz', path)
+
+        client, path = get_transport_and_path('foo.com:/bar/baz')
+        self.assertTrue(isinstance(client, SSHGitClient))
+        self.assertEquals('foo.com', client.host)
+        self.assertEquals(None, client.port)
+        self.assertEquals(None, client.username)
+        self.assertEqual('/bar/baz', path)
+
+        client, path = get_transport_and_path('user@foo.com:/bar/baz')
+        self.assertTrue(isinstance(client, SSHGitClient))
+        self.assertEquals('foo.com', client.host)
+        self.assertEquals(None, client.port)
+        self.assertEquals('user', client.username)
+        self.assertEqual('/bar/baz', path)
+
+    def test_get_transport_and_path_subprocess(self):
+        client, path = get_transport_and_path('foo.bar/baz')
+        self.assertTrue(isinstance(client, SubprocessGitClient))
+        self.assertEquals('foo.bar/baz', path)
+
+    def test_get_transport_and_path_error(self):
+        self.assertRaises(ValueError, get_transport_and_path, 'foo://bar/baz')
+
 
 class SSHGitClientTests(TestCase):