Browse Source

Addressing PR comments

 - Fix docstring
 - Less confusing abspath/split/iterate code
 - More tests
bmcorser 9 years ago
parent
commit
d64b632e40
2 changed files with 17 additions and 8 deletions
  1. 8 7
      dulwich/repo.py
  2. 9 1
      dulwich/tests/test_repository.py

+ 8 - 7
dulwich/repo.py

@@ -679,19 +679,20 @@ class Repo(BaseRepo):
         self.hooks['post-commit'] = PostCommitShellHook(self.controldir())
 
     @classmethod
-    def discover(cls, start):
-        """
+    def discover(cls, start='.'):
+        """Iterate parent directories to discover a repository
+
         Return a Repo object for the first parent directory that looks like a
         Git repository.
 
-        :param start: The directory to start discovery from
+        :param start: The directory to start discovery from (defaults to '.')
         """
-        abs_split = os.path.abspath(start)[1:].split(os.path.sep)
-        for _ in range(len(abs_split)):
+        path = os.path.abspath(start)
+        while path != '/':
             try:
-                return cls(os.path.join('/', *abs_split))
+                return cls(path)
             except NotGitRepository:
-                abs_split.pop()
+                path, _ = os.path.split(path)
         raise NotGitRepository(
             "No git repository was found at %(path)s" % dict(path=start)
         )

+ 9 - 1
dulwich/tests/test_repository.py

@@ -35,6 +35,7 @@ from dulwich.object_store import (
     )
 from dulwich import objects
 from dulwich.config import Config
+from dulwich.errors import NotGitRepository
 from dulwich.repo import (
     Repo,
     MemoryRepo,
@@ -766,7 +767,14 @@ class BuildRepoRootTests(TestCase):
             self.assertEqual(stat.S_IFREG | 0o644, mode)
             self.assertEqual(encoding.encode('ascii'), r[id].data)
 
-    def test_discover(self):
+    def test_discover_intended(self):
         path = os.path.join(self._repo_dir, 'b/c')
         r = Repo.discover(path)
         self.assertEqual(r.head(), self._repo.head())
+
+    def test_discover_isrepo(self):
+        r = Repo.discover(self._repo_dir)
+        self.assertEqual(r.head(), self._repo.head())
+
+    def test_discover_notrepo(self):
+        self.assertRaises(NotGitRepository, Repo.discover('/'))