Explorar o código

Fix SHA-256 pack checksum generation and trailer handling

Jelmer Vernooij hai 1 mes
pai
achega
1cced19ca4
Modificáronse 6 ficheiros con 437 adicións e 87 borrados
  1. 38 7
      dulwich/contrib/swift.py
  2. 25 5
      dulwich/object_store.py
  3. 301 62
      dulwich/pack.py
  4. 43 0
      tests/compat/test_sha256.py
  5. 6 0
      tests/test_config.py
  6. 24 13
      tests/test_pack.py

+ 38 - 7
dulwich/contrib/swift.py

@@ -59,7 +59,10 @@ import zlib
 from collections.abc import Callable, Iterator, Mapping
 from configparser import ConfigParser
 from io import BytesIO
-from typing import Any, BinaryIO, cast
+from typing import TYPE_CHECKING, Any, BinaryIO, cast
+
+if TYPE_CHECKING:
+    from dulwich.object_format import ObjectFormat
 
 from geventhttpclient import HTTPClient
 
@@ -655,13 +658,31 @@ class SwiftPackData(PackData):
     using the Range header feature of Swift.
     """
 
-    def __init__(self, scon: SwiftConnector, filename: str | os.PathLike[str]) -> None:
+    def __init__(
+        self,
+        scon: SwiftConnector,
+        filename: str | os.PathLike[str],
+        object_format: "ObjectFormat | None" = None,
+    ) -> None:
         """Initialize a SwiftPackReader.
 
         Args:
           scon: a `SwiftConnector` instance
           filename: the pack filename
+          object_format: Object format for this pack
         """
+        from dulwich.object_format import DEFAULT_OBJECT_FORMAT
+
+        if object_format is None:
+            import warnings
+
+            warnings.warn(
+                "SwiftPackData() should be called with object_format parameter",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            object_format = DEFAULT_OBJECT_FORMAT
+        self.object_format = object_format
         self.scon = scon
         self._filename = filename
         self._header_size = 12
@@ -693,7 +714,7 @@ class SwiftPackData(PackData):
         assert offset >= self._header_size
         pack_reader = SwiftPackReader(self.scon, str(self._filename), self.pack_length)
         pack_reader.seek(offset)
-        unpacked, _ = unpack_object(pack_reader.read)
+        unpacked, _ = unpack_object(pack_reader.read, self.object_format.hash_func)
         obj_data = unpacked._obj()
         return (unpacked.pack_type_num, obj_data)
 
@@ -910,9 +931,17 @@ class SwiftObjectStore(PackBasedObjectStore):
         fd, path = tempfile.mkstemp(prefix="tmp_pack_")
         f = os.fdopen(fd, "w+b")
         try:
-            pack_data = PackData(file=cast(_GitFile, f), filename=path)
-            indexer = PackIndexer(cast(BinaryIO, pack_data._file), resolve_ext_ref=None)
-            copier = PackStreamCopier(read_all, read_some, f, delta_iter=None)
+            pack_data = PackData(
+                file=cast(_GitFile, f), filename=path, object_format=self.object_format
+            )
+            indexer = PackIndexer(
+                cast(BinaryIO, pack_data._file),
+                self.object_format.hash_func,
+                resolve_ext_ref=None,
+            )
+            copier = PackStreamCopier(
+                self.object_format.hash_func, read_all, read_some, f, delta_iter=None
+            )
             copier.verify()
             return self._complete_thin_pack(f, path, copier, indexer)
         finally:
@@ -932,7 +961,9 @@ class SwiftObjectStore(PackBasedObjectStore):
         f.flush()
 
         # Rescan the rest of the pack, computing the SHA with the new header.
-        new_sha = compute_file_sha(f, end_ofs=-20)
+        new_sha = compute_file_sha(
+            f, self.object_format, end_ofs=-self.object_format.oid_length
+        )
 
         # Must reposition before writing (http://bugs.python.org/issue3207)
         f.seek(0, os.SEEK_CUR)

+ 25 - 5
dulwich/object_store.py

@@ -1906,7 +1906,9 @@ class DiskObjectStore(PackBasedObjectStore):
             # Load the index we just wrote
             with open(target_index_path, "rb") as idx_file:
                 pack_index = load_pack_index_file(
-                    os.path.basename(target_index_path), idx_file
+                    os.path.basename(target_index_path),
+                    idx_file,
+                    object_format=self.object_format,
                 )
 
             # Generate the bitmap
@@ -1970,8 +1972,18 @@ class DiskObjectStore(PackBasedObjectStore):
         fd, path = tempfile.mkstemp(dir=self.path, prefix="tmp_pack_")
         with os.fdopen(fd, "w+b") as f:
             os.chmod(path, PACK_MODE)
-            indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)  # type: ignore[arg-type]
-            copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)  # type: ignore[arg-type]
+            indexer = PackIndexer(
+                f,
+                self.object_format.hash_func,
+                resolve_ext_ref=self.get_raw,  # type: ignore[arg-type]
+            )
+            copier = PackStreamCopier(
+                self.object_format.hash_func,
+                read_all,
+                read_some,
+                f,
+                delta_iter=indexer,  # type: ignore[arg-type]
+            )
             copier.verify(progress=progress)
             return self._complete_pack(f, path, len(copier), indexer, progress=progress)
 
@@ -2192,6 +2204,7 @@ class DiskObjectStore(PackBasedObjectStore):
             depth=self.pack_depth,
             threads=self.pack_threads,
             big_file_threshold=self.pack_big_file_threshold,
+            object_format=self.object_format,
         )
         self._pack_cache[base_name] = pack
         return pack
@@ -2612,7 +2625,12 @@ class MemoryObjectStore(PackCapableObjectStore):
         """
         f, commit, abort = self.add_pack()
         try:
-            copier = PackStreamCopier(read_all, read_some, f)  # type: ignore[arg-type]
+            copier = PackStreamCopier(
+                self.object_format.hash_func,
+                read_all,
+                read_some,
+                f,  # type: ignore[arg-type]
+            )
             copier.verify()
         except BaseException:
             abort()
@@ -3353,7 +3371,9 @@ class BucketBasedObjectStore(PackBasedObjectStore):
             checksum = p.get_stored_checksum()
             write_pack_index(idxf, entries, checksum, version=self.pack_index_version)
             idxf.seek(0)
-            idx = load_pack_index_file(basename + ".idx", idxf)
+            idx = load_pack_index_file(
+                basename + ".idx", idxf, object_format=self.object_format
+            )
             for pack in self.packs:
                 if pack.get_stored_checksum() == p.get_stored_checksum():
                     p.close()

+ 301 - 62
dulwich/pack.py

@@ -113,7 +113,7 @@ import sys
 import warnings
 import zlib
 from collections.abc import Callable, Iterable, Iterator, Sequence, Set
-from hashlib import sha1
+from hashlib import sha1, sha256
 from itertools import chain
 from os import SEEK_CUR, SEEK_END
 from struct import unpack_from
@@ -345,6 +345,7 @@ class UnpackedObject:
         "decomp_chunks",  # Decompressed object chunks.
         "decomp_len",  # Decompressed length of this object.
         "delta_base",  # Delta base offset or SHA.
+        "hash_func",  # Hash function to use for computing object IDs.
         "obj_chunks",  # Decompressed and delta-resolved chunks.
         "obj_type_num",  # Type of this object.
         "offset",  # Offset in its pack.
@@ -361,6 +362,7 @@ class UnpackedObject:
     offset: int | None
     pack_type_num: int
     _sha: bytes | None
+    hash_func: Callable[[], "HashObject"]
 
     # TODO(dborowitz): read_zlib_chunks and unpack_object could very well be
     # methods of this object.
@@ -374,6 +376,7 @@ class UnpackedObject:
         sha: bytes | None = None,
         decomp_chunks: list[bytes] | None = None,
         offset: int | None = None,
+        hash_func: Callable[[], "HashObject"] = sha1,
     ) -> None:
         """Initialize an UnpackedObject.
 
@@ -382,9 +385,10 @@ class UnpackedObject:
             delta_base: Delta base (offset or SHA) if this is a delta object
             decomp_len: Decompressed length of this object
             crc32: CRC32 checksum
-            sha: SHA-1 hash of the object
+            sha: SHA hash of the object
             decomp_chunks: Decompressed chunks
             offset: Offset in the pack file
+            hash_func: Hash function to use (defaults to sha1)
         """
         self.offset = offset
         self._sha = sha
@@ -397,6 +401,7 @@ class UnpackedObject:
         else:
             self.decomp_len = decomp_len
         self.crc32 = crc32
+        self.hash_func = hash_func
 
         if pack_type_num in DELTA_TYPES:
             self.obj_type_num = None
@@ -410,7 +415,7 @@ class UnpackedObject:
         """Return the binary SHA of this object."""
         if self._sha is None:
             assert self.obj_type_num is not None and self.obj_chunks is not None
-            self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
+            self._sha = obj_sha(self.obj_type_num, self.obj_chunks, self.hash_func)
         return RawObjectID(self._sha)
 
     def sha_file(self) -> ShaFile:
@@ -1176,6 +1181,29 @@ class PackIndex2(FilePackIndex):
         checksum_size = self.hash_size
         return bytes(self._contents[-2 * checksum_size : -checksum_size])
 
+    def get_stored_checksum(self) -> bytes:
+        """Return the checksum stored for this index.
+
+        Returns: binary digest (size depends on hash algorithm)
+        """
+        checksum_size = self.hash_size
+        return bytes(self._contents[-checksum_size:])
+
+    def calculate_checksum(self) -> bytes:
+        """Calculate the checksum over this pack index.
+
+        Returns: binary digest (size depends on hash algorithm)
+        """
+        # Determine hash function based on hash_size
+        if self.hash_size == 20:
+            hash_func = sha1
+        elif self.hash_size == 32:
+            hash_func = sha256
+        else:
+            raise ValueError(f"Unsupported hash size: {self.hash_size}")
+
+        return hash_func(self._contents[: -self.hash_size]).digest()
+
 
 class PackIndex3(FilePackIndex):
     """Version 3 Pack Index file.
@@ -1292,6 +1320,7 @@ def chunks_length(chunks: bytes | Iterable[bytes]) -> int:
 
 def unpack_object(
     read_all: Callable[[int], bytes],
+    hash_func: Callable[[], "HashObject"],
     read_some: Callable[[int], bytes] | None = None,
     compute_crc32: bool = False,
     include_comp: bool = False,
@@ -1302,6 +1331,7 @@ def unpack_object(
     Args:
       read_all: Read function that blocks until the number of requested
         bytes are read.
+      hash_func: Hash function to use for computing object IDs.
       read_some: Read function that returns at least one byte, but may not
         return the number of bytes requested.
       compute_crc32: If True, compute the CRC32 of the compressed data. If
@@ -1347,16 +1377,22 @@ def unpack_object(
             delta_base_offset += byte & 0x7F
         delta_base = delta_base_offset
     elif type_num == REF_DELTA:
-        delta_base_obj = read_all(20)
+        # Determine hash size from hash_func
+        hash_size = len(hash_func().digest())
+        delta_base_obj = read_all(hash_size)
         if crc32 is not None:
             crc32 = binascii.crc32(delta_base_obj, crc32)
         delta_base = delta_base_obj
-        raw_base += 20
+        raw_base += hash_size
     else:
         delta_base = None
 
     unpacked = UnpackedObject(
-        type_num, delta_base=delta_base, decomp_len=size, crc32=crc32
+        type_num,
+        delta_base=delta_base,
+        decomp_len=size,
+        crc32=crc32,
+        hash_func=hash_func,
     )
     unused = read_zlib_chunks(
         read_some,
@@ -1384,6 +1420,7 @@ class PackStreamReader:
 
     def __init__(
         self,
+        hash_func: Callable[[], "HashObject"],
         read_all: Callable[[int], bytes],
         read_some: Callable[[int], bytes] | None = None,
         zlib_bufsize: int = _ZLIB_BUFSIZE,
@@ -1391,6 +1428,7 @@ class PackStreamReader:
         """Initialize pack stream reader.
 
         Args:
+            hash_func: Hash function to use for computing object IDs
             read_all: Function to read all requested bytes
             read_some: Function to read some bytes (optional)
             zlib_bufsize: Buffer size for zlib decompression
@@ -1400,7 +1438,9 @@ class PackStreamReader:
             self.read_some = read_all
         else:
             self.read_some = read_some
-        self.sha = sha1()
+        self.hash_func = hash_func
+        self.sha = hash_func()
+        self._hash_size = len(hash_func().digest())
         self._offset = 0
         self._rbuf = BytesIO()
         # trailer is a deque to avoid memory allocation on small reads
@@ -1410,8 +1450,8 @@ class PackStreamReader:
     def _read(self, read: Callable[[int], bytes], size: int) -> bytes:
         """Read up to size bytes using the given callback.
 
-        As a side effect, update the verifier's hash (excluding the last 20
-        bytes read).
+        As a side effect, update the verifier's hash (excluding the last
+        hash_size bytes read, which is the pack checksum).
 
         Args:
           read: The read callback to read from.
@@ -1421,15 +1461,15 @@ class PackStreamReader:
         """
         data = read(size)
 
-        # maintain a trailer of the last 20 bytes we've read
+        # maintain a trailer of the last hash_size bytes we've read
         n = len(data)
         self._offset += n
         tn = len(self._trailer)
-        if n >= 20:
+        if n >= self._hash_size:
             to_pop = tn
-            to_add = 20
+            to_add = self._hash_size
         else:
-            to_pop = max(n + tn - 20, 0)
+            to_pop = max(n + tn - self._hash_size, 0)
             to_add = n
         self.sha.update(
             bytes(bytearray([self._trailer.popleft() for _ in range(to_pop)]))
@@ -1503,6 +1543,7 @@ class PackStreamReader:
             offset = self.offset
             unpacked, unused = unpack_object(
                 self.read,
+                self.hash_func,
                 read_some=self.recv,
                 compute_crc32=compute_crc32,
                 zlib_bufsize=self._zlib_bufsize,
@@ -1518,12 +1559,12 @@ class PackStreamReader:
 
             yield unpacked
 
-        if self._buf_len() < 20:
+        if self._buf_len() < self._hash_size:
             # If the read buffer is full, then the last read() got the whole
             # trailer off the wire. If not, it means there is still some of the
-            # trailer to read. We need to read() all 20 bytes; N come from the
-            # read buffer and (20 - N) come from the wire.
-            self.read(20)
+            # trailer to read. We need to read() all hash_size bytes; N come from the
+            # read buffer and (hash_size - N) come from the wire.
+            self.read(self._hash_size)
 
         pack_sha = bytearray(self._trailer)
         if pack_sha != self.sha.digest():
@@ -1541,6 +1582,7 @@ class PackStreamCopier(PackStreamReader):
 
     def __init__(
         self,
+        hash_func: Callable[[], "HashObject"],
         read_all: Callable[[int], bytes],
         read_some: Callable[[int], bytes] | None,
         outfile: IO[bytes],
@@ -1549,6 +1591,7 @@ class PackStreamCopier(PackStreamReader):
         """Initialize the copier.
 
         Args:
+          hash_func: Hash function to use for computing object IDs
           read_all: Read function that blocks until the number of
             requested bytes are read.
           read_some: Read function that returns at least one byte, but may
@@ -1557,7 +1600,7 @@ class PackStreamCopier(PackStreamReader):
           delta_iter: Optional DeltaChainIterator to record deltas as we
             read them.
         """
-        super().__init__(read_all, read_some=read_some)
+        super().__init__(hash_func, read_all, read_some=read_some)
         self.outfile = outfile
         self._delta_iter = delta_iter
 
@@ -1583,9 +1626,22 @@ class PackStreamCopier(PackStreamReader):
             progress(f"copied {i} pack entries\n".encode("ascii"))
 
 
-def obj_sha(type: int, chunks: bytes | Iterable[bytes]) -> bytes:
-    """Compute the SHA for a numeric type and object chunks."""
-    sha = sha1()
+def obj_sha(
+    type: int,
+    chunks: bytes | Iterable[bytes],
+    hash_func: Callable[[], "HashObject"] = sha1,
+) -> bytes:
+    """Compute the SHA for a numeric type and object chunks.
+
+    Args:
+        type: Object type number
+        chunks: Object data chunks
+        hash_func: Hash function to use (defaults to sha1)
+
+    Returns:
+        Binary hash digest
+    """
+    sha = hash_func()
     sha.update(object_header(type, chunks_length(chunks)))
     if isinstance(chunks, bytes):
         sha.update(chunks)
@@ -1596,21 +1652,28 @@ def obj_sha(type: int, chunks: bytes | Iterable[bytes]) -> bytes:
 
 
 def compute_file_sha(
-    f: IO[bytes], start_ofs: int = 0, end_ofs: int = 0, buffer_size: int = 1 << 16
+    f: IO[bytes],
+    object_format: "ObjectFormat",
+    start_ofs: int = 0,
+    end_ofs: int = 0,
+    buffer_size: int = 1 << 16,
 ) -> "HashObject":
     """Hash a portion of a file into a new SHA.
 
     Args:
       f: A file-like object to read from that supports seek().
+      object_format: Hash algorithm to use.
       start_ofs: The offset in the file to start reading at.
       end_ofs: The offset in the file to end reading at, relative to the
         end of the file.
       buffer_size: A buffer size for reading.
     Returns: A new SHA object updated with data read from the file.
     """
-    sha = sha1()
+    sha = object_format.new_hash()
     f.seek(0, SEEK_END)
     length = f.tell()
+    if start_ofs < 0:
+        raise AssertionError(f"start_ofs cannot be negative: {start_ofs}")
     if (end_ofs < 0 and length + end_ofs < start_ofs) or end_ofs > length:
         raise AssertionError(
             f"Attempt to read beyond file length. start_ofs: {start_ofs}, end_ofs: {end_ofs}, file length: {length}"
@@ -1799,9 +1862,11 @@ class PackData:
     def calculate_checksum(self) -> bytes:
         """Calculate the checksum for this pack.
 
-        Returns: 20-byte binary SHA1 digest
+        Returns: Binary digest (size depends on hash algorithm)
         """
-        return compute_file_sha(self._file, end_ofs=-20).digest()
+        return compute_file_sha(
+            self._file, self.object_format, end_ofs=-self.object_format.oid_length
+        ).digest()
 
     def iter_unpacked(self, *, include_comp: bool = False) -> Iterator[UnpackedObject]:
         """Iterate over unpacked objects in the pack."""
@@ -1813,7 +1878,10 @@ class PackData:
         for _ in range(self._num_objects):
             offset = self._file.tell()
             unpacked, unused = unpack_object(
-                self._file.read, compute_crc32=False, include_comp=include_comp
+                self._file.read,
+                self.object_format.hash_func,
+                compute_crc32=False,
+                include_comp=include_comp,
             )
             unpacked.offset = offset
             yield unpacked
@@ -1984,7 +2052,9 @@ class PackData:
         """Given offset in the packfile return a UnpackedObject."""
         assert offset >= self._header_size
         self._file.seek(offset)
-        unpacked, _ = unpack_object(self._file.read, include_comp=include_comp)
+        unpacked, _ = unpack_object(
+            self._file.read, self.object_format.hash_func, include_comp=include_comp
+        )
         unpacked.offset = offset
         return unpacked
 
@@ -2033,6 +2103,7 @@ class DeltaChainIterator(Generic[T]):
     def __init__(
         self,
         file_obj: IO[bytes] | None,
+        hash_func: Callable[[], "HashObject"],
         *,
         resolve_ext_ref: ResolveExtRefFn | None = None,
     ) -> None:
@@ -2040,9 +2111,11 @@ class DeltaChainIterator(Generic[T]):
 
         Args:
             file_obj: File object to read pack data from
+            hash_func: Hash function to use for computing object IDs
             resolve_ext_ref: Optional function to resolve external references
         """
         self._file = file_obj
+        self.hash_func = hash_func
         self._resolve_ext_ref = resolve_ext_ref
         self._pending_ofs: dict[int, list[int]] = defaultdict(list)
         self._pending_ref: dict[bytes, list[int]] = defaultdict(list)
@@ -2062,7 +2135,9 @@ class DeltaChainIterator(Generic[T]):
         Returns:
           DeltaChainIterator instance
         """
-        walker = cls(None, resolve_ext_ref=resolve_ext_ref)
+        walker = cls(
+            None, pack_data.object_format.hash_func, resolve_ext_ref=resolve_ext_ref
+        )
         walker.set_pack_data(pack_data)
         for unpacked in pack_data.iter_unpacked(include_comp=False):
             walker.record(unpacked)
@@ -2088,7 +2163,9 @@ class DeltaChainIterator(Generic[T]):
         Returns:
           DeltaChainIterator instance
         """
-        walker = cls(None, resolve_ext_ref=resolve_ext_ref)
+        walker = cls(
+            None, pack.object_format.hash_func, resolve_ext_ref=resolve_ext_ref
+        )
         walker.set_pack_data(pack.data)
         todo = set()
         for sha in shas:
@@ -2191,8 +2268,10 @@ class DeltaChainIterator(Generic[T]):
         self._file.seek(offset)
         unpacked, _ = unpack_object(
             self._file.read,
-            include_comp=self._include_comp,
+            self.hash_func,
+            read_some=None,
             compute_crc32=self._compute_crc32,
+            include_comp=self._include_comp,
         )
         unpacked.offset = offset
         if base_chunks is None:
@@ -2319,11 +2398,15 @@ class SHA1Reader(BinaryIO):
         # If git option index.skipHash is set the index will be empty
         if stored != self.sha1.digest() and (
             not allow_empty
-            or sha_to_hex(RawObjectID(stored))
-            != b"0000000000000000000000000000000000000000"
+            or (
+                len(stored) == 20
+                and sha_to_hex(RawObjectID(stored))
+                != b"0000000000000000000000000000000000000000"
+            )
         ):
             raise ChecksumMismatch(
-                self.sha1.hexdigest(), sha_to_hex(RawObjectID(stored))
+                self.sha1.hexdigest(),
+                sha_to_hex(RawObjectID(stored)) if stored else b"",
             )
 
     def close(self) -> None:
@@ -2573,19 +2656,162 @@ class SHA1Writer(BinaryIO):
         traceback: TracebackType | None,
     ) -> None:
         """Exit context manager and close file."""
-        self.close()
+        self.f.close()
 
-    def __iter__(self) -> "SHA1Writer":
-        """Return iterator."""
-        return self
+    def fileno(self) -> int:
+        """Return file descriptor number."""
+        return self.f.fileno()
 
-    def __next__(self) -> bytes:
+    def isatty(self) -> bool:
+        """Check if file is a terminal."""
+        return getattr(self.f, "isatty", lambda: False)()
+
+    def truncate(self, size: int | None = None) -> int:
+        """Not supported for write-only file.
+
+        Raises:
+            UnsupportedOperation: Always raised
+        """
+        raise UnsupportedOperation("truncate")
+
+
+class HashWriter(BinaryIO):
+    """Wrapper for file-like object that computes hash of its data.
+
+    This is a generic version that works with any hash algorithm.
+    """
+
+    def __init__(
+        self, f: BinaryIO | IO[bytes], hash_func: Callable[[], "HashObject"]
+    ) -> None:
+        """Initialize HashWriter.
+
+        Args:
+            f: File-like object to wrap
+            hash_func: Hash function (e.g., sha1, sha256)
+        """
+        self.f = f
+        self.length = 0
+        self.hash_obj = hash_func()
+        self.digest: bytes | None = None
+
+    def write(self, data: bytes | bytearray | memoryview, /) -> int:  # type: ignore[override]
+        """Write data and update hash.
+
+        Args:
+            data: Data to write
+
+        Returns:
+            Number of bytes written
+        """
+        self.hash_obj.update(data)
+        written = self.f.write(data)
+        self.length += written
+        return written
+
+    def write_hash(self) -> bytes:
+        """Write the hash digest to the file.
+
+        Returns:
+            The hash digest bytes
+        """
+        digest = self.hash_obj.digest()
+        self.f.write(digest)
+        self.length += len(digest)
+        return digest
+
+    def close(self) -> None:
+        """Close the pack file and finalize the hash."""
+        self.digest = self.write_hash()
+        self.f.close()
+
+    def offset(self) -> int:
+        """Get the total number of bytes written.
+
+        Returns:
+            Total bytes written
+        """
+        return self.length
+
+    def tell(self) -> int:
+        """Return current file position."""
+        return self.f.tell()
+
+    # BinaryIO abstract methods
+    def readable(self) -> bool:
+        """Check if file is readable."""
+        return False
+
+    def writable(self) -> bool:
+        """Check if file is writable."""
+        return True
+
+    def seekable(self) -> bool:
+        """Check if file is seekable."""
+        return getattr(self.f, "seekable", lambda: False)()
+
+    def seek(self, offset: int, whence: int = 0) -> int:
+        """Seek to position in file.
+
+        Args:
+            offset: Position offset
+            whence: Reference point (0=start, 1=current, 2=end)
+
+        Returns:
+            New file position
+        """
+        return self.f.seek(offset, whence)
+
+    def flush(self) -> None:
+        """Flush the file buffer."""
+        if hasattr(self.f, "flush"):
+            self.f.flush()
+
+    def readline(self, size: int = -1) -> bytes:
+        """Not supported for write-only file.
+
+        Raises:
+            UnsupportedOperation: Always raised
+        """
+        raise UnsupportedOperation("readline")
+
+    def readlines(self, hint: int = -1) -> list[bytes]:
+        """Not supported for write-only file.
+
+        Raises:
+            UnsupportedOperation: Always raised
+        """
+        raise UnsupportedOperation("readlines")
+
+    def writelines(self, lines: Iterable[bytes], /) -> None:  # type: ignore[override]
+        """Write multiple lines to the file.
+
+        Args:
+            lines: Iterable of lines to write
+        """
+        for line in lines:
+            self.write(line)
+
+    def read(self, size: int = -1) -> bytes:
         """Not supported for write-only file.
 
         Raises:
             UnsupportedOperation: Always raised
         """
-        raise UnsupportedOperation("__next__")
+        raise UnsupportedOperation("read")
+
+    def __enter__(self) -> "HashWriter":
+        """Enter context manager."""
+        return self
+
+    def __exit__(
+        self,
+        type: type | None,
+        value: BaseException | None,
+        traceback: TracebackType | None,
+    ) -> None:
+        """Exit context manager and close file."""
+        self.close()
 
     def fileno(self) -> int:
         """Return file descriptor number."""
@@ -3183,7 +3409,7 @@ class PackChunkGenerator:
             )
             object_format = DEFAULT_OBJECT_FORMAT
         self.object_format = object_format
-        self.cs = sha1(b"")
+        self.cs = object_format.new_hash()
         self.entries: dict[bytes, tuple[int, int]] = {}
         if records is None:
             records = iter([])  # Empty iterator if None
@@ -3518,42 +3744,55 @@ def write_pack_index_v2(
       entries: List of tuples with object name (sha), offset_in_pack, and
         crc32_checksum.
       pack_checksum: Checksum of the pack file.
-    Returns: The SHA of the index file written
+    Returns: The checksum of the index file written
     """
-    f = SHA1Writer(f)
-    f.write(b"\377tOc")  # Magic!
-    f.write(struct.pack(">L", 2))
+    # Determine hash algorithm from pack_checksum length
+    if len(pack_checksum) == 20:
+        hash_func = sha1
+    elif len(pack_checksum) == 32:
+        hash_func = sha256
+    else:
+        raise ValueError(f"Unsupported pack checksum length: {len(pack_checksum)}")
+
+    f_writer = HashWriter(f, hash_func)  # type: ignore[abstract]
+    f_writer.write(b"\377tOc")  # Magic!
+    f_writer.write(struct.pack(">L", 2))
+
+    # Convert to list to allow multiple iterations
+    entries_list = list(entries)
+
     fan_out_table: dict[int, int] = defaultdict(lambda: 0)
-    for name, offset, entry_checksum in entries:
+    for name, offset, entry_checksum in entries_list:
         fan_out_table[ord(name[:1])] += 1
-    try:
-        hash_size = len(next(iter(entries))[0])
-    except StopIteration:
-        hash_size = 20  # Default to SHA-1 size if no entries
+
+    if entries_list:
+        hash_size = len(entries_list[0][0])
+    else:
+        hash_size = len(pack_checksum)  # Use pack_checksum length as hash size
+
     # Fan-out table
     largetable: list[int] = []
     for i in range(0x100):
-        f.write(struct.pack(b">L", fan_out_table[i]))
+        f_writer.write(struct.pack(b">L", fan_out_table[i]))
         fan_out_table[i + 1] += fan_out_table[i]
-    for name, offset, entry_checksum in entries:
+    for name, offset, entry_checksum in entries_list:
         if len(name) != hash_size:
             raise TypeError(
                 f"Object name has wrong length: expected {hash_size}, got {len(name)}"
             )
-        f.write(name)
-    for name, offset, entry_checksum in entries:
-        f.write(struct.pack(b">L", entry_checksum))
-    for name, offset, entry_checksum in entries:
+        f_writer.write(name)
+    for name, offset, entry_checksum in entries_list:
+        f_writer.write(struct.pack(b">L", entry_checksum))
+    for name, offset, entry_checksum in entries_list:
         if offset < 2**31:
-            f.write(struct.pack(b">L", offset))
+            f_writer.write(struct.pack(b">L", offset))
         else:
-            f.write(struct.pack(b">L", 2**31 + len(largetable)))
+            f_writer.write(struct.pack(b">L", 2**31 + len(largetable)))
             largetable.append(offset)
     for offset in largetable:
-        f.write(struct.pack(b">Q", offset))
-    assert len(pack_checksum) == 20
-    f.write(pack_checksum)
-    return f.write_sha()
+        f_writer.write(struct.pack(b">Q", offset))
+    f_writer.write(pack_checksum)
+    return f_writer.write_hash()
 
 
 def write_pack_index_v3(
@@ -4219,7 +4458,7 @@ def extend_pack(
         f.flush()
 
     # Rescan the rest of the pack, computing the SHA with the new header.
-    new_sha = compute_file_sha(f, end_ofs=-20)
+    new_sha = compute_file_sha(f, object_format, end_ofs=-object_format.oid_length)
 
     # Must reposition before writing (http://bugs.python.org/issue3207)
     f.seek(0, os.SEEK_CUR)

+ 43 - 0
tests/compat/test_sha256.py

@@ -364,3 +364,46 @@ class GitSHA256CompatibilityTests(CompatTestCase):
         self.assertNotIn(b"error", fsck_output.lower())
         self.assertNotIn(b"missing", fsck_output.lower())
         self.assertNotIn(b"broken", fsck_output.lower())
+
+    def test_dulwich_clone_sha256_repo(self):
+        """Test that dulwich's clone() auto-detects SHA-256 format from git repo."""
+        from dulwich.client import LocalGitClient
+
+        # Create source SHA-256 repo with git
+        source_path = tempfile.mkdtemp()
+        self.addCleanup(rmtree_ro, source_path)
+        self._run_git(
+            ["init", "--object-format=sha256", "--initial-branch=main", source_path]
+        )
+
+        # Add content and commit
+        test_file = os.path.join(source_path, "test.txt")
+        with open(test_file, "w") as f:
+            f.write("SHA-256 clone test")
+
+        self._run_git(["add", "test.txt"], cwd=source_path)
+        self._run_git(["commit", "-m", "Test commit"], cwd=source_path)
+
+        # Clone with dulwich LocalGitClient
+        target_path = tempfile.mkdtemp()
+        self.addCleanup(rmtree_ro, target_path)
+
+        client = LocalGitClient()
+        cloned_repo = client.clone(source_path, target_path, mkdir=False)
+        self.addCleanup(cloned_repo.close)
+
+        # Verify the cloned repo is SHA-256
+        self.assertEqual(cloned_repo.object_format, SHA256)
+
+        # Verify config has correct objectformat extension
+        config = cloned_repo.get_config()
+        self.assertEqual(b"sha256", config.get((b"extensions",), b"objectformat"))
+
+        # Verify git also sees it as SHA-256
+        output = self._run_git(["rev-parse", "--show-object-format"], cwd=target_path)
+        self.assertEqual(output.strip(), b"sha256")
+
+        # Verify objects were cloned correctly
+        source_head = self._run_git(["rev-parse", "refs/heads/main"], cwd=source_path)
+        cloned_head = cloned_repo.refs[b"refs/heads/main"]
+        self.assertEqual(source_head.strip(), cloned_head)

+ 6 - 0
tests/test_config.py

@@ -616,6 +616,7 @@ who\"
         handler = logging.StreamHandler(log_capture)
         handler.setLevel(logging.DEBUG)
         logger = logging.getLogger("dulwich.config")
+        old_level = logger.level
         logger.addHandler(handler)
         logger.setLevel(logging.DEBUG)
 
@@ -636,6 +637,7 @@ who\"
                 self.assertIn("nonexistent.config", log_output)
         finally:
             logger.removeHandler(handler)
+            logger.setLevel(old_level)
 
     def test_invalid_include_path_logging(self) -> None:
         """Test that invalid include paths are logged but don't cause failure."""
@@ -647,6 +649,7 @@ who\"
         handler = logging.StreamHandler(log_capture)
         handler.setLevel(logging.DEBUG)
         logger = logging.getLogger("dulwich.config")
+        old_level = logger.level
         logger.addHandler(handler)
         logger.setLevel(logging.DEBUG)
 
@@ -667,6 +670,7 @@ who\"
                 self.assertIn("Invalid include path", log_output)
         finally:
             logger.removeHandler(handler)
+            logger.setLevel(old_level)
 
     def test_unknown_includeif_condition_logging(self) -> None:
         """Test that unknown includeIf conditions are logged."""
@@ -678,6 +682,7 @@ who\"
         handler = logging.StreamHandler(log_capture)
         handler.setLevel(logging.DEBUG)
         logger = logging.getLogger("dulwich.config")
+        old_level = logger.level
         logger.addHandler(handler)
         logger.setLevel(logging.DEBUG)
 
@@ -700,6 +705,7 @@ who\"
                 self.assertIn("futurefeature:value", log_output)
         finally:
             logger.removeHandler(handler)
+            logger.setLevel(old_level)
 
     def test_custom_file_opener_with_include_depth(self) -> None:
         """Test that custom file opener is passed through include chain."""

+ 24 - 13
tests/test_pack.py

@@ -606,23 +606,26 @@ class TestPackData(PackTests):
         f = BytesIO(b"abcd1234wxyz")
         try:
             self.assertEqual(
-                sha1(b"abcd1234wxyz").hexdigest(), compute_file_sha(f).hexdigest()
+                sha1(b"abcd1234wxyz").hexdigest(),
+                compute_file_sha(f, DEFAULT_OBJECT_FORMAT).hexdigest(),
             )
             self.assertEqual(
                 sha1(b"abcd1234wxyz").hexdigest(),
-                compute_file_sha(f, buffer_size=5).hexdigest(),
+                compute_file_sha(f, DEFAULT_OBJECT_FORMAT, buffer_size=5).hexdigest(),
             )
             self.assertEqual(
                 sha1(b"abcd1234").hexdigest(),
-                compute_file_sha(f, end_ofs=-4).hexdigest(),
+                compute_file_sha(f, DEFAULT_OBJECT_FORMAT, end_ofs=-4).hexdigest(),
             )
             self.assertEqual(
                 sha1(b"1234wxyz").hexdigest(),
-                compute_file_sha(f, start_ofs=4).hexdigest(),
+                compute_file_sha(f, DEFAULT_OBJECT_FORMAT, start_ofs=4).hexdigest(),
             )
             self.assertEqual(
                 sha1(b"1234").hexdigest(),
-                compute_file_sha(f, start_ofs=4, end_ofs=-4).hexdigest(),
+                compute_file_sha(
+                    f, DEFAULT_OBJECT_FORMAT, start_ofs=4, end_ofs=-4
+                ).hexdigest(),
             )
         finally:
             f.close()
@@ -630,10 +633,14 @@ class TestPackData(PackTests):
     def test_compute_file_sha_short_file(self) -> None:
         f = BytesIO(b"abcd1234wxyz")
         try:
-            self.assertRaises(AssertionError, compute_file_sha, f, end_ofs=-20)
-            self.assertRaises(AssertionError, compute_file_sha, f, end_ofs=20)
             self.assertRaises(
-                AssertionError, compute_file_sha, f, start_ofs=10, end_ofs=-12
+                AssertionError, compute_file_sha, f, DEFAULT_OBJECT_FORMAT, -20
+            )
+            self.assertRaises(
+                AssertionError, compute_file_sha, f, DEFAULT_OBJECT_FORMAT, 0, 20
+            )
+            self.assertRaises(
+                AssertionError, compute_file_sha, f, DEFAULT_OBJECT_FORMAT, 10, -12
             )
         finally:
             f.close()
@@ -934,7 +941,9 @@ class WritePackTests(TestCase):
 
             f.write(b"x")  # unpack_object needs extra trailing data.
             f.seek(offset)
-            unpacked, unused = unpack_object(f.read, compute_crc32=True)
+            unpacked, unused = unpack_object(
+                f.read, DEFAULT_OBJECT_FORMAT.hash_func, compute_crc32=True
+            )
             self.assertEqual(Blob.type_num, unpacked.pack_type_num)
             self.assertEqual(Blob.type_num, unpacked.obj_type_num)
             self.assertEqual([b"blob"], unpacked.decomp_chunks)
@@ -1410,7 +1419,7 @@ class TestPackStreamReader(TestCase):
     def test_read_objects_emtpy(self) -> None:
         f = BytesIO()
         build_pack(f, [])
-        reader = PackStreamReader(f.read)
+        reader = PackStreamReader(DEFAULT_OBJECT_FORMAT.hash_func, f.read)
         self.assertEqual(0, len(list(reader.read_objects())))
 
     def test_read_objects(self) -> None:
@@ -1422,7 +1431,7 @@ class TestPackStreamReader(TestCase):
                 (OFS_DELTA, (0, b"blob1")),
             ],
         )
-        reader = PackStreamReader(f.read)
+        reader = PackStreamReader(DEFAULT_OBJECT_FORMAT.hash_func, f.read)
         objects = list(reader.read_objects(compute_crc32=True))
         self.assertEqual(2, len(objects))
 
@@ -1455,11 +1464,13 @@ class TestPackStreamReader(TestCase):
                 (OFS_DELTA, (0, b"blob1")),
             ],
         )
-        reader = PackStreamReader(f.read, zlib_bufsize=4)
+        reader = PackStreamReader(
+            DEFAULT_OBJECT_FORMAT.hash_func, f.read, zlib_bufsize=4
+        )
         self.assertEqual(2, len(list(reader.read_objects())))
 
     def test_read_objects_empty(self) -> None:
-        reader = PackStreamReader(BytesIO().read)
+        reader = PackStreamReader(DEFAULT_OBJECT_FORMAT.hash_func, BytesIO().read)
         self.assertRaises(AssertionError, list, reader.read_objects())