Преглед изворни кода

Fix previously unrun tests

Jelmer Vernooij пре 1 месец
родитељ
комит
9596a2c08c
7 измењених фајлова са 160 додато и 79 уклоњено
  1. 5 15
      dulwich/client.py
  2. 11 14
      dulwich/dumb.py
  3. 71 12
      dulwich/index.py
  4. 4 2
      dulwich/rebase.py
  5. 4 4
      tests/compat/test_dumb.py
  6. 52 29
      tests/compat/test_index.py
  7. 13 3
      tests/test_cli_cherry_pick.py

+ 5 - 15
dulwich/client.py

@@ -2812,26 +2812,16 @@ class AbstractHttpGitClient(GitClient):
 
             # Write pack data
             if pack_data:
-                from .pack import pack_objects_to_data, write_pack_data
+                from .pack import write_pack_data
 
-                # Convert unpacked objects to ShaFile objects for packing
-                objects = []
-                for unpacked in pack_data_list:
-                    objects.append(unpacked.sha_file())
-
-                # Generate pack data and write it to a buffer
-                pack_buffer = BytesIO()
-                count, unpacked_iter = pack_objects_to_data(objects)
+                # Write pack data directly using the unpacked objects
                 write_pack_data(
-                    pack_buffer.write,
-                    unpacked_iter,
-                    num_records=count,
+                    pack_data,
+                    iter(pack_data_list),
+                    num_records=len(pack_data_list),
                     progress=progress,
                 )
 
-                # Pass the raw pack data to pack_data callback
-                pack_data(pack_buffer.getvalue())
-
             return FetchPackResult(refs, symrefs, agent)
         req_data = BytesIO()
         req_proto = Protocol(None, req_data.write)  # type: ignore

+ 11 - 14
dulwich/dumb.py

@@ -410,13 +410,13 @@ class DumbRemoteHTTPRepo(BaseRepo):
 
         This is the main method for fetching objects from a dumb HTTP remote.
         Since dumb HTTP doesn't support negotiation, we need to download
-        all objects reachable from the wanted refs that we don't have locally.
+        all objects reachable from the wanted refs.
 
         Args:
-          graph_walker: GraphWalker instance that can tell us which commits we have
+          graph_walker: GraphWalker instance (not used for dumb HTTP)
           determine_wants: Function that returns list of wanted SHAs
           progress: Optional progress callback
-          depth: Depth for shallow clones (not fully supported)
+          depth: Depth for shallow clones (not supported for dumb HTTP)
 
         Returns:
           Iterator of UnpackedObject instances
@@ -427,8 +427,7 @@ class DumbRemoteHTTPRepo(BaseRepo):
         if not wants:
             return
 
-        # For dumb HTTP, we can't negotiate, so we need to fetch all objects
-        # reachable from wants that we don't already have
+        # For dumb HTTP, we traverse the object graph starting from wants
         to_fetch = set(wants)
         seen = set()
 
@@ -438,19 +437,17 @@ class DumbRemoteHTTPRepo(BaseRepo):
                 continue
             seen.add(sha)
 
-            # Check if we already have this object
-            haves = list(graph_walker.ack(sha))
-            if haves:
+            # Fetch the object
+            try:
+                type_num, content = self._object_store.get_raw(sha)
+            except KeyError:
+                # Object not found, skip it
                 continue
 
-            # Fetch the object
-            type_num, content = self._object_store.get_raw(sha)
-            unpacked = UnpackedObject(type_num, sha=sha)
-            unpacked.obj_type_num = type_num
-            unpacked.obj_chunks = [content]
+            unpacked = UnpackedObject(type_num, sha=sha, decomp_chunks=[content])
             yield unpacked
 
-            # If it's a commit or tag, we need to fetch its references
+            # Parse the object to find references to other objects
             obj = ShaFile.from_raw_string(type_num, content)
 
             if isinstance(obj, Commit):  # Commit

+ 71 - 12
dulwich/index.py

@@ -661,11 +661,13 @@ def read_index(f: BinaryIO) -> Iterator[SerializedIndexEntry]:
 
 def read_index_dict_with_version(
     f: BinaryIO,
-) -> tuple[dict[bytes, Union[IndexEntry, ConflictedIndexEntry]], int]:
+) -> tuple[
+    dict[bytes, Union[IndexEntry, ConflictedIndexEntry]], int, list[IndexExtension]
+]:
     """Read an index file and return it as a dictionary along with the version.
 
     Returns:
-      tuple of (entries_dict, version)
+      tuple of (entries_dict, version, extensions)
     """
     version, num_entries = read_index_header(f)
 
@@ -688,7 +690,44 @@ def read_index_dict_with_version(
             elif stage == Stage.MERGE_CONFLICT_OTHER:
                 existing.other = IndexEntry.from_serialized(entry)
 
-    return ret, version
+    # Read extensions
+    extensions = []
+    while True:
+        # Check if we're at the end (20 bytes before EOF for SHA checksum)
+        current_pos = f.tell()
+        f.seek(0, 2)  # EOF
+        eof_pos = f.tell()
+        f.seek(current_pos)
+
+        if current_pos >= eof_pos - 20:
+            break
+
+        # Try to read extension signature
+        signature = f.read(4)
+        if len(signature) < 4:
+            break
+
+        # Check if it's a valid extension signature (4 uppercase letters)
+        if not all(65 <= b <= 90 for b in signature):
+            # Not an extension, seek back
+            f.seek(-4, 1)
+            break
+
+        # Read extension size
+        size_data = f.read(4)
+        if len(size_data) < 4:
+            break
+        size = struct.unpack(">I", size_data)[0]
+
+        # Read extension data
+        data = f.read(size)
+        if len(data) < size:
+            break
+
+        extension = IndexExtension.from_raw(signature, data)
+        extensions.append(extension)
+
+    return ret, version, extensions
 
 
 def read_index_dict(
@@ -719,7 +758,10 @@ def read_index_dict(
 
 
 def write_index(
-    f: BinaryIO, entries: list[SerializedIndexEntry], version: Optional[int] = None
+    f: BinaryIO,
+    entries: list[SerializedIndexEntry],
+    version: Optional[int] = None,
+    extensions: Optional[list[IndexExtension]] = None,
 ) -> None:
     """Write an index file.
 
@@ -727,6 +769,7 @@ def write_index(
       f: File-like object to write to
       version: Version number to write
       entries: Iterable over the entries to write
+      extensions: Optional list of extensions to write
     """
     if version is None:
         version = DEFAULT_VERSION
@@ -749,11 +792,17 @@ def write_index(
         write_cache_entry(f, entry, version=version, previous_path=previous_path)
         previous_path = entry.name
 
+    # Write extensions
+    if extensions:
+        for extension in extensions:
+            write_index_extension(f, extension)
+
 
 def write_index_dict(
     f: BinaryIO,
     entries: dict[bytes, Union[IndexEntry, ConflictedIndexEntry]],
     version: Optional[int] = None,
+    extensions: Optional[list[IndexExtension]] = None,
 ) -> None:
     """Write an index file based on the contents of a dictionary.
     being careful to sort by path and then by stage.
@@ -776,7 +825,8 @@ def write_index_dict(
                 )
         else:
             entries_list.append(value.serialize(key, Stage.NORMAL))
-    write_index(f, entries_list, version=version)
+
+    write_index(f, entries_list, version=version, extensions=extensions)
 
 
 def cleanup_mode(mode: int) -> int:
@@ -826,6 +876,7 @@ class Index:
         # TODO(jelmer): Store the version returned by read_index
         self._version = version
         self._skip_hash = skip_hash
+        self._extensions: list[IndexExtension] = []
         self.clear()
         if read:
             self.read()
@@ -845,14 +896,22 @@ class Index:
         try:
             if self._skip_hash:
                 # When skipHash is enabled, write the index without computing SHA1
-                write_index_dict(cast(BinaryIO, f), self._byname, version=self._version)
+                write_index_dict(
+                    cast(BinaryIO, f),
+                    self._byname,
+                    version=self._version,
+                    extensions=self._extensions,
+                )
                 # Write 20 zero bytes instead of SHA1
                 f.write(b"\x00" * 20)
                 f.close()
             else:
                 sha1_writer = SHA1Writer(cast(BinaryIO, f))
                 write_index_dict(
-                    cast(BinaryIO, sha1_writer), self._byname, version=self._version
+                    cast(BinaryIO, sha1_writer),
+                    self._byname,
+                    version=self._version,
+                    extensions=self._extensions,
                 )
                 sha1_writer.close()
         except:
@@ -866,13 +925,13 @@ class Index:
         f = GitFile(self._filename, "rb")
         try:
             sha1_reader = SHA1Reader(f)
-            entries, version = read_index_dict_with_version(cast(BinaryIO, sha1_reader))
+            entries, version, extensions = read_index_dict_with_version(
+                cast(BinaryIO, sha1_reader)
+            )
             self._version = version
+            self._extensions = extensions
             self.update(entries)
-            # Read any remaining data before the SHA
-            remaining = os.path.getsize(self._filename) - sha1_reader.tell() - 20
-            if remaining > 0:
-                sha1_reader.read(remaining)
+            # Extensions have already been read by read_index_dict_with_version
             sha1_reader.check_sha(allow_empty=True)
         finally:
             f.close()

+ 4 - 2
dulwich/rebase.py

@@ -316,7 +316,9 @@ class Rebaser:
         # Return in chronological order (oldest first)
         return list(reversed(commits))
 
-    def _cherry_pick(self, commit: Commit, onto: bytes) -> tuple[bytes, list[bytes]]:
+    def _cherry_pick(
+        self, commit: Commit, onto: bytes
+    ) -> tuple[Optional[bytes], list[bytes]]:
         """Cherry-pick a commit onto another commit.
 
         Args:
@@ -341,7 +343,7 @@ class Rebaser:
         if conflicts:
             # Store merge state for conflict resolution
             self.repo._put_named_file("rebase-merge/stopped-sha", commit.id)
-            return commit.id, conflicts
+            return None, conflicts
 
         # Create new commit
         new_commit = Commit()

+ 4 - 4
tests/compat/test_dumb.py

@@ -30,7 +30,7 @@ from unittest import skipUnless
 
 from dulwich.client import HttpGitClient
 from dulwich.repo import Repo
-from dulwich.tests.compat.utils import (
+from tests.compat.utils import (
     CompatTestCase,
     run_git_or_fail,
 )
@@ -131,7 +131,7 @@ class DumbHTTPClientTests(CompatTestCase):
         client = HttpGitClient(self.server.url)
 
         # Create destination repo
-        dest_repo = Repo.init(dest_path)
+        dest_repo = Repo.init(dest_path, mkdir=True)
 
         # Fetch from dumb HTTP
         def determine_wants(refs):
@@ -191,7 +191,7 @@ class DumbHTTPClientTests(CompatTestCase):
                 dest_repo.refs[ref] = sha
 
         # Reset to new commit
-        dest_repo.reset_index(dest_repo.refs[b"refs/heads/master"])
+        dest_repo.reset_index()
 
         # Verify the new file exists
         test_file2_dest = os.path.join(dest_path, "test2.txt")
@@ -213,7 +213,7 @@ class DumbHTTPClientTests(CompatTestCase):
 
         # Clone with dulwich
         dest_path = os.path.join(self.temp_dir, "cloned_with_tags")
-        dest_repo = Repo.init(dest_path)
+        dest_repo = Repo.init(dest_path, mkdir=True)
 
         client = HttpGitClient(self.server.url)
 

+ 52 - 29
tests/compat/test_index.py

@@ -27,7 +27,7 @@ import tempfile
 from dulwich.index import Index, read_index_dict_with_version, write_index_dict
 from dulwich.repo import Repo
 
-from .utils import CompatTestCase, require_git_version, run_git_or_fail
+from .utils import CompatTestCase, require_git_version, run_git, run_git_or_fail
 
 
 class IndexV4CompatTestCase(CompatTestCase):
@@ -84,7 +84,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Read the index with dulwich
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         # Verify it's version 4
         self.assertEqual(version, 4)
@@ -96,7 +96,11 @@ class IndexV4CompatTestCase(CompatTestCase):
 
         # Write the index back with dulwich
         with open(index_path + ".dulwich", "wb") as f:
-            write_index_dict(f, entries, version=4)
+            from dulwich.pack import SHA1Writer
+
+            sha1_writer = SHA1Writer(f)
+            write_index_dict(sha1_writer, entries, version=4, extensions=extensions)
+            sha1_writer.close()
 
         # Compare with C git - use git ls-files to read both indexes
         output1 = run_git_or_fail(["ls-files", "--stage"], cwd=repo.path)
@@ -165,7 +169,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Read the index
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
         self.assertIn(b"test.txt", entries)
@@ -217,7 +221,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Read with dulwich
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
         self.assertEqual(len(entries), len(test_files))
@@ -229,14 +233,19 @@ class IndexV4CompatTestCase(CompatTestCase):
 
         # Test round-trip: dulwich write -> C Git read
         with open(index_path + ".dulwich", "wb") as f:
-            write_index_dict(f, entries, version=4)
+            from dulwich.pack import SHA1Writer
+
+            sha1_writer = SHA1Writer(f)
+            write_index_dict(sha1_writer, entries, version=4, extensions=extensions)
+            sha1_writer.close()
 
         # Replace index
         os.rename(index_path + ".dulwich", index_path)
 
         # Verify C Git can read all files
-        output = run_git_or_fail(["ls-files"], cwd=repo.path)
-        git_files = set(output.strip().split(b"\n"))
+        # Use -z flag to avoid quoting of non-ASCII filenames
+        output = run_git_or_fail(["ls-files", "-z"], cwd=repo.path)
+        git_files = set(output.strip(b"\x00").split(b"\x00"))
         expected_files = {f.encode("utf-8") for f in test_files}
         self.assertEqual(git_files, expected_files)
 
@@ -276,7 +285,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Read the index
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
         self.assertEqual(len(entries), len(all_files))
@@ -287,14 +296,22 @@ class IndexV4CompatTestCase(CompatTestCase):
 
         # Test that dulwich can write a compatible index
         with open(index_path + ".dulwich", "wb") as f:
-            write_index_dict(f, entries, version=4)
+            from dulwich.pack import SHA1Writer
 
-        # Verify the written index is smaller (compression should help)
+            sha1_writer = SHA1Writer(f)
+            write_index_dict(sha1_writer, entries, version=4, extensions=extensions)
+            sha1_writer.close()
+
+        # Verify the written index is the same size (for byte-for-byte compatibility)
         original_size = os.path.getsize(index_path)
         dulwich_size = os.path.getsize(index_path + ".dulwich")
 
-        # Allow some variance due to different compression decisions
-        self.assertLess(abs(original_size - dulwich_size), original_size * 0.2)
+        # For v4 format with proper compression, checksum, and extensions, sizes should match
+        self.assertEqual(
+            original_size,
+            dulwich_size,
+            f"Index sizes don't match: Git={original_size}, Dulwich={dulwich_size}",
+        )
 
     def test_index_v4_with_extensions(self) -> None:
         """Test v4 index with various extensions."""
@@ -320,7 +337,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Read index with extensions
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
         self.assertEqual(len(entries), len(files))
@@ -348,14 +365,20 @@ class IndexV4CompatTestCase(CompatTestCase):
         index_path = os.path.join(repo.path, ".git", "index")
         if os.path.exists(index_path):
             with open(index_path, "rb") as f:
-                entries, version = read_index_dict_with_version(f)
+                entries, version, extensions = read_index_dict_with_version(f)
 
             # Even empty indexes should be readable
             self.assertEqual(len(entries), 0)
 
             # Test writing empty index
             with open(index_path + ".dulwich", "wb") as f:
-                write_index_dict(f, entries, version=version)
+                from dulwich.pack import SHA1Writer
+
+                sha1_writer = SHA1Writer(f)
+                write_index_dict(
+                    sha1_writer, entries, version=version, extensions=extensions
+                )
+                sha1_writer.close()
 
     def test_index_v4_large_file_count(self) -> None:
         """Test v4 index with many files (stress test)."""
@@ -379,7 +402,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Read index
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
         self.assertEqual(len(entries), len(files))
@@ -422,7 +445,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Test dulwich can read the updated index
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
         self.assertEqual(len(entries), 4)  # 3 original + 1 new
@@ -469,13 +492,13 @@ class IndexV4CompatTestCase(CompatTestCase):
         run_git_or_fail(["commit", "-m", "master change"], cwd=repo.path)
 
         # Try to merge (should create conflicts)
-        run_git_or_fail(["merge", "feature"], cwd=repo.path, check=False)
+        run_git(["merge", "feature"], cwd=repo.path)
 
         # Read the index with conflicts
         index_path = os.path.join(repo.path, ".git", "index")
         if os.path.exists(index_path):
             with open(index_path, "rb") as f:
-                entries, version = read_index_dict_with_version(f)
+                entries, version, extensions = read_index_dict_with_version(f)
 
             self.assertEqual(version, 4)
 
@@ -525,7 +548,7 @@ class IndexV4CompatTestCase(CompatTestCase):
             # Test reading
             index_path = os.path.join(repo.path, ".git", "index")
             with open(index_path, "rb") as f:
-                entries, version = read_index_dict_with_version(f)
+                entries, version, extensions = read_index_dict_with_version(f)
 
             self.assertEqual(version, 4)
 
@@ -579,7 +602,7 @@ class IndexV4CompatTestCase(CompatTestCase):
             # Test reading
             index_path = os.path.join(repo.path, ".git", "index")
             with open(index_path, "rb") as f:
-                entries, version = read_index_dict_with_version(f)
+                entries, version, extensions = read_index_dict_with_version(f)
 
             self.assertEqual(version, 4)
             self.assertGreater(len(entries), 0)
@@ -618,7 +641,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Test reading
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
 
@@ -675,7 +698,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Test reading
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
         self.assertEqual(len(entries), len(files))
@@ -716,7 +739,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Test reading index with submodule
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
 
@@ -759,7 +782,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Test reading this complex index state
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
         self.assertIn(b"partial.txt", entries)
@@ -806,7 +829,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Test reading
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
 
@@ -872,7 +895,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Test reading large index
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
         self.assertEqual(len(entries), len(files))
@@ -912,7 +935,7 @@ class IndexV4CompatTestCase(CompatTestCase):
         # Test reading index with renames
         index_path = os.path.join(repo.path, ".git", "index")
         with open(index_path, "rb") as f:
-            entries, version = read_index_dict_with_version(f)
+            entries, version, extensions = read_index_dict_with_version(f)
 
         self.assertEqual(version, 4)
 

+ 13 - 3
tests/test_cli_cherry_pick.py

@@ -120,6 +120,9 @@ class CherryPickCommandTests(TestCase):
 
     def test_cherry_pick_missing_argument(self):
         """Test cherry-pick without commit argument."""
+        import io
+        import sys
+
         with tempfile.TemporaryDirectory() as tmpdir:
             orig_cwd = os.getcwd()
             try:
@@ -128,10 +131,17 @@ class CherryPickCommandTests(TestCase):
 
                 # Try to cherry-pick without argument
                 cmd = cmd_cherry_pick()
-                with self.assertRaises(SystemExit) as cm:
-                    cmd.run([])
 
-                self.assertEqual(cm.exception.code, 2)  # argparse error code
+                # Capture stderr to prevent argparse from printing to console
+                old_stderr = sys.stderr
+                sys.stderr = io.StringIO()
+
+                try:
+                    with self.assertRaises(SystemExit) as cm:
+                        cmd.run([])
+                    self.assertEqual(cm.exception.code, 2)  # argparse error code
+                finally:
+                    sys.stderr = old_stderr
 
             finally:
                 os.chdir(orig_cwd)