Procházet zdrojové kódy

Make Ref and ObjectID newtypes for improved typing

Jelmer Vernooij před 1 měsícem
rodič
revize
b5674746f5

+ 4 - 0
NEWS

@@ -1,5 +1,9 @@
 0.25.0	UNRELEASED
 0.25.0	UNRELEASED
 
 
+ * The ``ObjectID`` and ``Ref`` types are now newtypes, making it harder to
+   accidentally pass the wrong type - as notified by mypy. Most of this is in
+   lower-level code. (Jelmer Vernooij)
+
  * Implement support for ``core.sharedRepository`` configuration option.
  * Implement support for ``core.sharedRepository`` configuration option.
    Repository files and directories now respect shared repository permissions
    Repository files and directories now respect shared repository permissions
    for group-writable or world-writable repositories. Affects loose objects,
    for group-writable or world-writable repositories. Affects loose objects,

+ 2 - 2
dulwich/annotate.py

@@ -39,7 +39,7 @@ from dulwich.walk import (
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from dulwich.diff_tree import TreeChange
     from dulwich.diff_tree import TreeChange
     from dulwich.object_store import BaseObjectStore
     from dulwich.object_store import BaseObjectStore
-    from dulwich.objects import Commit, TreeEntry
+    from dulwich.objects import Commit, ObjectID, TreeEntry
 
 
 # Walk over ancestry graph breadth-first
 # Walk over ancestry graph breadth-first
 # When checking each revision, find lines that according to difflib.Differ()
 # When checking each revision, find lines that according to difflib.Differ()
@@ -74,7 +74,7 @@ def update_lines(
 
 
 def annotate_lines(
 def annotate_lines(
     store: "BaseObjectStore",
     store: "BaseObjectStore",
-    commit_id: bytes,
+    commit_id: "ObjectID",
     path: bytes,
     path: bytes,
     order: str = ORDER_DATE,
     order: str = ORDER_DATE,
     lines: Sequence[tuple[tuple["Commit", "TreeEntry"], bytes]] | None = None,
     lines: Sequence[tuple[tuple["Commit", "TreeEntry"], bytes]] | None = None,

+ 28 - 26
dulwich/bisect.py

@@ -24,7 +24,8 @@ import os
 from collections.abc import Sequence, Set
 from collections.abc import Sequence, Set
 
 
 from dulwich.object_store import peel_sha
 from dulwich.object_store import peel_sha
-from dulwich.objects import Commit
+from dulwich.objects import Commit, ObjectID
+from dulwich.refs import HEADREF, Ref
 from dulwich.repo import Repo
 from dulwich.repo import Repo
 
 
 
 
@@ -47,8 +48,8 @@ class BisectState:
 
 
     def start(
     def start(
         self,
         self,
-        bad: bytes | None = None,
-        good: Sequence[bytes] | None = None,
+        bad: ObjectID | None = None,
+        good: Sequence[ObjectID] | None = None,
         paths: Sequence[bytes] | None = None,
         paths: Sequence[bytes] | None = None,
         no_checkout: bool = False,
         no_checkout: bool = False,
         term_bad: str = "bad",
         term_bad: str = "bad",
@@ -73,11 +74,12 @@ class BisectState:
 
 
         # Store current branch/commit
         # Store current branch/commit
         try:
         try:
-            ref_chain, sha = self.repo.refs.follow(b"HEAD")
+            ref_chain, sha = self.repo.refs.follow(HEADREF)
             if sha is None:
             if sha is None:
                 # No HEAD exists
                 # No HEAD exists
                 raise ValueError("Cannot start bisect: repository has no HEAD")
                 raise ValueError("Cannot start bisect: repository has no HEAD")
             # Use the first non-HEAD ref in the chain, or the SHA itself
             # Use the first non-HEAD ref in the chain, or the SHA itself
+            current_branch: Ref | ObjectID
             if len(ref_chain) > 1:
             if len(ref_chain) > 1:
                 current_branch = ref_chain[1]  # The actual branch ref
                 current_branch = ref_chain[1]  # The actual branch ref
             else:
             else:
@@ -124,7 +126,7 @@ class BisectState:
             for g in good:
             for g in good:
                 self.mark_good(g)
                 self.mark_good(g)
 
 
-    def mark_bad(self, rev: bytes | None = None) -> bytes | None:
+    def mark_bad(self, rev: ObjectID | None = None) -> ObjectID | None:
         """Mark a commit as bad.
         """Mark a commit as bad.
 
 
         Args:
         Args:
@@ -154,7 +156,7 @@ class BisectState:
 
 
         return self._find_next_commit()
         return self._find_next_commit()
 
 
-    def mark_good(self, rev: bytes | None = None) -> bytes | None:
+    def mark_good(self, rev: ObjectID | None = None) -> ObjectID | None:
         """Mark a commit as good.
         """Mark a commit as good.
 
 
         Args:
         Args:
@@ -186,7 +188,7 @@ class BisectState:
 
 
         return self._find_next_commit()
         return self._find_next_commit()
 
 
-    def skip(self, revs: Sequence[bytes] | None = None) -> bytes | None:
+    def skip(self, revs: Sequence[ObjectID] | None = None) -> ObjectID | None:
         """Skip one or more commits.
         """Skip one or more commits.
 
 
         Args:
         Args:
@@ -213,7 +215,7 @@ class BisectState:
 
 
         return self._find_next_commit()
         return self._find_next_commit()
 
 
-    def reset(self, commit: bytes | None = None) -> None:
+    def reset(self, commit: ObjectID | None = None) -> None:
         """Reset bisect state and return to original branch/commit.
         """Reset bisect state and return to original branch/commit.
 
 
         Args:
         Args:
@@ -250,13 +252,13 @@ class BisectState:
         if commit is None:
         if commit is None:
             if original.startswith(b"refs/"):
             if original.startswith(b"refs/"):
                 # It's a branch reference - need to create a symbolic ref
                 # It's a branch reference - need to create a symbolic ref
-                self.repo.refs.set_symbolic_ref(b"HEAD", original)
+                self.repo.refs.set_symbolic_ref(HEADREF, Ref(original))
             else:
             else:
                 # It's a commit SHA
                 # It's a commit SHA
-                self.repo.refs[b"HEAD"] = original
+                self.repo.refs[HEADREF] = ObjectID(original)
         else:
         else:
             commit = peel_sha(self.repo.object_store, commit)[1].id
             commit = peel_sha(self.repo.object_store, commit)[1].id
-            self.repo.refs[b"HEAD"] = commit
+            self.repo.refs[HEADREF] = commit
 
 
     def get_log(self) -> str:
     def get_log(self) -> str:
         """Get the bisect log."""
         """Get the bisect log."""
@@ -289,16 +291,16 @@ class BisectState:
             if cmd == "start":
             if cmd == "start":
                 self.start()
                 self.start()
             elif cmd == "bad":
             elif cmd == "bad":
-                rev = args[0].encode("ascii") if args else None
+                rev = ObjectID(args[0].encode("ascii")) if args else None
                 self.mark_bad(rev)
                 self.mark_bad(rev)
             elif cmd == "good":
             elif cmd == "good":
-                rev = args[0].encode("ascii") if args else None
+                rev = ObjectID(args[0].encode("ascii")) if args else None
                 self.mark_good(rev)
                 self.mark_good(rev)
             elif cmd == "skip":
             elif cmd == "skip":
-                revs = [arg.encode("ascii") for arg in args] if args else None
+                revs = [ObjectID(arg.encode("ascii")) for arg in args] if args else None
                 self.skip(revs)
                 self.skip(revs)
 
 
-    def _find_next_commit(self) -> bytes | None:
+    def _find_next_commit(self) -> ObjectID | None:
         """Find the next commit to test using binary search.
         """Find the next commit to test using binary search.
 
 
         Returns:
         Returns:
@@ -311,15 +313,15 @@ class BisectState:
             return None
             return None
 
 
         with open(bad_ref_path, "rb") as f:
         with open(bad_ref_path, "rb") as f:
-            bad_sha = f.read().strip()
+            bad_sha = ObjectID(f.read().strip())
 
 
         # Get all good commits
         # Get all good commits
-        good_shas = []
+        good_shas: list[ObjectID] = []
         bisect_refs_dir = os.path.join(self.repo.controldir(), "refs", "bisect")
         bisect_refs_dir = os.path.join(self.repo.controldir(), "refs", "bisect")
         for filename in os.listdir(bisect_refs_dir):
         for filename in os.listdir(bisect_refs_dir):
             if filename.startswith("good-"):
             if filename.startswith("good-"):
                 with open(os.path.join(bisect_refs_dir, filename), "rb") as f:
                 with open(os.path.join(bisect_refs_dir, filename), "rb") as f:
-                    good_shas.append(f.read().strip())
+                    good_shas.append(ObjectID(f.read().strip()))
 
 
         if not good_shas:
         if not good_shas:
             self._append_to_log(
             self._append_to_log(
@@ -328,11 +330,11 @@ class BisectState:
             return None
             return None
 
 
         # Get skip commits
         # Get skip commits
-        skip_shas = set()
+        skip_shas: set[ObjectID] = set()
         for filename in os.listdir(bisect_refs_dir):
         for filename in os.listdir(bisect_refs_dir):
             if filename.startswith("skip-"):
             if filename.startswith("skip-"):
                 with open(os.path.join(bisect_refs_dir, filename), "rb") as f:
                 with open(os.path.join(bisect_refs_dir, filename), "rb") as f:
-                    skip_shas.add(f.read().strip())
+                    skip_shas.add(ObjectID(f.read().strip()))
 
 
         # Find commits between good and bad
         # Find commits between good and bad
         candidates = self._find_bisect_candidates(bad_sha, good_shas, skip_shas)
         candidates = self._find_bisect_candidates(bad_sha, good_shas, skip_shas)
@@ -367,8 +369,8 @@ class BisectState:
         return next_commit
         return next_commit
 
 
     def _find_bisect_candidates(
     def _find_bisect_candidates(
-        self, bad_sha: bytes, good_shas: Sequence[bytes], skip_shas: Set[bytes]
-    ) -> list[bytes]:
+        self, bad_sha: ObjectID, good_shas: Sequence[ObjectID], skip_shas: Set[ObjectID]
+    ) -> list[ObjectID]:
         """Find all commits between good and bad commits.
         """Find all commits between good and bad commits.
 
 
         Args:
         Args:
@@ -382,9 +384,9 @@ class BisectState:
         # Use git's graph walking to find commits
         # Use git's graph walking to find commits
         # This is a simplified version - a full implementation would need
         # This is a simplified version - a full implementation would need
         # to handle merge commits properly
         # to handle merge commits properly
-        candidates = []
-        visited = set(good_shas)
-        queue = [bad_sha]
+        candidates: list[ObjectID] = []
+        visited: set[ObjectID] = set(good_shas)
+        queue: list[ObjectID] = [bad_sha]
 
 
         while queue:
         while queue:
             sha = queue.pop(0)
             sha = queue.pop(0)
@@ -410,7 +412,7 @@ class BisectState:
 
 
         return candidates
         return candidates
 
 
-    def _get_commit_subject(self, sha: bytes) -> str:
+    def _get_commit_subject(self, sha: ObjectID) -> str:
         """Get the subject line of a commit message."""
         """Get the subject line of a commit message."""
         obj = self.repo.object_store[sha]
         obj = self.repo.object_store[sha]
         if isinstance(obj, Commit):
         if isinstance(obj, Commit):

+ 39 - 29
dulwich/bitmap.py

@@ -36,11 +36,21 @@ from io import BytesIO
 from typing import IO, TYPE_CHECKING
 from typing import IO, TYPE_CHECKING
 
 
 from .file import GitFile
 from .file import GitFile
-from .objects import Blob, Commit, Tag, Tree
+from .objects import (
+    Blob,
+    Commit,
+    ObjectID,
+    RawObjectID,
+    Tag,
+    Tree,
+    hex_to_sha,
+    sha_to_hex,
+)
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
     from .object_store import BaseObjectStore
     from .pack import Pack, PackIndex
     from .pack import Pack, PackIndex
+    from .refs import Ref
 
 
 # Bitmap file signature
 # Bitmap file signature
 BITMAP_SIGNATURE = b"BITM"
 BITMAP_SIGNATURE = b"BITM"
@@ -781,10 +791,10 @@ def _compute_name_hash(name: bytes) -> int:
 
 
 
 
 def select_bitmap_commits(
 def select_bitmap_commits(
-    refs: dict[bytes, bytes],
+    refs: dict["Ref", ObjectID],
     object_store: "BaseObjectStore",
     object_store: "BaseObjectStore",
     commit_interval: int = DEFAULT_COMMIT_INTERVAL,
     commit_interval: int = DEFAULT_COMMIT_INTERVAL,
-) -> list[bytes]:
+) -> list[ObjectID]:
     """Select commits for bitmap generation.
     """Select commits for bitmap generation.
 
 
     Uses Git's strategy:
     Uses Git's strategy:
@@ -849,8 +859,8 @@ def select_bitmap_commits(
 
 
 
 
 def build_reachability_bitmap(
 def build_reachability_bitmap(
-    commit_sha: bytes,
-    sha_to_pos: dict[bytes, int],
+    commit_sha: ObjectID,
+    sha_to_pos: dict[RawObjectID, int],
     object_store: "BaseObjectStore",
     object_store: "BaseObjectStore",
 ) -> EWAHBitmap:
 ) -> EWAHBitmap:
     """Build a reachability bitmap for a commit.
     """Build a reachability bitmap for a commit.
@@ -879,8 +889,10 @@ def build_reachability_bitmap(
         seen.add(sha)
         seen.add(sha)
 
 
         # Add this object to the bitmap if it's in the pack
         # Add this object to the bitmap if it's in the pack
-        if sha in sha_to_pos:
-            bitmap.add(sha_to_pos[sha])
+        # Convert hex SHA to binary for pack index lookup
+        raw_sha = hex_to_sha(sha)
+        if raw_sha in sha_to_pos:
+            bitmap.add(sha_to_pos[raw_sha])
 
 
         # Get the object and traverse its references
         # Get the object and traverse its references
         try:
         try:
@@ -902,9 +914,9 @@ def build_reachability_bitmap(
 
 
 
 
 def apply_xor_compression(
 def apply_xor_compression(
-    bitmaps: list[tuple[bytes, EWAHBitmap]],
+    bitmaps: list[tuple[ObjectID, EWAHBitmap]],
     max_xor_offset: int = MAX_XOR_OFFSET,
     max_xor_offset: int = MAX_XOR_OFFSET,
-) -> list[tuple[bytes, EWAHBitmap, int]]:
+) -> list[tuple[ObjectID, EWAHBitmap, int]]:
     """Apply XOR compression to bitmaps.
     """Apply XOR compression to bitmaps.
 
 
     XOR compression stores some bitmaps as XOR differences from previous bitmaps,
     XOR compression stores some bitmaps as XOR differences from previous bitmaps,
@@ -942,7 +954,7 @@ def apply_xor_compression(
 
 
 
 
 def build_type_bitmaps(
 def build_type_bitmaps(
-    sha_to_pos: dict[bytes, int],
+    sha_to_pos: dict["RawObjectID", int],
     object_store: "BaseObjectStore",
     object_store: "BaseObjectStore",
 ) -> tuple[EWAHBitmap, EWAHBitmap, EWAHBitmap, EWAHBitmap]:
 ) -> tuple[EWAHBitmap, EWAHBitmap, EWAHBitmap, EWAHBitmap]:
     """Build type bitmaps for all objects in a pack.
     """Build type bitmaps for all objects in a pack.
@@ -956,8 +968,6 @@ def build_type_bitmaps(
     Returns:
     Returns:
         Tuple of (commit_bitmap, tree_bitmap, blob_bitmap, tag_bitmap)
         Tuple of (commit_bitmap, tree_bitmap, blob_bitmap, tag_bitmap)
     """
     """
-    from .objects import sha_to_hex
-
     commit_bitmap = EWAHBitmap()
     commit_bitmap = EWAHBitmap()
     tree_bitmap = EWAHBitmap()
     tree_bitmap = EWAHBitmap()
     blob_bitmap = EWAHBitmap()
     blob_bitmap = EWAHBitmap()
@@ -965,7 +975,7 @@ def build_type_bitmaps(
 
 
     for sha, pos in sha_to_pos.items():
     for sha, pos in sha_to_pos.items():
         # Pack index returns binary SHA (20 bytes), but object_store expects hex SHA (40 bytes)
         # Pack index returns binary SHA (20 bytes), but object_store expects hex SHA (40 bytes)
-        hex_sha = sha_to_hex(sha) if len(sha) == 20 else sha
+        hex_sha = sha_to_hex(sha) if len(sha) == 20 else ObjectID(sha)
         try:
         try:
             obj = object_store[hex_sha]
             obj = object_store[hex_sha]
         except KeyError:
         except KeyError:
@@ -987,7 +997,7 @@ def build_type_bitmaps(
 
 
 
 
 def build_name_hash_cache(
 def build_name_hash_cache(
-    sha_to_pos: dict[bytes, int],
+    sha_to_pos: dict["RawObjectID", int],
     object_store: "BaseObjectStore",
     object_store: "BaseObjectStore",
 ) -> list[int]:
 ) -> list[int]:
     """Build name-hash cache for all objects in a pack.
     """Build name-hash cache for all objects in a pack.
@@ -1002,15 +1012,13 @@ def build_name_hash_cache(
     Returns:
     Returns:
         List of 32-bit hash values, one per object in the pack
         List of 32-bit hash values, one per object in the pack
     """
     """
-    from .objects import sha_to_hex
-
     # Pre-allocate list with correct size
     # Pre-allocate list with correct size
     num_objects = len(sha_to_pos)
     num_objects = len(sha_to_pos)
     name_hashes = [0] * num_objects
     name_hashes = [0] * num_objects
 
 
     for sha, pos in sha_to_pos.items():
     for sha, pos in sha_to_pos.items():
         # Pack index returns binary SHA (20 bytes), but object_store expects hex SHA (40 bytes)
         # Pack index returns binary SHA (20 bytes), but object_store expects hex SHA (40 bytes)
-        hex_sha = sha_to_hex(sha) if len(sha) == 20 else sha
+        hex_sha = sha_to_hex(sha) if len(sha) == 20 else ObjectID(sha)
         try:
         try:
             obj = object_store[hex_sha]
             obj = object_store[hex_sha]
         except KeyError:
         except KeyError:
@@ -1038,7 +1046,7 @@ def build_name_hash_cache(
 def generate_bitmap(
 def generate_bitmap(
     pack_index: "PackIndex",
     pack_index: "PackIndex",
     object_store: "BaseObjectStore",
     object_store: "BaseObjectStore",
-    refs: dict[bytes, bytes],
+    refs: dict["Ref", ObjectID],
     pack_checksum: bytes,
     pack_checksum: bytes,
     include_hash_cache: bool = True,
     include_hash_cache: bool = True,
     include_lookup_table: bool = True,
     include_lookup_table: bool = True,
@@ -1068,7 +1076,7 @@ def generate_bitmap(
 
 
     # Build mapping from SHA to position in pack index ONCE
     # Build mapping from SHA to position in pack index ONCE
     # This is used by all subsequent operations and avoids repeated enumeration
     # This is used by all subsequent operations and avoids repeated enumeration
-    sha_to_pos = {}
+    sha_to_pos: dict[RawObjectID, int] = {}
     for pos, (sha, _offset, _crc32) in enumerate(pack_index.iterentries()):
     for pos, (sha, _offset, _crc32) in enumerate(pack_index.iterentries()):
         sha_to_pos[sha] = pos
         sha_to_pos[sha] = pos
 
 
@@ -1120,11 +1128,12 @@ def generate_bitmap(
 
 
     # Add bitmap entries
     # Add bitmap entries
     for commit_sha, xor_bitmap, xor_offset in compressed_bitmaps:
     for commit_sha, xor_bitmap, xor_offset in compressed_bitmaps:
-        if commit_sha not in sha_to_pos:
+        raw_commit_sha = hex_to_sha(commit_sha)
+        if raw_commit_sha not in sha_to_pos:
             continue
             continue
 
 
         entry = BitmapEntry(
         entry = BitmapEntry(
-            object_pos=sha_to_pos[commit_sha],
+            object_pos=sha_to_pos[raw_commit_sha],
             xor_offset=xor_offset,
             xor_offset=xor_offset,
             flags=0,
             flags=0,
             bitmap=xor_bitmap,
             bitmap=xor_bitmap,
@@ -1154,8 +1163,8 @@ def generate_bitmap(
 
 
 
 
 def find_commit_bitmaps(
 def find_commit_bitmaps(
-    commit_shas: set[bytes], packs: Iterable["Pack"]
-) -> dict[bytes, tuple["Pack", "PackBitmap", dict[bytes, int]]]:
+    commit_shas: set["ObjectID"], packs: Iterable["Pack"]
+) -> dict["ObjectID", tuple["Pack", "PackBitmap", dict[RawObjectID, int]]]:
     """Find which packs have bitmaps for the given commits.
     """Find which packs have bitmaps for the given commits.
 
 
     Args:
     Args:
@@ -1178,14 +1187,15 @@ def find_commit_bitmaps(
             continue
             continue
 
 
         # Build SHA to position mapping for this pack
         # Build SHA to position mapping for this pack
-        sha_to_pos = {}
+        sha_to_pos: dict[RawObjectID, int] = {}
         for pos, (sha, _offset, _crc32) in enumerate(pack.index.iterentries()):
         for pos, (sha, _offset, _crc32) in enumerate(pack.index.iterentries()):
             sha_to_pos[sha] = pos
             sha_to_pos[sha] = pos
 
 
         # Check which commits have bitmaps
         # Check which commits have bitmaps
         for commit_sha in list(remaining):
         for commit_sha in list(remaining):
             if pack_bitmap.has_commit(commit_sha):
             if pack_bitmap.has_commit(commit_sha):
-                if commit_sha in sha_to_pos:
+                raw_commit_sha = hex_to_sha(commit_sha)
+                if raw_commit_sha in sha_to_pos:
                     result[commit_sha] = (pack, pack_bitmap, sha_to_pos)
                     result[commit_sha] = (pack, pack_bitmap, sha_to_pos)
                     remaining.remove(commit_sha)
                     remaining.remove(commit_sha)
 
 
@@ -1196,7 +1206,7 @@ def bitmap_to_object_shas(
     bitmap: EWAHBitmap,
     bitmap: EWAHBitmap,
     pack_index: "PackIndex",
     pack_index: "PackIndex",
     type_filter: EWAHBitmap | None = None,
     type_filter: EWAHBitmap | None = None,
-) -> set[bytes]:
+) -> set[ObjectID]:
     """Convert a bitmap to a set of object SHAs.
     """Convert a bitmap to a set of object SHAs.
 
 
     Args:
     Args:
@@ -1205,15 +1215,15 @@ def bitmap_to_object_shas(
         type_filter: Optional type bitmap to filter results (e.g., commits only)
         type_filter: Optional type bitmap to filter results (e.g., commits only)
 
 
     Returns:
     Returns:
-        Set of object SHAs
+        Set of object SHAs (hex format)
     """
     """
-    result = set()
+    result: set[ObjectID] = set()
 
 
     for pos, (sha, _offset, _crc32) in enumerate(pack_index.iterentries()):
     for pos, (sha, _offset, _crc32) in enumerate(pack_index.iterentries()):
         # Check if this position is in the bitmap
         # Check if this position is in the bitmap
         if pos in bitmap:
         if pos in bitmap:
             # Apply type filter if provided
             # Apply type filter if provided
             if type_filter is None or pos in type_filter:
             if type_filter is None or pos in type_filter:
-                result.add(sha)
+                result.add(sha_to_hex(sha))
 
 
     return result
     return result

+ 13 - 11
dulwich/bundle.py

@@ -30,7 +30,9 @@ from typing import (
     runtime_checkable,
     runtime_checkable,
 )
 )
 
 
+from .objects import ObjectID
 from .pack import PackData, UnpackedObject, write_pack_data
 from .pack import PackData, UnpackedObject, write_pack_data
+from .refs import Ref
 
 
 
 
 @runtime_checkable
 @runtime_checkable
@@ -57,8 +59,8 @@ class Bundle:
     version: int | None
     version: int | None
 
 
     capabilities: dict[str, str | None]
     capabilities: dict[str, str | None]
-    prerequisites: list[tuple[bytes, bytes]]
-    references: dict[bytes, bytes]
+    prerequisites: list[tuple[ObjectID, bytes]]
+    references: dict[Ref, ObjectID]
     pack_data: PackDataLike | None
     pack_data: PackDataLike | None
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
@@ -121,7 +123,7 @@ class Bundle:
 def _read_bundle(f: BinaryIO, version: int) -> Bundle:
 def _read_bundle(f: BinaryIO, version: int) -> Bundle:
     capabilities = {}
     capabilities = {}
     prerequisites = []
     prerequisites = []
-    references = {}
+    references: dict[Ref, ObjectID] = {}
     line = f.readline()
     line = f.readline()
     if version >= 3:
     if version >= 3:
         while line.startswith(b"@"):
         while line.startswith(b"@"):
@@ -136,11 +138,11 @@ def _read_bundle(f: BinaryIO, version: int) -> Bundle:
             line = f.readline()
             line = f.readline()
     while line.startswith(b"-"):
     while line.startswith(b"-"):
         (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
         (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
-        prerequisites.append((obj_id, comment))
+        prerequisites.append((ObjectID(obj_id), comment))
         line = f.readline()
         line = f.readline()
     while line != b"\n":
     while line != b"\n":
         (obj_id, ref) = line.rstrip(b"\n").split(b" ", 1)
         (obj_id, ref) = line.rstrip(b"\n").split(b" ", 1)
-        references[ref] = obj_id
+        references[Ref(ref)] = ObjectID(obj_id)
         line = f.readline()
         line = f.readline()
     # Extract pack data to separate stream since PackData expects
     # Extract pack data to separate stream since PackData expects
     # the file to start with PACK header at position 0
     # the file to start with PACK header at position 0
@@ -220,7 +222,7 @@ def write_bundle(f: BinaryIO, bundle: Bundle) -> None:
 
 
 def create_bundle_from_repo(
 def create_bundle_from_repo(
     repo: "BaseRepo",
     repo: "BaseRepo",
-    refs: Sequence[bytes] | None = None,
+    refs: Sequence[Ref] | None = None,
     prerequisites: Sequence[bytes] | None = None,
     prerequisites: Sequence[bytes] | None = None,
     version: int | None = None,
     version: int | None = None,
     capabilities: dict[str, str | None] | None = None,
     capabilities: dict[str, str | None] | None = None,
@@ -249,8 +251,8 @@ def create_bundle_from_repo(
         capabilities = {}
         capabilities = {}
 
 
     # Build the references dictionary for the bundle
     # Build the references dictionary for the bundle
-    bundle_refs = {}
-    want_objects = set()
+    bundle_refs: dict[Ref, ObjectID] = {}
+    want_objects: set[ObjectID] = set()
 
 
     for ref in refs:
     for ref in refs:
         if ref in repo.refs:
         if ref in repo.refs:
@@ -268,7 +270,7 @@ def create_bundle_from_repo(
 
 
     # Convert prerequisites to proper format
     # Convert prerequisites to proper format
     bundle_prerequisites = []
     bundle_prerequisites = []
-    have_objects = set()
+    have_objects: set[ObjectID] = set()
     for prereq in prerequisites:
     for prereq in prerequisites:
         if not isinstance(prereq, bytes):
         if not isinstance(prereq, bytes):
             raise TypeError(
             raise TypeError(
@@ -284,8 +286,8 @@ def create_bundle_from_repo(
         except ValueError:
         except ValueError:
             raise ValueError(f"Invalid prerequisite format: {prereq!r}")
             raise ValueError(f"Invalid prerequisite format: {prereq!r}")
         # Store hex in bundle and for pack generation
         # Store hex in bundle and for pack generation
-        bundle_prerequisites.append((prereq, b""))
-        have_objects.add(prereq)
+        bundle_prerequisites.append((ObjectID(prereq), b""))
+        have_objects.add(ObjectID(prereq))
 
 
     # Generate pack data containing all objects needed for the refs
     # Generate pack data containing all objects needed for the refs
     pack_count, pack_objects = repo.generate_pack_data(
     pack_count, pack_objects = repo.generate_pack_data(

+ 32 - 25
dulwich/cli.py

@@ -53,6 +53,7 @@ from typing import (
 
 
 from dulwich import porcelain
 from dulwich import porcelain
 from dulwich._typing import Buffer
 from dulwich._typing import Buffer
+from dulwich.refs import HEADREF, Ref
 
 
 from .bundle import Bundle, create_bundle_from_repo, read_bundle, write_bundle
 from .bundle import Bundle, create_bundle_from_repo, read_bundle, write_bundle
 from .client import get_transport_and_path
 from .client import get_transport_and_path
@@ -65,7 +66,7 @@ from .errors import (
 )
 )
 from .index import Index
 from .index import Index
 from .log_utils import _configure_logging_from_trace
 from .log_utils import _configure_logging_from_trace
-from .objects import Commit, sha_to_hex, valid_hexsha
+from .objects import Commit, ObjectID, RawObjectID, sha_to_hex, valid_hexsha
 from .objectspec import parse_commit_range
 from .objectspec import parse_commit_range
 from .pack import Pack
 from .pack import Pack
 from .patch import DiffAlgorithmNotAvailable
 from .patch import DiffAlgorithmNotAvailable
@@ -1237,9 +1238,13 @@ class cmd_fetch_pack(Command):
         else:
         else:
 
 
             def determine_wants(
             def determine_wants(
-                refs: Mapping[bytes, bytes], depth: int | None = None
-            ) -> list[bytes]:
-                return [y.encode("utf-8") for y in args.refs if y not in r.object_store]
+                refs: Mapping[Ref, ObjectID], depth: int | None = None
+            ) -> list[ObjectID]:
+                return [
+                    ObjectID(y.encode("utf-8"))
+                    for y in args.refs
+                    if y not in r.object_store
+                ]
 
 
         client.fetch(path.encode("utf-8"), r, determine_wants)
         client.fetch(path.encode("utf-8"), r, determine_wants)
 
 
@@ -1552,7 +1557,7 @@ class cmd_dump_pack(Command):
         basename, _ = os.path.splitext(parsed_args.filename)
         basename, _ = os.path.splitext(parsed_args.filename)
         x = Pack(basename)
         x = Pack(basename)
         logger.info("Object names checksum: %s", x.name().decode("ascii", "replace"))
         logger.info("Object names checksum: %s", x.name().decode("ascii", "replace"))
-        logger.info("Checksum: %r", sha_to_hex(x.get_stored_checksum()))
+        logger.info("Checksum: %r", sha_to_hex(RawObjectID(x.get_stored_checksum())))
         x.check()
         x.check()
         logger.info("Length: %d", len(x))
         logger.info("Length: %d", len(x))
         for name in x:
         for name in x:
@@ -1921,7 +1926,7 @@ def _get_commit_message_with_template(
     # Add branch info if repo is provided
     # Add branch info if repo is provided
     if repo:
     if repo:
         try:
         try:
-            ref_names, _ref_sha = repo.refs.follow(b"HEAD")
+            ref_names, _ref_sha = repo.refs.follow(HEADREF)
             ref_path = ref_names[-1]  # Get the final reference
             ref_path = ref_names[-1]  # Get the final reference
             if ref_path.startswith(b"refs/heads/"):
             if ref_path.startswith(b"refs/heads/"):
                 branch = ref_path[11:]  # Remove 'refs/heads/' prefix
                 branch = ref_path[11:]  # Remove 'refs/heads/' prefix
@@ -3348,7 +3353,7 @@ class cmd_pack_objects(Command):
         if not parsed_args.stdout and not parsed_args.basename:
         if not parsed_args.stdout and not parsed_args.basename:
             parser.error("basename required when not using --stdout")
             parser.error("basename required when not using --stdout")
 
 
-        object_ids = [line.strip().encode() for line in sys.stdin.readlines()]
+        object_ids = [ObjectID(line.strip().encode()) for line in sys.stdin.readlines()]
         deltify = parsed_args.deltify
         deltify = parsed_args.deltify
         reuse_deltas = not parsed_args.no_reuse_deltas
         reuse_deltas = not parsed_args.no_reuse_deltas
 
 
@@ -4102,7 +4107,7 @@ class cmd_bisect(SuperCommand):
                     with porcelain.open_repo_closing(".") as r:
                     with porcelain.open_repo_closing(".") as r:
                         bad_ref = os.path.join(r.controldir(), "refs", "bisect", "bad")
                         bad_ref = os.path.join(r.controldir(), "refs", "bisect", "bad")
                         with open(bad_ref, "rb") as f:
                         with open(bad_ref, "rb") as f:
-                            bad_sha = f.read().strip()
+                            bad_sha = ObjectID(f.read().strip())
                         commit = r.object_store[bad_sha]
                         commit = r.object_store[bad_sha]
                         assert isinstance(commit, Commit)
                         assert isinstance(commit, Commit)
                         message = commit.message.decode(
                         message = commit.message.decode(
@@ -5342,7 +5347,7 @@ class cmd_filter_branch(Command):
         tree_filter = None
         tree_filter = None
         if parsed_args.tree_filter:
         if parsed_args.tree_filter:
 
 
-            def tree_filter(tree_sha: bytes, tmpdir: str) -> bytes:
+            def tree_filter(tree_sha: ObjectID, tmpdir: str) -> ObjectID:
                 from dulwich.objects import Blob, Tree
                 from dulwich.objects import Blob, Tree
 
 
                 # Export tree to tmpdir
                 # Export tree to tmpdir
@@ -5364,7 +5369,7 @@ class cmd_filter_branch(Command):
                     run_filter(parsed_args.tree_filter, cwd=tmpdir)
                     run_filter(parsed_args.tree_filter, cwd=tmpdir)
 
 
                     # Rebuild tree from modified temp directory
                     # Rebuild tree from modified temp directory
-                    def build_tree_from_dir(dir_path: str) -> bytes:
+                    def build_tree_from_dir(dir_path: str) -> ObjectID:
                         tree = Tree()
                         tree = Tree()
                         for name in sorted(os.listdir(dir_path)):
                         for name in sorted(os.listdir(dir_path)):
                             if name.startswith("."):
                             if name.startswith("."):
@@ -5393,7 +5398,7 @@ class cmd_filter_branch(Command):
         index_filter = None
         index_filter = None
         if parsed_args.index_filter:
         if parsed_args.index_filter:
 
 
-            def index_filter(tree_sha: bytes, index_path: str) -> bytes | None:
+            def index_filter(tree_sha: ObjectID, index_path: str) -> ObjectID | None:
                 run_filter(
                 run_filter(
                     parsed_args.index_filter, extra_env={"GIT_INDEX_FILE": index_path}
                     parsed_args.index_filter, extra_env={"GIT_INDEX_FILE": index_path}
                 )
                 )
@@ -5402,7 +5407,7 @@ class cmd_filter_branch(Command):
         parent_filter = None
         parent_filter = None
         if parsed_args.parent_filter:
         if parsed_args.parent_filter:
 
 
-            def parent_filter(parents: Sequence[bytes]) -> list[bytes]:
+            def parent_filter(parents: Sequence[ObjectID]) -> list[ObjectID]:
                 parent_str = " ".join(p.hex() for p in parents)
                 parent_str = " ".join(p.hex() for p in parents)
                 result = run_filter(
                 result = run_filter(
                     parsed_args.parent_filter, input_data=parent_str.encode()
                     parsed_args.parent_filter, input_data=parent_str.encode()
@@ -5417,13 +5422,15 @@ class cmd_filter_branch(Command):
                 for sha in output.split():
                 for sha in output.split():
                     sha_bytes = sha.encode()
                     sha_bytes = sha.encode()
                     if valid_hexsha(sha_bytes):
                     if valid_hexsha(sha_bytes):
-                        new_parents.append(sha_bytes)
+                        new_parents.append(ObjectID(sha_bytes))
                 return new_parents
                 return new_parents
 
 
         commit_filter = None
         commit_filter = None
         if parsed_args.commit_filter:
         if parsed_args.commit_filter:
 
 
-            def commit_filter(commit_obj: Commit, tree_sha: bytes) -> bytes | None:
+            def commit_filter(
+                commit_obj: Commit, tree_sha: ObjectID
+            ) -> ObjectID | None:
                 # The filter receives: tree parent1 parent2...
                 # The filter receives: tree parent1 parent2...
                 cmd_input = tree_sha.hex()
                 cmd_input = tree_sha.hex()
                 for parent in commit_obj.parents:
                 for parent in commit_obj.parents:
@@ -5442,7 +5449,7 @@ class cmd_filter_branch(Command):
                     return None  # Skip commit
                     return None  # Skip commit
 
 
                 if valid_hexsha(output):
                 if valid_hexsha(output):
-                    return output.encode()
+                    return ObjectID(output.encode())
                 return None
                 return None
 
 
         tag_name_filter = None
         tag_name_filter = None
@@ -5775,7 +5782,7 @@ class cmd_format_patch(Command):
         parsed_args = parser.parse_args(args)
         parsed_args = parser.parse_args(args)
 
 
         # Parse committish using the new function
         # Parse committish using the new function
-        committish: bytes | tuple[bytes, bytes] | None = None
+        committish: ObjectID | tuple[ObjectID, ObjectID] | None = None
         if parsed_args.committish:
         if parsed_args.committish:
             with Repo(".") as r:
             with Repo(".") as r:
                 range_result = parse_commit_range(r, parsed_args.committish)
                 range_result = parse_commit_range(r, parsed_args.committish)
@@ -5783,7 +5790,7 @@ class cmd_format_patch(Command):
                     # Convert Commit objects to their SHAs
                     # Convert Commit objects to their SHAs
                     committish = (range_result[0].id, range_result[1].id)
                     committish = (range_result[0].id, range_result[1].id)
                 else:
                 else:
-                    committish = (
+                    committish = ObjectID(
                         parsed_args.committish.encode()
                         parsed_args.committish.encode()
                         if isinstance(parsed_args.committish, str)
                         if isinstance(parsed_args.committish, str)
                         else parsed_args.committish
                         else parsed_args.committish
@@ -6025,7 +6032,7 @@ class cmd_bundle(Command):
                     msg = msg.decode("utf-8", "replace")
                     msg = msg.decode("utf-8", "replace")
                 logger.error("%s", msg)
                 logger.error("%s", msg)
 
 
-        refs_to_include = []
+        refs_to_include: list[Ref] = []
         prerequisites = []
         prerequisites = []
 
 
         if parsed_args.all:
         if parsed_args.all:
@@ -6034,7 +6041,7 @@ class cmd_bundle(Command):
             for line in sys.stdin:
             for line in sys.stdin:
                 ref = line.strip().encode("utf-8")
                 ref = line.strip().encode("utf-8")
                 if ref:
                 if ref:
-                    refs_to_include.append(ref)
+                    refs_to_include.append(Ref(ref))
         elif parsed_args.refs:
         elif parsed_args.refs:
             for ref_arg in parsed_args.refs:
             for ref_arg in parsed_args.refs:
                 if ".." in ref_arg:
                 if ".." in ref_arg:
@@ -6046,19 +6053,19 @@ class cmd_bundle(Command):
                         # Split the range to get the end part
                         # Split the range to get the end part
                         end_part = ref_arg.split("..")[1]
                         end_part = ref_arg.split("..")[1]
                         if end_part:  # Not empty (not "A..")
                         if end_part:  # Not empty (not "A..")
-                            end_ref = end_part.encode("utf-8")
+                            end_ref = Ref(end_part.encode("utf-8"))
                             if end_ref in repo.refs:
                             if end_ref in repo.refs:
                                 refs_to_include.append(end_ref)
                                 refs_to_include.append(end_ref)
                     else:
                     else:
-                        sha = repo.refs[ref_arg.encode("utf-8")]
-                        refs_to_include.append(ref_arg.encode("utf-8"))
+                        sha = repo.refs[Ref(ref_arg.encode("utf-8"))]
+                        refs_to_include.append(Ref(ref_arg.encode("utf-8")))
                 else:
                 else:
                     if ref_arg.startswith("^"):
                     if ref_arg.startswith("^"):
-                        sha = repo.refs[ref_arg[1:].encode("utf-8")]
+                        sha = repo.refs[Ref(ref_arg[1:].encode("utf-8"))]
                         prerequisites.append(sha)
                         prerequisites.append(sha)
                     else:
                     else:
-                        sha = repo.refs[ref_arg.encode("utf-8")]
-                        refs_to_include.append(ref_arg.encode("utf-8"))
+                        sha = repo.refs[Ref(ref_arg.encode("utf-8"))]
+                        refs_to_include.append(Ref(ref_arg.encode("utf-8")))
         else:
         else:
             logger.error("No refs specified. Use --all, --stdin, or specify refs")
             logger.error("No refs specified. Use --all, --stdin, or specify refs")
             return 1
             return 1

+ 114 - 102
dulwich/client.py

@@ -67,7 +67,9 @@ import dulwich
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from typing import Protocol as TypingProtocol
     from typing import Protocol as TypingProtocol
 
 
+    from .objects import ObjectID
     from .pack import UnpackedObject
     from .pack import UnpackedObject
+    from .refs import Ref
 
 
     class HTTPResponse(TypingProtocol):
     class HTTPResponse(TypingProtocol):
         """Protocol for HTTP response objects."""
         """Protocol for HTTP response objects."""
@@ -84,8 +86,8 @@ if TYPE_CHECKING:
 
 
         def __call__(
         def __call__(
             self,
             self,
-            have: Set[bytes],
-            want: Set[bytes],
+            have: Set[ObjectID],
+            want: Set[ObjectID],
             *,
             *,
             ofs_delta: bool = False,
             ofs_delta: bool = False,
             progress: Callable[[bytes], None] | None = None,
             progress: Callable[[bytes], None] | None = None,
@@ -98,9 +100,9 @@ if TYPE_CHECKING:
 
 
         def __call__(
         def __call__(
             self,
             self,
-            refs: Mapping[bytes, bytes],
+            refs: Mapping[Ref, ObjectID],
             depth: int | None = None,
             depth: int | None = None,
-        ) -> list[bytes]:
+        ) -> list[ObjectID]:
             """Determine the objects to fetch from the given refs."""
             """Determine the objects to fetch from the given refs."""
             ...
             ...
 
 
@@ -110,6 +112,7 @@ from .config import Config, apply_instead_of, get_xdg_config_home_path
 from .credentials import match_partial_url, match_urls
 from .credentials import match_partial_url, match_urls
 from .errors import GitProtocolError, HangupException, NotGitRepository, SendPackError
 from .errors import GitProtocolError, HangupException, NotGitRepository, SendPackError
 from .object_store import GraphWalker
 from .object_store import GraphWalker
+from .objects import ObjectID
 from .pack import (
 from .pack import (
     PACK_SPOOL_FILE_MAX_SIZE,
     PACK_SPOOL_FILE_MAX_SIZE,
     PackChunkGenerator,
     PackChunkGenerator,
@@ -146,6 +149,7 @@ from .protocol import (
     GIT_PROTOCOL_VERSIONS,
     GIT_PROTOCOL_VERSIONS,
     KNOWN_RECEIVE_CAPABILITIES,
     KNOWN_RECEIVE_CAPABILITIES,
     KNOWN_UPLOAD_CAPABILITIES,
     KNOWN_UPLOAD_CAPABILITIES,
+    PEELED_TAG_SUFFIX,
     SIDE_BAND_CHANNEL_DATA,
     SIDE_BAND_CHANNEL_DATA,
     SIDE_BAND_CHANNEL_FATAL,
     SIDE_BAND_CHANNEL_FATAL,
     SIDE_BAND_CHANNEL_PROGRESS,
     SIDE_BAND_CHANNEL_PROGRESS,
@@ -162,7 +166,7 @@ from .protocol import (
     pkt_seq,
     pkt_seq,
 )
 )
 from .refs import (
 from .refs import (
-    PEELED_TAG_SUFFIX,
+    HEADREF,
     SYMREF,
     SYMREF,
     Ref,
     Ref,
     _import_remote_refs,
     _import_remote_refs,
@@ -181,8 +185,6 @@ from .repo import BaseRepo, Repo
 # behaviour with v1 when no ref-prefix is specified.
 # behaviour with v1 when no ref-prefix is specified.
 DEFAULT_REF_PREFIX = [b"HEAD", b"refs/"]
 DEFAULT_REF_PREFIX = [b"HEAD", b"refs/"]
 
 
-ObjectID = bytes
-
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -216,8 +218,8 @@ class HTTPUnauthorized(Exception):
         self.url = url
         self.url = url
 
 
 
 
-def _to_optional_dict(refs: Mapping[bytes, bytes]) -> dict[bytes, bytes | None]:
-    """Convert a dict[bytes, bytes] to dict[bytes, Optional[bytes]].
+def _to_optional_dict(refs: Mapping[Ref, ObjectID]) -> dict[Ref, ObjectID | None]:
+    """Convert a dict[Ref, ObjectID] to dict[Ref, Optional[ObjectID]].
 
 
     This is needed for compatibility with result types that expect Optional values.
     This is needed for compatibility with result types that expect Optional values.
     """
     """
@@ -350,23 +352,26 @@ def read_server_capabilities(pkt_seq: Iterable[bytes]) -> set[bytes]:
 
 
 def read_pkt_refs_v2(
 def read_pkt_refs_v2(
     pkt_seq: Iterable[bytes],
     pkt_seq: Iterable[bytes],
-) -> tuple[dict[bytes, bytes | None], dict[bytes, bytes], dict[bytes, bytes]]:
+) -> tuple[dict[Ref, ObjectID | None], dict[Ref, Ref], dict[Ref, ObjectID]]:
     """Read references using protocol version 2."""
     """Read references using protocol version 2."""
-    refs: dict[bytes, bytes | None] = {}
-    symrefs = {}
-    peeled = {}
+    refs: dict[Ref, ObjectID | None] = {}
+    symrefs: dict[Ref, Ref] = {}
+    peeled: dict[Ref, ObjectID] = {}
     # Receive refs from server
     # Receive refs from server
     for pkt in pkt_seq:
     for pkt in pkt_seq:
         parts = pkt.rstrip(b"\n").split(b" ")
         parts = pkt.rstrip(b"\n").split(b" ")
-        sha: bytes | None = parts[0]
-        if sha == b"unborn":
+        sha_bytes = parts[0]
+        sha: ObjectID | None
+        if sha_bytes == b"unborn":
             sha = None
             sha = None
-        ref = parts[1]
+        else:
+            sha = ObjectID(sha_bytes)
+        ref = Ref(parts[1])
         for part in parts[2:]:
         for part in parts[2:]:
             if part.startswith(b"peeled:"):
             if part.startswith(b"peeled:"):
-                peeled[ref] = part[7:]
+                peeled[ref] = ObjectID(part[7:])
             elif part.startswith(b"symref-target:"):
             elif part.startswith(b"symref-target:"):
-                symrefs[ref] = part[14:]
+                symrefs[ref] = Ref(part[14:])
             else:
             else:
                 logging.warning("unknown part in pkt-ref: %s", part)
                 logging.warning("unknown part in pkt-ref: %s", part)
         refs[ref] = sha
         refs[ref] = sha
@@ -376,10 +381,10 @@ def read_pkt_refs_v2(
 
 
 def read_pkt_refs_v1(
 def read_pkt_refs_v1(
     pkt_seq: Iterable[bytes],
     pkt_seq: Iterable[bytes],
-) -> tuple[dict[bytes, bytes], set[bytes]]:
+) -> tuple[dict[Ref, ObjectID], set[bytes]]:
     """Read references using protocol version 1."""
     """Read references using protocol version 1."""
     server_capabilities = None
     server_capabilities = None
-    refs: dict[bytes, bytes] = {}
+    refs: dict[Ref, ObjectID] = {}
     # Receive refs from server
     # Receive refs from server
     for pkt in pkt_seq:
     for pkt in pkt_seq:
         (sha, ref) = pkt.rstrip(b"\n").split(None, 1)
         (sha, ref) = pkt.rstrip(b"\n").split(None, 1)
@@ -387,7 +392,7 @@ def read_pkt_refs_v1(
             raise GitProtocolError(ref.decode("utf-8", "replace"))
             raise GitProtocolError(ref.decode("utf-8", "replace"))
         if server_capabilities is None:
         if server_capabilities is None:
             (ref, server_capabilities) = extract_capabilities(ref)
             (ref, server_capabilities) = extract_capabilities(ref)
-        refs[ref] = sha
+        refs[Ref(ref)] = ObjectID(sha)
 
 
     if len(refs) == 0:
     if len(refs) == 0:
         return {}, set()
         return {}, set()
@@ -400,7 +405,7 @@ def read_pkt_refs_v1(
 class _DeprecatedDictProxy:
 class _DeprecatedDictProxy:
     """Base class for result objects that provide deprecated dict-like interface."""
     """Base class for result objects that provide deprecated dict-like interface."""
 
 
-    refs: dict[bytes, bytes | None]  # To be overridden by subclasses
+    refs: dict[Ref, ObjectID | None]  # To be overridden by subclasses
 
 
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
         "clear",
         "clear",
@@ -428,11 +433,11 @@ class _DeprecatedDictProxy:
             stacklevel=3,
             stacklevel=3,
         )
         )
 
 
-    def __contains__(self, name: bytes) -> bool:
+    def __contains__(self, name: Ref) -> bool:
         self._warn_deprecated()
         self._warn_deprecated()
         return name in self.refs
         return name in self.refs
 
 
-    def __getitem__(self, name: bytes) -> bytes | None:
+    def __getitem__(self, name: Ref) -> ObjectID | None:
         self._warn_deprecated()
         self._warn_deprecated()
         return self.refs[name]
         return self.refs[name]
 
 
@@ -440,7 +445,7 @@ class _DeprecatedDictProxy:
         self._warn_deprecated()
         self._warn_deprecated()
         return len(self.refs)
         return len(self.refs)
 
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[Ref]:
         self._warn_deprecated()
         self._warn_deprecated()
         return iter(self.refs)
         return iter(self.refs)
 
 
@@ -463,16 +468,17 @@ class FetchPackResult(_DeprecatedDictProxy):
       agent: User agent string
       agent: User agent string
     """
     """
 
 
-    symrefs: dict[bytes, bytes]
+    refs: dict[Ref, ObjectID | None]
+    symrefs: dict[Ref, Ref]
     agent: bytes | None
     agent: bytes | None
 
 
     def __init__(
     def __init__(
         self,
         self,
-        refs: dict[bytes, bytes | None],
-        symrefs: dict[bytes, bytes],
+        refs: dict[Ref, ObjectID | None],
+        symrefs: dict[Ref, Ref],
         agent: bytes | None,
         agent: bytes | None,
-        new_shallow: set[bytes] | None = None,
-        new_unshallow: set[bytes] | None = None,
+        new_shallow: set[ObjectID] | None = None,
+        new_unshallow: set[ObjectID] | None = None,
     ) -> None:
     ) -> None:
         """Initialize FetchPackResult.
         """Initialize FetchPackResult.
 
 
@@ -515,10 +521,10 @@ class LsRemoteResult(_DeprecatedDictProxy):
       symrefs: Dictionary with remote symrefs
       symrefs: Dictionary with remote symrefs
     """
     """
 
 
-    symrefs: dict[bytes, bytes]
+    symrefs: dict[Ref, Ref]
 
 
     def __init__(
     def __init__(
-        self, refs: dict[bytes, bytes | None], symrefs: dict[bytes, bytes]
+        self, refs: dict[Ref, ObjectID | None], symrefs: dict[Ref, Ref]
     ) -> None:
     ) -> None:
         """Initialize LsRemoteResult.
         """Initialize LsRemoteResult.
 
 
@@ -565,7 +571,7 @@ class SendPackResult(_DeprecatedDictProxy):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        refs: dict[bytes, bytes | None],
+        refs: dict[Ref, ObjectID | None],
         agent: bytes | None = None,
         agent: bytes | None = None,
         ref_status: dict[bytes, str | None] | None = None,
         ref_status: dict[bytes, str | None] | None = None,
     ) -> None:
     ) -> None:
@@ -594,9 +600,11 @@ class SendPackResult(_DeprecatedDictProxy):
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
 
 
 
 
-def _read_shallow_updates(pkt_seq: Iterable[bytes]) -> tuple[set[bytes], set[bytes]]:
-    new_shallow = set()
-    new_unshallow = set()
+def _read_shallow_updates(
+    pkt_seq: Iterable[bytes],
+) -> tuple[set[ObjectID], set[ObjectID]]:
+    new_shallow: set[ObjectID] = set()
+    new_unshallow: set[ObjectID] = set()
     for pkt in pkt_seq:
     for pkt in pkt_seq:
         if pkt == b"shallow-info\n":  # Git-protocol v2
         if pkt == b"shallow-info\n":  # Git-protocol v2
             continue
             continue
@@ -605,9 +613,9 @@ def _read_shallow_updates(pkt_seq: Iterable[bytes]) -> tuple[set[bytes], set[byt
         except ValueError:
         except ValueError:
             raise GitProtocolError(f"unknown command {pkt!r}")
             raise GitProtocolError(f"unknown command {pkt!r}")
         if cmd == COMMAND_SHALLOW:
         if cmd == COMMAND_SHALLOW:
-            new_shallow.add(sha.strip())
+            new_shallow.add(ObjectID(sha.strip()))
         elif cmd == COMMAND_UNSHALLOW:
         elif cmd == COMMAND_UNSHALLOW:
-            new_unshallow.add(sha.strip())
+            new_unshallow.add(ObjectID(sha.strip()))
         else:
         else:
             raise GitProtocolError(f"unknown command {pkt!r}")
             raise GitProtocolError(f"unknown command {pkt!r}")
     return (new_shallow, new_unshallow)
     return (new_shallow, new_unshallow)
@@ -617,11 +625,11 @@ class _v1ReceivePackHeader:
     def __init__(
     def __init__(
         self,
         self,
         capabilities: Sequence[bytes],
         capabilities: Sequence[bytes],
-        old_refs: Mapping[bytes, bytes],
-        new_refs: Mapping[bytes, bytes],
+        old_refs: Mapping[Ref, ObjectID],
+        new_refs: Mapping[Ref, ObjectID],
     ) -> None:
     ) -> None:
-        self.want: set[bytes] = set()
-        self.have: set[bytes] = set()
+        self.want: set[ObjectID] = set()
+        self.have: set[ObjectID] = set()
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self.sent_capabilities = False
         self.sent_capabilities = False
 
 
@@ -631,8 +639,8 @@ class _v1ReceivePackHeader:
     def _handle_receive_pack_head(
     def _handle_receive_pack_head(
         self,
         self,
         capabilities: Sequence[bytes],
         capabilities: Sequence[bytes],
-        old_refs: Mapping[bytes, bytes],
-        new_refs: Mapping[bytes, bytes],
+        old_refs: Mapping[Ref, ObjectID],
+        new_refs: Mapping[Ref, ObjectID],
     ) -> Iterator[bytes | None]:
     ) -> Iterator[bytes | None]:
         """Handle the head of a 'git-receive-pack' request.
         """Handle the head of a 'git-receive-pack' request.
 
 
@@ -713,13 +721,13 @@ def _handle_upload_pack_head(
     proto: Protocol,
     proto: Protocol,
     capabilities: Iterable[bytes],
     capabilities: Iterable[bytes],
     graph_walker: GraphWalker,
     graph_walker: GraphWalker,
-    wants: list[bytes],
+    wants: list[ObjectID],
     can_read: Callable[[], bool] | None,
     can_read: Callable[[], bool] | None,
     depth: int | None,
     depth: int | None,
     protocol_version: int | None,
     protocol_version: int | None,
     shallow_since: str | None = None,
     shallow_since: str | None = None,
     shallow_exclude: list[str] | None = None,
     shallow_exclude: list[str] | None = None,
-) -> tuple[set[bytes] | None, set[bytes] | None]:
+) -> tuple[set[ObjectID] | None, set[ObjectID] | None]:
     """Handle the head of a 'git-upload-pack' request.
     """Handle the head of a 'git-upload-pack' request.
 
 
     Args:
     Args:
@@ -734,8 +742,8 @@ def _handle_upload_pack_head(
       shallow_since: Deepen the history to include commits after this date
       shallow_since: Deepen the history to include commits after this date
       shallow_exclude: Deepen the history to exclude commits reachable from these refs
       shallow_exclude: Deepen the history to exclude commits reachable from these refs
     """
     """
-    new_shallow: set[bytes] | None
-    new_unshallow: set[bytes] | None
+    new_shallow: set[ObjectID] | None
+    new_unshallow: set[ObjectID] | None
     assert isinstance(wants, list) and isinstance(wants[0], bytes)
     assert isinstance(wants, list) and isinstance(wants[0], bytes)
     wantcmd = COMMAND_WANT + b" " + wants[0]
     wantcmd = COMMAND_WANT + b" " + wants[0]
     if protocol_version is None:
     if protocol_version is None:
@@ -788,7 +796,7 @@ def _handle_upload_pack_head(
             assert pkt is not None
             assert pkt is not None
             parts = pkt.rstrip(b"\n").split(b" ")
             parts = pkt.rstrip(b"\n").split(b" ")
             if parts[0] == b"ACK":
             if parts[0] == b"ACK":
-                graph_walker.ack(parts[1])
+                graph_walker.ack(ObjectID(parts[1]))
                 if parts[2] in (b"continue", b"common"):
                 if parts[2] in (b"continue", b"common"):
                     pass
                     pass
                 elif parts[2] == b"ready":
                 elif parts[2] == b"ready":
@@ -809,7 +817,7 @@ def _handle_upload_pack_head(
             new_shallow = None
             new_shallow = None
             new_unshallow = None
             new_unshallow = None
     else:
     else:
-        new_shallow = new_unshallow = set()
+        new_shallow = new_unshallow = set[ObjectID]()
 
 
     return (new_shallow, new_unshallow)
     return (new_shallow, new_unshallow)
 
 
@@ -841,7 +849,7 @@ def _handle_upload_pack_tail(
             break
             break
         else:
         else:
             if parts[0] == b"ACK":
             if parts[0] == b"ACK":
-                graph_walker.ack(parts[1])
+                graph_walker.ack(ObjectID(parts[1]))
             if parts[0] == b"NAK":
             if parts[0] == b"NAK":
                 graph_walker.nak()
                 graph_walker.nak()
             if len(parts) < 3 or parts[2] not in (
             if len(parts) < 3 or parts[2] not in (
@@ -875,7 +883,7 @@ def _handle_upload_pack_tail(
 
 
 def _extract_symrefs_and_agent(
 def _extract_symrefs_and_agent(
     capabilities: Iterable[bytes],
     capabilities: Iterable[bytes],
-) -> tuple[dict[bytes, bytes], bytes | None]:
+) -> tuple[dict[Ref, Ref], bytes | None]:
     """Extract symrefs and agent from capabilities.
     """Extract symrefs and agent from capabilities.
 
 
     Args:
     Args:
@@ -883,14 +891,14 @@ def _extract_symrefs_and_agent(
     Returns:
     Returns:
      (symrefs, agent) tuple
      (symrefs, agent) tuple
     """
     """
-    symrefs = {}
+    symrefs: dict[Ref, Ref] = {}
     agent = None
     agent = None
     for capability in capabilities:
     for capability in capabilities:
         k, v = parse_capability(capability)
         k, v = parse_capability(capability)
         if k == CAPABILITY_SYMREF:
         if k == CAPABILITY_SYMREF:
             assert v is not None
             assert v is not None
             (src, dst) = v.split(b":", 1)
             (src, dst) = v.split(b":", 1)
-            symrefs[src] = dst
+            symrefs[Ref(src)] = Ref(dst)
         if k == CAPABILITY_AGENT:
         if k == CAPABILITY_AGENT:
             agent = v
             agent = v
     return (symrefs, agent)
     return (symrefs, agent)
@@ -979,7 +987,7 @@ class GitClient:
     def send_pack(
     def send_pack(
         self,
         self,
         path: bytes,
         path: bytes,
-        update_refs: Callable[[dict[bytes, bytes]], dict[bytes, bytes]],
+        update_refs: Callable[[dict[Ref, ObjectID]], dict[Ref, ObjectID]],
         generate_pack_data: "GeneratePackDataFunc",
         generate_pack_data: "GeneratePackDataFunc",
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
     ) -> SendPackResult:
     ) -> SendPackResult:
@@ -1014,7 +1022,7 @@ class GitClient:
         branch: str | None = None,
         branch: str | None = None,
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
     ) -> Repo:
     ) -> Repo:
@@ -1067,12 +1075,12 @@ class GitClient:
                     target.refs, origin, result.refs, message=ref_message
                     target.refs, origin, result.refs, message=ref_message
                 )
                 )
 
 
-            origin_head = result.symrefs.get(b"HEAD")
-            origin_sha = result.refs.get(b"HEAD")
+            origin_head = result.symrefs.get(HEADREF)
+            origin_sha = result.refs.get(HEADREF)
             if origin is None or (origin_sha and not origin_head):
             if origin is None or (origin_sha and not origin_head):
                 # set detached HEAD
                 # set detached HEAD
                 if origin_sha is not None:
                 if origin_sha is not None:
-                    target.refs[b"HEAD"] = origin_sha
+                    target.refs[HEADREF] = origin_sha
                     head = origin_sha
                     head = origin_sha
                 else:
                 else:
                     head = None
                     head = None
@@ -1111,7 +1119,7 @@ class GitClient:
         determine_wants: "DetermineWantsFunc | None" = None,
         determine_wants: "DetermineWantsFunc | None" = None,
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
         shallow_since: str | None = None,
@@ -1195,7 +1203,7 @@ class GitClient:
         *,
         *,
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
         shallow_since: str | None = None,
@@ -1233,7 +1241,7 @@ class GitClient:
         self,
         self,
         path: bytes,
         path: bytes,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> LsRemoteResult:
     ) -> LsRemoteResult:
         """Retrieve the current refs from a git smart server.
         """Retrieve the current refs from a git smart server.
 
 
@@ -1248,7 +1256,7 @@ class GitClient:
         raise NotImplementedError(self.get_refs)
         raise NotImplementedError(self.get_refs)
 
 
     @staticmethod
     @staticmethod
-    def _should_send_pack(new_refs: Mapping[bytes, bytes]) -> bool:
+    def _should_send_pack(new_refs: Mapping[Ref, ObjectID]) -> bool:
         # The packfile MUST NOT be sent if the only command used is delete.
         # The packfile MUST NOT be sent if the only command used is delete.
         return any(sha != ZERO_SHA for sha in new_refs.values())
         return any(sha != ZERO_SHA for sha in new_refs.values())
 
 
@@ -1308,7 +1316,7 @@ class GitClient:
 
 
     def _negotiate_upload_pack_capabilities(
     def _negotiate_upload_pack_capabilities(
         self, server_capabilities: set[bytes]
         self, server_capabilities: set[bytes]
-    ) -> tuple[set[bytes], dict[bytes, bytes], bytes | None]:
+    ) -> tuple[set[bytes], dict[Ref, Ref], bytes | None]:
         (extract_capability_names(server_capabilities) - KNOWN_UPLOAD_CAPABILITIES)
         (extract_capability_names(server_capabilities) - KNOWN_UPLOAD_CAPABILITIES)
         # TODO(jelmer): warn about unknown capabilities
         # TODO(jelmer): warn about unknown capabilities
         fetch_capa = None
         fetch_capa = None
@@ -1440,7 +1448,7 @@ class TraditionalGitClient(GitClient):
     def send_pack(
     def send_pack(
         self,
         self,
         path: bytes,
         path: bytes,
-        update_refs: Callable[[dict[bytes, bytes]], dict[bytes, bytes]],
+        update_refs: Callable[[dict[Ref, ObjectID]], dict[Ref, ObjectID]],
         generate_pack_data: "GeneratePackDataFunc",
         generate_pack_data: "GeneratePackDataFunc",
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
     ) -> SendPackResult:
     ) -> SendPackResult:
@@ -1541,11 +1549,8 @@ class TraditionalGitClient(GitClient):
             ref_status = self._handle_receive_pack_tail(
             ref_status = self._handle_receive_pack_tail(
                 proto, negotiated_capabilities, progress
                 proto, negotiated_capabilities, progress
             )
             )
-            refs_with_optional_2: dict[bytes, bytes | None] = {
-                k: v for k, v in new_refs.items()
-            }
             return SendPackResult(
             return SendPackResult(
-                refs_with_optional_2, agent=agent, ref_status=ref_status
+                _to_optional_dict(new_refs), agent=agent, ref_status=ref_status
             )
             )
 
 
     def fetch_pack(
     def fetch_pack(
@@ -1556,7 +1561,7 @@ class TraditionalGitClient(GitClient):
         pack_data: Callable[[bytes], int],
         pack_data: Callable[[bytes], int],
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
         shallow_since: str | None = None,
@@ -1606,7 +1611,9 @@ class TraditionalGitClient(GitClient):
         self.protocol_version = server_protocol_version
         self.protocol_version = server_protocol_version
         with proto:
         with proto:
             # refs may have None values in v2 but not in v1
             # refs may have None values in v2 but not in v1
-            refs: dict[bytes, bytes | None]
+            refs: dict[Ref, ObjectID | None]
+            symrefs: dict[Ref, Ref]
+            agent: bytes | None
             if self.protocol_version == 2:
             if self.protocol_version == 2:
                 try:
                 try:
                     server_capabilities = read_server_capabilities(proto.read_pkt_seq())
                     server_capabilities = read_server_capabilities(proto.read_pkt_seq())
@@ -1710,7 +1717,7 @@ class TraditionalGitClient(GitClient):
         self,
         self,
         path: bytes,
         path: bytes,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> LsRemoteResult:
     ) -> LsRemoteResult:
         """Retrieve the current refs from a git smart server."""
         """Retrieve the current refs from a git smart server."""
         # stock `git ls-remote` uses upload-pack
         # stock `git ls-remote` uses upload-pack
@@ -1748,7 +1755,7 @@ class TraditionalGitClient(GitClient):
                     raise _remote_error_from_stderr(stderr) from exc
                     raise _remote_error_from_stderr(stderr) from exc
                 proto.write_pkt_line(None)
                 proto.write_pkt_line(None)
                 for refname, refvalue in peeled.items():
                 for refname, refvalue in peeled.items():
-                    refs[refname + PEELED_TAG_SUFFIX] = refvalue
+                    refs[Ref(refname + PEELED_TAG_SUFFIX)] = refvalue
                 return LsRemoteResult(refs, symrefs)
                 return LsRemoteResult(refs, symrefs)
         else:
         else:
             with proto:
             with proto:
@@ -2231,7 +2238,7 @@ class LocalGitClient(GitClient):
     def send_pack(
     def send_pack(
         self,
         self,
         path: str | bytes,
         path: str | bytes,
-        update_refs: Callable[[dict[bytes, bytes]], dict[bytes, bytes]],
+        update_refs: Callable[[dict[Ref, ObjectID]], dict[Ref, ObjectID]],
         generate_pack_data: "GeneratePackDataFunc",
         generate_pack_data: "GeneratePackDataFunc",
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
     ) -> SendPackResult:
     ) -> SendPackResult:
@@ -2354,7 +2361,7 @@ class LocalGitClient(GitClient):
         pack_data: Callable[[bytes], int],
         pack_data: Callable[[bytes], int],
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
         shallow_since: str | None = None,
@@ -2415,21 +2422,23 @@ class LocalGitClient(GitClient):
         self,
         self,
         path: str | bytes,
         path: str | bytes,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> LsRemoteResult:
     ) -> LsRemoteResult:
         """Retrieve the current refs from a local on-disk repository."""
         """Retrieve the current refs from a local on-disk repository."""
         with self._open_repo(path) as target:
         with self._open_repo(path) as target:
             refs_dict = target.get_refs()
             refs_dict = target.get_refs()
             refs = _to_optional_dict(refs_dict)
             refs = _to_optional_dict(refs_dict)
             # Extract symrefs from the local repository
             # Extract symrefs from the local repository
-            symrefs: dict[bytes, bytes] = {}
+            from dulwich.refs import Ref
+
+            symrefs: dict[Ref, Ref] = {}
             for ref in refs:
             for ref in refs:
                 try:
                 try:
                     # Check if this ref is symbolic by reading it directly
                     # Check if this ref is symbolic by reading it directly
                     ref_value = target.refs.read_ref(ref)
                     ref_value = target.refs.read_ref(ref)
                     if ref_value and ref_value.startswith(SYMREF):
                     if ref_value and ref_value.startswith(SYMREF):
                         # Extract the target from the symref
                         # Extract the target from the symref
-                        symrefs[ref] = ref_value[len(SYMREF) :]
+                        symrefs[ref] = Ref(ref_value[len(SYMREF) :])
                 except (KeyError, ValueError):
                 except (KeyError, ValueError):
                     # Not a symbolic ref or error reading it
                     # Not a symbolic ref or error reading it
                     pass
                     pass
@@ -2546,9 +2555,9 @@ class BundleClient(GitClient):
             else:
             else:
                 raise AssertionError(f"unsupported bundle format header: {firstline!r}")
                 raise AssertionError(f"unsupported bundle format header: {firstline!r}")
 
 
-            capabilities = {}
-            prerequisites = []
-            references = {}
+            capabilities: dict[str, str | None] = {}
+            prerequisites: list[tuple[ObjectID, bytes]] = []
+            references: dict[Ref, ObjectID] = {}
             line = f.readline()
             line = f.readline()
 
 
             if version >= 3:
             if version >= 3:
@@ -2565,12 +2574,12 @@ class BundleClient(GitClient):
 
 
             while line.startswith(b"-"):
             while line.startswith(b"-"):
                 (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
                 (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
-                prerequisites.append((obj_id, comment))
+                prerequisites.append((ObjectID(obj_id), comment))
                 line = f.readline()
                 line = f.readline()
 
 
             while line != b"\n":
             while line != b"\n":
                 (obj_id, ref) = line.rstrip(b"\n").split(b" ", 1)
                 (obj_id, ref) = line.rstrip(b"\n").split(b" ", 1)
-                references[ref] = obj_id
+                references[Ref(ref)] = ObjectID(obj_id)
                 line = f.readline()
                 line = f.readline()
 
 
             # Don't read PackData here, we'll do it later
             # Don't read PackData here, we'll do it later
@@ -2619,7 +2628,7 @@ class BundleClient(GitClient):
     def send_pack(
     def send_pack(
         self,
         self,
         path: str | bytes,
         path: str | bytes,
-        update_refs: Callable[[dict[bytes, bytes]], dict[bytes, bytes]],
+        update_refs: Callable[[dict[Ref, ObjectID]], dict[Ref, ObjectID]],
         generate_pack_data: "GeneratePackDataFunc",
         generate_pack_data: "GeneratePackDataFunc",
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
     ) -> SendPackResult:
     ) -> SendPackResult:
@@ -2633,7 +2642,7 @@ class BundleClient(GitClient):
         determine_wants: "DetermineWantsFunc | None" = None,
         determine_wants: "DetermineWantsFunc | None" = None,
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
         shallow_since: str | None = None,
@@ -2687,7 +2696,7 @@ class BundleClient(GitClient):
         pack_data: Callable[[bytes], int],
         pack_data: Callable[[bytes], int],
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
         shallow_since: str | None = None,
@@ -2732,7 +2741,7 @@ class BundleClient(GitClient):
         self,
         self,
         path: str | bytes,
         path: str | bytes,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> LsRemoteResult:
     ) -> LsRemoteResult:
         """Retrieve the current refs from a bundle file."""
         """Retrieve the current refs from a bundle file."""
         bundle = self._open_bundle(path)
         bundle = self._open_bundle(path)
@@ -3502,7 +3511,7 @@ class AbstractHttpGitClient(GitClient):
         service: bytes,
         service: bytes,
         base_url: str,
         base_url: str,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> tuple[
     ) -> tuple[
         dict[Ref, ObjectID | None],
         dict[Ref, ObjectID | None],
         set[bytes],
         set[bytes],
@@ -3627,7 +3636,8 @@ class AbstractHttpGitClient(GitClient):
                         ) = read_pkt_refs_v1(proto.read_pkt_seq())
                         ) = read_pkt_refs_v1(proto.read_pkt_seq())
                         # Convert v1 refs to Optional type
                         # Convert v1 refs to Optional type
                         refs = _to_optional_dict(refs_v1)
                         refs = _to_optional_dict(refs_v1)
-                        (refs, peeled) = split_peeled_refs(refs)
+                        # TODO: split_peeled_refs should accept Optional values
+                        (refs, peeled) = split_peeled_refs(refs)  # type: ignore[arg-type,assignment]
                         (symrefs, _agent) = _extract_symrefs_and_agent(
                         (symrefs, _agent) = _extract_symrefs_and_agent(
                             server_capabilities
                             server_capabilities
                         )
                         )
@@ -3646,12 +3656,13 @@ class AbstractHttpGitClient(GitClient):
                 from typing import cast
                 from typing import cast
 
 
                 info_refs = read_info_refs(BytesIO(data))
                 info_refs = read_info_refs(BytesIO(data))
-                (refs, peeled) = split_peeled_refs(
-                    cast(dict[bytes, bytes | None], info_refs)
-                )
+                (refs_nonopt, peeled) = split_peeled_refs(info_refs)
                 if ref_prefix is not None:
                 if ref_prefix is not None:
-                    refs = filter_ref_prefix(refs, ref_prefix)
-                return refs, set(), base_url, {}, peeled
+                    refs_nonopt = filter_ref_prefix(refs_nonopt, ref_prefix)
+                refs_result: dict[Ref, ObjectID | None] = cast(
+                    dict[Ref, ObjectID | None], refs_nonopt
+                )
+                return refs_result, set(), base_url, {}, peeled
         finally:
         finally:
             resp.close()
             resp.close()
 
 
@@ -3687,7 +3698,7 @@ class AbstractHttpGitClient(GitClient):
     def send_pack(
     def send_pack(
         self,
         self,
         path: str | bytes,
         path: str | bytes,
-        update_refs: Callable[[dict[bytes, bytes]], dict[bytes, bytes]],
+        update_refs: Callable[[dict[Ref, ObjectID]], dict[Ref, ObjectID]],
         generate_pack_data: "GeneratePackDataFunc",
         generate_pack_data: "GeneratePackDataFunc",
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
     ) -> SendPackResult:
     ) -> SendPackResult:
@@ -3726,13 +3737,14 @@ class AbstractHttpGitClient(GitClient):
         assert all(v is not None for v in old_refs.values()), (
         assert all(v is not None for v in old_refs.values()), (
             "old_refs should not contain None values"
             "old_refs should not contain None values"
         )
         )
-        old_refs_typed: dict[bytes, bytes] = old_refs  # type: ignore[assignment]
+        old_refs_typed: dict[Ref, ObjectID] = old_refs  # type: ignore[assignment]
         new_refs = update_refs(dict(old_refs_typed))
         new_refs = update_refs(dict(old_refs_typed))
         if new_refs is None:
         if new_refs is None:
             # Determine wants function is aborting the push.
             # Determine wants function is aborting the push.
             # Convert to Optional type for SendPackResult
             # Convert to Optional type for SendPackResult
-            old_refs_optional: dict[bytes, bytes | None] = old_refs
-            return SendPackResult(old_refs_optional, agent=agent, ref_status={})
+            return SendPackResult(
+                _to_optional_dict(old_refs_typed), agent=agent, ref_status={}
+            )
         if set(new_refs.items()).issubset(set(old_refs_typed.items())):
         if set(new_refs.items()).issubset(set(old_refs_typed.items())):
             # Convert to Optional type for SendPackResult
             # Convert to Optional type for SendPackResult
             return SendPackResult(
             return SendPackResult(
@@ -3777,7 +3789,7 @@ class AbstractHttpGitClient(GitClient):
         pack_data: Callable[[bytes], int],
         pack_data: Callable[[bytes], int],
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
         shallow_since: str | None = None,
@@ -3850,7 +3862,7 @@ class AbstractHttpGitClient(GitClient):
                 )
                 )
             )
             )
 
 
-            symrefs[b"HEAD"] = dumb_repo.get_head()
+            symrefs[HEADREF] = dumb_repo.get_head()
 
 
             # Write pack data
             # Write pack data
             if pack_data_list:
             if pack_data_list:
@@ -3923,7 +3935,7 @@ class AbstractHttpGitClient(GitClient):
         self,
         self,
         path: str | bytes,
         path: str | bytes,
         protocol_version: int | None = None,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> LsRemoteResult:
     ) -> LsRemoteResult:
         """Retrieve the current refs from a git smart server."""
         """Retrieve the current refs from a git smart server."""
         url = self._get_url(path)
         url = self._get_url(path)
@@ -3934,7 +3946,7 @@ class AbstractHttpGitClient(GitClient):
             ref_prefix=ref_prefix,
             ref_prefix=ref_prefix,
         )
         )
         for refname, refvalue in peeled.items():
         for refname, refvalue in peeled.items():
-            refs[refname + PEELED_TAG_SUFFIX] = refvalue
+            refs[Ref(refname + PEELED_TAG_SUFFIX)] = refvalue
         return LsRemoteResult(refs, symrefs)
         return LsRemoteResult(refs, symrefs)
 
 
     def get_url(self, path: str) -> str:
     def get_url(self, path: str) -> str:

+ 27 - 32
dulwich/commit_graph.py

@@ -26,7 +26,7 @@ from .file import _GitFile
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
     from .object_store import BaseObjectStore
 
 
-from .objects import Commit, ObjectID, hex_to_sha, sha_to_hex
+from .objects import Commit, ObjectID, RawObjectID, hex_to_sha, sha_to_hex
 
 
 # File format constants
 # File format constants
 COMMIT_GRAPH_SIGNATURE = b"CGPH"
 COMMIT_GRAPH_SIGNATURE = b"CGPH"
@@ -188,9 +188,9 @@ class CommitGraph:
         for i in range(num_commits):
         for i in range(num_commits):
             start = i * self._hash_size
             start = i * self._hash_size
             end = start + self._hash_size
             end = start + self._hash_size
-            oid = oid_lookup_data[start:end]
+            oid = RawObjectID(oid_lookup_data[start:end])
             oids.append(oid)
             oids.append(oid)
-            self._oid_to_index[oid] = i
+            self._oid_to_index[sha_to_hex(oid)] = i
 
 
         # Parse commit data chunk
         # Parse commit data chunk
         commit_data = self.chunks[CHUNK_COMMIT_DATA].data
         commit_data = self.chunks[CHUNK_COMMIT_DATA].data
@@ -205,7 +205,7 @@ class CommitGraph:
             offset = i * (self._hash_size + 16)
             offset = i * (self._hash_size + 16)
 
 
             # Tree OID
             # Tree OID
-            tree_id = commit_data[offset : offset + self._hash_size]
+            tree_id = RawObjectID(commit_data[offset : offset + self._hash_size])
             offset += self._hash_size
             offset += self._hash_size
 
 
             # Parent positions (2 x 4 bytes)
             # Parent positions (2 x 4 bytes)
@@ -271,14 +271,7 @@ class CommitGraph:
 
 
     def get_entry_by_oid(self, oid: ObjectID) -> CommitGraphEntry | None:
     def get_entry_by_oid(self, oid: ObjectID) -> CommitGraphEntry | None:
         """Get commit graph entry by commit OID."""
         """Get commit graph entry by commit OID."""
-        # Convert hex ObjectID to binary if needed for lookup
-        if isinstance(oid, bytes) and len(oid) == 40:
-            # Input is hex ObjectID, convert to binary for internal lookup
-            lookup_oid = hex_to_sha(oid)
-        else:
-            # Input is already binary
-            lookup_oid = oid
-        index = self._oid_to_index.get(lookup_oid)
+        index = self._oid_to_index.get(oid)
         if index is not None:
         if index is not None:
             return self.entries[index]
             return self.entries[index]
         return None
         return None
@@ -288,7 +281,7 @@ class CommitGraph:
         entry = self.get_entry_by_oid(oid)
         entry = self.get_entry_by_oid(oid)
         return entry.generation if entry else None
         return entry.generation if entry else None
 
 
-    def get_parents(self, oid: ObjectID) -> list[bytes] | None:
+    def get_parents(self, oid: ObjectID) -> list[ObjectID] | None:
         """Get parent commit IDs for a commit."""
         """Get parent commit IDs for a commit."""
         entry = self.get_entry_by_oid(oid)
         entry = self.get_entry_by_oid(oid)
         return entry.parents if entry else None
         return entry.parents if entry else None
@@ -443,20 +436,20 @@ def generate_commit_graph(
 
 
     # Ensure all commit_ids are in the correct format for object store access
     # Ensure all commit_ids are in the correct format for object store access
     # DiskObjectStore expects hex ObjectIDs (40-byte hex strings)
     # DiskObjectStore expects hex ObjectIDs (40-byte hex strings)
-    normalized_commit_ids = []
+    normalized_commit_ids: list[ObjectID] = []
     for commit_id in commit_ids:
     for commit_id in commit_ids:
         if isinstance(commit_id, bytes) and len(commit_id) == 40:
         if isinstance(commit_id, bytes) and len(commit_id) == 40:
             # Already hex ObjectID
             # Already hex ObjectID
-            normalized_commit_ids.append(commit_id)
+            normalized_commit_ids.append(ObjectID(commit_id))
         elif isinstance(commit_id, bytes) and len(commit_id) == 20:
         elif isinstance(commit_id, bytes) and len(commit_id) == 20:
             # Binary SHA, convert to hex ObjectID
             # Binary SHA, convert to hex ObjectID
-            normalized_commit_ids.append(sha_to_hex(commit_id))
+            normalized_commit_ids.append(sha_to_hex(RawObjectID(commit_id)))
         else:
         else:
             # Assume it's already correct format
             # Assume it's already correct format
-            normalized_commit_ids.append(commit_id)
+            normalized_commit_ids.append(ObjectID(commit_id))
 
 
     # Build a map of all commits and their metadata
     # Build a map of all commits and their metadata
-    commit_map: dict[bytes, Commit] = {}
+    commit_map: dict[ObjectID, Commit] = {}
     for commit_id in normalized_commit_ids:
     for commit_id in normalized_commit_ids:
         try:
         try:
             commit_obj = object_store[commit_id]
             commit_obj = object_store[commit_id]
@@ -503,19 +496,20 @@ def generate_commit_graph(
     # Build commit graph entries
     # Build commit graph entries
     for commit_id, commit_obj in commit_map.items():
     for commit_id, commit_obj in commit_map.items():
         # commit_id is already hex ObjectID from normalized_commit_ids
         # commit_id is already hex ObjectID from normalized_commit_ids
-        commit_hex = commit_id
+        commit_hex: ObjectID = commit_id
 
 
         # Handle tree ID - might already be hex ObjectID
         # Handle tree ID - might already be hex ObjectID
+        tree_hex: ObjectID
         if isinstance(commit_obj.tree, bytes) and len(commit_obj.tree) == 40:
         if isinstance(commit_obj.tree, bytes) and len(commit_obj.tree) == 40:
-            tree_hex = commit_obj.tree  # Already hex ObjectID
+            tree_hex = ObjectID(commit_obj.tree)  # Already hex ObjectID
         else:
         else:
             tree_hex = sha_to_hex(commit_obj.tree)  # Binary, convert to hex
             tree_hex = sha_to_hex(commit_obj.tree)  # Binary, convert to hex
 
 
         # Handle parent IDs - might already be hex ObjectIDs
         # Handle parent IDs - might already be hex ObjectIDs
-        parents_hex = []
+        parents_hex: list[ObjectID] = []
         for parent_id in commit_obj.parents:
         for parent_id in commit_obj.parents:
             if isinstance(parent_id, bytes) and len(parent_id) == 40:
             if isinstance(parent_id, bytes) and len(parent_id) == 40:
-                parents_hex.append(parent_id)  # Already hex ObjectID
+                parents_hex.append(ObjectID(parent_id))  # Already hex ObjectID
             else:
             else:
                 parents_hex.append(sha_to_hex(parent_id))  # Binary, convert to hex
                 parents_hex.append(sha_to_hex(parent_id))  # Binary, convert to hex
 
 
@@ -531,8 +525,7 @@ def generate_commit_graph(
     # Build the OID to index mapping for lookups
     # Build the OID to index mapping for lookups
     graph._oid_to_index = {}
     graph._oid_to_index = {}
     for i, entry in enumerate(graph.entries):
     for i, entry in enumerate(graph.entries):
-        binary_oid = hex_to_sha(entry.commit_id.decode())
-        graph._oid_to_index[binary_oid] = i
+        graph._oid_to_index[entry.commit_id] = i
 
 
     return graph
     return graph
 
 
@@ -582,25 +575,27 @@ def get_reachable_commits(
     Returns:
     Returns:
         List of all reachable commit IDs (including the starting commits)
         List of all reachable commit IDs (including the starting commits)
     """
     """
-    visited = set()
-    reachable = []
-    stack = []
+    visited: set[ObjectID] = set()
+    reachable: list[ObjectID] = []
+    stack: list[ObjectID] = []
 
 
     # Normalize commit IDs for object store access and tracking
     # Normalize commit IDs for object store access and tracking
     for commit_id in start_commits:
     for commit_id in start_commits:
         if isinstance(commit_id, bytes) and len(commit_id) == 40:
         if isinstance(commit_id, bytes) and len(commit_id) == 40:
             # Hex ObjectID - use directly for object store access
             # Hex ObjectID - use directly for object store access
-            if commit_id not in visited:
-                stack.append(commit_id)
+            oid = ObjectID(commit_id)
+            if oid not in visited:
+                stack.append(oid)
         elif isinstance(commit_id, bytes) and len(commit_id) == 20:
         elif isinstance(commit_id, bytes) and len(commit_id) == 20:
             # Binary SHA, convert to hex ObjectID for object store access
             # Binary SHA, convert to hex ObjectID for object store access
-            hex_id = sha_to_hex(commit_id)
+            hex_id = sha_to_hex(RawObjectID(commit_id))
             if hex_id not in visited:
             if hex_id not in visited:
                 stack.append(hex_id)
                 stack.append(hex_id)
         else:
         else:
             # Assume it's already correct format
             # Assume it's already correct format
-            if commit_id not in visited:
-                stack.append(commit_id)
+            oid = ObjectID(commit_id)
+            if oid not in visited:
+                stack.append(oid)
 
 
     while stack:
     while stack:
         commit_id = stack.pop()
         commit_id = stack.pop()

+ 24 - 17
dulwich/contrib/swift.py

@@ -47,7 +47,7 @@ from ..file import _GitFile
 from ..greenthreads import GreenThreadsMissingObjectFinder
 from ..greenthreads import GreenThreadsMissingObjectFinder
 from ..lru_cache import LRUSizeCache
 from ..lru_cache import LRUSizeCache
 from ..object_store import INFODIR, PACKDIR, PackBasedObjectStore
 from ..object_store import INFODIR, PACKDIR, PackBasedObjectStore
-from ..objects import S_ISGITLINK, Blob, Commit, Tag, Tree
+from ..objects import S_ISGITLINK, Blob, Commit, ObjectID, Tag, Tree
 from ..pack import (
 from ..pack import (
     ObjectContainer,
     ObjectContainer,
     Pack,
     Pack,
@@ -66,7 +66,14 @@ from ..pack import (
     write_pack_object,
     write_pack_object,
 )
 )
 from ..protocol import TCP_GIT_PORT
 from ..protocol import TCP_GIT_PORT
-from ..refs import InfoRefsContainer, read_info_refs, split_peeled_refs, write_info_refs
+from ..refs import (
+    HEADREF,
+    InfoRefsContainer,
+    Ref,
+    read_info_refs,
+    split_peeled_refs,
+    write_info_refs,
+)
 from ..repo import OBJECTDIR, BaseRepo
 from ..repo import OBJECTDIR, BaseRepo
 from ..server import Backend, BackendRepo, TCPGitServer
 from ..server import Backend, BackendRepo, TCPGitServer
 
 
@@ -763,7 +770,7 @@ class SwiftObjectStore(PackBasedObjectStore):
         """Loose objects are not supported by this repository."""
         """Loose objects are not supported by this repository."""
         return iter([])
         return iter([])
 
 
-    def pack_info_get(self, sha: bytes) -> tuple[Any, ...] | None:
+    def pack_info_get(self, sha: ObjectID) -> tuple[Any, ...] | None:
         """Get pack info for a specific SHA.
         """Get pack info for a specific SHA.
 
 
         Args:
         Args:
@@ -786,7 +793,7 @@ class SwiftObjectStore(PackBasedObjectStore):
         if common is None:
         if common is None:
             common = set()
             common = set()
 
 
-        def _find_parents(commit: bytes) -> list[Any]:
+        def _find_parents(commit: ObjectID) -> list[Any]:
             for pack in self.packs:
             for pack in self.packs:
                 if commit in pack:
                 if commit in pack:
                     try:
                     try:
@@ -983,8 +990,8 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         super().__init__(f)
         super().__init__(f)
 
 
     def _load_check_ref(
     def _load_check_ref(
-        self, name: bytes, old_ref: bytes | None
-    ) -> dict[bytes, bytes] | bool:
+        self, name: Ref, old_ref: ObjectID | None
+    ) -> dict[Ref, ObjectID] | bool:
         self._check_refname(name)
         self._check_refname(name)
         obj = self.scon.get_object(self.filename)
         obj = self.scon.get_object(self.filename)
         if not obj:
         if not obj:
@@ -1000,23 +1007,23 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
                 return False
                 return False
         return refs
         return refs
 
 
-    def _write_refs(self, refs: Mapping[bytes, bytes]) -> None:
+    def _write_refs(self, refs: Mapping[Ref, ObjectID]) -> None:
         f = BytesIO()
         f = BytesIO()
         f.writelines(write_info_refs(refs, cast("ObjectContainer", self.store)))
         f.writelines(write_info_refs(refs, cast("ObjectContainer", self.store)))
         self.scon.put_object(self.filename, f)
         self.scon.put_object(self.filename, f)
 
 
     def set_if_equals(
     def set_if_equals(
         self,
         self,
-        name: bytes,
-        old_ref: bytes | None,
-        new_ref: bytes,
+        name: Ref,
+        old_ref: ObjectID | None,
+        new_ref: ObjectID,
         committer: bytes | None = None,
         committer: bytes | None = None,
         timestamp: float | None = None,
         timestamp: float | None = None,
         timezone: int | None = None,
         timezone: int | None = None,
         message: bytes | None = None,
         message: bytes | None = None,
     ) -> bool:
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref."""
         """Set a refname to new_ref only if it currently equals old_ref."""
-        if name == b"HEAD":
+        if name == HEADREF:
             return True
             return True
         refs = self._load_check_ref(name, old_ref)
         refs = self._load_check_ref(name, old_ref)
         if not isinstance(refs, dict):
         if not isinstance(refs, dict):
@@ -1028,15 +1035,15 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
 
 
     def remove_if_equals(
     def remove_if_equals(
         self,
         self,
-        name: bytes,
-        old_ref: bytes | None,
+        name: Ref,
+        old_ref: ObjectID | None,
         committer: object = None,
         committer: object = None,
         timestamp: object = None,
         timestamp: object = None,
         timezone: object = None,
         timezone: object = None,
         message: object = None,
         message: object = None,
     ) -> bool:
     ) -> bool:
         """Remove a refname only if it currently equals old_ref."""
         """Remove a refname only if it currently equals old_ref."""
-        if name == b"HEAD":
+        if name == HEADREF:
             return True
             return True
         refs = self._load_check_ref(name, old_ref)
         refs = self._load_check_ref(name, old_ref)
         if not isinstance(refs, dict):
         if not isinstance(refs, dict):
@@ -1046,14 +1053,14 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         del self._refs[name]
         del self._refs[name]
         return True
         return True
 
 
-    def allkeys(self) -> set[bytes]:
+    def allkeys(self) -> set[Ref]:
         """Get all reference names.
         """Get all reference names.
 
 
         Returns:
         Returns:
-          Set of reference names as bytes
+          Set of reference names as Ref
         """
         """
         try:
         try:
-            self._refs[b"HEAD"] = self._refs[b"refs/heads/master"]
+            self._refs[HEADREF] = self._refs[Ref(b"refs/heads/master")]
         except KeyError:
         except KeyError:
             pass
             pass
         return set(self._refs.keys())
         return set(self._refs.keys())

+ 6 - 4
dulwich/diff.py

@@ -54,7 +54,7 @@ from typing import BinaryIO
 from ._typing import Buffer
 from ._typing import Buffer
 from .index import ConflictedIndexEntry, commit_index
 from .index import ConflictedIndexEntry, commit_index
 from .object_store import iter_tree_contents
 from .object_store import iter_tree_contents
-from .objects import S_ISGITLINK, Blob, Commit
+from .objects import S_ISGITLINK, Blob, Commit, ObjectID
 from .patch import write_blob_diff, write_object_diff
 from .patch import write_blob_diff, write_object_diff
 from .repo import Repo
 from .repo import Repo
 
 
@@ -79,7 +79,7 @@ def should_include_path(path: bytes, paths: Sequence[bytes] | None) -> bool:
 def diff_index_to_tree(
 def diff_index_to_tree(
     repo: Repo,
     repo: Repo,
     outstream: BinaryIO,
     outstream: BinaryIO,
-    commit_sha: bytes | None = None,
+    commit_sha: ObjectID | None = None,
     paths: Sequence[bytes] | None = None,
     paths: Sequence[bytes] | None = None,
     diff_algorithm: str | None = None,
     diff_algorithm: str | None = None,
 ) -> None:
 ) -> None:
@@ -94,7 +94,9 @@ def diff_index_to_tree(
     """
     """
     if commit_sha is None:
     if commit_sha is None:
         try:
         try:
-            commit_sha = repo.refs[b"HEAD"]
+            from dulwich.refs import HEADREF
+
+            commit_sha = repo.refs[HEADREF]
             old_commit = repo[commit_sha]
             old_commit = repo[commit_sha]
             assert isinstance(old_commit, Commit)
             assert isinstance(old_commit, Commit)
             old_tree = old_commit.tree
             old_tree = old_commit.tree
@@ -124,7 +126,7 @@ def diff_index_to_tree(
 def diff_working_tree_to_tree(
 def diff_working_tree_to_tree(
     repo: Repo,
     repo: Repo,
     outstream: BinaryIO,
     outstream: BinaryIO,
-    commit_sha: bytes,
+    commit_sha: ObjectID,
     paths: Sequence[bytes] | None = None,
     paths: Sequence[bytes] | None = None,
     diff_algorithm: str | None = None,
     diff_algorithm: str | None = None,
 ) -> None:
 ) -> None:

+ 19 - 13
dulwich/dumb.py

@@ -36,6 +36,7 @@ from .objects import (
     Blob,
     Blob,
     Commit,
     Commit,
     ObjectID,
     ObjectID,
+    RawObjectID,
     ShaFile,
     ShaFile,
     Tag,
     Tag,
     Tree,
     Tree,
@@ -201,7 +202,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
                 return idx
                 return idx
         raise KeyError(f"Pack not found: {pack_name}")
         raise KeyError(f"Pack not found: {pack_name}")
 
 
-    def _fetch_from_pack(self, sha: bytes) -> tuple[int, bytes]:
+    def _fetch_from_pack(self, sha: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Try to fetch an object from pack files.
         """Try to fetch an object from pack files.
 
 
         Args:
         Args:
@@ -215,7 +216,10 @@ class DumbHTTPObjectStore(BaseObjectStore):
         """
         """
         self._load_packs()
         self._load_packs()
         # Convert hex to binary for pack operations
         # Convert hex to binary for pack operations
-        binsha = hex_to_sha(sha)
+        if len(sha) == 20:
+            binsha = RawObjectID(sha)  # Already binary
+        else:
+            binsha = hex_to_sha(ObjectID(sha))  # Convert hex to binary
 
 
         for pack_name, pack_idx in self._packs or []:
         for pack_name, pack_idx in self._packs or []:
             if pack_idx is None:
             if pack_idx is None:
@@ -251,7 +255,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
 
 
         raise KeyError(sha)
         raise KeyError(sha)
 
 
-    def get_raw(self, sha: bytes) -> tuple[int, bytes]:
+    def get_raw(self, sha: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
 
 
         Args:
         Args:
@@ -276,7 +280,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
         self._cached_objects[sha] = result
         self._cached_objects[sha] = result
         return result
         return result
 
 
-    def contains_loose(self, sha: bytes) -> bool:
+    def contains_loose(self, sha: RawObjectID | ObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is loose."""
         """Check if a particular object is present by SHA1 and is loose."""
         try:
         try:
             self._fetch_loose_object(sha)
             self._fetch_loose_object(sha)
@@ -284,7 +288,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
         except KeyError:
         except KeyError:
             return False
             return False
 
 
-    def __contains__(self, sha: bytes) -> bool:
+    def __contains__(self, sha: RawObjectID | ObjectID) -> bool:
         """Check if a particular object is present by SHA1."""
         """Check if a particular object is present by SHA1."""
         if sha in self._cached_objects:
         if sha in self._cached_objects:
             return True
             return True
@@ -303,7 +307,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
         except KeyError:
         except KeyError:
             return False
             return False
 
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over all SHAs in the store.
         """Iterate over all SHAs in the store.
 
 
         Note: This is inefficient for dumb HTTP as it requires
         Note: This is inefficient for dumb HTTP as it requires
@@ -322,7 +326,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
             for sha in idx:
             for sha in idx:
                 if sha not in seen:
                 if sha not in seen:
                     seen.add(sha)
                     seen.add(sha)
-                    yield sha_to_hex(sha)
+                    yield sha_to_hex(RawObjectID(sha))
 
 
     @property
     @property
     def packs(self) -> list[Any]:
     def packs(self) -> list[Any]:
@@ -405,7 +409,10 @@ class DumbRemoteHTTPRepo:
 
 
             refs_hex = read_info_refs(BytesIO(refs_data))
             refs_hex = read_info_refs(BytesIO(refs_data))
             # Keep SHAs as hex
             # Keep SHAs as hex
-            self._refs, self._peeled = split_peeled_refs(refs_hex)
+            refs_raw, peeled_raw = split_peeled_refs(refs_hex)
+            # Convert to typed dicts
+            self._refs = {Ref(k): ObjectID(v) for k, v in refs_raw.items()}
+            self._peeled = peeled_raw
 
 
         return dict(self._refs)
         return dict(self._refs)
 
 
@@ -417,13 +424,12 @@ class DumbRemoteHTTPRepo:
         """
         """
         head_resp_bytes = self._fetch_url("HEAD")
         head_resp_bytes = self._fetch_url("HEAD")
         head_split = head_resp_bytes.replace(b"\n", b"").split(b" ")
         head_split = head_resp_bytes.replace(b"\n", b"").split(b" ")
-        head_target = head_split[1] if len(head_split) > 1 else head_split[0]
+        head_target_bytes = head_split[1] if len(head_split) > 1 else head_split[0]
         # handle HEAD legacy format containing a commit id instead of a ref name
         # handle HEAD legacy format containing a commit id instead of a ref name
         for ref_name, ret_target in self.get_refs().items():
         for ref_name, ret_target in self.get_refs().items():
-            if ret_target == head_target:
-                head_target = ref_name
-                break
-        return head_target
+            if ret_target == head_target_bytes:
+                return ref_name
+        return Ref(head_target_bytes)
 
 
     def get_peeled(self, ref: Ref) -> ObjectID:
     def get_peeled(self, ref: Ref) -> ObjectID:
         """Get the peeled value of a ref."""
         """Get the peeled value of a ref."""

+ 12 - 7
dulwich/fastexport.py

@@ -66,7 +66,7 @@ class GitFastExporter:
         """
         """
         self.outf = outf
         self.outf = outf
         self.store = store
         self.store = store
-        self.markers: dict[bytes, bytes] = {}
+        self.markers: dict[bytes, ObjectID] = {}
         self._marker_idx = 0
         self._marker_idx = 0
 
 
     def print_cmd(self, cmd: object) -> None:
     def print_cmd(self, cmd: object) -> None:
@@ -117,7 +117,7 @@ class GitFastExporter:
         return marker
         return marker
 
 
     def _iter_files(
     def _iter_files(
-        self, base_tree: bytes | None, new_tree: bytes | None
+        self, base_tree: ObjectID | None, new_tree: ObjectID | None
     ) -> Generator[Any, None, None]:
     ) -> Generator[Any, None, None]:
         for (
         for (
             (old_path, new_path),
             (old_path, new_path),
@@ -216,7 +216,7 @@ class GitImportProcessor(processor.ImportProcessor):  # type: ignore[misc,unused
         processor.ImportProcessor.__init__(self, params, verbose)  # type: ignore[no-untyped-call,unused-ignore]
         processor.ImportProcessor.__init__(self, params, verbose)  # type: ignore[no-untyped-call,unused-ignore]
         self.repo = repo
         self.repo = repo
         self.last_commit = ZERO_SHA
         self.last_commit = ZERO_SHA
-        self.markers: dict[bytes, bytes] = {}
+        self.markers: dict[bytes, ObjectID] = {}
         self._contents: dict[bytes, tuple[int, bytes]] = {}
         self._contents: dict[bytes, tuple[int, bytes]] = {}
 
 
     def lookup_object(self, objectish: bytes) -> ObjectID:
     def lookup_object(self, objectish: bytes) -> ObjectID:
@@ -230,9 +230,9 @@ class GitImportProcessor(processor.ImportProcessor):  # type: ignore[misc,unused
         """
         """
         if objectish.startswith(b":"):
         if objectish.startswith(b":"):
             return self.markers[objectish[1:]]
             return self.markers[objectish[1:]]
-        return objectish
+        return ObjectID(objectish)
 
 
-    def import_stream(self, stream: BinaryIO) -> dict[bytes, bytes]:
+    def import_stream(self, stream: BinaryIO) -> dict[bytes, ObjectID]:
         """Import from a fast-import stream.
         """Import from a fast-import stream.
 
 
         Args:
         Args:
@@ -314,9 +314,14 @@ class GitImportProcessor(processor.ImportProcessor):  # type: ignore[misc,unused
                 self._contents = {}
                 self._contents = {}
             else:
             else:
                 raise Exception(f"Command {filecmd.name!r} not supported")
                 raise Exception(f"Command {filecmd.name!r} not supported")
+        from dulwich.objects import ObjectID
+
         commit.tree = commit_tree(
         commit.tree = commit_tree(
             self.repo.object_store,
             self.repo.object_store,
-            ((path, hexsha, mode) for (path, (mode, hexsha)) in self._contents.items()),
+            (
+                (path, ObjectID(hexsha), mode)
+                for (path, (mode, hexsha)) in self._contents.items()
+            ),
         )
         )
         if self.last_commit != ZERO_SHA:
         if self.last_commit != ZERO_SHA:
             commit.parents.append(self.last_commit)
             commit.parents.append(self.last_commit)
@@ -363,7 +368,7 @@ class GitImportProcessor(processor.ImportProcessor):  # type: ignore[misc,unused
         else:
         else:
             from_ = self.lookup_object(cmd.from_)
             from_ = self.lookup_object(cmd.from_)
         self._reset_base(from_)
         self._reset_base(from_)
-        self.repo.refs[cmd.ref] = from_
+        self.repo.refs[Ref(cmd.ref)] = from_
 
 
     def tag_handler(self, cmd: commands.TagCommand) -> None:
     def tag_handler(self, cmd: commands.TagCommand) -> None:
         """Process a TagCommand."""
         """Process a TagCommand."""

+ 26 - 24
dulwich/filter_branch.py

@@ -29,8 +29,8 @@ from typing import TypedDict
 
 
 from .index import Index, build_index_from_tree
 from .index import Index, build_index_from_tree
 from .object_store import BaseObjectStore
 from .object_store import BaseObjectStore
-from .objects import Commit, Tag, Tree
-from .refs import RefsContainer, local_tag_name
+from .objects import Commit, ObjectID, Tag, Tree
+from .refs import Ref, RefsContainer, local_tag_name
 
 
 
 
 class CommitData(TypedDict, total=False):
 class CommitData(TypedDict, total=False):
@@ -57,10 +57,10 @@ class CommitFilter:
         filter_author: Callable[[bytes], bytes | None] | None = None,
         filter_author: Callable[[bytes], bytes | None] | None = None,
         filter_committer: Callable[[bytes], bytes | None] | None = None,
         filter_committer: Callable[[bytes], bytes | None] | None = None,
         filter_message: Callable[[bytes], bytes | None] | None = None,
         filter_message: Callable[[bytes], bytes | None] | None = None,
-        tree_filter: Callable[[bytes, str], bytes | None] | None = None,
-        index_filter: Callable[[bytes, str], bytes | None] | None = None,
-        parent_filter: Callable[[Sequence[bytes]], list[bytes]] | None = None,
-        commit_filter: Callable[[Commit, bytes], bytes | None] | None = None,
+        tree_filter: Callable[[ObjectID, str], ObjectID | None] | None = None,
+        index_filter: Callable[[ObjectID, str], ObjectID | None] | None = None,
+        parent_filter: Callable[[Sequence[ObjectID]], list[ObjectID]] | None = None,
+        commit_filter: Callable[[Commit, ObjectID], ObjectID | None] | None = None,
         subdirectory_filter: bytes | None = None,
         subdirectory_filter: bytes | None = None,
         prune_empty: bool = False,
         prune_empty: bool = False,
         tag_name_filter: Callable[[bytes], bytes | None] | None = None,
         tag_name_filter: Callable[[bytes], bytes | None] | None = None,
@@ -101,13 +101,13 @@ class CommitFilter:
         self.subdirectory_filter = subdirectory_filter
         self.subdirectory_filter = subdirectory_filter
         self.prune_empty = prune_empty
         self.prune_empty = prune_empty
         self.tag_name_filter = tag_name_filter
         self.tag_name_filter = tag_name_filter
-        self._old_to_new: dict[bytes, bytes] = {}
-        self._processed: set[bytes] = set()
-        self._tree_cache: dict[bytes, bytes] = {}  # Cache for filtered trees
+        self._old_to_new: dict[ObjectID, ObjectID] = {}
+        self._processed: set[ObjectID] = set()
+        self._tree_cache: dict[ObjectID, ObjectID] = {}  # Cache for filtered trees
 
 
     def _filter_tree_with_subdirectory(
     def _filter_tree_with_subdirectory(
-        self, tree_sha: bytes, subdirectory: bytes
-    ) -> bytes | None:
+        self, tree_sha: ObjectID, subdirectory: bytes
+    ) -> ObjectID | None:
         """Extract a subdirectory from a tree as the new root.
         """Extract a subdirectory from a tree as the new root.
 
 
         Args:
         Args:
@@ -153,7 +153,7 @@ class CommitFilter:
         # Return the subdirectory tree
         # Return the subdirectory tree
         return current_tree.id
         return current_tree.id
 
 
-    def _apply_tree_filter(self, tree_sha: bytes) -> bytes:
+    def _apply_tree_filter(self, tree_sha: ObjectID) -> ObjectID:
         """Apply tree filter by checking out tree and running filter.
         """Apply tree filter by checking out tree and running filter.
 
 
         Args:
         Args:
@@ -181,7 +181,7 @@ class CommitFilter:
             self._tree_cache[tree_sha] = new_tree_sha
             self._tree_cache[tree_sha] = new_tree_sha
             return new_tree_sha
             return new_tree_sha
 
 
-    def _apply_index_filter(self, tree_sha: bytes) -> bytes:
+    def _apply_index_filter(self, tree_sha: ObjectID) -> ObjectID:
         """Apply index filter by creating temp index and running filter.
         """Apply index filter by creating temp index and running filter.
 
 
         Args:
         Args:
@@ -217,7 +217,7 @@ class CommitFilter:
         finally:
         finally:
             os.unlink(tmp_index_path)
             os.unlink(tmp_index_path)
 
 
-    def process_commit(self, commit_sha: bytes) -> bytes | None:
+    def process_commit(self, commit_sha: ObjectID) -> ObjectID | None:
         """Process a single commit, creating a filtered version.
         """Process a single commit, creating a filtered version.
 
 
         Args:
         Args:
@@ -366,7 +366,7 @@ class CommitFilter:
             self._old_to_new[commit_sha] = commit_sha
             self._old_to_new[commit_sha] = commit_sha
             return commit_sha
             return commit_sha
 
 
-    def get_mapping(self) -> dict[bytes, bytes]:
+    def get_mapping(self) -> dict[ObjectID, ObjectID]:
         """Get the mapping of old commit SHAs to new commit SHAs.
         """Get the mapping of old commit SHAs to new commit SHAs.
 
 
         Returns:
         Returns:
@@ -383,8 +383,8 @@ def filter_refs(
     *,
     *,
     keep_original: bool = True,
     keep_original: bool = True,
     force: bool = False,
     force: bool = False,
-    tag_callback: Callable[[bytes, bytes], None] | None = None,
-) -> dict[bytes, bytes]:
+    tag_callback: Callable[[Ref, Ref], None] | None = None,
+) -> dict[ObjectID, ObjectID]:
     """Filter commits reachable from the given refs.
     """Filter commits reachable from the given refs.
 
 
     Args:
     Args:
@@ -405,7 +405,7 @@ def filter_refs(
     # Check if already filtered
     # Check if already filtered
     if keep_original and not force:
     if keep_original and not force:
         for ref in ref_names:
         for ref in ref_names:
-            original_ref = b"refs/original/" + ref
+            original_ref = Ref(b"refs/original/" + ref)
             if original_ref in refs:
             if original_ref in refs:
                 raise ValueError(
                 raise ValueError(
                     f"Branch {ref.decode()} appears to have been filtered already. "
                     f"Branch {ref.decode()} appears to have been filtered already. "
@@ -416,8 +416,9 @@ def filter_refs(
     for ref in ref_names:
     for ref in ref_names:
         try:
         try:
             # Get the commit SHA for this ref
             # Get the commit SHA for this ref
-            if ref in refs:
-                ref_sha = refs[ref]
+            ref_obj = Ref(ref)
+            if ref_obj in refs:
+                ref_sha = refs[ref_obj]
                 if ref_sha:
                 if ref_sha:
                     commit_filter.process_commit(ref_sha)
                     commit_filter.process_commit(ref_sha)
         except KeyError:
         except KeyError:
@@ -429,18 +430,19 @@ def filter_refs(
     mapping = commit_filter.get_mapping()
     mapping = commit_filter.get_mapping()
     for ref in ref_names:
     for ref in ref_names:
         try:
         try:
-            if ref in refs:
-                old_sha = refs[ref]
+            ref_obj = Ref(ref)
+            if ref_obj in refs:
+                old_sha = refs[ref_obj]
                 new_sha = mapping.get(old_sha, old_sha)
                 new_sha = mapping.get(old_sha, old_sha)
 
 
                 if old_sha != new_sha:
                 if old_sha != new_sha:
                     # Save original ref if requested
                     # Save original ref if requested
                     if keep_original:
                     if keep_original:
-                        original_ref = b"refs/original/" + ref
+                        original_ref = Ref(b"refs/original/" + ref)
                         refs[original_ref] = old_sha
                         refs[original_ref] = old_sha
 
 
                     # Update ref to new commit
                     # Update ref to new commit
-                    refs[ref] = new_sha
+                    refs[ref_obj] = new_sha
         except KeyError:
         except KeyError:
             # Not a valid ref, skip updating
             # Not a valid ref, skip updating
             warnings.warn(f"Could not update ref {ref!r}: ref not found")
             warnings.warn(f"Could not update ref {ref!r}: ref not found")

+ 7 - 7
dulwich/gc.py

@@ -28,7 +28,7 @@ DEFAULT_GC_AUTO_PACK_LIMIT = 50
 class GCStats:
 class GCStats:
     """Statistics from garbage collection."""
     """Statistics from garbage collection."""
 
 
-    pruned_objects: set[bytes] = field(default_factory=set)
+    pruned_objects: set[ObjectID] = field(default_factory=set)
     bytes_freed: int = 0
     bytes_freed: int = 0
     packs_before: int = 0
     packs_before: int = 0
     packs_after: int = 0
     packs_after: int = 0
@@ -41,7 +41,7 @@ def find_reachable_objects(
     refs_container: RefsContainer,
     refs_container: RefsContainer,
     include_reflogs: bool = True,
     include_reflogs: bool = True,
     progress: Callable[[str], None] | None = None,
     progress: Callable[[str], None] | None = None,
-) -> set[bytes]:
+) -> set[ObjectID]:
     """Find all reachable objects in the repository.
     """Find all reachable objects in the repository.
 
 
     Args:
     Args:
@@ -53,7 +53,7 @@ def find_reachable_objects(
     Returns:
     Returns:
         Set of reachable object SHAs
         Set of reachable object SHAs
     """
     """
-    reachable = set()
+    reachable: set[ObjectID] = set()
     pending: deque[ObjectID] = deque()
     pending: deque[ObjectID] = deque()
 
 
     # Start with all refs
     # Start with all refs
@@ -115,7 +115,7 @@ def find_unreachable_objects(
     refs_container: RefsContainer,
     refs_container: RefsContainer,
     include_reflogs: bool = True,
     include_reflogs: bool = True,
     progress: Callable[[str], None] | None = None,
     progress: Callable[[str], None] | None = None,
-) -> set[bytes]:
+) -> set[ObjectID]:
     """Find all unreachable objects in the repository.
     """Find all unreachable objects in the repository.
 
 
     Args:
     Args:
@@ -131,7 +131,7 @@ def find_unreachable_objects(
         object_store, refs_container, include_reflogs, progress
         object_store, refs_container, include_reflogs, progress
     )
     )
 
 
-    unreachable = set()
+    unreachable: set[ObjectID] = set()
     for sha in object_store:
     for sha in object_store:
         if sha not in reachable:
         if sha not in reachable:
             unreachable.add(sha)
             unreachable.add(sha)
@@ -145,7 +145,7 @@ def prune_unreachable_objects(
     grace_period: int | None = None,
     grace_period: int | None = None,
     dry_run: bool = False,
     dry_run: bool = False,
     progress: Callable[[str], None] | None = None,
     progress: Callable[[str], None] | None = None,
-) -> tuple[set[bytes], int]:
+) -> tuple[set[ObjectID], int]:
     """Remove unreachable objects from the repository.
     """Remove unreachable objects from the repository.
 
 
     Args:
     Args:
@@ -162,7 +162,7 @@ def prune_unreachable_objects(
         object_store, refs_container, progress=progress
         object_store, refs_container, progress=progress
     )
     )
 
 
-    pruned = set()
+    pruned: set[ObjectID] = set()
     bytes_freed = 0
     bytes_freed = 0
 
 
     for sha in unreachable:
     for sha in unreachable:

+ 3 - 3
dulwich/graph.py

@@ -96,7 +96,7 @@ def _find_lcas(
         List of lowest common ancestor commit IDs
         List of lowest common ancestor commit IDs
     """
     """
     cands = []
     cands = []
-    cstates = {}
+    cstates: dict[ObjectID, int] = {}
 
 
     # Flags to Record State
     # Flags to Record State
     _ANC_OF_1 = 1  # ancestor of commit 1
     _ANC_OF_1 = 1  # ancestor of commit 1
@@ -124,7 +124,7 @@ def _find_lcas(
 
 
     # initialize the working list states with ancestry info
     # initialize the working list states with ancestry info
     # note possibility of c1 being one of c2s should be handled
     # note possibility of c1 being one of c2s should be handled
-    wlst: WorkList[bytes] = WorkList()
+    wlst: WorkList[ObjectID] = WorkList()
     cstates[c1] = _ANC_OF_1
     cstates[c1] = _ANC_OF_1
     try:
     try:
         wlst.add((lookup_stamp(c1), c1))
         wlst.add((lookup_stamp(c1), c1))
@@ -298,7 +298,7 @@ def find_octopus_base(
     return lcas
     return lcas
 
 
 
 
-def can_fast_forward(repo: "BaseRepo", c1: bytes, c2: bytes) -> bool:
+def can_fast_forward(repo: "BaseRepo", c1: ObjectID, c2: ObjectID) -> bool:
     """Is it possible to fast-forward from c1 to c2?
     """Is it possible to fast-forward from c1 to c2?
 
 
     Args:
     Args:

+ 1 - 1
dulwich/greenthreads.py

@@ -137,4 +137,4 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
             self.progress: Callable[[bytes], None] = lambda x: None
             self.progress: Callable[[bytes], None] = lambda x: None
         else:
         else:
             self.progress = progress
             self.progress = progress
-        self._tagged = (get_tagged and get_tagged()) or {}
+        self._tagged: dict[ObjectID, ObjectID] = (get_tagged and get_tagged()) or {}

+ 22 - 18
dulwich/index.py

@@ -69,7 +69,7 @@ from .objects import (
 from .pack import ObjectContainer, SHA1Reader, SHA1Writer
 from .pack import ObjectContainer, SHA1Reader, SHA1Writer
 
 
 # Type alias for recursive tree structure used in commit_tree
 # Type alias for recursive tree structure used in commit_tree
-TreeDict = dict[bytes, "TreeDict | tuple[int, bytes]"]
+TreeDict = dict[bytes, "TreeDict | tuple[int, ObjectID]"]
 
 
 # 2-bit stage (during merge)
 # 2-bit stage (during merge)
 FLAG_STAGEMASK = 0x3000
 FLAG_STAGEMASK = 0x3000
@@ -294,7 +294,7 @@ class SerializedIndexEntry:
     uid: int
     uid: int
     gid: int
     gid: int
     size: int
     size: int
-    sha: bytes
+    sha: ObjectID
     flags: int
     flags: int
     extended_flags: int
     extended_flags: int
 
 
@@ -505,7 +505,7 @@ class IndexEntry:
     uid: int
     uid: int
     gid: int
     gid: int
     size: int
     size: int
-    sha: bytes
+    sha: ObjectID
     flags: int = 0
     flags: int = 0
     extended_flags: int = 0
     extended_flags: int = 0
 
 
@@ -1168,7 +1168,7 @@ class Index:
         """Check if a path exists in the index."""
         """Check if a path exists in the index."""
         return key in self._byname
         return key in self._byname
 
 
-    def get_sha1(self, path: bytes) -> bytes:
+    def get_sha1(self, path: bytes) -> ObjectID:
         """Return the (git object) SHA1 for the object at a path."""
         """Return the (git object) SHA1 for the object at a path."""
         value = self[path]
         value = self[path]
         if isinstance(value, ConflictedIndexEntry):
         if isinstance(value, ConflictedIndexEntry):
@@ -1182,7 +1182,7 @@ class Index:
             raise UnmergedEntries
             raise UnmergedEntries
         return value.mode
         return value.mode
 
 
-    def iterobjects(self) -> Iterable[tuple[bytes, bytes, int]]:
+    def iterobjects(self) -> Iterable[tuple[bytes, ObjectID, int]]:
         """Iterate over path, sha, mode tuples for use with commit_tree."""
         """Iterate over path, sha, mode tuples for use with commit_tree."""
         for path in self:
         for path in self:
             entry = self[path]
             entry = self[path]
@@ -1291,7 +1291,7 @@ class Index:
             want_unchanged=want_unchanged,
             want_unchanged=want_unchanged,
         )
         )
 
 
-    def commit(self, object_store: ObjectContainer) -> bytes:
+    def commit(self, object_store: ObjectContainer) -> ObjectID:
         """Create a new tree from an index.
         """Create a new tree from an index.
 
 
         Args:
         Args:
@@ -1398,7 +1398,7 @@ class Index:
     def convert_to_sparse(
     def convert_to_sparse(
         self,
         self,
         object_store: "BaseObjectStore",
         object_store: "BaseObjectStore",
-        tree_sha: bytes,
+        tree_sha: ObjectID,
         sparse_dirs: Set[bytes],
         sparse_dirs: Set[bytes],
     ) -> None:
     ) -> None:
         """Convert full index entries to sparse directory entries.
         """Convert full index entries to sparse directory entries.
@@ -1443,6 +1443,8 @@ class Index:
 
 
             # Create a sparse directory entry
             # Create a sparse directory entry
             # Use minimal metadata since it's not a real file
             # Use minimal metadata since it's not a real file
+            from dulwich.objects import ObjectID
+
             sparse_entry = IndexEntry(
             sparse_entry = IndexEntry(
                 ctime=0,
                 ctime=0,
                 mtime=0,
                 mtime=0,
@@ -1452,7 +1454,7 @@ class Index:
                 uid=0,
                 uid=0,
                 gid=0,
                 gid=0,
                 size=0,
                 size=0,
-                sha=subtree_sha,
+                sha=ObjectID(subtree_sha),
                 flags=0,
                 flags=0,
                 extended_flags=EXTENDED_FLAG_SKIP_WORKTREE,
                 extended_flags=EXTENDED_FLAG_SKIP_WORKTREE,
             )
             )
@@ -1505,8 +1507,8 @@ class Index:
 
 
 
 
 def commit_tree(
 def commit_tree(
-    object_store: ObjectContainer, blobs: Iterable[tuple[bytes, bytes, int]]
-) -> bytes:
+    object_store: ObjectContainer, blobs: Iterable[tuple[bytes, ObjectID, int]]
+) -> ObjectID:
     """Commit a new tree.
     """Commit a new tree.
 
 
     Args:
     Args:
@@ -1533,7 +1535,7 @@ def commit_tree(
         tree = add_tree(tree_path)
         tree = add_tree(tree_path)
         tree[basename] = (mode, sha)
         tree[basename] = (mode, sha)
 
 
-    def build_tree(path: bytes) -> bytes:
+    def build_tree(path: bytes) -> ObjectID:
         tree = Tree()
         tree = Tree()
         for basename, entry in trees[path].items():
         for basename, entry in trees[path].items():
             if isinstance(entry, dict):
             if isinstance(entry, dict):
@@ -1548,7 +1550,7 @@ def commit_tree(
     return build_tree(b"")
     return build_tree(b"")
 
 
 
 
-def commit_index(object_store: ObjectContainer, index: Index) -> bytes:
+def commit_index(object_store: ObjectContainer, index: Index) -> ObjectID:
     """Create a new tree from an index.
     """Create a new tree from an index.
 
 
     Args:
     Args:
@@ -1564,7 +1566,7 @@ def changes_from_tree(
     names: Iterable[bytes],
     names: Iterable[bytes],
     lookup_entry: Callable[[bytes], tuple[bytes, int]],
     lookup_entry: Callable[[bytes], tuple[bytes, int]],
     object_store: ObjectContainer,
     object_store: ObjectContainer,
-    tree: bytes | None,
+    tree: ObjectID | None,
     want_unchanged: bool = False,
     want_unchanged: bool = False,
 ) -> Iterable[
 ) -> Iterable[
     tuple[
     tuple[
@@ -1625,6 +1627,8 @@ def index_entry_from_stat(
     if mode is None:
     if mode is None:
         mode = cleanup_mode(stat_val.st_mode)
         mode = cleanup_mode(stat_val.st_mode)
 
 
+    from dulwich.objects import ObjectID
+
     return IndexEntry(
     return IndexEntry(
         ctime=stat_val.st_ctime,
         ctime=stat_val.st_ctime,
         mtime=stat_val.st_mtime,
         mtime=stat_val.st_mtime,
@@ -1634,7 +1638,7 @@ def index_entry_from_stat(
         uid=stat_val.st_uid,
         uid=stat_val.st_uid,
         gid=stat_val.st_gid,
         gid=stat_val.st_gid,
         size=stat_val.st_size,
         size=stat_val.st_size,
-        sha=hex_sha,
+        sha=ObjectID(hex_sha),
         flags=0,
         flags=0,
         extended_flags=0,
         extended_flags=0,
     )
     )
@@ -1884,7 +1888,7 @@ def build_index_from_tree(
     root_path: str | bytes,
     root_path: str | bytes,
     index_path: str | bytes,
     index_path: str | bytes,
     object_store: ObjectContainer,
     object_store: ObjectContainer,
-    tree_id: bytes,
+    tree_id: ObjectID,
     honor_filemode: bool = True,
     honor_filemode: bool = True,
     validate_path_element: Callable[[bytes], bool] = validate_path_element_default,
     validate_path_element: Callable[[bytes], bool] = validate_path_element_default,
     symlink_fn: Callable[
     symlink_fn: Callable[
@@ -2132,7 +2136,7 @@ def _remove_empty_parents(path: bytes, stop_at: bytes) -> None:
 
 
 
 
 def _check_symlink_matches(
 def _check_symlink_matches(
-    full_path: bytes, repo_object_store: "BaseObjectStore", entry_sha: bytes
+    full_path: bytes, repo_object_store: "BaseObjectStore", entry_sha: ObjectID
 ) -> bool:
 ) -> bool:
     """Check if symlink target matches expected target.
     """Check if symlink target matches expected target.
 
 
@@ -2158,7 +2162,7 @@ def _check_symlink_matches(
 def _check_file_matches(
 def _check_file_matches(
     repo_object_store: "BaseObjectStore",
     repo_object_store: "BaseObjectStore",
     full_path: bytes,
     full_path: bytes,
-    entry_sha: bytes,
+    entry_sha: ObjectID,
     entry_mode: int,
     entry_mode: int,
     current_stat: os.stat_result,
     current_stat: os.stat_result,
     honor_filemode: bool,
     honor_filemode: bool,
@@ -3045,7 +3049,7 @@ def iter_fresh_objects(
     root_path: bytes,
     root_path: bytes,
     include_deleted: bool = False,
     include_deleted: bool = False,
     object_store: ObjectContainer | None = None,
     object_store: ObjectContainer | None = None,
-) -> Iterator[tuple[bytes, bytes | None, int | None]]:
+) -> Iterator[tuple[bytes, ObjectID | None, int | None]]:
     """Iterate over versions of objects on disk referenced by index.
     """Iterate over versions of objects on disk referenced by index.
 
 
     Args:
     Args:

+ 5 - 5
dulwich/merge.py

@@ -16,7 +16,7 @@ from dulwich.attrs import GitAttributes
 from dulwich.config import Config
 from dulwich.config import Config
 from dulwich.merge_drivers import get_merge_driver_registry
 from dulwich.merge_drivers import get_merge_driver_registry
 from dulwich.object_store import BaseObjectStore
 from dulwich.object_store import BaseObjectStore
-from dulwich.objects import S_ISGITLINK, Blob, Commit, Tree, is_blob, is_tree
+from dulwich.objects import S_ISGITLINK, Blob, Commit, ObjectID, Tree, is_blob, is_tree
 
 
 
 
 def make_merge3(
 def make_merge3(
@@ -303,7 +303,7 @@ class Merger:
             tuple of (merged_tree, list_of_conflicted_paths)
             tuple of (merged_tree, list_of_conflicted_paths)
         """
         """
         conflicts: list[bytes] = []
         conflicts: list[bytes] = []
-        merged_entries: dict[bytes, tuple[int | None, bytes | None]] = {}
+        merged_entries: dict[bytes, tuple[int | None, ObjectID | None]] = {}
 
 
         # Get all paths from all trees
         # Get all paths from all trees
         all_paths = set()
         all_paths = set()
@@ -481,7 +481,7 @@ class Merger:
 def _create_virtual_commit(
 def _create_virtual_commit(
     object_store: BaseObjectStore,
     object_store: BaseObjectStore,
     tree: Tree,
     tree: Tree,
-    parents: list[bytes],
+    parents: list[ObjectID],
     message: bytes = b"Virtual merge base",
     message: bytes = b"Virtual merge base",
 ) -> Commit:
 ) -> Commit:
     """Create a virtual commit object for recursive merging.
     """Create a virtual commit object for recursive merging.
@@ -519,7 +519,7 @@ def _create_virtual_commit(
 
 
 def recursive_merge(
 def recursive_merge(
     object_store: BaseObjectStore,
     object_store: BaseObjectStore,
-    merge_bases: list[bytes],
+    merge_bases: list[ObjectID],
     ours_commit: Commit,
     ours_commit: Commit,
     theirs_commit: Commit,
     theirs_commit: Commit,
     gitattributes: GitAttributes | None = None,
     gitattributes: GitAttributes | None = None,
@@ -671,7 +671,7 @@ def three_way_merge(
 
 
 def octopus_merge(
 def octopus_merge(
     object_store: BaseObjectStore,
     object_store: BaseObjectStore,
-    merge_bases: list[bytes],
+    merge_bases: list[ObjectID],
     head_commit: Commit,
     head_commit: Commit,
     other_commits: list[Commit],
     other_commits: list[Commit],
     gitattributes: GitAttributes | None = None,
     gitattributes: GitAttributes | None = None,

+ 15 - 12
dulwich/notes.py

@@ -24,7 +24,8 @@ import stat
 from collections.abc import Iterator, Sequence
 from collections.abc import Iterator, Sequence
 from typing import TYPE_CHECKING
 from typing import TYPE_CHECKING
 
 
-from .objects import Blob, Tree
+from .objects import Blob, ObjectID, Tree
+from .refs import Ref
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .config import StackedConfig
     from .config import StackedConfig
@@ -240,7 +241,7 @@ class NotesTree:
 
 
             # Build new tree structure
             # Build new tree structure
             def update_tree(
             def update_tree(
-                tree: Tree, components: Sequence[bytes], blob_sha: bytes
+                tree: Tree, components: Sequence[bytes], blob_sha: ObjectID
             ) -> Tree:
             ) -> Tree:
                 """Update tree with new note entry.
                 """Update tree with new note entry.
 
 
@@ -302,7 +303,7 @@ class NotesTree:
             self._object_store.add_object(self._tree)
             self._object_store.add_object(self._tree)
 
 
     def _update_tree_entry(
     def _update_tree_entry(
-        self, tree: Tree, name: bytes, mode: int, sha: bytes
+        self, tree: Tree, name: bytes, mode: int, sha: ObjectID
     ) -> Tree:
     ) -> Tree:
         """Update a tree entry and return the updated tree.
         """Update a tree entry and return the updated tree.
 
 
@@ -333,7 +334,7 @@ class NotesTree:
 
 
         return new_tree
         return new_tree
 
 
-    def _get_note_sha(self, object_sha: bytes) -> bytes | None:
+    def _get_note_sha(self, object_sha: bytes) -> ObjectID | None:
         """Get the SHA of the note blob for an object.
         """Get the SHA of the note blob for an object.
 
 
         Args:
         Args:
@@ -412,7 +413,7 @@ class NotesTree:
 
 
         # Build new tree structure
         # Build new tree structure
         def update_tree(
         def update_tree(
-            tree: Tree, components: Sequence[bytes], blob_sha: bytes
+            tree: Tree, components: Sequence[bytes], blob_sha: ObjectID
         ) -> Tree:
         ) -> Tree:
             """Update tree with new note entry.
             """Update tree with new note entry.
 
 
@@ -546,14 +547,16 @@ class NotesTree:
         self._fanout_level = self._detect_fanout_level()
         self._fanout_level = self._detect_fanout_level()
         return new_tree
         return new_tree
 
 
-    def list_notes(self) -> Iterator[tuple[bytes, bytes]]:
+    def list_notes(self) -> Iterator[tuple[ObjectID, ObjectID]]:
         """List all notes in this tree.
         """List all notes in this tree.
 
 
         Yields:
         Yields:
             Tuples of (object_sha, note_sha)
             Tuples of (object_sha, note_sha)
         """
         """
 
 
-        def walk_tree(tree: Tree, prefix: bytes = b"") -> Iterator[tuple[bytes, bytes]]:
+        def walk_tree(
+            tree: Tree, prefix: bytes = b""
+        ) -> Iterator[tuple[ObjectID, ObjectID]]:
             """Walk the notes tree recursively.
             """Walk the notes tree recursively.
 
 
             Args:
             Args:
@@ -572,7 +575,7 @@ class NotesTree:
                 elif stat.S_ISREG(mode):  # File
                 elif stat.S_ISREG(mode):  # File
                     # Reconstruct the full hex SHA from the path
                     # Reconstruct the full hex SHA from the path
                     full_hex = prefix + name
                     full_hex = prefix + name
-                    yield (full_hex, sha)
+                    yield (ObjectID(full_hex), sha)
 
 
         yield from walk_tree(self._tree)
         yield from walk_tree(self._tree)
 
 
@@ -610,7 +613,7 @@ class Notes:
         self,
         self,
         notes_ref: bytes | None = None,
         notes_ref: bytes | None = None,
         config: "StackedConfig | None" = None,
         config: "StackedConfig | None" = None,
-    ) -> bytes:
+    ) -> Ref:
         """Get the notes reference to use.
         """Get the notes reference to use.
 
 
         Args:
         Args:
@@ -625,7 +628,7 @@ class Notes:
                 notes_ref = config.get((b"notes",), b"displayRef")
                 notes_ref = config.get((b"notes",), b"displayRef")
             if notes_ref is None:
             if notes_ref is None:
                 notes_ref = DEFAULT_NOTES_REF
                 notes_ref = DEFAULT_NOTES_REF
-        return notes_ref
+        return Ref(notes_ref)
 
 
     def get_note(
     def get_note(
         self,
         self,
@@ -838,7 +841,7 @@ class Notes:
         self,
         self,
         notes_ref: bytes | None = None,
         notes_ref: bytes | None = None,
         config: "StackedConfig | None" = None,
         config: "StackedConfig | None" = None,
-    ) -> list[tuple[bytes, bytes]]:
+    ) -> list[tuple[ObjectID, bytes]]:
         """List all notes in a notes ref.
         """List all notes in a notes ref.
 
 
         Args:
         Args:
@@ -870,7 +873,7 @@ class Notes:
             return []
             return []
 
 
         notes_tree_obj = NotesTree(notes_tree, self._object_store)
         notes_tree_obj = NotesTree(notes_tree, self._object_store)
-        result = []
+        result: list[tuple[ObjectID, bytes]] = []
         for object_sha, note_sha in notes_tree_obj.list_notes():
         for object_sha, note_sha in notes_tree_obj.list_notes():
             note_obj = self._object_store[note_sha]
             note_obj = self._object_store[note_sha]
             if isinstance(note_obj, Blob):
             if isinstance(note_obj, Blob):

+ 157 - 163
dulwich/object_store.py

@@ -37,6 +37,7 @@ from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
     BinaryIO,
     BinaryIO,
     Protocol,
     Protocol,
+    cast,
 )
 )
 
 
 from .errors import NotTreeError
 from .errors import NotTreeError
@@ -47,6 +48,7 @@ from .objects import (
     Blob,
     Blob,
     Commit,
     Commit,
     ObjectID,
     ObjectID,
+    RawObjectID,
     ShaFile,
     ShaFile,
     Tag,
     Tag,
     Tree,
     Tree,
@@ -78,8 +80,8 @@ from .pack import (
     write_pack_data,
     write_pack_data,
     write_pack_index,
     write_pack_index,
 )
 )
-from .protocol import DEPTH_INFINITE
-from .refs import PEELED_TAG_SUFFIX, Ref
+from .protocol import DEPTH_INFINITE, PEELED_TAG_SUFFIX
+from .refs import Ref
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .bitmap import EWAHBitmap
     from .bitmap import EWAHBitmap
@@ -92,11 +94,11 @@ if TYPE_CHECKING:
 class GraphWalker(Protocol):
 class GraphWalker(Protocol):
     """Protocol for graph walker objects."""
     """Protocol for graph walker objects."""
 
 
-    def __next__(self) -> bytes | None:
+    def __next__(self) -> ObjectID | None:
         """Return the next object SHA to visit."""
         """Return the next object SHA to visit."""
         ...
         ...
 
 
-    def ack(self, sha: bytes) -> None:
+    def ack(self, sha: ObjectID) -> None:
         """Acknowledge that an object has been received."""
         """Acknowledge that an object has been received."""
         ...
         ...
 
 
@@ -114,10 +116,10 @@ class ObjectReachabilityProvider(Protocol):
 
 
     def get_reachable_commits(
     def get_reachable_commits(
         self,
         self,
-        heads: Iterable[bytes],
-        exclude: Iterable[bytes] | None = None,
-        shallow: Set[bytes] | None = None,
-    ) -> set[bytes]:
+        heads: Iterable[ObjectID],
+        exclude: Iterable[ObjectID] | None = None,
+        shallow: Set[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all commits reachable from heads, excluding those in exclude.
         """Get all commits reachable from heads, excluding those in exclude.
 
 
         Args:
         Args:
@@ -132,9 +134,9 @@ class ObjectReachabilityProvider(Protocol):
 
 
     def get_reachable_objects(
     def get_reachable_objects(
         self,
         self,
-        commits: Iterable[bytes],
-        exclude_commits: Iterable[bytes] | None = None,
-    ) -> set[bytes]:
+        commits: Iterable[ObjectID],
+        exclude_commits: Iterable[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all objects (commits + trees + blobs) reachable from commits.
         """Get all objects (commits + trees + blobs) reachable from commits.
 
 
         Args:
         Args:
@@ -148,8 +150,8 @@ class ObjectReachabilityProvider(Protocol):
 
 
     def get_tree_objects(
     def get_tree_objects(
         self,
         self,
-        tree_shas: Iterable[bytes],
-    ) -> set[bytes]:
+        tree_shas: Iterable[ObjectID],
+    ) -> set[ObjectID]:
         """Get all trees and blobs reachable from the given trees.
         """Get all trees and blobs reachable from the given trees.
 
 
         Args:
         Args:
@@ -175,8 +177,8 @@ DEFAULT_TEMPFILE_GRACE_PERIOD = 14 * 24 * 60 * 60  # 2 weeks
 
 
 
 
 def find_shallow(
 def find_shallow(
-    store: ObjectContainer, heads: Iterable[bytes], depth: int
-) -> tuple[set[bytes], set[bytes]]:
+    store: ObjectContainer, heads: Iterable[ObjectID], depth: int
+) -> tuple[set[ObjectID], set[ObjectID]]:
     """Find shallow commits according to a given depth.
     """Find shallow commits according to a given depth.
 
 
     Args:
     Args:
@@ -188,10 +190,10 @@ def find_shallow(
         considered shallow and unshallow according to the arguments. Note that
         considered shallow and unshallow according to the arguments. Note that
         these sets may overlap if a commit is reachable along multiple paths.
         these sets may overlap if a commit is reachable along multiple paths.
     """
     """
-    parents: dict[bytes, list[bytes]] = {}
+    parents: dict[ObjectID, list[ObjectID]] = {}
     commit_graph = store.get_commit_graph()
     commit_graph = store.get_commit_graph()
 
 
-    def get_parents(sha: bytes) -> list[bytes]:
+    def get_parents(sha: ObjectID) -> list[ObjectID]:
         result = parents.get(sha, None)
         result = parents.get(sha, None)
         if not result:
         if not result:
             # Try to use commit graph first if available
             # Try to use commit graph first if available
@@ -234,8 +236,8 @@ def find_shallow(
 
 
 def get_depth(
 def get_depth(
     store: ObjectContainer,
     store: ObjectContainer,
-    head: bytes,
-    get_parents: Callable[..., list[bytes]] = lambda commit: commit.parents,
+    head: ObjectID,
+    get_parents: Callable[..., list[ObjectID]] = lambda commit: commit.parents,
     max_depth: int | None = None,
     max_depth: int | None = None,
 ) -> int:
 ) -> int:
     """Return the current available depth for the given head.
     """Return the current available depth for the given head.
@@ -291,7 +293,7 @@ class BaseObjectStore:
     ) -> list[ObjectID]:
     ) -> list[ObjectID]:
         """Determine which objects are wanted based on refs."""
         """Determine which objects are wanted based on refs."""
 
 
-        def _want_deepen(sha: bytes) -> bool:
+        def _want_deepen(sha: ObjectID) -> bool:
             if not depth:
             if not depth:
                 return False
                 return False
             if depth == DEPTH_INFINITE:
             if depth == DEPTH_INFINITE:
@@ -306,15 +308,15 @@ class BaseObjectStore:
             and not sha == ZERO_SHA
             and not sha == ZERO_SHA
         ]
         ]
 
 
-    def contains_loose(self, sha: bytes) -> bool:
+    def contains_loose(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is loose."""
         """Check if a particular object is present by SHA1 and is loose."""
         raise NotImplementedError(self.contains_loose)
         raise NotImplementedError(self.contains_loose)
 
 
-    def contains_packed(self, sha: bytes) -> bool:
+    def contains_packed(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is packed."""
         """Check if a particular object is present by SHA1 and is packed."""
         return False  # Default implementation for stores that don't support packing
         return False  # Default implementation for stores that don't support packing
 
 
-    def __contains__(self, sha1: bytes) -> bool:
+    def __contains__(self, sha1: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1.
         """Check if a particular object is present by SHA1.
 
 
         This method makes no distinction between loose and packed objects.
         This method makes no distinction between loose and packed objects.
@@ -326,7 +328,7 @@ class BaseObjectStore:
         """Iterable of pack objects."""
         """Iterable of pack objects."""
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def get_raw(self, name: bytes) -> tuple[int, bytes]:
+    def get_raw(self, name: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
 
 
         Args:
         Args:
@@ -335,12 +337,12 @@ class BaseObjectStore:
         """
         """
         raise NotImplementedError(self.get_raw)
         raise NotImplementedError(self.get_raw)
 
 
-    def __getitem__(self, sha1: ObjectID) -> ShaFile:
+    def __getitem__(self, sha1: ObjectID | RawObjectID) -> ShaFile:
         """Obtain an object by SHA1."""
         """Obtain an object by SHA1."""
         type_num, uncomp = self.get_raw(sha1)
         type_num, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
 
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs that are present in this store."""
         """Iterate over the SHAs that are present in this store."""
         raise NotImplementedError(self.__iter__)
         raise NotImplementedError(self.__iter__)
 
 
@@ -381,8 +383,8 @@ class BaseObjectStore:
 
 
     def tree_changes(
     def tree_changes(
         self,
         self,
-        source: bytes | None,
-        target: bytes | None,
+        source: ObjectID | None,
+        target: ObjectID | None,
         want_unchanged: bool = False,
         want_unchanged: bool = False,
         include_trees: bool = False,
         include_trees: bool = False,
         change_type_same: bool = False,
         change_type_same: bool = False,
@@ -392,7 +394,7 @@ class BaseObjectStore:
         tuple[
         tuple[
             tuple[bytes | None, bytes | None],
             tuple[bytes | None, bytes | None],
             tuple[int | None, int | None],
             tuple[int | None, int | None],
-            tuple[bytes | None, bytes | None],
+            tuple[ObjectID | None, ObjectID | None],
         ]
         ]
     ]:
     ]:
         """Find the differences between the contents of two trees.
         """Find the differences between the contents of two trees.
@@ -434,7 +436,7 @@ class BaseObjectStore:
             )
             )
 
 
     def iter_tree_contents(
     def iter_tree_contents(
-        self, tree_id: bytes, include_trees: bool = False
+        self, tree_id: ObjectID, include_trees: bool = False
     ) -> Iterator[TreeEntry]:
     ) -> Iterator[TreeEntry]:
         """Iterate the contents of a tree and all subtrees.
         """Iterate the contents of a tree and all subtrees.
 
 
@@ -454,7 +456,7 @@ class BaseObjectStore:
         return iter_tree_contents(self, tree_id, include_trees=include_trees)
         return iter_tree_contents(self, tree_id, include_trees=include_trees)
 
 
     def iterobjects_subset(
     def iterobjects_subset(
-        self, shas: Iterable[bytes], *, allow_missing: bool = False
+        self, shas: Iterable[ObjectID], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
     ) -> Iterator[ShaFile]:
         """Iterate over a subset of objects in the store.
         """Iterate over a subset of objects in the store.
 
 
@@ -477,7 +479,7 @@ class BaseObjectStore:
 
 
     def iter_unpacked_subset(
     def iter_unpacked_subset(
         self,
         self,
-        shas: Iterable[bytes],
+        shas: Iterable[ObjectID | RawObjectID],
         include_comp: bool = False,
         include_comp: bool = False,
         allow_missing: bool = False,
         allow_missing: bool = False,
         convert_ofs_delta: bool = True,
         convert_ofs_delta: bool = True,
@@ -518,13 +520,13 @@ class BaseObjectStore:
 
 
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
-        haves: Iterable[bytes],
-        wants: Iterable[bytes],
-        shallow: Set[bytes] | None = None,
+        haves: Iterable[ObjectID],
+        wants: Iterable[ObjectID],
+        shallow: Set[ObjectID] | None = None,
         progress: Callable[..., None] | None = None,
         progress: Callable[..., None] | None = None,
-        get_tagged: Callable[[], dict[bytes, bytes]] | None = None,
-        get_parents: Callable[..., list[bytes]] = lambda commit: commit.parents,
-    ) -> Iterator[tuple[bytes, PackHint | None]]:
+        get_tagged: Callable[[], dict[ObjectID, ObjectID]] | None = None,
+        get_parents: Callable[..., list[ObjectID]] = lambda commit: commit.parents,
+    ) -> Iterator[tuple[ObjectID, PackHint | None]]:
         """Find the missing objects required for a set of revisions.
         """Find the missing objects required for a set of revisions.
 
 
         Args:
         Args:
@@ -551,7 +553,7 @@ class BaseObjectStore:
         )
         )
         return iter(finder)
         return iter(finder)
 
 
-    def find_common_revisions(self, graphwalker: GraphWalker) -> list[bytes]:
+    def find_common_revisions(self, graphwalker: GraphWalker) -> list[ObjectID]:
         """Find which revisions this store has in common using graphwalker.
         """Find which revisions this store has in common using graphwalker.
 
 
         Args:
         Args:
@@ -569,10 +571,10 @@ class BaseObjectStore:
 
 
     def generate_pack_data(
     def generate_pack_data(
         self,
         self,
-        have: Iterable[bytes],
-        want: Iterable[bytes],
+        have: Iterable[ObjectID],
+        want: Iterable[ObjectID],
         *,
         *,
-        shallow: Set[bytes] | None = None,
+        shallow: Set[ObjectID] | None = None,
         progress: Callable[..., None] | None = None,
         progress: Callable[..., None] | None = None,
         ofs_delta: bool = True,
         ofs_delta: bool = True,
     ) -> tuple[int, Iterator[UnpackedObject]]:
     ) -> tuple[int, Iterator[UnpackedObject]]:
@@ -597,7 +599,7 @@ class BaseObjectStore:
             progress=progress,
             progress=progress,
         )
         )
 
 
-    def peel_sha(self, sha: bytes) -> bytes:
+    def peel_sha(self, sha: ObjectID | RawObjectID) -> ObjectID:
         """Peel all tags from a SHA.
         """Peel all tags from a SHA.
 
 
         Args:
         Args:
@@ -615,8 +617,8 @@ class BaseObjectStore:
 
 
     def _get_depth(
     def _get_depth(
         self,
         self,
-        head: bytes,
-        get_parents: Callable[..., list[bytes]] = lambda commit: commit.parents,
+        head: ObjectID,
+        get_parents: Callable[..., list[ObjectID]] = lambda commit: commit.parents,
         max_depth: int | None = None,
         max_depth: int | None = None,
     ) -> int:
     ) -> int:
         """Return the current available depth for the given head.
         """Return the current available depth for the given head.
@@ -667,7 +669,7 @@ class BaseObjectStore:
         return None
         return None
 
 
     def write_commit_graph(
     def write_commit_graph(
-        self, refs: Sequence[bytes] | None = None, reachable: bool = True
+        self, refs: Iterable[ObjectID] | None = None, reachable: bool = True
     ) -> None:
     ) -> None:
         """Write a commit graph file for this object store.
         """Write a commit graph file for this object store.
 
 
@@ -682,7 +684,7 @@ class BaseObjectStore:
         """
         """
         raise NotImplementedError(self.write_commit_graph)
         raise NotImplementedError(self.write_commit_graph)
 
 
-    def get_object_mtime(self, sha: bytes) -> float:
+    def get_object_mtime(self, sha: ObjectID) -> float:
         """Get the modification time of an object.
         """Get the modification time of an object.
 
 
         Args:
         Args:
@@ -729,7 +731,7 @@ class PackCapableObjectStore(BaseObjectStore, PackedObjectContainer):
         raise NotImplementedError(self.add_pack_data)
         raise NotImplementedError(self.add_pack_data)
 
 
     def get_unpacked_object(
     def get_unpacked_object(
-        self, sha1: bytes, *, include_comp: bool = False
+        self, sha1: ObjectID | RawObjectID, *, include_comp: bool = False
     ) -> "UnpackedObject":
     ) -> "UnpackedObject":
         """Get a raw unresolved object.
         """Get a raw unresolved object.
 
 
@@ -746,7 +748,7 @@ class PackCapableObjectStore(BaseObjectStore, PackedObjectContainer):
         return UnpackedObject(obj.type_num, sha=sha1, decomp_chunks=obj.as_raw_chunks())
         return UnpackedObject(obj.type_num, sha=sha1, decomp_chunks=obj.as_raw_chunks())
 
 
     def iterobjects_subset(
     def iterobjects_subset(
-        self, shas: Iterable[bytes], *, allow_missing: bool = False
+        self, shas: Iterable[ObjectID], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
     ) -> Iterator[ShaFile]:
         """Iterate over a subset of objects.
         """Iterate over a subset of objects.
 
 
@@ -878,7 +880,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         """Return list of alternate object stores."""
         """Return list of alternate object stores."""
         return []
         return []
 
 
-    def contains_packed(self, sha: bytes) -> bool:
+    def contains_packed(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is packed.
         """Check if a particular object is present by SHA1 and is packed.
 
 
         This does not check alternates.
         This does not check alternates.
@@ -891,7 +893,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
                 pass
                 pass
         return False
         return False
 
 
-    def __contains__(self, sha: bytes) -> bool:
+    def __contains__(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1.
         """Check if a particular object is present by SHA1.
 
 
         This method makes no distinction between loose and packed objects.
         This method makes no distinction between loose and packed objects.
@@ -913,10 +915,10 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
 
 
     def generate_pack_data(
     def generate_pack_data(
         self,
         self,
-        have: Iterable[bytes],
-        want: Iterable[bytes],
+        have: Iterable[ObjectID],
+        want: Iterable[ObjectID],
         *,
         *,
-        shallow: Set[bytes] | None = None,
+        shallow: Set[ObjectID] | None = None,
         progress: Callable[..., None] | None = None,
         progress: Callable[..., None] | None = None,
         ofs_delta: bool = True,
         ofs_delta: bool = True,
     ) -> tuple[int, Iterator[UnpackedObject]]:
     ) -> tuple[int, Iterator[UnpackedObject]]:
@@ -981,19 +983,19 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
                 count += 1
                 count += 1
         return count
         return count
 
 
-    def _iter_alternate_objects(self) -> Iterator[bytes]:
+    def _iter_alternate_objects(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs of all the objects in alternate stores."""
         """Iterate over the SHAs of all the objects in alternate stores."""
         for alternate in self.alternates:
         for alternate in self.alternates:
             yield from alternate
             yield from alternate
 
 
-    def _iter_loose_objects(self) -> Iterator[bytes]:
+    def _iter_loose_objects(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs of all loose objects."""
         """Iterate over the SHAs of all loose objects."""
         raise NotImplementedError(self._iter_loose_objects)
         raise NotImplementedError(self._iter_loose_objects)
 
 
-    def _get_loose_object(self, sha: bytes) -> ShaFile | None:
+    def _get_loose_object(self, sha: ObjectID | RawObjectID) -> ShaFile | None:
         raise NotImplementedError(self._get_loose_object)
         raise NotImplementedError(self._get_loose_object)
 
 
-    def delete_loose_object(self, sha: bytes) -> None:
+    def delete_loose_object(self, sha: ObjectID) -> None:
         """Delete a loose object.
         """Delete a loose object.
 
 
         This method only handles loose objects. For packed objects,
         This method only handles loose objects. For packed objects,
@@ -1079,7 +1081,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
 
 
     def generate_pack_bitmaps(
     def generate_pack_bitmaps(
         self,
         self,
-        refs: dict[bytes, bytes],
+        refs: dict[Ref, ObjectID],
         *,
         *,
         commit_interval: int | None = None,
         commit_interval: int | None = None,
         progress: Callable[[str], None] | None = None,
         progress: Callable[[str], None] | None = None,
@@ -1109,7 +1111,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
 
 
         return count
         return count
 
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs that are present in this store."""
         """Iterate over the SHAs that are present in this store."""
         self._update_pack_cache()
         self._update_pack_cache()
         for pack in self._iter_cached_packs():
         for pack in self._iter_cached_packs():
@@ -1120,14 +1122,14 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         yield from self._iter_loose_objects()
         yield from self._iter_loose_objects()
         yield from self._iter_alternate_objects()
         yield from self._iter_alternate_objects()
 
 
-    def contains_loose(self, sha: bytes) -> bool:
+    def contains_loose(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is loose.
         """Check if a particular object is present by SHA1 and is loose.
 
 
         This does not check alternates.
         This does not check alternates.
         """
         """
         return self._get_loose_object(sha) is not None
         return self._get_loose_object(sha) is not None
 
 
-    def get_raw(self, name: bytes) -> tuple[int, bytes]:
+    def get_raw(self, name: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Obtain the raw fulltext for an object.
         """Obtain the raw fulltext for an object.
 
 
         Args:
         Args:
@@ -1137,10 +1139,10 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         if name == ZERO_SHA:
         if name == ZERO_SHA:
             raise KeyError(name)
             raise KeyError(name)
         if len(name) == 40:
         if len(name) == 40:
-            sha = hex_to_sha(name)
-            hexsha = name
+            sha = hex_to_sha(cast(ObjectID, name))
+            hexsha = cast(ObjectID, name)
         elif len(name) == 20:
         elif len(name) == 20:
-            sha = name
+            sha = cast(RawObjectID, name)
             hexsha = None
             hexsha = None
         else:
         else:
             raise AssertionError(f"Invalid object name {name!r}")
             raise AssertionError(f"Invalid object name {name!r}")
@@ -1150,7 +1152,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
             except (KeyError, PackFileDisappeared):
             except (KeyError, PackFileDisappeared):
                 pass
                 pass
         if hexsha is None:
         if hexsha is None:
-            hexsha = sha_to_hex(name)
+            hexsha = sha_to_hex(sha)
         ret = self._get_loose_object(hexsha)
         ret = self._get_loose_object(hexsha)
         if ret is not None:
         if ret is not None:
             return ret.type_num, ret.as_raw_string()
             return ret.type_num, ret.as_raw_string()
@@ -1170,7 +1172,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
 
 
     def iter_unpacked_subset(
     def iter_unpacked_subset(
         self,
         self,
-        shas: Iterable[bytes],
+        shas: Iterable[ObjectID | RawObjectID],
         include_comp: bool = False,
         include_comp: bool = False,
         allow_missing: bool = False,
         allow_missing: bool = False,
         convert_ofs_delta: bool = True,
         convert_ofs_delta: bool = True,
@@ -1189,7 +1191,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         Raises:
         Raises:
           KeyError: If an object is missing and allow_missing is False
           KeyError: If an object is missing and allow_missing is False
         """
         """
-        todo: set[bytes] = set(shas)
+        todo: set[ObjectID | RawObjectID] = set(shas)
         for p in self._iter_cached_packs():
         for p in self._iter_cached_packs():
             for unpacked in p.iter_unpacked_subset(
             for unpacked in p.iter_unpacked_subset(
                 todo,
                 todo,
@@ -1225,7 +1227,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
                 todo.remove(hexsha)
                 todo.remove(hexsha)
 
 
     def iterobjects_subset(
     def iterobjects_subset(
-        self, shas: Iterable[bytes], *, allow_missing: bool = False
+        self, shas: Iterable[ObjectID], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
     ) -> Iterator[ShaFile]:
         """Iterate over a subset of objects in the store.
         """Iterate over a subset of objects in the store.
 
 
@@ -1241,7 +1243,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         Raises:
         Raises:
           KeyError: If an object is missing and allow_missing is False
           KeyError: If an object is missing and allow_missing is False
         """
         """
-        todo: set[bytes] = set(shas)
+        todo: set[ObjectID] = set(shas)
         for p in self._iter_cached_packs():
         for p in self._iter_cached_packs():
             for o in p.iterobjects_subset(todo, allow_missing=True):
             for o in p.iterobjects_subset(todo, allow_missing=True):
                 yield o
                 yield o
@@ -1275,10 +1277,10 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         if sha1 == ZERO_SHA:
         if sha1 == ZERO_SHA:
             raise KeyError(sha1)
             raise KeyError(sha1)
         if len(sha1) == 40:
         if len(sha1) == 40:
-            sha = hex_to_sha(sha1)
-            hexsha = sha1
+            sha = hex_to_sha(cast(ObjectID, sha1))
+            hexsha = cast(ObjectID, sha1)
         elif len(sha1) == 20:
         elif len(sha1) == 20:
-            sha = sha1
+            sha = cast(RawObjectID, sha1)
             hexsha = None
             hexsha = None
         else:
         else:
             raise AssertionError(f"Invalid object sha1 {sha1!r}")
             raise AssertionError(f"Invalid object sha1 {sha1!r}")
@@ -1288,7 +1290,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
             except (KeyError, PackFileDisappeared):
             except (KeyError, PackFileDisappeared):
                 pass
                 pass
         if hexsha is None:
         if hexsha is None:
-            hexsha = sha_to_hex(sha1)
+            hexsha = sha_to_hex(sha)
         # Maybe something else has added a pack with the object
         # Maybe something else has added a pack with the object
         # in the mean time?
         # in the mean time?
         for pack in self._update_pack_cache():
         for pack in self._update_pack_cache():
@@ -1614,11 +1616,11 @@ class DiskObjectStore(PackBasedObjectStore):
             self._pack_cache.pop(f).close()
             self._pack_cache.pop(f).close()
         return new_packs
         return new_packs
 
 
-    def _get_shafile_path(self, sha: bytes) -> str:
+    def _get_shafile_path(self, sha: ObjectID | RawObjectID) -> str:
         # Check from object dir
         # Check from object dir
         return hex_to_filename(os.fspath(self.path), sha)
         return hex_to_filename(os.fspath(self.path), sha)
 
 
-    def _iter_loose_objects(self) -> Iterator[bytes]:
+    def _iter_loose_objects(self) -> Iterator[ObjectID]:
         for base in os.listdir(self.path):
         for base in os.listdir(self.path):
             if len(base) != 2:
             if len(base) != 2:
                 continue
                 continue
@@ -1626,7 +1628,7 @@ class DiskObjectStore(PackBasedObjectStore):
                 sha = os.fsencode(base + rest)
                 sha = os.fsencode(base + rest)
                 if not valid_hexsha(sha):
                 if not valid_hexsha(sha):
                     continue
                     continue
-                yield sha
+                yield ObjectID(sha)
 
 
     def count_loose_objects(self) -> int:
     def count_loose_objects(self) -> int:
         """Count the number of loose objects in the object store.
         """Count the number of loose objects in the object store.
@@ -1654,14 +1656,14 @@ class DiskObjectStore(PackBasedObjectStore):
 
 
         return count
         return count
 
 
-    def _get_loose_object(self, sha: bytes) -> ShaFile | None:
+    def _get_loose_object(self, sha: ObjectID | RawObjectID) -> ShaFile | None:
         path = self._get_shafile_path(sha)
         path = self._get_shafile_path(sha)
         try:
         try:
             return ShaFile.from_path(path)
             return ShaFile.from_path(path)
         except FileNotFoundError:
         except FileNotFoundError:
             return None
             return None
 
 
-    def delete_loose_object(self, sha: bytes) -> None:
+    def delete_loose_object(self, sha: ObjectID) -> None:
         """Delete a loose object from disk.
         """Delete a loose object from disk.
 
 
         Args:
         Args:
@@ -1672,7 +1674,7 @@ class DiskObjectStore(PackBasedObjectStore):
         """
         """
         os.remove(self._get_shafile_path(sha))
         os.remove(self._get_shafile_path(sha))
 
 
-    def get_object_mtime(self, sha: bytes) -> float:
+    def get_object_mtime(self, sha: ObjectID) -> float:
         """Get the modification time of an object.
         """Get the modification time of an object.
 
 
         Args:
         Args:
@@ -1732,7 +1734,7 @@ class DiskObjectStore(PackBasedObjectStore):
         num_objects: int,
         num_objects: int,
         indexer: PackIndexer,
         indexer: PackIndexer,
         progress: Callable[..., None] | None = None,
         progress: Callable[..., None] | None = None,
-        refs: dict[bytes, bytes] | None = None,
+        refs: dict[Ref, ObjectID] | None = None,
     ) -> Pack:
     ) -> Pack:
         """Move a specific file containing a pack into the pack directory.
         """Move a specific file containing a pack into the pack directory.
 
 
@@ -1974,14 +1976,14 @@ class DiskObjectStore(PackBasedObjectStore):
             os.chmod(pack_path, dir_mode)
             os.chmod(pack_path, dir_mode)
         return cls(path, file_mode=file_mode, dir_mode=dir_mode)
         return cls(path, file_mode=file_mode, dir_mode=dir_mode)
 
 
-    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+    def iter_prefix(self, prefix: bytes) -> Iterator[ObjectID]:
         """Iterate over all object SHAs with the given prefix.
         """Iterate over all object SHAs with the given prefix.
 
 
         Args:
         Args:
           prefix: Hex prefix to search for (as bytes)
           prefix: Hex prefix to search for (as bytes)
 
 
         Returns:
         Returns:
-          Iterator of object SHAs (as bytes) matching the prefix
+          Iterator of object SHAs (as ObjectID) matching the prefix
         """
         """
         if len(prefix) < 2:
         if len(prefix) < 2:
             yield from super().iter_prefix(prefix)
             yield from super().iter_prefix(prefix)
@@ -1992,7 +1994,7 @@ class DiskObjectStore(PackBasedObjectStore):
         try:
         try:
             for name in os.listdir(os.path.join(self.path, dir)):
             for name in os.listdir(os.path.join(self.path, dir)):
                 if name.startswith(rest):
                 if name.startswith(rest):
-                    sha = os.fsencode(dir + name)
+                    sha = ObjectID(os.fsencode(dir + name))
                     if sha not in seen:
                     if sha not in seen:
                         seen.add(sha)
                         seen.add(sha)
                         yield sha
                         yield sha
@@ -2005,8 +2007,8 @@ class DiskObjectStore(PackBasedObjectStore):
                 if len(prefix) % 2 == 0
                 if len(prefix) % 2 == 0
                 else binascii.unhexlify(prefix[:-1])
                 else binascii.unhexlify(prefix[:-1])
             )
             )
-            for sha in p.index.iter_prefix(bin_prefix):
-                sha = sha_to_hex(sha)
+            for bin_sha in p.index.iter_prefix(bin_prefix):
+                sha = sha_to_hex(bin_sha)
                 if sha.startswith(prefix) and sha not in seen:
                 if sha.startswith(prefix) and sha not in seen:
                     seen.add(sha)
                     seen.add(sha)
                     yield sha
                     yield sha
@@ -2035,7 +2037,7 @@ class DiskObjectStore(PackBasedObjectStore):
         return self._commit_graph
         return self._commit_graph
 
 
     def write_commit_graph(
     def write_commit_graph(
-        self, refs: Iterable[bytes] | None = None, reachable: bool = True
+        self, refs: Iterable[ObjectID] | None = None, reachable: bool = True
     ) -> None:
     ) -> None:
         """Write a commit graph file for this object store.
         """Write a commit graph file for this object store.
 
 
@@ -2068,20 +2070,8 @@ class DiskObjectStore(PackBasedObjectStore):
             # Get all reachable commits
             # Get all reachable commits
             commit_ids = get_reachable_commits(self, all_refs)
             commit_ids = get_reachable_commits(self, all_refs)
         else:
         else:
-            # Just use the direct ref targets - ensure they're hex ObjectIDs
-            commit_ids = []
-            for ref in all_refs:
-                if isinstance(ref, bytes) and len(ref) == 40:
-                    # Already hex ObjectID
-                    commit_ids.append(ref)
-                elif isinstance(ref, bytes) and len(ref) == 20:
-                    # Binary SHA, convert to hex ObjectID
-                    from .objects import sha_to_hex
-
-                    commit_ids.append(sha_to_hex(ref))
-                else:
-                    # Assume it's already correct format
-                    commit_ids.append(ref)
+            # Just use the direct ref targets (already ObjectIDs)
+            commit_ids = all_refs
 
 
         if commit_ids:
         if commit_ids:
             # Write commit graph directly to our object store path
             # Write commit graph directly to our object store path
@@ -2169,26 +2159,26 @@ class MemoryObjectStore(PackCapableObjectStore):
         Creates an empty in-memory object store.
         Creates an empty in-memory object store.
         """
         """
         super().__init__()
         super().__init__()
-        self._data: dict[bytes, ShaFile] = {}
+        self._data: dict[ObjectID, ShaFile] = {}
         self.pack_compression_level = -1
         self.pack_compression_level = -1
 
 
-    def _to_hexsha(self, sha: bytes) -> bytes:
+    def _to_hexsha(self, sha: ObjectID | RawObjectID) -> ObjectID:
         if len(sha) == 40:
         if len(sha) == 40:
-            return sha
+            return cast(ObjectID, sha)
         elif len(sha) == 20:
         elif len(sha) == 20:
-            return sha_to_hex(sha)
+            return sha_to_hex(cast(RawObjectID, sha))
         else:
         else:
             raise ValueError(f"Invalid sha {sha!r}")
             raise ValueError(f"Invalid sha {sha!r}")
 
 
-    def contains_loose(self, sha: bytes) -> bool:
+    def contains_loose(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is loose."""
         """Check if a particular object is present by SHA1 and is loose."""
         return self._to_hexsha(sha) in self._data
         return self._to_hexsha(sha) in self._data
 
 
-    def contains_packed(self, sha: bytes) -> bool:
+    def contains_packed(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is packed."""
         """Check if a particular object is present by SHA1 and is packed."""
         return False
         return False
 
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs that are present in this store."""
         """Iterate over the SHAs that are present in this store."""
         return iter(self._data.keys())
         return iter(self._data.keys())
 
 
@@ -2197,7 +2187,7 @@ class MemoryObjectStore(PackCapableObjectStore):
         """List with pack objects."""
         """List with pack objects."""
         return []
         return []
 
 
-    def get_raw(self, name: ObjectID) -> tuple[int, bytes]:
+    def get_raw(self, name: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
 
 
         Args:
         Args:
@@ -2207,7 +2197,7 @@ class MemoryObjectStore(PackCapableObjectStore):
         obj = self[self._to_hexsha(name)]
         obj = self[self._to_hexsha(name)]
         return obj.type_num, obj.as_raw_string()
         return obj.type_num, obj.as_raw_string()
 
 
-    def __getitem__(self, name: ObjectID) -> ShaFile:
+    def __getitem__(self, name: ObjectID | RawObjectID) -> ShaFile:
         """Retrieve an object by SHA.
         """Retrieve an object by SHA.
 
 
         Args:
         Args:
@@ -2350,8 +2340,10 @@ class ObjectIterator(Protocol):
 
 
 
 
 def tree_lookup_path(
 def tree_lookup_path(
-    lookup_obj: Callable[[bytes], ShaFile], root_sha: bytes, path: bytes
-) -> tuple[int, bytes]:
+    lookup_obj: Callable[[ObjectID | RawObjectID], ShaFile],
+    root_sha: ObjectID | RawObjectID,
+    path: bytes,
+) -> tuple[int, ObjectID]:
     """Look up an object in a Git tree.
     """Look up an object in a Git tree.
 
 
     Args:
     Args:
@@ -2388,8 +2380,8 @@ def _collect_filetree_revs(
 
 
 
 
 def _split_commits_and_tags(
 def _split_commits_and_tags(
-    obj_store: ObjectContainer, lst: Iterable[bytes], *, ignore_unknown: bool = False
-) -> tuple[set[bytes], set[bytes], set[bytes]]:
+    obj_store: ObjectContainer, lst: Iterable[ObjectID], *, ignore_unknown: bool = False
+) -> tuple[set[ObjectID], set[ObjectID], set[ObjectID]]:
     """Split object id list into three lists with commit, tag, and other SHAs.
     """Split object id list into three lists with commit, tag, and other SHAs.
 
 
     Commits referenced by tags are included into commits
     Commits referenced by tags are included into commits
@@ -2404,9 +2396,9 @@ def _split_commits_and_tags(
         silently.
         silently.
     Returns: A tuple of (commits, tags, others) SHA1s
     Returns: A tuple of (commits, tags, others) SHA1s
     """
     """
-    commits: set[bytes] = set()
-    tags: set[bytes] = set()
-    others: set[bytes] = set()
+    commits: set[ObjectID] = set()
+    tags: set[ObjectID] = set()
+    others: set[ObjectID] = set()
     for e in lst:
     for e in lst:
         try:
         try:
             o = obj_store[e]
             o = obj_store[e]
@@ -2447,13 +2439,13 @@ class MissingObjectFinder:
     def __init__(
     def __init__(
         self,
         self,
         object_store: BaseObjectStore,
         object_store: BaseObjectStore,
-        haves: Iterable[bytes],
-        wants: Iterable[bytes],
+        haves: Iterable[ObjectID],
+        wants: Iterable[ObjectID],
         *,
         *,
-        shallow: Set[bytes] | None = None,
+        shallow: Set[ObjectID] | None = None,
         progress: Callable[[bytes], None] | None = None,
         progress: Callable[[bytes], None] | None = None,
-        get_tagged: Callable[[], dict[bytes, bytes]] | None = None,
-        get_parents: Callable[[Commit], list[bytes]] = lambda commit: commit.parents,
+        get_tagged: Callable[[], dict[ObjectID, ObjectID]] | None = None,
+        get_parents: Callable[[Commit], list[ObjectID]] = lambda commit: commit.parents,
     ) -> None:
     ) -> None:
         """Initialize a MissingObjectFinder.
         """Initialize a MissingObjectFinder.
 
 
@@ -2501,7 +2493,7 @@ class MissingObjectFinder:
             get_parents=self._get_parents,
             get_parents=self._get_parents,
         )
         )
 
 
-        self.remote_has: set[bytes] = set()
+        self.remote_has: set[ObjectID] = set()
         # Now, fill sha_done with commits and revisions of
         # Now, fill sha_done with commits and revisions of
         # files and directories known to be both locally
         # files and directories known to be both locally
         # and on target. Thus these commits and files
         # and on target. Thus these commits and files
@@ -2537,7 +2529,7 @@ class MissingObjectFinder:
             self.progress = progress
             self.progress = progress
         self._tagged = (get_tagged and get_tagged()) or {}
         self._tagged = (get_tagged and get_tagged()) or {}
 
 
-    def get_remote_has(self) -> set[bytes]:
+    def get_remote_has(self) -> set[ObjectID]:
         """Get the set of SHAs the remote has.
         """Get the set of SHAs the remote has.
 
 
         Returns:
         Returns:
@@ -2555,7 +2547,7 @@ class MissingObjectFinder:
         """
         """
         self.objects_to_send.update([e for e in entries if e[0] not in self.sha_done])
         self.objects_to_send.update([e for e in entries if e[0] not in self.sha_done])
 
 
-    def __next__(self) -> tuple[bytes, PackHint | None]:
+    def __next__(self) -> tuple[ObjectID, PackHint | None]:
         """Get the next object to send.
         """Get the next object to send.
 
 
         Returns:
         Returns:
@@ -2606,7 +2598,7 @@ class MissingObjectFinder:
             pack_hint = (type_num, name)
             pack_hint = (type_num, name)
         return (sha, pack_hint)
         return (sha, pack_hint)
 
 
-    def __iter__(self) -> Iterator[tuple[bytes, PackHint | None]]:
+    def __iter__(self) -> Iterator[tuple[ObjectID, PackHint | None]]:
         """Return iterator over objects to send.
         """Return iterator over objects to send.
 
 
         Returns:
         Returns:
@@ -2698,7 +2690,7 @@ class ObjectStoreGraphWalker:
 def commit_tree_changes(
 def commit_tree_changes(
     object_store: BaseObjectStore,
     object_store: BaseObjectStore,
     tree: ObjectID | Tree,
     tree: ObjectID | Tree,
-    changes: Sequence[tuple[bytes, int | None, bytes | None]],
+    changes: Sequence[tuple[bytes, int | None, ObjectID | None]],
 ) -> ObjectID:
 ) -> ObjectID:
     """Commit a specified set of changes to a tree structure.
     """Commit a specified set of changes to a tree structure.
 
 
@@ -2729,7 +2721,7 @@ def commit_tree_changes(
         sha_obj = object_store[tree]
         sha_obj = object_store[tree]
         assert isinstance(sha_obj, Tree)
         assert isinstance(sha_obj, Tree)
         tree_obj = sha_obj
         tree_obj = sha_obj
-    nested_changes: dict[bytes, list[tuple[bytes, int | None, bytes | None]]] = {}
+    nested_changes: dict[bytes, list[tuple[bytes, int | None, ObjectID | None]]] = {}
     for path, new_mode, new_sha in changes:
     for path, new_mode, new_sha in changes:
         try:
         try:
             (dirname, subpath) = path.split(b"/", 1)
             (dirname, subpath) = path.split(b"/", 1)
@@ -2743,7 +2735,7 @@ def commit_tree_changes(
             nested_changes.setdefault(dirname, []).append((subpath, new_mode, new_sha))
             nested_changes.setdefault(dirname, []).append((subpath, new_mode, new_sha))
     for name, subchanges in nested_changes.items():
     for name, subchanges in nested_changes.items():
         try:
         try:
-            orig_subtree_id: bytes | Tree = tree_obj[name][1]
+            orig_subtree_id: ObjectID | Tree = tree_obj[name][1]
         except KeyError:
         except KeyError:
             # For new directories, pass an empty Tree object
             # For new directories, pass an empty Tree object
             orig_subtree_id = Tree()
             orig_subtree_id = Tree()
@@ -2832,7 +2824,7 @@ class OverlayObjectStore(BaseObjectStore):
                     done.add(o_id)
                     done.add(o_id)
 
 
     def iterobjects_subset(
     def iterobjects_subset(
-        self, shas: Iterable[bytes], *, allow_missing: bool = False
+        self, shas: Iterable[ObjectID], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
     ) -> Iterator[ShaFile]:
         """Iterate over a subset of objects from the overlaid stores.
         """Iterate over a subset of objects from the overlaid stores.
 
 
@@ -2847,7 +2839,7 @@ class OverlayObjectStore(BaseObjectStore):
           KeyError: If an object is missing and allow_missing is False
           KeyError: If an object is missing and allow_missing is False
         """
         """
         todo = set(shas)
         todo = set(shas)
-        found: set[bytes] = set()
+        found: set[ObjectID] = set()
 
 
         for b in self.bases:
         for b in self.bases:
             # Create a copy of todo for each base to avoid modifying
             # Create a copy of todo for each base to avoid modifying
@@ -2864,7 +2856,7 @@ class OverlayObjectStore(BaseObjectStore):
 
 
     def iter_unpacked_subset(
     def iter_unpacked_subset(
         self,
         self,
-        shas: Iterable[bytes],
+        shas: Iterable[ObjectID | RawObjectID],
         include_comp: bool = False,
         include_comp: bool = False,
         allow_missing: bool = False,
         allow_missing: bool = False,
         convert_ofs_delta: bool = True,
         convert_ofs_delta: bool = True,
@@ -2883,7 +2875,7 @@ class OverlayObjectStore(BaseObjectStore):
         Raises:
         Raises:
           KeyError: If an object is missing and allow_missing is False
           KeyError: If an object is missing and allow_missing is False
         """
         """
-        todo = set(shas)
+        todo: set[ObjectID | RawObjectID] = set(shas)
         for b in self.bases:
         for b in self.bases:
             for o in b.iter_unpacked_subset(
             for o in b.iter_unpacked_subset(
                 todo,
                 todo,
@@ -2896,7 +2888,7 @@ class OverlayObjectStore(BaseObjectStore):
         if todo and not allow_missing:
         if todo and not allow_missing:
             raise KeyError(next(iter(todo)))
             raise KeyError(next(iter(todo)))
 
 
-    def get_raw(self, sha_id: ObjectID) -> tuple[int, bytes]:
+    def get_raw(self, sha_id: ObjectID | RawObjectID) -> tuple[int, bytes]:
         """Get the raw object data from the overlaid stores.
         """Get the raw object data from the overlaid stores.
 
 
         Args:
         Args:
@@ -2915,7 +2907,7 @@ class OverlayObjectStore(BaseObjectStore):
                 pass
                 pass
         raise KeyError(sha_id)
         raise KeyError(sha_id)
 
 
-    def contains_packed(self, sha: bytes) -> bool:
+    def contains_packed(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if an object is packed in any base store.
         """Check if an object is packed in any base store.
 
 
         Args:
         Args:
@@ -2929,7 +2921,7 @@ class OverlayObjectStore(BaseObjectStore):
                 return True
                 return True
         return False
         return False
 
 
-    def contains_loose(self, sha: bytes) -> bool:
+    def contains_loose(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if an object is loose in any base store.
         """Check if an object is loose in any base store.
 
 
         Args:
         Args:
@@ -2958,14 +2950,14 @@ def read_packs_file(f: BinaryIO) -> Iterator[str]:
 class BucketBasedObjectStore(PackBasedObjectStore):
 class BucketBasedObjectStore(PackBasedObjectStore):
     """Object store implementation that uses a bucket store like S3 as backend."""
     """Object store implementation that uses a bucket store like S3 as backend."""
 
 
-    def _iter_loose_objects(self) -> Iterator[bytes]:
+    def _iter_loose_objects(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs of all loose objects."""
         """Iterate over the SHAs of all loose objects."""
         return iter([])
         return iter([])
 
 
-    def _get_loose_object(self, sha: bytes) -> None:
+    def _get_loose_object(self, sha: ObjectID | RawObjectID) -> None:
         return None
         return None
 
 
-    def delete_loose_object(self, sha: bytes) -> None:
+    def delete_loose_object(self, sha: ObjectID) -> None:
         """Delete a loose object (no-op for bucket stores).
         """Delete a loose object (no-op for bucket stores).
 
 
         Bucket-based stores don't have loose objects, so this is a no-op.
         Bucket-based stores don't have loose objects, so this is a no-op.
@@ -3069,7 +3061,7 @@ def _collect_ancestors(
     heads: Iterable[ObjectID],
     heads: Iterable[ObjectID],
     common: frozenset[ObjectID] = frozenset(),
     common: frozenset[ObjectID] = frozenset(),
     shallow: frozenset[ObjectID] = frozenset(),
     shallow: frozenset[ObjectID] = frozenset(),
-    get_parents: Callable[[Commit], list[bytes]] = lambda commit: commit.parents,
+    get_parents: Callable[[Commit], list[ObjectID]] = lambda commit: commit.parents,
 ) -> tuple[set[ObjectID], set[ObjectID]]:
 ) -> tuple[set[ObjectID], set[ObjectID]]:
     """Collect all ancestors of heads up to (excluding) those in common.
     """Collect all ancestors of heads up to (excluding) those in common.
 
 
@@ -3154,7 +3146,7 @@ def iter_tree_contents(
 
 
 def iter_commit_contents(
 def iter_commit_contents(
     store: ObjectContainer,
     store: ObjectContainer,
-    commit: Commit | bytes,
+    commit: Commit | ObjectID | RawObjectID,
     *,
     *,
     include: Sequence[str | bytes | Path] | None = None,
     include: Sequence[str | bytes | Path] | None = None,
 ) -> Iterator[TreeEntry]:
 ) -> Iterator[TreeEntry]:
@@ -3203,7 +3195,9 @@ def iter_commit_contents(
             yield TreeEntry(path, mode, obj_id)
             yield TreeEntry(path, mode, obj_id)
 
 
 
 
-def peel_sha(store: ObjectContainer, sha: bytes) -> tuple[ShaFile, ShaFile]:
+def peel_sha(
+    store: ObjectContainer, sha: ObjectID | RawObjectID
+) -> tuple[ShaFile, ShaFile]:
     """Peel all tags from a SHA.
     """Peel all tags from a SHA.
 
 
     Args:
     Args:
@@ -3240,10 +3234,10 @@ class GraphTraversalReachability:
 
 
     def get_reachable_commits(
     def get_reachable_commits(
         self,
         self,
-        heads: Iterable[bytes],
-        exclude: Iterable[bytes] | None = None,
-        shallow: Set[bytes] | None = None,
-    ) -> set[bytes]:
+        heads: Iterable[ObjectID],
+        exclude: Iterable[ObjectID] | None = None,
+        shallow: Set[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all commits reachable from heads, excluding those in exclude.
         """Get all commits reachable from heads, excluding those in exclude.
 
 
         Uses _collect_ancestors for commit traversal.
         Uses _collect_ancestors for commit traversal.
@@ -3265,8 +3259,8 @@ class GraphTraversalReachability:
 
 
     def get_tree_objects(
     def get_tree_objects(
         self,
         self,
-        tree_shas: Iterable[bytes],
-    ) -> set[bytes]:
+        tree_shas: Iterable[ObjectID],
+    ) -> set[ObjectID]:
         """Get all trees and blobs reachable from the given trees.
         """Get all trees and blobs reachable from the given trees.
 
 
         Uses _collect_filetree_revs for tree traversal.
         Uses _collect_filetree_revs for tree traversal.
@@ -3277,16 +3271,16 @@ class GraphTraversalReachability:
         Returns:
         Returns:
           Set of tree and blob SHAs
           Set of tree and blob SHAs
         """
         """
-        result: set[bytes] = set()
+        result: set[ObjectID] = set()
         for tree_sha in tree_shas:
         for tree_sha in tree_shas:
             _collect_filetree_revs(self.store, tree_sha, result)
             _collect_filetree_revs(self.store, tree_sha, result)
         return result
         return result
 
 
     def get_reachable_objects(
     def get_reachable_objects(
         self,
         self,
-        commits: Iterable[bytes],
-        exclude_commits: Iterable[bytes] | None = None,
-    ) -> set[bytes]:
+        commits: Iterable[ObjectID],
+        exclude_commits: Iterable[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all objects (commits + trees + blobs) reachable from commits.
         """Get all objects (commits + trees + blobs) reachable from commits.
 
 
         Args:
         Args:
@@ -3341,8 +3335,8 @@ class BitmapReachability:
 
 
     def _combine_commit_bitmaps(
     def _combine_commit_bitmaps(
         self,
         self,
-        commit_shas: set[bytes],
-        exclude_shas: set[bytes] | None = None,
+        commit_shas: set[ObjectID],
+        exclude_shas: set[ObjectID] | None = None,
     ) -> tuple["EWAHBitmap", "Pack"] | None:
     ) -> tuple["EWAHBitmap", "Pack"] | None:
         """Combine bitmaps for multiple commits using OR, with optional exclusion.
         """Combine bitmaps for multiple commits using OR, with optional exclusion.
 
 
@@ -3413,10 +3407,10 @@ class BitmapReachability:
 
 
     def get_reachable_commits(
     def get_reachable_commits(
         self,
         self,
-        heads: Iterable[bytes],
-        exclude: Iterable[bytes] | None = None,
-        shallow: Set[bytes] | None = None,
-    ) -> set[bytes]:
+        heads: Iterable[ObjectID],
+        exclude: Iterable[ObjectID] | None = None,
+        shallow: Set[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all commits reachable from heads using bitmaps where possible.
         """Get all commits reachable from heads using bitmaps where possible.
 
 
         Args:
         Args:
@@ -3455,8 +3449,8 @@ class BitmapReachability:
 
 
     def get_tree_objects(
     def get_tree_objects(
         self,
         self,
-        tree_shas: Iterable[bytes],
-    ) -> set[bytes]:
+        tree_shas: Iterable[ObjectID],
+    ) -> set[ObjectID]:
         """Get all trees and blobs reachable from the given trees.
         """Get all trees and blobs reachable from the given trees.
 
 
         Args:
         Args:
@@ -3470,9 +3464,9 @@ class BitmapReachability:
 
 
     def get_reachable_objects(
     def get_reachable_objects(
         self,
         self,
-        commits: Iterable[bytes],
-        exclude_commits: Iterable[bytes] | None = None,
-    ) -> set[bytes]:
+        commits: Iterable[ObjectID],
+        exclude_commits: Iterable[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all objects reachable from commits using bitmaps.
         """Get all objects reachable from commits using bitmaps.
 
 
         Args:
         Args:

+ 39 - 28
dulwich/objects.py

@@ -43,7 +43,7 @@ if sys.version_info >= (3, 11):
 else:
 else:
     from typing_extensions import Self
     from typing_extensions import Self
 
 
-from typing import TypeGuard
+from typing import NewType, TypeGuard
 
 
 from . import replace_me
 from . import replace_me
 from .errors import (
 from .errors import (
@@ -62,8 +62,6 @@ if TYPE_CHECKING:
 
 
     from .file import _GitFile
     from .file import _GitFile
 
 
-ZERO_SHA = b"0" * 40
-
 # Header fields for commits
 # Header fields for commits
 _TREE_HEADER = b"tree"
 _TREE_HEADER = b"tree"
 _PARENT_HEADER = b"parent"
 _PARENT_HEADER = b"parent"
@@ -93,7 +91,14 @@ SIGNATURE_PGP = b"pgp"
 SIGNATURE_SSH = b"ssh"
 SIGNATURE_SSH = b"ssh"
 
 
 
 
-ObjectID = bytes
+# Hex SHA type
+ObjectID = NewType("ObjectID", bytes)
+
+# Raw SHA type
+RawObjectID = NewType("RawObjectID", bytes)
+
+# Zero SHA constant
+ZERO_SHA: ObjectID = ObjectID(b"0" * 40)
 
 
 
 
 class EmptyFileException(FileFormatException):
 class EmptyFileException(FileFormatException):
@@ -117,18 +122,18 @@ def _decompress(string: bytes) -> bytes:
     return dcomped
     return dcomped
 
 
 
 
-def sha_to_hex(sha: ObjectID) -> bytes:
+def sha_to_hex(sha: RawObjectID) -> ObjectID:
     """Takes a string and returns the hex of the sha within."""
     """Takes a string and returns the hex of the sha within."""
     hexsha = binascii.hexlify(sha)
     hexsha = binascii.hexlify(sha)
     assert len(hexsha) == 40, f"Incorrect length of sha1 string: {hexsha!r}"
     assert len(hexsha) == 40, f"Incorrect length of sha1 string: {hexsha!r}"
-    return hexsha
+    return ObjectID(hexsha)
 
 
 
 
-def hex_to_sha(hex: bytes | str) -> bytes:
+def hex_to_sha(hex: ObjectID | str) -> RawObjectID:
     """Takes a hex sha and returns a binary sha."""
     """Takes a hex sha and returns a binary sha."""
     assert len(hex) == 40, f"Incorrect length of hexsha: {hex!r}"
     assert len(hex) == 40, f"Incorrect length of hexsha: {hex!r}"
     try:
     try:
-        return binascii.unhexlify(hex)
+        return RawObjectID(binascii.unhexlify(hex))
     except TypeError as exc:
     except TypeError as exc:
         if not isinstance(hex, bytes):
         if not isinstance(hex, bytes):
             raise
             raise
@@ -206,7 +211,7 @@ def filename_to_hex(filename: str | bytes) -> str:
         base_b, rest_b = names_b
         base_b, rest_b = names_b
         assert len(base_b) == 2 and len(rest_b) == 38, errmsg
         assert len(base_b) == 2 and len(rest_b) == 38, errmsg
         hex_bytes = base_b + rest_b
         hex_bytes = base_b + rest_b
-    hex_to_sha(hex_bytes)
+    hex_to_sha(ObjectID(hex_bytes))
     return hex_bytes.decode("ascii")
     return hex_bytes.decode("ascii")
 
 
 
 
@@ -337,7 +342,7 @@ class FixedSha:
         if not isinstance(hexsha, bytes):
         if not isinstance(hexsha, bytes):
             raise TypeError(f"Expected bytes for hexsha, got {hexsha!r}")
             raise TypeError(f"Expected bytes for hexsha, got {hexsha!r}")
         self._hexsha = hexsha
         self._hexsha = hexsha
-        self._sha = hex_to_sha(hexsha)
+        self._sha = hex_to_sha(ObjectID(hexsha))
 
 
     def digest(self) -> bytes:
     def digest(self) -> bytes:
         """Return the raw SHA digest."""
         """Return the raw SHA digest."""
@@ -481,13 +486,17 @@ class ShaFile:
         """Return a string representing this object, fit for display."""
         """Return a string representing this object, fit for display."""
         return self.as_raw_string().decode("utf-8", "replace")
         return self.as_raw_string().decode("utf-8", "replace")
 
 
-    def set_raw_string(self, text: bytes, sha: ObjectID | None = None) -> None:
+    def set_raw_string(
+        self, text: bytes, sha: ObjectID | RawObjectID | None = None
+    ) -> None:
         """Set the contents of this object from a serialized string."""
         """Set the contents of this object from a serialized string."""
         if not isinstance(text, bytes):
         if not isinstance(text, bytes):
             raise TypeError(f"Expected bytes for text, got {text!r}")
             raise TypeError(f"Expected bytes for text, got {text!r}")
         self.set_raw_chunks([text], sha)
         self.set_raw_chunks([text], sha)
 
 
-    def set_raw_chunks(self, chunks: list[bytes], sha: ObjectID | None = None) -> None:
+    def set_raw_chunks(
+        self, chunks: list[bytes], sha: ObjectID | RawObjectID | None = None
+    ) -> None:
         """Set the contents of this object from a list of chunks."""
         """Set the contents of this object from a list of chunks."""
         self._chunked_text = chunks
         self._chunked_text = chunks
         self._deserialize(chunks)
         self._deserialize(chunks)
@@ -571,7 +580,7 @@ class ShaFile:
 
 
     @staticmethod
     @staticmethod
     def from_raw_string(
     def from_raw_string(
-        type_num: int, string: bytes, sha: ObjectID | None = None
+        type_num: int, string: bytes, sha: ObjectID | RawObjectID | None = None
     ) -> "ShaFile":
     ) -> "ShaFile":
         """Creates an object of the indicated type from the raw string given.
         """Creates an object of the indicated type from the raw string given.
 
 
@@ -589,7 +598,7 @@ class ShaFile:
 
 
     @staticmethod
     @staticmethod
     def from_raw_chunks(
     def from_raw_chunks(
-        type_num: int, chunks: list[bytes], sha: ObjectID | None = None
+        type_num: int, chunks: list[bytes], sha: ObjectID | RawObjectID | None = None
     ) -> "ShaFile":
     ) -> "ShaFile":
         """Creates an object of the indicated type from the raw chunks given.
         """Creates an object of the indicated type from the raw chunks given.
 
 
@@ -673,9 +682,9 @@ class ShaFile:
         return obj_class.from_raw_string(self.type_num, self.as_raw_string(), self.id)
         return obj_class.from_raw_string(self.type_num, self.as_raw_string(), self.id)
 
 
     @property
     @property
-    def id(self) -> bytes:
+    def id(self) -> ObjectID:
         """The hex SHA of this object."""
         """The hex SHA of this object."""
-        return self.sha().hexdigest().encode("ascii")
+        return ObjectID(self.sha().hexdigest().encode("ascii"))
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         """Return string representation of this object."""
         """Return string representation of this object."""
@@ -1178,7 +1187,7 @@ class TreeEntry(NamedTuple):
 
 
     path: bytes
     path: bytes
     mode: int
     mode: int
-    sha: bytes
+    sha: ObjectID
 
 
     def in_path(self, path: bytes) -> "TreeEntry":
     def in_path(self, path: bytes) -> "TreeEntry":
         """Return a copy of this entry with the given path prepended."""
         """Return a copy of this entry with the given path prepended."""
@@ -1187,7 +1196,9 @@ class TreeEntry(NamedTuple):
         return TreeEntry(posixpath.join(path, self.path), self.mode, self.sha)
         return TreeEntry(posixpath.join(path, self.path), self.mode, self.sha)
 
 
 
 
-def parse_tree(text: bytes, strict: bool = False) -> Iterator[tuple[bytes, int, bytes]]:
+def parse_tree(
+    text: bytes, strict: bool = False
+) -> Iterator[tuple[bytes, int, ObjectID]]:
     """Parse a tree text.
     """Parse a tree text.
 
 
     Args:
     Args:
@@ -1215,11 +1226,11 @@ def parse_tree(text: bytes, strict: bool = False) -> Iterator[tuple[bytes, int,
         sha = text[name_end + 1 : count]
         sha = text[name_end + 1 : count]
         if len(sha) != 20:
         if len(sha) != 20:
             raise ObjectFormatException("Sha has invalid length")
             raise ObjectFormatException("Sha has invalid length")
-        hexsha = sha_to_hex(sha)
+        hexsha = sha_to_hex(RawObjectID(sha))
         yield (name, mode, hexsha)
         yield (name, mode, hexsha)
 
 
 
 
-def serialize_tree(items: Iterable[tuple[bytes, int, bytes]]) -> Iterator[bytes]:
+def serialize_tree(items: Iterable[tuple[bytes, int, ObjectID]]) -> Iterator[bytes]:
     """Serialize the items in a tree to a text.
     """Serialize the items in a tree to a text.
 
 
     Args:
     Args:
@@ -1233,7 +1244,7 @@ def serialize_tree(items: Iterable[tuple[bytes, int, bytes]]) -> Iterator[bytes]
 
 
 
 
 def sorted_tree_items(
 def sorted_tree_items(
-    entries: dict[bytes, tuple[int, bytes]], name_order: bool
+    entries: dict[bytes, tuple[int, ObjectID]], name_order: bool
 ) -> Iterator[TreeEntry]:
 ) -> Iterator[TreeEntry]:
     """Iterate over a tree entries dictionary.
     """Iterate over a tree entries dictionary.
 
 
@@ -1275,7 +1286,7 @@ def key_entry_name_order(entry: tuple[bytes, tuple[int, ObjectID]]) -> bytes:
 
 
 
 
 def pretty_format_tree_entry(
 def pretty_format_tree_entry(
-    name: bytes, mode: int, hexsha: bytes, encoding: str = "utf-8"
+    name: bytes, mode: int, hexsha: ObjectID, encoding: str = "utf-8"
 ) -> str:
 ) -> str:
     """Pretty format tree entry.
     """Pretty format tree entry.
 
 
@@ -1323,7 +1334,7 @@ class Tree(ShaFile):
     def __init__(self) -> None:
     def __init__(self) -> None:
         """Initialize an empty Tree."""
         """Initialize an empty Tree."""
         super().__init__()
         super().__init__()
-        self._entries: dict[bytes, tuple[int, bytes]] = {}
+        self._entries: dict[bytes, tuple[int, ObjectID]] = {}
 
 
     @classmethod
     @classmethod
     def from_path(cls, filename: str | bytes) -> "Tree":
     def from_path(cls, filename: str | bytes) -> "Tree":
@@ -1377,7 +1388,7 @@ class Tree(ShaFile):
         """Iterate over tree entry names."""
         """Iterate over tree entry names."""
         return iter(self._entries)
         return iter(self._entries)
 
 
-    def add(self, name: bytes, mode: int, hexsha: bytes) -> None:
+    def add(self, name: bytes, mode: int, hexsha: ObjectID) -> None:
         """Add an entry to the tree.
         """Add an entry to the tree.
 
 
         Args:
         Args:
@@ -1698,7 +1709,7 @@ class Commit(ShaFile):
     def __init__(self) -> None:
     def __init__(self) -> None:
         """Initialize an empty Commit."""
         """Initialize an empty Commit."""
         super().__init__()
         super().__init__()
-        self._parents: list[bytes] = []
+        self._parents: list[ObjectID] = []
         self._encoding: bytes | None = None
         self._encoding: bytes | None = None
         self._mergetag: list[Tag] = []
         self._mergetag: list[Tag] = []
         self._gpgsig: bytes | None = None
         self._gpgsig: bytes | None = None
@@ -1749,7 +1760,7 @@ class Commit(ShaFile):
                 self._tree = value
                 self._tree = value
             elif field == _PARENT_HEADER:
             elif field == _PARENT_HEADER:
                 assert value is not None
                 assert value is not None
-                self._parents.append(value)
+                self._parents.append(ObjectID(value))
             elif field == _AUTHOR_HEADER:
             elif field == _AUTHOR_HEADER:
                 if value is None:
                 if value is None:
                     raise ObjectFormatException("missing author value")
                     raise ObjectFormatException("missing author value")
@@ -1976,11 +1987,11 @@ class Commit(ShaFile):
 
 
     tree = serializable_property("tree", "Tree that is the state of this commit")
     tree = serializable_property("tree", "Tree that is the state of this commit")
 
 
-    def _get_parents(self) -> list[bytes]:
+    def _get_parents(self) -> list[ObjectID]:
         """Return a list of parents of this commit."""
         """Return a list of parents of this commit."""
         return self._parents
         return self._parents
 
 
-    def _set_parents(self, value: list[bytes]) -> None:
+    def _set_parents(self, value: list[ObjectID]) -> None:
         """Set a list of parents of this commit."""
         """Set a list of parents of this commit."""
         self._needs_serialization = True
         self._needs_serialization = True
         self._parents = value
         self._parents = value

+ 27 - 18
dulwich/objectspec.py

@@ -24,8 +24,8 @@
 from collections.abc import Sequence
 from collections.abc import Sequence
 from typing import TYPE_CHECKING
 from typing import TYPE_CHECKING
 
 
-from .objects import Commit, ShaFile, Tag, Tree
-from .refs import local_branch_name, local_tag_name
+from .objects import Commit, ObjectID, RawObjectID, ShaFile, Tag, Tree
+from .refs import Ref, local_branch_name, local_tag_name
 from .repo import BaseRepo
 from .repo import BaseRepo
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -48,7 +48,9 @@ def to_bytes(text: str | bytes) -> bytes:
     return text
     return text
 
 
 
 
-def _resolve_object(repo: "Repo", ref: bytes) -> "ShaFile":
+def _resolve_object(
+    repo: "Repo", ref: Ref | ObjectID | RawObjectID | bytes
+) -> "ShaFile":
     """Resolve a reference to an object using multiple strategies."""
     """Resolve a reference to an object using multiple strategies."""
     try:
     try:
         return repo[ref]
         return repo[ref]
@@ -58,7 +60,7 @@ def _resolve_object(repo: "Repo", ref: bytes) -> "ShaFile":
             return repo[ref_sha]
             return repo[ref_sha]
         except KeyError:
         except KeyError:
             try:
             try:
-                return repo.object_store[ref]
+                return repo.object_store[ref]  # type: ignore[index]
             except (KeyError, ValueError):
             except (KeyError, ValueError):
                 # Re-raise original KeyError for consistency
                 # Re-raise original KeyError for consistency
                 raise KeyError(ref)
                 raise KeyError(ref)
@@ -262,12 +264,13 @@ def parse_tree(repo: "BaseRepo", treeish: bytes | str | Tree | Commit | Tag) ->
         treeish = treeish.id
         treeish = treeish.id
     else:
     else:
         treeish = to_bytes(treeish)
         treeish = to_bytes(treeish)
+    treeish_typed: Ref | ObjectID
     try:
     try:
-        treeish = parse_ref(repo.refs, treeish)
+        treeish_typed = parse_ref(repo.refs, treeish)
     except KeyError:  # treeish is commit sha
     except KeyError:  # treeish is commit sha
-        pass
+        treeish_typed = ObjectID(treeish)
     try:
     try:
-        o = repo[treeish]
+        o = repo[treeish_typed]
     except KeyError:
     except KeyError:
         # Try parsing as commit (handles short hashes)
         # Try parsing as commit (handles short hashes)
         try:
         try:
@@ -311,7 +314,7 @@ def parse_ref(container: "Repo | RefsContainer", refspec: str | bytes) -> "Ref":
     ]
     ]
     for ref in possible_refs:
     for ref in possible_refs:
         if ref in container:
         if ref in container:
-            return ref
+            return Ref(ref)
     raise KeyError(refspec)
     raise KeyError(refspec)
 
 
 
 
@@ -336,25 +339,31 @@ def parse_reftuple(
     if refspec.startswith(b"+"):
     if refspec.startswith(b"+"):
         force = True
         force = True
         refspec = refspec[1:]
         refspec = refspec[1:]
-    lh: bytes | None
-    rh: bytes | None
+    lh_bytes: bytes | None
+    rh_bytes: bytes | None
     if b":" in refspec:
     if b":" in refspec:
-        (lh, rh) = refspec.split(b":")
+        (lh_bytes, rh_bytes) = refspec.split(b":")
     else:
     else:
-        lh = rh = refspec
-    if lh == b"":
+        lh_bytes = rh_bytes = refspec
+
+    lh: Ref | None
+    if lh_bytes == b"":
         lh = None
         lh = None
     else:
     else:
-        lh = parse_ref(lh_container, lh)
-    if rh == b"":
+        lh = parse_ref(lh_container, lh_bytes)
+
+    rh: Ref | None
+    if rh_bytes == b"":
         rh = None
         rh = None
     else:
     else:
         try:
         try:
-            rh = parse_ref(rh_container, rh)
+            rh = parse_ref(rh_container, rh_bytes)
         except KeyError:
         except KeyError:
             # TODO: check force?
             # TODO: check force?
-            if b"/" not in rh:
-                rh = local_branch_name(rh)
+            if b"/" not in rh_bytes:
+                rh = Ref(local_branch_name(rh_bytes))
+            else:
+                rh = Ref(rh_bytes)
     return (lh, rh, force)
     return (lh, rh, force)
 
 
 
 

+ 89 - 65
dulwich/pack.py

@@ -62,6 +62,7 @@ from typing import (
     Generic,
     Generic,
     Protocol,
     Protocol,
     TypeVar,
     TypeVar,
+    cast,
 )
 )
 
 
 try:
 try:
@@ -77,6 +78,7 @@ if TYPE_CHECKING:
     from .bitmap import PackBitmap
     from .bitmap import PackBitmap
     from .commit_graph import CommitGraph
     from .commit_graph import CommitGraph
     from .object_store import BaseObjectStore
     from .object_store import BaseObjectStore
+    from .refs import Ref
 
 
 # For some reason the above try, except fails to set has_mmap = False for plan9
 # For some reason the above try, except fails to set has_mmap = False for plan9
 if sys.platform == "Plan9":
 if sys.platform == "Plan9":
@@ -86,7 +88,14 @@ from . import replace_me
 from .errors import ApplyDeltaError, ChecksumMismatch
 from .errors import ApplyDeltaError, ChecksumMismatch
 from .file import GitFile, _GitFile
 from .file import GitFile, _GitFile
 from .lru_cache import LRUSizeCache
 from .lru_cache import LRUSizeCache
-from .objects import ObjectID, ShaFile, hex_to_sha, object_header, sha_to_hex
+from .objects import (
+    ObjectID,
+    RawObjectID,
+    ShaFile,
+    hex_to_sha,
+    object_header,
+    sha_to_hex,
+)
 
 
 OFS_DELTA = 6
 OFS_DELTA = 6
 REF_DELTA = 7
 REF_DELTA = 7
@@ -140,10 +149,10 @@ class ObjectContainer(Protocol):
         Returns: Optional Pack object of the objects written.
         Returns: Optional Pack object of the objects written.
         """
         """
 
 
-    def __contains__(self, sha1: bytes) -> bool:
+    def __contains__(self, sha1: "ObjectID") -> bool:
         """Check if a hex sha is present."""
         """Check if a hex sha is present."""
 
 
-    def __getitem__(self, sha1: bytes) -> ShaFile:
+    def __getitem__(self, sha1: "ObjectID | RawObjectID") -> ShaFile:
         """Retrieve an object."""
         """Retrieve an object."""
 
 
     def get_commit_graph(self) -> "CommitGraph | None":
     def get_commit_graph(self) -> "CommitGraph | None":
@@ -159,7 +168,7 @@ class PackedObjectContainer(ObjectContainer):
     """Container for objects packed in a pack file."""
     """Container for objects packed in a pack file."""
 
 
     def get_unpacked_object(
     def get_unpacked_object(
-        self, sha1: bytes, *, include_comp: bool = False
+        self, sha1: "ObjectID | RawObjectID", *, include_comp: bool = False
     ) -> "UnpackedObject":
     ) -> "UnpackedObject":
         """Get a raw unresolved object.
         """Get a raw unresolved object.
 
 
@@ -173,7 +182,7 @@ class PackedObjectContainer(ObjectContainer):
         raise NotImplementedError(self.get_unpacked_object)
         raise NotImplementedError(self.get_unpacked_object)
 
 
     def iterobjects_subset(
     def iterobjects_subset(
-        self, shas: Iterable[bytes], *, allow_missing: bool = False
+        self, shas: Iterable["ObjectID"], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
     ) -> Iterator[ShaFile]:
         """Iterate over a subset of objects.
         """Iterate over a subset of objects.
 
 
@@ -188,7 +197,7 @@ class PackedObjectContainer(ObjectContainer):
 
 
     def iter_unpacked_subset(
     def iter_unpacked_subset(
         self,
         self,
-        shas: Iterable[bytes],
+        shas: Iterable["ObjectID | RawObjectID"],
         *,
         *,
         include_comp: bool = False,
         include_comp: bool = False,
         allow_missing: bool = False,
         allow_missing: bool = False,
@@ -332,12 +341,12 @@ class UnpackedObject:
             self.obj_chunks = self.decomp_chunks
             self.obj_chunks = self.decomp_chunks
             self.delta_base = delta_base
             self.delta_base = delta_base
 
 
-    def sha(self) -> bytes:
+    def sha(self) -> RawObjectID:
         """Return the binary SHA of this object."""
         """Return the binary SHA of this object."""
         if self._sha is None:
         if self._sha is None:
             assert self.obj_type_num is not None and self.obj_chunks is not None
             assert self.obj_type_num is not None and self.obj_chunks is not None
             self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
             self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
-        return self._sha
+        return RawObjectID(self._sha)
 
 
     def sha_file(self) -> ShaFile:
     def sha_file(self) -> ShaFile:
         """Return a ShaFile from this object."""
         """Return a ShaFile from this object."""
@@ -547,7 +556,7 @@ def bisect_find_sha(
     return None
     return None
 
 
 
 
-PackIndexEntry = tuple[bytes, int, int | None]
+PackIndexEntry = tuple[RawObjectID, int, int | None]
 
 
 
 
 class PackIndex:
 class PackIndex:
@@ -581,9 +590,9 @@ class PackIndex:
         """Return the number of entries in this pack index."""
         """Return the number of entries in this pack index."""
         raise NotImplementedError(self.__len__)
         raise NotImplementedError(self.__len__)
 
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs in this pack."""
         """Iterate over the SHAs in this pack."""
-        return map(sha_to_hex, self._itersha())
+        return map(lambda sha: sha_to_hex(RawObjectID(sha)), self._itersha())
 
 
     def iterentries(self) -> Iterator[PackIndexEntry]:
     def iterentries(self) -> Iterator[PackIndexEntry]:
         """Iterate over the entries in this pack index.
         """Iterate over the entries in this pack index.
@@ -601,7 +610,7 @@ class PackIndex:
         raise NotImplementedError(self.get_pack_checksum)
         raise NotImplementedError(self.get_pack_checksum)
 
 
     @replace_me(since="0.21.0", remove_in="0.23.0")
     @replace_me(since="0.21.0", remove_in="0.23.0")
-    def object_index(self, sha: bytes) -> int:
+    def object_index(self, sha: ObjectID | RawObjectID) -> int:
         """Return the index for the given SHA.
         """Return the index for the given SHA.
 
 
         Args:
         Args:
@@ -612,7 +621,7 @@ class PackIndex:
         """
         """
         return self.object_offset(sha)
         return self.object_offset(sha)
 
 
-    def object_offset(self, sha: bytes) -> int:
+    def object_offset(self, sha: ObjectID | RawObjectID) -> int:
         """Return the offset in to the corresponding packfile for the object.
         """Return the offset in to the corresponding packfile for the object.
 
 
         Given the name of an object it will return the offset that object
         Given the name of an object it will return the offset that object
@@ -648,7 +657,7 @@ class PackIndex:
         """Yield all the SHA1's of the objects in the index, sorted."""
         """Yield all the SHA1's of the objects in the index, sorted."""
         raise NotImplementedError(self._itersha)
         raise NotImplementedError(self._itersha)
 
 
-    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+    def iter_prefix(self, prefix: bytes) -> Iterator[RawObjectID]:
         """Iterate over all SHA1s with the given prefix.
         """Iterate over all SHA1s with the given prefix.
 
 
         Args:
         Args:
@@ -658,7 +667,7 @@ class PackIndex:
         # Default implementation for PackIndex classes that don't override
         # Default implementation for PackIndex classes that don't override
         for sha, _, _ in self.iterentries():
         for sha, _, _ in self.iterentries():
             if sha.startswith(prefix):
             if sha.startswith(prefix):
-                yield sha
+                yield RawObjectID(sha)
 
 
     def close(self) -> None:
     def close(self) -> None:
         """Close any open files."""
         """Close any open files."""
@@ -672,7 +681,7 @@ class MemoryPackIndex(PackIndex):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        entries: list[tuple[bytes, int, int | None]],
+        entries: list[PackIndexEntry],
         pack_checksum: bytes | None = None,
         pack_checksum: bytes | None = None,
     ) -> None:
     ) -> None:
         """Create a new MemoryPackIndex.
         """Create a new MemoryPackIndex.
@@ -697,7 +706,7 @@ class MemoryPackIndex(PackIndex):
         """Return the number of entries in this pack index."""
         """Return the number of entries in this pack index."""
         return len(self._entries)
         return len(self._entries)
 
 
-    def object_offset(self, sha: bytes) -> int:
+    def object_offset(self, sha: ObjectID | RawObjectID) -> int:
         """Return the offset for the given SHA.
         """Return the offset for the given SHA.
 
 
         Args:
         Args:
@@ -705,8 +714,8 @@ class MemoryPackIndex(PackIndex):
         Returns: Offset in the pack file
         Returns: Offset in the pack file
         """
         """
         if len(sha) == 40:
         if len(sha) == 40:
-            sha = hex_to_sha(sha)
-        return self._by_sha[sha]
+            sha = hex_to_sha(cast(ObjectID, sha))
+        return self._by_sha[cast(RawObjectID, sha)]
 
 
     def object_sha1(self, offset: int) -> bytes:
     def object_sha1(self, offset: int) -> bytes:
         """Return the SHA1 for the object at the given offset."""
         """Return the SHA1 for the object at the given offset."""
@@ -880,7 +889,7 @@ class FilePackIndex(PackIndex):
         """
         """
         return bytes(self._contents[-20:])
         return bytes(self._contents[-20:])
 
 
-    def object_offset(self, sha: bytes) -> int:
+    def object_offset(self, sha: ObjectID | RawObjectID) -> int:
         """Return the offset in to the corresponding packfile for the object.
         """Return the offset in to the corresponding packfile for the object.
 
 
         Given the name of an object it will return the offset that object
         Given the name of an object it will return the offset that object
@@ -888,7 +897,7 @@ class FilePackIndex(PackIndex):
         have the object then None will be returned.
         have the object then None will be returned.
         """
         """
         if len(sha) == 40:
         if len(sha) == 40:
-            sha = hex_to_sha(sha)
+            sha = hex_to_sha(cast(ObjectID, sha))
         try:
         try:
             return self._object_offset(sha)
             return self._object_offset(sha)
         except ValueError as exc:
         except ValueError as exc:
@@ -915,7 +924,7 @@ class FilePackIndex(PackIndex):
             raise KeyError(sha)
             raise KeyError(sha)
         return self._unpack_offset(i)
         return self._unpack_offset(i)
 
 
-    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+    def iter_prefix(self, prefix: bytes) -> Iterator[RawObjectID]:
         """Iterate over all SHA1s with the given prefix."""
         """Iterate over all SHA1s with the given prefix."""
         start = ord(prefix[:1])
         start = ord(prefix[:1])
         if start == 0:
         if start == 0:
@@ -932,7 +941,7 @@ class FilePackIndex(PackIndex):
         for i in range(start, end):
         for i in range(start, end):
             name: bytes = self._unpack_name(i)
             name: bytes = self._unpack_name(i)
             if name.startswith(prefix):
             if name.startswith(prefix):
-                yield name
+                yield RawObjectID(name)
                 started = True
                 started = True
             elif started:
             elif started:
                 break
                 break
@@ -960,9 +969,9 @@ class PackIndex1(FilePackIndex):
         self.version = 1
         self.version = 1
         self._fan_out_table = self._read_fan_out_table(0)
         self._fan_out_table = self._read_fan_out_table(0)
 
 
-    def _unpack_entry(self, i: int) -> tuple[bytes, int, None]:
+    def _unpack_entry(self, i: int) -> tuple[RawObjectID, int, None]:
         (offset, name) = unpack_from(">L20s", self._contents, (0x100 * 4) + (i * 24))
         (offset, name) = unpack_from(">L20s", self._contents, (0x100 * 4) + (i * 24))
-        return (name, offset, None)
+        return (RawObjectID(name), offset, None)
 
 
     def _unpack_name(self, i: int) -> bytes:
     def _unpack_name(self, i: int) -> bytes:
         offset = (0x100 * 4) + (i * 24) + 4
         offset = (0x100 * 4) + (i * 24) + 4
@@ -1011,9 +1020,9 @@ class PackIndex2(FilePackIndex):
             self
             self
         )
         )
 
 
-    def _unpack_entry(self, i: int) -> tuple[bytes, int, int]:
+    def _unpack_entry(self, i: int) -> tuple[RawObjectID, int, int]:
         return (
         return (
-            self._unpack_name(i),
+            RawObjectID(self._unpack_name(i)),
             self._unpack_offset(i),
             self._unpack_offset(i),
             self._unpack_crc32_checksum(i),
             self._unpack_crc32_checksum(i),
         )
         )
@@ -1091,9 +1100,9 @@ class PackIndex3(FilePackIndex):
             self
             self
         )
         )
 
 
-    def _unpack_entry(self, i: int) -> tuple[bytes, int, int]:
+    def _unpack_entry(self, i: int) -> tuple[RawObjectID, int, int]:
         return (
         return (
-            self._unpack_name(i),
+            RawObjectID(self._unpack_name(i)),
             self._unpack_offset(i),
             self._unpack_offset(i),
             self._unpack_crc32_checksum(i),
             self._unpack_crc32_checksum(i),
         )
         )
@@ -1390,7 +1399,9 @@ class PackStreamReader:
 
 
         pack_sha = bytearray(self._trailer)
         pack_sha = bytearray(self._trailer)
         if pack_sha != self.sha.digest():
         if pack_sha != self.sha.digest():
-            raise ChecksumMismatch(sha_to_hex(bytes(pack_sha)), self.sha.hexdigest())
+            raise ChecksumMismatch(
+                sha_to_hex(RawObjectID(bytes(pack_sha))), self.sha.hexdigest()
+            )
 
 
 
 
 class PackStreamCopier(PackStreamReader):
 class PackStreamCopier(PackStreamReader):
@@ -1663,7 +1674,7 @@ class PackData:
         self,
         self,
         progress: Callable[[int, int], None] | None = None,
         progress: Callable[[int, int], None] | None = None,
         resolve_ext_ref: ResolveExtRefFn | None = None,
         resolve_ext_ref: ResolveExtRefFn | None = None,
-    ) -> Iterator[tuple[bytes, int, int | None]]:
+    ) -> Iterator[PackIndexEntry]:
         """Yield entries summarizing the contents of this pack.
         """Yield entries summarizing the contents of this pack.
 
 
         Args:
         Args:
@@ -1683,7 +1694,7 @@ class PackData:
         self,
         self,
         progress: ProgressFn | None = None,
         progress: ProgressFn | None = None,
         resolve_ext_ref: ResolveExtRefFn | None = None,
         resolve_ext_ref: ResolveExtRefFn | None = None,
-    ) -> list[tuple[bytes, int, int]]:
+    ) -> list[tuple[RawObjectID, int, int]]:
         """Return entries in this pack, sorted by SHA.
         """Return entries in this pack, sorted by SHA.
 
 
         Args:
         Args:
@@ -1883,7 +1894,7 @@ class DeltaChainIterator(Generic[T]):
         self._pending_ofs: dict[int, list[int]] = defaultdict(list)
         self._pending_ofs: dict[int, list[int]] = defaultdict(list)
         self._pending_ref: dict[bytes, list[int]] = defaultdict(list)
         self._pending_ref: dict[bytes, list[int]] = defaultdict(list)
         self._full_ofs: list[tuple[int, int]] = []
         self._full_ofs: list[tuple[int, int]] = []
-        self._ext_refs: list[bytes] = []
+        self._ext_refs: list[RawObjectID] = []
 
 
     @classmethod
     @classmethod
     def for_pack_data(
     def for_pack_data(
@@ -1908,7 +1919,7 @@ class DeltaChainIterator(Generic[T]):
     def for_pack_subset(
     def for_pack_subset(
         cls,
         cls,
         pack: "Pack",
         pack: "Pack",
-        shas: Iterable[bytes],
+        shas: Iterable[ObjectID | RawObjectID],
         *,
         *,
         allow_missing: bool = False,
         allow_missing: bool = False,
         resolve_ext_ref: ResolveExtRefFn | None = None,
         resolve_ext_ref: ResolveExtRefFn | None = None,
@@ -1928,7 +1939,6 @@ class DeltaChainIterator(Generic[T]):
         walker.set_pack_data(pack.data)
         walker.set_pack_data(pack.data)
         todo = set()
         todo = set()
         for sha in shas:
         for sha in shas:
-            assert isinstance(sha, bytes)
             try:
             try:
                 off = pack.index.object_offset(sha)
                 off = pack.index.object_offset(sha)
             except KeyError:
             except KeyError:
@@ -1951,7 +1961,7 @@ class DeltaChainIterator(Generic[T]):
             elif unpacked.pack_type_num == REF_DELTA:
             elif unpacked.pack_type_num == REF_DELTA:
                 with suppress(KeyError):
                 with suppress(KeyError):
                     assert isinstance(unpacked.delta_base, bytes)
                     assert isinstance(unpacked.delta_base, bytes)
-                    base_ofs = pack.index.object_index(unpacked.delta_base)
+                    base_ofs = pack.index.object_index(RawObjectID(unpacked.delta_base))
             if base_ofs is not None and base_ofs not in done:
             if base_ofs is not None and base_ofs not in done:
                 todo.add(base_ofs)
                 todo.add(base_ofs)
         return walker
         return walker
@@ -1992,7 +2002,9 @@ class DeltaChainIterator(Generic[T]):
 
 
     def _ensure_no_pending(self) -> None:
     def _ensure_no_pending(self) -> None:
         if self._pending_ref:
         if self._pending_ref:
-            raise UnresolvedDeltas([sha_to_hex(s) for s in self._pending_ref])
+            raise UnresolvedDeltas(
+                [sha_to_hex(RawObjectID(s)) for s in self._pending_ref]
+            )
 
 
     def _walk_ref_chains(self) -> Iterator[T]:
     def _walk_ref_chains(self) -> Iterator[T]:
         if not self._resolve_ext_ref:
         if not self._resolve_ext_ref:
@@ -2009,7 +2021,7 @@ class DeltaChainIterator(Generic[T]):
                 # get popped via a _follow_chain call, or we will raise an
                 # get popped via a _follow_chain call, or we will raise an
                 # error below.
                 # error below.
                 continue
                 continue
-            self._ext_refs.append(base_sha)
+            self._ext_refs.append(RawObjectID(base_sha))
             self._pending_ref.pop(base_sha)
             self._pending_ref.pop(base_sha)
             for new_offset in pending:
             for new_offset in pending:
                 yield from self._follow_chain(new_offset, type_num, chunks)  # type: ignore[arg-type]
                 yield from self._follow_chain(new_offset, type_num, chunks)  # type: ignore[arg-type]
@@ -2063,7 +2075,7 @@ class DeltaChainIterator(Generic[T]):
         """Iterate over objects in the pack."""
         """Iterate over objects in the pack."""
         return self._walk_all_chains()
         return self._walk_all_chains()
 
 
-    def ext_refs(self) -> list[bytes]:
+    def ext_refs(self) -> list[RawObjectID]:
         """Return external references."""
         """Return external references."""
         return self._ext_refs
         return self._ext_refs
 
 
@@ -2088,7 +2100,7 @@ class PackIndexer(DeltaChainIterator[PackIndexEntry]):
 
 
     _compute_crc32 = True
     _compute_crc32 = True
 
 
-    def _result(self, unpacked: UnpackedObject) -> tuple[bytes, int, int | None]:
+    def _result(self, unpacked: UnpackedObject) -> PackIndexEntry:
         """Convert unpacked object to pack index entry.
         """Convert unpacked object to pack index entry.
 
 
         Args:
         Args:
@@ -2154,9 +2166,12 @@ class SHA1Reader(BinaryIO):
         # If git option index.skipHash is set the index will be empty
         # If git option index.skipHash is set the index will be empty
         if stored != self.sha1.digest() and (
         if stored != self.sha1.digest() and (
             not allow_empty
             not allow_empty
-            or sha_to_hex(stored) != b"0000000000000000000000000000000000000000"
+            or sha_to_hex(RawObjectID(stored))
+            != b"0000000000000000000000000000000000000000"
         ):
         ):
-            raise ChecksumMismatch(self.sha1.hexdigest(), sha_to_hex(stored))
+            raise ChecksumMismatch(
+                self.sha1.hexdigest(), sha_to_hex(RawObjectID(stored))
+            )
 
 
     def close(self) -> None:
     def close(self) -> None:
         """Close the underlying file."""
         """Close the underlying file."""
@@ -2595,9 +2610,9 @@ def write_pack_header(
 
 
 def find_reusable_deltas(
 def find_reusable_deltas(
     container: PackedObjectContainer,
     container: PackedObjectContainer,
-    object_ids: Set[bytes],
+    object_ids: Set[ObjectID],
     *,
     *,
-    other_haves: Set[bytes] | None = None,
+    other_haves: Set[ObjectID] | None = None,
     progress: Callable[..., None] | None = None,
     progress: Callable[..., None] | None = None,
 ) -> Iterator[UnpackedObject]:
 ) -> Iterator[UnpackedObject]:
     """Find deltas in a pack that can be reused.
     """Find deltas in a pack that can be reused.
@@ -2799,7 +2814,7 @@ def generate_unpacked_objects(
     deltify: bool | None = None,
     deltify: bool | None = None,
     reuse_deltas: bool = True,
     reuse_deltas: bool = True,
     ofs_delta: bool = True,
     ofs_delta: bool = True,
-    other_haves: set[bytes] | None = None,
+    other_haves: set[ObjectID] | None = None,
     progress: Callable[..., None] | None = None,
     progress: Callable[..., None] | None = None,
 ) -> Iterator[UnpackedObject]:
 ) -> Iterator[UnpackedObject]:
     """Create pack data from objects.
     """Create pack data from objects.
@@ -2811,7 +2826,7 @@ def generate_unpacked_objects(
         for unpack in find_reusable_deltas(
         for unpack in find_reusable_deltas(
             container, set(todo), other_haves=other_haves, progress=progress
             container, set(todo), other_haves=other_haves, progress=progress
         ):
         ):
-            del todo[sha_to_hex(unpack.sha())]
+            del todo[sha_to_hex(RawObjectID(unpack.sha()))]
             yield unpack
             yield unpack
     if deltify is None:
     if deltify is None:
         # PERFORMANCE/TODO(jelmer): This should be enabled but is *much* too
         # PERFORMANCE/TODO(jelmer): This should be enabled but is *much* too
@@ -2860,7 +2875,7 @@ def write_pack_from_container(
     deltify: bool | None = None,
     deltify: bool | None = None,
     reuse_deltas: bool = True,
     reuse_deltas: bool = True,
     compression_level: int = -1,
     compression_level: int = -1,
-    other_haves: set[bytes] | None = None,
+    other_haves: set[ObjectID] | None = None,
 ) -> tuple[dict[bytes, tuple[int, int]], bytes]:
 ) -> tuple[dict[bytes, tuple[int, int]], bytes]:
     """Write a new pack data file.
     """Write a new pack data file.
 
 
@@ -3535,7 +3550,7 @@ class Pack:
     def ensure_bitmap(
     def ensure_bitmap(
         self,
         self,
         object_store: "BaseObjectStore",
         object_store: "BaseObjectStore",
-        refs: dict[bytes, bytes],
+        refs: dict["Ref", "ObjectID"],
         commit_interval: int | None = None,
         commit_interval: int | None = None,
         progress: Callable[[str], None] | None = None,
         progress: Callable[[str], None] | None = None,
     ) -> "PackBitmap":
     ) -> "PackBitmap":
@@ -3618,7 +3633,7 @@ class Pack:
         """Return string representation of this pack."""
         """Return string representation of this pack."""
         return f"{self.__class__.__name__}({self._basename!r})"
         return f"{self.__class__.__name__}({self._basename!r})"
 
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over all the sha1s of the objects in this pack."""
         """Iterate over all the sha1s of the objects in this pack."""
         return iter(self.index)
         return iter(self.index)
 
 
@@ -3634,8 +3649,8 @@ class Pack:
             and idx_stored_checksum != data_stored_checksum
             and idx_stored_checksum != data_stored_checksum
         ):
         ):
             raise ChecksumMismatch(
             raise ChecksumMismatch(
-                sha_to_hex(idx_stored_checksum),
-                sha_to_hex(data_stored_checksum),
+                sha_to_hex(RawObjectID(idx_stored_checksum)),
+                sha_to_hex(RawObjectID(data_stored_checksum)),
             )
             )
 
 
     def check(self) -> None:
     def check(self) -> None:
@@ -3658,7 +3673,7 @@ class Pack:
         """Return pack tuples for all objects in pack."""
         """Return pack tuples for all objects in pack."""
         return [(o, None) for o in self.iterobjects()]
         return [(o, None) for o in self.iterobjects()]
 
 
-    def __contains__(self, sha1: bytes) -> bool:
+    def __contains__(self, sha1: ObjectID | RawObjectID) -> bool:
         """Check whether this pack contains a particular SHA1."""
         """Check whether this pack contains a particular SHA1."""
         try:
         try:
             self.index.object_offset(sha1)
             self.index.object_offset(sha1)
@@ -3666,14 +3681,14 @@ class Pack:
         except KeyError:
         except KeyError:
             return False
             return False
 
 
-    def get_raw(self, sha1: bytes) -> tuple[int, bytes]:
+    def get_raw(self, sha1: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Get raw object data by SHA1."""
         """Get raw object data by SHA1."""
         offset = self.index.object_offset(sha1)
         offset = self.index.object_offset(sha1)
         obj_type, obj = self.data.get_object_at(offset)
         obj_type, obj = self.data.get_object_at(offset)
         type_num, chunks = self.resolve_object(offset, obj_type, obj)
         type_num, chunks = self.resolve_object(offset, obj_type, obj)
         return type_num, b"".join(chunks)  # type: ignore[arg-type]
         return type_num, b"".join(chunks)  # type: ignore[arg-type]
 
 
-    def __getitem__(self, sha1: bytes) -> ShaFile:
+    def __getitem__(self, sha1: "ObjectID | RawObjectID") -> ShaFile:
         """Retrieve the specified SHA1."""
         """Retrieve the specified SHA1."""
         type, uncomp = self.get_raw(sha1)
         type, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type, uncomp, sha=sha1)
         return ShaFile.from_raw_string(type, uncomp, sha=sha1)
@@ -3701,7 +3716,7 @@ class Pack:
 
 
     def iter_unpacked_subset(
     def iter_unpacked_subset(
         self,
         self,
-        shas: Iterable[ObjectID],
+        shas: Iterable[ObjectID | RawObjectID],
         *,
         *,
         include_comp: bool = False,
         include_comp: bool = False,
         allow_missing: bool = False,
         allow_missing: bool = False,
@@ -3710,12 +3725,12 @@ class Pack:
         """Iterate over unpacked objects in subset."""
         """Iterate over unpacked objects in subset."""
         ofs_pending: dict[int, list[UnpackedObject]] = defaultdict(list)
         ofs_pending: dict[int, list[UnpackedObject]] = defaultdict(list)
         ofs: dict[int, bytes] = {}
         ofs: dict[int, bytes] = {}
-        todo = set(shas)
+        todo: set[ObjectID | RawObjectID] = set(shas)
         for unpacked in self.iter_unpacked(include_comp=include_comp):
         for unpacked in self.iter_unpacked(include_comp=include_comp):
             sha = unpacked.sha()
             sha = unpacked.sha()
             if unpacked.offset is not None:
             if unpacked.offset is not None:
                 ofs[unpacked.offset] = sha
                 ofs[unpacked.offset] = sha
-            hexsha = sha_to_hex(sha)
+            hexsha = sha_to_hex(RawObjectID(sha))
             if hexsha in todo:
             if hexsha in todo:
                 if unpacked.pack_type_num == OFS_DELTA:
                 if unpacked.pack_type_num == OFS_DELTA:
                     assert isinstance(unpacked.delta_base, int)
                     assert isinstance(unpacked.delta_base, int)
@@ -3766,7 +3781,9 @@ class Pack:
                 keepfile.write(b"\n")
                 keepfile.write(b"\n")
         return keepfile_name
         return keepfile_name
 
 
-    def get_ref(self, sha: bytes) -> tuple[int | None, int, OldUnpackedObject]:
+    def get_ref(
+        self, sha: RawObjectID | ObjectID
+    ) -> tuple[int | None, int, OldUnpackedObject]:
         """Get the object for a ref SHA, only looking in this pack."""
         """Get the object for a ref SHA, only looking in this pack."""
         # TODO: cache these results
         # TODO: cache these results
         try:
         try:
@@ -3786,7 +3803,9 @@ class Pack:
         offset: int,
         offset: int,
         type: int,
         type: int,
         obj: OldUnpackedObject,
         obj: OldUnpackedObject,
-        get_ref: Callable[[bytes], tuple[int | None, int, OldUnpackedObject]]
+        get_ref: Callable[
+            [RawObjectID | ObjectID], tuple[int | None, int, OldUnpackedObject]
+        ]
         | None = None,
         | None = None,
     ) -> tuple[int, OldUnpackedObject]:
     ) -> tuple[int, OldUnpackedObject]:
         """Resolve an object, possibly resolving deltas when necessary.
         """Resolve an object, possibly resolving deltas when necessary.
@@ -3795,7 +3814,7 @@ class Pack:
         """
         """
         # Walk down the delta chain, building a stack of deltas to reach
         # Walk down the delta chain, building a stack of deltas to reach
         # the requested object.
         # the requested object.
-        base_offset = offset
+        base_offset: int | None = offset
         base_type = type
         base_type = type
         base_obj = obj
         base_obj = obj
         delta_stack = []
         delta_stack = []
@@ -3809,13 +3828,14 @@ class Pack:
                 assert isinstance(delta_offset, int), (
                 assert isinstance(delta_offset, int), (
                     f"Expected int, got {delta_offset.__class__}"
                     f"Expected int, got {delta_offset.__class__}"
                 )
                 )
+                assert base_offset is not None
                 base_offset = base_offset - delta_offset
                 base_offset = base_offset - delta_offset
                 base_type, base_obj = self.data.get_object_at(base_offset)
                 base_type, base_obj = self.data.get_object_at(base_offset)
                 assert isinstance(base_type, int)
                 assert isinstance(base_type, int)
             elif base_type == REF_DELTA:
             elif base_type == REF_DELTA:
                 (basename, delta) = base_obj
                 (basename, delta) = base_obj
                 assert isinstance(basename, bytes) and len(basename) == 20
                 assert isinstance(basename, bytes) and len(basename) == 20
-                base_offset, base_type, base_obj = get_ref(basename)  # type: ignore[assignment]
+                base_offset, base_type, base_obj = get_ref(cast(RawObjectID, basename))
                 assert isinstance(base_type, int)
                 assert isinstance(base_type, int)
                 if base_offset == prev_offset:  # object is based on itself
                 if base_offset == prev_offset:  # object is based on itself
                     raise UnresolvedDeltas([basename])
                     raise UnresolvedDeltas([basename])
@@ -3876,7 +3896,11 @@ class Pack:
         )
         )
 
 
     def get_unpacked_object(
     def get_unpacked_object(
-        self, sha: bytes, *, include_comp: bool = False, convert_ofs_delta: bool = True
+        self,
+        sha: ObjectID | RawObjectID,
+        *,
+        include_comp: bool = False,
+        convert_ofs_delta: bool = True,
     ) -> UnpackedObject:
     ) -> UnpackedObject:
         """Get the unpacked object for a sha.
         """Get the unpacked object for a sha.
 
 
@@ -3896,12 +3920,12 @@ class Pack:
 
 
 def extend_pack(
 def extend_pack(
     f: BinaryIO,
     f: BinaryIO,
-    object_ids: Set[ObjectID],
-    get_raw: Callable[[ObjectID], tuple[int, bytes]],
+    object_ids: Set["RawObjectID"],
+    get_raw: Callable[["RawObjectID | ObjectID"], tuple[int, bytes]],
     *,
     *,
     compression_level: int = -1,
     compression_level: int = -1,
     progress: Callable[[bytes], None] | None = None,
     progress: Callable[[bytes], None] | None = None,
-) -> tuple[bytes, list[tuple[bytes, int, int]]]:
+) -> tuple[bytes, list[tuple["RawObjectID", int, int]]]:
     """Extend a pack file with more objects.
     """Extend a pack file with more objects.
 
 
     The caller should make sure that object_ids does not contain any objects
     The caller should make sure that object_ids does not contain any objects

+ 9 - 7
dulwich/patch.py

@@ -43,7 +43,7 @@ from typing import (
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
     from .object_store import BaseObjectStore
 
 
-from .objects import S_ISGITLINK, Blob, Commit
+from .objects import S_ISGITLINK, Blob, Commit, ObjectID, RawObjectID
 
 
 FIRST_FEW_BYTES = 8000
 FIRST_FEW_BYTES = 8000
 
 
@@ -361,8 +361,8 @@ def patch_filename(p: bytes | None, root: bytes) -> bytes:
 def write_object_diff(
 def write_object_diff(
     f: IO[bytes],
     f: IO[bytes],
     store: "BaseObjectStore",
     store: "BaseObjectStore",
-    old_file: tuple[bytes | None, int | None, bytes | None],
-    new_file: tuple[bytes | None, int | None, bytes | None],
+    old_file: tuple[bytes | None, int | None, ObjectID | None],
+    new_file: tuple[bytes | None, int | None, ObjectID | None],
     diff_binary: bool = False,
     diff_binary: bool = False,
     diff_algorithm: str | None = None,
     diff_algorithm: str | None = None,
 ) -> None:
 ) -> None:
@@ -384,7 +384,7 @@ def write_object_diff(
     patched_old_path = patch_filename(old_path, b"a")
     patched_old_path = patch_filename(old_path, b"a")
     patched_new_path = patch_filename(new_path, b"b")
     patched_new_path = patch_filename(new_path, b"b")
 
 
-    def content(mode: int | None, hexsha: bytes | None) -> Blob:
+    def content(mode: int | None, hexsha: ObjectID | None) -> Blob:
         """Get blob content for a file.
         """Get blob content for a file.
 
 
         Args:
         Args:
@@ -542,8 +542,8 @@ def write_blob_diff(
 def write_tree_diff(
 def write_tree_diff(
     f: IO[bytes],
     f: IO[bytes],
     store: "BaseObjectStore",
     store: "BaseObjectStore",
-    old_tree: bytes | None,
-    new_tree: bytes | None,
+    old_tree: ObjectID | None,
+    new_tree: ObjectID | None,
     diff_binary: bool = False,
     diff_binary: bool = False,
     diff_algorithm: str | None = None,
     diff_algorithm: str | None = None,
 ) -> None:
 ) -> None:
@@ -731,7 +731,9 @@ def patch_id(diff_data: bytes) -> bytes:
     return hashlib.sha1(normalized).hexdigest().encode("ascii")
     return hashlib.sha1(normalized).hexdigest().encode("ascii")
 
 
 
 
-def commit_patch_id(store: "BaseObjectStore", commit_id: bytes) -> bytes:
+def commit_patch_id(
+    store: "BaseObjectStore", commit_id: ObjectID | RawObjectID
+) -> bytes:
     """Compute patch ID for a commit.
     """Compute patch ID for a commit.
 
 
     Args:
     Args:

+ 137 - 127
dulwich/porcelain.py

@@ -169,6 +169,7 @@ from .object_store import BaseObjectStore, tree_lookup_path
 from .objects import (
 from .objects import (
     Blob,
     Blob,
     Commit,
     Commit,
+    ObjectID,
     Tag,
     Tag,
     Tree,
     Tree,
     TreeEntry,
     TreeEntry,
@@ -193,6 +194,7 @@ from .patch import (
 )
 )
 from .protocol import ZERO_SHA, Protocol
 from .protocol import ZERO_SHA, Protocol
 from .refs import (
 from .refs import (
+    HEADREF,
     LOCAL_BRANCH_PREFIX,
     LOCAL_BRANCH_PREFIX,
     LOCAL_NOTES_PREFIX,
     LOCAL_NOTES_PREFIX,
     LOCAL_REMOTE_PREFIX,
     LOCAL_REMOTE_PREFIX,
@@ -543,7 +545,7 @@ class DivergedBranches(Error):
         self.new_sha = new_sha
         self.new_sha = new_sha
 
 
 
 
-def check_diverged(repo: BaseRepo, current_sha: bytes, new_sha: bytes) -> None:
+def check_diverged(repo: BaseRepo, current_sha: ObjectID, new_sha: ObjectID) -> None:
     """Check if updating to a sha can be done with fast forwarding.
     """Check if updating to a sha can be done with fast forwarding.
 
 
     Args:
     Args:
@@ -625,7 +627,7 @@ def symbolic_ref(repo: RepoPath, ref_name: str | bytes, force: bool = False) ->
                 else ref_name
                 else ref_name
             )
             )
             raise Error(f"fatal: ref `{ref_name_str}` is not a ref")
             raise Error(f"fatal: ref `{ref_name_str}` is not a ref")
-        repo_obj.refs.set_symbolic_ref(b"HEAD", ref_path)
+        repo_obj.refs.set_symbolic_ref(HEADREF, ref_path)
 
 
 
 
 def pack_refs(repo: RepoPath, all: bool = False) -> None:
 def pack_refs(repo: RepoPath, all: bool = False) -> None:
@@ -878,7 +880,7 @@ def commit(
             )
             )
             # Update HEAD to point to the new commit with reflog message
             # Update HEAD to point to the new commit with reflog message
             try:
             try:
-                old_head = r.refs[b"HEAD"]
+                old_head = r.refs[HEADREF]
             except KeyError:
             except KeyError:
                 old_head = None
                 old_head = None
 
 
@@ -892,7 +894,7 @@ def commit(
                 default_message = default_message[:97] + b"..."
                 default_message = default_message[:97] + b"..."
             reflog_message = _get_reflog_message(default_message)
             reflog_message = _get_reflog_message(default_message)
 
 
-            r.refs.set_if_equals(b"HEAD", old_head, commit_sha, message=reflog_message)
+            r.refs.set_if_equals(HEADREF, old_head, commit_sha, message=reflog_message)
             return commit_sha
             return commit_sha
         else:
         else:
             return r.get_worktree().commit(
             return r.get_worktree().commit(
@@ -911,11 +913,11 @@ def commit(
 
 
 def commit_tree(
 def commit_tree(
     repo: RepoPath,
     repo: RepoPath,
-    tree: bytes,
+    tree: ObjectID,
     message: str | bytes | None = None,
     message: str | bytes | None = None,
     author: bytes | None = None,
     author: bytes | None = None,
     committer: bytes | None = None,
     committer: bytes | None = None,
-) -> bytes:
+) -> ObjectID:
     """Create a new commit object.
     """Create a new commit object.
 
 
     Args:
     Args:
@@ -2002,18 +2004,18 @@ def diff_tree(
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
         if isinstance(old_tree, Tree):
         if isinstance(old_tree, Tree):
-            old_tree_id: bytes | None = old_tree.id
+            old_tree_id: ObjectID | None = old_tree.id
         elif isinstance(old_tree, str):
         elif isinstance(old_tree, str):
-            old_tree_id = old_tree.encode()
+            old_tree_id = ObjectID(old_tree.encode())
         else:
         else:
-            old_tree_id = old_tree
+            old_tree_id = ObjectID(old_tree)
 
 
         if isinstance(new_tree, Tree):
         if isinstance(new_tree, Tree):
-            new_tree_id: bytes | None = new_tree.id
+            new_tree_id: ObjectID | None = new_tree.id
         elif isinstance(new_tree, str):
         elif isinstance(new_tree, str):
-            new_tree_id = new_tree.encode()
+            new_tree_id = ObjectID(new_tree.encode())
         else:
         else:
-            new_tree_id = new_tree
+            new_tree_id = ObjectID(new_tree)
 
 
         write_tree_diff(outstream, r.object_store, old_tree_id, new_tree_id)
         write_tree_diff(outstream, r.object_store, old_tree_id, new_tree_id)
 
 
@@ -2323,7 +2325,7 @@ def submodule_update(
                     sub_config.write_to_path()
                     sub_config.write_to_path()
 
 
                     # Checkout the target commit
                     # Checkout the target commit
-                    sub_repo.refs[b"HEAD"] = target_sha
+                    sub_repo.refs[HEADREF] = target_sha
 
 
                     # Build the index and checkout files
                     # Build the index and checkout files
                     tree = sub_repo[target_sha]
                     tree = sub_repo[target_sha]
@@ -2346,7 +2348,7 @@ def submodule_update(
                     client.fetch(path_segments.encode(), sub_repo)
                     client.fetch(path_segments.encode(), sub_repo)
 
 
                     # Update to the target commit
                     # Update to the target commit
-                    sub_repo.refs[b"HEAD"] = target_sha
+                    sub_repo.refs[HEADREF] = target_sha
 
 
                     # Reset the working directory
                     # Reset the working directory
                     reset(sub_repo, "hard", target_sha)
                     reset(sub_repo, "hard", target_sha)
@@ -2513,7 +2515,7 @@ def verify_tag(
         tag_obj.verify(keyids)
         tag_obj.verify(keyids)
 
 
 
 
-def tag_list(repo: RepoPath, outstream: TextIO = sys.stdout) -> list[bytes]:
+def tag_list(repo: RepoPath, outstream: TextIO = sys.stdout) -> list[Ref]:
     """List all tags.
     """List all tags.
 
 
     Args:
     Args:
@@ -2521,7 +2523,7 @@ def tag_list(repo: RepoPath, outstream: TextIO = sys.stdout) -> list[bytes]:
       outstream: Stream to write tags to
       outstream: Stream to write tags to
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
-        tags = sorted(r.refs.as_dict(b"refs/tags"))
+        tags: list[Ref] = sorted(r.refs.as_dict(Ref(b"refs/tags")))
         return tags
         return tags
 
 
 
 
@@ -2666,7 +2668,7 @@ def notes_show(
         return r.notes.get_note(object_sha, notes_ref, config=config)
         return r.notes.get_note(object_sha, notes_ref, config=config)
 
 
 
 
-def notes_list(repo: RepoPath, ref: bytes = b"commits") -> list[tuple[bytes, bytes]]:
+def notes_list(repo: RepoPath, ref: bytes = b"commits") -> list[tuple[ObjectID, bytes]]:
     """List all notes in a notes ref.
     """List all notes in a notes ref.
 
 
     Args:
     Args:
@@ -2686,7 +2688,7 @@ def notes_list(repo: RepoPath, ref: bytes = b"commits") -> list[tuple[bytes, byt
         return r.notes.list_notes(notes_ref, config=config)
         return r.notes.list_notes(notes_ref, config=config)
 
 
 
 
-def replace_list(repo: RepoPath) -> list[tuple[bytes, bytes]]:
+def replace_list(repo: RepoPath) -> list[tuple[ObjectID, ObjectID]]:
     """List all replacement refs.
     """List all replacement refs.
 
 
     Args:
     Args:
@@ -2697,16 +2699,16 @@ def replace_list(repo: RepoPath) -> list[tuple[bytes, bytes]]:
       object being replaced and replacement_sha is what it's replaced with
       object being replaced and replacement_sha is what it's replaced with
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
-        replacements = []
+        replacements: list[tuple[ObjectID, ObjectID]] = []
         for ref in r.refs.keys():
         for ref in r.refs.keys():
             if ref.startswith(LOCAL_REPLACE_PREFIX):
             if ref.startswith(LOCAL_REPLACE_PREFIX):
-                object_sha = ref[len(LOCAL_REPLACE_PREFIX) :]
+                object_sha = ObjectID(ref[len(LOCAL_REPLACE_PREFIX) :])
                 replacement_sha = r.refs[ref]
                 replacement_sha = r.refs[ref]
                 replacements.append((object_sha, replacement_sha))
                 replacements.append((object_sha, replacement_sha))
         return replacements
         return replacements
 
 
 
 
-def replace_delete(repo: RepoPath, object_sha: bytes | str) -> None:
+def replace_delete(repo: RepoPath, object_sha: ObjectID | str) -> None:
     """Delete a replacement ref.
     """Delete a replacement ref.
 
 
     Args:
     Args:
@@ -2714,24 +2716,24 @@ def replace_delete(repo: RepoPath, object_sha: bytes | str) -> None:
       object_sha: SHA of the object whose replacement should be removed
       object_sha: SHA of the object whose replacement should be removed
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
-        # Convert to bytes if string
+        # Convert to ObjectID if string
         if isinstance(object_sha, str):
         if isinstance(object_sha, str):
-            object_sha_hex = object_sha.encode("ascii")
+            object_sha_id = ObjectID(object_sha.encode("ascii"))
         else:
         else:
-            object_sha_hex = object_sha
+            object_sha_id = object_sha
 
 
-        replace_ref = _make_replace_ref(object_sha_hex)
+        replace_ref = _make_replace_ref(object_sha_id)
         if replace_ref not in r.refs:
         if replace_ref not in r.refs:
             raise KeyError(
             raise KeyError(
-                f"No replacement ref found for {object_sha_hex.decode('ascii')}"
+                f"No replacement ref found for {object_sha_id.decode('ascii')}"
             )
             )
         del r.refs[replace_ref]
         del r.refs[replace_ref]
 
 
 
 
 def replace_create(
 def replace_create(
     repo: RepoPath,
     repo: RepoPath,
-    object_sha: str | bytes,
-    replacement_sha: str | bytes,
+    object_sha: str | ObjectID,
+    replacement_sha: str | ObjectID,
 ) -> None:
 ) -> None:
     """Create a replacement ref to replace one object with another.
     """Create a replacement ref to replace one object with another.
 
 
@@ -2741,20 +2743,20 @@ def replace_create(
       replacement_sha: SHA of the replacement object
       replacement_sha: SHA of the replacement object
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
-        # Convert to bytes if string
+        # Convert to ObjectID if string
         if isinstance(object_sha, str):
         if isinstance(object_sha, str):
-            object_sha_hex = object_sha.encode("ascii")
+            object_sha_id = ObjectID(object_sha.encode("ascii"))
         else:
         else:
-            object_sha_hex = object_sha
+            object_sha_id = object_sha
 
 
         if isinstance(replacement_sha, str):
         if isinstance(replacement_sha, str):
-            replacement_sha_hex = replacement_sha.encode("ascii")
+            replacement_sha_id = ObjectID(replacement_sha.encode("ascii"))
         else:
         else:
-            replacement_sha_hex = replacement_sha
+            replacement_sha_id = replacement_sha
 
 
         # Create the replacement ref
         # Create the replacement ref
-        replace_ref = _make_replace_ref(object_sha_hex)
-        r.refs[replace_ref] = replacement_sha_hex
+        replace_ref = _make_replace_ref(object_sha_id)
+        r.refs[replace_ref] = replacement_sha_id
 
 
 
 
 def reset(
 def reset(
@@ -2783,7 +2785,7 @@ def reset(
         if target_commit is not None:
         if target_commit is not None:
             # Get the current HEAD value for set_if_equals
             # Get the current HEAD value for set_if_equals
             try:
             try:
-                old_head = r.refs[b"HEAD"]
+                old_head = r.refs[HEADREF]
             except KeyError:
             except KeyError:
                 old_head = None
                 old_head = None
 
 
@@ -2800,7 +2802,7 @@ def reset(
 
 
             # Update HEAD with reflog message
             # Update HEAD with reflog message
             r.refs.set_if_equals(
             r.refs.set_if_equals(
-                b"HEAD", old_head, target_commit.id, message=reflog_message
+                HEADREF, old_head, target_commit.id, message=reflog_message
             )
             )
 
 
         if mode == "soft":
         if mode == "soft":
@@ -2978,14 +2980,14 @@ def push(
         )
         )
 
 
         selected_refs = []
         selected_refs = []
-        remote_changed_refs: dict[bytes, bytes | None] = {}
+        remote_changed_refs: dict[Ref, ObjectID | None] = {}
 
 
-        def update_refs(refs: dict[bytes, bytes]) -> dict[bytes, bytes]:
-            remote_refs = DictRefsContainer(refs)
+        def update_refs(refs: dict[Ref, ObjectID]) -> dict[Ref, ObjectID]:
+            remote_refs = DictRefsContainer(refs)  # type: ignore[arg-type]
             selected_refs.extend(
             selected_refs.extend(
                 parse_reftuples(r.refs, remote_refs, refspecs_bytes, force=force)
                 parse_reftuples(r.refs, remote_refs, refspecs_bytes, force=force)
             )
             )
-            new_refs = {}
+            new_refs: dict[Ref, ObjectID] = {}
 
 
             # In mirror mode, delete remote refs that don't exist locally
             # In mirror mode, delete remote refs that don't exist locally
             if mirror_mode:
             if mirror_mode:
@@ -3019,8 +3021,8 @@ def push(
         try:
         try:
 
 
             def generate_pack_data_wrapper(
             def generate_pack_data_wrapper(
-                have: AbstractSet[bytes],
-                want: AbstractSet[bytes],
+                have: AbstractSet[ObjectID],
+                want: AbstractSet[ObjectID],
                 *,
                 *,
                 ofs_delta: bool = False,
                 ofs_delta: bool = False,
                 progress: Callable[..., None] | None = None,
                 progress: Callable[..., None] | None = None,
@@ -3046,7 +3048,7 @@ def push(
                 b"Push to " + remote_location.encode(err_encoding) + b" successful.\n"
                 b"Push to " + remote_location.encode(err_encoding) + b" successful.\n"
             )
             )
 
 
-        for ref, error in (result.ref_status or {}).items():
+        for ref, error in (result.ref_status or {}).items():  # type: ignore[assignment]
             if error is not None:
             if error is not None:
                 errstream.write(
                 errstream.write(
                     f"Push of ref {ref.decode('utf-8', 'replace')} failed: {error}\n".encode(
                     f"Push of ref {ref.decode('utf-8', 'replace')} failed: {error}\n".encode(
@@ -3125,9 +3127,9 @@ def pull(
                     refspecs_normalized.append(spec)
                     refspecs_normalized.append(spec)
 
 
         def determine_wants(
         def determine_wants(
-            remote_refs: dict[bytes, bytes], depth: int | None = None
-        ) -> list[bytes]:
-            remote_refs_container = DictRefsContainer(remote_refs)
+            remote_refs: dict[Ref, ObjectID], depth: int | None = None
+        ) -> list[ObjectID]:
+            remote_refs_container = DictRefsContainer(remote_refs)  # type: ignore[arg-type]
             selected_refs.extend(
             selected_refs.extend(
                 parse_reftuples(
                 parse_reftuples(
                     remote_refs_container, r.refs, refspecs_normalized, force=force
                     remote_refs_container, r.refs, refspecs_normalized, force=force
@@ -3166,7 +3168,7 @@ def pull(
 
 
         # Store the old HEAD tree before making changes
         # Store the old HEAD tree before making changes
         try:
         try:
-            old_head = r.refs[b"HEAD"]
+            old_head = r.refs[HEADREF]
             old_commit = r[old_head]
             old_commit = r[old_head]
             assert isinstance(old_commit, Commit)
             assert isinstance(old_commit, Commit)
             old_tree_id = old_commit.tree
             old_tree_id = old_commit.tree
@@ -3202,7 +3204,7 @@ def pull(
             if rh is not None and lh is not None:
             if rh is not None and lh is not None:
                 lh_value = fetch_result.refs[lh]
                 lh_value = fetch_result.refs[lh]
                 if lh_value is not None:
                 if lh_value is not None:
-                    r.refs[rh] = lh_value
+                    r.refs[Ref(rh)] = lh_value
 
 
         # Only update HEAD if we didn't perform a merge
         # Only update HEAD if we didn't perform a merge
         if selected_refs and not merged:
         if selected_refs and not merged:
@@ -3833,7 +3835,7 @@ def _make_tag_ref(name: str | bytes) -> Ref:
     return local_tag_name(name)
     return local_tag_name(name)
 
 
 
 
-def _make_replace_ref(name: str | bytes) -> Ref:
+def _make_replace_ref(name: str | bytes | ObjectID) -> Ref:
     if isinstance(name, str):
     if isinstance(name, str):
         name = name.encode(DEFAULT_ENCODING)
         name = name.encode(DEFAULT_ENCODING)
     return local_replace_name(name)
     return local_replace_name(name)
@@ -3881,7 +3883,7 @@ def branch_create(
             else objectish
             else objectish
         )
         )
 
 
-        if b"refs/remotes/" + objectish_bytes in r.refs:
+        if Ref(b"refs/remotes/" + objectish_bytes) in r.refs:
             objectish = b"refs/remotes/" + objectish_bytes
             objectish = b"refs/remotes/" + objectish_bytes
         elif local_branch_name(objectish_bytes) in r.refs:
         elif local_branch_name(objectish_bytes) in r.refs:
             objectish = local_branch_name(objectish_bytes)
             objectish = local_branch_name(objectish_bytes)
@@ -3918,15 +3920,15 @@ def branch_create(
                 else original_objectish
                 else original_objectish
             )
             )
 
 
-            if objectish_bytes in r.refs:
+            if Ref(objectish_bytes) in r.refs:
                 objectish_ref = objectish_bytes
                 objectish_ref = objectish_bytes
-            elif b"refs/remotes/" + objectish_bytes in r.refs:
+            elif Ref(b"refs/remotes/" + objectish_bytes) in r.refs:
                 objectish_ref = b"refs/remotes/" + objectish_bytes
                 objectish_ref = b"refs/remotes/" + objectish_bytes
             elif local_branch_name(objectish_bytes) in r.refs:
             elif local_branch_name(objectish_bytes) in r.refs:
                 objectish_ref = local_branch_name(objectish_bytes)
                 objectish_ref = local_branch_name(objectish_bytes)
         else:
         else:
             # HEAD might point to a remote-tracking branch
             # HEAD might point to a remote-tracking branch
-            head_ref = r.refs.follow(b"HEAD")[0][1]
+            head_ref = r.refs.follow(HEADREF)[0][1]
             if head_ref.startswith(b"refs/remotes/"):
             if head_ref.startswith(b"refs/remotes/"):
                 objectish_ref = head_ref
                 objectish_ref = head_ref
 
 
@@ -3974,7 +3976,7 @@ def filter_branches_by_pattern(branches: Iterable[bytes], pattern: str) -> list[
     ]
     ]
 
 
 
 
-def branch_list(repo: RepoPath) -> list[bytes]:
+def branch_list(repo: RepoPath) -> list[Ref]:
     """List all branches.
     """List all branches.
 
 
     Args:
     Args:
@@ -3983,7 +3985,7 @@ def branch_list(repo: RepoPath) -> list[bytes]:
       List of branch names (without refs/heads/ prefix)
       List of branch names (without refs/heads/ prefix)
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
-        branches = list(r.refs.keys(base=LOCAL_BRANCH_PREFIX))
+        branches: list[Ref] = list(r.refs.keys(base=Ref(LOCAL_BRANCH_PREFIX)))
 
 
         # Check for branch.sort configuration
         # Check for branch.sort configuration
         config = r.get_config_stack()
         config = r.get_config_stack()
@@ -4040,7 +4042,7 @@ def branch_remotes_list(repo: RepoPath) -> list[bytes]:
       List of branch names (without refs/remotes/ prefix, and without remote name; e.g. 'main' from 'origin/main')
       List of branch names (without refs/remotes/ prefix, and without remote name; e.g. 'main' from 'origin/main')
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
-        branches = list(r.refs.keys(base=LOCAL_REMOTE_PREFIX))
+        branches = [bytes(ref) for ref in r.refs.keys(base=Ref(LOCAL_REMOTE_PREFIX))]
 
 
         config = r.get_config_stack()
         config = r.get_config_stack()
         try:
         try:
@@ -4063,7 +4065,7 @@ def branch_remotes_list(repo: RepoPath) -> list[bytes]:
             # Sort by date
             # Sort by date
             def get_commit_date(branch_name: bytes) -> int:
             def get_commit_date(branch_name: bytes) -> int:
                 ref = LOCAL_REMOTE_PREFIX + branch_name
                 ref = LOCAL_REMOTE_PREFIX + branch_name
-                sha = r.refs[ref]
+                sha = r.refs[Ref(ref)]
                 commit = r.object_store[sha]
                 commit = r.object_store[sha]
                 assert isinstance(commit, Commit)
                 assert isinstance(commit, Commit)
                 if sort_key == "committerdate":
                 if sort_key == "committerdate":
@@ -4100,9 +4102,9 @@ def _get_branch_merge_status(repo: RepoPath) -> Iterator[tuple[bytes, bool]]:
         - ``is_merged``: True if branch is merged into HEAD, False otherwise
         - ``is_merged``: True if branch is merged into HEAD, False otherwise
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
-        current_sha = r.refs[b"HEAD"]
+        current_sha = r.refs[HEADREF]
 
 
-        for branch_ref, branch_sha in r.refs.as_dict(base=b"refs/heads/").items():
+        for branch_ref, branch_sha in r.refs.as_dict(base=Ref(b"refs/heads/")).items():
             # Check if branch is an ancestor of HEAD (fully merged)
             # Check if branch is an ancestor of HEAD (fully merged)
             is_merged = can_fast_forward(r, branch_sha, current_sha)
             is_merged = can_fast_forward(r, branch_sha, current_sha)
             yield branch_ref, is_merged
             yield branch_ref, is_merged
@@ -4154,7 +4156,9 @@ def branches_containing(repo: RepoPath, commit: str) -> Iterator[bytes]:
         commit_obj = parse_commit(r, commit)
         commit_obj = parse_commit(r, commit)
         commit_sha = commit_obj.id
         commit_sha = commit_obj.id
 
 
-        for branch_ref, branch_sha in r.refs.as_dict(base=LOCAL_BRANCH_PREFIX).items():
+        for branch_ref, branch_sha in r.refs.as_dict(
+            base=Ref(LOCAL_BRANCH_PREFIX)
+        ).items():
             if can_fast_forward(r, commit_sha, branch_sha):
             if can_fast_forward(r, commit_sha, branch_sha):
                 yield branch_ref
                 yield branch_ref
 
 
@@ -4171,7 +4175,7 @@ def active_branch(repo: RepoPath) -> bytes:
       IndexError: if HEAD is floating
       IndexError: if HEAD is floating
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
-        active_ref = r.refs.follow(b"HEAD")[0][1]
+        active_ref = r.refs.follow(HEADREF)[0][1]
         if not active_ref.startswith(LOCAL_BRANCH_PREFIX):
         if not active_ref.startswith(LOCAL_BRANCH_PREFIX):
             raise ValueError(active_ref)
             raise ValueError(active_ref)
         return active_ref[len(LOCAL_BRANCH_PREFIX) :]
         return active_ref[len(LOCAL_BRANCH_PREFIX) :]
@@ -4352,7 +4356,7 @@ def for_each_ref(
         refs = r.get_refs()
         refs = r.get_refs()
 
 
     if pattern:
     if pattern:
-        matching_refs: dict[bytes, bytes] = {}
+        matching_refs: dict[Ref, ObjectID] = {}
         pattern_parts = pattern.split(b"/")
         pattern_parts = pattern.split(b"/")
         for ref, sha in refs.items():
         for ref, sha in refs.items():
             matches = False
             matches = False
@@ -4429,12 +4433,12 @@ def show_ref(
             filtered_refs = filter_ref_prefix(refs, [b"refs/"])
             filtered_refs = filter_ref_prefix(refs, [b"refs/"])
 
 
         # Add HEAD if requested
         # Add HEAD if requested
-        if head and b"HEAD" in refs:
-            filtered_refs[b"HEAD"] = refs[b"HEAD"]
+        if head and HEADREF in refs:
+            filtered_refs[HEADREF] = refs[HEADREF]
 
 
         # Filter by patterns if specified
         # Filter by patterns if specified
         if byte_patterns:
         if byte_patterns:
-            matching_refs: dict[bytes, bytes] = {}
+            matching_refs: dict[Ref, ObjectID] = {}
             for ref, sha in filtered_refs.items():
             for ref, sha in filtered_refs.items():
                 for pattern in byte_patterns:
                 for pattern in byte_patterns:
                     if verify:
                     if verify:
@@ -4527,7 +4531,7 @@ def show_branch(
         refs = r.get_refs()
         refs = r.get_refs()
 
 
         # Determine which branches to show
         # Determine which branches to show
-        branch_refs: dict[bytes, bytes] = {}
+        branch_refs: dict[Ref, ObjectID] = {}
 
 
         if branches:
         if branches:
             # Specific branches requested
             # Specific branches requested
@@ -4536,18 +4540,19 @@ def show_branch(
                     os.fsencode(branch) if isinstance(branch, str) else branch
                     os.fsencode(branch) if isinstance(branch, str) else branch
                 )
                 )
                 # Try as full ref name first
                 # Try as full ref name first
-                if branch_bytes in refs:
-                    branch_refs[branch_bytes] = refs[branch_bytes]
+                branch_ref_check = Ref(branch_bytes)
+                if branch_ref_check in refs:
+                    branch_refs[branch_ref_check] = refs[branch_ref_check]
                 else:
                 else:
                     # Try as branch name
                     # Try as branch name
                     branch_ref = local_branch_name(branch_bytes)
                     branch_ref = local_branch_name(branch_bytes)
                     if branch_ref in refs:
                     if branch_ref in refs:
                         branch_refs[branch_ref] = refs[branch_ref]
                         branch_refs[branch_ref] = refs[branch_ref]
                     # Try as remote branch
                     # Try as remote branch
-                    elif LOCAL_REMOTE_PREFIX + branch_bytes in refs:
-                        branch_refs[LOCAL_REMOTE_PREFIX + branch_bytes] = refs[
-                            LOCAL_REMOTE_PREFIX + branch_bytes
-                        ]
+                    else:
+                        remote_ref = Ref(LOCAL_REMOTE_PREFIX + branch_bytes)
+                        if remote_ref in refs:
+                            branch_refs[remote_ref] = refs[remote_ref]
         else:
         else:
             # Default behavior: show local branches
             # Default behavior: show local branches
             if all_branches:
             if all_branches:
@@ -4565,7 +4570,7 @@ def show_branch(
         # Add current branch if requested and not already included
         # Add current branch if requested and not already included
         if current:
         if current:
             try:
             try:
-                head_refs, _ = r.refs.follow(b"HEAD")
+                head_refs, _ = r.refs.follow(HEADREF)
                 if head_refs:
                 if head_refs:
                     head_ref = head_refs[0]
                     head_ref = head_refs[0]
                     if head_ref not in branch_refs and head_ref in refs:
                     if head_ref not in branch_refs and head_ref in refs:
@@ -4579,7 +4584,7 @@ def show_branch(
 
 
         # Sort branches for consistent output
         # Sort branches for consistent output
         sorted_branches = sorted(branch_refs.items(), key=lambda x: x[0])
         sorted_branches = sorted(branch_refs.items(), key=lambda x: x[0])
-        branch_sha_list = [sha for _, sha in sorted_branches]
+        branch_sha_list: list[ObjectID] = [sha for _, sha in sorted_branches]
 
 
         # Handle --independent flag
         # Handle --independent flag
         if independent_branches:
         if independent_branches:
@@ -4604,7 +4609,7 @@ def show_branch(
         # Get current branch for marking
         # Get current branch for marking
         current_branch: bytes | None = None
         current_branch: bytes | None = None
         try:
         try:
-            head_refs, _ = r.refs.follow(b"HEAD")
+            head_refs, _ = r.refs.follow(HEADREF)
             if head_refs:
             if head_refs:
                 current_branch = head_refs[0]
                 current_branch = head_refs[0]
         except (KeyError, TypeError):
         except (KeyError, TypeError):
@@ -4837,7 +4842,7 @@ def repack(repo: RepoPath, write_bitmaps: bool = False) -> None:
 
 
 def pack_objects(
 def pack_objects(
     repo: RepoPath,
     repo: RepoPath,
-    object_ids: Sequence[bytes],
+    object_ids: Sequence[ObjectID],
     packf: BinaryIO,
     packf: BinaryIO,
     idxf: BinaryIO | None,
     idxf: BinaryIO | None,
     delta_window_size: int | None = None,
     delta_window_size: int | None = None,
@@ -4889,7 +4894,7 @@ def ls_tree(
       name_only: Only print item name
       name_only: Only print item name
     """
     """
 
 
-    def list_tree(store: BaseObjectStore, treeid: bytes, base: bytes) -> None:
+    def list_tree(store: BaseObjectStore, treeid: ObjectID, base: bytes) -> None:
         tree = store[treeid]
         tree = store[treeid]
         assert isinstance(tree, Tree)
         assert isinstance(tree, Tree)
         for name, mode, sha in tree.iteritems():
         for name, mode, sha in tree.iteritems():
@@ -5057,7 +5062,7 @@ def check_ignore(
                 yield _quote_path(output_path) if quote_path else output_path
                 yield _quote_path(output_path) if quote_path else output_path
 
 
 
 
-def _get_current_head_tree(repo: Repo) -> bytes | None:
+def _get_current_head_tree(repo: Repo) -> ObjectID | None:
     """Get the current HEAD tree ID.
     """Get the current HEAD tree ID.
 
 
     Args:
     Args:
@@ -5067,10 +5072,10 @@ def _get_current_head_tree(repo: Repo) -> bytes | None:
       Tree ID of current HEAD, or None if no HEAD exists (empty repo)
       Tree ID of current HEAD, or None if no HEAD exists (empty repo)
     """
     """
     try:
     try:
-        current_head = repo.refs[b"HEAD"]
+        current_head = repo.refs[HEADREF]
         current_commit = repo[current_head]
         current_commit = repo[current_head]
         assert isinstance(current_commit, Commit), "Expected a Commit object"
         assert isinstance(current_commit, Commit), "Expected a Commit object"
-        tree_id: bytes = current_commit.tree
+        tree_id: ObjectID = current_commit.tree
         return tree_id
         return tree_id
     except KeyError:
     except KeyError:
         # No HEAD yet (empty repo)
         # No HEAD yet (empty repo)
@@ -5078,7 +5083,7 @@ def _get_current_head_tree(repo: Repo) -> bytes | None:
 
 
 
 
 def _check_uncommitted_changes(
 def _check_uncommitted_changes(
-    repo: Repo, target_tree_id: bytes, force: bool = False
+    repo: Repo, target_tree_id: ObjectID, force: bool = False
 ) -> None:
 ) -> None:
     """Check for uncommitted changes that would conflict with a checkout/switch.
     """Check for uncommitted changes that would conflict with a checkout/switch.
 
 
@@ -5182,8 +5187,8 @@ def _get_worktree_update_config(
 
 
 def _perform_tree_switch(
 def _perform_tree_switch(
     repo: Repo,
     repo: Repo,
-    current_tree_id: bytes | None,
-    target_tree_id: bytes,
+    current_tree_id: ObjectID | None,
+    target_tree_id: ObjectID,
     force: bool = False,
     force: bool = False,
 ) -> None:
 ) -> None:
     """Perform the actual working tree switch.
     """Perform the actual working tree switch.
@@ -5239,7 +5244,7 @@ def update_head(
         if new_branch is not None:
         if new_branch is not None:
             to_set = _make_branch_ref(new_branch)
             to_set = _make_branch_ref(new_branch)
         else:
         else:
-            to_set = b"HEAD"
+            to_set = HEADREF
         if detached:
         if detached:
             # TODO(jelmer): Provide some way so that the actual ref gets
             # TODO(jelmer): Provide some way so that the actual ref gets
             # updated rather than what it points to, so the delete isn't
             # updated rather than what it points to, so the delete isn't
@@ -5249,7 +5254,7 @@ def update_head(
         else:
         else:
             r.refs.set_symbolic_ref(to_set, parse_ref(r, target))
             r.refs.set_symbolic_ref(to_set, parse_ref(r, target))
         if new_branch is not None:
         if new_branch is not None:
-            r.refs.set_symbolic_ref(b"HEAD", to_set)
+            r.refs.set_symbolic_ref(HEADREF, to_set)
 
 
 
 
 def checkout(
 def checkout(
@@ -5297,7 +5302,7 @@ def checkout(
             # If no target specified, use HEAD
             # If no target specified, use HEAD
             if target is None:
             if target is None:
                 try:
                 try:
-                    target = r.refs[b"HEAD"]
+                    target = r.refs[HEADREF]
                 except KeyError:
                 except KeyError:
                     raise CheckoutError("No HEAD reference found")
                     raise CheckoutError("No HEAD reference found")
             else:
             else:
@@ -5464,7 +5469,7 @@ def restore(
             if staged:
             if staged:
                 # Restoring staged files from HEAD
                 # Restoring staged files from HEAD
                 try:
                 try:
-                    source = r.refs[b"HEAD"]
+                    source = r.refs[HEADREF]
                 except KeyError:
                 except KeyError:
                     raise CheckoutError("No HEAD reference found")
                     raise CheckoutError("No HEAD reference found")
             elif worktree:
             elif worktree:
@@ -5626,8 +5631,6 @@ def switch(
             update_head(r, create)
             update_head(r, create)
 
 
             # Set up tracking if creating from a remote branch
             # Set up tracking if creating from a remote branch
-            from .refs import LOCAL_REMOTE_PREFIX, local_branch_name, parse_remote_ref
-
             if isinstance(original_target, bytes) and target_bytes.startswith(
             if isinstance(original_target, bytes) and target_bytes.startswith(
                 LOCAL_REMOTE_PREFIX
                 LOCAL_REMOTE_PREFIX
             ):
             ):
@@ -6121,13 +6124,13 @@ def write_tree(repo: RepoPath) -> bytes:
 
 
 def _do_merge(
 def _do_merge(
     r: Repo,
     r: Repo,
-    merge_commit_id: bytes,
+    merge_commit_id: ObjectID,
     no_commit: bool = False,
     no_commit: bool = False,
     no_ff: bool = False,
     no_ff: bool = False,
     message: bytes | None = None,
     message: bytes | None = None,
     author: bytes | None = None,
     author: bytes | None = None,
     committer: bytes | None = None,
     committer: bytes | None = None,
-) -> tuple[bytes | None, list[bytes]]:
+) -> tuple[ObjectID | None, list[bytes]]:
     """Internal merge implementation that operates on an open repository.
     """Internal merge implementation that operates on an open repository.
 
 
     Args:
     Args:
@@ -6148,7 +6151,7 @@ def _do_merge(
 
 
     # Get HEAD commit
     # Get HEAD commit
     try:
     try:
-        head_commit_id = r.refs[b"HEAD"]
+        head_commit_id = r.refs[HEADREF]
     except KeyError:
     except KeyError:
         raise Error("No HEAD reference found")
         raise Error("No HEAD reference found")
 
 
@@ -6174,7 +6177,7 @@ def _do_merge(
     # Check for fast-forward
     # Check for fast-forward
     if base_commit_id == head_commit_id and not no_ff:
     if base_commit_id == head_commit_id and not no_ff:
         # Fast-forward merge
         # Fast-forward merge
-        r.refs[b"HEAD"] = merge_commit_id
+        r.refs[HEADREF] = merge_commit_id
         # Update the working directory
         # Update the working directory
         changes = tree_changes(r.object_store, head_commit.tree, merge_commit.tree)
         changes = tree_changes(r.object_store, head_commit.tree, merge_commit.tree)
         update_working_tree(
         update_working_tree(
@@ -6235,20 +6238,20 @@ def _do_merge(
     r.object_store.add_object(merge_commit_obj)
     r.object_store.add_object(merge_commit_obj)
 
 
     # Update HEAD
     # Update HEAD
-    r.refs[b"HEAD"] = merge_commit_obj.id
+    r.refs[HEADREF] = merge_commit_obj.id
 
 
     return (merge_commit_obj.id, [])
     return (merge_commit_obj.id, [])
 
 
 
 
 def _do_octopus_merge(
 def _do_octopus_merge(
     r: Repo,
     r: Repo,
-    merge_commit_ids: list[bytes],
+    merge_commit_ids: list[ObjectID],
     no_commit: bool = False,
     no_commit: bool = False,
     no_ff: bool = False,
     no_ff: bool = False,
     message: bytes | None = None,
     message: bytes | None = None,
     author: bytes | None = None,
     author: bytes | None = None,
     committer: bytes | None = None,
     committer: bytes | None = None,
-) -> tuple[bytes | None, list[bytes]]:
+) -> tuple[ObjectID | None, list[bytes]]:
     """Internal octopus merge implementation that operates on an open repository.
     """Internal octopus merge implementation that operates on an open repository.
 
 
     Args:
     Args:
@@ -6269,7 +6272,7 @@ def _do_octopus_merge(
 
 
     # Get HEAD commit
     # Get HEAD commit
     try:
     try:
-        head_commit_id = r.refs[b"HEAD"]
+        head_commit_id = r.refs[HEADREF]
     except KeyError:
     except KeyError:
         raise Error("No HEAD reference found")
         raise Error("No HEAD reference found")
 
 
@@ -6367,7 +6370,7 @@ def _do_octopus_merge(
     r.object_store.add_object(merge_commit_obj)
     r.object_store.add_object(merge_commit_obj)
 
 
     # Update HEAD
     # Update HEAD
-    r.refs[b"HEAD"] = merge_commit_obj.id
+    r.refs[HEADREF] = merge_commit_obj.id
 
 
     return (merge_commit_obj.id, [])
     return (merge_commit_obj.id, [])
 
 
@@ -6544,7 +6547,7 @@ def cherry(
         if upstream is None:
         if upstream is None:
             # Try to find tracking branch
             # Try to find tracking branch
             upstream_found = False
             upstream_found = False
-            head_refs, _ = r.refs.follow(b"HEAD")
+            head_refs, _ = r.refs.follow(HEADREF)
             if head_refs:
             if head_refs:
                 head_ref = head_refs[0]
                 head_ref = head_refs[0]
                 if head_ref.startswith(b"refs/heads/"):
                 if head_ref.startswith(b"refs/heads/"):
@@ -6566,7 +6569,7 @@ def cherry(
 
 
                         if remote_name:
                         if remote_name:
                             # Build the tracking branch ref
                             # Build the tracking branch ref
-                            upstream_refname = (
+                            upstream_refname = Ref(
                                 b"refs/remotes/"
                                 b"refs/remotes/"
                                 + remote_name
                                 + remote_name
                                 + b"/"
                                 + b"/"
@@ -6578,7 +6581,7 @@ def cherry(
 
 
             if not upstream_found:
             if not upstream_found:
                 # Default to HEAD^ if no tracking branch found
                 # Default to HEAD^ if no tracking branch found
-                head_commit = r[b"HEAD"]
+                head_commit = r[HEADREF]
                 if isinstance(head_commit, Commit) and head_commit.parents:
                 if isinstance(head_commit, Commit) and head_commit.parents:
                     upstream = head_commit.parents[0]
                     upstream = head_commit.parents[0]
                 else:
                 else:
@@ -6873,7 +6876,7 @@ def revert(
 
 
         # Get current HEAD
         # Get current HEAD
         try:
         try:
-            head_commit_id = r.refs[b"HEAD"]
+            head_commit_id = r.refs[HEADREF]
         except KeyError:
         except KeyError:
             raise Error("No HEAD reference found")
             raise Error("No HEAD reference found")
 
 
@@ -6973,7 +6976,7 @@ def revert(
                 r.object_store.add_object(revert_commit)
                 r.object_store.add_object(revert_commit)
 
 
                 # Update HEAD
                 # Update HEAD
-                r.refs[b"HEAD"] = revert_commit.id
+                r.refs[HEADREF] = revert_commit.id
                 head_commit_id = revert_commit.id
                 head_commit_id = revert_commit.id
 
 
         return head_commit_id if not no_commit else None
         return head_commit_id if not no_commit else None
@@ -7362,17 +7365,17 @@ def filter_branch(
     filter_author: Callable[[bytes], bytes | None] | None = None,
     filter_author: Callable[[bytes], bytes | None] | None = None,
     filter_committer: Callable[[bytes], bytes | None] | None = None,
     filter_committer: Callable[[bytes], bytes | None] | None = None,
     filter_message: Callable[[bytes], bytes | None] | None = None,
     filter_message: Callable[[bytes], bytes | None] | None = None,
-    tree_filter: Callable[[bytes, str], bytes | None] | None = None,
-    index_filter: Callable[[bytes, str], bytes | None] | None = None,
-    parent_filter: Callable[[Sequence[bytes]], list[bytes]] | None = None,
-    commit_filter: Callable[[Commit, bytes], bytes | None] | None = None,
+    tree_filter: Callable[[ObjectID, str], ObjectID | None] | None = None,
+    index_filter: Callable[[ObjectID, str], ObjectID | None] | None = None,
+    parent_filter: Callable[[Sequence[ObjectID]], list[ObjectID]] | None = None,
+    commit_filter: Callable[[Commit, ObjectID], ObjectID | None] | None = None,
     subdirectory_filter: str | bytes | None = None,
     subdirectory_filter: str | bytes | None = None,
     prune_empty: bool = False,
     prune_empty: bool = False,
     tag_name_filter: Callable[[bytes], bytes | None] | None = None,
     tag_name_filter: Callable[[bytes], bytes | None] | None = None,
     force: bool = False,
     force: bool = False,
     keep_original: bool = True,
     keep_original: bool = True,
     refs: list[bytes] | None = None,
     refs: list[bytes] | None = None,
-) -> dict[bytes, bytes]:
+) -> dict[ObjectID, ObjectID]:
     """Rewrite branch history by creating new commits with filtered properties.
     """Rewrite branch history by creating new commits with filtered properties.
 
 
     This is similar to git filter-branch, allowing you to rewrite commit
     This is similar to git filter-branch, allowing you to rewrite commit
@@ -7422,7 +7425,7 @@ def filter_branch(
             if branch == b"HEAD":
             if branch == b"HEAD":
                 # Resolve HEAD to actual branch
                 # Resolve HEAD to actual branch
                 try:
                 try:
-                    resolved = r.refs.follow(b"HEAD")
+                    resolved = r.refs.follow(HEADREF)
                     if resolved and resolved[0]:
                     if resolved and resolved[0]:
                         # resolved is a list of (refname, sha) tuples
                         # resolved is a list of (refname, sha) tuples
                         resolved_ref = resolved[0][-1]
                         resolved_ref = resolved[0][-1]
@@ -7465,7 +7468,7 @@ def filter_branch(
         )
         )
 
 
         # Tag callback for renaming tags
         # Tag callback for renaming tags
-        def rename_tag(old_ref: bytes, new_ref: bytes) -> None:
+        def rename_tag(old_ref: Ref, new_ref: Ref) -> None:
             # Copy tag to new name
             # Copy tag to new name
             r.refs[new_ref] = r.refs[old_ref]
             r.refs[new_ref] = r.refs[old_ref]
             # Delete old tag
             # Delete old tag
@@ -7488,7 +7491,7 @@ def filter_branch(
 
 
 def format_patch(
 def format_patch(
     repo: RepoPath = ".",
     repo: RepoPath = ".",
-    committish: bytes | tuple[bytes, bytes] | None = None,
+    committish: ObjectID | tuple[ObjectID, ObjectID] | None = None,
     outstream: TextIO = sys.stdout,
     outstream: TextIO = sys.stdout,
     outdir: str | os.PathLike[str] | None = None,
     outdir: str | os.PathLike[str] | None = None,
     n: int = 1,
     n: int = 1,
@@ -7664,7 +7667,7 @@ def bisect_start(
                 old_commit = r[r.head()]
                 old_commit = r[r.head()]
                 assert isinstance(old_commit, Commit)
                 assert isinstance(old_commit, Commit)
                 old_tree = old_commit.tree if r.head() else None
                 old_tree = old_commit.tree if r.head() else None
-                r.refs[b"HEAD"] = next_sha
+                r.refs[HEADREF] = next_sha
                 commit = r[next_sha]
                 commit = r[next_sha]
                 assert isinstance(commit, Commit)
                 assert isinstance(commit, Commit)
                 changes = tree_changes(r.object_store, old_tree, commit.tree)
                 changes = tree_changes(r.object_store, old_tree, commit.tree)
@@ -7696,7 +7699,7 @@ def bisect_bad(
             old_commit = r[r.head()]
             old_commit = r[r.head()]
             assert isinstance(old_commit, Commit)
             assert isinstance(old_commit, Commit)
             old_tree = old_commit.tree if r.head() else None
             old_tree = old_commit.tree if r.head() else None
-            r.refs[b"HEAD"] = next_sha
+            r.refs[HEADREF] = next_sha
             commit = r[next_sha]
             commit = r[next_sha]
             assert isinstance(commit, Commit)
             assert isinstance(commit, Commit)
             changes = tree_changes(r.object_store, old_tree, commit.tree)
             changes = tree_changes(r.object_store, old_tree, commit.tree)
@@ -7728,7 +7731,7 @@ def bisect_good(
             old_commit = r[r.head()]
             old_commit = r[r.head()]
             assert isinstance(old_commit, Commit)
             assert isinstance(old_commit, Commit)
             old_tree = old_commit.tree if r.head() else None
             old_tree = old_commit.tree if r.head() else None
-            r.refs[b"HEAD"] = next_sha
+            r.refs[HEADREF] = next_sha
             commit = r[next_sha]
             commit = r[next_sha]
             assert isinstance(commit, Commit)
             assert isinstance(commit, Commit)
             changes = tree_changes(r.object_store, old_tree, commit.tree)
             changes = tree_changes(r.object_store, old_tree, commit.tree)
@@ -7773,7 +7776,7 @@ def bisect_skip(
             old_commit = r[r.head()]
             old_commit = r[r.head()]
             assert isinstance(old_commit, Commit)
             assert isinstance(old_commit, Commit)
             old_tree = old_commit.tree if r.head() else None
             old_tree = old_commit.tree if r.head() else None
-            r.refs[b"HEAD"] = next_sha
+            r.refs[HEADREF] = next_sha
             commit = r[next_sha]
             commit = r[next_sha]
             assert isinstance(commit, Commit)
             assert isinstance(commit, Commit)
             changes = tree_changes(r.object_store, old_tree, commit.tree)
             changes = tree_changes(r.object_store, old_tree, commit.tree)
@@ -7942,7 +7945,7 @@ def reflog_expire(
             refs_to_process = [ref]
             refs_to_process = [ref]
 
 
         # Build set of reachable objects if we have unreachable expiration time
         # Build set of reachable objects if we have unreachable expiration time
-        reachable_objects: set[bytes] | None = None
+        reachable_objects: set[ObjectID] | None = None
         if expire_unreachable_time is not None:
         if expire_unreachable_time is not None:
             from .gc import find_reachable_objects
             from .gc import find_reachable_objects
 
 
@@ -8461,9 +8464,13 @@ def lfs_fetch(
 
 
         for ref in refs:
         for ref in refs:
             if isinstance(ref, str):
             if isinstance(ref, str):
-                ref = ref.encode()
+                ref_key = Ref(ref.encode())
+            elif isinstance(ref, bytes):
+                ref_key = Ref(ref)
+            else:
+                ref_key = ref
             try:
             try:
-                commit = r[r.refs[ref]]
+                commit = r[r.refs[ref_key]]
             except KeyError:
             except KeyError:
                 continue
                 continue
 
 
@@ -8583,19 +8590,21 @@ def lfs_push(
         # Find all LFS objects to push
         # Find all LFS objects to push
         if refs is None:
         if refs is None:
             # Push current branch
             # Push current branch
-            head_ref = r.refs.read_ref(b"HEAD")
+            head_ref = r.refs.read_ref(HEADREF)
             refs = [head_ref] if head_ref else []
             refs = [head_ref] if head_ref else []
 
 
         objects_to_push = set()
         objects_to_push = set()
 
 
         for ref in refs:
         for ref in refs:
             if isinstance(ref, str):
             if isinstance(ref, str):
-                ref = ref.encode()
+                ref_bytes = ref.encode()
+            else:
+                ref_bytes = ref
             try:
             try:
-                if ref.startswith(b"refs/"):
-                    commit = r[r.refs[ref]]
+                if ref_bytes.startswith(b"refs/"):
+                    commit = r[r.refs[Ref(ref_bytes)]]
                 else:
                 else:
-                    commit = r[ref]
+                    commit = r[ref_bytes]
             except KeyError:
             except KeyError:
                 continue
                 continue
 
 
@@ -8736,8 +8745,9 @@ def worktree_add(
 
 
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
         commit_bytes = commit.encode() if isinstance(commit, str) else commit
         commit_bytes = commit.encode() if isinstance(commit, str) else commit
+        commit_id = ObjectID(commit_bytes) if commit_bytes is not None else None
         wt_repo = add_worktree(
         wt_repo = add_worktree(
-            r, path, branch=branch, commit=commit_bytes, detach=detach, force=force
+            r, path, branch=branch, commit=commit_id, detach=detach, force=force
         )
         )
         return wt_repo.path
         return wt_repo.path
 
 
@@ -8867,7 +8877,7 @@ def merge_base(
     committishes: Sequence[str | bytes] | None = None,
     committishes: Sequence[str | bytes] | None = None,
     all: bool = False,
     all: bool = False,
     octopus: bool = False,
     octopus: bool = False,
-) -> list[bytes]:
+) -> list[ObjectID]:
     """Find the best common ancestor(s) between commits.
     """Find the best common ancestor(s) between commits.
 
 
     Args:
     Args:
@@ -8946,7 +8956,7 @@ def is_ancestor(
 def independent_commits(
 def independent_commits(
     repo: RepoPath = ".",
     repo: RepoPath = ".",
     committishes: Sequence[str | bytes] | None = None,
     committishes: Sequence[str | bytes] | None = None,
-) -> list[bytes]:
+) -> list[ObjectID]:
     """Filter commits to only those that are not reachable from others.
     """Filter commits to only those that are not reachable from others.
 
 
     Args:
     Args:

+ 5 - 1
dulwich/protocol.py

@@ -30,6 +30,7 @@ from os import SEEK_END
 import dulwich
 import dulwich
 
 
 from .errors import GitProtocolError, HangupException
 from .errors import GitProtocolError, HangupException
+from .objects import ObjectID
 
 
 TCP_GIT_PORT = 9418
 TCP_GIT_PORT = 9418
 
 
@@ -49,7 +50,10 @@ GIT_PROTOCOL_VERSIONS = [0, 1, 2]
 DEFAULT_GIT_PROTOCOL_VERSION_FETCH = 2
 DEFAULT_GIT_PROTOCOL_VERSION_FETCH = 2
 DEFAULT_GIT_PROTOCOL_VERSION_SEND = 0
 DEFAULT_GIT_PROTOCOL_VERSION_SEND = 0
 
 
-ZERO_SHA = b"0" * 40
+# Suffix used in the Git protocol to indicate peeled tag references
+PEELED_TAG_SUFFIX = b"^{}"
+
+ZERO_SHA: ObjectID = ObjectID(b"0" * 40)
 
 
 SINGLE_ACK = 0
 SINGLE_ACK = 0
 MULTI_ACK = 1
 MULTI_ACK = 1

+ 39 - 35
dulwich/rebase.py

@@ -31,9 +31,9 @@ from typing import Protocol, TypedDict
 
 
 from dulwich.graph import find_merge_base
 from dulwich.graph import find_merge_base
 from dulwich.merge import three_way_merge
 from dulwich.merge import three_way_merge
-from dulwich.objects import Commit
+from dulwich.objects import Commit, ObjectID
 from dulwich.objectspec import parse_commit
 from dulwich.objectspec import parse_commit
-from dulwich.refs import local_branch_name
+from dulwich.refs import HEADREF, Ref, local_branch_name, set_ref_from_raw
 from dulwich.repo import BaseRepo, Repo
 from dulwich.repo import BaseRepo, Repo
 
 
 
 
@@ -119,7 +119,7 @@ class RebaseTodoEntry:
     """Represents a single entry in a rebase todo list."""
     """Represents a single entry in a rebase todo list."""
 
 
     command: RebaseTodoCommand
     command: RebaseTodoCommand
-    commit_sha: bytes | None = None  # Store as hex string encoded as bytes
+    commit_sha: ObjectID | None = None  # Store as hex string encoded as bytes
     short_message: str | None = None
     short_message: str | None = None
     arguments: str | None = None
     arguments: str | None = None
 
 
@@ -209,7 +209,7 @@ class RebaseTodoEntry:
             # Commands that operate on commits
             # Commands that operate on commits
             if len(parts) > 1:
             if len(parts) > 1:
                 # Store SHA as hex string encoded as bytes
                 # Store SHA as hex string encoded as bytes
-                commit_sha = parts[1].encode()
+                commit_sha = ObjectID(parts[1].encode())
 
 
                 # Parse commit message if present
                 # Parse commit message if present
                 if len(parts) > 2:
                 if len(parts) > 2:
@@ -374,8 +374,8 @@ class RebaseStateManager(Protocol):
     def save(
     def save(
         self,
         self,
         original_head: bytes | None,
         original_head: bytes | None,
-        rebasing_branch: bytes | None,
-        onto: bytes | None,
+        rebasing_branch: Ref | None,
+        onto: ObjectID | None,
         todo: list[Commit],
         todo: list[Commit],
         done: list[Commit],
         done: list[Commit],
     ) -> None:
     ) -> None:
@@ -386,8 +386,8 @@ class RebaseStateManager(Protocol):
         self,
         self,
     ) -> tuple[
     ) -> tuple[
         bytes | None,  # original_head
         bytes | None,  # original_head
-        bytes | None,  # rebasing_branch
-        bytes | None,  # onto
+        Ref | None,  # rebasing_branch
+        ObjectID | None,  # onto
         list[Commit],  # todo
         list[Commit],  # todo
         list[Commit],  # done
         list[Commit],  # done
     ]:
     ]:
@@ -425,8 +425,8 @@ class DiskRebaseStateManager:
     def save(
     def save(
         self,
         self,
         original_head: bytes | None,
         original_head: bytes | None,
-        rebasing_branch: bytes | None,
-        onto: bytes | None,
+        rebasing_branch: Ref | None,
+        onto: ObjectID | None,
         todo: list[Commit],
         todo: list[Commit],
         done: list[Commit],
         done: list[Commit],
     ) -> None:
     ) -> None:
@@ -467,22 +467,26 @@ class DiskRebaseStateManager:
         self,
         self,
     ) -> tuple[
     ) -> tuple[
         bytes | None,
         bytes | None,
-        bytes | None,
-        bytes | None,
+        Ref | None,
+        ObjectID | None,
         list[Commit],
         list[Commit],
         list[Commit],
         list[Commit],
     ]:
     ]:
         """Load rebase state from disk."""
         """Load rebase state from disk."""
         original_head = None
         original_head = None
-        rebasing_branch = None
-        onto = None
+        rebasing_branch_bytes = None
+        onto_bytes = None
         todo: list[Commit] = []
         todo: list[Commit] = []
         done: list[Commit] = []
         done: list[Commit] = []
 
 
         # Load rebase state files
         # Load rebase state files
         original_head = self._read_file("orig-head")
         original_head = self._read_file("orig-head")
-        rebasing_branch = self._read_file("head-name")
-        onto = self._read_file("onto")
+        rebasing_branch_bytes = self._read_file("head-name")
+        rebasing_branch = (
+            Ref(rebasing_branch_bytes) if rebasing_branch_bytes is not None else None
+        )
+        onto_bytes = self._read_file("onto")
+        onto = ObjectID(onto_bytes) if onto_bytes is not None else None
 
 
         return original_head, rebasing_branch, onto, todo, done
         return original_head, rebasing_branch, onto, todo, done
 
 
@@ -532,8 +536,8 @@ class RebaseState(TypedDict):
     """Type definition for rebase state."""
     """Type definition for rebase state."""
 
 
     original_head: bytes | None
     original_head: bytes | None
-    rebasing_branch: bytes | None
-    onto: bytes | None
+    rebasing_branch: Ref | None
+    onto: ObjectID | None
     todo: list[Commit]
     todo: list[Commit]
     done: list[Commit]
     done: list[Commit]
 
 
@@ -554,8 +558,8 @@ class MemoryRebaseStateManager:
     def save(
     def save(
         self,
         self,
         original_head: bytes | None,
         original_head: bytes | None,
-        rebasing_branch: bytes | None,
-        onto: bytes | None,
+        rebasing_branch: Ref | None,
+        onto: ObjectID | None,
         todo: list[Commit],
         todo: list[Commit],
         done: list[Commit],
         done: list[Commit],
     ) -> None:
     ) -> None:
@@ -572,8 +576,8 @@ class MemoryRebaseStateManager:
         self,
         self,
     ) -> tuple[
     ) -> tuple[
         bytes | None,
         bytes | None,
-        bytes | None,
-        bytes | None,
+        Ref | None,
+        ObjectID | None,
         list[Commit],
         list[Commit],
         list[Commit],
         list[Commit],
     ]:
     ]:
@@ -630,10 +634,10 @@ class Rebaser:
 
 
         # Initialize state
         # Initialize state
         self._original_head: bytes | None = None
         self._original_head: bytes | None = None
-        self._onto: bytes | None = None
+        self._onto: ObjectID | None = None
         self._todo: list[Commit] = []
         self._todo: list[Commit] = []
         self._done: list[Commit] = []
         self._done: list[Commit] = []
-        self._rebasing_branch: bytes | None = None
+        self._rebasing_branch: Ref | None = None
 
 
         # Load any existing rebase state
         # Load any existing rebase state
         self._load_rebase_state()
         self._load_rebase_state()
@@ -653,7 +657,7 @@ class Rebaser:
         # Get the branch commit
         # Get the branch commit
         if branch is None:
         if branch is None:
             # Use current HEAD
             # Use current HEAD
-            _head_ref, head_sha = self.repo.refs.follow(b"HEAD")
+            _head_ref, head_sha = self.repo.refs.follow(HEADREF)
             if head_sha is None:
             if head_sha is None:
                 raise ValueError("HEAD does not point to a valid commit")
                 raise ValueError("HEAD does not point to a valid commit")
             branch_commit = self.repo[head_sha]
             branch_commit = self.repo[head_sha]
@@ -688,8 +692,8 @@ class Rebaser:
         return list(reversed(commits))
         return list(reversed(commits))
 
 
     def _cherry_pick(
     def _cherry_pick(
-        self, commit: Commit, onto: bytes
-    ) -> tuple[bytes | None, list[bytes]]:
+        self, commit: Commit, onto: ObjectID
+    ) -> tuple[ObjectID | None, list[bytes]]:
         """Cherry-pick a commit onto another commit.
         """Cherry-pick a commit onto another commit.
 
 
         Args:
         Args:
@@ -754,22 +758,22 @@ class Rebaser:
             List of commits that will be rebased
             List of commits that will be rebased
         """
         """
         # Save original HEAD
         # Save original HEAD
-        self._original_head = self.repo.refs.read_ref(b"HEAD")
+        self._original_head = self.repo.refs.read_ref(HEADREF)
 
 
         # Save which branch we're rebasing (for later update)
         # Save which branch we're rebasing (for later update)
         if branch is not None:
         if branch is not None:
             # Parse the branch ref
             # Parse the branch ref
             if branch.startswith(b"refs/heads/"):
             if branch.startswith(b"refs/heads/"):
-                self._rebasing_branch = branch
+                self._rebasing_branch = Ref(branch)
             else:
             else:
                 # Assume it's a branch name
                 # Assume it's a branch name
-                self._rebasing_branch = local_branch_name(branch)
+                self._rebasing_branch = Ref(local_branch_name(branch))
         else:
         else:
             # Use current branch
             # Use current branch
             if self._original_head is not None and self._original_head.startswith(
             if self._original_head is not None and self._original_head.startswith(
                 b"ref: "
                 b"ref: "
             ):
             ):
-                self._rebasing_branch = self._original_head[5:]
+                self._rebasing_branch = Ref(self._original_head[5:])
             else:
             else:
                 self._rebasing_branch = None
                 self._rebasing_branch = None
 
 
@@ -844,7 +848,7 @@ class Rebaser:
         # Restore original HEAD
         # Restore original HEAD
         if self._original_head is None:
         if self._original_head is None:
             raise RebaseError("No original HEAD to restore")
             raise RebaseError("No original HEAD to restore")
-        self.repo.refs[b"HEAD"] = self._original_head
+        set_ref_from_raw(self.repo.refs, HEADREF, self._original_head)
 
 
         # Clean up rebase state
         # Clean up rebase state
         self._clean_rebase_state()
         self._clean_rebase_state()
@@ -870,13 +874,13 @@ class Rebaser:
             # If HEAD was pointing to this branch, it will follow automatically
             # If HEAD was pointing to this branch, it will follow automatically
         else:
         else:
             # If we don't know which branch, check current HEAD
             # If we don't know which branch, check current HEAD
-            head_ref = self.repo.refs[b"HEAD"]
+            head_ref = self.repo.refs[HEADREF]
             if head_ref.startswith(b"ref: "):
             if head_ref.startswith(b"ref: "):
-                branch_ref = head_ref[5:]
+                branch_ref = Ref(head_ref[5:])
                 self.repo.refs[branch_ref] = last_commit.id
                 self.repo.refs[branch_ref] = last_commit.id
             else:
             else:
                 # Detached HEAD
                 # Detached HEAD
-                self.repo.refs[b"HEAD"] = last_commit.id
+                self.repo.refs[HEADREF] = last_commit.id
 
 
         # Clean up rebase state
         # Clean up rebase state
         self._clean_rebase_state()
         self._clean_rebase_state()

Rozdílová data souboru nebyla zobrazena, protože soubor je příliš velký
+ 210 - 171
dulwich/refs.py


+ 56 - 47
dulwich/reftable.py

@@ -22,6 +22,7 @@ from typing import BinaryIO
 from dulwich.objects import ObjectID
 from dulwich.objects import ObjectID
 from dulwich.refs import (
 from dulwich.refs import (
     SYMREF,
     SYMREF,
+    Ref,
     RefsContainer,
     RefsContainer,
 )
 )
 
 
@@ -909,7 +910,7 @@ class ReftableRefsContainer(RefsContainer):
                     files.append(os.path.join(self.reftable_dir, table_name))
                     files.append(os.path.join(self.reftable_dir, table_name))
         return files
         return files
 
 
-    def _read_all_tables(self) -> dict[bytes, tuple[int, bytes]]:
+    def _read_all_tables(self) -> dict[Ref, tuple[int, bytes]]:
         """Read all reftable files and merge results."""
         """Read all reftable files and merge results."""
         # First, read all tables and sort them by min_update_index
         # First, read all tables and sort them by min_update_index
         table_data = []
         table_data = []
@@ -924,19 +925,20 @@ class ReftableRefsContainer(RefsContainer):
         table_data.sort(key=lambda x: x[0])
         table_data.sort(key=lambda x: x[0])
 
 
         # Merge results in chronological order
         # Merge results in chronological order
-        all_refs: dict[bytes, tuple[int, bytes]] = {}
+        all_refs: dict[Ref, tuple[int, bytes]] = {}
         for min_update_index, table_file, refs in table_data:
         for min_update_index, table_file, refs in table_data:
             # Apply updates from this table
             # Apply updates from this table
             for refname, (value_type, value) in refs.items():
             for refname, (value_type, value) in refs.items():
+                ref = Ref(refname)
                 if value_type == REF_VALUE_DELETE:
                 if value_type == REF_VALUE_DELETE:
                     # Remove ref if it exists
                     # Remove ref if it exists
-                    all_refs.pop(refname, None)
+                    all_refs.pop(ref, None)
                 else:
                 else:
                     # Add/update ref
                     # Add/update ref
-                    all_refs[refname] = (value_type, value)
+                    all_refs[ref] = (value_type, value)
         return all_refs
         return all_refs
 
 
-    def allkeys(self) -> set[bytes]:
+    def allkeys(self) -> set[Ref]:
         """Return set of all ref names."""
         """Return set of all ref names."""
         refs = self._read_all_tables()
         refs = self._read_all_tables()
         result = set(refs.keys())
         result = set(refs.keys())
@@ -946,17 +948,17 @@ class ReftableRefsContainer(RefsContainer):
             if value_type == REF_VALUE_SYMREF:
             if value_type == REF_VALUE_SYMREF:
                 # Add the target ref as an implicit ref
                 # Add the target ref as an implicit ref
                 target = value
                 target = value
-                result.add(target)
+                result.add(Ref(target))
 
 
         return result
         return result
 
 
-    def follow(self, name: bytes) -> tuple[list[bytes], bytes]:
+    def follow(self, name: Ref) -> tuple[list[Ref], ObjectID | None]:
         """Follow a reference name.
         """Follow a reference name.
 
 
         Returns: a tuple of (refnames, sha), where refnames are the names of
         Returns: a tuple of (refnames, sha), where refnames are the names of
             references in the chain
             references in the chain
         """
         """
-        refnames = []
+        refnames: list[Ref] = []
         current = name
         current = name
         refs = self._read_all_tables()
         refs = self._read_all_tables()
 
 
@@ -968,11 +970,11 @@ class ReftableRefsContainer(RefsContainer):
 
 
             value_type, value = ref_data
             value_type, value = ref_data
             if value_type == REF_VALUE_REF:
             if value_type == REF_VALUE_REF:
-                return refnames, value
+                return refnames, ObjectID(value)
             if value_type == REF_VALUE_PEELED:
             if value_type == REF_VALUE_PEELED:
-                return refnames, value[:SHA1_HEX_SIZE]  # First SHA1 hex chars
+                return refnames, ObjectID(value[:SHA1_HEX_SIZE])  # First SHA1 hex chars
             if value_type == REF_VALUE_SYMREF:
             if value_type == REF_VALUE_SYMREF:
-                current = value
+                current = Ref(value)
                 continue
                 continue
 
 
             # Unknown value type
             # Unknown value type
@@ -981,7 +983,7 @@ class ReftableRefsContainer(RefsContainer):
         # Too many levels of indirection
         # Too many levels of indirection
         raise ValueError(f"Too many levels of symbolic ref indirection for {name!r}")
         raise ValueError(f"Too many levels of symbolic ref indirection for {name!r}")
 
 
-    def __getitem__(self, name: bytes) -> ObjectID:
+    def __getitem__(self, name: Ref) -> ObjectID:
         """Get the SHA1 for a reference name.
         """Get the SHA1 for a reference name.
 
 
         This method follows all symbolic references.
         This method follows all symbolic references.
@@ -991,7 +993,7 @@ class ReftableRefsContainer(RefsContainer):
             raise KeyError(name)
             raise KeyError(name)
         return sha
         return sha
 
 
-    def read_loose_ref(self, name: bytes) -> bytes:
+    def read_loose_ref(self, name: Ref) -> bytes:
         """Read a reference value without following symbolic refs.
         """Read a reference value without following symbolic refs.
 
 
         Args:
         Args:
@@ -1019,18 +1021,18 @@ class ReftableRefsContainer(RefsContainer):
 
 
         raise ValueError(f"Unknown ref value type: {value_type}")
         raise ValueError(f"Unknown ref value type: {value_type}")
 
 
-    def get_packed_refs(self) -> dict[bytes, bytes]:
+    def get_packed_refs(self) -> dict[Ref, ObjectID]:
         """Get packed refs. Reftable doesn't distinguish packed/loose."""
         """Get packed refs. Reftable doesn't distinguish packed/loose."""
         refs = self._read_all_tables()
         refs = self._read_all_tables()
         result = {}
         result = {}
         for name, (value_type, value) in refs.items():
         for name, (value_type, value) in refs.items():
             if value_type == REF_VALUE_REF:
             if value_type == REF_VALUE_REF:
-                result[name] = value
+                result[name] = ObjectID(value)
             elif value_type == REF_VALUE_PEELED:
             elif value_type == REF_VALUE_PEELED:
-                result[name] = value[:SHA1_HEX_SIZE]  # First SHA1 hex chars
+                result[name] = ObjectID(value[:SHA1_HEX_SIZE])  # First SHA1 hex chars
         return result
         return result
 
 
-    def get_peeled(self, name: bytes) -> bytes | None:
+    def get_peeled(self, name: Ref) -> ObjectID | None:
         """Return the cached peeled value of a ref, if available.
         """Return the cached peeled value of a ref, if available.
 
 
         Args:
         Args:
@@ -1048,10 +1050,10 @@ class ReftableRefsContainer(RefsContainer):
         value_type, value = ref_data
         value_type, value = ref_data
         if value_type == REF_VALUE_PEELED:
         if value_type == REF_VALUE_PEELED:
             # Return the peeled SHA (second 40 hex chars)
             # Return the peeled SHA (second 40 hex chars)
-            return value[40:80]
+            return ObjectID(value[40:80])
         elif value_type == REF_VALUE_REF:
         elif value_type == REF_VALUE_REF:
             # Known not to be peeled
             # Known not to be peeled
-            return value
+            return ObjectID(value)
         else:
         else:
             # Symbolic ref or other - no peeled info
             # Symbolic ref or other - no peeled info
             return None
             return None
@@ -1073,12 +1075,14 @@ class ReftableRefsContainer(RefsContainer):
         table_name = f"0x{min_idx:016x}-0x{max_idx:016x}-{hash_part:08x}.ref"
         table_name = f"0x{min_idx:016x}-0x{max_idx:016x}-{hash_part:08x}.ref"
         return os.path.join(self.reftable_dir, table_name)
         return os.path.join(self.reftable_dir, table_name)
 
 
-    def add_packed_refs(self, new_refs: Mapping[bytes, bytes | None]) -> None:
+    def add_packed_refs(self, new_refs: Mapping[Ref, ObjectID | None]) -> None:
         """Add packed refs. Creates a new reftable file with all refs consolidated."""
         """Add packed refs. Creates a new reftable file with all refs consolidated."""
         if not new_refs:
         if not new_refs:
             return
             return
 
 
-        self._write_batch_updates(new_refs)
+        # Convert to bytes for internal use
+        byte_refs = {bytes(k): bytes(v) if v else None for k, v in new_refs.items()}
+        self._write_batch_updates(byte_refs)
 
 
     def _write_batch_updates(self, updates: Mapping[bytes, bytes | None]) -> None:
     def _write_batch_updates(self, updates: Mapping[bytes, bytes | None]) -> None:
         """Write multiple ref updates to a single reftable file."""
         """Write multiple ref updates to a single reftable file."""
@@ -1100,9 +1104,9 @@ class ReftableRefsContainer(RefsContainer):
 
 
     def set_if_equals(
     def set_if_equals(
         self,
         self,
-        name: bytes,
-        old_ref: bytes | None,
-        new_ref: bytes | None,
+        name: Ref,
+        old_ref: ObjectID | None,
+        new_ref: ObjectID,
         committer: bytes | None = None,
         committer: bytes | None = None,
         timestamp: int | None = None,
         timestamp: int | None = None,
         timezone: int | None = None,
         timezone: int | None = None,
@@ -1116,22 +1120,19 @@ class ReftableRefsContainer(RefsContainer):
         except KeyError:
         except KeyError:
             current = None
             current = None
 
 
-        if current != old_ref:
+        old_ref_bytes = bytes(old_ref) if old_ref else None
+        if current != old_ref_bytes:
             return False
             return False
 
 
-        if new_ref is None:
-            # Delete ref
-            self._write_ref_update(name, REF_VALUE_DELETE, b"")
-        else:
-            # Update ref
-            self._write_ref_update(name, REF_VALUE_REF, new_ref)
+        # Update ref
+        self._write_ref_update(bytes(name), REF_VALUE_REF, bytes(new_ref))
 
 
         return True
         return True
 
 
     def add_if_new(
     def add_if_new(
         self,
         self,
-        name: bytes,
-        ref: bytes,
+        name: Ref,
+        ref: ObjectID,
         committer: bytes | None = None,
         committer: bytes | None = None,
         timestamp: int | None = None,
         timestamp: int | None = None,
         timezone: int | None = None,
         timezone: int | None = None,
@@ -1143,28 +1144,31 @@ class ReftableRefsContainer(RefsContainer):
             return False  # Ref exists
             return False  # Ref exists
         except KeyError:
         except KeyError:
             pass  # Ref doesn't exist, continue
             pass  # Ref doesn't exist, continue
-        self._write_ref_update(name, REF_VALUE_REF, ref)
+        self._write_ref_update(bytes(name), REF_VALUE_REF, bytes(ref))
         return True
         return True
 
 
     def remove_if_equals(
     def remove_if_equals(
         self,
         self,
-        name: bytes,
-        old_ref: bytes | None,
+        name: Ref,
+        old_ref: ObjectID | None,
         committer: bytes | None = None,
         committer: bytes | None = None,
         timestamp: int | None = None,
         timestamp: int | None = None,
         timezone: int | None = None,
         timezone: int | None = None,
         message: bytes | None = None,
         message: bytes | None = None,
     ) -> bool:
     ) -> bool:
         """Remove a ref if it equals old_ref."""
         """Remove a ref if it equals old_ref."""
-        return self.set_if_equals(
-            name,
-            old_ref,
-            None,
-            committer=committer,
-            timestamp=timestamp,
-            timezone=timezone,
-            message=message,
-        )
+        # For deletion, we need to use the internal method since set_if_equals requires new_ref
+        try:
+            current = self.read_loose_ref(name)
+        except KeyError:
+            current = None
+
+        old_ref_bytes = bytes(old_ref) if old_ref else None
+        if current != old_ref_bytes:
+            return False
+
+        self._write_ref_update(bytes(name), REF_VALUE_DELETE, b"")
+        return True
 
 
     def set_symbolic_ref(
     def set_symbolic_ref(
         self,
         self,
@@ -1234,14 +1238,19 @@ class ReftableRefsContainer(RefsContainer):
         # Get next update index - all refs in batch get the SAME index
         # Get next update index - all refs in batch get the SAME index
         batch_update_index = self._get_next_update_index()
         batch_update_index = self._get_next_update_index()
 
 
+        # Convert Ref keys to bytes for internal methods
+        all_refs_bytes = {bytes(k): v for k, v in all_refs.items()}
+
         # Apply updates to get final state
         # Apply updates to get final state
         self._apply_batch_updates(
         self._apply_batch_updates(
-            all_refs, other_updates, head_update, batch_update_index
+            all_refs_bytes, other_updates, head_update, batch_update_index
         )
         )
 
 
         # Write consolidated batch file
         # Write consolidated batch file
         created_files = (
         created_files = (
-            self._write_batch_file(all_refs, batch_update_index) if all_refs else []
+            self._write_batch_file(all_refs_bytes, batch_update_index)
+            if all_refs_bytes
+            else []
         )
         )
 
 
         # Update tables list with new files (don't compact, keep separate)
         # Update tables list with new files (don't compact, keep separate)

+ 74 - 59
dulwich/repo.py

@@ -92,6 +92,7 @@ from .objects import (
     Blob,
     Blob,
     Commit,
     Commit,
     ObjectID,
     ObjectID,
+    RawObjectID,
     ShaFile,
     ShaFile,
     Tag,
     Tag,
     Tree,
     Tree,
@@ -100,7 +101,7 @@ from .objects import (
 )
 )
 from .pack import generate_unpacked_objects
 from .pack import generate_unpacked_objects
 from .refs import (
 from .refs import (
-    ANNOTATED_TAG_SUFFIX,  # noqa: F401
+    HEADREF,
     LOCAL_TAG_PREFIX,  # noqa: F401
     LOCAL_TAG_PREFIX,  # noqa: F401
     SYMREF,  # noqa: F401
     SYMREF,  # noqa: F401
     DictRefsContainer,
     DictRefsContainer,
@@ -268,7 +269,7 @@ def check_user_identity(identity: bytes) -> None:
 
 
 def parse_graftpoints(
 def parse_graftpoints(
     graftpoints: Iterable[bytes],
     graftpoints: Iterable[bytes],
-) -> dict[bytes, list[bytes]]:
+) -> dict[ObjectID, list[ObjectID]]:
     """Convert a list of graftpoints into a dict.
     """Convert a list of graftpoints into a dict.
 
 
     Args:
     Args:
@@ -282,13 +283,13 @@ def parse_graftpoints(
 
 
     https://git.wiki.kernel.org/index.php/GraftPoint
     https://git.wiki.kernel.org/index.php/GraftPoint
     """
     """
-    grafts = {}
+    grafts: dict[ObjectID, list[ObjectID]] = {}
     for line in graftpoints:
     for line in graftpoints:
         raw_graft = line.split(None, 1)
         raw_graft = line.split(None, 1)
 
 
-        commit = raw_graft[0]
+        commit = ObjectID(raw_graft[0])
         if len(raw_graft) == 2:
         if len(raw_graft) == 2:
-            parents = raw_graft[1].split()
+            parents = [ObjectID(p) for p in raw_graft[1].split()]
         else:
         else:
             parents = []
             parents = []
 
 
@@ -299,7 +300,7 @@ def parse_graftpoints(
     return grafts
     return grafts
 
 
 
 
-def serialize_graftpoints(graftpoints: Mapping[bytes, Sequence[bytes]]) -> bytes:
+def serialize_graftpoints(graftpoints: Mapping[ObjectID, Sequence[ObjectID]]) -> bytes:
     """Convert a dictionary of grafts into string.
     """Convert a dictionary of grafts into string.
 
 
     The graft dictionary is:
     The graft dictionary is:
@@ -414,8 +415,8 @@ class ParentsProvider:
     def __init__(
     def __init__(
         self,
         self,
         store: "BaseObjectStore",
         store: "BaseObjectStore",
-        grafts: dict[bytes, list[bytes]] = {},
-        shallows: Iterable[bytes] = [],
+        grafts: dict[ObjectID, list[ObjectID]] = {},
+        shallows: Iterable[ObjectID] = [],
     ) -> None:
     ) -> None:
         """Initialize ParentsProvider.
         """Initialize ParentsProvider.
 
 
@@ -432,8 +433,8 @@ class ParentsProvider:
         self.commit_graph = store.get_commit_graph()
         self.commit_graph = store.get_commit_graph()
 
 
     def get_parents(
     def get_parents(
-        self, commit_id: bytes, commit: Commit | None = None
-    ) -> list[bytes]:
+        self, commit_id: ObjectID, commit: Commit | None = None
+    ) -> list[ObjectID]:
         """Get parents for a commit using the parents provider."""
         """Get parents for a commit using the parents provider."""
         try:
         try:
             return self.grafts[commit_id]
             return self.grafts[commit_id]
@@ -453,9 +454,8 @@ class ParentsProvider:
             obj = self.store[commit_id]
             obj = self.store[commit_id]
             assert isinstance(obj, Commit)
             assert isinstance(obj, Commit)
             commit = obj
             commit = obj
-        parents = commit.parents
-        assert isinstance(parents, list)
-        return parents
+        result: list[ObjectID] = commit.parents
+        return result
 
 
 
 
 class BaseRepo:
 class BaseRepo:
@@ -486,7 +486,7 @@ class BaseRepo:
         self.object_store = object_store
         self.object_store = object_store
         self.refs = refs
         self.refs = refs
 
 
-        self._graftpoints: dict[bytes, list[bytes]] = {}
+        self._graftpoints: dict[ObjectID, list[ObjectID]] = {}
         self.hooks: dict[str, Hook] = {}
         self.hooks: dict[str, Hook] = {}
 
 
     def _determine_file_mode(self) -> bool:
     def _determine_file_mode(self) -> bool:
@@ -585,11 +585,11 @@ class BaseRepo:
     def fetch(
     def fetch(
         self,
         self,
         target: "BaseRepo",
         target: "BaseRepo",
-        determine_wants: Callable[[Mapping[bytes, bytes], int | None], list[bytes]]
+        determine_wants: Callable[[Mapping[Ref, ObjectID], int | None], list[ObjectID]]
         | None = None,
         | None = None,
         progress: Callable[..., None] | None = None,
         progress: Callable[..., None] | None = None,
         depth: int | None = None,
         depth: int | None = None,
-    ) -> dict[bytes, bytes]:
+    ) -> dict[Ref, ObjectID]:
         """Fetch objects into another repository.
         """Fetch objects into another repository.
 
 
         Args:
         Args:
@@ -613,11 +613,11 @@ class BaseRepo:
 
 
     def fetch_pack_data(
     def fetch_pack_data(
         self,
         self,
-        determine_wants: Callable[[Mapping[bytes, bytes], int | None], list[bytes]],
+        determine_wants: Callable[[Mapping[Ref, ObjectID], int | None], list[ObjectID]],
         graph_walker: "GraphWalker",
         graph_walker: "GraphWalker",
         progress: Callable[[bytes], None] | None,
         progress: Callable[[bytes], None] | None,
         *,
         *,
-        get_tagged: Callable[[], dict[bytes, bytes]] | None = None,
+        get_tagged: Callable[[], dict[ObjectID, ObjectID]] | None = None,
         depth: int | None = None,
         depth: int | None = None,
     ) -> tuple[int, Iterator["UnpackedObject"]]:
     ) -> tuple[int, Iterator["UnpackedObject"]]:
         """Fetch the pack data required for a set of revisions.
         """Fetch the pack data required for a set of revisions.
@@ -648,11 +648,11 @@ class BaseRepo:
 
 
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
-        determine_wants: Callable[[Mapping[bytes, bytes], int | None], list[bytes]],
+        determine_wants: Callable[[Mapping[Ref, ObjectID], int | None], list[ObjectID]],
         graph_walker: "GraphWalker",
         graph_walker: "GraphWalker",
         progress: Callable[[bytes], None] | None,
         progress: Callable[[bytes], None] | None,
         *,
         *,
-        get_tagged: Callable[[], dict[bytes, bytes]] | None = None,
+        get_tagged: Callable[[], dict[ObjectID, ObjectID]] | None = None,
         depth: int | None = None,
         depth: int | None = None,
     ) -> MissingObjectFinder | None:
     ) -> MissingObjectFinder | None:
         """Fetch the missing objects required for a set of revisions.
         """Fetch the missing objects required for a set of revisions.
@@ -670,9 +670,11 @@ class BaseRepo:
           depth: Shallow fetch depth
           depth: Shallow fetch depth
         Returns: iterator over objects, with __len__ implemented
         Returns: iterator over objects, with __len__ implemented
         """
         """
+        # TODO: serialize_refs returns dict[bytes, ObjectID] with peeled refs (^{}),
+        # but determine_wants expects Mapping[Ref, ObjectID]. Need to reconcile this.
         refs = serialize_refs(self.object_store, self.get_refs())
         refs = serialize_refs(self.object_store, self.get_refs())
 
 
-        wants = determine_wants(refs, depth)
+        wants = determine_wants(refs, depth)  # type: ignore[arg-type]
         if not isinstance(wants, list):
         if not isinstance(wants, list):
             raise TypeError("determine_wants() did not return a list")
             raise TypeError("determine_wants() did not return a list")
 
 
@@ -721,7 +723,7 @@ class BaseRepo:
 
 
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
 
 
-        def get_parents(commit: Commit) -> list[bytes]:
+        def get_parents(commit: Commit) -> list[ObjectID]:
             """Get parents for a commit using the parents provider.
             """Get parents for a commit using the parents provider.
 
 
             Args:
             Args:
@@ -785,7 +787,7 @@ class BaseRepo:
         if heads is None:
         if heads is None:
             heads = [
             heads = [
                 sha
                 sha
-                for sha in self.refs.as_dict(b"refs/heads").values()
+                for sha in self.refs.as_dict(Ref(b"refs/heads")).values()
                 if sha in self.object_store
                 if sha in self.object_store
             ]
             ]
         parents_provider = ParentsProvider(self.object_store)
         parents_provider = ParentsProvider(self.object_store)
@@ -796,21 +798,22 @@ class BaseRepo:
             update_shallow=self.update_shallow,
             update_shallow=self.update_shallow,
         )
         )
 
 
-    def get_refs(self) -> dict[bytes, bytes]:
+    def get_refs(self) -> dict[Ref, ObjectID]:
         """Get dictionary with all refs.
         """Get dictionary with all refs.
 
 
         Returns: A ``dict`` mapping ref names to SHA1s
         Returns: A ``dict`` mapping ref names to SHA1s
         """
         """
         return self.refs.as_dict()
         return self.refs.as_dict()
 
 
-    def head(self) -> bytes:
+    def head(self) -> ObjectID:
         """Return the SHA1 pointed at by HEAD."""
         """Return the SHA1 pointed at by HEAD."""
         # TODO: move this method to WorkTree
         # TODO: move this method to WorkTree
-        return self.refs[b"HEAD"]
+        return self.refs[HEADREF]
 
 
     def _get_object(self, sha: bytes, cls: type[T]) -> T:
     def _get_object(self, sha: bytes, cls: type[T]) -> T:
         assert len(sha) in (20, 40)
         assert len(sha) in (20, 40)
-        ret = self.get_object(sha)
+        obj_id = ObjectID(sha) if len(sha) == 40 else RawObjectID(sha)
+        ret = self.get_object(obj_id)
         if not isinstance(ret, cls):
         if not isinstance(ret, cls):
             if cls is Commit:
             if cls is Commit:
                 raise NotCommitError(ret.id)
                 raise NotCommitError(ret.id)
@@ -824,7 +827,7 @@ class BaseRepo:
                 raise Exception(f"Type invalid: {ret.type_name!r} != {cls.type_name!r}")
                 raise Exception(f"Type invalid: {ret.type_name!r} != {cls.type_name!r}")
         return ret
         return ret
 
 
-    def get_object(self, sha: bytes) -> ShaFile:
+    def get_object(self, sha: ObjectID | RawObjectID) -> ShaFile:
         """Retrieve the object with the specified SHA.
         """Retrieve the object with the specified SHA.
 
 
         Args:
         Args:
@@ -847,7 +850,9 @@ class BaseRepo:
             shallows=self.get_shallow(),
             shallows=self.get_shallow(),
         )
         )
 
 
-    def get_parents(self, sha: bytes, commit: Commit | None = None) -> list[bytes]:
+    def get_parents(
+        self, sha: ObjectID, commit: Commit | None = None
+    ) -> list[ObjectID]:
         """Retrieve the parents of a specific commit.
         """Retrieve the parents of a specific commit.
 
 
         If the specific commit is a graftpoint, the graft parents
         If the specific commit is a graftpoint, the graft parents
@@ -940,10 +945,10 @@ class BaseRepo:
         if f is None:
         if f is None:
             return set()
             return set()
         with f:
         with f:
-            return {line.strip() for line in f}
+            return {ObjectID(line.strip()) for line in f}
 
 
     def update_shallow(
     def update_shallow(
-        self, new_shallow: set[bytes] | None, new_unshallow: set[bytes] | None
+        self, new_shallow: set[ObjectID] | None, new_unshallow: set[ObjectID] | None
     ) -> None:
     ) -> None:
         """Update the list of shallow objects.
         """Update the list of shallow objects.
 
 
@@ -988,8 +993,8 @@ class BaseRepo:
 
 
     def get_walker(
     def get_walker(
         self,
         self,
-        include: Sequence[bytes] | None = None,
-        exclude: Sequence[bytes] | None = None,
+        include: Sequence[ObjectID] | None = None,
+        exclude: Sequence[ObjectID] | None = None,
         order: str = "date",
         order: str = "date",
         reverse: bool = False,
         reverse: bool = False,
         max_entries: int | None = None,
         max_entries: int | None = None,
@@ -1047,7 +1052,7 @@ class BaseRepo:
             queue_cls=queue_cls if queue_cls is not None else _CommitTimeQueue,
             queue_cls=queue_cls if queue_cls is not None else _CommitTimeQueue,
         )
         )
 
 
-    def __getitem__(self, name: ObjectID | Ref) -> "ShaFile":
+    def __getitem__(self, name: ObjectID | Ref | bytes) -> "ShaFile":
         """Retrieve a Git object by SHA1 or ref.
         """Retrieve a Git object by SHA1 or ref.
 
 
         Args:
         Args:
@@ -1060,11 +1065,14 @@ class BaseRepo:
             raise TypeError(f"'name' must be bytestring, not {type(name).__name__:.80}")
             raise TypeError(f"'name' must be bytestring, not {type(name).__name__:.80}")
         if len(name) in (20, 40):
         if len(name) in (20, 40):
             try:
             try:
-                return self.object_store[name]
+                # Try as ObjectID/RawObjectID
+                return self.object_store[
+                    ObjectID(name) if len(name) == 40 else RawObjectID(name)
+                ]
             except (KeyError, ValueError):
             except (KeyError, ValueError):
                 pass
                 pass
         try:
         try:
-            return self.object_store[self.refs[name]]
+            return self.object_store[self.refs[Ref(name)]]
         except RefFormatError as exc:
         except RefFormatError as exc:
             raise KeyError(name) from exc
             raise KeyError(name) from exc
 
 
@@ -1074,10 +1082,12 @@ class BaseRepo:
         Args:
         Args:
           name: Git object SHA1 or ref name
           name: Git object SHA1 or ref name
         """
         """
-        if len(name) == 20 or (len(name) == 40 and valid_hexsha(name)):
-            return name in self.object_store or name in self.refs
+        if len(name) == 20:
+            return RawObjectID(name) in self.object_store or Ref(name) in self.refs
+        elif len(name) == 40 and valid_hexsha(name):
+            return ObjectID(name) in self.object_store or Ref(name) in self.refs
         else:
         else:
-            return name in self.refs
+            return Ref(name) in self.refs
 
 
     def __setitem__(self, name: bytes, value: ShaFile | bytes) -> None:
     def __setitem__(self, name: bytes, value: ShaFile | bytes) -> None:
         """Set a ref.
         """Set a ref.
@@ -1086,11 +1096,12 @@ class BaseRepo:
           name: ref name
           name: ref name
           value: Ref value - either a ShaFile object, or a hex sha
           value: Ref value - either a ShaFile object, or a hex sha
         """
         """
-        if name.startswith(b"refs/") or name == b"HEAD":
+        if name.startswith(b"refs/") or name == HEADREF:
+            ref_name = Ref(name)
             if isinstance(value, ShaFile):
             if isinstance(value, ShaFile):
-                self.refs[name] = value.id
+                self.refs[ref_name] = value.id
             elif isinstance(value, bytes):
             elif isinstance(value, bytes):
-                self.refs[name] = value
+                self.refs[ref_name] = ObjectID(value)
             else:
             else:
                 raise TypeError(value)
                 raise TypeError(value)
         else:
         else:
@@ -1102,8 +1113,8 @@ class BaseRepo:
         Args:
         Args:
           name: Name of the ref to remove
           name: Name of the ref to remove
         """
         """
-        if name.startswith(b"refs/") or name == b"HEAD":
-            del self.refs[name]
+        if name.startswith(b"refs/") or name == HEADREF:
+            del self.refs[Ref(name)]
         else:
         else:
             raise ValueError(name)
             raise ValueError(name)
 
 
@@ -1117,7 +1128,9 @@ class BaseRepo:
         )
         )
         return get_user_identity(config)
         return get_user_identity(config)
 
 
-    def _add_graftpoints(self, updated_graftpoints: dict[bytes, list[bytes]]) -> None:
+    def _add_graftpoints(
+        self, updated_graftpoints: dict[ObjectID, list[ObjectID]]
+    ) -> None:
         """Add or modify graftpoints.
         """Add or modify graftpoints.
 
 
         Args:
         Args:
@@ -1130,7 +1143,7 @@ class BaseRepo:
 
 
         self._graftpoints.update(updated_graftpoints)
         self._graftpoints.update(updated_graftpoints)
 
 
-    def _remove_graftpoints(self, to_remove: Sequence[bytes] = ()) -> None:
+    def _remove_graftpoints(self, to_remove: Sequence[ObjectID] = ()) -> None:
         """Remove graftpoints.
         """Remove graftpoints.
 
 
         Args:
         Args:
@@ -1139,12 +1152,12 @@ class BaseRepo:
         for sha in to_remove:
         for sha in to_remove:
             del self._graftpoints[sha]
             del self._graftpoints[sha]
 
 
-    def _read_heads(self, name: str) -> list[bytes]:
+    def _read_heads(self, name: str) -> list[ObjectID]:
         f = self.get_named_file(name)
         f = self.get_named_file(name)
         if f is None:
         if f is None:
             return []
             return []
         with f:
         with f:
-            return [line.strip() for line in f.readlines() if line.strip()]
+            return [ObjectID(line.strip()) for line in f.readlines() if line.strip()]
 
 
     def get_worktree(self) -> "WorkTree":
     def get_worktree(self) -> "WorkTree":
         """Get the working tree for this repository.
         """Get the working tree for this repository.
@@ -1171,7 +1184,7 @@ class BaseRepo:
         author_timezone: int | None = None,
         author_timezone: int | None = None,
         tree: ObjectID | None = None,
         tree: ObjectID | None = None,
         encoding: bytes | None = None,
         encoding: bytes | None = None,
-        ref: Ref | None = b"HEAD",
+        ref: Ref | None = HEADREF,
         merge_heads: list[ObjectID] | None = None,
         merge_heads: list[ObjectID] | None = None,
         no_verify: bool = False,
         no_verify: bool = False,
         sign: bool = False,
         sign: bool = False,
@@ -1782,19 +1795,21 @@ class Repo(BaseRepo):
                 ref_message = b"clone: from " + encoded_path
                 ref_message = b"clone: from " + encoded_path
                 self.fetch(target, depth=depth)
                 self.fetch(target, depth=depth)
                 target.refs.import_refs(
                 target.refs.import_refs(
-                    b"refs/remotes/" + origin,
-                    self.refs.as_dict(b"refs/heads"),
+                    Ref(b"refs/remotes/" + origin),
+                    self.refs.as_dict(Ref(b"refs/heads")),
                     message=ref_message,
                     message=ref_message,
                 )
                 )
                 target.refs.import_refs(
                 target.refs.import_refs(
-                    b"refs/tags", self.refs.as_dict(b"refs/tags"), message=ref_message
+                    Ref(b"refs/tags"),
+                    self.refs.as_dict(Ref(b"refs/tags")),
+                    message=ref_message,
                 )
                 )
 
 
-                head_chain, origin_sha = self.refs.follow(b"HEAD")
+                head_chain, origin_sha = self.refs.follow(HEADREF)
                 origin_head = head_chain[-1] if head_chain else None
                 origin_head = head_chain[-1] if head_chain else None
                 if origin_sha and not origin_head:
                 if origin_sha and not origin_head:
                     # set detached HEAD
                     # set detached HEAD
-                    target.refs[b"HEAD"] = origin_sha
+                    target.refs[HEADREF] = origin_sha
                 else:
                 else:
                     _set_origin_head(target.refs, origin, origin_head)
                     _set_origin_head(target.refs, origin, origin_head)
                     head_ref = _set_default_branch(
                     head_ref = _set_default_branch(
@@ -1821,7 +1836,7 @@ class Repo(BaseRepo):
         return target
         return target
 
 
     @replace_me(remove_in="0.26.0")
     @replace_me(remove_in="0.26.0")
-    def reset_index(self, tree: bytes | None = None) -> None:
+    def reset_index(self, tree: ObjectID | None = None) -> None:
         """Reset the index back to a specific tree.
         """Reset the index back to a specific tree.
 
 
         Args:
         Args:
@@ -1896,7 +1911,7 @@ class Repo(BaseRepo):
             """
             """
             try:
             try:
                 # Get the current branch using refs
                 # Get the current branch using refs
-                ref_chain, _ = self.refs.follow(b"HEAD")
+                ref_chain, _ = self.refs.follow(HEADREF)
                 head_ref = ref_chain[-1]  # Get the final resolved ref
                 head_ref = ref_chain[-1]  # Get the final resolved ref
             except KeyError:
             except KeyError:
                 pass
                 pass
@@ -2037,7 +2052,7 @@ class Repo(BaseRepo):
                 default_branch = config.get("init", "defaultBranch")
                 default_branch = config.get("init", "defaultBranch")
             except KeyError:
             except KeyError:
                 default_branch = DEFAULT_BRANCH
                 default_branch = DEFAULT_BRANCH
-        ret.refs.set_symbolic_ref(b"HEAD", local_branch_name(default_branch))
+        ret.refs.set_symbolic_ref(HEADREF, local_branch_name(default_branch))
         ret._init_files(
         ret._init_files(
             bare=bare,
             bare=bare,
             symlinks=symlinks,
             symlinks=symlinks,
@@ -2559,7 +2574,7 @@ class MemoryRepo(BaseRepo):
         author_timezone: int | None = None,
         author_timezone: int | None = None,
         tree: ObjectID | None = None,
         tree: ObjectID | None = None,
         encoding: bytes | None = None,
         encoding: bytes | None = None,
-        ref: Ref | None = b"HEAD",
+        ref: Ref | None = HEADREF,
         merge_heads: list[ObjectID] | None = None,
         merge_heads: list[ObjectID] | None = None,
         no_verify: bool = False,
         no_verify: bool = False,
         sign: bool = False,
         sign: bool = False,
@@ -2682,7 +2697,7 @@ class MemoryRepo(BaseRepo):
     def init_bare(
     def init_bare(
         cls,
         cls,
         objects: Iterable[ShaFile],
         objects: Iterable[ShaFile],
-        refs: Mapping[bytes, bytes],
+        refs: Mapping[Ref, ObjectID],
         format: int | None = None,
         format: int | None = None,
     ) -> "MemoryRepo":
     ) -> "MemoryRepo":
         """Create a new bare repository in memory.
         """Create a new bare repository in memory.

+ 47 - 45
dulwich/server.py

@@ -99,6 +99,7 @@ from .protocol import (
     MULTI_ACK,
     MULTI_ACK,
     MULTI_ACK_DETAILED,
     MULTI_ACK_DETAILED,
     NAK_LINE,
     NAK_LINE,
+    PEELED_TAG_SUFFIX,
     SIDE_BAND_CHANNEL_DATA,
     SIDE_BAND_CHANNEL_DATA,
     SIDE_BAND_CHANNEL_FATAL,
     SIDE_BAND_CHANNEL_FATAL,
     SIDE_BAND_CHANNEL_PROGRESS,
     SIDE_BAND_CHANNEL_PROGRESS,
@@ -118,7 +119,7 @@ from .protocol import (
     format_unshallow_line,
     format_unshallow_line,
     symref_capabilities,
     symref_capabilities,
 )
 )
-from .refs import PEELED_TAG_SUFFIX, Ref, RefsContainer, write_info_refs
+from .refs import Ref, RefsContainer, write_info_refs
 from .repo import Repo
 from .repo import Repo
 
 
 logger = log_utils.getLogger(__name__)
 logger = log_utils.getLogger(__name__)
@@ -156,7 +157,7 @@ class BackendRepo(TypingProtocol):
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def get_peeled(self, name: bytes) -> bytes | None:
+    def get_peeled(self, name: bytes) -> ObjectID | None:
         """Return the cached peeled value of a ref, if available.
         """Return the cached peeled value of a ref, if available.
 
 
         Args:
         Args:
@@ -170,11 +171,11 @@ class BackendRepo(TypingProtocol):
 
 
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
-        determine_wants: Callable[[Mapping[bytes, bytes], int | None], list[bytes]],
+        determine_wants: Callable[[Mapping[Ref, ObjectID], int | None], list[ObjectID]],
         graph_walker: "_ProtocolGraphWalker",
         graph_walker: "_ProtocolGraphWalker",
         progress: Callable[[bytes], None] | None,
         progress: Callable[[bytes], None] | None,
         *,
         *,
-        get_tagged: Callable[[], dict[bytes, bytes]] | None = None,
+        get_tagged: Callable[[], dict[ObjectID, ObjectID]] | None = None,
         depth: int | None = None,
         depth: int | None = None,
     ) -> "MissingObjectFinder | None":
     ) -> "MissingObjectFinder | None":
         """Yield the objects required for a list of commits.
         """Yield the objects required for a list of commits.
@@ -496,8 +497,8 @@ class UploadPackHandler(PackHandler):
         tagged = {}
         tagged = {}
         for name, sha in refs.items():
         for name, sha in refs.items():
             peeled_sha = repo.get_peeled(name)
             peeled_sha = repo.get_peeled(name)
-            if peeled_sha is not None and peeled_sha != sha:
-                tagged[peeled_sha] = sha
+            if peeled_sha is not None and peeled_sha != ObjectID(sha):
+                tagged[peeled_sha] = ObjectID(sha)
         return tagged
         return tagged
 
 
     def handle(self) -> None:
     def handle(self) -> None:
@@ -517,11 +518,11 @@ class UploadPackHandler(PackHandler):
             self.repo.get_peeled,
             self.repo.get_peeled,
             self.repo.refs.get_symrefs,
             self.repo.refs.get_symrefs,
         )
         )
-        wants = []
+        wants: list[ObjectID] = []
 
 
         def wants_wrapper(
         def wants_wrapper(
-            refs: Mapping[bytes, bytes], depth: int | None = None
-        ) -> list[bytes]:
+            refs: Mapping[Ref, ObjectID], depth: int | None = None
+        ) -> list[ObjectID]:
             wants.extend(graph_walker.determine_wants(refs, depth))
             wants.extend(graph_walker.determine_wants(refs, depth))
             return wants
             return wants
 
 
@@ -612,7 +613,7 @@ def _split_proto_line(
 
 
 
 
 def _want_satisfied(
 def _want_satisfied(
-    store: ObjectContainer, haves: set[bytes], want: bytes, earliest: int
+    store: ObjectContainer, haves: set[ObjectID], want: ObjectID, earliest: int
 ) -> bool:
 ) -> bool:
     """Check if a specific want is satisfied by a set of haves.
     """Check if a specific want is satisfied by a set of haves.
 
 
@@ -646,7 +647,7 @@ def _want_satisfied(
 
 
 
 
 def _all_wants_satisfied(
 def _all_wants_satisfied(
-    store: ObjectContainer, haves: AbstractSet[bytes], wants: set[bytes]
+    store: ObjectContainer, haves: AbstractSet[ObjectID], wants: set[ObjectID]
 ) -> bool:
 ) -> bool:
     """Check whether all the current wants are satisfied by a set of haves.
     """Check whether all the current wants are satisfied by a set of haves.
 
 
@@ -712,8 +713,8 @@ class _ProtocolGraphWalker:
         self,
         self,
         handler: PackHandler,
         handler: PackHandler,
         object_store: ObjectContainer,
         object_store: ObjectContainer,
-        get_peeled: Callable[[bytes], bytes | None],
-        get_symrefs: Callable[[], dict[bytes, bytes]],
+        get_peeled: Callable[[bytes], ObjectID | None],
+        get_symrefs: Callable[[], dict[Ref, Ref]],
     ) -> None:
     ) -> None:
         """Initialize a ProtocolGraphWalker.
         """Initialize a ProtocolGraphWalker.
 
 
@@ -730,18 +731,18 @@ class _ProtocolGraphWalker:
         self.proto = handler.proto
         self.proto = handler.proto
         self.stateless_rpc = handler.stateless_rpc
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
         self.advertise_refs = handler.advertise_refs
-        self._wants: list[bytes] = []
-        self.shallow: set[bytes] = set()
-        self.client_shallow: set[bytes] = set()
-        self.unshallow: set[bytes] = set()
+        self._wants: list[ObjectID] = []
+        self.shallow: set[ObjectID] = set()
+        self.client_shallow: set[ObjectID] = set()
+        self.unshallow: set[ObjectID] = set()
         self._cached = False
         self._cached = False
-        self._cache: list[bytes] = []
+        self._cache: list[ObjectID] = []
         self._cache_index = 0
         self._cache_index = 0
         self._impl: AckGraphWalkerImpl | None = None
         self._impl: AckGraphWalkerImpl | None = None
 
 
     def determine_wants(
     def determine_wants(
-        self, heads: Mapping[bytes, bytes], depth: int | None = None
-    ) -> list[bytes]:
+        self, heads: Mapping[Ref, ObjectID], depth: int | None = None
+    ) -> list[ObjectID]:
         """Determine the wants for a set of heads.
         """Determine the wants for a set of heads.
 
 
         The given heads are advertised to the client, who then specifies which
         The given heads are advertised to the client, who then specifies which
@@ -804,12 +805,12 @@ class _ProtocolGraphWalker:
         allowed = (COMMAND_WANT, COMMAND_SHALLOW, COMMAND_DEEPEN, None)
         allowed = (COMMAND_WANT, COMMAND_SHALLOW, COMMAND_DEEPEN, None)
         command, sha_result = _split_proto_line(line, allowed)
         command, sha_result = _split_proto_line(line, allowed)
 
 
-        want_revs = []
+        want_revs: list[ObjectID] = []
         while command == COMMAND_WANT:
         while command == COMMAND_WANT:
             assert isinstance(sha_result, bytes)
             assert isinstance(sha_result, bytes)
             if sha_result not in values:
             if sha_result not in values:
                 raise GitProtocolError(f"Client wants invalid object {sha_result!r}")
                 raise GitProtocolError(f"Client wants invalid object {sha_result!r}")
-            want_revs.append(sha_result)
+            want_revs.append(ObjectID(sha_result))
             command, sha_result = self.read_proto_line(allowed)
             command, sha_result = self.read_proto_line(allowed)
 
 
         self.set_wants(want_revs)
         self.set_wants(want_revs)
@@ -841,7 +842,7 @@ class _ProtocolGraphWalker:
     def nak(self) -> None:
     def nak(self) -> None:
         """Send a NAK response."""
         """Send a NAK response."""
 
 
-    def ack(self, have_ref: bytes) -> None:
+    def ack(self, have_ref: ObjectID) -> None:
         """Acknowledge a have reference.
         """Acknowledge a have reference.
 
 
         Args:
         Args:
@@ -860,7 +861,7 @@ class _ProtocolGraphWalker:
         self._cached = True
         self._cached = True
         self._cache_index = 0
         self._cache_index = 0
 
 
-    def next(self) -> bytes | None:
+    def next(self) -> ObjectID | None:
         """Get the next SHA from the graph walker.
         """Get the next SHA from the graph walker.
 
 
         Returns: Next SHA or None if done
         Returns: Next SHA or None if done
@@ -891,7 +892,7 @@ class _ProtocolGraphWalker:
         """
         """
         return _split_proto_line(self.proto.read_pkt_line(), allowed)
         return _split_proto_line(self.proto.read_pkt_line(), allowed)
 
 
-    def _handle_shallow_request(self, wants: Sequence[bytes]) -> None:
+    def _handle_shallow_request(self, wants: Sequence[ObjectID]) -> None:
         """Handle shallow clone requests from the client.
         """Handle shallow clone requests from the client.
 
 
         Args:
         Args:
@@ -904,7 +905,7 @@ class _ProtocolGraphWalker:
                 depth = val
                 depth = val
                 break
                 break
             assert isinstance(val, bytes)
             assert isinstance(val, bytes)
-            self.client_shallow.add(val)
+            self.client_shallow.add(ObjectID(val))
         self.read_proto_line((None,))  # consume client's flush-pkt
         self.read_proto_line((None,))  # consume client's flush-pkt
 
 
         shallow, not_shallow = find_shallow(self.store, wants, depth)
         shallow, not_shallow = find_shallow(self.store, wants, depth)
@@ -918,7 +919,7 @@ class _ProtocolGraphWalker:
         self.update_shallow(new_shallow, unshallow)
         self.update_shallow(new_shallow, unshallow)
 
 
     def update_shallow(
     def update_shallow(
-        self, new_shallow: AbstractSet[bytes], unshallow: AbstractSet[bytes]
+        self, new_shallow: AbstractSet[ObjectID], unshallow: AbstractSet[ObjectID]
     ) -> None:
     ) -> None:
         """Update shallow/unshallow information to the client.
         """Update shallow/unshallow information to the client.
 
 
@@ -938,7 +939,7 @@ class _ProtocolGraphWalker:
         # relay the message down to the handler.
         # relay the message down to the handler.
         self.handler.notify_done()
         self.handler.notify_done()
 
 
-    def send_ack(self, sha: bytes, ack_type: bytes = b"") -> None:
+    def send_ack(self, sha: ObjectID, ack_type: bytes = b"") -> None:
         """Send an ACK to the client.
         """Send an ACK to the client.
 
 
         Args:
         Args:
@@ -963,7 +964,7 @@ class _ProtocolGraphWalker:
         assert self._impl is not None
         assert self._impl is not None
         return self._impl.handle_done(done_required, done_received)
         return self._impl.handle_done(done_required, done_received)
 
 
-    def set_wants(self, wants: list[bytes]) -> None:
+    def set_wants(self, wants: list[ObjectID]) -> None:
         """Set the list of wanted objects.
         """Set the list of wanted objects.
 
 
         Args:
         Args:
@@ -971,7 +972,7 @@ class _ProtocolGraphWalker:
         """
         """
         self._wants = wants
         self._wants = wants
 
 
-    def all_wants_satisfied(self, haves: AbstractSet[bytes]) -> bool:
+    def all_wants_satisfied(self, haves: AbstractSet[ObjectID]) -> bool:
         """Check whether all the current wants are satisfied by a set of haves.
         """Check whether all the current wants are satisfied by a set of haves.
 
 
         Args:
         Args:
@@ -1008,9 +1009,9 @@ class SingleAckGraphWalkerImpl(AckGraphWalkerImpl):
           walker: Parent ProtocolGraphWalker instance
           walker: Parent ProtocolGraphWalker instance
         """
         """
         self.walker = walker
         self.walker = walker
-        self._common: list[bytes] = []
+        self._common: list[ObjectID] = []
 
 
-    def ack(self, have_ref: bytes) -> None:
+    def ack(self, have_ref: ObjectID) -> None:
         """Acknowledge a have reference.
         """Acknowledge a have reference.
 
 
         Args:
         Args:
@@ -1020,7 +1021,7 @@ class SingleAckGraphWalkerImpl(AckGraphWalkerImpl):
             self.walker.send_ack(have_ref)
             self.walker.send_ack(have_ref)
             self._common.append(have_ref)
             self._common.append(have_ref)
 
 
-    def next(self) -> bytes | None:
+    def next(self) -> ObjectID | None:
         """Get next SHA from graph walker.
         """Get next SHA from graph walker.
 
 
         Returns:
         Returns:
@@ -1033,7 +1034,7 @@ class SingleAckGraphWalkerImpl(AckGraphWalkerImpl):
             return None
             return None
         elif command == COMMAND_HAVE:
         elif command == COMMAND_HAVE:
             assert isinstance(sha, bytes)
             assert isinstance(sha, bytes)
-            return sha
+            return ObjectID(sha)
         return None
         return None
 
 
     __next__ = next
     __next__ = next
@@ -1079,9 +1080,9 @@ class MultiAckGraphWalkerImpl(AckGraphWalkerImpl):
         """
         """
         self.walker = walker
         self.walker = walker
         self._found_base = False
         self._found_base = False
-        self._common: list[bytes] = []
+        self._common: list[ObjectID] = []
 
 
-    def ack(self, have_ref: bytes) -> None:
+    def ack(self, have_ref: ObjectID) -> None:
         """Acknowledge a have reference.
         """Acknowledge a have reference.
 
 
         Args:
         Args:
@@ -1094,7 +1095,7 @@ class MultiAckGraphWalkerImpl(AckGraphWalkerImpl):
                 self._found_base = True
                 self._found_base = True
         # else we blind ack within next
         # else we blind ack within next
 
 
-    def next(self) -> bytes | None:
+    def next(self) -> ObjectID | None:
         """Get next SHA from graph walker.
         """Get next SHA from graph walker.
 
 
         Returns:
         Returns:
@@ -1112,10 +1113,11 @@ class MultiAckGraphWalkerImpl(AckGraphWalkerImpl):
                 return None
                 return None
             elif command == COMMAND_HAVE:
             elif command == COMMAND_HAVE:
                 assert isinstance(sha, bytes)
                 assert isinstance(sha, bytes)
+                sha_id = ObjectID(sha)
                 if self._found_base:
                 if self._found_base:
                     # blind ack
                     # blind ack
-                    self.walker.send_ack(sha, b"continue")
-                return sha
+                    self.walker.send_ack(sha_id, b"continue")
+                return sha_id
 
 
     __next__ = next
     __next__ = next
 
 
@@ -1162,9 +1164,9 @@ class MultiAckDetailedGraphWalkerImpl(AckGraphWalkerImpl):
             walker: Parent ProtocolGraphWalker instance
             walker: Parent ProtocolGraphWalker instance
         """
         """
         self.walker = walker
         self.walker = walker
-        self._common: list[bytes] = []
+        self._common: list[ObjectID] = []
 
 
-    def ack(self, have_ref: bytes) -> None:
+    def ack(self, have_ref: ObjectID) -> None:
         """Acknowledge a have reference.
         """Acknowledge a have reference.
 
 
         Args:
         Args:
@@ -1174,7 +1176,7 @@ class MultiAckDetailedGraphWalkerImpl(AckGraphWalkerImpl):
         self._common.append(have_ref)
         self._common.append(have_ref)
         self.walker.send_ack(have_ref, b"common")
         self.walker.send_ack(have_ref, b"common")
 
 
-    def next(self) -> bytes | None:
+    def next(self) -> ObjectID | None:
         """Get next SHA from graph walker.
         """Get next SHA from graph walker.
 
 
         Returns:
         Returns:
@@ -1203,7 +1205,7 @@ class MultiAckDetailedGraphWalkerImpl(AckGraphWalkerImpl):
                 # return the sha and let the caller ACK it with the
                 # return the sha and let the caller ACK it with the
                 # above ack method.
                 # above ack method.
                 assert isinstance(sha, bytes)
                 assert isinstance(sha, bytes)
-                return sha
+                return ObjectID(sha)
         # don't nak unless no common commits were found, even if not
         # don't nak unless no common commits were found, even if not
         # everything is satisfied
         # everything is satisfied
         return None
         return None
@@ -1432,7 +1434,7 @@ class ReceivePackHandler(PackHandler):
         # client will now send us a list of (oldsha, newsha, ref)
         # client will now send us a list of (oldsha, newsha, ref)
         while ref_line:
         while ref_line:
             (oldsha, newsha, ref_name) = ref_line.split()
             (oldsha, newsha, ref_name) = ref_line.split()
-            client_refs.append((oldsha, newsha, ref_name))
+            client_refs.append((ObjectID(oldsha), ObjectID(newsha), Ref(ref_name)))
             ref_line = self.proto.read_pkt_line()
             ref_line = self.proto.read_pkt_line()
 
 
         # backend can now deal with this refs and read a pack using self.read
         # backend can now deal with this refs and read a pack using self.read
@@ -1492,7 +1494,7 @@ class UploadArchiveHandler(Handler):
                 i += 1
                 i += 1
                 format = arguments[i].decode("ascii")
                 format = arguments[i].decode("ascii")
             else:
             else:
-                commit_sha = self.repo.refs[argument]
+                commit_sha = self.repo.refs[Ref(argument)]
                 commit_obj = store[commit_sha]
                 commit_obj = store[commit_sha]
                 assert isinstance(commit_obj, Commit)
                 assert isinstance(commit_obj, Commit)
                 tree_obj = store[commit_obj.tree]
                 tree_obj = store[commit_obj.tree]

+ 4 - 2
dulwich/stash.py

@@ -58,7 +58,7 @@ class CommitKwargs(TypedDict, total=False):
     author: bytes
     author: bytes
 
 
 
 
-DEFAULT_STASH_REF = b"refs/stash"
+DEFAULT_STASH_REF = Ref(b"refs/stash")
 
 
 
 
 class Stash:
 class Stash:
@@ -135,7 +135,9 @@ class Stash:
 
 
         # Get current HEAD to determine if we can apply cleanly
         # Get current HEAD to determine if we can apply cleanly
         try:
         try:
-            current_head = self._repo.refs[b"HEAD"]
+            from dulwich.refs import HEADREF
+
+            current_head = self._repo.refs[HEADREF]
         except KeyError:
         except KeyError:
             raise ValueError("Cannot pop stash: no HEAD")
             raise ValueError("Cannot pop stash: no HEAD")
 
 

+ 3 - 2
dulwich/submodule.py

@@ -29,13 +29,14 @@ from .object_store import iter_tree_contents
 from .objects import S_ISGITLINK
 from .objects import S_ISGITLINK
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
+    from .objects import ObjectID
     from .pack import ObjectContainer
     from .pack import ObjectContainer
     from .repo import Repo
     from .repo import Repo
 
 
 
 
 def iter_cached_submodules(
 def iter_cached_submodules(
-    store: "ObjectContainer", root_tree_id: bytes
-) -> Iterator[tuple[bytes, bytes]]:
+    store: "ObjectContainer", root_tree_id: "ObjectID"
+) -> Iterator[tuple[bytes, "ObjectID"]]:
     """Iterate over cached submodules.
     """Iterate over cached submodules.
 
 
     Args:
     Args:

+ 14 - 7
dulwich/tests/test_object_store.py

@@ -38,12 +38,14 @@ from dulwich.object_store import (
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
     Commit,
     Commit,
+    ObjectID,
     ShaFile,
     ShaFile,
     Tag,
     Tag,
     Tree,
     Tree,
     TreeEntry,
     TreeEntry,
 )
 )
 from dulwich.protocol import DEPTH_INFINITE
 from dulwich.protocol import DEPTH_INFINITE
+from dulwich.refs import Ref
 
 
 from .utils import make_commit, make_object, make_tag
 from .utils import make_commit, make_object, make_tag
 
 
@@ -71,20 +73,25 @@ class ObjectStoreTests:
     def test_determine_wants_all(self) -> None:
     def test_determine_wants_all(self) -> None:
         """Test determine_wants_all with valid ref."""
         """Test determine_wants_all with valid ref."""
         self.assertEqual(
         self.assertEqual(
-            [b"1" * 40],
-            self.store.determine_wants_all({b"refs/heads/foo": b"1" * 40}),
+            [ObjectID(b"1" * 40)],
+            self.store.determine_wants_all(
+                {Ref(b"refs/heads/foo"): ObjectID(b"1" * 40)}
+            ),
         )
         )
 
 
     def test_determine_wants_all_zero(self) -> None:
     def test_determine_wants_all_zero(self) -> None:
         """Test determine_wants_all with zero ref."""
         """Test determine_wants_all with zero ref."""
         self.assertEqual(
         self.assertEqual(
-            [], self.store.determine_wants_all({b"refs/heads/foo": b"0" * 40})
+            [],
+            self.store.determine_wants_all(
+                {Ref(b"refs/heads/foo"): ObjectID(b"0" * 40)}
+            ),
         )
         )
 
 
     def test_determine_wants_all_depth(self) -> None:
     def test_determine_wants_all_depth(self) -> None:
         """Test determine_wants_all with depth parameter."""
         """Test determine_wants_all with depth parameter."""
         self.store.add_object(testobject)
         self.store.add_object(testobject)
-        refs = {b"refs/heads/foo": testobject.id}
+        refs = {Ref(b"refs/heads/foo"): testobject.id}
         with patch.object(self.store, "_get_depth", return_value=1) as m:
         with patch.object(self.store, "_get_depth", return_value=1) as m:
             self.assertEqual([], self.store.determine_wants_all(refs, depth=0))
             self.assertEqual([], self.store.determine_wants_all(refs, depth=0))
             self.assertEqual(
             self.assertEqual(
@@ -124,7 +131,7 @@ class ObjectStoreTests:
 
 
     def test_get_nonexistant(self) -> None:
     def test_get_nonexistant(self) -> None:
         """Test getting non-existent object raises KeyError."""
         """Test getting non-existent object raises KeyError."""
-        self.assertRaises(KeyError, lambda: self.store[b"a" * 40])
+        self.assertRaises(KeyError, lambda: self.store[ObjectID(b"a" * 40)])
 
 
     def test_contains_nonexistant(self) -> None:
     def test_contains_nonexistant(self) -> None:
         """Test checking for non-existent object."""
         """Test checking for non-existent object."""
@@ -300,7 +307,7 @@ class ObjectStoreTests:
         """Test iterating with missing objects when not allowed."""
         """Test iterating with missing objects when not allowed."""
         blob1 = make_object(Blob, data=b"blob 1 data")
         blob1 = make_object(Blob, data=b"blob 1 data")
         self.store.add_object(blob1)
         self.store.add_object(blob1)
-        missing_sha = b"1" * 40
+        missing_sha = ObjectID(b"1" * 40)
 
 
         self.assertRaises(
         self.assertRaises(
             KeyError,
             KeyError,
@@ -311,7 +318,7 @@ class ObjectStoreTests:
         """Test iterating with missing objects when allowed."""
         """Test iterating with missing objects when allowed."""
         blob1 = make_object(Blob, data=b"blob 1 data")
         blob1 = make_object(Blob, data=b"blob 1 data")
         self.store.add_object(blob1)
         self.store.add_object(blob1)
-        missing_sha = b"1" * 40
+        missing_sha = ObjectID(b"1" * 40)
 
 
         objects = list(
         objects = list(
             self.store.iterobjects_subset([blob1.id, missing_sha], allow_missing=True)
             self.store.iterobjects_subset([blob1.id, missing_sha], allow_missing=True)

+ 4 - 4
dulwich/walk.py

@@ -274,8 +274,8 @@ class Walker:
     def __init__(
     def __init__(
         self,
         self,
         store: "BaseObjectStore",
         store: "BaseObjectStore",
-        include: Sequence[bytes],
-        exclude: Sequence[bytes] | None = None,
+        include: ObjectID | Sequence[ObjectID],
+        exclude: Sequence[ObjectID] | None = None,
         order: str = "date",
         order: str = "date",
         reverse: bool = False,
         reverse: bool = False,
         max_entries: int | None = None,
         max_entries: int | None = None,
@@ -284,7 +284,7 @@ class Walker:
         follow: bool = False,
         follow: bool = False,
         since: int | None = None,
         since: int | None = None,
         until: int | None = None,
         until: int | None = None,
-        get_parents: Callable[[Commit], list[bytes]] = lambda commit: commit.parents,
+        get_parents: Callable[[Commit], list[ObjectID]] = lambda commit: commit.parents,
         queue_cls: type = _CommitTimeQueue,
         queue_cls: type = _CommitTimeQueue,
     ) -> None:
     ) -> None:
         """Constructor.
         """Constructor.
@@ -468,7 +468,7 @@ class Walker:
 
 
 def _topo_reorder(
 def _topo_reorder(
     entries: Iterator[WalkEntry],
     entries: Iterator[WalkEntry],
-    get_parents: Callable[[Commit], list[bytes]] = lambda commit: commit.parents,
+    get_parents: Callable[[Commit], list[ObjectID]] = lambda commit: commit.parents,
 ) -> Iterator[WalkEntry]:
 ) -> Iterator[WalkEntry]:
     """Reorder an iterable of entries topologically.
     """Reorder an iterable of entries topologically.
 
 

+ 2 - 1
dulwich/web.py

@@ -107,6 +107,7 @@ else:
 from dulwich import log_utils
 from dulwich import log_utils
 
 
 from .errors import NotGitRepository
 from .errors import NotGitRepository
+from .objects import ObjectID
 from .protocol import ReceivableProtocol
 from .protocol import ReceivableProtocol
 from .repo import BaseRepo, Repo
 from .repo import BaseRepo, Repo
 from .server import (
 from .server import (
@@ -309,7 +310,7 @@ def get_loose_object(
     Returns:
     Returns:
       Iterator yielding object contents as bytes
       Iterator yielding object contents as bytes
     """
     """
-    sha = (mat.group(1) + mat.group(2)).encode("ascii")
+    sha = cast(ObjectID, (mat.group(1) + mat.group(2)).encode("ascii"))
     logger.info("Sending loose object %s", sha)
     logger.info("Sending loose object %s", sha)
     object_store = get_repo(backend, mat).object_store
     object_store = get_repo(backend, mat).object_store
     if not object_store.contains_loose(sha):
     if not object_store.contains_loose(sha):

+ 10 - 8
dulwich/worktree.py

@@ -67,7 +67,7 @@ class WorkTreeInfo:
         self,
         self,
         path: str,
         path: str,
         head: bytes | None = None,
         head: bytes | None = None,
-        branch: bytes | None = None,
+        branch: Ref | None = None,
         bare: bool = False,
         bare: bool = False,
         detached: bool = False,
         detached: bool = False,
         locked: bool = False,
         locked: bool = False,
@@ -358,7 +358,7 @@ class WorkTree:
 
 
         index = self._repo.open_index()
         index = self._repo.open_index()
         try:
         try:
-            commit = self._repo[b"HEAD"]
+            commit = self._repo[Ref(b"HEAD")]
         except KeyError:
         except KeyError:
             # no head mean no commit in the repo
             # no head mean no commit in the repo
             for fs_path in fs_paths:
             for fs_path in fs_paths:
@@ -425,7 +425,7 @@ class WorkTree:
         author_timezone: int | None = None,
         author_timezone: int | None = None,
         tree: ObjectID | None = None,
         tree: ObjectID | None = None,
         encoding: bytes | None = None,
         encoding: bytes | None = None,
-        ref: Ref | None = b"HEAD",
+        ref: Ref | None = Ref(b"HEAD"),
         merge_heads: Sequence[ObjectID] | None = None,
         merge_heads: Sequence[ObjectID] | None = None,
         no_verify: bool = False,
         no_verify: bool = False,
         sign: bool | None = None,
         sign: bool | None = None,
@@ -670,7 +670,7 @@ class WorkTree:
 
 
         return c.id
         return c.id
 
 
-    def reset_index(self, tree: bytes | None = None) -> None:
+    def reset_index(self, tree: ObjectID | None = None) -> None:
         """Reset the index back to a specific tree.
         """Reset the index back to a specific tree.
 
 
         Args:
         Args:
@@ -685,7 +685,7 @@ class WorkTree:
         )
         )
 
 
         if tree is None:
         if tree is None:
-            head = self._repo[b"HEAD"]
+            head = self._repo[Ref(b"HEAD")]
             if isinstance(head, Tag):
             if isinstance(head, Tag):
                 _cls, obj = head.object
                 _cls, obj = head.object
                 head = self._repo.get_object(obj)
                 head = self._repo.get_object(obj)
@@ -840,7 +840,7 @@ def list_worktrees(repo: Repo) -> list[WorkTreeInfo]:
         with open(os.path.join(repo.controldir(), "HEAD"), "rb") as f:
         with open(os.path.join(repo.controldir(), "HEAD"), "rb") as f:
             head_contents = f.read().strip()
             head_contents = f.read().strip()
             if head_contents.startswith(SYMREF):
             if head_contents.startswith(SYMREF):
-                ref_name = head_contents[len(SYMREF) :].strip()
+                ref_name = Ref(head_contents[len(SYMREF) :].strip())
                 main_wt_info.branch = ref_name
                 main_wt_info.branch = ref_name
             else:
             else:
                 main_wt_info.detached = True
                 main_wt_info.detached = True
@@ -892,7 +892,7 @@ def list_worktrees(repo: Repo) -> list[WorkTreeInfo]:
                 with open(head_path, "rb") as f:
                 with open(head_path, "rb") as f:
                     head_contents = f.read().strip()
                     head_contents = f.read().strip()
                     if head_contents.startswith(SYMREF):
                     if head_contents.startswith(SYMREF):
-                        ref_name = head_contents[len(SYMREF) :].strip()
+                        ref_name = Ref(head_contents[len(SYMREF) :].strip())
                         wt_info.branch = ref_name
                         wt_info.branch = ref_name
                         # Resolve ref to get commit sha
                         # Resolve ref to get commit sha
                         try:
                         try:
@@ -1005,7 +1005,9 @@ def add_worktree(
     else:
     else:
         # Point to branch
         # Point to branch
         assert branch is not None  # Should be guaranteed by logic above
         assert branch is not None  # Should be guaranteed by logic above
-        wt_repo.refs.set_symbolic_ref(b"HEAD", branch)
+        from dulwich.refs import HEADREF
+
+        wt_repo.refs.set_symbolic_ref(HEADREF, branch)
 
 
     # Reset index to match HEAD
     # Reset index to match HEAD
     wt_repo.get_worktree().reset_index()
     wt_repo.get_worktree().reset_index()

+ 3 - 3
tests/compat/test_reftable.py

@@ -498,7 +498,7 @@ class ReftableCompatTestCase(CompatTestCase):
 
 
         # Delete a ref using dulwich
         # Delete a ref using dulwich
         with repo.refs.batch_update():
         with repo.refs.batch_update():
-            repo.refs.set_if_equals(b"refs/heads/feature", commit_sha2, None)
+            repo.refs.remove_if_equals(b"refs/heads/feature", commit_sha2)
 
 
         repo.close()
         repo.close()
 
 
@@ -725,8 +725,8 @@ class ReftableCompatTestCase(CompatTestCase):
             repo.refs.set_if_equals(
             repo.refs.set_if_equals(
                 b"refs/heads/develop", commits[1], commits[4]
                 b"refs/heads/develop", commits[1], commits[4]
             )  # Update develop
             )  # Update develop
-            repo.refs.set_if_equals(
-                b"refs/heads/feature", commits[3], None
+            repo.refs.remove_if_equals(
+                b"refs/heads/feature", commits[3]
             )  # Delete feature
             )  # Delete feature
             repo.refs.set_symbolic_ref(
             repo.refs.set_symbolic_ref(
                 b"HEAD", b"refs/heads/develop"
                 b"HEAD", b"refs/heads/develop"

+ 3 - 1
tests/test_bundle.py

@@ -413,7 +413,9 @@ class BundleTests(TestCase):
         repo.refs[b"refs/heads/feature"] = commit.id
         repo.refs[b"refs/heads/feature"] = commit.id
 
 
         # Create bundle with only master ref
         # Create bundle with only master ref
-        bundle = create_bundle_from_repo(repo, refs=[b"refs/heads/master"])
+        from dulwich.refs import Ref
+
+        bundle = create_bundle_from_repo(repo, refs=[Ref(b"refs/heads/master")])
 
 
         # Verify only master ref is included
         # Verify only master ref is included
         self.assertEqual(len(bundle.references), 1)
         self.assertEqual(len(bundle.references), 1)

Některé soubory nejsou zobrazeny, neboť je v těchto rozdílových datech změněno mnoho souborů