Browse Source

gc: improve test coverage (#1679)

Jelmer Vernooij 1 month ago
parent
commit
cb986035b0
1 changed files with 248 additions and 1 deletions
  1. 248 1
      tests/test_gc.py

+ 248 - 1
tests/test_gc.py

@@ -17,7 +17,7 @@ from dulwich.gc import (
     prune_unreachable_objects,
     should_run_gc,
 )
-from dulwich.objects import Blob, Commit, Tree
+from dulwich.objects import Blob, Commit, Tag, Tree
 from dulwich.repo import MemoryRepo, Repo
 
 
@@ -271,6 +271,151 @@ class GCTestCase(TestCase):
         self.assertGreater(stats.bytes_freed, 0)
         self.assertNotIn(unreachable_blob.id, self.repo.object_store)
 
+    def test_garbage_collect_with_progress(self):
+        """Test garbage collection with progress callback."""
+        # Create some objects
+        blob = Blob.from_string(b"test content")
+        self.repo.object_store.add_object(blob)
+
+        tree = Tree()
+        tree.add(b"test.txt", 0o100644, blob.id)
+        self.repo.object_store.add_object(tree)
+
+        commit = Commit()
+        commit.tree = tree.id
+        commit.author = commit.committer = b"Test User <test@example.com>"
+        commit.commit_time = commit.author_time = 1234567890
+        commit.commit_timezone = commit.author_timezone = 0
+        commit.message = b"Test commit"
+        self.repo.object_store.add_object(commit)
+
+        self.repo.refs[b"HEAD"] = commit.id
+
+        # Create an unreachable blob
+        unreachable_blob = Blob.from_string(b"unreachable content")
+        self.repo.object_store.add_object(unreachable_blob)
+
+        # Track progress messages
+        progress_messages = []
+
+        def progress_callback(msg):
+            progress_messages.append(msg)
+
+        # Run garbage collection with progress
+        garbage_collect(
+            self.repo, prune=True, grace_period=None, progress=progress_callback
+        )
+
+        # Check that progress was reported
+        self.assertGreater(len(progress_messages), 0)
+        self.assertIn("Finding unreachable objects", progress_messages)
+        self.assertIn("Packing references", progress_messages)
+        self.assertIn("Repacking repository", progress_messages)
+        self.assertIn("Pruning temporary files", progress_messages)
+
+    def test_find_reachable_objects_with_broken_ref(self):
+        """Test finding reachable objects with a broken ref."""
+        # Create a valid object
+        blob = Blob.from_string(b"test content")
+        self.repo.object_store.add_object(blob)
+
+        # Create a commit pointing to the blob
+        tree = Tree()
+        tree.add(b"test.txt", 0o100644, blob.id)
+        self.repo.object_store.add_object(tree)
+
+        commit = Commit()
+        commit.tree = tree.id
+        commit.author = commit.committer = b"Test User <test@example.com>"
+        commit.commit_time = commit.author_time = 1234567890
+        commit.commit_timezone = commit.author_timezone = 0
+        commit.message = b"Test commit"
+        self.repo.object_store.add_object(commit)
+
+        self.repo.refs[b"HEAD"] = commit.id
+
+        # Create a broken ref pointing to non-existent object
+        broken_sha = b"0" * 40
+        self.repo.refs[b"refs/heads/broken"] = broken_sha
+
+        # Track progress to see warning
+        progress_messages = []
+
+        def progress_callback(msg):
+            progress_messages.append(msg)
+
+        # Find reachable objects
+        reachable = find_reachable_objects(
+            self.repo.object_store, self.repo.refs, progress=progress_callback
+        )
+
+        # Valid objects should still be found, plus the broken ref SHA
+        # (which will be included in reachable but won't be walkable)
+        self.assertEqual({blob.id, tree.id, commit.id, broken_sha}, reachable)
+
+        # Check that we got a message about checking the broken object
+        # The warning happens when trying to walk from the broken SHA
+        check_messages = [msg for msg in progress_messages if "Checking object" in msg]
+        self.assertTrue(
+            any(broken_sha.decode("ascii") in msg for msg in check_messages)
+        )
+
+    def test_find_reachable_objects_with_tag(self):
+        """Test finding reachable objects through tags."""
+        # Create a blob
+        blob = Blob.from_string(b"tagged content")
+        self.repo.object_store.add_object(blob)
+
+        # Create a tree
+        tree = Tree()
+        tree.add(b"tagged.txt", 0o100644, blob.id)
+        self.repo.object_store.add_object(tree)
+
+        # Create a commit
+        commit = Commit()
+        commit.tree = tree.id
+        commit.author = commit.committer = b"Test User <test@example.com>"
+        commit.commit_time = commit.author_time = 1234567890
+        commit.commit_timezone = commit.author_timezone = 0
+        commit.message = b"Tagged commit"
+        self.repo.object_store.add_object(commit)
+
+        # Create a tag pointing to the commit
+        tag = Tag()
+        tag.name = b"v1.0"
+        tag.message = b"Version 1.0"
+        tag.tag_time = 1234567890
+        tag.tag_timezone = 0
+        tag.object = (Commit, commit.id)
+        tag.tagger = b"Test Tagger <tagger@example.com>"
+        self.repo.object_store.add_object(tag)
+
+        # Set a ref to the tag
+        self.repo.refs[b"refs/tags/v1.0"] = tag.id
+
+        # Find reachable objects
+        reachable = find_reachable_objects(self.repo.object_store, self.repo.refs)
+
+        # All objects should be reachable through the tag
+        self.assertEqual({blob.id, tree.id, commit.id, tag.id}, reachable)
+
+    def test_prune_with_missing_mtime(self):
+        """Test pruning when get_object_mtime raises KeyError."""
+        # Create an unreachable blob
+        unreachable_blob = Blob.from_string(b"unreachable content")
+        self.repo.object_store.add_object(unreachable_blob)
+
+        # Mock get_object_mtime to raise KeyError
+        with patch.object(
+            self.repo.object_store, "get_object_mtime", side_effect=KeyError
+        ):
+            # Run garbage collection with grace period
+            stats = garbage_collect(self.repo, prune=True, grace_period=3600)
+
+        # Object should be kept because mtime couldn't be determined
+        self.assertEqual(set(), stats.pruned_objects)
+        self.assertEqual(0, stats.bytes_freed)
+
 
 class AutoGCTestCase(TestCase):
     """Tests for auto GC functionality."""
@@ -474,3 +619,105 @@ class AutoGCTestCase(TestCase):
             with open(gc_log_path, "rb") as f:
                 content = f.read()
                 self.assertIn(b"Auto GC failed: GC failed", content)
+
+    def test_gc_log_expiry_singular_day(self):
+        """Test that gc.logExpiry supports singular '.day' format."""
+        with tempfile.TemporaryDirectory() as tmpdir:
+            r = Repo.init(tmpdir)
+            config = ConfigDict()
+            config.set(b"gc", b"auto", b"1")  # Low threshold
+            config.set(b"gc", b"logExpiry", b"1.day")  # Singular form
+
+            # Create gc.log file
+            gc_log_path = os.path.join(r.controldir(), "gc.log")
+            with open(gc_log_path, "wb") as f:
+                f.write(b"Previous GC failed\n")
+
+            # Make the file 2 days old (older than 1 day expiry)
+            old_time = time.time() - (2 * 86400)
+            os.utime(gc_log_path, (old_time, old_time))
+
+            # Add objects to trigger GC
+            blob = Blob()
+            blob.data = b"test"
+            r.object_store.add_object(blob)
+
+            with patch("dulwich.gc.garbage_collect") as mock_gc:
+                result = maybe_auto_gc(r, config)
+
+            self.assertTrue(result)
+            mock_gc.assert_called_once_with(r, auto=True)
+
+    def test_gc_log_expiry_invalid_format(self):
+        """Test that invalid gc.logExpiry format defaults to 1 day."""
+        with tempfile.TemporaryDirectory() as tmpdir:
+            r = Repo.init(tmpdir)
+            config = ConfigDict()
+            config.set(b"gc", b"auto", b"1")  # Low threshold
+            config.set(b"gc", b"logExpiry", b"invalid")  # Invalid format
+
+            # Create gc.log file
+            gc_log_path = os.path.join(r.controldir(), "gc.log")
+            with open(gc_log_path, "wb") as f:
+                f.write(b"Previous GC failed\n")
+
+            # Make the file recent (within default 1 day)
+            recent_time = time.time() - 3600  # 1 hour ago
+            os.utime(gc_log_path, (recent_time, recent_time))
+
+            # Add objects to trigger GC
+            blob = Blob()
+            blob.data = b"test"
+            r.object_store.add_object(blob)
+
+            with patch("builtins.print") as mock_print:
+                result = maybe_auto_gc(r, config)
+
+            # Should not run GC because gc.log is recent (within default 1 day)
+            self.assertFalse(result)
+            mock_print.assert_called_once()
+
+    def test_maybe_auto_gc_non_disk_repo(self):
+        """Test auto GC on non-disk repository (MemoryRepo)."""
+        r = MemoryRepo()
+        config = ConfigDict()
+        config.set(b"gc", b"auto", b"1")  # Would trigger if it were disk-based
+
+        # Add objects that would trigger GC in a disk repo
+        for i in range(10):
+            blob = Blob()
+            blob.data = f"test {i}".encode()
+            r.object_store.add_object(blob)
+
+        # For non-disk repos, should_run_gc returns False
+        # because it can't count loose objects
+        result = maybe_auto_gc(r, config)
+        self.assertFalse(result)
+
+    def test_gc_removes_existing_gc_log_on_success(self):
+        """Test that successful GC removes pre-existing gc.log file."""
+        with tempfile.TemporaryDirectory() as tmpdir:
+            r = Repo.init(tmpdir)
+            config = ConfigDict()
+            config.set(b"gc", b"auto", b"1")  # Low threshold
+
+            # Create gc.log file from previous failure
+            gc_log_path = os.path.join(r.controldir(), "gc.log")
+            with open(gc_log_path, "wb") as f:
+                f.write(b"Previous GC failed\n")
+
+            # Make it old enough to be expired
+            old_time = time.time() - (2 * 86400)  # 2 days ago
+            os.utime(gc_log_path, (old_time, old_time))
+
+            # Add objects to trigger GC
+            blob = Blob()
+            blob.data = b"test"
+            r.object_store.add_object(blob)
+
+            # Run auto GC
+            result = maybe_auto_gc(r, config)
+
+            self.assertTrue(result)
+            # gc.log should be removed after successful GC
+            self.assertFalse(os.path.exists(gc_log_path))