Răsfoiți Sursa

Rename hash to object_format

Jelmer Vernooij 2 luni în urmă
părinte
comite
db1eab3425

+ 59 - 26
dulwich/hash.py → dulwich/object_format.py

@@ -1,4 +1,4 @@
-# hash.py -- Hash algorithm abstraction layer for Git
+# hash.py -- Object format abstraction layer for Git
 # Copyright (C) 2024 The Dulwich contributors
 #
 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
@@ -19,26 +19,34 @@
 # License, Version 2.0.
 #
 
-"""Hash algorithm abstraction for Git objects.
+"""Object format abstraction for Git objects.
 
-This module provides an abstraction layer for different hash algorithms
+This module provides an abstraction layer for different object formats
 used in Git repositories (SHA-1 and SHA-256).
 """
 
+from collections.abc import Callable
 from hashlib import sha1, sha256
-from typing import Callable, Optional
+from typing import TYPE_CHECKING
 
+if TYPE_CHECKING:
+    from _hashlib import HASH
 
-class HashAlgorithm:
-    """Base class for hash algorithms used in Git."""
+
+class ObjectFormat:
+    """Object format (hash algorithm) used in Git."""
 
     def __init__(
-        self, name: str, oid_length: int, hex_length: int, hash_func: Callable
+        self,
+        name: str,
+        oid_length: int,
+        hex_length: int,
+        hash_func: Callable[[], "HASH"],
     ) -> None:
-        """Initialize a hash algorithm.
+        """Initialize an object format.
 
         Args:
-            name: Name of the algorithm (e.g., "sha1", "sha256")
+            name: Name of the format (e.g., "sha1", "sha256")
             oid_length: Length of the binary object ID in bytes
             hex_length: Length of the hexadecimal object ID in characters
             hash_func: Hash function from hashlib
@@ -51,12 +59,14 @@ class HashAlgorithm:
         self.zero_oid_bin = b"\x00" * oid_length
 
     def __str__(self) -> str:
+        """Return string representation."""
         return self.name
 
     def __repr__(self) -> str:
-        return f"HashAlgorithm({self.name!r})"
+        """Return repr."""
+        return f"ObjectFormat({self.name!r})"
 
-    def new_hash(self):
+    def new_hash(self) -> "HASH":
         """Create a new hash object."""
         return self.hash_func()
 
@@ -87,35 +97,58 @@ class HashAlgorithm:
         return h.hexdigest().encode("ascii")
 
 
-# Define the supported hash algorithms
-SHA1 = HashAlgorithm("sha1", 20, 40, sha1)
-SHA256 = HashAlgorithm("sha256", 32, 64, sha256)
+# Define the supported object formats
+SHA1 = ObjectFormat("sha1", 20, 40, sha1)
+SHA256 = ObjectFormat("sha256", 32, 64, sha256)
 
-# Map of algorithm names to HashAlgorithm instances
-HASH_ALGORITHMS = {
+# Map of format names to ObjectFormat instances
+OBJECT_FORMATS = {
     "sha1": SHA1,
     "sha256": SHA256,
 }
 
-# Default algorithm for backward compatibility
-DEFAULT_HASH_ALGORITHM = SHA1
+# Default format for backward compatibility
+DEFAULT_OBJECT_FORMAT = SHA1
 
 
-def get_hash_algorithm(name: Optional[str] = None) -> HashAlgorithm:
-    """Get a hash algorithm by name.
+def get_object_format(name: str | None = None) -> ObjectFormat:
+    """Get an object format by name.
 
     Args:
-        name: Algorithm name ("sha1" or "sha256"). If None, returns default.
+        name: Format name ("sha1" or "sha256"). If None, returns default.
 
     Returns:
-        HashAlgorithm instance
+        ObjectFormat instance
 
     Raises:
-        ValueError: If the algorithm name is not supported
+        ValueError: If the format name is not supported
     """
     if name is None:
-        return DEFAULT_HASH_ALGORITHM
+        return DEFAULT_OBJECT_FORMAT
     try:
-        return HASH_ALGORITHMS[name.lower()]
+        return OBJECT_FORMATS[name.lower()]
     except KeyError:
-        raise ValueError(f"Unsupported hash algorithm: {name}")
+        raise ValueError(f"Unsupported object format: {name}")
+
+
+def verify_same_object_format(*formats: ObjectFormat) -> ObjectFormat:
+    """Verify that all provided object formats are the same.
+
+    Args:
+        *formats: Object format instances to verify
+
+    Returns:
+        The common object format
+
+    Raises:
+        ValueError: If formats don't match or no formats provided
+    """
+    if not formats:
+        raise ValueError("At least one object format must be provided")
+
+    first = formats[0]
+    for fmt in formats[1:]:
+        if fmt != first:
+            raise ValueError(f"Object format mismatch: {first.name} != {fmt.name}")
+
+    return first

+ 71 - 22
dulwich/object_store.py

@@ -319,6 +319,16 @@ class PackContainer(Protocol):
 class BaseObjectStore:
     """Object store interface."""
 
+    def __init__(self, *, object_format=None) -> None:
+        """Initialize object store.
+
+        Args:
+            object_format: Hash algorithm to use (defaults to SHA1)
+        """
+        from .object_format import get_object_format
+
+        self.object_format = object_format if object_format else get_object_format()
+
     def determine_wants_all(
         self, refs: Mapping[Ref, ObjectID], depth: int | None = None
     ) -> list[ObjectID]:
@@ -371,7 +381,9 @@ class BaseObjectStore:
     def __getitem__(self, sha1: ObjectID | RawObjectID) -> ShaFile:
         """Obtain an object by SHA1."""
         type_num, uncomp = self.get_raw(sha1)
-        return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
+        return ShaFile.from_raw_string(
+            type_num, uncomp, sha=sha1, object_format=self.object_format
+        )
 
     def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs that are present in this store."""
@@ -816,6 +828,8 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         pack_depth: int | None = None,
         pack_threads: int | None = None,
         pack_big_file_threshold: int | None = None,
+        *,
+        object_format=None,
     ) -> None:
         """Initialize a PackBasedObjectStore.
 
@@ -828,7 +842,9 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
           pack_depth: Maximum depth for pack deltas
           pack_threads: Number of threads to use for packing
           pack_big_file_threshold: Threshold for treating files as "big"
+          object_format: Hash algorithm to use
         """
+        super().__init__(object_format=object_format)
         self._pack_cache: dict[str, Pack] = {}
         self.pack_compression_level = pack_compression_level
         self.pack_index_version = pack_index_version
@@ -1382,7 +1398,7 @@ class DiskObjectStore(PackBasedObjectStore):
         pack_write_bitmap_lookup_table: bool = True,
         file_mode: int | None = None,
         dir_mode: int | None = None,
-        hash_algorithm=None,
+        object_format=None,
     ) -> None:
         """Open an object store.
 
@@ -1403,8 +1419,12 @@ class DiskObjectStore(PackBasedObjectStore):
           pack_write_bitmap_lookup_table: whether to include lookup table in bitmaps
           file_mode: File permission mask for shared repository
           dir_mode: Directory permission mask for shared repository
-          hash_algorithm: Hash algorithm to use (SHA1 or SHA256)
+          object_format: Hash algorithm to use (SHA1 or SHA256)
+          object_format: Hash algorithm to use (SHA1 or SHA256)
         """
+        # Import here to avoid circular dependency
+        from .object_format import get_object_format
+
         super().__init__(
             pack_compression_level=pack_compression_level,
             pack_index_version=pack_index_version,
@@ -1414,6 +1434,7 @@ class DiskObjectStore(PackBasedObjectStore):
             pack_depth=pack_depth,
             pack_threads=pack_threads,
             pack_big_file_threshold=pack_big_file_threshold,
+            object_format=object_format if object_format else get_object_format(),
         )
         self.path = path
         self.pack_dir = os.path.join(self.path, PACKDIR)
@@ -1428,11 +1449,6 @@ class DiskObjectStore(PackBasedObjectStore):
         self.file_mode = file_mode
         self.dir_mode = dir_mode
 
-        # Import here to avoid circular dependency
-        from .hash import get_hash_algorithm
-
-        self.hash_algorithm = hash_algorithm if hash_algorithm else get_hash_algorithm()
-
         # Commit graph support - lazy loaded
         self._commit_graph = None
         self._use_commit_graph = True  # Default to true
@@ -1548,9 +1564,9 @@ class DiskObjectStore(PackBasedObjectStore):
             )
 
         # Get hash algorithm from config
-        from .hash import get_hash_algorithm
+        from .object_format import get_object_format
 
-        hash_algorithm = None
+        object_format = None
         try:
             try:
                 version = int(config.get((b"core",), b"repositoryformatversion"))
@@ -1561,7 +1577,7 @@ class DiskObjectStore(PackBasedObjectStore):
                     object_format = config.get((b"extensions",), b"objectformat")
                 except KeyError:
                     object_format = b"sha1"
-                hash_algorithm = get_hash_algorithm(object_format.decode("ascii"))
+                object_format = get_object_format(object_format.decode("ascii"))
         except (KeyError, ValueError):
             pass
 
@@ -1582,7 +1598,18 @@ class DiskObjectStore(PackBasedObjectStore):
             pack_write_bitmap_lookup_table=pack_write_bitmap_lookup_table,
             file_mode=file_mode,
             dir_mode=dir_mode,
-            hash_algorithm=hash_algorithm,
+            object_format=hash_algorithm,
+            loose_compression_level,
+            pack_compression_level,
+            pack_index_version,
+            pack_delta_window_size,
+            pack_window_memory,
+            pack_delta_cache_size,
+            pack_depth,
+            pack_threads,
+            pack_big_file_threshold,
+            fsync_object_files,
+            object_format,
         )
         instance._use_commit_graph = use_commit_graph
         instance._use_midx = use_midx
@@ -1673,7 +1700,7 @@ class DiskObjectStore(PackBasedObjectStore):
                     depth=self.pack_depth,
                     threads=self.pack_threads,
                     big_file_threshold=self.pack_big_file_threshold,
-                    hash_algorithm=self.hash_algorithm,
+                    object_format=self.object_format,
                 )
                 new_packs.append(pack)
                 self._pack_cache[f] = pack
@@ -1725,9 +1752,8 @@ class DiskObjectStore(PackBasedObjectStore):
     def _get_loose_object(self, sha: ObjectID | RawObjectID) -> ShaFile | None:
         path = self._get_shafile_path(sha)
         try:
-            # Load the object from path with SHA for hash algorithm detection
-            # sha parameter here is already hex, so pass it directly
-            return ShaFile.from_path(path, sha)
+            # Load the object from path with SHA and hash algorithm from object store
+            return ShaFile.from_path(path, sha, object_format=self.object_format)
         except FileNotFoundError:
             return None
 
@@ -1914,7 +1940,7 @@ class DiskObjectStore(PackBasedObjectStore):
             depth=self.pack_depth,
             threads=self.pack_threads,
             big_file_threshold=self.pack_big_file_threshold,
-            hash_algorithm=self.hash_algorithm,
+            object_format=self.object_format,
         )
         final_pack.check_length_and_checksum()
         self._add_cached_pack(pack_base_name, final_pack)
@@ -1995,7 +2021,7 @@ class DiskObjectStore(PackBasedObjectStore):
           obj: Object to add
         """
         # Use the correct hash algorithm for the object ID
-        obj_id = obj.get_id(self.hash_algorithm)
+        obj_id = obj.get_id(self.object_format)
         path = self._get_shafile_path(obj_id)
         dir = os.path.dirname(path)
         try:
@@ -2019,7 +2045,8 @@ class DiskObjectStore(PackBasedObjectStore):
         *,
         file_mode: int | None = None,
         dir_mode: int | None = None,
-        hash_algorithm=None,
+        object_format=None,
+        cls, path: str | os.PathLike[str], object_format=None
     ) -> "DiskObjectStore":
         """Initialize a new disk object store.
 
@@ -2029,7 +2056,8 @@ class DiskObjectStore(PackBasedObjectStore):
           path: Path where the object store should be created
           file_mode: Optional file permission mask for shared repository
           dir_mode: Optional directory permission mask for shared repository
-          hash_algorithm: Hash algorithm to use (SHA1 or SHA256)
+          object_format: Hash algorithm to use (SHA1 or SHA256)
+          object_format: Hash algorithm to use (SHA1 or SHA256)
 
         Returns:
           New DiskObjectStore instance
@@ -2047,7 +2075,10 @@ class DiskObjectStore(PackBasedObjectStore):
         if dir_mode is not None:
             os.chmod(info_path, dir_mode)
             os.chmod(pack_path, dir_mode)
-        return cls(path, file_mode=file_mode, dir_mode=dir_mode, hash_algorithm=hash_algorithm)
+        return cls(path, file_mode=file_mode, dir_mode=dir_mode, object_format=object_format)
+        os.mkdir(os.path.join(path, "info"))
+        os.mkdir(os.path.join(path, PACKDIR))
+        return cls(path, object_format=object_format)
 
     def iter_prefix(self, prefix: bytes) -> Iterator[ObjectID]:
         """Iterate over all object SHAs with the given prefix.
@@ -2401,13 +2432,18 @@ class DiskObjectStore(PackBasedObjectStore):
 class MemoryObjectStore(PackCapableObjectStore):
     """Object store that keeps all objects in memory."""
 
-    def __init__(self) -> None:
+    def __init__(self, *, object_format=None) -> None:
         """Initialize a MemoryObjectStore.
 
         Creates an empty in-memory object store.
+
+        Args:
+            object_format: Hash algorithm to use (defaults to SHA1)
         """
         super().__init__()
         self._data: dict[ObjectID, ShaFile] = {}
+        super().__init__(object_format=object_format)
+        self._data: dict[bytes, ShaFile] = {}
         self.pack_compression_level = -1
 
     def _to_hexsha(self, sha: ObjectID | RawObjectID) -> ObjectID:
@@ -3020,7 +3056,20 @@ class OverlayObjectStore(BaseObjectStore):
         Args:
           bases: List of base object stores to overlay
           add_store: Optional store to write new objects to
+
+        Raises:
+          ValueError: If stores have different hash algorithms
         """
+        from .object_format import verify_same_object_format
+
+        # Verify all stores use the same hash algorithm
+        store_algorithms = [store.object_format for store in bases]
+        if add_store:
+            store_algorithms.append(add_store.object_format)
+
+        object_format = verify_same_object_format(*store_algorithms)
+
+        super().__init__(object_format=object_format)
         self.bases = bases
         self.add_store = add_store
 

+ 84 - 66
dulwich/objects.py

@@ -116,18 +116,19 @@ if TYPE_CHECKING:
 ZERO_SHA = b"0" * 40  # SHA1 - kept for backward compatibility
 
 
-def zero_sha_for(hash_algorithm=None):
+def zero_sha_for(object_format=None) -> bytes:
     """Get the zero SHA for a given hash algorithm.
 
     Args:
-        hash_algorithm: HashAlgorithm instance. If None, returns SHA1 zero.
+        object_format: HashAlgorithm instance. If None, returns SHA1 zero.
 
     Returns:
         Zero SHA as hex bytes (40 chars for SHA1, 64 for SHA256)
     """
-    if hash_algorithm is None:
+    if object_format is None:
         return ZERO_SHA
-    return hash_algorithm.zero_oid
+    return object_format.zero_oid
+
 
 
 # Header fields for commits
@@ -203,7 +204,7 @@ def hex_to_sha(hex: ObjectID | str) -> RawObjectID:
     """Takes a hex sha and returns a binary sha."""
     # Support both SHA1 (40 chars) and SHA256 (64 chars)
     if len(hex) not in (40, 64):
-        raise ValueError(f"Incorrect length of hexsha: {hex}")
+        raise ValueError(f"Incorrect length of hexsha: {hex!r}")
     try:
         return RawObjectID(binascii.unhexlify(hex))
     except TypeError as exc:
@@ -465,7 +466,7 @@ else:
 class ShaFile:
     """A git SHA file."""
 
-    __slots__ = ("_chunked_text", "_needs_serialization", "_sha")
+    __slots__ = ("_chunked_text", "_needs_serialization", "_sha", "object_format")
 
     _needs_serialization: bool
     type_name: bytes
@@ -473,6 +474,13 @@ class ShaFile:
     _chunked_text: list[bytes] | None
     _sha: "FixedSha | None | HASH"
 
+    def __init__(self) -> None:
+        """Initialize a ShaFile."""
+        self._sha = None
+        self._chunked_text = None
+        self._needs_serialization = True
+        self.object_format = None
+
     @staticmethod
     def _parse_legacy_object_header(
         magic: bytes, f: BufferedIOBase | IO[bytes] | "_GitFile"
@@ -567,15 +575,18 @@ class ShaFile:
         self.set_raw_chunks([text], sha)
 
     def set_raw_chunks(
-        self, chunks: list[bytes], sha: ObjectID | RawObjectID | None = None
+        self, chunks: list[bytes], sha: ObjectID | RawObjectID | None = None, *, object_format=None
     ) -> None:
         """Set the contents of this object from a list of chunks."""
         self._chunked_text = chunks
-        # Set SHA before deserialization so Tree can detect hash algorithm
+        # Set hash algorithm if provided
+        if object_format is not None:
+            self.object_format = object_format
+        # Set SHA before deserialization so Tree can use hash algorithm
         if sha is None:
             self._sha = None
         else:
-            self._sha = FixedSha(sha)  # type: ignore
+            self._sha = FixedSha(sha)
         self._deserialize(chunks)
         self._needs_serialization = False
 
@@ -610,25 +621,25 @@ class ShaFile:
         return (b0 & 0x8F) == 0x08 and (word % 31) == 0
 
     @classmethod
-    def _parse_file(cls, f: BufferedIOBase | IO[bytes] | "_GitFile") -> "ShaFile":
+    def _parse_file(
+        cls, f: BufferedIOBase | IO[bytes] | "_GitFile", *, object_format=None
+    ) -> "ShaFile":
         map = f.read()
         if not map:
             raise EmptyFileException("Corrupted empty file detected")
 
         if cls._is_legacy_object(map):
             obj = cls._parse_legacy_object_header(map, f)
+            if object_format is not None:
+                obj.object_format = object_format
             obj._parse_legacy_object(map)
         else:
             obj = cls._parse_object_header(map, f)
+            if object_format is not None:
+                obj.object_format = object_format
             obj._parse_object(map)
         return obj
 
-    def __init__(self) -> None:
-        """Don't call this directly."""
-        self._sha = None
-        self._chunked_text = []
-        self._needs_serialization = True
-
     def _deserialize(self, chunks: list[bytes]) -> None:
         raise NotImplementedError(self._deserialize)
 
@@ -636,17 +647,33 @@ class ShaFile:
         raise NotImplementedError(self._serialize)
 
     @classmethod
-    def from_path(cls, path: str | bytes, sha: ObjectID | None = None) -> "ShaFile":
+    def from_path(
+        cls, path: str | bytes, sha: ObjectID | None = None, *, object_format=None
+    ) -> "ShaFile":
         """Open a SHA file from disk."""
         with GitFile(path, "rb") as f:
-            return cls.from_file(f, sha)
+            return cls.from_file(f, sha, object_format=object_format)
 
     @classmethod
-    def from_file(cls, f: BufferedIOBase | IO[bytes] | "_GitFile", sha: ObjectID | None = None) -> "ShaFile":
+    def from_file(
+        cls,
+        f: BufferedIOBase | IO[bytes] | "_GitFile",
+        sha: ObjectID | None = None,
+        *,
+        object_format=None,
+    ) -> "ShaFile":
         """Get the contents of a SHA file on disk."""
         try:
-            obj = cls._parse_file(f)
-            # Set SHA after parsing but before any further processing
+            # Validate SHA length matches hash algorithm if both provided
+            if sha is not None and object_format is not None:
+                expected_len = object_format.hex_length
+                if len(sha) != expected_len:
+                    raise ValueError(
+                        f"SHA length {len(sha)} doesn't match hash algorithm "
+                        f"{object_format.name} (expected {expected_len})"
+                    )
+
+            obj = cls._parse_file(f, object_format=object_format)
             if sha is not None:
                 obj._sha = FixedSha(sha)
             else:
@@ -657,7 +684,11 @@ class ShaFile:
 
     @staticmethod
     def from_raw_string(
-        type_num: int, string: bytes, sha: ObjectID | RawObjectID | None = None
+        type_num: int,
+        string: bytes,
+        sha: ObjectID | RawObjectID | None = None,
+        *,
+        object_format=None,
     ) -> "ShaFile":
         """Creates an object of the indicated type from the raw string given.
 
@@ -665,11 +696,14 @@ class ShaFile:
           type_num: The numeric type of the object.
           string: The raw uncompressed contents.
           sha: Optional known sha for the object
+          object_format: Optional hash algorithm for the object
         """
         cls = object_class(type_num)
         if cls is None:
             raise AssertionError(f"unsupported class type num: {type_num}")
         obj = cls()
+        if object_format is not None:
+            obj.object_format = object_format
         obj.set_raw_string(string, sha)
         return obj
 
@@ -740,15 +774,15 @@ class ShaFile:
         """Returns the length of the raw string of this object."""
         return sum(map(len, self.as_raw_chunks()))
 
-    def sha(self, hash_algorithm=None) -> "FixedSha | HASH":
+    def sha(self, object_format=None) -> "FixedSha | HASH":
         """The SHA object that is the name of this object.
 
         Args:
-            hash_algorithm: Optional HashAlgorithm to use. Defaults to SHA1.
+            object_format: Optional HashAlgorithm to use. Defaults to SHA1.
         """
         # If using a different hash algorithm, always recalculate
-        if hash_algorithm is not None:
-            new_sha = hash_algorithm.new_hash()
+        if object_format is not None:
+            new_sha = object_format.new_hash()
             new_sha.update(self._header())
             for chunk in self.as_raw_chunks():
                 new_sha.update(chunk)
@@ -775,16 +809,16 @@ class ShaFile:
     def id(self) -> ObjectID:
         """The hex SHA1 of this object.
 
-        For SHA256 repositories, use get_id(hash_algorithm) instead.
+        For SHA256 repositories, use get_id(object_format) instead.
         This property always returns SHA1 for backward compatibility.
         """
         return ObjectID(self.sha().hexdigest().encode("ascii"))
 
-    def get_id(self, hash_algorithm=None):
+    def get_id(self, object_format=None) -> bytes:
         """Get the hex SHA of this object using the specified hash algorithm.
 
         Args:
-            hash_algorithm: Optional HashAlgorithm to use. Defaults to SHA1.
+            object_format: Optional HashAlgorithm to use. Defaults to SHA1.
 
         Example:
             >>> blob = Blob()
@@ -793,11 +827,11 @@ class ShaFile:
             b'4ab299c8ad6ed14f31923dd94f8b5f5cb89dfb54'
             >>> blob.get_id()  # Same as .id
             b'4ab299c8ad6ed14f31923dd94f8b5f5cb89dfb54'
-            >>> from dulwich.hash import SHA256
+            >>> from dulwich.object_format import SHA256
             >>> blob.get_id(SHA256)  # Get SHA256 hash
             b'03ba204e2f2e707...'  # 64-character SHA256
         """
-        return self.sha(hash_algorithm).hexdigest().encode("ascii")
+        return self.sha(object_format).hexdigest().encode("ascii")
 
     def __repr__(self) -> str:
         """Return string representation of this object."""
@@ -869,11 +903,12 @@ class Blob(ShaFile):
     )
 
     @classmethod
-    def from_path(cls, path: str | bytes) -> "Blob":
+    def from_path(cls, path: str | bytes, sha: ObjectID | None = None) -> "Blob":
         """Read a blob from a file on disk.
 
         Args:
           path: Path to the blob file
+          sha: Optional known SHA for the object
 
         Returns:
           A Blob object
@@ -881,7 +916,7 @@ class Blob(ShaFile):
         Raises:
           NotBlobError: If the file is not a blob
         """
-        blob = ShaFile.from_path(path)
+        blob = ShaFile.from_path(path, sha)
         if not isinstance(blob, cls):
             raise NotBlobError(_path_to_bytes(path))
         return blob
@@ -1030,11 +1065,12 @@ class Tag(ShaFile):
         self._signature: bytes | None = None
 
     @classmethod
-    def from_path(cls, filename: str | bytes) -> "Tag":
+    def from_path(cls, filename: str | bytes, sha: ObjectID | None = None) -> "Tag":
         """Read a tag from a file on disk.
 
         Args:
           filename: Path to the tag file
+          sha: Optional known SHA for the object
 
         Returns:
           A Tag object
@@ -1042,7 +1078,7 @@ class Tag(ShaFile):
         Raises:
           NotTagError: If the file is not a tag
         """
-        tag = ShaFile.from_path(filename)
+        tag = ShaFile.from_path(filename, sha)
         if not isinstance(tag, cls):
             raise NotTagError(_path_to_bytes(filename))
         return tag
@@ -1310,21 +1346,21 @@ class TreeEntry(NamedTuple):
 
 
 def parse_tree(
-    text: bytes, strict: bool = False, hash_algorithm=None
+    text: bytes, strict: bool = False, object_format=None
 ) -> Iterator[tuple[bytes, int, ObjectID]]:
     """Parse a tree text.
 
     Args:
       text: Serialized text to parse
       strict: Whether to be strict about format
-      hash_algorithm: Hash algorithm object (SHA1 or SHA256) - if None, auto-detect
+      object_format: Hash algorithm object (SHA1 or SHA256) - if None, auto-detect
     Returns: iterator of tuples of (name, mode, sha)
 
     Raises:
       ObjectFormatException: if the object was malformed in some way
     """
-    if hash_algorithm is not None:
-        sha_len = hash_algorithm.oid_length
+    if object_format is not None:
+        sha_len = object_format.oid_length
         return _parse_tree_with_sha_len(text, strict, sha_len)
 
     # Try both hash lengths and use the one that works
@@ -1336,7 +1372,9 @@ def parse_tree(
         return _parse_tree_with_sha_len(text, strict, 32)
 
 
-def _parse_tree_with_sha_len(text, strict, sha_len):
+def _parse_tree_with_sha_len(
+    text: bytes, strict: bool, sha_len: int
+) -> Iterator[tuple[bytes, int, bytes]]:
     """Helper function to parse tree with a specific hash length."""
     count = 0
     length = len(text)
@@ -1474,27 +1512,6 @@ class Tree(ShaFile):
         super().__init__()
         self._entries: dict[bytes, tuple[int, ObjectID]] = {}
 
-    def _get_hash_algorithm(self):
-        """Get the hash algorithm based on the object's SHA."""
-        if not hasattr(self, "_sha") or self._sha is None:
-            return None
-
-        # Get the raw SHA bytes
-        sha = self._sha.digest() if hasattr(self._sha, "digest") else self._sha
-        if not isinstance(sha, bytes):
-            return None
-
-        # Import hash modules lazily to avoid circular imports
-        if len(sha) == 32:
-            from .hash import SHA256
-
-            return SHA256
-        elif len(sha) == 20:
-            from .hash import SHA1
-
-            return SHA1
-        return None
-
     @classmethod
     def from_path(cls, filename: str | bytes, sha: ObjectID | None = None) -> "Tree":
         """Read a tree from a file on disk.
@@ -1581,7 +1598,7 @@ class Tree(ShaFile):
         """Grab the entries in the tree."""
         try:
             parsed_entries = parse_tree(
-                b"".join(chunks), hash_algorithm=self._get_hash_algorithm()
+                b"".join(chunks), object_format=self.object_format
             )
         except ValueError as exc:
             raise ObjectFormatException(exc) from exc
@@ -1611,9 +1628,9 @@ class Tree(ShaFile):
         for name, mode, sha in parse_tree(
             b"".join(self._chunked_text),
             strict=True,
-            hash_algorithm=self._get_hash_algorithm(),
+            object_format=self.object_format,
         ):
-            check_hexsha(sha, f"invalid sha {sha}")
+            check_hexsha(sha, f"invalid sha {sha!r}")
             if b"/" in name or name in (b"", b".", b"..", b".git"):
                 raise ObjectFormatException(
                     "invalid name {}".format(name.decode("utf-8", "replace"))
@@ -1884,11 +1901,12 @@ class Commit(ShaFile):
         self._commit_timezone_neg_utc: bool | None = False
 
     @classmethod
-    def from_path(cls, path: str | bytes) -> "Commit":
+    def from_path(cls, path: str | bytes, sha: ObjectID | None = None) -> "Commit":
         """Read a commit from a file on disk.
 
         Args:
           path: Path to the commit file
+          sha: Optional known SHA for the object
 
         Returns:
           A Commit object
@@ -1896,7 +1914,7 @@ class Commit(ShaFile):
         Raises:
           NotCommitError: If the file is not a commit
         """
-        commit = ShaFile.from_path(path)
+        commit = ShaFile.from_path(path, sha)
         if not isinstance(commit, cls):
             raise NotCommitError(_path_to_bytes(path))
         return commit

+ 48 - 40
dulwich/pack.py

@@ -530,16 +530,16 @@ def iter_sha1(iter: Iterable[bytes]) -> bytes:
     return sha.hexdigest().encode("ascii")
 
 
-def load_pack_index(path: str | os.PathLike[str], hash_algorithm: int | None = None) -> "PackIndex":
+def load_pack_index(path: str | os.PathLike[str], object_format: ObjectFormat| None = None) -> "PackIndex":
     """Load an index file by path.
 
     Args:
       path: Path to the index file
-      hash_algorithm: Hash algorithm used by the repository
+      object_format: Hash algorithm used by the repository
     Returns: A PackIndex loaded from the given path
     """
     with GitFile(path, "rb") as f:
-        return load_pack_index_file(path, f, hash_algorithm=hash_algorithm)
+        return load_pack_index_file(path, f, object_format=object_format)
 
 
 def _load_file_contents(
@@ -575,14 +575,14 @@ def _load_file_contents(
 
 def load_pack_index_file(
     path: str | os.PathLike[str], f: IO[bytes] | _GitFile
-    hash_algorithm: int | None = None,
+    object_format: ObjectFormat | = None,
 ) -> "PackIndex":
     """Load an index file from a file-like object.
 
     Args:
       path: Path for the index file
       f: File-like object
-      hash_algorithm: Hash algorithm used by the repository
+      object_format: Hash algorithm used by the repository
     Returns: A PackIndex loaded from the given file
     """
     contents, size = _load_file_contents(f)
@@ -594,7 +594,7 @@ def load_pack_index_file(
                 file=f,
                 contents=contents,
                 size=size,
-                hash_algorithm=hash_algorithm,
+                object_format=object_format,
             )
         elif version == 3:
             return PackIndex3(path, file=f, contents=contents, size=size)
@@ -602,7 +602,7 @@ def load_pack_index_file(
             raise KeyError(f"Unknown pack index format {version}")
     else:
         return PackIndex1(
-            path, file=f, contents=contents, size=size, hash_algorithm=hash_algorithm
+            path, file=f, contents=contents, size=size, object_format=object_format
         )
 
 
@@ -642,7 +642,7 @@ class PackIndex:
     """
 
     # Default to SHA-1 for backward compatibility
-    hash_algorithm = 1
+    object_format = 1
     hash_size = 20
 
     def __eq__(self, other: object) -> bool:
@@ -1032,7 +1032,7 @@ class PackIndex1(FilePackIndex):
         file: IO[bytes] | _GitFile | None = None,
         contents: bytes | None = None,
         size: int | None = None,
-        hash_algorithm: int | None = None,
+        object_format: ObjectFormat | None = None,
     ) -> None:
         """Initialize a version 1 pack index.
 
@@ -1047,8 +1047,8 @@ class PackIndex1(FilePackIndex):
         self.version = 1
         self._fan_out_table = self._read_fan_out_table(0)
         # Use provided hash algorithm if available, otherwise default to SHA1
-        if hash_algorithm:
-            self.hash_size = hash_algorithm.oid_length
+        if object_format:
+            self.hash_size = object_format.oid_length
         else:
             self.hash_size = 20  # Default to SHA1
 
@@ -1085,7 +1085,7 @@ class PackIndex2(FilePackIndex):
         file: IO[bytes] | _GitFile | None = None,
         contents: bytes | None = None,
         size: int | None = None,
-        hash_algorithm: int | None = None,
+        object_format: ObjectFormat | None = None,
     ) -> None:
         """Initialize a version 2 pack index.
 
@@ -1105,8 +1105,8 @@ class PackIndex2(FilePackIndex):
         self._fan_out_table = self._read_fan_out_table(8)
 
         # Use provided hash algorithm if available, otherwise default to SHA1
-        if hash_algorithm:
-            self.hash_size = hash_algorithm.oid_length
+        if object_format:
+            self.hash_size = object_format.oid_length
         else:
             self.hash_size = 20  # Default to SHA1
 
@@ -1177,13 +1177,13 @@ class PackIndex3(FilePackIndex):
             raise AssertionError(f"Version was {self.version}")
 
         # Read hash algorithm identifier (1 = SHA-1, 2 = SHA-256)
-        (self.hash_algorithm,) = unpack_from(b">L", self._contents, 8)
-        if self.hash_algorithm == 1:
+        (self.object_format,) = unpack_from(b">L", self._contents, 8)
+        if self.object_format == 1:
             self.hash_size = 20  # SHA-1
-        elif self.hash_algorithm == 2:
+        elif self.object_format == 2:
             self.hash_size = 32  # SHA-256
         else:
-            raise AssertionError(f"Unknown hash algorithm {self.hash_algorithm}")
+            raise AssertionError(f"Unknown hash algorithm {self.object_format}")
 
         # Read length of shortened object names
         (self.shortened_oid_len,) = unpack_from(b">L", self._contents, 12)
@@ -1857,7 +1857,7 @@ class PackData:
         filename: str,
         progress: Callable[..., None] | None = None,
         resolve_ext_ref: ResolveExtRefFn | None = None,
-        hash_algorithm: int = 1,
+        hash_format: int | None = None,
     ) -> bytes:
         """Create a version 3 index file for this data file.
 
@@ -1865,7 +1865,7 @@ class PackData:
           filename: Index filename.
           progress: Progress report function
           resolve_ext_ref: Function to resolve external references
-          hash_algorithm: Hash algorithm identifier (1 = SHA-1, 2 = SHA-256)
+          hash_format: Hash algorithm identifier (1 = SHA-1, 2 = SHA-256)
         Returns: Checksum of index file
         """
         entries = self.sorted_entries(
@@ -1873,7 +1873,7 @@ class PackData:
         )
         with GitFile(filename, "wb") as f:
             return write_pack_index_v3(
-                f, entries, self.calculate_checksum(), hash_algorithm
+                f, entries, self.calculate_checksum(), hash_format
             )
 
     def create_index(
@@ -1882,7 +1882,7 @@ class PackData:
         progress: Callable[..., None] | None = None,
         version: int = 2,
         resolve_ext_ref: ResolveExtRefFn | None = None,
-        hash_algorithm: int = 1,
+        hash_format: int | None = None,
     ) -> bytes:
         """Create an  index file for this data file.
 
@@ -1891,7 +1891,7 @@ class PackData:
           progress: Progress report function
           version: Index version (1, 2, or 3)
           resolve_ext_ref: Function to resolve external references
-          hash_algorithm: Hash algorithm identifier for v3 (1 = SHA-1, 2 = SHA-256)
+          hash_format: Hash algorithm identifier for v3 (1 = SHA-1, 2 = SHA-256)
         Returns: Checksum of index file
         """
         if version == 1:
@@ -1907,7 +1907,7 @@ class PackData:
                 filename,
                 progress,
                 resolve_ext_ref=resolve_ext_ref,
-                hash_algorithm=hash_algorithm,
+                hash_format=hash_format,
             )
         else:
             raise ValueError(f"unknown index format {version}")
@@ -3428,7 +3428,7 @@ def write_pack_index_v3(
     f: IO[bytes],
     entries: Iterable[tuple[bytes, int, int | None]],
     pack_checksum: bytes,
-    hash_algorithm: int = 1,
+    hash_format: int = 1,
 ) -> bytes:
     """Write a new pack index file in v3 format.
 
@@ -3437,18 +3437,18 @@ def write_pack_index_v3(
       entries: List of tuples with object name (sha), offset_in_pack, and
         crc32_checksum.
       pack_checksum: Checksum of the pack file.
-      hash_algorithm: Hash algorithm identifier (1 = SHA-1, 2 = SHA-256)
+      hash_format: Hash algorithm identifier (1 = SHA-1, 2 = SHA-256)
     Returns: The SHA of the index file written
     """
-    if hash_algorithm == 1:
+    if hash_format == 1:
         hash_size = 20  # SHA-1
         writer_cls = SHA1Writer
-    elif hash_algorithm == 2:
+    elif hash_format == 2:
         hash_size = 32  # SHA-256
         # TODO: Add SHA256Writer when SHA-256 support is implemented
         raise NotImplementedError("SHA-256 support not yet implemented")
     else:
-        raise ValueError(f"Unknown hash algorithm {hash_algorithm}")
+        raise ValueError(f"Unknown hash algorithm {hash_format}")
 
     # Convert entries to list to allow multiple iterations
     entries_list = list(entries)
@@ -3460,7 +3460,7 @@ def write_pack_index_v3(
     f = writer_cls(f)
     f.write(b"\377tOc")  # Magic!
     f.write(struct.pack(">L", 3))  # Version 3
-    f.write(struct.pack(">L", hash_algorithm))  # Hash algorithm
+    f.write(struct.pack(">L", hash_format))  # Hash algorithm
     f.write(struct.pack(">L", shortened_oid_len))  # Shortened OID length
 
     fan_out_table: dict[int, int] = defaultdict(lambda: 0)
@@ -3522,6 +3522,9 @@ def write_pack_index(
 
     Returns:
       SHA of the written index file
+
+    Raises:
+      ValueError: If an unsupported version is specified
     """
     if version is None:
         version = DEFAULT_PACK_INDEX_VERSION
@@ -3557,7 +3560,7 @@ class Pack:
         depth: int | None = None,
         threads: int | None = None,
         big_file_threshold: int | None = None,
-        hash_algorithm: int | None = None,
+        object_format: ObjectFormat | None = None,
     ) -> None:
         """Initialize a Pack object.
 
@@ -3570,7 +3573,7 @@ class Pack:
           depth: Maximum depth for delta chains
           threads: Number of threads to use for operations
           big_file_threshold: Size threshold for big file handling
-          hash_algorithm: Hash algorithm identifier (1 = SHA-1, 2 = SHA-256)
+          object_format: Hash algorithm to use (defaults to SHA1)
         """
         self._basename = basename
         self._data = None
@@ -3594,27 +3597,32 @@ class Pack:
             threads=threads,
             big_file_threshold=big_file_threshold,
         )
-        self._idx_load = lambda: load_pack_index(self._idx_path)
+        self._idx_load = lambda: load_pack_index(
+            self._idx_path, object_format=object_format
+        )
         self.resolve_ext_ref = resolve_ext_ref
-        self.hash_algorithm = (
-            hash_algorithm if hash_algorithm is not None else DEFAULT_HASH_ALGORITHM
+        # Always set object_format, defaulting to SHA1
+        from .object_format import get_object_format
+
+        self.object_format = (
+            object_format if object_format else get_object_format("sha1")
         )
 
     @classmethod
     def from_lazy_objects(
         cls, data_fn: Callable[[], PackData], idx_fn: Callable[[], PackIndex],
-        hash_algorithm: int | None = None
+        object_format: ObjectFormat | None = None
     ) -> "Pack":
         """Create a new pack object from callables to load pack data and index objects."""
-        ret = cls("", hash_algorithm=hash_algorithm)
+        ret = cls("", object_format=object_format)
         ret._data_load = data_fn
         ret._idx_load = idx_fn
         return ret
 
     @classmethod
-    def from_objects(cls, data: PackData, idx: PackIndex, hash_algorithm: int | None = None) -> "Pack":
+    def from_objects(cls, data: PackData, idx: PackIndex, object_format | ObjectFormat | None = None) -> "Pack":
         """Create a new pack object from pack data and index objects."""
-        ret = cls("", hash_algorithm=hash_algorithm)
+        ret = cls("", object_format=object_format)
         ret._data = data
         ret._data_load = None
         ret._idx = idx
@@ -3981,7 +3989,7 @@ class Pack:
         return base_type, chunks
 
     def entries(
-        self, progress: Callable[[int, int], None] | None = None
+        self, progress: ProgressFn | None = None
     ) -> Iterator[PackIndexEntry]:
         """Yield entries summarizing the contents of this pack.
 

+ 49 - 45
dulwich/repo.py

@@ -499,7 +499,10 @@ class BaseRepo:
     """
 
     def __init__(
-        self, object_store: "PackCapableObjectStore", refs: RefsContainer
+        self,
+        object_store: "PackCapableObjectStore",
+        refs: RefsContainer,
+        object_format=None,
     ) -> None:
         """Open a repository.
 
@@ -509,13 +512,14 @@ class BaseRepo:
         Args:
           object_store: Object store to use
           refs: Refs container to use
+          object_format: Hash algorithm to use (if None, will be determined from config)
         """
         self.object_store = object_store
         self.refs = refs
 
         self._graftpoints: dict[ObjectID, list[ObjectID]] = {}
         self.hooks: dict[str, Hook] = {}
-        self._hash_algorithm = None  # Cached hash algorithm
+        self.object_format = object_format  # Hash algorithm (SHA1 or SHA256)
 
     def _determine_file_mode(self) -> bool:
         """Probe the file-system to determine whether permissions can be trusted.
@@ -570,6 +574,11 @@ class BaseRepo:
         if object_format == "sha256":
             cf.set("extensions", "objectformat", "sha256")
 
+        # Set hash algorithm based on object format
+        from .object_format import get_object_format
+
+        self.object_format = get_object_format(object_format)
+
         if self._determine_file_mode():
             cf.set("core", "filemode", True)
         else:
@@ -946,42 +955,6 @@ class BaseRepo:
         """
         raise NotImplementedError(self.get_config)
 
-    def get_hash_algorithm(self):
-        """Get the hash algorithm used by this repository.
-
-        Returns: HashAlgorithm instance (SHA1 or SHA256)
-        """
-        if self._hash_algorithm is None:
-            from .hash import get_hash_algorithm
-
-            # Check if repository uses SHA256
-            try:
-                config = self.get_config()
-                try:
-                    version = int(config.get(("core",), "repositoryformatversion"))
-                except KeyError:
-                    version = 0  # Default version is 0
-
-                if version == 1:
-                    # Check for SHA256 extension
-                    try:
-                        object_format = config.get(("extensions",), "objectformat")
-                        if object_format == b"sha256":
-                            self._hash_algorithm = get_hash_algorithm("sha256")
-                        else:
-                            self._hash_algorithm = get_hash_algorithm("sha1")
-                    except KeyError:
-                        # No objectformat extension, default to SHA1
-                        self._hash_algorithm = get_hash_algorithm("sha1")
-                else:
-                    # Version 0 always uses SHA1
-                    self._hash_algorithm = get_hash_algorithm("sha1")
-            except (KeyError, ValueError):
-                # If we can't read config, default to SHA1
-                self._hash_algorithm = get_hash_algorithm("sha1")
-
-        return self._hash_algorithm
-
     def get_worktree_config(self) -> "ConfigFile":
         """Retrieve the worktree config object."""
         raise NotImplementedError(self.get_worktree_config)
@@ -1190,12 +1163,17 @@ class BaseRepo:
         """Check if a specific Git object or ref is present.
 
         Args:
-          name: Git object SHA1 or ref name
+          name: Git object SHA1/SHA256 or ref name
         """
         if len(name) == 20:
             return RawObjectID(name) in self.object_store or Ref(name) in self.refs
         elif len(name) == 40 and valid_hexsha(name):
             return ObjectID(name) in self.object_store or Ref(name) in self.refs
+        # Check if it's a binary or hex SHA
+        if len(name) == self.object_format.oid_length or (
+            len(name) == self.object_format.hex_length and valid_hexsha(name)
+        ):
+            return name in self.object_store or name in self.refs
         else:
             return Ref(name) in self.refs
 
@@ -1521,6 +1499,21 @@ class Repo(BaseRepo):
             self.worktrees = WorkTreeContainer(self)
         BaseRepo.__init__(self, object_store, self.refs)
 
+        # Determine hash algorithm from config if not already set
+        if self.object_format is None:
+            from .object_format import get_object_format
+
+            if format_version == 1:
+                try:
+                    object_format = config.get((b"extensions",), b"objectformat")
+                    self.object_format = get_object_format(
+                        object_format.decode("ascii")
+                    )
+                except KeyError:
+                    self.object_format = get_object_format("sha1")
+            else:
+                self.object_format = get_object_format("sha1")
+
         self._graftpoints = {}
         graft_file = self.get_named_file(
             os.path.join("info", "grafts"), basedir=self.commondir()
@@ -2120,6 +2113,7 @@ class Repo(BaseRepo):
         controldir: str | bytes | os.PathLike[str],
         bare: bool,
         object_store: PackBasedObjectStore | None = None,
+        object_store: "PackBasedObjectStore | None" = None,
         config: "StackedConfig | None" = None,
         default_branch: bytes | None = None,
         symlinks: bool | None = None,
@@ -2150,15 +2144,20 @@ class Repo(BaseRepo):
         if object_store is None:
             # Get hash algorithm for object store
             from .hash import get_hash_algorithm
+            os.mkdir(os.path.join(controldir, *d))
 
-            hash_alg = get_hash_algorithm(
-                "sha256" if object_format == "sha256" else "sha1"
-            )
+        # Determine hash algorithm
+        from .object_format import get_object_format
+
+        hash_alg = get_object_format(object_format)
+
+        if object_store is None:
             object_store = DiskObjectStore.init(
                 os.path.join(controldir, OBJECTDIR),
                 file_mode=file_mode,
                 dir_mode=dir_mode,
-                hash_algorithm=hash_alg,
+                object_format=hash_alg,
+                os.path.join(controldir, OBJECTDIR), object_format=hash_alg
             )
         ret = cls(path, bare=bare, object_store=object_store)
         if default_branch is None:
@@ -2537,6 +2536,7 @@ class MemoryRepo(BaseRepo):
     def __init__(self) -> None:
         """Create a new repository in memory."""
         from .config import ConfigFile
+        from .object_format import get_object_format
 
         self._reflog: list[Any] = []
         refs_container = DictRefsContainer({}, logger=self._append_reflog)
@@ -2546,6 +2546,8 @@ class MemoryRepo(BaseRepo):
         self._config = ConfigFile()
         self._description: bytes | None = None
         self.filter_context = None
+        # MemoryRepo defaults to SHA1
+        self.object_format = get_object_format("sha1")
 
     def _append_reflog(
         self,
@@ -2740,8 +2742,10 @@ class MemoryRepo(BaseRepo):
             raise ValueError("tree must be specified for MemoryRepo")
 
         c = Commit()
-        if len(tree) != 40:
-            raise ValueError("tree must be a 40-byte hex sha string")
+        if len(tree) != self.object_format.hex_length:
+            raise ValueError(
+                f"tree must be a {self.object_format.hex_length}-character hex sha string"
+            )
         c.tree = tree
 
         config = self.get_config_stack()

+ 2 - 3
tests/compat/test_sha256.py

@@ -24,7 +24,7 @@
 import os
 import tempfile
 
-from dulwich.hash import SHA256
+from dulwich.object_format import SHA256
 from dulwich.objects import Blob, Commit, Tree
 from dulwich.repo import Repo
 
@@ -90,8 +90,7 @@ class GitSHA256CompatibilityTests(CompatTestCase):
         repo = Repo(repo_path)
 
         # Verify dulwich detects SHA256
-        hash_alg = repo.get_hash_algorithm()
-        self.assertEqual(hash_alg, SHA256)
+        self.assertEqual(repo.object_format, SHA256)
 
         # Verify dulwich can read objects
         # Try both main and master branches (git default changed over time)

+ 1 - 1
tests/compat/test_sha256_packs.py

@@ -24,7 +24,7 @@
 import os
 import tempfile
 
-from dulwich.hash import SHA256
+from dulwich.object_format import SHA256
 from dulwich.objects import Blob, Commit, Tree
 from dulwich.pack import load_pack_index_file
 from dulwich.repo import Repo

+ 8 - 11
tests/test_sha256.py

@@ -26,7 +26,7 @@ import shutil
 import tempfile
 import unittest
 
-from dulwich.hash import SHA1, SHA256, get_hash_algorithm
+from dulwich.object_format import SHA1, SHA256, get_object_format
 from dulwich.objects import Blob, Tree, valid_hexsha, zero_sha_for
 from dulwich.repo import MemoryRepo, Repo
 
@@ -54,12 +54,12 @@ class HashAlgorithmTests(unittest.TestCase):
 
     def test_get_hash_algorithm(self):
         """Test getting hash algorithms by name."""
-        self.assertEqual(get_hash_algorithm("sha1"), SHA1)
-        self.assertEqual(get_hash_algorithm("sha256"), SHA256)
-        self.assertEqual(get_hash_algorithm(None), SHA1)  # Default
+        self.assertEqual(get_object_format("sha1"), SHA1)
+        self.assertEqual(get_object_format("sha256"), SHA256)
+        self.assertEqual(get_object_format(None), SHA1)  # Default
 
         with self.assertRaises(ValueError):
-            get_hash_algorithm("invalid")
+            get_object_format("invalid")
 
 
 class ObjectHashingTests(unittest.TestCase):
@@ -167,8 +167,7 @@ class RepositorySHA256Tests(unittest.TestCase):
         self.assertEqual(config.get(("extensions",), "objectformat"), b"sha256")
 
         # Check hash algorithm detection
-        hash_alg = repo.get_hash_algorithm()
-        self.assertEqual(hash_alg, SHA256)
+        self.assertEqual(repo.object_format, SHA256)
 
         repo.close()
 
@@ -186,8 +185,7 @@ class RepositorySHA256Tests(unittest.TestCase):
             config.get(("extensions",), "objectformat")
 
         # Check hash algorithm detection
-        hash_alg = repo.get_hash_algorithm()
-        self.assertEqual(hash_alg, SHA1)
+        self.assertEqual(repo.object_format, SHA1)
 
         repo.close()
 
@@ -205,8 +203,7 @@ class RepositorySHA256Tests(unittest.TestCase):
         repo = MemoryRepo.init_bare([], {}, object_format="sha256")
 
         # Check hash algorithm
-        hash_alg = repo.get_hash_algorithm()
-        self.assertEqual(hash_alg, SHA256)
+        self.assertEqual(repo.object_format, SHA256)
 
 
 if __name__ == "__main__":

+ 1 - 1
tests/test_sha256_pack.py

@@ -26,7 +26,7 @@ import tempfile
 import unittest
 from io import BytesIO
 
-from dulwich.hash import SHA256
+from dulwich.object_format import SHA256
 from dulwich.pack import (
     load_pack_index_file,
     write_pack_index_v2,