Forráskód Böngészése

Add more object_store tests (#1538)

Jelmer Vernooij 3 hete
szülő
commit
e54e55a978
2 módosított fájl, 46 hozzáadás és 8 törlés
  1. 12 4
      dulwich/object_store.py
  2. 34 4
      dulwich/tests/test_object_store.py

+ 12 - 4
dulwich/object_store.py

@@ -1577,12 +1577,20 @@ class OverlayObjectStore(BaseObjectStore):
         self, shas: Iterable[bytes], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
         todo = set(shas)
+        found: set[bytes] = set()
+
         for b in self.bases:
-            for o in b.iterobjects_subset(todo, allow_missing=True):
+            # Create a copy of todo for each base to avoid modifying
+            # the set while iterating through it
+            current_todo = todo - found
+            for o in b.iterobjects_subset(current_todo, allow_missing=True):
                 yield o
-                todo.remove(o.id)
-        if todo and not allow_missing:
-            raise KeyError(o.id)
+                found.add(o.id)
+
+        # Check for any remaining objects not found
+        missing = todo - found
+        if missing and not allow_missing:
+            raise KeyError(next(iter(missing)))
 
     def iter_unpacked_subset(
         self,

+ 34 - 4
dulwich/tests/test_object_store.py

@@ -55,7 +55,8 @@ class ObjectStoreTests:
     store: "BaseObjectStore"
 
     assertEqual: Callable[[object, object], None]
-    assertRaises: Callable[[type[Exception], Callable[[], Any]], None]
+    # For type checker purposes - actual implementation supports both styles
+    assertRaises: Callable[..., Any]
     assertNotIn: Callable[[object, object], None]
     assertNotEqual: Callable[[object, object], None]
     assertIn: Callable[[object, object], None]
@@ -259,10 +260,39 @@ class ObjectStoreTests:
         self.assertEqual(
             [testobject.id], list(self.store.iter_prefix(testobject.id[:10]))
         )
-        self.assertEqual(
-            [testobject.id], list(self.store.iter_prefix(testobject.id[:4]))
+
+    def test_iterobjects_subset_all_present(self) -> None:
+        """Test iterating over a subset of objects that all exist."""
+        blob1 = make_object(Blob, data=b"blob 1 data")
+        blob2 = make_object(Blob, data=b"blob 2 data")
+        self.store.add_object(blob1)
+        self.store.add_object(blob2)
+
+        objects = list(self.store.iterobjects_subset([blob1.id, blob2.id]))
+        self.assertEqual(2, len(objects))
+        object_ids = set(o.id for o in objects)
+        self.assertEqual(set([blob1.id, blob2.id]), object_ids)
+
+    def test_iterobjects_subset_missing_not_allowed(self) -> None:
+        """Test iterating with missing objects when not allowed."""
+        blob1 = make_object(Blob, data=b"blob 1 data")
+        self.store.add_object(blob1)
+        missing_sha = b"1" * 40
+
+        with self.assertRaises(KeyError):
+            list(self.store.iterobjects_subset([blob1.id, missing_sha]))
+
+    def test_iterobjects_subset_missing_allowed(self) -> None:
+        """Test iterating with missing objects when allowed."""
+        blob1 = make_object(Blob, data=b"blob 1 data")
+        self.store.add_object(blob1)
+        missing_sha = b"1" * 40
+
+        objects = list(
+            self.store.iterobjects_subset([blob1.id, missing_sha], allow_missing=True)
         )
-        self.assertEqual([testobject.id], list(self.store.iter_prefix(b"")))
+        self.assertEqual(1, len(objects))
+        self.assertEqual(blob1.id, objects[0].id)
 
     def test_iter_prefix_not_found(self) -> None:
         self.assertEqual([], list(self.store.iter_prefix(b"1" * 40)))