Преглед изворни кода

Make Ref and ObjectID newtypes for improved typing

Jelmer Vernooij пре 1 месец
родитељ
комит
b5674746f5

+ 4 - 0
NEWS

@@ -1,5 +1,9 @@
 0.25.0	UNRELEASED
 
+ * The ``ObjectID`` and ``Ref`` types are now newtypes, making it harder to
+   accidentally pass the wrong type - as notified by mypy. Most of this is in
+   lower-level code. (Jelmer Vernooij)
+
  * Implement support for ``core.sharedRepository`` configuration option.
    Repository files and directories now respect shared repository permissions
    for group-writable or world-writable repositories. Affects loose objects,

+ 2 - 2
dulwich/annotate.py

@@ -39,7 +39,7 @@ from dulwich.walk import (
 if TYPE_CHECKING:
     from dulwich.diff_tree import TreeChange
     from dulwich.object_store import BaseObjectStore
-    from dulwich.objects import Commit, TreeEntry
+    from dulwich.objects import Commit, ObjectID, TreeEntry
 
 # Walk over ancestry graph breadth-first
 # When checking each revision, find lines that according to difflib.Differ()
@@ -74,7 +74,7 @@ def update_lines(
 
 def annotate_lines(
     store: "BaseObjectStore",
-    commit_id: bytes,
+    commit_id: "ObjectID",
     path: bytes,
     order: str = ORDER_DATE,
     lines: Sequence[tuple[tuple["Commit", "TreeEntry"], bytes]] | None = None,

+ 28 - 26
dulwich/bisect.py

@@ -24,7 +24,8 @@ import os
 from collections.abc import Sequence, Set
 
 from dulwich.object_store import peel_sha
-from dulwich.objects import Commit
+from dulwich.objects import Commit, ObjectID
+from dulwich.refs import HEADREF, Ref
 from dulwich.repo import Repo
 
 
@@ -47,8 +48,8 @@ class BisectState:
 
     def start(
         self,
-        bad: bytes | None = None,
-        good: Sequence[bytes] | None = None,
+        bad: ObjectID | None = None,
+        good: Sequence[ObjectID] | None = None,
         paths: Sequence[bytes] | None = None,
         no_checkout: bool = False,
         term_bad: str = "bad",
@@ -73,11 +74,12 @@ class BisectState:
 
         # Store current branch/commit
         try:
-            ref_chain, sha = self.repo.refs.follow(b"HEAD")
+            ref_chain, sha = self.repo.refs.follow(HEADREF)
             if sha is None:
                 # No HEAD exists
                 raise ValueError("Cannot start bisect: repository has no HEAD")
             # Use the first non-HEAD ref in the chain, or the SHA itself
+            current_branch: Ref | ObjectID
             if len(ref_chain) > 1:
                 current_branch = ref_chain[1]  # The actual branch ref
             else:
@@ -124,7 +126,7 @@ class BisectState:
             for g in good:
                 self.mark_good(g)
 
-    def mark_bad(self, rev: bytes | None = None) -> bytes | None:
+    def mark_bad(self, rev: ObjectID | None = None) -> ObjectID | None:
         """Mark a commit as bad.
 
         Args:
@@ -154,7 +156,7 @@ class BisectState:
 
         return self._find_next_commit()
 
-    def mark_good(self, rev: bytes | None = None) -> bytes | None:
+    def mark_good(self, rev: ObjectID | None = None) -> ObjectID | None:
         """Mark a commit as good.
 
         Args:
@@ -186,7 +188,7 @@ class BisectState:
 
         return self._find_next_commit()
 
-    def skip(self, revs: Sequence[bytes] | None = None) -> bytes | None:
+    def skip(self, revs: Sequence[ObjectID] | None = None) -> ObjectID | None:
         """Skip one or more commits.
 
         Args:
@@ -213,7 +215,7 @@ class BisectState:
 
         return self._find_next_commit()
 
-    def reset(self, commit: bytes | None = None) -> None:
+    def reset(self, commit: ObjectID | None = None) -> None:
         """Reset bisect state and return to original branch/commit.
 
         Args:
@@ -250,13 +252,13 @@ class BisectState:
         if commit is None:
             if original.startswith(b"refs/"):
                 # It's a branch reference - need to create a symbolic ref
-                self.repo.refs.set_symbolic_ref(b"HEAD", original)
+                self.repo.refs.set_symbolic_ref(HEADREF, Ref(original))
             else:
                 # It's a commit SHA
-                self.repo.refs[b"HEAD"] = original
+                self.repo.refs[HEADREF] = ObjectID(original)
         else:
             commit = peel_sha(self.repo.object_store, commit)[1].id
-            self.repo.refs[b"HEAD"] = commit
+            self.repo.refs[HEADREF] = commit
 
     def get_log(self) -> str:
         """Get the bisect log."""
@@ -289,16 +291,16 @@ class BisectState:
             if cmd == "start":
                 self.start()
             elif cmd == "bad":
-                rev = args[0].encode("ascii") if args else None
+                rev = ObjectID(args[0].encode("ascii")) if args else None
                 self.mark_bad(rev)
             elif cmd == "good":
-                rev = args[0].encode("ascii") if args else None
+                rev = ObjectID(args[0].encode("ascii")) if args else None
                 self.mark_good(rev)
             elif cmd == "skip":
-                revs = [arg.encode("ascii") for arg in args] if args else None
+                revs = [ObjectID(arg.encode("ascii")) for arg in args] if args else None
                 self.skip(revs)
 
-    def _find_next_commit(self) -> bytes | None:
+    def _find_next_commit(self) -> ObjectID | None:
         """Find the next commit to test using binary search.
 
         Returns:
@@ -311,15 +313,15 @@ class BisectState:
             return None
 
         with open(bad_ref_path, "rb") as f:
-            bad_sha = f.read().strip()
+            bad_sha = ObjectID(f.read().strip())
 
         # Get all good commits
-        good_shas = []
+        good_shas: list[ObjectID] = []
         bisect_refs_dir = os.path.join(self.repo.controldir(), "refs", "bisect")
         for filename in os.listdir(bisect_refs_dir):
             if filename.startswith("good-"):
                 with open(os.path.join(bisect_refs_dir, filename), "rb") as f:
-                    good_shas.append(f.read().strip())
+                    good_shas.append(ObjectID(f.read().strip()))
 
         if not good_shas:
             self._append_to_log(
@@ -328,11 +330,11 @@ class BisectState:
             return None
 
         # Get skip commits
-        skip_shas = set()
+        skip_shas: set[ObjectID] = set()
         for filename in os.listdir(bisect_refs_dir):
             if filename.startswith("skip-"):
                 with open(os.path.join(bisect_refs_dir, filename), "rb") as f:
-                    skip_shas.add(f.read().strip())
+                    skip_shas.add(ObjectID(f.read().strip()))
 
         # Find commits between good and bad
         candidates = self._find_bisect_candidates(bad_sha, good_shas, skip_shas)
@@ -367,8 +369,8 @@ class BisectState:
         return next_commit
 
     def _find_bisect_candidates(
-        self, bad_sha: bytes, good_shas: Sequence[bytes], skip_shas: Set[bytes]
-    ) -> list[bytes]:
+        self, bad_sha: ObjectID, good_shas: Sequence[ObjectID], skip_shas: Set[ObjectID]
+    ) -> list[ObjectID]:
         """Find all commits between good and bad commits.
 
         Args:
@@ -382,9 +384,9 @@ class BisectState:
         # Use git's graph walking to find commits
         # This is a simplified version - a full implementation would need
         # to handle merge commits properly
-        candidates = []
-        visited = set(good_shas)
-        queue = [bad_sha]
+        candidates: list[ObjectID] = []
+        visited: set[ObjectID] = set(good_shas)
+        queue: list[ObjectID] = [bad_sha]
 
         while queue:
             sha = queue.pop(0)
@@ -410,7 +412,7 @@ class BisectState:
 
         return candidates
 
-    def _get_commit_subject(self, sha: bytes) -> str:
+    def _get_commit_subject(self, sha: ObjectID) -> str:
         """Get the subject line of a commit message."""
         obj = self.repo.object_store[sha]
         if isinstance(obj, Commit):

+ 39 - 29
dulwich/bitmap.py

@@ -36,11 +36,21 @@ from io import BytesIO
 from typing import IO, TYPE_CHECKING
 
 from .file import GitFile
-from .objects import Blob, Commit, Tag, Tree
+from .objects import (
+    Blob,
+    Commit,
+    ObjectID,
+    RawObjectID,
+    Tag,
+    Tree,
+    hex_to_sha,
+    sha_to_hex,
+)
 
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
     from .pack import Pack, PackIndex
+    from .refs import Ref
 
 # Bitmap file signature
 BITMAP_SIGNATURE = b"BITM"
@@ -781,10 +791,10 @@ def _compute_name_hash(name: bytes) -> int:
 
 
 def select_bitmap_commits(
-    refs: dict[bytes, bytes],
+    refs: dict["Ref", ObjectID],
     object_store: "BaseObjectStore",
     commit_interval: int = DEFAULT_COMMIT_INTERVAL,
-) -> list[bytes]:
+) -> list[ObjectID]:
     """Select commits for bitmap generation.
 
     Uses Git's strategy:
@@ -849,8 +859,8 @@ def select_bitmap_commits(
 
 
 def build_reachability_bitmap(
-    commit_sha: bytes,
-    sha_to_pos: dict[bytes, int],
+    commit_sha: ObjectID,
+    sha_to_pos: dict[RawObjectID, int],
     object_store: "BaseObjectStore",
 ) -> EWAHBitmap:
     """Build a reachability bitmap for a commit.
@@ -879,8 +889,10 @@ def build_reachability_bitmap(
         seen.add(sha)
 
         # Add this object to the bitmap if it's in the pack
-        if sha in sha_to_pos:
-            bitmap.add(sha_to_pos[sha])
+        # Convert hex SHA to binary for pack index lookup
+        raw_sha = hex_to_sha(sha)
+        if raw_sha in sha_to_pos:
+            bitmap.add(sha_to_pos[raw_sha])
 
         # Get the object and traverse its references
         try:
@@ -902,9 +914,9 @@ def build_reachability_bitmap(
 
 
 def apply_xor_compression(
-    bitmaps: list[tuple[bytes, EWAHBitmap]],
+    bitmaps: list[tuple[ObjectID, EWAHBitmap]],
     max_xor_offset: int = MAX_XOR_OFFSET,
-) -> list[tuple[bytes, EWAHBitmap, int]]:
+) -> list[tuple[ObjectID, EWAHBitmap, int]]:
     """Apply XOR compression to bitmaps.
 
     XOR compression stores some bitmaps as XOR differences from previous bitmaps,
@@ -942,7 +954,7 @@ def apply_xor_compression(
 
 
 def build_type_bitmaps(
-    sha_to_pos: dict[bytes, int],
+    sha_to_pos: dict["RawObjectID", int],
     object_store: "BaseObjectStore",
 ) -> tuple[EWAHBitmap, EWAHBitmap, EWAHBitmap, EWAHBitmap]:
     """Build type bitmaps for all objects in a pack.
@@ -956,8 +968,6 @@ def build_type_bitmaps(
     Returns:
         Tuple of (commit_bitmap, tree_bitmap, blob_bitmap, tag_bitmap)
     """
-    from .objects import sha_to_hex
-
     commit_bitmap = EWAHBitmap()
     tree_bitmap = EWAHBitmap()
     blob_bitmap = EWAHBitmap()
@@ -965,7 +975,7 @@ def build_type_bitmaps(
 
     for sha, pos in sha_to_pos.items():
         # Pack index returns binary SHA (20 bytes), but object_store expects hex SHA (40 bytes)
-        hex_sha = sha_to_hex(sha) if len(sha) == 20 else sha
+        hex_sha = sha_to_hex(sha) if len(sha) == 20 else ObjectID(sha)
         try:
             obj = object_store[hex_sha]
         except KeyError:
@@ -987,7 +997,7 @@ def build_type_bitmaps(
 
 
 def build_name_hash_cache(
-    sha_to_pos: dict[bytes, int],
+    sha_to_pos: dict["RawObjectID", int],
     object_store: "BaseObjectStore",
 ) -> list[int]:
     """Build name-hash cache for all objects in a pack.
@@ -1002,15 +1012,13 @@ def build_name_hash_cache(
     Returns:
         List of 32-bit hash values, one per object in the pack
     """
-    from .objects import sha_to_hex
-
     # Pre-allocate list with correct size
     num_objects = len(sha_to_pos)
     name_hashes = [0] * num_objects
 
     for sha, pos in sha_to_pos.items():
         # Pack index returns binary SHA (20 bytes), but object_store expects hex SHA (40 bytes)
-        hex_sha = sha_to_hex(sha) if len(sha) == 20 else sha
+        hex_sha = sha_to_hex(sha) if len(sha) == 20 else ObjectID(sha)
         try:
             obj = object_store[hex_sha]
         except KeyError:
@@ -1038,7 +1046,7 @@ def build_name_hash_cache(
 def generate_bitmap(
     pack_index: "PackIndex",
     object_store: "BaseObjectStore",
-    refs: dict[bytes, bytes],
+    refs: dict["Ref", ObjectID],
     pack_checksum: bytes,
     include_hash_cache: bool = True,
     include_lookup_table: bool = True,
@@ -1068,7 +1076,7 @@ def generate_bitmap(
 
     # Build mapping from SHA to position in pack index ONCE
     # This is used by all subsequent operations and avoids repeated enumeration
-    sha_to_pos = {}
+    sha_to_pos: dict[RawObjectID, int] = {}
     for pos, (sha, _offset, _crc32) in enumerate(pack_index.iterentries()):
         sha_to_pos[sha] = pos
 
@@ -1120,11 +1128,12 @@ def generate_bitmap(
 
     # Add bitmap entries
     for commit_sha, xor_bitmap, xor_offset in compressed_bitmaps:
-        if commit_sha not in sha_to_pos:
+        raw_commit_sha = hex_to_sha(commit_sha)
+        if raw_commit_sha not in sha_to_pos:
             continue
 
         entry = BitmapEntry(
-            object_pos=sha_to_pos[commit_sha],
+            object_pos=sha_to_pos[raw_commit_sha],
             xor_offset=xor_offset,
             flags=0,
             bitmap=xor_bitmap,
@@ -1154,8 +1163,8 @@ def generate_bitmap(
 
 
 def find_commit_bitmaps(
-    commit_shas: set[bytes], packs: Iterable["Pack"]
-) -> dict[bytes, tuple["Pack", "PackBitmap", dict[bytes, int]]]:
+    commit_shas: set["ObjectID"], packs: Iterable["Pack"]
+) -> dict["ObjectID", tuple["Pack", "PackBitmap", dict[RawObjectID, int]]]:
     """Find which packs have bitmaps for the given commits.
 
     Args:
@@ -1178,14 +1187,15 @@ def find_commit_bitmaps(
             continue
 
         # Build SHA to position mapping for this pack
-        sha_to_pos = {}
+        sha_to_pos: dict[RawObjectID, int] = {}
         for pos, (sha, _offset, _crc32) in enumerate(pack.index.iterentries()):
             sha_to_pos[sha] = pos
 
         # Check which commits have bitmaps
         for commit_sha in list(remaining):
             if pack_bitmap.has_commit(commit_sha):
-                if commit_sha in sha_to_pos:
+                raw_commit_sha = hex_to_sha(commit_sha)
+                if raw_commit_sha in sha_to_pos:
                     result[commit_sha] = (pack, pack_bitmap, sha_to_pos)
                     remaining.remove(commit_sha)
 
@@ -1196,7 +1206,7 @@ def bitmap_to_object_shas(
     bitmap: EWAHBitmap,
     pack_index: "PackIndex",
     type_filter: EWAHBitmap | None = None,
-) -> set[bytes]:
+) -> set[ObjectID]:
     """Convert a bitmap to a set of object SHAs.
 
     Args:
@@ -1205,15 +1215,15 @@ def bitmap_to_object_shas(
         type_filter: Optional type bitmap to filter results (e.g., commits only)
 
     Returns:
-        Set of object SHAs
+        Set of object SHAs (hex format)
     """
-    result = set()
+    result: set[ObjectID] = set()
 
     for pos, (sha, _offset, _crc32) in enumerate(pack_index.iterentries()):
         # Check if this position is in the bitmap
         if pos in bitmap:
             # Apply type filter if provided
             if type_filter is None or pos in type_filter:
-                result.add(sha)
+                result.add(sha_to_hex(sha))
 
     return result

+ 13 - 11
dulwich/bundle.py

@@ -30,7 +30,9 @@ from typing import (
     runtime_checkable,
 )
 
+from .objects import ObjectID
 from .pack import PackData, UnpackedObject, write_pack_data
+from .refs import Ref
 
 
 @runtime_checkable
@@ -57,8 +59,8 @@ class Bundle:
     version: int | None
 
     capabilities: dict[str, str | None]
-    prerequisites: list[tuple[bytes, bytes]]
-    references: dict[bytes, bytes]
+    prerequisites: list[tuple[ObjectID, bytes]]
+    references: dict[Ref, ObjectID]
     pack_data: PackDataLike | None
 
     def __repr__(self) -> str:
@@ -121,7 +123,7 @@ class Bundle:
 def _read_bundle(f: BinaryIO, version: int) -> Bundle:
     capabilities = {}
     prerequisites = []
-    references = {}
+    references: dict[Ref, ObjectID] = {}
     line = f.readline()
     if version >= 3:
         while line.startswith(b"@"):
@@ -136,11 +138,11 @@ def _read_bundle(f: BinaryIO, version: int) -> Bundle:
             line = f.readline()
     while line.startswith(b"-"):
         (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
-        prerequisites.append((obj_id, comment))
+        prerequisites.append((ObjectID(obj_id), comment))
         line = f.readline()
     while line != b"\n":
         (obj_id, ref) = line.rstrip(b"\n").split(b" ", 1)
-        references[ref] = obj_id
+        references[Ref(ref)] = ObjectID(obj_id)
         line = f.readline()
     # Extract pack data to separate stream since PackData expects
     # the file to start with PACK header at position 0
@@ -220,7 +222,7 @@ def write_bundle(f: BinaryIO, bundle: Bundle) -> None:
 
 def create_bundle_from_repo(
     repo: "BaseRepo",
-    refs: Sequence[bytes] | None = None,
+    refs: Sequence[Ref] | None = None,
     prerequisites: Sequence[bytes] | None = None,
     version: int | None = None,
     capabilities: dict[str, str | None] | None = None,
@@ -249,8 +251,8 @@ def create_bundle_from_repo(
         capabilities = {}
 
     # Build the references dictionary for the bundle
-    bundle_refs = {}
-    want_objects = set()
+    bundle_refs: dict[Ref, ObjectID] = {}
+    want_objects: set[ObjectID] = set()
 
     for ref in refs:
         if ref in repo.refs:
@@ -268,7 +270,7 @@ def create_bundle_from_repo(
 
     # Convert prerequisites to proper format
     bundle_prerequisites = []
-    have_objects = set()
+    have_objects: set[ObjectID] = set()
     for prereq in prerequisites:
         if not isinstance(prereq, bytes):
             raise TypeError(
@@ -284,8 +286,8 @@ def create_bundle_from_repo(
         except ValueError:
             raise ValueError(f"Invalid prerequisite format: {prereq!r}")
         # Store hex in bundle and for pack generation
-        bundle_prerequisites.append((prereq, b""))
-        have_objects.add(prereq)
+        bundle_prerequisites.append((ObjectID(prereq), b""))
+        have_objects.add(ObjectID(prereq))
 
     # Generate pack data containing all objects needed for the refs
     pack_count, pack_objects = repo.generate_pack_data(

+ 32 - 25
dulwich/cli.py

@@ -53,6 +53,7 @@ from typing import (
 
 from dulwich import porcelain
 from dulwich._typing import Buffer
+from dulwich.refs import HEADREF, Ref
 
 from .bundle import Bundle, create_bundle_from_repo, read_bundle, write_bundle
 from .client import get_transport_and_path
@@ -65,7 +66,7 @@ from .errors import (
 )
 from .index import Index
 from .log_utils import _configure_logging_from_trace
-from .objects import Commit, sha_to_hex, valid_hexsha
+from .objects import Commit, ObjectID, RawObjectID, sha_to_hex, valid_hexsha
 from .objectspec import parse_commit_range
 from .pack import Pack
 from .patch import DiffAlgorithmNotAvailable
@@ -1237,9 +1238,13 @@ class cmd_fetch_pack(Command):
         else:
 
             def determine_wants(
-                refs: Mapping[bytes, bytes], depth: int | None = None
-            ) -> list[bytes]:
-                return [y.encode("utf-8") for y in args.refs if y not in r.object_store]
+                refs: Mapping[Ref, ObjectID], depth: int | None = None
+            ) -> list[ObjectID]:
+                return [
+                    ObjectID(y.encode("utf-8"))
+                    for y in args.refs
+                    if y not in r.object_store
+                ]
 
         client.fetch(path.encode("utf-8"), r, determine_wants)
 
@@ -1552,7 +1557,7 @@ class cmd_dump_pack(Command):
         basename, _ = os.path.splitext(parsed_args.filename)
         x = Pack(basename)
         logger.info("Object names checksum: %s", x.name().decode("ascii", "replace"))
-        logger.info("Checksum: %r", sha_to_hex(x.get_stored_checksum()))
+        logger.info("Checksum: %r", sha_to_hex(RawObjectID(x.get_stored_checksum())))
         x.check()
         logger.info("Length: %d", len(x))
         for name in x:
@@ -1921,7 +1926,7 @@ def _get_commit_message_with_template(
     # Add branch info if repo is provided
     if repo:
         try:
-            ref_names, _ref_sha = repo.refs.follow(b"HEAD")
+            ref_names, _ref_sha = repo.refs.follow(HEADREF)
             ref_path = ref_names[-1]  # Get the final reference
             if ref_path.startswith(b"refs/heads/"):
                 branch = ref_path[11:]  # Remove 'refs/heads/' prefix
@@ -3348,7 +3353,7 @@ class cmd_pack_objects(Command):
         if not parsed_args.stdout and not parsed_args.basename:
             parser.error("basename required when not using --stdout")
 
-        object_ids = [line.strip().encode() for line in sys.stdin.readlines()]
+        object_ids = [ObjectID(line.strip().encode()) for line in sys.stdin.readlines()]
         deltify = parsed_args.deltify
         reuse_deltas = not parsed_args.no_reuse_deltas
 
@@ -4102,7 +4107,7 @@ class cmd_bisect(SuperCommand):
                     with porcelain.open_repo_closing(".") as r:
                         bad_ref = os.path.join(r.controldir(), "refs", "bisect", "bad")
                         with open(bad_ref, "rb") as f:
-                            bad_sha = f.read().strip()
+                            bad_sha = ObjectID(f.read().strip())
                         commit = r.object_store[bad_sha]
                         assert isinstance(commit, Commit)
                         message = commit.message.decode(
@@ -5342,7 +5347,7 @@ class cmd_filter_branch(Command):
         tree_filter = None
         if parsed_args.tree_filter:
 
-            def tree_filter(tree_sha: bytes, tmpdir: str) -> bytes:
+            def tree_filter(tree_sha: ObjectID, tmpdir: str) -> ObjectID:
                 from dulwich.objects import Blob, Tree
 
                 # Export tree to tmpdir
@@ -5364,7 +5369,7 @@ class cmd_filter_branch(Command):
                     run_filter(parsed_args.tree_filter, cwd=tmpdir)
 
                     # Rebuild tree from modified temp directory
-                    def build_tree_from_dir(dir_path: str) -> bytes:
+                    def build_tree_from_dir(dir_path: str) -> ObjectID:
                         tree = Tree()
                         for name in sorted(os.listdir(dir_path)):
                             if name.startswith("."):
@@ -5393,7 +5398,7 @@ class cmd_filter_branch(Command):
         index_filter = None
         if parsed_args.index_filter:
 
-            def index_filter(tree_sha: bytes, index_path: str) -> bytes | None:
+            def index_filter(tree_sha: ObjectID, index_path: str) -> ObjectID | None:
                 run_filter(
                     parsed_args.index_filter, extra_env={"GIT_INDEX_FILE": index_path}
                 )
@@ -5402,7 +5407,7 @@ class cmd_filter_branch(Command):
         parent_filter = None
         if parsed_args.parent_filter:
 
-            def parent_filter(parents: Sequence[bytes]) -> list[bytes]:
+            def parent_filter(parents: Sequence[ObjectID]) -> list[ObjectID]:
                 parent_str = " ".join(p.hex() for p in parents)
                 result = run_filter(
                     parsed_args.parent_filter, input_data=parent_str.encode()
@@ -5417,13 +5422,15 @@ class cmd_filter_branch(Command):
                 for sha in output.split():
                     sha_bytes = sha.encode()
                     if valid_hexsha(sha_bytes):
-                        new_parents.append(sha_bytes)
+                        new_parents.append(ObjectID(sha_bytes))
                 return new_parents
 
         commit_filter = None
         if parsed_args.commit_filter:
 
-            def commit_filter(commit_obj: Commit, tree_sha: bytes) -> bytes | None:
+            def commit_filter(
+                commit_obj: Commit, tree_sha: ObjectID
+            ) -> ObjectID | None:
                 # The filter receives: tree parent1 parent2...
                 cmd_input = tree_sha.hex()
                 for parent in commit_obj.parents:
@@ -5442,7 +5449,7 @@ class cmd_filter_branch(Command):
                     return None  # Skip commit
 
                 if valid_hexsha(output):
-                    return output.encode()
+                    return ObjectID(output.encode())
                 return None
 
         tag_name_filter = None
@@ -5775,7 +5782,7 @@ class cmd_format_patch(Command):
         parsed_args = parser.parse_args(args)
 
         # Parse committish using the new function
-        committish: bytes | tuple[bytes, bytes] | None = None
+        committish: ObjectID | tuple[ObjectID, ObjectID] | None = None
         if parsed_args.committish:
             with Repo(".") as r:
                 range_result = parse_commit_range(r, parsed_args.committish)
@@ -5783,7 +5790,7 @@ class cmd_format_patch(Command):
                     # Convert Commit objects to their SHAs
                     committish = (range_result[0].id, range_result[1].id)
                 else:
-                    committish = (
+                    committish = ObjectID(
                         parsed_args.committish.encode()
                         if isinstance(parsed_args.committish, str)
                         else parsed_args.committish
@@ -6025,7 +6032,7 @@ class cmd_bundle(Command):
                     msg = msg.decode("utf-8", "replace")
                 logger.error("%s", msg)
 
-        refs_to_include = []
+        refs_to_include: list[Ref] = []
         prerequisites = []
 
         if parsed_args.all:
@@ -6034,7 +6041,7 @@ class cmd_bundle(Command):
             for line in sys.stdin:
                 ref = line.strip().encode("utf-8")
                 if ref:
-                    refs_to_include.append(ref)
+                    refs_to_include.append(Ref(ref))
         elif parsed_args.refs:
             for ref_arg in parsed_args.refs:
                 if ".." in ref_arg:
@@ -6046,19 +6053,19 @@ class cmd_bundle(Command):
                         # Split the range to get the end part
                         end_part = ref_arg.split("..")[1]
                         if end_part:  # Not empty (not "A..")
-                            end_ref = end_part.encode("utf-8")
+                            end_ref = Ref(end_part.encode("utf-8"))
                             if end_ref in repo.refs:
                                 refs_to_include.append(end_ref)
                     else:
-                        sha = repo.refs[ref_arg.encode("utf-8")]
-                        refs_to_include.append(ref_arg.encode("utf-8"))
+                        sha = repo.refs[Ref(ref_arg.encode("utf-8"))]
+                        refs_to_include.append(Ref(ref_arg.encode("utf-8")))
                 else:
                     if ref_arg.startswith("^"):
-                        sha = repo.refs[ref_arg[1:].encode("utf-8")]
+                        sha = repo.refs[Ref(ref_arg[1:].encode("utf-8"))]
                         prerequisites.append(sha)
                     else:
-                        sha = repo.refs[ref_arg.encode("utf-8")]
-                        refs_to_include.append(ref_arg.encode("utf-8"))
+                        sha = repo.refs[Ref(ref_arg.encode("utf-8"))]
+                        refs_to_include.append(Ref(ref_arg.encode("utf-8")))
         else:
             logger.error("No refs specified. Use --all, --stdin, or specify refs")
             return 1

+ 114 - 102
dulwich/client.py

@@ -67,7 +67,9 @@ import dulwich
 if TYPE_CHECKING:
     from typing import Protocol as TypingProtocol
 
+    from .objects import ObjectID
     from .pack import UnpackedObject
+    from .refs import Ref
 
     class HTTPResponse(TypingProtocol):
         """Protocol for HTTP response objects."""
@@ -84,8 +86,8 @@ if TYPE_CHECKING:
 
         def __call__(
             self,
-            have: Set[bytes],
-            want: Set[bytes],
+            have: Set[ObjectID],
+            want: Set[ObjectID],
             *,
             ofs_delta: bool = False,
             progress: Callable[[bytes], None] | None = None,
@@ -98,9 +100,9 @@ if TYPE_CHECKING:
 
         def __call__(
             self,
-            refs: Mapping[bytes, bytes],
+            refs: Mapping[Ref, ObjectID],
             depth: int | None = None,
-        ) -> list[bytes]:
+        ) -> list[ObjectID]:
             """Determine the objects to fetch from the given refs."""
             ...
 
@@ -110,6 +112,7 @@ from .config import Config, apply_instead_of, get_xdg_config_home_path
 from .credentials import match_partial_url, match_urls
 from .errors import GitProtocolError, HangupException, NotGitRepository, SendPackError
 from .object_store import GraphWalker
+from .objects import ObjectID
 from .pack import (
     PACK_SPOOL_FILE_MAX_SIZE,
     PackChunkGenerator,
@@ -146,6 +149,7 @@ from .protocol import (
     GIT_PROTOCOL_VERSIONS,
     KNOWN_RECEIVE_CAPABILITIES,
     KNOWN_UPLOAD_CAPABILITIES,
+    PEELED_TAG_SUFFIX,
     SIDE_BAND_CHANNEL_DATA,
     SIDE_BAND_CHANNEL_FATAL,
     SIDE_BAND_CHANNEL_PROGRESS,
@@ -162,7 +166,7 @@ from .protocol import (
     pkt_seq,
 )
 from .refs import (
-    PEELED_TAG_SUFFIX,
+    HEADREF,
     SYMREF,
     Ref,
     _import_remote_refs,
@@ -181,8 +185,6 @@ from .repo import BaseRepo, Repo
 # behaviour with v1 when no ref-prefix is specified.
 DEFAULT_REF_PREFIX = [b"HEAD", b"refs/"]
 
-ObjectID = bytes
-
 
 logger = logging.getLogger(__name__)
 
@@ -216,8 +218,8 @@ class HTTPUnauthorized(Exception):
         self.url = url
 
 
-def _to_optional_dict(refs: Mapping[bytes, bytes]) -> dict[bytes, bytes | None]:
-    """Convert a dict[bytes, bytes] to dict[bytes, Optional[bytes]].
+def _to_optional_dict(refs: Mapping[Ref, ObjectID]) -> dict[Ref, ObjectID | None]:
+    """Convert a dict[Ref, ObjectID] to dict[Ref, Optional[ObjectID]].
 
     This is needed for compatibility with result types that expect Optional values.
     """
@@ -350,23 +352,26 @@ def read_server_capabilities(pkt_seq: Iterable[bytes]) -> set[bytes]:
 
 def read_pkt_refs_v2(
     pkt_seq: Iterable[bytes],
-) -> tuple[dict[bytes, bytes | None], dict[bytes, bytes], dict[bytes, bytes]]:
+) -> tuple[dict[Ref, ObjectID | None], dict[Ref, Ref], dict[Ref, ObjectID]]:
     """Read references using protocol version 2."""
-    refs: dict[bytes, bytes | None] = {}
-    symrefs = {}
-    peeled = {}
+    refs: dict[Ref, ObjectID | None] = {}
+    symrefs: dict[Ref, Ref] = {}
+    peeled: dict[Ref, ObjectID] = {}
     # Receive refs from server
     for pkt in pkt_seq:
         parts = pkt.rstrip(b"\n").split(b" ")
-        sha: bytes | None = parts[0]
-        if sha == b"unborn":
+        sha_bytes = parts[0]
+        sha: ObjectID | None
+        if sha_bytes == b"unborn":
             sha = None
-        ref = parts[1]
+        else:
+            sha = ObjectID(sha_bytes)
+        ref = Ref(parts[1])
         for part in parts[2:]:
             if part.startswith(b"peeled:"):
-                peeled[ref] = part[7:]
+                peeled[ref] = ObjectID(part[7:])
             elif part.startswith(b"symref-target:"):
-                symrefs[ref] = part[14:]
+                symrefs[ref] = Ref(part[14:])
             else:
                 logging.warning("unknown part in pkt-ref: %s", part)
         refs[ref] = sha
@@ -376,10 +381,10 @@ def read_pkt_refs_v2(
 
 def read_pkt_refs_v1(
     pkt_seq: Iterable[bytes],
-) -> tuple[dict[bytes, bytes], set[bytes]]:
+) -> tuple[dict[Ref, ObjectID], set[bytes]]:
     """Read references using protocol version 1."""
     server_capabilities = None
-    refs: dict[bytes, bytes] = {}
+    refs: dict[Ref, ObjectID] = {}
     # Receive refs from server
     for pkt in pkt_seq:
         (sha, ref) = pkt.rstrip(b"\n").split(None, 1)
@@ -387,7 +392,7 @@ def read_pkt_refs_v1(
             raise GitProtocolError(ref.decode("utf-8", "replace"))
         if server_capabilities is None:
             (ref, server_capabilities) = extract_capabilities(ref)
-        refs[ref] = sha
+        refs[Ref(ref)] = ObjectID(sha)
 
     if len(refs) == 0:
         return {}, set()
@@ -400,7 +405,7 @@ def read_pkt_refs_v1(
 class _DeprecatedDictProxy:
     """Base class for result objects that provide deprecated dict-like interface."""
 
-    refs: dict[bytes, bytes | None]  # To be overridden by subclasses
+    refs: dict[Ref, ObjectID | None]  # To be overridden by subclasses
 
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
         "clear",
@@ -428,11 +433,11 @@ class _DeprecatedDictProxy:
             stacklevel=3,
         )
 
-    def __contains__(self, name: bytes) -> bool:
+    def __contains__(self, name: Ref) -> bool:
         self._warn_deprecated()
         return name in self.refs
 
-    def __getitem__(self, name: bytes) -> bytes | None:
+    def __getitem__(self, name: Ref) -> ObjectID | None:
         self._warn_deprecated()
         return self.refs[name]
 
@@ -440,7 +445,7 @@ class _DeprecatedDictProxy:
         self._warn_deprecated()
         return len(self.refs)
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[Ref]:
         self._warn_deprecated()
         return iter(self.refs)
 
@@ -463,16 +468,17 @@ class FetchPackResult(_DeprecatedDictProxy):
       agent: User agent string
     """
 
-    symrefs: dict[bytes, bytes]
+    refs: dict[Ref, ObjectID | None]
+    symrefs: dict[Ref, Ref]
     agent: bytes | None
 
     def __init__(
         self,
-        refs: dict[bytes, bytes | None],
-        symrefs: dict[bytes, bytes],
+        refs: dict[Ref, ObjectID | None],
+        symrefs: dict[Ref, Ref],
         agent: bytes | None,
-        new_shallow: set[bytes] | None = None,
-        new_unshallow: set[bytes] | None = None,
+        new_shallow: set[ObjectID] | None = None,
+        new_unshallow: set[ObjectID] | None = None,
     ) -> None:
         """Initialize FetchPackResult.
 
@@ -515,10 +521,10 @@ class LsRemoteResult(_DeprecatedDictProxy):
       symrefs: Dictionary with remote symrefs
     """
 
-    symrefs: dict[bytes, bytes]
+    symrefs: dict[Ref, Ref]
 
     def __init__(
-        self, refs: dict[bytes, bytes | None], symrefs: dict[bytes, bytes]
+        self, refs: dict[Ref, ObjectID | None], symrefs: dict[Ref, Ref]
     ) -> None:
         """Initialize LsRemoteResult.
 
@@ -565,7 +571,7 @@ class SendPackResult(_DeprecatedDictProxy):
 
     def __init__(
         self,
-        refs: dict[bytes, bytes | None],
+        refs: dict[Ref, ObjectID | None],
         agent: bytes | None = None,
         ref_status: dict[bytes, str | None] | None = None,
     ) -> None:
@@ -594,9 +600,11 @@ class SendPackResult(_DeprecatedDictProxy):
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
 
 
-def _read_shallow_updates(pkt_seq: Iterable[bytes]) -> tuple[set[bytes], set[bytes]]:
-    new_shallow = set()
-    new_unshallow = set()
+def _read_shallow_updates(
+    pkt_seq: Iterable[bytes],
+) -> tuple[set[ObjectID], set[ObjectID]]:
+    new_shallow: set[ObjectID] = set()
+    new_unshallow: set[ObjectID] = set()
     for pkt in pkt_seq:
         if pkt == b"shallow-info\n":  # Git-protocol v2
             continue
@@ -605,9 +613,9 @@ def _read_shallow_updates(pkt_seq: Iterable[bytes]) -> tuple[set[bytes], set[byt
         except ValueError:
             raise GitProtocolError(f"unknown command {pkt!r}")
         if cmd == COMMAND_SHALLOW:
-            new_shallow.add(sha.strip())
+            new_shallow.add(ObjectID(sha.strip()))
         elif cmd == COMMAND_UNSHALLOW:
-            new_unshallow.add(sha.strip())
+            new_unshallow.add(ObjectID(sha.strip()))
         else:
             raise GitProtocolError(f"unknown command {pkt!r}")
     return (new_shallow, new_unshallow)
@@ -617,11 +625,11 @@ class _v1ReceivePackHeader:
     def __init__(
         self,
         capabilities: Sequence[bytes],
-        old_refs: Mapping[bytes, bytes],
-        new_refs: Mapping[bytes, bytes],
+        old_refs: Mapping[Ref, ObjectID],
+        new_refs: Mapping[Ref, ObjectID],
     ) -> None:
-        self.want: set[bytes] = set()
-        self.have: set[bytes] = set()
+        self.want: set[ObjectID] = set()
+        self.have: set[ObjectID] = set()
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self.sent_capabilities = False
 
@@ -631,8 +639,8 @@ class _v1ReceivePackHeader:
     def _handle_receive_pack_head(
         self,
         capabilities: Sequence[bytes],
-        old_refs: Mapping[bytes, bytes],
-        new_refs: Mapping[bytes, bytes],
+        old_refs: Mapping[Ref, ObjectID],
+        new_refs: Mapping[Ref, ObjectID],
     ) -> Iterator[bytes | None]:
         """Handle the head of a 'git-receive-pack' request.
 
@@ -713,13 +721,13 @@ def _handle_upload_pack_head(
     proto: Protocol,
     capabilities: Iterable[bytes],
     graph_walker: GraphWalker,
-    wants: list[bytes],
+    wants: list[ObjectID],
     can_read: Callable[[], bool] | None,
     depth: int | None,
     protocol_version: int | None,
     shallow_since: str | None = None,
     shallow_exclude: list[str] | None = None,
-) -> tuple[set[bytes] | None, set[bytes] | None]:
+) -> tuple[set[ObjectID] | None, set[ObjectID] | None]:
     """Handle the head of a 'git-upload-pack' request.
 
     Args:
@@ -734,8 +742,8 @@ def _handle_upload_pack_head(
       shallow_since: Deepen the history to include commits after this date
       shallow_exclude: Deepen the history to exclude commits reachable from these refs
     """
-    new_shallow: set[bytes] | None
-    new_unshallow: set[bytes] | None
+    new_shallow: set[ObjectID] | None
+    new_unshallow: set[ObjectID] | None
     assert isinstance(wants, list) and isinstance(wants[0], bytes)
     wantcmd = COMMAND_WANT + b" " + wants[0]
     if protocol_version is None:
@@ -788,7 +796,7 @@ def _handle_upload_pack_head(
             assert pkt is not None
             parts = pkt.rstrip(b"\n").split(b" ")
             if parts[0] == b"ACK":
-                graph_walker.ack(parts[1])
+                graph_walker.ack(ObjectID(parts[1]))
                 if parts[2] in (b"continue", b"common"):
                     pass
                 elif parts[2] == b"ready":
@@ -809,7 +817,7 @@ def _handle_upload_pack_head(
             new_shallow = None
             new_unshallow = None
     else:
-        new_shallow = new_unshallow = set()
+        new_shallow = new_unshallow = set[ObjectID]()
 
     return (new_shallow, new_unshallow)
 
@@ -841,7 +849,7 @@ def _handle_upload_pack_tail(
             break
         else:
             if parts[0] == b"ACK":
-                graph_walker.ack(parts[1])
+                graph_walker.ack(ObjectID(parts[1]))
             if parts[0] == b"NAK":
                 graph_walker.nak()
             if len(parts) < 3 or parts[2] not in (
@@ -875,7 +883,7 @@ def _handle_upload_pack_tail(
 
 def _extract_symrefs_and_agent(
     capabilities: Iterable[bytes],
-) -> tuple[dict[bytes, bytes], bytes | None]:
+) -> tuple[dict[Ref, Ref], bytes | None]:
     """Extract symrefs and agent from capabilities.
 
     Args:
@@ -883,14 +891,14 @@ def _extract_symrefs_and_agent(
     Returns:
      (symrefs, agent) tuple
     """
-    symrefs = {}
+    symrefs: dict[Ref, Ref] = {}
     agent = None
     for capability in capabilities:
         k, v = parse_capability(capability)
         if k == CAPABILITY_SYMREF:
             assert v is not None
             (src, dst) = v.split(b":", 1)
-            symrefs[src] = dst
+            symrefs[Ref(src)] = Ref(dst)
         if k == CAPABILITY_AGENT:
             agent = v
     return (symrefs, agent)
@@ -979,7 +987,7 @@ class GitClient:
     def send_pack(
         self,
         path: bytes,
-        update_refs: Callable[[dict[bytes, bytes]], dict[bytes, bytes]],
+        update_refs: Callable[[dict[Ref, ObjectID]], dict[Ref, ObjectID]],
         generate_pack_data: "GeneratePackDataFunc",
         progress: Callable[[bytes], None] | None = None,
     ) -> SendPackResult:
@@ -1014,7 +1022,7 @@ class GitClient:
         branch: str | None = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
     ) -> Repo:
@@ -1067,12 +1075,12 @@ class GitClient:
                     target.refs, origin, result.refs, message=ref_message
                 )
 
-            origin_head = result.symrefs.get(b"HEAD")
-            origin_sha = result.refs.get(b"HEAD")
+            origin_head = result.symrefs.get(HEADREF)
+            origin_sha = result.refs.get(HEADREF)
             if origin is None or (origin_sha and not origin_head):
                 # set detached HEAD
                 if origin_sha is not None:
-                    target.refs[b"HEAD"] = origin_sha
+                    target.refs[HEADREF] = origin_sha
                     head = origin_sha
                 else:
                     head = None
@@ -1111,7 +1119,7 @@ class GitClient:
         determine_wants: "DetermineWantsFunc | None" = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
@@ -1195,7 +1203,7 @@ class GitClient:
         *,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
@@ -1233,7 +1241,7 @@ class GitClient:
         self,
         path: bytes,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> LsRemoteResult:
         """Retrieve the current refs from a git smart server.
 
@@ -1248,7 +1256,7 @@ class GitClient:
         raise NotImplementedError(self.get_refs)
 
     @staticmethod
-    def _should_send_pack(new_refs: Mapping[bytes, bytes]) -> bool:
+    def _should_send_pack(new_refs: Mapping[Ref, ObjectID]) -> bool:
         # The packfile MUST NOT be sent if the only command used is delete.
         return any(sha != ZERO_SHA for sha in new_refs.values())
 
@@ -1308,7 +1316,7 @@ class GitClient:
 
     def _negotiate_upload_pack_capabilities(
         self, server_capabilities: set[bytes]
-    ) -> tuple[set[bytes], dict[bytes, bytes], bytes | None]:
+    ) -> tuple[set[bytes], dict[Ref, Ref], bytes | None]:
         (extract_capability_names(server_capabilities) - KNOWN_UPLOAD_CAPABILITIES)
         # TODO(jelmer): warn about unknown capabilities
         fetch_capa = None
@@ -1440,7 +1448,7 @@ class TraditionalGitClient(GitClient):
     def send_pack(
         self,
         path: bytes,
-        update_refs: Callable[[dict[bytes, bytes]], dict[bytes, bytes]],
+        update_refs: Callable[[dict[Ref, ObjectID]], dict[Ref, ObjectID]],
         generate_pack_data: "GeneratePackDataFunc",
         progress: Callable[[bytes], None] | None = None,
     ) -> SendPackResult:
@@ -1541,11 +1549,8 @@ class TraditionalGitClient(GitClient):
             ref_status = self._handle_receive_pack_tail(
                 proto, negotiated_capabilities, progress
             )
-            refs_with_optional_2: dict[bytes, bytes | None] = {
-                k: v for k, v in new_refs.items()
-            }
             return SendPackResult(
-                refs_with_optional_2, agent=agent, ref_status=ref_status
+                _to_optional_dict(new_refs), agent=agent, ref_status=ref_status
             )
 
     def fetch_pack(
@@ -1556,7 +1561,7 @@ class TraditionalGitClient(GitClient):
         pack_data: Callable[[bytes], int],
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
@@ -1606,7 +1611,9 @@ class TraditionalGitClient(GitClient):
         self.protocol_version = server_protocol_version
         with proto:
             # refs may have None values in v2 but not in v1
-            refs: dict[bytes, bytes | None]
+            refs: dict[Ref, ObjectID | None]
+            symrefs: dict[Ref, Ref]
+            agent: bytes | None
             if self.protocol_version == 2:
                 try:
                     server_capabilities = read_server_capabilities(proto.read_pkt_seq())
@@ -1710,7 +1717,7 @@ class TraditionalGitClient(GitClient):
         self,
         path: bytes,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> LsRemoteResult:
         """Retrieve the current refs from a git smart server."""
         # stock `git ls-remote` uses upload-pack
@@ -1748,7 +1755,7 @@ class TraditionalGitClient(GitClient):
                     raise _remote_error_from_stderr(stderr) from exc
                 proto.write_pkt_line(None)
                 for refname, refvalue in peeled.items():
-                    refs[refname + PEELED_TAG_SUFFIX] = refvalue
+                    refs[Ref(refname + PEELED_TAG_SUFFIX)] = refvalue
                 return LsRemoteResult(refs, symrefs)
         else:
             with proto:
@@ -2231,7 +2238,7 @@ class LocalGitClient(GitClient):
     def send_pack(
         self,
         path: str | bytes,
-        update_refs: Callable[[dict[bytes, bytes]], dict[bytes, bytes]],
+        update_refs: Callable[[dict[Ref, ObjectID]], dict[Ref, ObjectID]],
         generate_pack_data: "GeneratePackDataFunc",
         progress: Callable[[bytes], None] | None = None,
     ) -> SendPackResult:
@@ -2354,7 +2361,7 @@ class LocalGitClient(GitClient):
         pack_data: Callable[[bytes], int],
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
@@ -2415,21 +2422,23 @@ class LocalGitClient(GitClient):
         self,
         path: str | bytes,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> LsRemoteResult:
         """Retrieve the current refs from a local on-disk repository."""
         with self._open_repo(path) as target:
             refs_dict = target.get_refs()
             refs = _to_optional_dict(refs_dict)
             # Extract symrefs from the local repository
-            symrefs: dict[bytes, bytes] = {}
+            from dulwich.refs import Ref
+
+            symrefs: dict[Ref, Ref] = {}
             for ref in refs:
                 try:
                     # Check if this ref is symbolic by reading it directly
                     ref_value = target.refs.read_ref(ref)
                     if ref_value and ref_value.startswith(SYMREF):
                         # Extract the target from the symref
-                        symrefs[ref] = ref_value[len(SYMREF) :]
+                        symrefs[ref] = Ref(ref_value[len(SYMREF) :])
                 except (KeyError, ValueError):
                     # Not a symbolic ref or error reading it
                     pass
@@ -2546,9 +2555,9 @@ class BundleClient(GitClient):
             else:
                 raise AssertionError(f"unsupported bundle format header: {firstline!r}")
 
-            capabilities = {}
-            prerequisites = []
-            references = {}
+            capabilities: dict[str, str | None] = {}
+            prerequisites: list[tuple[ObjectID, bytes]] = []
+            references: dict[Ref, ObjectID] = {}
             line = f.readline()
 
             if version >= 3:
@@ -2565,12 +2574,12 @@ class BundleClient(GitClient):
 
             while line.startswith(b"-"):
                 (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
-                prerequisites.append((obj_id, comment))
+                prerequisites.append((ObjectID(obj_id), comment))
                 line = f.readline()
 
             while line != b"\n":
                 (obj_id, ref) = line.rstrip(b"\n").split(b" ", 1)
-                references[ref] = obj_id
+                references[Ref(ref)] = ObjectID(obj_id)
                 line = f.readline()
 
             # Don't read PackData here, we'll do it later
@@ -2619,7 +2628,7 @@ class BundleClient(GitClient):
     def send_pack(
         self,
         path: str | bytes,
-        update_refs: Callable[[dict[bytes, bytes]], dict[bytes, bytes]],
+        update_refs: Callable[[dict[Ref, ObjectID]], dict[Ref, ObjectID]],
         generate_pack_data: "GeneratePackDataFunc",
         progress: Callable[[bytes], None] | None = None,
     ) -> SendPackResult:
@@ -2633,7 +2642,7 @@ class BundleClient(GitClient):
         determine_wants: "DetermineWantsFunc | None" = None,
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
@@ -2687,7 +2696,7 @@ class BundleClient(GitClient):
         pack_data: Callable[[bytes], int],
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
@@ -2732,7 +2741,7 @@ class BundleClient(GitClient):
         self,
         path: str | bytes,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> LsRemoteResult:
         """Retrieve the current refs from a bundle file."""
         bundle = self._open_bundle(path)
@@ -3502,7 +3511,7 @@ class AbstractHttpGitClient(GitClient):
         service: bytes,
         base_url: str,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> tuple[
         dict[Ref, ObjectID | None],
         set[bytes],
@@ -3627,7 +3636,8 @@ class AbstractHttpGitClient(GitClient):
                         ) = read_pkt_refs_v1(proto.read_pkt_seq())
                         # Convert v1 refs to Optional type
                         refs = _to_optional_dict(refs_v1)
-                        (refs, peeled) = split_peeled_refs(refs)
+                        # TODO: split_peeled_refs should accept Optional values
+                        (refs, peeled) = split_peeled_refs(refs)  # type: ignore[arg-type,assignment]
                         (symrefs, _agent) = _extract_symrefs_and_agent(
                             server_capabilities
                         )
@@ -3646,12 +3656,13 @@ class AbstractHttpGitClient(GitClient):
                 from typing import cast
 
                 info_refs = read_info_refs(BytesIO(data))
-                (refs, peeled) = split_peeled_refs(
-                    cast(dict[bytes, bytes | None], info_refs)
-                )
+                (refs_nonopt, peeled) = split_peeled_refs(info_refs)
                 if ref_prefix is not None:
-                    refs = filter_ref_prefix(refs, ref_prefix)
-                return refs, set(), base_url, {}, peeled
+                    refs_nonopt = filter_ref_prefix(refs_nonopt, ref_prefix)
+                refs_result: dict[Ref, ObjectID | None] = cast(
+                    dict[Ref, ObjectID | None], refs_nonopt
+                )
+                return refs_result, set(), base_url, {}, peeled
         finally:
             resp.close()
 
@@ -3687,7 +3698,7 @@ class AbstractHttpGitClient(GitClient):
     def send_pack(
         self,
         path: str | bytes,
-        update_refs: Callable[[dict[bytes, bytes]], dict[bytes, bytes]],
+        update_refs: Callable[[dict[Ref, ObjectID]], dict[Ref, ObjectID]],
         generate_pack_data: "GeneratePackDataFunc",
         progress: Callable[[bytes], None] | None = None,
     ) -> SendPackResult:
@@ -3726,13 +3737,14 @@ class AbstractHttpGitClient(GitClient):
         assert all(v is not None for v in old_refs.values()), (
             "old_refs should not contain None values"
         )
-        old_refs_typed: dict[bytes, bytes] = old_refs  # type: ignore[assignment]
+        old_refs_typed: dict[Ref, ObjectID] = old_refs  # type: ignore[assignment]
         new_refs = update_refs(dict(old_refs_typed))
         if new_refs is None:
             # Determine wants function is aborting the push.
             # Convert to Optional type for SendPackResult
-            old_refs_optional: dict[bytes, bytes | None] = old_refs
-            return SendPackResult(old_refs_optional, agent=agent, ref_status={})
+            return SendPackResult(
+                _to_optional_dict(old_refs_typed), agent=agent, ref_status={}
+            )
         if set(new_refs.items()).issubset(set(old_refs_typed.items())):
             # Convert to Optional type for SendPackResult
             return SendPackResult(
@@ -3777,7 +3789,7 @@ class AbstractHttpGitClient(GitClient):
         pack_data: Callable[[bytes], int],
         progress: Callable[[bytes], None] | None = None,
         depth: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
         filter_spec: bytes | None = None,
         protocol_version: int | None = None,
         shallow_since: str | None = None,
@@ -3850,7 +3862,7 @@ class AbstractHttpGitClient(GitClient):
                 )
             )
 
-            symrefs[b"HEAD"] = dumb_repo.get_head()
+            symrefs[HEADREF] = dumb_repo.get_head()
 
             # Write pack data
             if pack_data_list:
@@ -3923,7 +3935,7 @@ class AbstractHttpGitClient(GitClient):
         self,
         path: str | bytes,
         protocol_version: int | None = None,
-        ref_prefix: Sequence[Ref] | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
     ) -> LsRemoteResult:
         """Retrieve the current refs from a git smart server."""
         url = self._get_url(path)
@@ -3934,7 +3946,7 @@ class AbstractHttpGitClient(GitClient):
             ref_prefix=ref_prefix,
         )
         for refname, refvalue in peeled.items():
-            refs[refname + PEELED_TAG_SUFFIX] = refvalue
+            refs[Ref(refname + PEELED_TAG_SUFFIX)] = refvalue
         return LsRemoteResult(refs, symrefs)
 
     def get_url(self, path: str) -> str:

+ 27 - 32
dulwich/commit_graph.py

@@ -26,7 +26,7 @@ from .file import _GitFile
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
 
-from .objects import Commit, ObjectID, hex_to_sha, sha_to_hex
+from .objects import Commit, ObjectID, RawObjectID, hex_to_sha, sha_to_hex
 
 # File format constants
 COMMIT_GRAPH_SIGNATURE = b"CGPH"
@@ -188,9 +188,9 @@ class CommitGraph:
         for i in range(num_commits):
             start = i * self._hash_size
             end = start + self._hash_size
-            oid = oid_lookup_data[start:end]
+            oid = RawObjectID(oid_lookup_data[start:end])
             oids.append(oid)
-            self._oid_to_index[oid] = i
+            self._oid_to_index[sha_to_hex(oid)] = i
 
         # Parse commit data chunk
         commit_data = self.chunks[CHUNK_COMMIT_DATA].data
@@ -205,7 +205,7 @@ class CommitGraph:
             offset = i * (self._hash_size + 16)
 
             # Tree OID
-            tree_id = commit_data[offset : offset + self._hash_size]
+            tree_id = RawObjectID(commit_data[offset : offset + self._hash_size])
             offset += self._hash_size
 
             # Parent positions (2 x 4 bytes)
@@ -271,14 +271,7 @@ class CommitGraph:
 
     def get_entry_by_oid(self, oid: ObjectID) -> CommitGraphEntry | None:
         """Get commit graph entry by commit OID."""
-        # Convert hex ObjectID to binary if needed for lookup
-        if isinstance(oid, bytes) and len(oid) == 40:
-            # Input is hex ObjectID, convert to binary for internal lookup
-            lookup_oid = hex_to_sha(oid)
-        else:
-            # Input is already binary
-            lookup_oid = oid
-        index = self._oid_to_index.get(lookup_oid)
+        index = self._oid_to_index.get(oid)
         if index is not None:
             return self.entries[index]
         return None
@@ -288,7 +281,7 @@ class CommitGraph:
         entry = self.get_entry_by_oid(oid)
         return entry.generation if entry else None
 
-    def get_parents(self, oid: ObjectID) -> list[bytes] | None:
+    def get_parents(self, oid: ObjectID) -> list[ObjectID] | None:
         """Get parent commit IDs for a commit."""
         entry = self.get_entry_by_oid(oid)
         return entry.parents if entry else None
@@ -443,20 +436,20 @@ def generate_commit_graph(
 
     # Ensure all commit_ids are in the correct format for object store access
     # DiskObjectStore expects hex ObjectIDs (40-byte hex strings)
-    normalized_commit_ids = []
+    normalized_commit_ids: list[ObjectID] = []
     for commit_id in commit_ids:
         if isinstance(commit_id, bytes) and len(commit_id) == 40:
             # Already hex ObjectID
-            normalized_commit_ids.append(commit_id)
+            normalized_commit_ids.append(ObjectID(commit_id))
         elif isinstance(commit_id, bytes) and len(commit_id) == 20:
             # Binary SHA, convert to hex ObjectID
-            normalized_commit_ids.append(sha_to_hex(commit_id))
+            normalized_commit_ids.append(sha_to_hex(RawObjectID(commit_id)))
         else:
             # Assume it's already correct format
-            normalized_commit_ids.append(commit_id)
+            normalized_commit_ids.append(ObjectID(commit_id))
 
     # Build a map of all commits and their metadata
-    commit_map: dict[bytes, Commit] = {}
+    commit_map: dict[ObjectID, Commit] = {}
     for commit_id in normalized_commit_ids:
         try:
             commit_obj = object_store[commit_id]
@@ -503,19 +496,20 @@ def generate_commit_graph(
     # Build commit graph entries
     for commit_id, commit_obj in commit_map.items():
         # commit_id is already hex ObjectID from normalized_commit_ids
-        commit_hex = commit_id
+        commit_hex: ObjectID = commit_id
 
         # Handle tree ID - might already be hex ObjectID
+        tree_hex: ObjectID
         if isinstance(commit_obj.tree, bytes) and len(commit_obj.tree) == 40:
-            tree_hex = commit_obj.tree  # Already hex ObjectID
+            tree_hex = ObjectID(commit_obj.tree)  # Already hex ObjectID
         else:
             tree_hex = sha_to_hex(commit_obj.tree)  # Binary, convert to hex
 
         # Handle parent IDs - might already be hex ObjectIDs
-        parents_hex = []
+        parents_hex: list[ObjectID] = []
         for parent_id in commit_obj.parents:
             if isinstance(parent_id, bytes) and len(parent_id) == 40:
-                parents_hex.append(parent_id)  # Already hex ObjectID
+                parents_hex.append(ObjectID(parent_id))  # Already hex ObjectID
             else:
                 parents_hex.append(sha_to_hex(parent_id))  # Binary, convert to hex
 
@@ -531,8 +525,7 @@ def generate_commit_graph(
     # Build the OID to index mapping for lookups
     graph._oid_to_index = {}
     for i, entry in enumerate(graph.entries):
-        binary_oid = hex_to_sha(entry.commit_id.decode())
-        graph._oid_to_index[binary_oid] = i
+        graph._oid_to_index[entry.commit_id] = i
 
     return graph
 
@@ -582,25 +575,27 @@ def get_reachable_commits(
     Returns:
         List of all reachable commit IDs (including the starting commits)
     """
-    visited = set()
-    reachable = []
-    stack = []
+    visited: set[ObjectID] = set()
+    reachable: list[ObjectID] = []
+    stack: list[ObjectID] = []
 
     # Normalize commit IDs for object store access and tracking
     for commit_id in start_commits:
         if isinstance(commit_id, bytes) and len(commit_id) == 40:
             # Hex ObjectID - use directly for object store access
-            if commit_id not in visited:
-                stack.append(commit_id)
+            oid = ObjectID(commit_id)
+            if oid not in visited:
+                stack.append(oid)
         elif isinstance(commit_id, bytes) and len(commit_id) == 20:
             # Binary SHA, convert to hex ObjectID for object store access
-            hex_id = sha_to_hex(commit_id)
+            hex_id = sha_to_hex(RawObjectID(commit_id))
             if hex_id not in visited:
                 stack.append(hex_id)
         else:
             # Assume it's already correct format
-            if commit_id not in visited:
-                stack.append(commit_id)
+            oid = ObjectID(commit_id)
+            if oid not in visited:
+                stack.append(oid)
 
     while stack:
         commit_id = stack.pop()

+ 24 - 17
dulwich/contrib/swift.py

@@ -47,7 +47,7 @@ from ..file import _GitFile
 from ..greenthreads import GreenThreadsMissingObjectFinder
 from ..lru_cache import LRUSizeCache
 from ..object_store import INFODIR, PACKDIR, PackBasedObjectStore
-from ..objects import S_ISGITLINK, Blob, Commit, Tag, Tree
+from ..objects import S_ISGITLINK, Blob, Commit, ObjectID, Tag, Tree
 from ..pack import (
     ObjectContainer,
     Pack,
@@ -66,7 +66,14 @@ from ..pack import (
     write_pack_object,
 )
 from ..protocol import TCP_GIT_PORT
-from ..refs import InfoRefsContainer, read_info_refs, split_peeled_refs, write_info_refs
+from ..refs import (
+    HEADREF,
+    InfoRefsContainer,
+    Ref,
+    read_info_refs,
+    split_peeled_refs,
+    write_info_refs,
+)
 from ..repo import OBJECTDIR, BaseRepo
 from ..server import Backend, BackendRepo, TCPGitServer
 
@@ -763,7 +770,7 @@ class SwiftObjectStore(PackBasedObjectStore):
         """Loose objects are not supported by this repository."""
         return iter([])
 
-    def pack_info_get(self, sha: bytes) -> tuple[Any, ...] | None:
+    def pack_info_get(self, sha: ObjectID) -> tuple[Any, ...] | None:
         """Get pack info for a specific SHA.
 
         Args:
@@ -786,7 +793,7 @@ class SwiftObjectStore(PackBasedObjectStore):
         if common is None:
             common = set()
 
-        def _find_parents(commit: bytes) -> list[Any]:
+        def _find_parents(commit: ObjectID) -> list[Any]:
             for pack in self.packs:
                 if commit in pack:
                     try:
@@ -983,8 +990,8 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         super().__init__(f)
 
     def _load_check_ref(
-        self, name: bytes, old_ref: bytes | None
-    ) -> dict[bytes, bytes] | bool:
+        self, name: Ref, old_ref: ObjectID | None
+    ) -> dict[Ref, ObjectID] | bool:
         self._check_refname(name)
         obj = self.scon.get_object(self.filename)
         if not obj:
@@ -1000,23 +1007,23 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
                 return False
         return refs
 
-    def _write_refs(self, refs: Mapping[bytes, bytes]) -> None:
+    def _write_refs(self, refs: Mapping[Ref, ObjectID]) -> None:
         f = BytesIO()
         f.writelines(write_info_refs(refs, cast("ObjectContainer", self.store)))
         self.scon.put_object(self.filename, f)
 
     def set_if_equals(
         self,
-        name: bytes,
-        old_ref: bytes | None,
-        new_ref: bytes,
+        name: Ref,
+        old_ref: ObjectID | None,
+        new_ref: ObjectID,
         committer: bytes | None = None,
         timestamp: float | None = None,
         timezone: int | None = None,
         message: bytes | None = None,
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref."""
-        if name == b"HEAD":
+        if name == HEADREF:
             return True
         refs = self._load_check_ref(name, old_ref)
         if not isinstance(refs, dict):
@@ -1028,15 +1035,15 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
 
     def remove_if_equals(
         self,
-        name: bytes,
-        old_ref: bytes | None,
+        name: Ref,
+        old_ref: ObjectID | None,
         committer: object = None,
         timestamp: object = None,
         timezone: object = None,
         message: object = None,
     ) -> bool:
         """Remove a refname only if it currently equals old_ref."""
-        if name == b"HEAD":
+        if name == HEADREF:
             return True
         refs = self._load_check_ref(name, old_ref)
         if not isinstance(refs, dict):
@@ -1046,14 +1053,14 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         del self._refs[name]
         return True
 
-    def allkeys(self) -> set[bytes]:
+    def allkeys(self) -> set[Ref]:
         """Get all reference names.
 
         Returns:
-          Set of reference names as bytes
+          Set of reference names as Ref
         """
         try:
-            self._refs[b"HEAD"] = self._refs[b"refs/heads/master"]
+            self._refs[HEADREF] = self._refs[Ref(b"refs/heads/master")]
         except KeyError:
             pass
         return set(self._refs.keys())

+ 6 - 4
dulwich/diff.py

@@ -54,7 +54,7 @@ from typing import BinaryIO
 from ._typing import Buffer
 from .index import ConflictedIndexEntry, commit_index
 from .object_store import iter_tree_contents
-from .objects import S_ISGITLINK, Blob, Commit
+from .objects import S_ISGITLINK, Blob, Commit, ObjectID
 from .patch import write_blob_diff, write_object_diff
 from .repo import Repo
 
@@ -79,7 +79,7 @@ def should_include_path(path: bytes, paths: Sequence[bytes] | None) -> bool:
 def diff_index_to_tree(
     repo: Repo,
     outstream: BinaryIO,
-    commit_sha: bytes | None = None,
+    commit_sha: ObjectID | None = None,
     paths: Sequence[bytes] | None = None,
     diff_algorithm: str | None = None,
 ) -> None:
@@ -94,7 +94,9 @@ def diff_index_to_tree(
     """
     if commit_sha is None:
         try:
-            commit_sha = repo.refs[b"HEAD"]
+            from dulwich.refs import HEADREF
+
+            commit_sha = repo.refs[HEADREF]
             old_commit = repo[commit_sha]
             assert isinstance(old_commit, Commit)
             old_tree = old_commit.tree
@@ -124,7 +126,7 @@ def diff_index_to_tree(
 def diff_working_tree_to_tree(
     repo: Repo,
     outstream: BinaryIO,
-    commit_sha: bytes,
+    commit_sha: ObjectID,
     paths: Sequence[bytes] | None = None,
     diff_algorithm: str | None = None,
 ) -> None:

+ 19 - 13
dulwich/dumb.py

@@ -36,6 +36,7 @@ from .objects import (
     Blob,
     Commit,
     ObjectID,
+    RawObjectID,
     ShaFile,
     Tag,
     Tree,
@@ -201,7 +202,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
                 return idx
         raise KeyError(f"Pack not found: {pack_name}")
 
-    def _fetch_from_pack(self, sha: bytes) -> tuple[int, bytes]:
+    def _fetch_from_pack(self, sha: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Try to fetch an object from pack files.
 
         Args:
@@ -215,7 +216,10 @@ class DumbHTTPObjectStore(BaseObjectStore):
         """
         self._load_packs()
         # Convert hex to binary for pack operations
-        binsha = hex_to_sha(sha)
+        if len(sha) == 20:
+            binsha = RawObjectID(sha)  # Already binary
+        else:
+            binsha = hex_to_sha(ObjectID(sha))  # Convert hex to binary
 
         for pack_name, pack_idx in self._packs or []:
             if pack_idx is None:
@@ -251,7 +255,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
 
         raise KeyError(sha)
 
-    def get_raw(self, sha: bytes) -> tuple[int, bytes]:
+    def get_raw(self, sha: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Obtain the raw text for an object.
 
         Args:
@@ -276,7 +280,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
         self._cached_objects[sha] = result
         return result
 
-    def contains_loose(self, sha: bytes) -> bool:
+    def contains_loose(self, sha: RawObjectID | ObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is loose."""
         try:
             self._fetch_loose_object(sha)
@@ -284,7 +288,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
         except KeyError:
             return False
 
-    def __contains__(self, sha: bytes) -> bool:
+    def __contains__(self, sha: RawObjectID | ObjectID) -> bool:
         """Check if a particular object is present by SHA1."""
         if sha in self._cached_objects:
             return True
@@ -303,7 +307,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
         except KeyError:
             return False
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over all SHAs in the store.
 
         Note: This is inefficient for dumb HTTP as it requires
@@ -322,7 +326,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
             for sha in idx:
                 if sha not in seen:
                     seen.add(sha)
-                    yield sha_to_hex(sha)
+                    yield sha_to_hex(RawObjectID(sha))
 
     @property
     def packs(self) -> list[Any]:
@@ -405,7 +409,10 @@ class DumbRemoteHTTPRepo:
 
             refs_hex = read_info_refs(BytesIO(refs_data))
             # Keep SHAs as hex
-            self._refs, self._peeled = split_peeled_refs(refs_hex)
+            refs_raw, peeled_raw = split_peeled_refs(refs_hex)
+            # Convert to typed dicts
+            self._refs = {Ref(k): ObjectID(v) for k, v in refs_raw.items()}
+            self._peeled = peeled_raw
 
         return dict(self._refs)
 
@@ -417,13 +424,12 @@ class DumbRemoteHTTPRepo:
         """
         head_resp_bytes = self._fetch_url("HEAD")
         head_split = head_resp_bytes.replace(b"\n", b"").split(b" ")
-        head_target = head_split[1] if len(head_split) > 1 else head_split[0]
+        head_target_bytes = head_split[1] if len(head_split) > 1 else head_split[0]
         # handle HEAD legacy format containing a commit id instead of a ref name
         for ref_name, ret_target in self.get_refs().items():
-            if ret_target == head_target:
-                head_target = ref_name
-                break
-        return head_target
+            if ret_target == head_target_bytes:
+                return ref_name
+        return Ref(head_target_bytes)
 
     def get_peeled(self, ref: Ref) -> ObjectID:
         """Get the peeled value of a ref."""

+ 12 - 7
dulwich/fastexport.py

@@ -66,7 +66,7 @@ class GitFastExporter:
         """
         self.outf = outf
         self.store = store
-        self.markers: dict[bytes, bytes] = {}
+        self.markers: dict[bytes, ObjectID] = {}
         self._marker_idx = 0
 
     def print_cmd(self, cmd: object) -> None:
@@ -117,7 +117,7 @@ class GitFastExporter:
         return marker
 
     def _iter_files(
-        self, base_tree: bytes | None, new_tree: bytes | None
+        self, base_tree: ObjectID | None, new_tree: ObjectID | None
     ) -> Generator[Any, None, None]:
         for (
             (old_path, new_path),
@@ -216,7 +216,7 @@ class GitImportProcessor(processor.ImportProcessor):  # type: ignore[misc,unused
         processor.ImportProcessor.__init__(self, params, verbose)  # type: ignore[no-untyped-call,unused-ignore]
         self.repo = repo
         self.last_commit = ZERO_SHA
-        self.markers: dict[bytes, bytes] = {}
+        self.markers: dict[bytes, ObjectID] = {}
         self._contents: dict[bytes, tuple[int, bytes]] = {}
 
     def lookup_object(self, objectish: bytes) -> ObjectID:
@@ -230,9 +230,9 @@ class GitImportProcessor(processor.ImportProcessor):  # type: ignore[misc,unused
         """
         if objectish.startswith(b":"):
             return self.markers[objectish[1:]]
-        return objectish
+        return ObjectID(objectish)
 
-    def import_stream(self, stream: BinaryIO) -> dict[bytes, bytes]:
+    def import_stream(self, stream: BinaryIO) -> dict[bytes, ObjectID]:
         """Import from a fast-import stream.
 
         Args:
@@ -314,9 +314,14 @@ class GitImportProcessor(processor.ImportProcessor):  # type: ignore[misc,unused
                 self._contents = {}
             else:
                 raise Exception(f"Command {filecmd.name!r} not supported")
+        from dulwich.objects import ObjectID
+
         commit.tree = commit_tree(
             self.repo.object_store,
-            ((path, hexsha, mode) for (path, (mode, hexsha)) in self._contents.items()),
+            (
+                (path, ObjectID(hexsha), mode)
+                for (path, (mode, hexsha)) in self._contents.items()
+            ),
         )
         if self.last_commit != ZERO_SHA:
             commit.parents.append(self.last_commit)
@@ -363,7 +368,7 @@ class GitImportProcessor(processor.ImportProcessor):  # type: ignore[misc,unused
         else:
             from_ = self.lookup_object(cmd.from_)
         self._reset_base(from_)
-        self.repo.refs[cmd.ref] = from_
+        self.repo.refs[Ref(cmd.ref)] = from_
 
     def tag_handler(self, cmd: commands.TagCommand) -> None:
         """Process a TagCommand."""

+ 26 - 24
dulwich/filter_branch.py

@@ -29,8 +29,8 @@ from typing import TypedDict
 
 from .index import Index, build_index_from_tree
 from .object_store import BaseObjectStore
-from .objects import Commit, Tag, Tree
-from .refs import RefsContainer, local_tag_name
+from .objects import Commit, ObjectID, Tag, Tree
+from .refs import Ref, RefsContainer, local_tag_name
 
 
 class CommitData(TypedDict, total=False):
@@ -57,10 +57,10 @@ class CommitFilter:
         filter_author: Callable[[bytes], bytes | None] | None = None,
         filter_committer: Callable[[bytes], bytes | None] | None = None,
         filter_message: Callable[[bytes], bytes | None] | None = None,
-        tree_filter: Callable[[bytes, str], bytes | None] | None = None,
-        index_filter: Callable[[bytes, str], bytes | None] | None = None,
-        parent_filter: Callable[[Sequence[bytes]], list[bytes]] | None = None,
-        commit_filter: Callable[[Commit, bytes], bytes | None] | None = None,
+        tree_filter: Callable[[ObjectID, str], ObjectID | None] | None = None,
+        index_filter: Callable[[ObjectID, str], ObjectID | None] | None = None,
+        parent_filter: Callable[[Sequence[ObjectID]], list[ObjectID]] | None = None,
+        commit_filter: Callable[[Commit, ObjectID], ObjectID | None] | None = None,
         subdirectory_filter: bytes | None = None,
         prune_empty: bool = False,
         tag_name_filter: Callable[[bytes], bytes | None] | None = None,
@@ -101,13 +101,13 @@ class CommitFilter:
         self.subdirectory_filter = subdirectory_filter
         self.prune_empty = prune_empty
         self.tag_name_filter = tag_name_filter
-        self._old_to_new: dict[bytes, bytes] = {}
-        self._processed: set[bytes] = set()
-        self._tree_cache: dict[bytes, bytes] = {}  # Cache for filtered trees
+        self._old_to_new: dict[ObjectID, ObjectID] = {}
+        self._processed: set[ObjectID] = set()
+        self._tree_cache: dict[ObjectID, ObjectID] = {}  # Cache for filtered trees
 
     def _filter_tree_with_subdirectory(
-        self, tree_sha: bytes, subdirectory: bytes
-    ) -> bytes | None:
+        self, tree_sha: ObjectID, subdirectory: bytes
+    ) -> ObjectID | None:
         """Extract a subdirectory from a tree as the new root.
 
         Args:
@@ -153,7 +153,7 @@ class CommitFilter:
         # Return the subdirectory tree
         return current_tree.id
 
-    def _apply_tree_filter(self, tree_sha: bytes) -> bytes:
+    def _apply_tree_filter(self, tree_sha: ObjectID) -> ObjectID:
         """Apply tree filter by checking out tree and running filter.
 
         Args:
@@ -181,7 +181,7 @@ class CommitFilter:
             self._tree_cache[tree_sha] = new_tree_sha
             return new_tree_sha
 
-    def _apply_index_filter(self, tree_sha: bytes) -> bytes:
+    def _apply_index_filter(self, tree_sha: ObjectID) -> ObjectID:
         """Apply index filter by creating temp index and running filter.
 
         Args:
@@ -217,7 +217,7 @@ class CommitFilter:
         finally:
             os.unlink(tmp_index_path)
 
-    def process_commit(self, commit_sha: bytes) -> bytes | None:
+    def process_commit(self, commit_sha: ObjectID) -> ObjectID | None:
         """Process a single commit, creating a filtered version.
 
         Args:
@@ -366,7 +366,7 @@ class CommitFilter:
             self._old_to_new[commit_sha] = commit_sha
             return commit_sha
 
-    def get_mapping(self) -> dict[bytes, bytes]:
+    def get_mapping(self) -> dict[ObjectID, ObjectID]:
         """Get the mapping of old commit SHAs to new commit SHAs.
 
         Returns:
@@ -383,8 +383,8 @@ def filter_refs(
     *,
     keep_original: bool = True,
     force: bool = False,
-    tag_callback: Callable[[bytes, bytes], None] | None = None,
-) -> dict[bytes, bytes]:
+    tag_callback: Callable[[Ref, Ref], None] | None = None,
+) -> dict[ObjectID, ObjectID]:
     """Filter commits reachable from the given refs.
 
     Args:
@@ -405,7 +405,7 @@ def filter_refs(
     # Check if already filtered
     if keep_original and not force:
         for ref in ref_names:
-            original_ref = b"refs/original/" + ref
+            original_ref = Ref(b"refs/original/" + ref)
             if original_ref in refs:
                 raise ValueError(
                     f"Branch {ref.decode()} appears to have been filtered already. "
@@ -416,8 +416,9 @@ def filter_refs(
     for ref in ref_names:
         try:
             # Get the commit SHA for this ref
-            if ref in refs:
-                ref_sha = refs[ref]
+            ref_obj = Ref(ref)
+            if ref_obj in refs:
+                ref_sha = refs[ref_obj]
                 if ref_sha:
                     commit_filter.process_commit(ref_sha)
         except KeyError:
@@ -429,18 +430,19 @@ def filter_refs(
     mapping = commit_filter.get_mapping()
     for ref in ref_names:
         try:
-            if ref in refs:
-                old_sha = refs[ref]
+            ref_obj = Ref(ref)
+            if ref_obj in refs:
+                old_sha = refs[ref_obj]
                 new_sha = mapping.get(old_sha, old_sha)
 
                 if old_sha != new_sha:
                     # Save original ref if requested
                     if keep_original:
-                        original_ref = b"refs/original/" + ref
+                        original_ref = Ref(b"refs/original/" + ref)
                         refs[original_ref] = old_sha
 
                     # Update ref to new commit
-                    refs[ref] = new_sha
+                    refs[ref_obj] = new_sha
         except KeyError:
             # Not a valid ref, skip updating
             warnings.warn(f"Could not update ref {ref!r}: ref not found")

+ 7 - 7
dulwich/gc.py

@@ -28,7 +28,7 @@ DEFAULT_GC_AUTO_PACK_LIMIT = 50
 class GCStats:
     """Statistics from garbage collection."""
 
-    pruned_objects: set[bytes] = field(default_factory=set)
+    pruned_objects: set[ObjectID] = field(default_factory=set)
     bytes_freed: int = 0
     packs_before: int = 0
     packs_after: int = 0
@@ -41,7 +41,7 @@ def find_reachable_objects(
     refs_container: RefsContainer,
     include_reflogs: bool = True,
     progress: Callable[[str], None] | None = None,
-) -> set[bytes]:
+) -> set[ObjectID]:
     """Find all reachable objects in the repository.
 
     Args:
@@ -53,7 +53,7 @@ def find_reachable_objects(
     Returns:
         Set of reachable object SHAs
     """
-    reachable = set()
+    reachable: set[ObjectID] = set()
     pending: deque[ObjectID] = deque()
 
     # Start with all refs
@@ -115,7 +115,7 @@ def find_unreachable_objects(
     refs_container: RefsContainer,
     include_reflogs: bool = True,
     progress: Callable[[str], None] | None = None,
-) -> set[bytes]:
+) -> set[ObjectID]:
     """Find all unreachable objects in the repository.
 
     Args:
@@ -131,7 +131,7 @@ def find_unreachable_objects(
         object_store, refs_container, include_reflogs, progress
     )
 
-    unreachable = set()
+    unreachable: set[ObjectID] = set()
     for sha in object_store:
         if sha not in reachable:
             unreachable.add(sha)
@@ -145,7 +145,7 @@ def prune_unreachable_objects(
     grace_period: int | None = None,
     dry_run: bool = False,
     progress: Callable[[str], None] | None = None,
-) -> tuple[set[bytes], int]:
+) -> tuple[set[ObjectID], int]:
     """Remove unreachable objects from the repository.
 
     Args:
@@ -162,7 +162,7 @@ def prune_unreachable_objects(
         object_store, refs_container, progress=progress
     )
 
-    pruned = set()
+    pruned: set[ObjectID] = set()
     bytes_freed = 0
 
     for sha in unreachable:

+ 3 - 3
dulwich/graph.py

@@ -96,7 +96,7 @@ def _find_lcas(
         List of lowest common ancestor commit IDs
     """
     cands = []
-    cstates = {}
+    cstates: dict[ObjectID, int] = {}
 
     # Flags to Record State
     _ANC_OF_1 = 1  # ancestor of commit 1
@@ -124,7 +124,7 @@ def _find_lcas(
 
     # initialize the working list states with ancestry info
     # note possibility of c1 being one of c2s should be handled
-    wlst: WorkList[bytes] = WorkList()
+    wlst: WorkList[ObjectID] = WorkList()
     cstates[c1] = _ANC_OF_1
     try:
         wlst.add((lookup_stamp(c1), c1))
@@ -298,7 +298,7 @@ def find_octopus_base(
     return lcas
 
 
-def can_fast_forward(repo: "BaseRepo", c1: bytes, c2: bytes) -> bool:
+def can_fast_forward(repo: "BaseRepo", c1: ObjectID, c2: ObjectID) -> bool:
     """Is it possible to fast-forward from c1 to c2?
 
     Args:

+ 1 - 1
dulwich/greenthreads.py

@@ -137,4 +137,4 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
             self.progress: Callable[[bytes], None] = lambda x: None
         else:
             self.progress = progress
-        self._tagged = (get_tagged and get_tagged()) or {}
+        self._tagged: dict[ObjectID, ObjectID] = (get_tagged and get_tagged()) or {}

+ 22 - 18
dulwich/index.py

@@ -69,7 +69,7 @@ from .objects import (
 from .pack import ObjectContainer, SHA1Reader, SHA1Writer
 
 # Type alias for recursive tree structure used in commit_tree
-TreeDict = dict[bytes, "TreeDict | tuple[int, bytes]"]
+TreeDict = dict[bytes, "TreeDict | tuple[int, ObjectID]"]
 
 # 2-bit stage (during merge)
 FLAG_STAGEMASK = 0x3000
@@ -294,7 +294,7 @@ class SerializedIndexEntry:
     uid: int
     gid: int
     size: int
-    sha: bytes
+    sha: ObjectID
     flags: int
     extended_flags: int
 
@@ -505,7 +505,7 @@ class IndexEntry:
     uid: int
     gid: int
     size: int
-    sha: bytes
+    sha: ObjectID
     flags: int = 0
     extended_flags: int = 0
 
@@ -1168,7 +1168,7 @@ class Index:
         """Check if a path exists in the index."""
         return key in self._byname
 
-    def get_sha1(self, path: bytes) -> bytes:
+    def get_sha1(self, path: bytes) -> ObjectID:
         """Return the (git object) SHA1 for the object at a path."""
         value = self[path]
         if isinstance(value, ConflictedIndexEntry):
@@ -1182,7 +1182,7 @@ class Index:
             raise UnmergedEntries
         return value.mode
 
-    def iterobjects(self) -> Iterable[tuple[bytes, bytes, int]]:
+    def iterobjects(self) -> Iterable[tuple[bytes, ObjectID, int]]:
         """Iterate over path, sha, mode tuples for use with commit_tree."""
         for path in self:
             entry = self[path]
@@ -1291,7 +1291,7 @@ class Index:
             want_unchanged=want_unchanged,
         )
 
-    def commit(self, object_store: ObjectContainer) -> bytes:
+    def commit(self, object_store: ObjectContainer) -> ObjectID:
         """Create a new tree from an index.
 
         Args:
@@ -1398,7 +1398,7 @@ class Index:
     def convert_to_sparse(
         self,
         object_store: "BaseObjectStore",
-        tree_sha: bytes,
+        tree_sha: ObjectID,
         sparse_dirs: Set[bytes],
     ) -> None:
         """Convert full index entries to sparse directory entries.
@@ -1443,6 +1443,8 @@ class Index:
 
             # Create a sparse directory entry
             # Use minimal metadata since it's not a real file
+            from dulwich.objects import ObjectID
+
             sparse_entry = IndexEntry(
                 ctime=0,
                 mtime=0,
@@ -1452,7 +1454,7 @@ class Index:
                 uid=0,
                 gid=0,
                 size=0,
-                sha=subtree_sha,
+                sha=ObjectID(subtree_sha),
                 flags=0,
                 extended_flags=EXTENDED_FLAG_SKIP_WORKTREE,
             )
@@ -1505,8 +1507,8 @@ class Index:
 
 
 def commit_tree(
-    object_store: ObjectContainer, blobs: Iterable[tuple[bytes, bytes, int]]
-) -> bytes:
+    object_store: ObjectContainer, blobs: Iterable[tuple[bytes, ObjectID, int]]
+) -> ObjectID:
     """Commit a new tree.
 
     Args:
@@ -1533,7 +1535,7 @@ def commit_tree(
         tree = add_tree(tree_path)
         tree[basename] = (mode, sha)
 
-    def build_tree(path: bytes) -> bytes:
+    def build_tree(path: bytes) -> ObjectID:
         tree = Tree()
         for basename, entry in trees[path].items():
             if isinstance(entry, dict):
@@ -1548,7 +1550,7 @@ def commit_tree(
     return build_tree(b"")
 
 
-def commit_index(object_store: ObjectContainer, index: Index) -> bytes:
+def commit_index(object_store: ObjectContainer, index: Index) -> ObjectID:
     """Create a new tree from an index.
 
     Args:
@@ -1564,7 +1566,7 @@ def changes_from_tree(
     names: Iterable[bytes],
     lookup_entry: Callable[[bytes], tuple[bytes, int]],
     object_store: ObjectContainer,
-    tree: bytes | None,
+    tree: ObjectID | None,
     want_unchanged: bool = False,
 ) -> Iterable[
     tuple[
@@ -1625,6 +1627,8 @@ def index_entry_from_stat(
     if mode is None:
         mode = cleanup_mode(stat_val.st_mode)
 
+    from dulwich.objects import ObjectID
+
     return IndexEntry(
         ctime=stat_val.st_ctime,
         mtime=stat_val.st_mtime,
@@ -1634,7 +1638,7 @@ def index_entry_from_stat(
         uid=stat_val.st_uid,
         gid=stat_val.st_gid,
         size=stat_val.st_size,
-        sha=hex_sha,
+        sha=ObjectID(hex_sha),
         flags=0,
         extended_flags=0,
     )
@@ -1884,7 +1888,7 @@ def build_index_from_tree(
     root_path: str | bytes,
     index_path: str | bytes,
     object_store: ObjectContainer,
-    tree_id: bytes,
+    tree_id: ObjectID,
     honor_filemode: bool = True,
     validate_path_element: Callable[[bytes], bool] = validate_path_element_default,
     symlink_fn: Callable[
@@ -2132,7 +2136,7 @@ def _remove_empty_parents(path: bytes, stop_at: bytes) -> None:
 
 
 def _check_symlink_matches(
-    full_path: bytes, repo_object_store: "BaseObjectStore", entry_sha: bytes
+    full_path: bytes, repo_object_store: "BaseObjectStore", entry_sha: ObjectID
 ) -> bool:
     """Check if symlink target matches expected target.
 
@@ -2158,7 +2162,7 @@ def _check_symlink_matches(
 def _check_file_matches(
     repo_object_store: "BaseObjectStore",
     full_path: bytes,
-    entry_sha: bytes,
+    entry_sha: ObjectID,
     entry_mode: int,
     current_stat: os.stat_result,
     honor_filemode: bool,
@@ -3045,7 +3049,7 @@ def iter_fresh_objects(
     root_path: bytes,
     include_deleted: bool = False,
     object_store: ObjectContainer | None = None,
-) -> Iterator[tuple[bytes, bytes | None, int | None]]:
+) -> Iterator[tuple[bytes, ObjectID | None, int | None]]:
     """Iterate over versions of objects on disk referenced by index.
 
     Args:

+ 5 - 5
dulwich/merge.py

@@ -16,7 +16,7 @@ from dulwich.attrs import GitAttributes
 from dulwich.config import Config
 from dulwich.merge_drivers import get_merge_driver_registry
 from dulwich.object_store import BaseObjectStore
-from dulwich.objects import S_ISGITLINK, Blob, Commit, Tree, is_blob, is_tree
+from dulwich.objects import S_ISGITLINK, Blob, Commit, ObjectID, Tree, is_blob, is_tree
 
 
 def make_merge3(
@@ -303,7 +303,7 @@ class Merger:
             tuple of (merged_tree, list_of_conflicted_paths)
         """
         conflicts: list[bytes] = []
-        merged_entries: dict[bytes, tuple[int | None, bytes | None]] = {}
+        merged_entries: dict[bytes, tuple[int | None, ObjectID | None]] = {}
 
         # Get all paths from all trees
         all_paths = set()
@@ -481,7 +481,7 @@ class Merger:
 def _create_virtual_commit(
     object_store: BaseObjectStore,
     tree: Tree,
-    parents: list[bytes],
+    parents: list[ObjectID],
     message: bytes = b"Virtual merge base",
 ) -> Commit:
     """Create a virtual commit object for recursive merging.
@@ -519,7 +519,7 @@ def _create_virtual_commit(
 
 def recursive_merge(
     object_store: BaseObjectStore,
-    merge_bases: list[bytes],
+    merge_bases: list[ObjectID],
     ours_commit: Commit,
     theirs_commit: Commit,
     gitattributes: GitAttributes | None = None,
@@ -671,7 +671,7 @@ def three_way_merge(
 
 def octopus_merge(
     object_store: BaseObjectStore,
-    merge_bases: list[bytes],
+    merge_bases: list[ObjectID],
     head_commit: Commit,
     other_commits: list[Commit],
     gitattributes: GitAttributes | None = None,

+ 15 - 12
dulwich/notes.py

@@ -24,7 +24,8 @@ import stat
 from collections.abc import Iterator, Sequence
 from typing import TYPE_CHECKING
 
-from .objects import Blob, Tree
+from .objects import Blob, ObjectID, Tree
+from .refs import Ref
 
 if TYPE_CHECKING:
     from .config import StackedConfig
@@ -240,7 +241,7 @@ class NotesTree:
 
             # Build new tree structure
             def update_tree(
-                tree: Tree, components: Sequence[bytes], blob_sha: bytes
+                tree: Tree, components: Sequence[bytes], blob_sha: ObjectID
             ) -> Tree:
                 """Update tree with new note entry.
 
@@ -302,7 +303,7 @@ class NotesTree:
             self._object_store.add_object(self._tree)
 
     def _update_tree_entry(
-        self, tree: Tree, name: bytes, mode: int, sha: bytes
+        self, tree: Tree, name: bytes, mode: int, sha: ObjectID
     ) -> Tree:
         """Update a tree entry and return the updated tree.
 
@@ -333,7 +334,7 @@ class NotesTree:
 
         return new_tree
 
-    def _get_note_sha(self, object_sha: bytes) -> bytes | None:
+    def _get_note_sha(self, object_sha: bytes) -> ObjectID | None:
         """Get the SHA of the note blob for an object.
 
         Args:
@@ -412,7 +413,7 @@ class NotesTree:
 
         # Build new tree structure
         def update_tree(
-            tree: Tree, components: Sequence[bytes], blob_sha: bytes
+            tree: Tree, components: Sequence[bytes], blob_sha: ObjectID
         ) -> Tree:
             """Update tree with new note entry.
 
@@ -546,14 +547,16 @@ class NotesTree:
         self._fanout_level = self._detect_fanout_level()
         return new_tree
 
-    def list_notes(self) -> Iterator[tuple[bytes, bytes]]:
+    def list_notes(self) -> Iterator[tuple[ObjectID, ObjectID]]:
         """List all notes in this tree.
 
         Yields:
             Tuples of (object_sha, note_sha)
         """
 
-        def walk_tree(tree: Tree, prefix: bytes = b"") -> Iterator[tuple[bytes, bytes]]:
+        def walk_tree(
+            tree: Tree, prefix: bytes = b""
+        ) -> Iterator[tuple[ObjectID, ObjectID]]:
             """Walk the notes tree recursively.
 
             Args:
@@ -572,7 +575,7 @@ class NotesTree:
                 elif stat.S_ISREG(mode):  # File
                     # Reconstruct the full hex SHA from the path
                     full_hex = prefix + name
-                    yield (full_hex, sha)
+                    yield (ObjectID(full_hex), sha)
 
         yield from walk_tree(self._tree)
 
@@ -610,7 +613,7 @@ class Notes:
         self,
         notes_ref: bytes | None = None,
         config: "StackedConfig | None" = None,
-    ) -> bytes:
+    ) -> Ref:
         """Get the notes reference to use.
 
         Args:
@@ -625,7 +628,7 @@ class Notes:
                 notes_ref = config.get((b"notes",), b"displayRef")
             if notes_ref is None:
                 notes_ref = DEFAULT_NOTES_REF
-        return notes_ref
+        return Ref(notes_ref)
 
     def get_note(
         self,
@@ -838,7 +841,7 @@ class Notes:
         self,
         notes_ref: bytes | None = None,
         config: "StackedConfig | None" = None,
-    ) -> list[tuple[bytes, bytes]]:
+    ) -> list[tuple[ObjectID, bytes]]:
         """List all notes in a notes ref.
 
         Args:
@@ -870,7 +873,7 @@ class Notes:
             return []
 
         notes_tree_obj = NotesTree(notes_tree, self._object_store)
-        result = []
+        result: list[tuple[ObjectID, bytes]] = []
         for object_sha, note_sha in notes_tree_obj.list_notes():
             note_obj = self._object_store[note_sha]
             if isinstance(note_obj, Blob):

+ 157 - 163
dulwich/object_store.py

@@ -37,6 +37,7 @@ from typing import (
     TYPE_CHECKING,
     BinaryIO,
     Protocol,
+    cast,
 )
 
 from .errors import NotTreeError
@@ -47,6 +48,7 @@ from .objects import (
     Blob,
     Commit,
     ObjectID,
+    RawObjectID,
     ShaFile,
     Tag,
     Tree,
@@ -78,8 +80,8 @@ from .pack import (
     write_pack_data,
     write_pack_index,
 )
-from .protocol import DEPTH_INFINITE
-from .refs import PEELED_TAG_SUFFIX, Ref
+from .protocol import DEPTH_INFINITE, PEELED_TAG_SUFFIX
+from .refs import Ref
 
 if TYPE_CHECKING:
     from .bitmap import EWAHBitmap
@@ -92,11 +94,11 @@ if TYPE_CHECKING:
 class GraphWalker(Protocol):
     """Protocol for graph walker objects."""
 
-    def __next__(self) -> bytes | None:
+    def __next__(self) -> ObjectID | None:
         """Return the next object SHA to visit."""
         ...
 
-    def ack(self, sha: bytes) -> None:
+    def ack(self, sha: ObjectID) -> None:
         """Acknowledge that an object has been received."""
         ...
 
@@ -114,10 +116,10 @@ class ObjectReachabilityProvider(Protocol):
 
     def get_reachable_commits(
         self,
-        heads: Iterable[bytes],
-        exclude: Iterable[bytes] | None = None,
-        shallow: Set[bytes] | None = None,
-    ) -> set[bytes]:
+        heads: Iterable[ObjectID],
+        exclude: Iterable[ObjectID] | None = None,
+        shallow: Set[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all commits reachable from heads, excluding those in exclude.
 
         Args:
@@ -132,9 +134,9 @@ class ObjectReachabilityProvider(Protocol):
 
     def get_reachable_objects(
         self,
-        commits: Iterable[bytes],
-        exclude_commits: Iterable[bytes] | None = None,
-    ) -> set[bytes]:
+        commits: Iterable[ObjectID],
+        exclude_commits: Iterable[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all objects (commits + trees + blobs) reachable from commits.
 
         Args:
@@ -148,8 +150,8 @@ class ObjectReachabilityProvider(Protocol):
 
     def get_tree_objects(
         self,
-        tree_shas: Iterable[bytes],
-    ) -> set[bytes]:
+        tree_shas: Iterable[ObjectID],
+    ) -> set[ObjectID]:
         """Get all trees and blobs reachable from the given trees.
 
         Args:
@@ -175,8 +177,8 @@ DEFAULT_TEMPFILE_GRACE_PERIOD = 14 * 24 * 60 * 60  # 2 weeks
 
 
 def find_shallow(
-    store: ObjectContainer, heads: Iterable[bytes], depth: int
-) -> tuple[set[bytes], set[bytes]]:
+    store: ObjectContainer, heads: Iterable[ObjectID], depth: int
+) -> tuple[set[ObjectID], set[ObjectID]]:
     """Find shallow commits according to a given depth.
 
     Args:
@@ -188,10 +190,10 @@ def find_shallow(
         considered shallow and unshallow according to the arguments. Note that
         these sets may overlap if a commit is reachable along multiple paths.
     """
-    parents: dict[bytes, list[bytes]] = {}
+    parents: dict[ObjectID, list[ObjectID]] = {}
     commit_graph = store.get_commit_graph()
 
-    def get_parents(sha: bytes) -> list[bytes]:
+    def get_parents(sha: ObjectID) -> list[ObjectID]:
         result = parents.get(sha, None)
         if not result:
             # Try to use commit graph first if available
@@ -234,8 +236,8 @@ def find_shallow(
 
 def get_depth(
     store: ObjectContainer,
-    head: bytes,
-    get_parents: Callable[..., list[bytes]] = lambda commit: commit.parents,
+    head: ObjectID,
+    get_parents: Callable[..., list[ObjectID]] = lambda commit: commit.parents,
     max_depth: int | None = None,
 ) -> int:
     """Return the current available depth for the given head.
@@ -291,7 +293,7 @@ class BaseObjectStore:
     ) -> list[ObjectID]:
         """Determine which objects are wanted based on refs."""
 
-        def _want_deepen(sha: bytes) -> bool:
+        def _want_deepen(sha: ObjectID) -> bool:
             if not depth:
                 return False
             if depth == DEPTH_INFINITE:
@@ -306,15 +308,15 @@ class BaseObjectStore:
             and not sha == ZERO_SHA
         ]
 
-    def contains_loose(self, sha: bytes) -> bool:
+    def contains_loose(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is loose."""
         raise NotImplementedError(self.contains_loose)
 
-    def contains_packed(self, sha: bytes) -> bool:
+    def contains_packed(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is packed."""
         return False  # Default implementation for stores that don't support packing
 
-    def __contains__(self, sha1: bytes) -> bool:
+    def __contains__(self, sha1: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1.
 
         This method makes no distinction between loose and packed objects.
@@ -326,7 +328,7 @@ class BaseObjectStore:
         """Iterable of pack objects."""
         raise NotImplementedError
 
-    def get_raw(self, name: bytes) -> tuple[int, bytes]:
+    def get_raw(self, name: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Obtain the raw text for an object.
 
         Args:
@@ -335,12 +337,12 @@ class BaseObjectStore:
         """
         raise NotImplementedError(self.get_raw)
 
-    def __getitem__(self, sha1: ObjectID) -> ShaFile:
+    def __getitem__(self, sha1: ObjectID | RawObjectID) -> ShaFile:
         """Obtain an object by SHA1."""
         type_num, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs that are present in this store."""
         raise NotImplementedError(self.__iter__)
 
@@ -381,8 +383,8 @@ class BaseObjectStore:
 
     def tree_changes(
         self,
-        source: bytes | None,
-        target: bytes | None,
+        source: ObjectID | None,
+        target: ObjectID | None,
         want_unchanged: bool = False,
         include_trees: bool = False,
         change_type_same: bool = False,
@@ -392,7 +394,7 @@ class BaseObjectStore:
         tuple[
             tuple[bytes | None, bytes | None],
             tuple[int | None, int | None],
-            tuple[bytes | None, bytes | None],
+            tuple[ObjectID | None, ObjectID | None],
         ]
     ]:
         """Find the differences between the contents of two trees.
@@ -434,7 +436,7 @@ class BaseObjectStore:
             )
 
     def iter_tree_contents(
-        self, tree_id: bytes, include_trees: bool = False
+        self, tree_id: ObjectID, include_trees: bool = False
     ) -> Iterator[TreeEntry]:
         """Iterate the contents of a tree and all subtrees.
 
@@ -454,7 +456,7 @@ class BaseObjectStore:
         return iter_tree_contents(self, tree_id, include_trees=include_trees)
 
     def iterobjects_subset(
-        self, shas: Iterable[bytes], *, allow_missing: bool = False
+        self, shas: Iterable[ObjectID], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
         """Iterate over a subset of objects in the store.
 
@@ -477,7 +479,7 @@ class BaseObjectStore:
 
     def iter_unpacked_subset(
         self,
-        shas: Iterable[bytes],
+        shas: Iterable[ObjectID | RawObjectID],
         include_comp: bool = False,
         allow_missing: bool = False,
         convert_ofs_delta: bool = True,
@@ -518,13 +520,13 @@ class BaseObjectStore:
 
     def find_missing_objects(
         self,
-        haves: Iterable[bytes],
-        wants: Iterable[bytes],
-        shallow: Set[bytes] | None = None,
+        haves: Iterable[ObjectID],
+        wants: Iterable[ObjectID],
+        shallow: Set[ObjectID] | None = None,
         progress: Callable[..., None] | None = None,
-        get_tagged: Callable[[], dict[bytes, bytes]] | None = None,
-        get_parents: Callable[..., list[bytes]] = lambda commit: commit.parents,
-    ) -> Iterator[tuple[bytes, PackHint | None]]:
+        get_tagged: Callable[[], dict[ObjectID, ObjectID]] | None = None,
+        get_parents: Callable[..., list[ObjectID]] = lambda commit: commit.parents,
+    ) -> Iterator[tuple[ObjectID, PackHint | None]]:
         """Find the missing objects required for a set of revisions.
 
         Args:
@@ -551,7 +553,7 @@ class BaseObjectStore:
         )
         return iter(finder)
 
-    def find_common_revisions(self, graphwalker: GraphWalker) -> list[bytes]:
+    def find_common_revisions(self, graphwalker: GraphWalker) -> list[ObjectID]:
         """Find which revisions this store has in common using graphwalker.
 
         Args:
@@ -569,10 +571,10 @@ class BaseObjectStore:
 
     def generate_pack_data(
         self,
-        have: Iterable[bytes],
-        want: Iterable[bytes],
+        have: Iterable[ObjectID],
+        want: Iterable[ObjectID],
         *,
-        shallow: Set[bytes] | None = None,
+        shallow: Set[ObjectID] | None = None,
         progress: Callable[..., None] | None = None,
         ofs_delta: bool = True,
     ) -> tuple[int, Iterator[UnpackedObject]]:
@@ -597,7 +599,7 @@ class BaseObjectStore:
             progress=progress,
         )
 
-    def peel_sha(self, sha: bytes) -> bytes:
+    def peel_sha(self, sha: ObjectID | RawObjectID) -> ObjectID:
         """Peel all tags from a SHA.
 
         Args:
@@ -615,8 +617,8 @@ class BaseObjectStore:
 
     def _get_depth(
         self,
-        head: bytes,
-        get_parents: Callable[..., list[bytes]] = lambda commit: commit.parents,
+        head: ObjectID,
+        get_parents: Callable[..., list[ObjectID]] = lambda commit: commit.parents,
         max_depth: int | None = None,
     ) -> int:
         """Return the current available depth for the given head.
@@ -667,7 +669,7 @@ class BaseObjectStore:
         return None
 
     def write_commit_graph(
-        self, refs: Sequence[bytes] | None = None, reachable: bool = True
+        self, refs: Iterable[ObjectID] | None = None, reachable: bool = True
     ) -> None:
         """Write a commit graph file for this object store.
 
@@ -682,7 +684,7 @@ class BaseObjectStore:
         """
         raise NotImplementedError(self.write_commit_graph)
 
-    def get_object_mtime(self, sha: bytes) -> float:
+    def get_object_mtime(self, sha: ObjectID) -> float:
         """Get the modification time of an object.
 
         Args:
@@ -729,7 +731,7 @@ class PackCapableObjectStore(BaseObjectStore, PackedObjectContainer):
         raise NotImplementedError(self.add_pack_data)
 
     def get_unpacked_object(
-        self, sha1: bytes, *, include_comp: bool = False
+        self, sha1: ObjectID | RawObjectID, *, include_comp: bool = False
     ) -> "UnpackedObject":
         """Get a raw unresolved object.
 
@@ -746,7 +748,7 @@ class PackCapableObjectStore(BaseObjectStore, PackedObjectContainer):
         return UnpackedObject(obj.type_num, sha=sha1, decomp_chunks=obj.as_raw_chunks())
 
     def iterobjects_subset(
-        self, shas: Iterable[bytes], *, allow_missing: bool = False
+        self, shas: Iterable[ObjectID], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
         """Iterate over a subset of objects.
 
@@ -878,7 +880,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         """Return list of alternate object stores."""
         return []
 
-    def contains_packed(self, sha: bytes) -> bool:
+    def contains_packed(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is packed.
 
         This does not check alternates.
@@ -891,7 +893,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
                 pass
         return False
 
-    def __contains__(self, sha: bytes) -> bool:
+    def __contains__(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1.
 
         This method makes no distinction between loose and packed objects.
@@ -913,10 +915,10 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
 
     def generate_pack_data(
         self,
-        have: Iterable[bytes],
-        want: Iterable[bytes],
+        have: Iterable[ObjectID],
+        want: Iterable[ObjectID],
         *,
-        shallow: Set[bytes] | None = None,
+        shallow: Set[ObjectID] | None = None,
         progress: Callable[..., None] | None = None,
         ofs_delta: bool = True,
     ) -> tuple[int, Iterator[UnpackedObject]]:
@@ -981,19 +983,19 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
                 count += 1
         return count
 
-    def _iter_alternate_objects(self) -> Iterator[bytes]:
+    def _iter_alternate_objects(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs of all the objects in alternate stores."""
         for alternate in self.alternates:
             yield from alternate
 
-    def _iter_loose_objects(self) -> Iterator[bytes]:
+    def _iter_loose_objects(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs of all loose objects."""
         raise NotImplementedError(self._iter_loose_objects)
 
-    def _get_loose_object(self, sha: bytes) -> ShaFile | None:
+    def _get_loose_object(self, sha: ObjectID | RawObjectID) -> ShaFile | None:
         raise NotImplementedError(self._get_loose_object)
 
-    def delete_loose_object(self, sha: bytes) -> None:
+    def delete_loose_object(self, sha: ObjectID) -> None:
         """Delete a loose object.
 
         This method only handles loose objects. For packed objects,
@@ -1079,7 +1081,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
 
     def generate_pack_bitmaps(
         self,
-        refs: dict[bytes, bytes],
+        refs: dict[Ref, ObjectID],
         *,
         commit_interval: int | None = None,
         progress: Callable[[str], None] | None = None,
@@ -1109,7 +1111,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
 
         return count
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs that are present in this store."""
         self._update_pack_cache()
         for pack in self._iter_cached_packs():
@@ -1120,14 +1122,14 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         yield from self._iter_loose_objects()
         yield from self._iter_alternate_objects()
 
-    def contains_loose(self, sha: bytes) -> bool:
+    def contains_loose(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is loose.
 
         This does not check alternates.
         """
         return self._get_loose_object(sha) is not None
 
-    def get_raw(self, name: bytes) -> tuple[int, bytes]:
+    def get_raw(self, name: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Obtain the raw fulltext for an object.
 
         Args:
@@ -1137,10 +1139,10 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         if name == ZERO_SHA:
             raise KeyError(name)
         if len(name) == 40:
-            sha = hex_to_sha(name)
-            hexsha = name
+            sha = hex_to_sha(cast(ObjectID, name))
+            hexsha = cast(ObjectID, name)
         elif len(name) == 20:
-            sha = name
+            sha = cast(RawObjectID, name)
             hexsha = None
         else:
             raise AssertionError(f"Invalid object name {name!r}")
@@ -1150,7 +1152,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
             except (KeyError, PackFileDisappeared):
                 pass
         if hexsha is None:
-            hexsha = sha_to_hex(name)
+            hexsha = sha_to_hex(sha)
         ret = self._get_loose_object(hexsha)
         if ret is not None:
             return ret.type_num, ret.as_raw_string()
@@ -1170,7 +1172,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
 
     def iter_unpacked_subset(
         self,
-        shas: Iterable[bytes],
+        shas: Iterable[ObjectID | RawObjectID],
         include_comp: bool = False,
         allow_missing: bool = False,
         convert_ofs_delta: bool = True,
@@ -1189,7 +1191,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         Raises:
           KeyError: If an object is missing and allow_missing is False
         """
-        todo: set[bytes] = set(shas)
+        todo: set[ObjectID | RawObjectID] = set(shas)
         for p in self._iter_cached_packs():
             for unpacked in p.iter_unpacked_subset(
                 todo,
@@ -1225,7 +1227,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
                 todo.remove(hexsha)
 
     def iterobjects_subset(
-        self, shas: Iterable[bytes], *, allow_missing: bool = False
+        self, shas: Iterable[ObjectID], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
         """Iterate over a subset of objects in the store.
 
@@ -1241,7 +1243,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         Raises:
           KeyError: If an object is missing and allow_missing is False
         """
-        todo: set[bytes] = set(shas)
+        todo: set[ObjectID] = set(shas)
         for p in self._iter_cached_packs():
             for o in p.iterobjects_subset(todo, allow_missing=True):
                 yield o
@@ -1275,10 +1277,10 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         if sha1 == ZERO_SHA:
             raise KeyError(sha1)
         if len(sha1) == 40:
-            sha = hex_to_sha(sha1)
-            hexsha = sha1
+            sha = hex_to_sha(cast(ObjectID, sha1))
+            hexsha = cast(ObjectID, sha1)
         elif len(sha1) == 20:
-            sha = sha1
+            sha = cast(RawObjectID, sha1)
             hexsha = None
         else:
             raise AssertionError(f"Invalid object sha1 {sha1!r}")
@@ -1288,7 +1290,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
             except (KeyError, PackFileDisappeared):
                 pass
         if hexsha is None:
-            hexsha = sha_to_hex(sha1)
+            hexsha = sha_to_hex(sha)
         # Maybe something else has added a pack with the object
         # in the mean time?
         for pack in self._update_pack_cache():
@@ -1614,11 +1616,11 @@ class DiskObjectStore(PackBasedObjectStore):
             self._pack_cache.pop(f).close()
         return new_packs
 
-    def _get_shafile_path(self, sha: bytes) -> str:
+    def _get_shafile_path(self, sha: ObjectID | RawObjectID) -> str:
         # Check from object dir
         return hex_to_filename(os.fspath(self.path), sha)
 
-    def _iter_loose_objects(self) -> Iterator[bytes]:
+    def _iter_loose_objects(self) -> Iterator[ObjectID]:
         for base in os.listdir(self.path):
             if len(base) != 2:
                 continue
@@ -1626,7 +1628,7 @@ class DiskObjectStore(PackBasedObjectStore):
                 sha = os.fsencode(base + rest)
                 if not valid_hexsha(sha):
                     continue
-                yield sha
+                yield ObjectID(sha)
 
     def count_loose_objects(self) -> int:
         """Count the number of loose objects in the object store.
@@ -1654,14 +1656,14 @@ class DiskObjectStore(PackBasedObjectStore):
 
         return count
 
-    def _get_loose_object(self, sha: bytes) -> ShaFile | None:
+    def _get_loose_object(self, sha: ObjectID | RawObjectID) -> ShaFile | None:
         path = self._get_shafile_path(sha)
         try:
             return ShaFile.from_path(path)
         except FileNotFoundError:
             return None
 
-    def delete_loose_object(self, sha: bytes) -> None:
+    def delete_loose_object(self, sha: ObjectID) -> None:
         """Delete a loose object from disk.
 
         Args:
@@ -1672,7 +1674,7 @@ class DiskObjectStore(PackBasedObjectStore):
         """
         os.remove(self._get_shafile_path(sha))
 
-    def get_object_mtime(self, sha: bytes) -> float:
+    def get_object_mtime(self, sha: ObjectID) -> float:
         """Get the modification time of an object.
 
         Args:
@@ -1732,7 +1734,7 @@ class DiskObjectStore(PackBasedObjectStore):
         num_objects: int,
         indexer: PackIndexer,
         progress: Callable[..., None] | None = None,
-        refs: dict[bytes, bytes] | None = None,
+        refs: dict[Ref, ObjectID] | None = None,
     ) -> Pack:
         """Move a specific file containing a pack into the pack directory.
 
@@ -1974,14 +1976,14 @@ class DiskObjectStore(PackBasedObjectStore):
             os.chmod(pack_path, dir_mode)
         return cls(path, file_mode=file_mode, dir_mode=dir_mode)
 
-    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+    def iter_prefix(self, prefix: bytes) -> Iterator[ObjectID]:
         """Iterate over all object SHAs with the given prefix.
 
         Args:
           prefix: Hex prefix to search for (as bytes)
 
         Returns:
-          Iterator of object SHAs (as bytes) matching the prefix
+          Iterator of object SHAs (as ObjectID) matching the prefix
         """
         if len(prefix) < 2:
             yield from super().iter_prefix(prefix)
@@ -1992,7 +1994,7 @@ class DiskObjectStore(PackBasedObjectStore):
         try:
             for name in os.listdir(os.path.join(self.path, dir)):
                 if name.startswith(rest):
-                    sha = os.fsencode(dir + name)
+                    sha = ObjectID(os.fsencode(dir + name))
                     if sha not in seen:
                         seen.add(sha)
                         yield sha
@@ -2005,8 +2007,8 @@ class DiskObjectStore(PackBasedObjectStore):
                 if len(prefix) % 2 == 0
                 else binascii.unhexlify(prefix[:-1])
             )
-            for sha in p.index.iter_prefix(bin_prefix):
-                sha = sha_to_hex(sha)
+            for bin_sha in p.index.iter_prefix(bin_prefix):
+                sha = sha_to_hex(bin_sha)
                 if sha.startswith(prefix) and sha not in seen:
                     seen.add(sha)
                     yield sha
@@ -2035,7 +2037,7 @@ class DiskObjectStore(PackBasedObjectStore):
         return self._commit_graph
 
     def write_commit_graph(
-        self, refs: Iterable[bytes] | None = None, reachable: bool = True
+        self, refs: Iterable[ObjectID] | None = None, reachable: bool = True
     ) -> None:
         """Write a commit graph file for this object store.
 
@@ -2068,20 +2070,8 @@ class DiskObjectStore(PackBasedObjectStore):
             # Get all reachable commits
             commit_ids = get_reachable_commits(self, all_refs)
         else:
-            # Just use the direct ref targets - ensure they're hex ObjectIDs
-            commit_ids = []
-            for ref in all_refs:
-                if isinstance(ref, bytes) and len(ref) == 40:
-                    # Already hex ObjectID
-                    commit_ids.append(ref)
-                elif isinstance(ref, bytes) and len(ref) == 20:
-                    # Binary SHA, convert to hex ObjectID
-                    from .objects import sha_to_hex
-
-                    commit_ids.append(sha_to_hex(ref))
-                else:
-                    # Assume it's already correct format
-                    commit_ids.append(ref)
+            # Just use the direct ref targets (already ObjectIDs)
+            commit_ids = all_refs
 
         if commit_ids:
             # Write commit graph directly to our object store path
@@ -2169,26 +2159,26 @@ class MemoryObjectStore(PackCapableObjectStore):
         Creates an empty in-memory object store.
         """
         super().__init__()
-        self._data: dict[bytes, ShaFile] = {}
+        self._data: dict[ObjectID, ShaFile] = {}
         self.pack_compression_level = -1
 
-    def _to_hexsha(self, sha: bytes) -> bytes:
+    def _to_hexsha(self, sha: ObjectID | RawObjectID) -> ObjectID:
         if len(sha) == 40:
-            return sha
+            return cast(ObjectID, sha)
         elif len(sha) == 20:
-            return sha_to_hex(sha)
+            return sha_to_hex(cast(RawObjectID, sha))
         else:
             raise ValueError(f"Invalid sha {sha!r}")
 
-    def contains_loose(self, sha: bytes) -> bool:
+    def contains_loose(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is loose."""
         return self._to_hexsha(sha) in self._data
 
-    def contains_packed(self, sha: bytes) -> bool:
+    def contains_packed(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if a particular object is present by SHA1 and is packed."""
         return False
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs that are present in this store."""
         return iter(self._data.keys())
 
@@ -2197,7 +2187,7 @@ class MemoryObjectStore(PackCapableObjectStore):
         """List with pack objects."""
         return []
 
-    def get_raw(self, name: ObjectID) -> tuple[int, bytes]:
+    def get_raw(self, name: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Obtain the raw text for an object.
 
         Args:
@@ -2207,7 +2197,7 @@ class MemoryObjectStore(PackCapableObjectStore):
         obj = self[self._to_hexsha(name)]
         return obj.type_num, obj.as_raw_string()
 
-    def __getitem__(self, name: ObjectID) -> ShaFile:
+    def __getitem__(self, name: ObjectID | RawObjectID) -> ShaFile:
         """Retrieve an object by SHA.
 
         Args:
@@ -2350,8 +2340,10 @@ class ObjectIterator(Protocol):
 
 
 def tree_lookup_path(
-    lookup_obj: Callable[[bytes], ShaFile], root_sha: bytes, path: bytes
-) -> tuple[int, bytes]:
+    lookup_obj: Callable[[ObjectID | RawObjectID], ShaFile],
+    root_sha: ObjectID | RawObjectID,
+    path: bytes,
+) -> tuple[int, ObjectID]:
     """Look up an object in a Git tree.
 
     Args:
@@ -2388,8 +2380,8 @@ def _collect_filetree_revs(
 
 
 def _split_commits_and_tags(
-    obj_store: ObjectContainer, lst: Iterable[bytes], *, ignore_unknown: bool = False
-) -> tuple[set[bytes], set[bytes], set[bytes]]:
+    obj_store: ObjectContainer, lst: Iterable[ObjectID], *, ignore_unknown: bool = False
+) -> tuple[set[ObjectID], set[ObjectID], set[ObjectID]]:
     """Split object id list into three lists with commit, tag, and other SHAs.
 
     Commits referenced by tags are included into commits
@@ -2404,9 +2396,9 @@ def _split_commits_and_tags(
         silently.
     Returns: A tuple of (commits, tags, others) SHA1s
     """
-    commits: set[bytes] = set()
-    tags: set[bytes] = set()
-    others: set[bytes] = set()
+    commits: set[ObjectID] = set()
+    tags: set[ObjectID] = set()
+    others: set[ObjectID] = set()
     for e in lst:
         try:
             o = obj_store[e]
@@ -2447,13 +2439,13 @@ class MissingObjectFinder:
     def __init__(
         self,
         object_store: BaseObjectStore,
-        haves: Iterable[bytes],
-        wants: Iterable[bytes],
+        haves: Iterable[ObjectID],
+        wants: Iterable[ObjectID],
         *,
-        shallow: Set[bytes] | None = None,
+        shallow: Set[ObjectID] | None = None,
         progress: Callable[[bytes], None] | None = None,
-        get_tagged: Callable[[], dict[bytes, bytes]] | None = None,
-        get_parents: Callable[[Commit], list[bytes]] = lambda commit: commit.parents,
+        get_tagged: Callable[[], dict[ObjectID, ObjectID]] | None = None,
+        get_parents: Callable[[Commit], list[ObjectID]] = lambda commit: commit.parents,
     ) -> None:
         """Initialize a MissingObjectFinder.
 
@@ -2501,7 +2493,7 @@ class MissingObjectFinder:
             get_parents=self._get_parents,
         )
 
-        self.remote_has: set[bytes] = set()
+        self.remote_has: set[ObjectID] = set()
         # Now, fill sha_done with commits and revisions of
         # files and directories known to be both locally
         # and on target. Thus these commits and files
@@ -2537,7 +2529,7 @@ class MissingObjectFinder:
             self.progress = progress
         self._tagged = (get_tagged and get_tagged()) or {}
 
-    def get_remote_has(self) -> set[bytes]:
+    def get_remote_has(self) -> set[ObjectID]:
         """Get the set of SHAs the remote has.
 
         Returns:
@@ -2555,7 +2547,7 @@ class MissingObjectFinder:
         """
         self.objects_to_send.update([e for e in entries if e[0] not in self.sha_done])
 
-    def __next__(self) -> tuple[bytes, PackHint | None]:
+    def __next__(self) -> tuple[ObjectID, PackHint | None]:
         """Get the next object to send.
 
         Returns:
@@ -2606,7 +2598,7 @@ class MissingObjectFinder:
             pack_hint = (type_num, name)
         return (sha, pack_hint)
 
-    def __iter__(self) -> Iterator[tuple[bytes, PackHint | None]]:
+    def __iter__(self) -> Iterator[tuple[ObjectID, PackHint | None]]:
         """Return iterator over objects to send.
 
         Returns:
@@ -2698,7 +2690,7 @@ class ObjectStoreGraphWalker:
 def commit_tree_changes(
     object_store: BaseObjectStore,
     tree: ObjectID | Tree,
-    changes: Sequence[tuple[bytes, int | None, bytes | None]],
+    changes: Sequence[tuple[bytes, int | None, ObjectID | None]],
 ) -> ObjectID:
     """Commit a specified set of changes to a tree structure.
 
@@ -2729,7 +2721,7 @@ def commit_tree_changes(
         sha_obj = object_store[tree]
         assert isinstance(sha_obj, Tree)
         tree_obj = sha_obj
-    nested_changes: dict[bytes, list[tuple[bytes, int | None, bytes | None]]] = {}
+    nested_changes: dict[bytes, list[tuple[bytes, int | None, ObjectID | None]]] = {}
     for path, new_mode, new_sha in changes:
         try:
             (dirname, subpath) = path.split(b"/", 1)
@@ -2743,7 +2735,7 @@ def commit_tree_changes(
             nested_changes.setdefault(dirname, []).append((subpath, new_mode, new_sha))
     for name, subchanges in nested_changes.items():
         try:
-            orig_subtree_id: bytes | Tree = tree_obj[name][1]
+            orig_subtree_id: ObjectID | Tree = tree_obj[name][1]
         except KeyError:
             # For new directories, pass an empty Tree object
             orig_subtree_id = Tree()
@@ -2832,7 +2824,7 @@ class OverlayObjectStore(BaseObjectStore):
                     done.add(o_id)
 
     def iterobjects_subset(
-        self, shas: Iterable[bytes], *, allow_missing: bool = False
+        self, shas: Iterable[ObjectID], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
         """Iterate over a subset of objects from the overlaid stores.
 
@@ -2847,7 +2839,7 @@ class OverlayObjectStore(BaseObjectStore):
           KeyError: If an object is missing and allow_missing is False
         """
         todo = set(shas)
-        found: set[bytes] = set()
+        found: set[ObjectID] = set()
 
         for b in self.bases:
             # Create a copy of todo for each base to avoid modifying
@@ -2864,7 +2856,7 @@ class OverlayObjectStore(BaseObjectStore):
 
     def iter_unpacked_subset(
         self,
-        shas: Iterable[bytes],
+        shas: Iterable[ObjectID | RawObjectID],
         include_comp: bool = False,
         allow_missing: bool = False,
         convert_ofs_delta: bool = True,
@@ -2883,7 +2875,7 @@ class OverlayObjectStore(BaseObjectStore):
         Raises:
           KeyError: If an object is missing and allow_missing is False
         """
-        todo = set(shas)
+        todo: set[ObjectID | RawObjectID] = set(shas)
         for b in self.bases:
             for o in b.iter_unpacked_subset(
                 todo,
@@ -2896,7 +2888,7 @@ class OverlayObjectStore(BaseObjectStore):
         if todo and not allow_missing:
             raise KeyError(next(iter(todo)))
 
-    def get_raw(self, sha_id: ObjectID) -> tuple[int, bytes]:
+    def get_raw(self, sha_id: ObjectID | RawObjectID) -> tuple[int, bytes]:
         """Get the raw object data from the overlaid stores.
 
         Args:
@@ -2915,7 +2907,7 @@ class OverlayObjectStore(BaseObjectStore):
                 pass
         raise KeyError(sha_id)
 
-    def contains_packed(self, sha: bytes) -> bool:
+    def contains_packed(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if an object is packed in any base store.
 
         Args:
@@ -2929,7 +2921,7 @@ class OverlayObjectStore(BaseObjectStore):
                 return True
         return False
 
-    def contains_loose(self, sha: bytes) -> bool:
+    def contains_loose(self, sha: ObjectID | RawObjectID) -> bool:
         """Check if an object is loose in any base store.
 
         Args:
@@ -2958,14 +2950,14 @@ def read_packs_file(f: BinaryIO) -> Iterator[str]:
 class BucketBasedObjectStore(PackBasedObjectStore):
     """Object store implementation that uses a bucket store like S3 as backend."""
 
-    def _iter_loose_objects(self) -> Iterator[bytes]:
+    def _iter_loose_objects(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs of all loose objects."""
         return iter([])
 
-    def _get_loose_object(self, sha: bytes) -> None:
+    def _get_loose_object(self, sha: ObjectID | RawObjectID) -> None:
         return None
 
-    def delete_loose_object(self, sha: bytes) -> None:
+    def delete_loose_object(self, sha: ObjectID) -> None:
         """Delete a loose object (no-op for bucket stores).
 
         Bucket-based stores don't have loose objects, so this is a no-op.
@@ -3069,7 +3061,7 @@ def _collect_ancestors(
     heads: Iterable[ObjectID],
     common: frozenset[ObjectID] = frozenset(),
     shallow: frozenset[ObjectID] = frozenset(),
-    get_parents: Callable[[Commit], list[bytes]] = lambda commit: commit.parents,
+    get_parents: Callable[[Commit], list[ObjectID]] = lambda commit: commit.parents,
 ) -> tuple[set[ObjectID], set[ObjectID]]:
     """Collect all ancestors of heads up to (excluding) those in common.
 
@@ -3154,7 +3146,7 @@ def iter_tree_contents(
 
 def iter_commit_contents(
     store: ObjectContainer,
-    commit: Commit | bytes,
+    commit: Commit | ObjectID | RawObjectID,
     *,
     include: Sequence[str | bytes | Path] | None = None,
 ) -> Iterator[TreeEntry]:
@@ -3203,7 +3195,9 @@ def iter_commit_contents(
             yield TreeEntry(path, mode, obj_id)
 
 
-def peel_sha(store: ObjectContainer, sha: bytes) -> tuple[ShaFile, ShaFile]:
+def peel_sha(
+    store: ObjectContainer, sha: ObjectID | RawObjectID
+) -> tuple[ShaFile, ShaFile]:
     """Peel all tags from a SHA.
 
     Args:
@@ -3240,10 +3234,10 @@ class GraphTraversalReachability:
 
     def get_reachable_commits(
         self,
-        heads: Iterable[bytes],
-        exclude: Iterable[bytes] | None = None,
-        shallow: Set[bytes] | None = None,
-    ) -> set[bytes]:
+        heads: Iterable[ObjectID],
+        exclude: Iterable[ObjectID] | None = None,
+        shallow: Set[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all commits reachable from heads, excluding those in exclude.
 
         Uses _collect_ancestors for commit traversal.
@@ -3265,8 +3259,8 @@ class GraphTraversalReachability:
 
     def get_tree_objects(
         self,
-        tree_shas: Iterable[bytes],
-    ) -> set[bytes]:
+        tree_shas: Iterable[ObjectID],
+    ) -> set[ObjectID]:
         """Get all trees and blobs reachable from the given trees.
 
         Uses _collect_filetree_revs for tree traversal.
@@ -3277,16 +3271,16 @@ class GraphTraversalReachability:
         Returns:
           Set of tree and blob SHAs
         """
-        result: set[bytes] = set()
+        result: set[ObjectID] = set()
         for tree_sha in tree_shas:
             _collect_filetree_revs(self.store, tree_sha, result)
         return result
 
     def get_reachable_objects(
         self,
-        commits: Iterable[bytes],
-        exclude_commits: Iterable[bytes] | None = None,
-    ) -> set[bytes]:
+        commits: Iterable[ObjectID],
+        exclude_commits: Iterable[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all objects (commits + trees + blobs) reachable from commits.
 
         Args:
@@ -3341,8 +3335,8 @@ class BitmapReachability:
 
     def _combine_commit_bitmaps(
         self,
-        commit_shas: set[bytes],
-        exclude_shas: set[bytes] | None = None,
+        commit_shas: set[ObjectID],
+        exclude_shas: set[ObjectID] | None = None,
     ) -> tuple["EWAHBitmap", "Pack"] | None:
         """Combine bitmaps for multiple commits using OR, with optional exclusion.
 
@@ -3413,10 +3407,10 @@ class BitmapReachability:
 
     def get_reachable_commits(
         self,
-        heads: Iterable[bytes],
-        exclude: Iterable[bytes] | None = None,
-        shallow: Set[bytes] | None = None,
-    ) -> set[bytes]:
+        heads: Iterable[ObjectID],
+        exclude: Iterable[ObjectID] | None = None,
+        shallow: Set[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all commits reachable from heads using bitmaps where possible.
 
         Args:
@@ -3455,8 +3449,8 @@ class BitmapReachability:
 
     def get_tree_objects(
         self,
-        tree_shas: Iterable[bytes],
-    ) -> set[bytes]:
+        tree_shas: Iterable[ObjectID],
+    ) -> set[ObjectID]:
         """Get all trees and blobs reachable from the given trees.
 
         Args:
@@ -3470,9 +3464,9 @@ class BitmapReachability:
 
     def get_reachable_objects(
         self,
-        commits: Iterable[bytes],
-        exclude_commits: Iterable[bytes] | None = None,
-    ) -> set[bytes]:
+        commits: Iterable[ObjectID],
+        exclude_commits: Iterable[ObjectID] | None = None,
+    ) -> set[ObjectID]:
         """Get all objects reachable from commits using bitmaps.
 
         Args:

+ 39 - 28
dulwich/objects.py

@@ -43,7 +43,7 @@ if sys.version_info >= (3, 11):
 else:
     from typing_extensions import Self
 
-from typing import TypeGuard
+from typing import NewType, TypeGuard
 
 from . import replace_me
 from .errors import (
@@ -62,8 +62,6 @@ if TYPE_CHECKING:
 
     from .file import _GitFile
 
-ZERO_SHA = b"0" * 40
-
 # Header fields for commits
 _TREE_HEADER = b"tree"
 _PARENT_HEADER = b"parent"
@@ -93,7 +91,14 @@ SIGNATURE_PGP = b"pgp"
 SIGNATURE_SSH = b"ssh"
 
 
-ObjectID = bytes
+# Hex SHA type
+ObjectID = NewType("ObjectID", bytes)
+
+# Raw SHA type
+RawObjectID = NewType("RawObjectID", bytes)
+
+# Zero SHA constant
+ZERO_SHA: ObjectID = ObjectID(b"0" * 40)
 
 
 class EmptyFileException(FileFormatException):
@@ -117,18 +122,18 @@ def _decompress(string: bytes) -> bytes:
     return dcomped
 
 
-def sha_to_hex(sha: ObjectID) -> bytes:
+def sha_to_hex(sha: RawObjectID) -> ObjectID:
     """Takes a string and returns the hex of the sha within."""
     hexsha = binascii.hexlify(sha)
     assert len(hexsha) == 40, f"Incorrect length of sha1 string: {hexsha!r}"
-    return hexsha
+    return ObjectID(hexsha)
 
 
-def hex_to_sha(hex: bytes | str) -> bytes:
+def hex_to_sha(hex: ObjectID | str) -> RawObjectID:
     """Takes a hex sha and returns a binary sha."""
     assert len(hex) == 40, f"Incorrect length of hexsha: {hex!r}"
     try:
-        return binascii.unhexlify(hex)
+        return RawObjectID(binascii.unhexlify(hex))
     except TypeError as exc:
         if not isinstance(hex, bytes):
             raise
@@ -206,7 +211,7 @@ def filename_to_hex(filename: str | bytes) -> str:
         base_b, rest_b = names_b
         assert len(base_b) == 2 and len(rest_b) == 38, errmsg
         hex_bytes = base_b + rest_b
-    hex_to_sha(hex_bytes)
+    hex_to_sha(ObjectID(hex_bytes))
     return hex_bytes.decode("ascii")
 
 
@@ -337,7 +342,7 @@ class FixedSha:
         if not isinstance(hexsha, bytes):
             raise TypeError(f"Expected bytes for hexsha, got {hexsha!r}")
         self._hexsha = hexsha
-        self._sha = hex_to_sha(hexsha)
+        self._sha = hex_to_sha(ObjectID(hexsha))
 
     def digest(self) -> bytes:
         """Return the raw SHA digest."""
@@ -481,13 +486,17 @@ class ShaFile:
         """Return a string representing this object, fit for display."""
         return self.as_raw_string().decode("utf-8", "replace")
 
-    def set_raw_string(self, text: bytes, sha: ObjectID | None = None) -> None:
+    def set_raw_string(
+        self, text: bytes, sha: ObjectID | RawObjectID | None = None
+    ) -> None:
         """Set the contents of this object from a serialized string."""
         if not isinstance(text, bytes):
             raise TypeError(f"Expected bytes for text, got {text!r}")
         self.set_raw_chunks([text], sha)
 
-    def set_raw_chunks(self, chunks: list[bytes], sha: ObjectID | None = None) -> None:
+    def set_raw_chunks(
+        self, chunks: list[bytes], sha: ObjectID | RawObjectID | None = None
+    ) -> None:
         """Set the contents of this object from a list of chunks."""
         self._chunked_text = chunks
         self._deserialize(chunks)
@@ -571,7 +580,7 @@ class ShaFile:
 
     @staticmethod
     def from_raw_string(
-        type_num: int, string: bytes, sha: ObjectID | None = None
+        type_num: int, string: bytes, sha: ObjectID | RawObjectID | None = None
     ) -> "ShaFile":
         """Creates an object of the indicated type from the raw string given.
 
@@ -589,7 +598,7 @@ class ShaFile:
 
     @staticmethod
     def from_raw_chunks(
-        type_num: int, chunks: list[bytes], sha: ObjectID | None = None
+        type_num: int, chunks: list[bytes], sha: ObjectID | RawObjectID | None = None
     ) -> "ShaFile":
         """Creates an object of the indicated type from the raw chunks given.
 
@@ -673,9 +682,9 @@ class ShaFile:
         return obj_class.from_raw_string(self.type_num, self.as_raw_string(), self.id)
 
     @property
-    def id(self) -> bytes:
+    def id(self) -> ObjectID:
         """The hex SHA of this object."""
-        return self.sha().hexdigest().encode("ascii")
+        return ObjectID(self.sha().hexdigest().encode("ascii"))
 
     def __repr__(self) -> str:
         """Return string representation of this object."""
@@ -1178,7 +1187,7 @@ class TreeEntry(NamedTuple):
 
     path: bytes
     mode: int
-    sha: bytes
+    sha: ObjectID
 
     def in_path(self, path: bytes) -> "TreeEntry":
         """Return a copy of this entry with the given path prepended."""
@@ -1187,7 +1196,9 @@ class TreeEntry(NamedTuple):
         return TreeEntry(posixpath.join(path, self.path), self.mode, self.sha)
 
 
-def parse_tree(text: bytes, strict: bool = False) -> Iterator[tuple[bytes, int, bytes]]:
+def parse_tree(
+    text: bytes, strict: bool = False
+) -> Iterator[tuple[bytes, int, ObjectID]]:
     """Parse a tree text.
 
     Args:
@@ -1215,11 +1226,11 @@ def parse_tree(text: bytes, strict: bool = False) -> Iterator[tuple[bytes, int,
         sha = text[name_end + 1 : count]
         if len(sha) != 20:
             raise ObjectFormatException("Sha has invalid length")
-        hexsha = sha_to_hex(sha)
+        hexsha = sha_to_hex(RawObjectID(sha))
         yield (name, mode, hexsha)
 
 
-def serialize_tree(items: Iterable[tuple[bytes, int, bytes]]) -> Iterator[bytes]:
+def serialize_tree(items: Iterable[tuple[bytes, int, ObjectID]]) -> Iterator[bytes]:
     """Serialize the items in a tree to a text.
 
     Args:
@@ -1233,7 +1244,7 @@ def serialize_tree(items: Iterable[tuple[bytes, int, bytes]]) -> Iterator[bytes]
 
 
 def sorted_tree_items(
-    entries: dict[bytes, tuple[int, bytes]], name_order: bool
+    entries: dict[bytes, tuple[int, ObjectID]], name_order: bool
 ) -> Iterator[TreeEntry]:
     """Iterate over a tree entries dictionary.
 
@@ -1275,7 +1286,7 @@ def key_entry_name_order(entry: tuple[bytes, tuple[int, ObjectID]]) -> bytes:
 
 
 def pretty_format_tree_entry(
-    name: bytes, mode: int, hexsha: bytes, encoding: str = "utf-8"
+    name: bytes, mode: int, hexsha: ObjectID, encoding: str = "utf-8"
 ) -> str:
     """Pretty format tree entry.
 
@@ -1323,7 +1334,7 @@ class Tree(ShaFile):
     def __init__(self) -> None:
         """Initialize an empty Tree."""
         super().__init__()
-        self._entries: dict[bytes, tuple[int, bytes]] = {}
+        self._entries: dict[bytes, tuple[int, ObjectID]] = {}
 
     @classmethod
     def from_path(cls, filename: str | bytes) -> "Tree":
@@ -1377,7 +1388,7 @@ class Tree(ShaFile):
         """Iterate over tree entry names."""
         return iter(self._entries)
 
-    def add(self, name: bytes, mode: int, hexsha: bytes) -> None:
+    def add(self, name: bytes, mode: int, hexsha: ObjectID) -> None:
         """Add an entry to the tree.
 
         Args:
@@ -1698,7 +1709,7 @@ class Commit(ShaFile):
     def __init__(self) -> None:
         """Initialize an empty Commit."""
         super().__init__()
-        self._parents: list[bytes] = []
+        self._parents: list[ObjectID] = []
         self._encoding: bytes | None = None
         self._mergetag: list[Tag] = []
         self._gpgsig: bytes | None = None
@@ -1749,7 +1760,7 @@ class Commit(ShaFile):
                 self._tree = value
             elif field == _PARENT_HEADER:
                 assert value is not None
-                self._parents.append(value)
+                self._parents.append(ObjectID(value))
             elif field == _AUTHOR_HEADER:
                 if value is None:
                     raise ObjectFormatException("missing author value")
@@ -1976,11 +1987,11 @@ class Commit(ShaFile):
 
     tree = serializable_property("tree", "Tree that is the state of this commit")
 
-    def _get_parents(self) -> list[bytes]:
+    def _get_parents(self) -> list[ObjectID]:
         """Return a list of parents of this commit."""
         return self._parents
 
-    def _set_parents(self, value: list[bytes]) -> None:
+    def _set_parents(self, value: list[ObjectID]) -> None:
         """Set a list of parents of this commit."""
         self._needs_serialization = True
         self._parents = value

+ 27 - 18
dulwich/objectspec.py

@@ -24,8 +24,8 @@
 from collections.abc import Sequence
 from typing import TYPE_CHECKING
 
-from .objects import Commit, ShaFile, Tag, Tree
-from .refs import local_branch_name, local_tag_name
+from .objects import Commit, ObjectID, RawObjectID, ShaFile, Tag, Tree
+from .refs import Ref, local_branch_name, local_tag_name
 from .repo import BaseRepo
 
 if TYPE_CHECKING:
@@ -48,7 +48,9 @@ def to_bytes(text: str | bytes) -> bytes:
     return text
 
 
-def _resolve_object(repo: "Repo", ref: bytes) -> "ShaFile":
+def _resolve_object(
+    repo: "Repo", ref: Ref | ObjectID | RawObjectID | bytes
+) -> "ShaFile":
     """Resolve a reference to an object using multiple strategies."""
     try:
         return repo[ref]
@@ -58,7 +60,7 @@ def _resolve_object(repo: "Repo", ref: bytes) -> "ShaFile":
             return repo[ref_sha]
         except KeyError:
             try:
-                return repo.object_store[ref]
+                return repo.object_store[ref]  # type: ignore[index]
             except (KeyError, ValueError):
                 # Re-raise original KeyError for consistency
                 raise KeyError(ref)
@@ -262,12 +264,13 @@ def parse_tree(repo: "BaseRepo", treeish: bytes | str | Tree | Commit | Tag) ->
         treeish = treeish.id
     else:
         treeish = to_bytes(treeish)
+    treeish_typed: Ref | ObjectID
     try:
-        treeish = parse_ref(repo.refs, treeish)
+        treeish_typed = parse_ref(repo.refs, treeish)
     except KeyError:  # treeish is commit sha
-        pass
+        treeish_typed = ObjectID(treeish)
     try:
-        o = repo[treeish]
+        o = repo[treeish_typed]
     except KeyError:
         # Try parsing as commit (handles short hashes)
         try:
@@ -311,7 +314,7 @@ def parse_ref(container: "Repo | RefsContainer", refspec: str | bytes) -> "Ref":
     ]
     for ref in possible_refs:
         if ref in container:
-            return ref
+            return Ref(ref)
     raise KeyError(refspec)
 
 
@@ -336,25 +339,31 @@ def parse_reftuple(
     if refspec.startswith(b"+"):
         force = True
         refspec = refspec[1:]
-    lh: bytes | None
-    rh: bytes | None
+    lh_bytes: bytes | None
+    rh_bytes: bytes | None
     if b":" in refspec:
-        (lh, rh) = refspec.split(b":")
+        (lh_bytes, rh_bytes) = refspec.split(b":")
     else:
-        lh = rh = refspec
-    if lh == b"":
+        lh_bytes = rh_bytes = refspec
+
+    lh: Ref | None
+    if lh_bytes == b"":
         lh = None
     else:
-        lh = parse_ref(lh_container, lh)
-    if rh == b"":
+        lh = parse_ref(lh_container, lh_bytes)
+
+    rh: Ref | None
+    if rh_bytes == b"":
         rh = None
     else:
         try:
-            rh = parse_ref(rh_container, rh)
+            rh = parse_ref(rh_container, rh_bytes)
         except KeyError:
             # TODO: check force?
-            if b"/" not in rh:
-                rh = local_branch_name(rh)
+            if b"/" not in rh_bytes:
+                rh = Ref(local_branch_name(rh_bytes))
+            else:
+                rh = Ref(rh_bytes)
     return (lh, rh, force)
 
 

+ 89 - 65
dulwich/pack.py

@@ -62,6 +62,7 @@ from typing import (
     Generic,
     Protocol,
     TypeVar,
+    cast,
 )
 
 try:
@@ -77,6 +78,7 @@ if TYPE_CHECKING:
     from .bitmap import PackBitmap
     from .commit_graph import CommitGraph
     from .object_store import BaseObjectStore
+    from .refs import Ref
 
 # For some reason the above try, except fails to set has_mmap = False for plan9
 if sys.platform == "Plan9":
@@ -86,7 +88,14 @@ from . import replace_me
 from .errors import ApplyDeltaError, ChecksumMismatch
 from .file import GitFile, _GitFile
 from .lru_cache import LRUSizeCache
-from .objects import ObjectID, ShaFile, hex_to_sha, object_header, sha_to_hex
+from .objects import (
+    ObjectID,
+    RawObjectID,
+    ShaFile,
+    hex_to_sha,
+    object_header,
+    sha_to_hex,
+)
 
 OFS_DELTA = 6
 REF_DELTA = 7
@@ -140,10 +149,10 @@ class ObjectContainer(Protocol):
         Returns: Optional Pack object of the objects written.
         """
 
-    def __contains__(self, sha1: bytes) -> bool:
+    def __contains__(self, sha1: "ObjectID") -> bool:
         """Check if a hex sha is present."""
 
-    def __getitem__(self, sha1: bytes) -> ShaFile:
+    def __getitem__(self, sha1: "ObjectID | RawObjectID") -> ShaFile:
         """Retrieve an object."""
 
     def get_commit_graph(self) -> "CommitGraph | None":
@@ -159,7 +168,7 @@ class PackedObjectContainer(ObjectContainer):
     """Container for objects packed in a pack file."""
 
     def get_unpacked_object(
-        self, sha1: bytes, *, include_comp: bool = False
+        self, sha1: "ObjectID | RawObjectID", *, include_comp: bool = False
     ) -> "UnpackedObject":
         """Get a raw unresolved object.
 
@@ -173,7 +182,7 @@ class PackedObjectContainer(ObjectContainer):
         raise NotImplementedError(self.get_unpacked_object)
 
     def iterobjects_subset(
-        self, shas: Iterable[bytes], *, allow_missing: bool = False
+        self, shas: Iterable["ObjectID"], *, allow_missing: bool = False
     ) -> Iterator[ShaFile]:
         """Iterate over a subset of objects.
 
@@ -188,7 +197,7 @@ class PackedObjectContainer(ObjectContainer):
 
     def iter_unpacked_subset(
         self,
-        shas: Iterable[bytes],
+        shas: Iterable["ObjectID | RawObjectID"],
         *,
         include_comp: bool = False,
         allow_missing: bool = False,
@@ -332,12 +341,12 @@ class UnpackedObject:
             self.obj_chunks = self.decomp_chunks
             self.delta_base = delta_base
 
-    def sha(self) -> bytes:
+    def sha(self) -> RawObjectID:
         """Return the binary SHA of this object."""
         if self._sha is None:
             assert self.obj_type_num is not None and self.obj_chunks is not None
             self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
-        return self._sha
+        return RawObjectID(self._sha)
 
     def sha_file(self) -> ShaFile:
         """Return a ShaFile from this object."""
@@ -547,7 +556,7 @@ def bisect_find_sha(
     return None
 
 
-PackIndexEntry = tuple[bytes, int, int | None]
+PackIndexEntry = tuple[RawObjectID, int, int | None]
 
 
 class PackIndex:
@@ -581,9 +590,9 @@ class PackIndex:
         """Return the number of entries in this pack index."""
         raise NotImplementedError(self.__len__)
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over the SHAs in this pack."""
-        return map(sha_to_hex, self._itersha())
+        return map(lambda sha: sha_to_hex(RawObjectID(sha)), self._itersha())
 
     def iterentries(self) -> Iterator[PackIndexEntry]:
         """Iterate over the entries in this pack index.
@@ -601,7 +610,7 @@ class PackIndex:
         raise NotImplementedError(self.get_pack_checksum)
 
     @replace_me(since="0.21.0", remove_in="0.23.0")
-    def object_index(self, sha: bytes) -> int:
+    def object_index(self, sha: ObjectID | RawObjectID) -> int:
         """Return the index for the given SHA.
 
         Args:
@@ -612,7 +621,7 @@ class PackIndex:
         """
         return self.object_offset(sha)
 
-    def object_offset(self, sha: bytes) -> int:
+    def object_offset(self, sha: ObjectID | RawObjectID) -> int:
         """Return the offset in to the corresponding packfile for the object.
 
         Given the name of an object it will return the offset that object
@@ -648,7 +657,7 @@ class PackIndex:
         """Yield all the SHA1's of the objects in the index, sorted."""
         raise NotImplementedError(self._itersha)
 
-    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+    def iter_prefix(self, prefix: bytes) -> Iterator[RawObjectID]:
         """Iterate over all SHA1s with the given prefix.
 
         Args:
@@ -658,7 +667,7 @@ class PackIndex:
         # Default implementation for PackIndex classes that don't override
         for sha, _, _ in self.iterentries():
             if sha.startswith(prefix):
-                yield sha
+                yield RawObjectID(sha)
 
     def close(self) -> None:
         """Close any open files."""
@@ -672,7 +681,7 @@ class MemoryPackIndex(PackIndex):
 
     def __init__(
         self,
-        entries: list[tuple[bytes, int, int | None]],
+        entries: list[PackIndexEntry],
         pack_checksum: bytes | None = None,
     ) -> None:
         """Create a new MemoryPackIndex.
@@ -697,7 +706,7 @@ class MemoryPackIndex(PackIndex):
         """Return the number of entries in this pack index."""
         return len(self._entries)
 
-    def object_offset(self, sha: bytes) -> int:
+    def object_offset(self, sha: ObjectID | RawObjectID) -> int:
         """Return the offset for the given SHA.
 
         Args:
@@ -705,8 +714,8 @@ class MemoryPackIndex(PackIndex):
         Returns: Offset in the pack file
         """
         if len(sha) == 40:
-            sha = hex_to_sha(sha)
-        return self._by_sha[sha]
+            sha = hex_to_sha(cast(ObjectID, sha))
+        return self._by_sha[cast(RawObjectID, sha)]
 
     def object_sha1(self, offset: int) -> bytes:
         """Return the SHA1 for the object at the given offset."""
@@ -880,7 +889,7 @@ class FilePackIndex(PackIndex):
         """
         return bytes(self._contents[-20:])
 
-    def object_offset(self, sha: bytes) -> int:
+    def object_offset(self, sha: ObjectID | RawObjectID) -> int:
         """Return the offset in to the corresponding packfile for the object.
 
         Given the name of an object it will return the offset that object
@@ -888,7 +897,7 @@ class FilePackIndex(PackIndex):
         have the object then None will be returned.
         """
         if len(sha) == 40:
-            sha = hex_to_sha(sha)
+            sha = hex_to_sha(cast(ObjectID, sha))
         try:
             return self._object_offset(sha)
         except ValueError as exc:
@@ -915,7 +924,7 @@ class FilePackIndex(PackIndex):
             raise KeyError(sha)
         return self._unpack_offset(i)
 
-    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+    def iter_prefix(self, prefix: bytes) -> Iterator[RawObjectID]:
         """Iterate over all SHA1s with the given prefix."""
         start = ord(prefix[:1])
         if start == 0:
@@ -932,7 +941,7 @@ class FilePackIndex(PackIndex):
         for i in range(start, end):
             name: bytes = self._unpack_name(i)
             if name.startswith(prefix):
-                yield name
+                yield RawObjectID(name)
                 started = True
             elif started:
                 break
@@ -960,9 +969,9 @@ class PackIndex1(FilePackIndex):
         self.version = 1
         self._fan_out_table = self._read_fan_out_table(0)
 
-    def _unpack_entry(self, i: int) -> tuple[bytes, int, None]:
+    def _unpack_entry(self, i: int) -> tuple[RawObjectID, int, None]:
         (offset, name) = unpack_from(">L20s", self._contents, (0x100 * 4) + (i * 24))
-        return (name, offset, None)
+        return (RawObjectID(name), offset, None)
 
     def _unpack_name(self, i: int) -> bytes:
         offset = (0x100 * 4) + (i * 24) + 4
@@ -1011,9 +1020,9 @@ class PackIndex2(FilePackIndex):
             self
         )
 
-    def _unpack_entry(self, i: int) -> tuple[bytes, int, int]:
+    def _unpack_entry(self, i: int) -> tuple[RawObjectID, int, int]:
         return (
-            self._unpack_name(i),
+            RawObjectID(self._unpack_name(i)),
             self._unpack_offset(i),
             self._unpack_crc32_checksum(i),
         )
@@ -1091,9 +1100,9 @@ class PackIndex3(FilePackIndex):
             self
         )
 
-    def _unpack_entry(self, i: int) -> tuple[bytes, int, int]:
+    def _unpack_entry(self, i: int) -> tuple[RawObjectID, int, int]:
         return (
-            self._unpack_name(i),
+            RawObjectID(self._unpack_name(i)),
             self._unpack_offset(i),
             self._unpack_crc32_checksum(i),
         )
@@ -1390,7 +1399,9 @@ class PackStreamReader:
 
         pack_sha = bytearray(self._trailer)
         if pack_sha != self.sha.digest():
-            raise ChecksumMismatch(sha_to_hex(bytes(pack_sha)), self.sha.hexdigest())
+            raise ChecksumMismatch(
+                sha_to_hex(RawObjectID(bytes(pack_sha))), self.sha.hexdigest()
+            )
 
 
 class PackStreamCopier(PackStreamReader):
@@ -1663,7 +1674,7 @@ class PackData:
         self,
         progress: Callable[[int, int], None] | None = None,
         resolve_ext_ref: ResolveExtRefFn | None = None,
-    ) -> Iterator[tuple[bytes, int, int | None]]:
+    ) -> Iterator[PackIndexEntry]:
         """Yield entries summarizing the contents of this pack.
 
         Args:
@@ -1683,7 +1694,7 @@ class PackData:
         self,
         progress: ProgressFn | None = None,
         resolve_ext_ref: ResolveExtRefFn | None = None,
-    ) -> list[tuple[bytes, int, int]]:
+    ) -> list[tuple[RawObjectID, int, int]]:
         """Return entries in this pack, sorted by SHA.
 
         Args:
@@ -1883,7 +1894,7 @@ class DeltaChainIterator(Generic[T]):
         self._pending_ofs: dict[int, list[int]] = defaultdict(list)
         self._pending_ref: dict[bytes, list[int]] = defaultdict(list)
         self._full_ofs: list[tuple[int, int]] = []
-        self._ext_refs: list[bytes] = []
+        self._ext_refs: list[RawObjectID] = []
 
     @classmethod
     def for_pack_data(
@@ -1908,7 +1919,7 @@ class DeltaChainIterator(Generic[T]):
     def for_pack_subset(
         cls,
         pack: "Pack",
-        shas: Iterable[bytes],
+        shas: Iterable[ObjectID | RawObjectID],
         *,
         allow_missing: bool = False,
         resolve_ext_ref: ResolveExtRefFn | None = None,
@@ -1928,7 +1939,6 @@ class DeltaChainIterator(Generic[T]):
         walker.set_pack_data(pack.data)
         todo = set()
         for sha in shas:
-            assert isinstance(sha, bytes)
             try:
                 off = pack.index.object_offset(sha)
             except KeyError:
@@ -1951,7 +1961,7 @@ class DeltaChainIterator(Generic[T]):
             elif unpacked.pack_type_num == REF_DELTA:
                 with suppress(KeyError):
                     assert isinstance(unpacked.delta_base, bytes)
-                    base_ofs = pack.index.object_index(unpacked.delta_base)
+                    base_ofs = pack.index.object_index(RawObjectID(unpacked.delta_base))
             if base_ofs is not None and base_ofs not in done:
                 todo.add(base_ofs)
         return walker
@@ -1992,7 +2002,9 @@ class DeltaChainIterator(Generic[T]):
 
     def _ensure_no_pending(self) -> None:
         if self._pending_ref:
-            raise UnresolvedDeltas([sha_to_hex(s) for s in self._pending_ref])
+            raise UnresolvedDeltas(
+                [sha_to_hex(RawObjectID(s)) for s in self._pending_ref]
+            )
 
     def _walk_ref_chains(self) -> Iterator[T]:
         if not self._resolve_ext_ref:
@@ -2009,7 +2021,7 @@ class DeltaChainIterator(Generic[T]):
                 # get popped via a _follow_chain call, or we will raise an
                 # error below.
                 continue
-            self._ext_refs.append(base_sha)
+            self._ext_refs.append(RawObjectID(base_sha))
             self._pending_ref.pop(base_sha)
             for new_offset in pending:
                 yield from self._follow_chain(new_offset, type_num, chunks)  # type: ignore[arg-type]
@@ -2063,7 +2075,7 @@ class DeltaChainIterator(Generic[T]):
         """Iterate over objects in the pack."""
         return self._walk_all_chains()
 
-    def ext_refs(self) -> list[bytes]:
+    def ext_refs(self) -> list[RawObjectID]:
         """Return external references."""
         return self._ext_refs
 
@@ -2088,7 +2100,7 @@ class PackIndexer(DeltaChainIterator[PackIndexEntry]):
 
     _compute_crc32 = True
 
-    def _result(self, unpacked: UnpackedObject) -> tuple[bytes, int, int | None]:
+    def _result(self, unpacked: UnpackedObject) -> PackIndexEntry:
         """Convert unpacked object to pack index entry.
 
         Args:
@@ -2154,9 +2166,12 @@ class SHA1Reader(BinaryIO):
         # If git option index.skipHash is set the index will be empty
         if stored != self.sha1.digest() and (
             not allow_empty
-            or sha_to_hex(stored) != b"0000000000000000000000000000000000000000"
+            or sha_to_hex(RawObjectID(stored))
+            != b"0000000000000000000000000000000000000000"
         ):
-            raise ChecksumMismatch(self.sha1.hexdigest(), sha_to_hex(stored))
+            raise ChecksumMismatch(
+                self.sha1.hexdigest(), sha_to_hex(RawObjectID(stored))
+            )
 
     def close(self) -> None:
         """Close the underlying file."""
@@ -2595,9 +2610,9 @@ def write_pack_header(
 
 def find_reusable_deltas(
     container: PackedObjectContainer,
-    object_ids: Set[bytes],
+    object_ids: Set[ObjectID],
     *,
-    other_haves: Set[bytes] | None = None,
+    other_haves: Set[ObjectID] | None = None,
     progress: Callable[..., None] | None = None,
 ) -> Iterator[UnpackedObject]:
     """Find deltas in a pack that can be reused.
@@ -2799,7 +2814,7 @@ def generate_unpacked_objects(
     deltify: bool | None = None,
     reuse_deltas: bool = True,
     ofs_delta: bool = True,
-    other_haves: set[bytes] | None = None,
+    other_haves: set[ObjectID] | None = None,
     progress: Callable[..., None] | None = None,
 ) -> Iterator[UnpackedObject]:
     """Create pack data from objects.
@@ -2811,7 +2826,7 @@ def generate_unpacked_objects(
         for unpack in find_reusable_deltas(
             container, set(todo), other_haves=other_haves, progress=progress
         ):
-            del todo[sha_to_hex(unpack.sha())]
+            del todo[sha_to_hex(RawObjectID(unpack.sha()))]
             yield unpack
     if deltify is None:
         # PERFORMANCE/TODO(jelmer): This should be enabled but is *much* too
@@ -2860,7 +2875,7 @@ def write_pack_from_container(
     deltify: bool | None = None,
     reuse_deltas: bool = True,
     compression_level: int = -1,
-    other_haves: set[bytes] | None = None,
+    other_haves: set[ObjectID] | None = None,
 ) -> tuple[dict[bytes, tuple[int, int]], bytes]:
     """Write a new pack data file.
 
@@ -3535,7 +3550,7 @@ class Pack:
     def ensure_bitmap(
         self,
         object_store: "BaseObjectStore",
-        refs: dict[bytes, bytes],
+        refs: dict["Ref", "ObjectID"],
         commit_interval: int | None = None,
         progress: Callable[[str], None] | None = None,
     ) -> "PackBitmap":
@@ -3618,7 +3633,7 @@ class Pack:
         """Return string representation of this pack."""
         return f"{self.__class__.__name__}({self._basename!r})"
 
-    def __iter__(self) -> Iterator[bytes]:
+    def __iter__(self) -> Iterator[ObjectID]:
         """Iterate over all the sha1s of the objects in this pack."""
         return iter(self.index)
 
@@ -3634,8 +3649,8 @@ class Pack:
             and idx_stored_checksum != data_stored_checksum
         ):
             raise ChecksumMismatch(
-                sha_to_hex(idx_stored_checksum),
-                sha_to_hex(data_stored_checksum),
+                sha_to_hex(RawObjectID(idx_stored_checksum)),
+                sha_to_hex(RawObjectID(data_stored_checksum)),
             )
 
     def check(self) -> None:
@@ -3658,7 +3673,7 @@ class Pack:
         """Return pack tuples for all objects in pack."""
         return [(o, None) for o in self.iterobjects()]
 
-    def __contains__(self, sha1: bytes) -> bool:
+    def __contains__(self, sha1: ObjectID | RawObjectID) -> bool:
         """Check whether this pack contains a particular SHA1."""
         try:
             self.index.object_offset(sha1)
@@ -3666,14 +3681,14 @@ class Pack:
         except KeyError:
             return False
 
-    def get_raw(self, sha1: bytes) -> tuple[int, bytes]:
+    def get_raw(self, sha1: RawObjectID | ObjectID) -> tuple[int, bytes]:
         """Get raw object data by SHA1."""
         offset = self.index.object_offset(sha1)
         obj_type, obj = self.data.get_object_at(offset)
         type_num, chunks = self.resolve_object(offset, obj_type, obj)
         return type_num, b"".join(chunks)  # type: ignore[arg-type]
 
-    def __getitem__(self, sha1: bytes) -> ShaFile:
+    def __getitem__(self, sha1: "ObjectID | RawObjectID") -> ShaFile:
         """Retrieve the specified SHA1."""
         type, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type, uncomp, sha=sha1)
@@ -3701,7 +3716,7 @@ class Pack:
 
     def iter_unpacked_subset(
         self,
-        shas: Iterable[ObjectID],
+        shas: Iterable[ObjectID | RawObjectID],
         *,
         include_comp: bool = False,
         allow_missing: bool = False,
@@ -3710,12 +3725,12 @@ class Pack:
         """Iterate over unpacked objects in subset."""
         ofs_pending: dict[int, list[UnpackedObject]] = defaultdict(list)
         ofs: dict[int, bytes] = {}
-        todo = set(shas)
+        todo: set[ObjectID | RawObjectID] = set(shas)
         for unpacked in self.iter_unpacked(include_comp=include_comp):
             sha = unpacked.sha()
             if unpacked.offset is not None:
                 ofs[unpacked.offset] = sha
-            hexsha = sha_to_hex(sha)
+            hexsha = sha_to_hex(RawObjectID(sha))
             if hexsha in todo:
                 if unpacked.pack_type_num == OFS_DELTA:
                     assert isinstance(unpacked.delta_base, int)
@@ -3766,7 +3781,9 @@ class Pack:
                 keepfile.write(b"\n")
         return keepfile_name
 
-    def get_ref(self, sha: bytes) -> tuple[int | None, int, OldUnpackedObject]:
+    def get_ref(
+        self, sha: RawObjectID | ObjectID
+    ) -> tuple[int | None, int, OldUnpackedObject]:
         """Get the object for a ref SHA, only looking in this pack."""
         # TODO: cache these results
         try:
@@ -3786,7 +3803,9 @@ class Pack:
         offset: int,
         type: int,
         obj: OldUnpackedObject,
-        get_ref: Callable[[bytes], tuple[int | None, int, OldUnpackedObject]]
+        get_ref: Callable[
+            [RawObjectID | ObjectID], tuple[int | None, int, OldUnpackedObject]
+        ]
         | None = None,
     ) -> tuple[int, OldUnpackedObject]:
         """Resolve an object, possibly resolving deltas when necessary.
@@ -3795,7 +3814,7 @@ class Pack:
         """
         # Walk down the delta chain, building a stack of deltas to reach
         # the requested object.
-        base_offset = offset
+        base_offset: int | None = offset
         base_type = type
         base_obj = obj
         delta_stack = []
@@ -3809,13 +3828,14 @@ class Pack:
                 assert isinstance(delta_offset, int), (
                     f"Expected int, got {delta_offset.__class__}"
                 )
+                assert base_offset is not None
                 base_offset = base_offset - delta_offset
                 base_type, base_obj = self.data.get_object_at(base_offset)
                 assert isinstance(base_type, int)
             elif base_type == REF_DELTA:
                 (basename, delta) = base_obj
                 assert isinstance(basename, bytes) and len(basename) == 20
-                base_offset, base_type, base_obj = get_ref(basename)  # type: ignore[assignment]
+                base_offset, base_type, base_obj = get_ref(cast(RawObjectID, basename))
                 assert isinstance(base_type, int)
                 if base_offset == prev_offset:  # object is based on itself
                     raise UnresolvedDeltas([basename])
@@ -3876,7 +3896,11 @@ class Pack:
         )
 
     def get_unpacked_object(
-        self, sha: bytes, *, include_comp: bool = False, convert_ofs_delta: bool = True
+        self,
+        sha: ObjectID | RawObjectID,
+        *,
+        include_comp: bool = False,
+        convert_ofs_delta: bool = True,
     ) -> UnpackedObject:
         """Get the unpacked object for a sha.
 
@@ -3896,12 +3920,12 @@ class Pack:
 
 def extend_pack(
     f: BinaryIO,
-    object_ids: Set[ObjectID],
-    get_raw: Callable[[ObjectID], tuple[int, bytes]],
+    object_ids: Set["RawObjectID"],
+    get_raw: Callable[["RawObjectID | ObjectID"], tuple[int, bytes]],
     *,
     compression_level: int = -1,
     progress: Callable[[bytes], None] | None = None,
-) -> tuple[bytes, list[tuple[bytes, int, int]]]:
+) -> tuple[bytes, list[tuple["RawObjectID", int, int]]]:
     """Extend a pack file with more objects.
 
     The caller should make sure that object_ids does not contain any objects

+ 9 - 7
dulwich/patch.py

@@ -43,7 +43,7 @@ from typing import (
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
 
-from .objects import S_ISGITLINK, Blob, Commit
+from .objects import S_ISGITLINK, Blob, Commit, ObjectID, RawObjectID
 
 FIRST_FEW_BYTES = 8000
 
@@ -361,8 +361,8 @@ def patch_filename(p: bytes | None, root: bytes) -> bytes:
 def write_object_diff(
     f: IO[bytes],
     store: "BaseObjectStore",
-    old_file: tuple[bytes | None, int | None, bytes | None],
-    new_file: tuple[bytes | None, int | None, bytes | None],
+    old_file: tuple[bytes | None, int | None, ObjectID | None],
+    new_file: tuple[bytes | None, int | None, ObjectID | None],
     diff_binary: bool = False,
     diff_algorithm: str | None = None,
 ) -> None:
@@ -384,7 +384,7 @@ def write_object_diff(
     patched_old_path = patch_filename(old_path, b"a")
     patched_new_path = patch_filename(new_path, b"b")
 
-    def content(mode: int | None, hexsha: bytes | None) -> Blob:
+    def content(mode: int | None, hexsha: ObjectID | None) -> Blob:
         """Get blob content for a file.
 
         Args:
@@ -542,8 +542,8 @@ def write_blob_diff(
 def write_tree_diff(
     f: IO[bytes],
     store: "BaseObjectStore",
-    old_tree: bytes | None,
-    new_tree: bytes | None,
+    old_tree: ObjectID | None,
+    new_tree: ObjectID | None,
     diff_binary: bool = False,
     diff_algorithm: str | None = None,
 ) -> None:
@@ -731,7 +731,9 @@ def patch_id(diff_data: bytes) -> bytes:
     return hashlib.sha1(normalized).hexdigest().encode("ascii")
 
 
-def commit_patch_id(store: "BaseObjectStore", commit_id: bytes) -> bytes:
+def commit_patch_id(
+    store: "BaseObjectStore", commit_id: ObjectID | RawObjectID
+) -> bytes:
     """Compute patch ID for a commit.
 
     Args:

+ 137 - 127
dulwich/porcelain.py

@@ -169,6 +169,7 @@ from .object_store import BaseObjectStore, tree_lookup_path
 from .objects import (
     Blob,
     Commit,
+    ObjectID,
     Tag,
     Tree,
     TreeEntry,
@@ -193,6 +194,7 @@ from .patch import (
 )
 from .protocol import ZERO_SHA, Protocol
 from .refs import (
+    HEADREF,
     LOCAL_BRANCH_PREFIX,
     LOCAL_NOTES_PREFIX,
     LOCAL_REMOTE_PREFIX,
@@ -543,7 +545,7 @@ class DivergedBranches(Error):
         self.new_sha = new_sha
 
 
-def check_diverged(repo: BaseRepo, current_sha: bytes, new_sha: bytes) -> None:
+def check_diverged(repo: BaseRepo, current_sha: ObjectID, new_sha: ObjectID) -> None:
     """Check if updating to a sha can be done with fast forwarding.
 
     Args:
@@ -625,7 +627,7 @@ def symbolic_ref(repo: RepoPath, ref_name: str | bytes, force: bool = False) ->
                 else ref_name
             )
             raise Error(f"fatal: ref `{ref_name_str}` is not a ref")
-        repo_obj.refs.set_symbolic_ref(b"HEAD", ref_path)
+        repo_obj.refs.set_symbolic_ref(HEADREF, ref_path)
 
 
 def pack_refs(repo: RepoPath, all: bool = False) -> None:
@@ -878,7 +880,7 @@ def commit(
             )
             # Update HEAD to point to the new commit with reflog message
             try:
-                old_head = r.refs[b"HEAD"]
+                old_head = r.refs[HEADREF]
             except KeyError:
                 old_head = None
 
@@ -892,7 +894,7 @@ def commit(
                 default_message = default_message[:97] + b"..."
             reflog_message = _get_reflog_message(default_message)
 
-            r.refs.set_if_equals(b"HEAD", old_head, commit_sha, message=reflog_message)
+            r.refs.set_if_equals(HEADREF, old_head, commit_sha, message=reflog_message)
             return commit_sha
         else:
             return r.get_worktree().commit(
@@ -911,11 +913,11 @@ def commit(
 
 def commit_tree(
     repo: RepoPath,
-    tree: bytes,
+    tree: ObjectID,
     message: str | bytes | None = None,
     author: bytes | None = None,
     committer: bytes | None = None,
-) -> bytes:
+) -> ObjectID:
     """Create a new commit object.
 
     Args:
@@ -2002,18 +2004,18 @@ def diff_tree(
     """
     with open_repo_closing(repo) as r:
         if isinstance(old_tree, Tree):
-            old_tree_id: bytes | None = old_tree.id
+            old_tree_id: ObjectID | None = old_tree.id
         elif isinstance(old_tree, str):
-            old_tree_id = old_tree.encode()
+            old_tree_id = ObjectID(old_tree.encode())
         else:
-            old_tree_id = old_tree
+            old_tree_id = ObjectID(old_tree)
 
         if isinstance(new_tree, Tree):
-            new_tree_id: bytes | None = new_tree.id
+            new_tree_id: ObjectID | None = new_tree.id
         elif isinstance(new_tree, str):
-            new_tree_id = new_tree.encode()
+            new_tree_id = ObjectID(new_tree.encode())
         else:
-            new_tree_id = new_tree
+            new_tree_id = ObjectID(new_tree)
 
         write_tree_diff(outstream, r.object_store, old_tree_id, new_tree_id)
 
@@ -2323,7 +2325,7 @@ def submodule_update(
                     sub_config.write_to_path()
 
                     # Checkout the target commit
-                    sub_repo.refs[b"HEAD"] = target_sha
+                    sub_repo.refs[HEADREF] = target_sha
 
                     # Build the index and checkout files
                     tree = sub_repo[target_sha]
@@ -2346,7 +2348,7 @@ def submodule_update(
                     client.fetch(path_segments.encode(), sub_repo)
 
                     # Update to the target commit
-                    sub_repo.refs[b"HEAD"] = target_sha
+                    sub_repo.refs[HEADREF] = target_sha
 
                     # Reset the working directory
                     reset(sub_repo, "hard", target_sha)
@@ -2513,7 +2515,7 @@ def verify_tag(
         tag_obj.verify(keyids)
 
 
-def tag_list(repo: RepoPath, outstream: TextIO = sys.stdout) -> list[bytes]:
+def tag_list(repo: RepoPath, outstream: TextIO = sys.stdout) -> list[Ref]:
     """List all tags.
 
     Args:
@@ -2521,7 +2523,7 @@ def tag_list(repo: RepoPath, outstream: TextIO = sys.stdout) -> list[bytes]:
       outstream: Stream to write tags to
     """
     with open_repo_closing(repo) as r:
-        tags = sorted(r.refs.as_dict(b"refs/tags"))
+        tags: list[Ref] = sorted(r.refs.as_dict(Ref(b"refs/tags")))
         return tags
 
 
@@ -2666,7 +2668,7 @@ def notes_show(
         return r.notes.get_note(object_sha, notes_ref, config=config)
 
 
-def notes_list(repo: RepoPath, ref: bytes = b"commits") -> list[tuple[bytes, bytes]]:
+def notes_list(repo: RepoPath, ref: bytes = b"commits") -> list[tuple[ObjectID, bytes]]:
     """List all notes in a notes ref.
 
     Args:
@@ -2686,7 +2688,7 @@ def notes_list(repo: RepoPath, ref: bytes = b"commits") -> list[tuple[bytes, byt
         return r.notes.list_notes(notes_ref, config=config)
 
 
-def replace_list(repo: RepoPath) -> list[tuple[bytes, bytes]]:
+def replace_list(repo: RepoPath) -> list[tuple[ObjectID, ObjectID]]:
     """List all replacement refs.
 
     Args:
@@ -2697,16 +2699,16 @@ def replace_list(repo: RepoPath) -> list[tuple[bytes, bytes]]:
       object being replaced and replacement_sha is what it's replaced with
     """
     with open_repo_closing(repo) as r:
-        replacements = []
+        replacements: list[tuple[ObjectID, ObjectID]] = []
         for ref in r.refs.keys():
             if ref.startswith(LOCAL_REPLACE_PREFIX):
-                object_sha = ref[len(LOCAL_REPLACE_PREFIX) :]
+                object_sha = ObjectID(ref[len(LOCAL_REPLACE_PREFIX) :])
                 replacement_sha = r.refs[ref]
                 replacements.append((object_sha, replacement_sha))
         return replacements
 
 
-def replace_delete(repo: RepoPath, object_sha: bytes | str) -> None:
+def replace_delete(repo: RepoPath, object_sha: ObjectID | str) -> None:
     """Delete a replacement ref.
 
     Args:
@@ -2714,24 +2716,24 @@ def replace_delete(repo: RepoPath, object_sha: bytes | str) -> None:
       object_sha: SHA of the object whose replacement should be removed
     """
     with open_repo_closing(repo) as r:
-        # Convert to bytes if string
+        # Convert to ObjectID if string
         if isinstance(object_sha, str):
-            object_sha_hex = object_sha.encode("ascii")
+            object_sha_id = ObjectID(object_sha.encode("ascii"))
         else:
-            object_sha_hex = object_sha
+            object_sha_id = object_sha
 
-        replace_ref = _make_replace_ref(object_sha_hex)
+        replace_ref = _make_replace_ref(object_sha_id)
         if replace_ref not in r.refs:
             raise KeyError(
-                f"No replacement ref found for {object_sha_hex.decode('ascii')}"
+                f"No replacement ref found for {object_sha_id.decode('ascii')}"
             )
         del r.refs[replace_ref]
 
 
 def replace_create(
     repo: RepoPath,
-    object_sha: str | bytes,
-    replacement_sha: str | bytes,
+    object_sha: str | ObjectID,
+    replacement_sha: str | ObjectID,
 ) -> None:
     """Create a replacement ref to replace one object with another.
 
@@ -2741,20 +2743,20 @@ def replace_create(
       replacement_sha: SHA of the replacement object
     """
     with open_repo_closing(repo) as r:
-        # Convert to bytes if string
+        # Convert to ObjectID if string
         if isinstance(object_sha, str):
-            object_sha_hex = object_sha.encode("ascii")
+            object_sha_id = ObjectID(object_sha.encode("ascii"))
         else:
-            object_sha_hex = object_sha
+            object_sha_id = object_sha
 
         if isinstance(replacement_sha, str):
-            replacement_sha_hex = replacement_sha.encode("ascii")
+            replacement_sha_id = ObjectID(replacement_sha.encode("ascii"))
         else:
-            replacement_sha_hex = replacement_sha
+            replacement_sha_id = replacement_sha
 
         # Create the replacement ref
-        replace_ref = _make_replace_ref(object_sha_hex)
-        r.refs[replace_ref] = replacement_sha_hex
+        replace_ref = _make_replace_ref(object_sha_id)
+        r.refs[replace_ref] = replacement_sha_id
 
 
 def reset(
@@ -2783,7 +2785,7 @@ def reset(
         if target_commit is not None:
             # Get the current HEAD value for set_if_equals
             try:
-                old_head = r.refs[b"HEAD"]
+                old_head = r.refs[HEADREF]
             except KeyError:
                 old_head = None
 
@@ -2800,7 +2802,7 @@ def reset(
 
             # Update HEAD with reflog message
             r.refs.set_if_equals(
-                b"HEAD", old_head, target_commit.id, message=reflog_message
+                HEADREF, old_head, target_commit.id, message=reflog_message
             )
 
         if mode == "soft":
@@ -2978,14 +2980,14 @@ def push(
         )
 
         selected_refs = []
-        remote_changed_refs: dict[bytes, bytes | None] = {}
+        remote_changed_refs: dict[Ref, ObjectID | None] = {}
 
-        def update_refs(refs: dict[bytes, bytes]) -> dict[bytes, bytes]:
-            remote_refs = DictRefsContainer(refs)
+        def update_refs(refs: dict[Ref, ObjectID]) -> dict[Ref, ObjectID]:
+            remote_refs = DictRefsContainer(refs)  # type: ignore[arg-type]
             selected_refs.extend(
                 parse_reftuples(r.refs, remote_refs, refspecs_bytes, force=force)
             )
-            new_refs = {}
+            new_refs: dict[Ref, ObjectID] = {}
 
             # In mirror mode, delete remote refs that don't exist locally
             if mirror_mode:
@@ -3019,8 +3021,8 @@ def push(
         try:
 
             def generate_pack_data_wrapper(
-                have: AbstractSet[bytes],
-                want: AbstractSet[bytes],
+                have: AbstractSet[ObjectID],
+                want: AbstractSet[ObjectID],
                 *,
                 ofs_delta: bool = False,
                 progress: Callable[..., None] | None = None,
@@ -3046,7 +3048,7 @@ def push(
                 b"Push to " + remote_location.encode(err_encoding) + b" successful.\n"
             )
 
-        for ref, error in (result.ref_status or {}).items():
+        for ref, error in (result.ref_status or {}).items():  # type: ignore[assignment]
             if error is not None:
                 errstream.write(
                     f"Push of ref {ref.decode('utf-8', 'replace')} failed: {error}\n".encode(
@@ -3125,9 +3127,9 @@ def pull(
                     refspecs_normalized.append(spec)
 
         def determine_wants(
-            remote_refs: dict[bytes, bytes], depth: int | None = None
-        ) -> list[bytes]:
-            remote_refs_container = DictRefsContainer(remote_refs)
+            remote_refs: dict[Ref, ObjectID], depth: int | None = None
+        ) -> list[ObjectID]:
+            remote_refs_container = DictRefsContainer(remote_refs)  # type: ignore[arg-type]
             selected_refs.extend(
                 parse_reftuples(
                     remote_refs_container, r.refs, refspecs_normalized, force=force
@@ -3166,7 +3168,7 @@ def pull(
 
         # Store the old HEAD tree before making changes
         try:
-            old_head = r.refs[b"HEAD"]
+            old_head = r.refs[HEADREF]
             old_commit = r[old_head]
             assert isinstance(old_commit, Commit)
             old_tree_id = old_commit.tree
@@ -3202,7 +3204,7 @@ def pull(
             if rh is not None and lh is not None:
                 lh_value = fetch_result.refs[lh]
                 if lh_value is not None:
-                    r.refs[rh] = lh_value
+                    r.refs[Ref(rh)] = lh_value
 
         # Only update HEAD if we didn't perform a merge
         if selected_refs and not merged:
@@ -3833,7 +3835,7 @@ def _make_tag_ref(name: str | bytes) -> Ref:
     return local_tag_name(name)
 
 
-def _make_replace_ref(name: str | bytes) -> Ref:
+def _make_replace_ref(name: str | bytes | ObjectID) -> Ref:
     if isinstance(name, str):
         name = name.encode(DEFAULT_ENCODING)
     return local_replace_name(name)
@@ -3881,7 +3883,7 @@ def branch_create(
             else objectish
         )
 
-        if b"refs/remotes/" + objectish_bytes in r.refs:
+        if Ref(b"refs/remotes/" + objectish_bytes) in r.refs:
             objectish = b"refs/remotes/" + objectish_bytes
         elif local_branch_name(objectish_bytes) in r.refs:
             objectish = local_branch_name(objectish_bytes)
@@ -3918,15 +3920,15 @@ def branch_create(
                 else original_objectish
             )
 
-            if objectish_bytes in r.refs:
+            if Ref(objectish_bytes) in r.refs:
                 objectish_ref = objectish_bytes
-            elif b"refs/remotes/" + objectish_bytes in r.refs:
+            elif Ref(b"refs/remotes/" + objectish_bytes) in r.refs:
                 objectish_ref = b"refs/remotes/" + objectish_bytes
             elif local_branch_name(objectish_bytes) in r.refs:
                 objectish_ref = local_branch_name(objectish_bytes)
         else:
             # HEAD might point to a remote-tracking branch
-            head_ref = r.refs.follow(b"HEAD")[0][1]
+            head_ref = r.refs.follow(HEADREF)[0][1]
             if head_ref.startswith(b"refs/remotes/"):
                 objectish_ref = head_ref
 
@@ -3974,7 +3976,7 @@ def filter_branches_by_pattern(branches: Iterable[bytes], pattern: str) -> list[
     ]
 
 
-def branch_list(repo: RepoPath) -> list[bytes]:
+def branch_list(repo: RepoPath) -> list[Ref]:
     """List all branches.
 
     Args:
@@ -3983,7 +3985,7 @@ def branch_list(repo: RepoPath) -> list[bytes]:
       List of branch names (without refs/heads/ prefix)
     """
     with open_repo_closing(repo) as r:
-        branches = list(r.refs.keys(base=LOCAL_BRANCH_PREFIX))
+        branches: list[Ref] = list(r.refs.keys(base=Ref(LOCAL_BRANCH_PREFIX)))
 
         # Check for branch.sort configuration
         config = r.get_config_stack()
@@ -4040,7 +4042,7 @@ def branch_remotes_list(repo: RepoPath) -> list[bytes]:
       List of branch names (without refs/remotes/ prefix, and without remote name; e.g. 'main' from 'origin/main')
     """
     with open_repo_closing(repo) as r:
-        branches = list(r.refs.keys(base=LOCAL_REMOTE_PREFIX))
+        branches = [bytes(ref) for ref in r.refs.keys(base=Ref(LOCAL_REMOTE_PREFIX))]
 
         config = r.get_config_stack()
         try:
@@ -4063,7 +4065,7 @@ def branch_remotes_list(repo: RepoPath) -> list[bytes]:
             # Sort by date
             def get_commit_date(branch_name: bytes) -> int:
                 ref = LOCAL_REMOTE_PREFIX + branch_name
-                sha = r.refs[ref]
+                sha = r.refs[Ref(ref)]
                 commit = r.object_store[sha]
                 assert isinstance(commit, Commit)
                 if sort_key == "committerdate":
@@ -4100,9 +4102,9 @@ def _get_branch_merge_status(repo: RepoPath) -> Iterator[tuple[bytes, bool]]:
         - ``is_merged``: True if branch is merged into HEAD, False otherwise
     """
     with open_repo_closing(repo) as r:
-        current_sha = r.refs[b"HEAD"]
+        current_sha = r.refs[HEADREF]
 
-        for branch_ref, branch_sha in r.refs.as_dict(base=b"refs/heads/").items():
+        for branch_ref, branch_sha in r.refs.as_dict(base=Ref(b"refs/heads/")).items():
             # Check if branch is an ancestor of HEAD (fully merged)
             is_merged = can_fast_forward(r, branch_sha, current_sha)
             yield branch_ref, is_merged
@@ -4154,7 +4156,9 @@ def branches_containing(repo: RepoPath, commit: str) -> Iterator[bytes]:
         commit_obj = parse_commit(r, commit)
         commit_sha = commit_obj.id
 
-        for branch_ref, branch_sha in r.refs.as_dict(base=LOCAL_BRANCH_PREFIX).items():
+        for branch_ref, branch_sha in r.refs.as_dict(
+            base=Ref(LOCAL_BRANCH_PREFIX)
+        ).items():
             if can_fast_forward(r, commit_sha, branch_sha):
                 yield branch_ref
 
@@ -4171,7 +4175,7 @@ def active_branch(repo: RepoPath) -> bytes:
       IndexError: if HEAD is floating
     """
     with open_repo_closing(repo) as r:
-        active_ref = r.refs.follow(b"HEAD")[0][1]
+        active_ref = r.refs.follow(HEADREF)[0][1]
         if not active_ref.startswith(LOCAL_BRANCH_PREFIX):
             raise ValueError(active_ref)
         return active_ref[len(LOCAL_BRANCH_PREFIX) :]
@@ -4352,7 +4356,7 @@ def for_each_ref(
         refs = r.get_refs()
 
     if pattern:
-        matching_refs: dict[bytes, bytes] = {}
+        matching_refs: dict[Ref, ObjectID] = {}
         pattern_parts = pattern.split(b"/")
         for ref, sha in refs.items():
             matches = False
@@ -4429,12 +4433,12 @@ def show_ref(
             filtered_refs = filter_ref_prefix(refs, [b"refs/"])
 
         # Add HEAD if requested
-        if head and b"HEAD" in refs:
-            filtered_refs[b"HEAD"] = refs[b"HEAD"]
+        if head and HEADREF in refs:
+            filtered_refs[HEADREF] = refs[HEADREF]
 
         # Filter by patterns if specified
         if byte_patterns:
-            matching_refs: dict[bytes, bytes] = {}
+            matching_refs: dict[Ref, ObjectID] = {}
             for ref, sha in filtered_refs.items():
                 for pattern in byte_patterns:
                     if verify:
@@ -4527,7 +4531,7 @@ def show_branch(
         refs = r.get_refs()
 
         # Determine which branches to show
-        branch_refs: dict[bytes, bytes] = {}
+        branch_refs: dict[Ref, ObjectID] = {}
 
         if branches:
             # Specific branches requested
@@ -4536,18 +4540,19 @@ def show_branch(
                     os.fsencode(branch) if isinstance(branch, str) else branch
                 )
                 # Try as full ref name first
-                if branch_bytes in refs:
-                    branch_refs[branch_bytes] = refs[branch_bytes]
+                branch_ref_check = Ref(branch_bytes)
+                if branch_ref_check in refs:
+                    branch_refs[branch_ref_check] = refs[branch_ref_check]
                 else:
                     # Try as branch name
                     branch_ref = local_branch_name(branch_bytes)
                     if branch_ref in refs:
                         branch_refs[branch_ref] = refs[branch_ref]
                     # Try as remote branch
-                    elif LOCAL_REMOTE_PREFIX + branch_bytes in refs:
-                        branch_refs[LOCAL_REMOTE_PREFIX + branch_bytes] = refs[
-                            LOCAL_REMOTE_PREFIX + branch_bytes
-                        ]
+                    else:
+                        remote_ref = Ref(LOCAL_REMOTE_PREFIX + branch_bytes)
+                        if remote_ref in refs:
+                            branch_refs[remote_ref] = refs[remote_ref]
         else:
             # Default behavior: show local branches
             if all_branches:
@@ -4565,7 +4570,7 @@ def show_branch(
         # Add current branch if requested and not already included
         if current:
             try:
-                head_refs, _ = r.refs.follow(b"HEAD")
+                head_refs, _ = r.refs.follow(HEADREF)
                 if head_refs:
                     head_ref = head_refs[0]
                     if head_ref not in branch_refs and head_ref in refs:
@@ -4579,7 +4584,7 @@ def show_branch(
 
         # Sort branches for consistent output
         sorted_branches = sorted(branch_refs.items(), key=lambda x: x[0])
-        branch_sha_list = [sha for _, sha in sorted_branches]
+        branch_sha_list: list[ObjectID] = [sha for _, sha in sorted_branches]
 
         # Handle --independent flag
         if independent_branches:
@@ -4604,7 +4609,7 @@ def show_branch(
         # Get current branch for marking
         current_branch: bytes | None = None
         try:
-            head_refs, _ = r.refs.follow(b"HEAD")
+            head_refs, _ = r.refs.follow(HEADREF)
             if head_refs:
                 current_branch = head_refs[0]
         except (KeyError, TypeError):
@@ -4837,7 +4842,7 @@ def repack(repo: RepoPath, write_bitmaps: bool = False) -> None:
 
 def pack_objects(
     repo: RepoPath,
-    object_ids: Sequence[bytes],
+    object_ids: Sequence[ObjectID],
     packf: BinaryIO,
     idxf: BinaryIO | None,
     delta_window_size: int | None = None,
@@ -4889,7 +4894,7 @@ def ls_tree(
       name_only: Only print item name
     """
 
-    def list_tree(store: BaseObjectStore, treeid: bytes, base: bytes) -> None:
+    def list_tree(store: BaseObjectStore, treeid: ObjectID, base: bytes) -> None:
         tree = store[treeid]
         assert isinstance(tree, Tree)
         for name, mode, sha in tree.iteritems():
@@ -5057,7 +5062,7 @@ def check_ignore(
                 yield _quote_path(output_path) if quote_path else output_path
 
 
-def _get_current_head_tree(repo: Repo) -> bytes | None:
+def _get_current_head_tree(repo: Repo) -> ObjectID | None:
     """Get the current HEAD tree ID.
 
     Args:
@@ -5067,10 +5072,10 @@ def _get_current_head_tree(repo: Repo) -> bytes | None:
       Tree ID of current HEAD, or None if no HEAD exists (empty repo)
     """
     try:
-        current_head = repo.refs[b"HEAD"]
+        current_head = repo.refs[HEADREF]
         current_commit = repo[current_head]
         assert isinstance(current_commit, Commit), "Expected a Commit object"
-        tree_id: bytes = current_commit.tree
+        tree_id: ObjectID = current_commit.tree
         return tree_id
     except KeyError:
         # No HEAD yet (empty repo)
@@ -5078,7 +5083,7 @@ def _get_current_head_tree(repo: Repo) -> bytes | None:
 
 
 def _check_uncommitted_changes(
-    repo: Repo, target_tree_id: bytes, force: bool = False
+    repo: Repo, target_tree_id: ObjectID, force: bool = False
 ) -> None:
     """Check for uncommitted changes that would conflict with a checkout/switch.
 
@@ -5182,8 +5187,8 @@ def _get_worktree_update_config(
 
 def _perform_tree_switch(
     repo: Repo,
-    current_tree_id: bytes | None,
-    target_tree_id: bytes,
+    current_tree_id: ObjectID | None,
+    target_tree_id: ObjectID,
     force: bool = False,
 ) -> None:
     """Perform the actual working tree switch.
@@ -5239,7 +5244,7 @@ def update_head(
         if new_branch is not None:
             to_set = _make_branch_ref(new_branch)
         else:
-            to_set = b"HEAD"
+            to_set = HEADREF
         if detached:
             # TODO(jelmer): Provide some way so that the actual ref gets
             # updated rather than what it points to, so the delete isn't
@@ -5249,7 +5254,7 @@ def update_head(
         else:
             r.refs.set_symbolic_ref(to_set, parse_ref(r, target))
         if new_branch is not None:
-            r.refs.set_symbolic_ref(b"HEAD", to_set)
+            r.refs.set_symbolic_ref(HEADREF, to_set)
 
 
 def checkout(
@@ -5297,7 +5302,7 @@ def checkout(
             # If no target specified, use HEAD
             if target is None:
                 try:
-                    target = r.refs[b"HEAD"]
+                    target = r.refs[HEADREF]
                 except KeyError:
                     raise CheckoutError("No HEAD reference found")
             else:
@@ -5464,7 +5469,7 @@ def restore(
             if staged:
                 # Restoring staged files from HEAD
                 try:
-                    source = r.refs[b"HEAD"]
+                    source = r.refs[HEADREF]
                 except KeyError:
                     raise CheckoutError("No HEAD reference found")
             elif worktree:
@@ -5626,8 +5631,6 @@ def switch(
             update_head(r, create)
 
             # Set up tracking if creating from a remote branch
-            from .refs import LOCAL_REMOTE_PREFIX, local_branch_name, parse_remote_ref
-
             if isinstance(original_target, bytes) and target_bytes.startswith(
                 LOCAL_REMOTE_PREFIX
             ):
@@ -6121,13 +6124,13 @@ def write_tree(repo: RepoPath) -> bytes:
 
 def _do_merge(
     r: Repo,
-    merge_commit_id: bytes,
+    merge_commit_id: ObjectID,
     no_commit: bool = False,
     no_ff: bool = False,
     message: bytes | None = None,
     author: bytes | None = None,
     committer: bytes | None = None,
-) -> tuple[bytes | None, list[bytes]]:
+) -> tuple[ObjectID | None, list[bytes]]:
     """Internal merge implementation that operates on an open repository.
 
     Args:
@@ -6148,7 +6151,7 @@ def _do_merge(
 
     # Get HEAD commit
     try:
-        head_commit_id = r.refs[b"HEAD"]
+        head_commit_id = r.refs[HEADREF]
     except KeyError:
         raise Error("No HEAD reference found")
 
@@ -6174,7 +6177,7 @@ def _do_merge(
     # Check for fast-forward
     if base_commit_id == head_commit_id and not no_ff:
         # Fast-forward merge
-        r.refs[b"HEAD"] = merge_commit_id
+        r.refs[HEADREF] = merge_commit_id
         # Update the working directory
         changes = tree_changes(r.object_store, head_commit.tree, merge_commit.tree)
         update_working_tree(
@@ -6235,20 +6238,20 @@ def _do_merge(
     r.object_store.add_object(merge_commit_obj)
 
     # Update HEAD
-    r.refs[b"HEAD"] = merge_commit_obj.id
+    r.refs[HEADREF] = merge_commit_obj.id
 
     return (merge_commit_obj.id, [])
 
 
 def _do_octopus_merge(
     r: Repo,
-    merge_commit_ids: list[bytes],
+    merge_commit_ids: list[ObjectID],
     no_commit: bool = False,
     no_ff: bool = False,
     message: bytes | None = None,
     author: bytes | None = None,
     committer: bytes | None = None,
-) -> tuple[bytes | None, list[bytes]]:
+) -> tuple[ObjectID | None, list[bytes]]:
     """Internal octopus merge implementation that operates on an open repository.
 
     Args:
@@ -6269,7 +6272,7 @@ def _do_octopus_merge(
 
     # Get HEAD commit
     try:
-        head_commit_id = r.refs[b"HEAD"]
+        head_commit_id = r.refs[HEADREF]
     except KeyError:
         raise Error("No HEAD reference found")
 
@@ -6367,7 +6370,7 @@ def _do_octopus_merge(
     r.object_store.add_object(merge_commit_obj)
 
     # Update HEAD
-    r.refs[b"HEAD"] = merge_commit_obj.id
+    r.refs[HEADREF] = merge_commit_obj.id
 
     return (merge_commit_obj.id, [])
 
@@ -6544,7 +6547,7 @@ def cherry(
         if upstream is None:
             # Try to find tracking branch
             upstream_found = False
-            head_refs, _ = r.refs.follow(b"HEAD")
+            head_refs, _ = r.refs.follow(HEADREF)
             if head_refs:
                 head_ref = head_refs[0]
                 if head_ref.startswith(b"refs/heads/"):
@@ -6566,7 +6569,7 @@ def cherry(
 
                         if remote_name:
                             # Build the tracking branch ref
-                            upstream_refname = (
+                            upstream_refname = Ref(
                                 b"refs/remotes/"
                                 + remote_name
                                 + b"/"
@@ -6578,7 +6581,7 @@ def cherry(
 
             if not upstream_found:
                 # Default to HEAD^ if no tracking branch found
-                head_commit = r[b"HEAD"]
+                head_commit = r[HEADREF]
                 if isinstance(head_commit, Commit) and head_commit.parents:
                     upstream = head_commit.parents[0]
                 else:
@@ -6873,7 +6876,7 @@ def revert(
 
         # Get current HEAD
         try:
-            head_commit_id = r.refs[b"HEAD"]
+            head_commit_id = r.refs[HEADREF]
         except KeyError:
             raise Error("No HEAD reference found")
 
@@ -6973,7 +6976,7 @@ def revert(
                 r.object_store.add_object(revert_commit)
 
                 # Update HEAD
-                r.refs[b"HEAD"] = revert_commit.id
+                r.refs[HEADREF] = revert_commit.id
                 head_commit_id = revert_commit.id
 
         return head_commit_id if not no_commit else None
@@ -7362,17 +7365,17 @@ def filter_branch(
     filter_author: Callable[[bytes], bytes | None] | None = None,
     filter_committer: Callable[[bytes], bytes | None] | None = None,
     filter_message: Callable[[bytes], bytes | None] | None = None,
-    tree_filter: Callable[[bytes, str], bytes | None] | None = None,
-    index_filter: Callable[[bytes, str], bytes | None] | None = None,
-    parent_filter: Callable[[Sequence[bytes]], list[bytes]] | None = None,
-    commit_filter: Callable[[Commit, bytes], bytes | None] | None = None,
+    tree_filter: Callable[[ObjectID, str], ObjectID | None] | None = None,
+    index_filter: Callable[[ObjectID, str], ObjectID | None] | None = None,
+    parent_filter: Callable[[Sequence[ObjectID]], list[ObjectID]] | None = None,
+    commit_filter: Callable[[Commit, ObjectID], ObjectID | None] | None = None,
     subdirectory_filter: str | bytes | None = None,
     prune_empty: bool = False,
     tag_name_filter: Callable[[bytes], bytes | None] | None = None,
     force: bool = False,
     keep_original: bool = True,
     refs: list[bytes] | None = None,
-) -> dict[bytes, bytes]:
+) -> dict[ObjectID, ObjectID]:
     """Rewrite branch history by creating new commits with filtered properties.
 
     This is similar to git filter-branch, allowing you to rewrite commit
@@ -7422,7 +7425,7 @@ def filter_branch(
             if branch == b"HEAD":
                 # Resolve HEAD to actual branch
                 try:
-                    resolved = r.refs.follow(b"HEAD")
+                    resolved = r.refs.follow(HEADREF)
                     if resolved and resolved[0]:
                         # resolved is a list of (refname, sha) tuples
                         resolved_ref = resolved[0][-1]
@@ -7465,7 +7468,7 @@ def filter_branch(
         )
 
         # Tag callback for renaming tags
-        def rename_tag(old_ref: bytes, new_ref: bytes) -> None:
+        def rename_tag(old_ref: Ref, new_ref: Ref) -> None:
             # Copy tag to new name
             r.refs[new_ref] = r.refs[old_ref]
             # Delete old tag
@@ -7488,7 +7491,7 @@ def filter_branch(
 
 def format_patch(
     repo: RepoPath = ".",
-    committish: bytes | tuple[bytes, bytes] | None = None,
+    committish: ObjectID | tuple[ObjectID, ObjectID] | None = None,
     outstream: TextIO = sys.stdout,
     outdir: str | os.PathLike[str] | None = None,
     n: int = 1,
@@ -7664,7 +7667,7 @@ def bisect_start(
                 old_commit = r[r.head()]
                 assert isinstance(old_commit, Commit)
                 old_tree = old_commit.tree if r.head() else None
-                r.refs[b"HEAD"] = next_sha
+                r.refs[HEADREF] = next_sha
                 commit = r[next_sha]
                 assert isinstance(commit, Commit)
                 changes = tree_changes(r.object_store, old_tree, commit.tree)
@@ -7696,7 +7699,7 @@ def bisect_bad(
             old_commit = r[r.head()]
             assert isinstance(old_commit, Commit)
             old_tree = old_commit.tree if r.head() else None
-            r.refs[b"HEAD"] = next_sha
+            r.refs[HEADREF] = next_sha
             commit = r[next_sha]
             assert isinstance(commit, Commit)
             changes = tree_changes(r.object_store, old_tree, commit.tree)
@@ -7728,7 +7731,7 @@ def bisect_good(
             old_commit = r[r.head()]
             assert isinstance(old_commit, Commit)
             old_tree = old_commit.tree if r.head() else None
-            r.refs[b"HEAD"] = next_sha
+            r.refs[HEADREF] = next_sha
             commit = r[next_sha]
             assert isinstance(commit, Commit)
             changes = tree_changes(r.object_store, old_tree, commit.tree)
@@ -7773,7 +7776,7 @@ def bisect_skip(
             old_commit = r[r.head()]
             assert isinstance(old_commit, Commit)
             old_tree = old_commit.tree if r.head() else None
-            r.refs[b"HEAD"] = next_sha
+            r.refs[HEADREF] = next_sha
             commit = r[next_sha]
             assert isinstance(commit, Commit)
             changes = tree_changes(r.object_store, old_tree, commit.tree)
@@ -7942,7 +7945,7 @@ def reflog_expire(
             refs_to_process = [ref]
 
         # Build set of reachable objects if we have unreachable expiration time
-        reachable_objects: set[bytes] | None = None
+        reachable_objects: set[ObjectID] | None = None
         if expire_unreachable_time is not None:
             from .gc import find_reachable_objects
 
@@ -8461,9 +8464,13 @@ def lfs_fetch(
 
         for ref in refs:
             if isinstance(ref, str):
-                ref = ref.encode()
+                ref_key = Ref(ref.encode())
+            elif isinstance(ref, bytes):
+                ref_key = Ref(ref)
+            else:
+                ref_key = ref
             try:
-                commit = r[r.refs[ref]]
+                commit = r[r.refs[ref_key]]
             except KeyError:
                 continue
 
@@ -8583,19 +8590,21 @@ def lfs_push(
         # Find all LFS objects to push
         if refs is None:
             # Push current branch
-            head_ref = r.refs.read_ref(b"HEAD")
+            head_ref = r.refs.read_ref(HEADREF)
             refs = [head_ref] if head_ref else []
 
         objects_to_push = set()
 
         for ref in refs:
             if isinstance(ref, str):
-                ref = ref.encode()
+                ref_bytes = ref.encode()
+            else:
+                ref_bytes = ref
             try:
-                if ref.startswith(b"refs/"):
-                    commit = r[r.refs[ref]]
+                if ref_bytes.startswith(b"refs/"):
+                    commit = r[r.refs[Ref(ref_bytes)]]
                 else:
-                    commit = r[ref]
+                    commit = r[ref_bytes]
             except KeyError:
                 continue
 
@@ -8736,8 +8745,9 @@ def worktree_add(
 
     with open_repo_closing(repo) as r:
         commit_bytes = commit.encode() if isinstance(commit, str) else commit
+        commit_id = ObjectID(commit_bytes) if commit_bytes is not None else None
         wt_repo = add_worktree(
-            r, path, branch=branch, commit=commit_bytes, detach=detach, force=force
+            r, path, branch=branch, commit=commit_id, detach=detach, force=force
         )
         return wt_repo.path
 
@@ -8867,7 +8877,7 @@ def merge_base(
     committishes: Sequence[str | bytes] | None = None,
     all: bool = False,
     octopus: bool = False,
-) -> list[bytes]:
+) -> list[ObjectID]:
     """Find the best common ancestor(s) between commits.
 
     Args:
@@ -8946,7 +8956,7 @@ def is_ancestor(
 def independent_commits(
     repo: RepoPath = ".",
     committishes: Sequence[str | bytes] | None = None,
-) -> list[bytes]:
+) -> list[ObjectID]:
     """Filter commits to only those that are not reachable from others.
 
     Args:

+ 5 - 1
dulwich/protocol.py

@@ -30,6 +30,7 @@ from os import SEEK_END
 import dulwich
 
 from .errors import GitProtocolError, HangupException
+from .objects import ObjectID
 
 TCP_GIT_PORT = 9418
 
@@ -49,7 +50,10 @@ GIT_PROTOCOL_VERSIONS = [0, 1, 2]
 DEFAULT_GIT_PROTOCOL_VERSION_FETCH = 2
 DEFAULT_GIT_PROTOCOL_VERSION_SEND = 0
 
-ZERO_SHA = b"0" * 40
+# Suffix used in the Git protocol to indicate peeled tag references
+PEELED_TAG_SUFFIX = b"^{}"
+
+ZERO_SHA: ObjectID = ObjectID(b"0" * 40)
 
 SINGLE_ACK = 0
 MULTI_ACK = 1

+ 39 - 35
dulwich/rebase.py

@@ -31,9 +31,9 @@ from typing import Protocol, TypedDict
 
 from dulwich.graph import find_merge_base
 from dulwich.merge import three_way_merge
-from dulwich.objects import Commit
+from dulwich.objects import Commit, ObjectID
 from dulwich.objectspec import parse_commit
-from dulwich.refs import local_branch_name
+from dulwich.refs import HEADREF, Ref, local_branch_name, set_ref_from_raw
 from dulwich.repo import BaseRepo, Repo
 
 
@@ -119,7 +119,7 @@ class RebaseTodoEntry:
     """Represents a single entry in a rebase todo list."""
 
     command: RebaseTodoCommand
-    commit_sha: bytes | None = None  # Store as hex string encoded as bytes
+    commit_sha: ObjectID | None = None  # Store as hex string encoded as bytes
     short_message: str | None = None
     arguments: str | None = None
 
@@ -209,7 +209,7 @@ class RebaseTodoEntry:
             # Commands that operate on commits
             if len(parts) > 1:
                 # Store SHA as hex string encoded as bytes
-                commit_sha = parts[1].encode()
+                commit_sha = ObjectID(parts[1].encode())
 
                 # Parse commit message if present
                 if len(parts) > 2:
@@ -374,8 +374,8 @@ class RebaseStateManager(Protocol):
     def save(
         self,
         original_head: bytes | None,
-        rebasing_branch: bytes | None,
-        onto: bytes | None,
+        rebasing_branch: Ref | None,
+        onto: ObjectID | None,
         todo: list[Commit],
         done: list[Commit],
     ) -> None:
@@ -386,8 +386,8 @@ class RebaseStateManager(Protocol):
         self,
     ) -> tuple[
         bytes | None,  # original_head
-        bytes | None,  # rebasing_branch
-        bytes | None,  # onto
+        Ref | None,  # rebasing_branch
+        ObjectID | None,  # onto
         list[Commit],  # todo
         list[Commit],  # done
     ]:
@@ -425,8 +425,8 @@ class DiskRebaseStateManager:
     def save(
         self,
         original_head: bytes | None,
-        rebasing_branch: bytes | None,
-        onto: bytes | None,
+        rebasing_branch: Ref | None,
+        onto: ObjectID | None,
         todo: list[Commit],
         done: list[Commit],
     ) -> None:
@@ -467,22 +467,26 @@ class DiskRebaseStateManager:
         self,
     ) -> tuple[
         bytes | None,
-        bytes | None,
-        bytes | None,
+        Ref | None,
+        ObjectID | None,
         list[Commit],
         list[Commit],
     ]:
         """Load rebase state from disk."""
         original_head = None
-        rebasing_branch = None
-        onto = None
+        rebasing_branch_bytes = None
+        onto_bytes = None
         todo: list[Commit] = []
         done: list[Commit] = []
 
         # Load rebase state files
         original_head = self._read_file("orig-head")
-        rebasing_branch = self._read_file("head-name")
-        onto = self._read_file("onto")
+        rebasing_branch_bytes = self._read_file("head-name")
+        rebasing_branch = (
+            Ref(rebasing_branch_bytes) if rebasing_branch_bytes is not None else None
+        )
+        onto_bytes = self._read_file("onto")
+        onto = ObjectID(onto_bytes) if onto_bytes is not None else None
 
         return original_head, rebasing_branch, onto, todo, done
 
@@ -532,8 +536,8 @@ class RebaseState(TypedDict):
     """Type definition for rebase state."""
 
     original_head: bytes | None
-    rebasing_branch: bytes | None
-    onto: bytes | None
+    rebasing_branch: Ref | None
+    onto: ObjectID | None
     todo: list[Commit]
     done: list[Commit]
 
@@ -554,8 +558,8 @@ class MemoryRebaseStateManager:
     def save(
         self,
         original_head: bytes | None,
-        rebasing_branch: bytes | None,
-        onto: bytes | None,
+        rebasing_branch: Ref | None,
+        onto: ObjectID | None,
         todo: list[Commit],
         done: list[Commit],
     ) -> None:
@@ -572,8 +576,8 @@ class MemoryRebaseStateManager:
         self,
     ) -> tuple[
         bytes | None,
-        bytes | None,
-        bytes | None,
+        Ref | None,
+        ObjectID | None,
         list[Commit],
         list[Commit],
     ]:
@@ -630,10 +634,10 @@ class Rebaser:
 
         # Initialize state
         self._original_head: bytes | None = None
-        self._onto: bytes | None = None
+        self._onto: ObjectID | None = None
         self._todo: list[Commit] = []
         self._done: list[Commit] = []
-        self._rebasing_branch: bytes | None = None
+        self._rebasing_branch: Ref | None = None
 
         # Load any existing rebase state
         self._load_rebase_state()
@@ -653,7 +657,7 @@ class Rebaser:
         # Get the branch commit
         if branch is None:
             # Use current HEAD
-            _head_ref, head_sha = self.repo.refs.follow(b"HEAD")
+            _head_ref, head_sha = self.repo.refs.follow(HEADREF)
             if head_sha is None:
                 raise ValueError("HEAD does not point to a valid commit")
             branch_commit = self.repo[head_sha]
@@ -688,8 +692,8 @@ class Rebaser:
         return list(reversed(commits))
 
     def _cherry_pick(
-        self, commit: Commit, onto: bytes
-    ) -> tuple[bytes | None, list[bytes]]:
+        self, commit: Commit, onto: ObjectID
+    ) -> tuple[ObjectID | None, list[bytes]]:
         """Cherry-pick a commit onto another commit.
 
         Args:
@@ -754,22 +758,22 @@ class Rebaser:
             List of commits that will be rebased
         """
         # Save original HEAD
-        self._original_head = self.repo.refs.read_ref(b"HEAD")
+        self._original_head = self.repo.refs.read_ref(HEADREF)
 
         # Save which branch we're rebasing (for later update)
         if branch is not None:
             # Parse the branch ref
             if branch.startswith(b"refs/heads/"):
-                self._rebasing_branch = branch
+                self._rebasing_branch = Ref(branch)
             else:
                 # Assume it's a branch name
-                self._rebasing_branch = local_branch_name(branch)
+                self._rebasing_branch = Ref(local_branch_name(branch))
         else:
             # Use current branch
             if self._original_head is not None and self._original_head.startswith(
                 b"ref: "
             ):
-                self._rebasing_branch = self._original_head[5:]
+                self._rebasing_branch = Ref(self._original_head[5:])
             else:
                 self._rebasing_branch = None
 
@@ -844,7 +848,7 @@ class Rebaser:
         # Restore original HEAD
         if self._original_head is None:
             raise RebaseError("No original HEAD to restore")
-        self.repo.refs[b"HEAD"] = self._original_head
+        set_ref_from_raw(self.repo.refs, HEADREF, self._original_head)
 
         # Clean up rebase state
         self._clean_rebase_state()
@@ -870,13 +874,13 @@ class Rebaser:
             # If HEAD was pointing to this branch, it will follow automatically
         else:
             # If we don't know which branch, check current HEAD
-            head_ref = self.repo.refs[b"HEAD"]
+            head_ref = self.repo.refs[HEADREF]
             if head_ref.startswith(b"ref: "):
-                branch_ref = head_ref[5:]
+                branch_ref = Ref(head_ref[5:])
                 self.repo.refs[branch_ref] = last_commit.id
             else:
                 # Detached HEAD
-                self.repo.refs[b"HEAD"] = last_commit.id
+                self.repo.refs[HEADREF] = last_commit.id
 
         # Clean up rebase state
         self._clean_rebase_state()

Разлика између датотеке није приказан због своје велике величине
+ 210 - 171
dulwich/refs.py


+ 56 - 47
dulwich/reftable.py

@@ -22,6 +22,7 @@ from typing import BinaryIO
 from dulwich.objects import ObjectID
 from dulwich.refs import (
     SYMREF,
+    Ref,
     RefsContainer,
 )
 
@@ -909,7 +910,7 @@ class ReftableRefsContainer(RefsContainer):
                     files.append(os.path.join(self.reftable_dir, table_name))
         return files
 
-    def _read_all_tables(self) -> dict[bytes, tuple[int, bytes]]:
+    def _read_all_tables(self) -> dict[Ref, tuple[int, bytes]]:
         """Read all reftable files and merge results."""
         # First, read all tables and sort them by min_update_index
         table_data = []
@@ -924,19 +925,20 @@ class ReftableRefsContainer(RefsContainer):
         table_data.sort(key=lambda x: x[0])
 
         # Merge results in chronological order
-        all_refs: dict[bytes, tuple[int, bytes]] = {}
+        all_refs: dict[Ref, tuple[int, bytes]] = {}
         for min_update_index, table_file, refs in table_data:
             # Apply updates from this table
             for refname, (value_type, value) in refs.items():
+                ref = Ref(refname)
                 if value_type == REF_VALUE_DELETE:
                     # Remove ref if it exists
-                    all_refs.pop(refname, None)
+                    all_refs.pop(ref, None)
                 else:
                     # Add/update ref
-                    all_refs[refname] = (value_type, value)
+                    all_refs[ref] = (value_type, value)
         return all_refs
 
-    def allkeys(self) -> set[bytes]:
+    def allkeys(self) -> set[Ref]:
         """Return set of all ref names."""
         refs = self._read_all_tables()
         result = set(refs.keys())
@@ -946,17 +948,17 @@ class ReftableRefsContainer(RefsContainer):
             if value_type == REF_VALUE_SYMREF:
                 # Add the target ref as an implicit ref
                 target = value
-                result.add(target)
+                result.add(Ref(target))
 
         return result
 
-    def follow(self, name: bytes) -> tuple[list[bytes], bytes]:
+    def follow(self, name: Ref) -> tuple[list[Ref], ObjectID | None]:
         """Follow a reference name.
 
         Returns: a tuple of (refnames, sha), where refnames are the names of
             references in the chain
         """
-        refnames = []
+        refnames: list[Ref] = []
         current = name
         refs = self._read_all_tables()
 
@@ -968,11 +970,11 @@ class ReftableRefsContainer(RefsContainer):
 
             value_type, value = ref_data
             if value_type == REF_VALUE_REF:
-                return refnames, value
+                return refnames, ObjectID(value)
             if value_type == REF_VALUE_PEELED:
-                return refnames, value[:SHA1_HEX_SIZE]  # First SHA1 hex chars
+                return refnames, ObjectID(value[:SHA1_HEX_SIZE])  # First SHA1 hex chars
             if value_type == REF_VALUE_SYMREF:
-                current = value
+                current = Ref(value)
                 continue
 
             # Unknown value type
@@ -981,7 +983,7 @@ class ReftableRefsContainer(RefsContainer):
         # Too many levels of indirection
         raise ValueError(f"Too many levels of symbolic ref indirection for {name!r}")
 
-    def __getitem__(self, name: bytes) -> ObjectID:
+    def __getitem__(self, name: Ref) -> ObjectID:
         """Get the SHA1 for a reference name.
 
         This method follows all symbolic references.
@@ -991,7 +993,7 @@ class ReftableRefsContainer(RefsContainer):
             raise KeyError(name)
         return sha
 
-    def read_loose_ref(self, name: bytes) -> bytes:
+    def read_loose_ref(self, name: Ref) -> bytes:
         """Read a reference value without following symbolic refs.
 
         Args:
@@ -1019,18 +1021,18 @@ class ReftableRefsContainer(RefsContainer):
 
         raise ValueError(f"Unknown ref value type: {value_type}")
 
-    def get_packed_refs(self) -> dict[bytes, bytes]:
+    def get_packed_refs(self) -> dict[Ref, ObjectID]:
         """Get packed refs. Reftable doesn't distinguish packed/loose."""
         refs = self._read_all_tables()
         result = {}
         for name, (value_type, value) in refs.items():
             if value_type == REF_VALUE_REF:
-                result[name] = value
+                result[name] = ObjectID(value)
             elif value_type == REF_VALUE_PEELED:
-                result[name] = value[:SHA1_HEX_SIZE]  # First SHA1 hex chars
+                result[name] = ObjectID(value[:SHA1_HEX_SIZE])  # First SHA1 hex chars
         return result
 
-    def get_peeled(self, name: bytes) -> bytes | None:
+    def get_peeled(self, name: Ref) -> ObjectID | None:
         """Return the cached peeled value of a ref, if available.
 
         Args:
@@ -1048,10 +1050,10 @@ class ReftableRefsContainer(RefsContainer):
         value_type, value = ref_data
         if value_type == REF_VALUE_PEELED:
             # Return the peeled SHA (second 40 hex chars)
-            return value[40:80]
+            return ObjectID(value[40:80])
         elif value_type == REF_VALUE_REF:
             # Known not to be peeled
-            return value
+            return ObjectID(value)
         else:
             # Symbolic ref or other - no peeled info
             return None
@@ -1073,12 +1075,14 @@ class ReftableRefsContainer(RefsContainer):
         table_name = f"0x{min_idx:016x}-0x{max_idx:016x}-{hash_part:08x}.ref"
         return os.path.join(self.reftable_dir, table_name)
 
-    def add_packed_refs(self, new_refs: Mapping[bytes, bytes | None]) -> None:
+    def add_packed_refs(self, new_refs: Mapping[Ref, ObjectID | None]) -> None:
         """Add packed refs. Creates a new reftable file with all refs consolidated."""
         if not new_refs:
             return
 
-        self._write_batch_updates(new_refs)
+        # Convert to bytes for internal use
+        byte_refs = {bytes(k): bytes(v) if v else None for k, v in new_refs.items()}
+        self._write_batch_updates(byte_refs)
 
     def _write_batch_updates(self, updates: Mapping[bytes, bytes | None]) -> None:
         """Write multiple ref updates to a single reftable file."""
@@ -1100,9 +1104,9 @@ class ReftableRefsContainer(RefsContainer):
 
     def set_if_equals(
         self,
-        name: bytes,
-        old_ref: bytes | None,
-        new_ref: bytes | None,
+        name: Ref,
+        old_ref: ObjectID | None,
+        new_ref: ObjectID,
         committer: bytes | None = None,
         timestamp: int | None = None,
         timezone: int | None = None,
@@ -1116,22 +1120,19 @@ class ReftableRefsContainer(RefsContainer):
         except KeyError:
             current = None
 
-        if current != old_ref:
+        old_ref_bytes = bytes(old_ref) if old_ref else None
+        if current != old_ref_bytes:
             return False
 
-        if new_ref is None:
-            # Delete ref
-            self._write_ref_update(name, REF_VALUE_DELETE, b"")
-        else:
-            # Update ref
-            self._write_ref_update(name, REF_VALUE_REF, new_ref)
+        # Update ref
+        self._write_ref_update(bytes(name), REF_VALUE_REF, bytes(new_ref))
 
         return True
 
     def add_if_new(
         self,
-        name: bytes,
-        ref: bytes,
+        name: Ref,
+        ref: ObjectID,
         committer: bytes | None = None,
         timestamp: int | None = None,
         timezone: int | None = None,
@@ -1143,28 +1144,31 @@ class ReftableRefsContainer(RefsContainer):
             return False  # Ref exists
         except KeyError:
             pass  # Ref doesn't exist, continue
-        self._write_ref_update(name, REF_VALUE_REF, ref)
+        self._write_ref_update(bytes(name), REF_VALUE_REF, bytes(ref))
         return True
 
     def remove_if_equals(
         self,
-        name: bytes,
-        old_ref: bytes | None,
+        name: Ref,
+        old_ref: ObjectID | None,
         committer: bytes | None = None,
         timestamp: int | None = None,
         timezone: int | None = None,
         message: bytes | None = None,
     ) -> bool:
         """Remove a ref if it equals old_ref."""
-        return self.set_if_equals(
-            name,
-            old_ref,
-            None,
-            committer=committer,
-            timestamp=timestamp,
-            timezone=timezone,
-            message=message,
-        )
+        # For deletion, we need to use the internal method since set_if_equals requires new_ref
+        try:
+            current = self.read_loose_ref(name)
+        except KeyError:
+            current = None
+
+        old_ref_bytes = bytes(old_ref) if old_ref else None
+        if current != old_ref_bytes:
+            return False
+
+        self._write_ref_update(bytes(name), REF_VALUE_DELETE, b"")
+        return True
 
     def set_symbolic_ref(
         self,
@@ -1234,14 +1238,19 @@ class ReftableRefsContainer(RefsContainer):
         # Get next update index - all refs in batch get the SAME index
         batch_update_index = self._get_next_update_index()
 
+        # Convert Ref keys to bytes for internal methods
+        all_refs_bytes = {bytes(k): v for k, v in all_refs.items()}
+
         # Apply updates to get final state
         self._apply_batch_updates(
-            all_refs, other_updates, head_update, batch_update_index
+            all_refs_bytes, other_updates, head_update, batch_update_index
         )
 
         # Write consolidated batch file
         created_files = (
-            self._write_batch_file(all_refs, batch_update_index) if all_refs else []
+            self._write_batch_file(all_refs_bytes, batch_update_index)
+            if all_refs_bytes
+            else []
         )
 
         # Update tables list with new files (don't compact, keep separate)

+ 74 - 59
dulwich/repo.py

@@ -92,6 +92,7 @@ from .objects import (
     Blob,
     Commit,
     ObjectID,
+    RawObjectID,
     ShaFile,
     Tag,
     Tree,
@@ -100,7 +101,7 @@ from .objects import (
 )
 from .pack import generate_unpacked_objects
 from .refs import (
-    ANNOTATED_TAG_SUFFIX,  # noqa: F401
+    HEADREF,
     LOCAL_TAG_PREFIX,  # noqa: F401
     SYMREF,  # noqa: F401
     DictRefsContainer,
@@ -268,7 +269,7 @@ def check_user_identity(identity: bytes) -> None:
 
 def parse_graftpoints(
     graftpoints: Iterable[bytes],
-) -> dict[bytes, list[bytes]]:
+) -> dict[ObjectID, list[ObjectID]]:
     """Convert a list of graftpoints into a dict.
 
     Args:
@@ -282,13 +283,13 @@ def parse_graftpoints(
 
     https://git.wiki.kernel.org/index.php/GraftPoint
     """
-    grafts = {}
+    grafts: dict[ObjectID, list[ObjectID]] = {}
     for line in graftpoints:
         raw_graft = line.split(None, 1)
 
-        commit = raw_graft[0]
+        commit = ObjectID(raw_graft[0])
         if len(raw_graft) == 2:
-            parents = raw_graft[1].split()
+            parents = [ObjectID(p) for p in raw_graft[1].split()]
         else:
             parents = []
 
@@ -299,7 +300,7 @@ def parse_graftpoints(
     return grafts
 
 
-def serialize_graftpoints(graftpoints: Mapping[bytes, Sequence[bytes]]) -> bytes:
+def serialize_graftpoints(graftpoints: Mapping[ObjectID, Sequence[ObjectID]]) -> bytes:
     """Convert a dictionary of grafts into string.
 
     The graft dictionary is:
@@ -414,8 +415,8 @@ class ParentsProvider:
     def __init__(
         self,
         store: "BaseObjectStore",
-        grafts: dict[bytes, list[bytes]] = {},
-        shallows: Iterable[bytes] = [],
+        grafts: dict[ObjectID, list[ObjectID]] = {},
+        shallows: Iterable[ObjectID] = [],
     ) -> None:
         """Initialize ParentsProvider.
 
@@ -432,8 +433,8 @@ class ParentsProvider:
         self.commit_graph = store.get_commit_graph()
 
     def get_parents(
-        self, commit_id: bytes, commit: Commit | None = None
-    ) -> list[bytes]:
+        self, commit_id: ObjectID, commit: Commit | None = None
+    ) -> list[ObjectID]:
         """Get parents for a commit using the parents provider."""
         try:
             return self.grafts[commit_id]
@@ -453,9 +454,8 @@ class ParentsProvider:
             obj = self.store[commit_id]
             assert isinstance(obj, Commit)
             commit = obj
-        parents = commit.parents
-        assert isinstance(parents, list)
-        return parents
+        result: list[ObjectID] = commit.parents
+        return result
 
 
 class BaseRepo:
@@ -486,7 +486,7 @@ class BaseRepo:
         self.object_store = object_store
         self.refs = refs
 
-        self._graftpoints: dict[bytes, list[bytes]] = {}
+        self._graftpoints: dict[ObjectID, list[ObjectID]] = {}
         self.hooks: dict[str, Hook] = {}
 
     def _determine_file_mode(self) -> bool:
@@ -585,11 +585,11 @@ class BaseRepo:
     def fetch(
         self,
         target: "BaseRepo",
-        determine_wants: Callable[[Mapping[bytes, bytes], int | None], list[bytes]]
+        determine_wants: Callable[[Mapping[Ref, ObjectID], int | None], list[ObjectID]]
         | None = None,
         progress: Callable[..., None] | None = None,
         depth: int | None = None,
-    ) -> dict[bytes, bytes]:
+    ) -> dict[Ref, ObjectID]:
         """Fetch objects into another repository.
 
         Args:
@@ -613,11 +613,11 @@ class BaseRepo:
 
     def fetch_pack_data(
         self,
-        determine_wants: Callable[[Mapping[bytes, bytes], int | None], list[bytes]],
+        determine_wants: Callable[[Mapping[Ref, ObjectID], int | None], list[ObjectID]],
         graph_walker: "GraphWalker",
         progress: Callable[[bytes], None] | None,
         *,
-        get_tagged: Callable[[], dict[bytes, bytes]] | None = None,
+        get_tagged: Callable[[], dict[ObjectID, ObjectID]] | None = None,
         depth: int | None = None,
     ) -> tuple[int, Iterator["UnpackedObject"]]:
         """Fetch the pack data required for a set of revisions.
@@ -648,11 +648,11 @@ class BaseRepo:
 
     def find_missing_objects(
         self,
-        determine_wants: Callable[[Mapping[bytes, bytes], int | None], list[bytes]],
+        determine_wants: Callable[[Mapping[Ref, ObjectID], int | None], list[ObjectID]],
         graph_walker: "GraphWalker",
         progress: Callable[[bytes], None] | None,
         *,
-        get_tagged: Callable[[], dict[bytes, bytes]] | None = None,
+        get_tagged: Callable[[], dict[ObjectID, ObjectID]] | None = None,
         depth: int | None = None,
     ) -> MissingObjectFinder | None:
         """Fetch the missing objects required for a set of revisions.
@@ -670,9 +670,11 @@ class BaseRepo:
           depth: Shallow fetch depth
         Returns: iterator over objects, with __len__ implemented
         """
+        # TODO: serialize_refs returns dict[bytes, ObjectID] with peeled refs (^{}),
+        # but determine_wants expects Mapping[Ref, ObjectID]. Need to reconcile this.
         refs = serialize_refs(self.object_store, self.get_refs())
 
-        wants = determine_wants(refs, depth)
+        wants = determine_wants(refs, depth)  # type: ignore[arg-type]
         if not isinstance(wants, list):
             raise TypeError("determine_wants() did not return a list")
 
@@ -721,7 +723,7 @@ class BaseRepo:
 
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
 
-        def get_parents(commit: Commit) -> list[bytes]:
+        def get_parents(commit: Commit) -> list[ObjectID]:
             """Get parents for a commit using the parents provider.
 
             Args:
@@ -785,7 +787,7 @@ class BaseRepo:
         if heads is None:
             heads = [
                 sha
-                for sha in self.refs.as_dict(b"refs/heads").values()
+                for sha in self.refs.as_dict(Ref(b"refs/heads")).values()
                 if sha in self.object_store
             ]
         parents_provider = ParentsProvider(self.object_store)
@@ -796,21 +798,22 @@ class BaseRepo:
             update_shallow=self.update_shallow,
         )
 
-    def get_refs(self) -> dict[bytes, bytes]:
+    def get_refs(self) -> dict[Ref, ObjectID]:
         """Get dictionary with all refs.
 
         Returns: A ``dict`` mapping ref names to SHA1s
         """
         return self.refs.as_dict()
 
-    def head(self) -> bytes:
+    def head(self) -> ObjectID:
         """Return the SHA1 pointed at by HEAD."""
         # TODO: move this method to WorkTree
-        return self.refs[b"HEAD"]
+        return self.refs[HEADREF]
 
     def _get_object(self, sha: bytes, cls: type[T]) -> T:
         assert len(sha) in (20, 40)
-        ret = self.get_object(sha)
+        obj_id = ObjectID(sha) if len(sha) == 40 else RawObjectID(sha)
+        ret = self.get_object(obj_id)
         if not isinstance(ret, cls):
             if cls is Commit:
                 raise NotCommitError(ret.id)
@@ -824,7 +827,7 @@ class BaseRepo:
                 raise Exception(f"Type invalid: {ret.type_name!r} != {cls.type_name!r}")
         return ret
 
-    def get_object(self, sha: bytes) -> ShaFile:
+    def get_object(self, sha: ObjectID | RawObjectID) -> ShaFile:
         """Retrieve the object with the specified SHA.
 
         Args:
@@ -847,7 +850,9 @@ class BaseRepo:
             shallows=self.get_shallow(),
         )
 
-    def get_parents(self, sha: bytes, commit: Commit | None = None) -> list[bytes]:
+    def get_parents(
+        self, sha: ObjectID, commit: Commit | None = None
+    ) -> list[ObjectID]:
         """Retrieve the parents of a specific commit.
 
         If the specific commit is a graftpoint, the graft parents
@@ -940,10 +945,10 @@ class BaseRepo:
         if f is None:
             return set()
         with f:
-            return {line.strip() for line in f}
+            return {ObjectID(line.strip()) for line in f}
 
     def update_shallow(
-        self, new_shallow: set[bytes] | None, new_unshallow: set[bytes] | None
+        self, new_shallow: set[ObjectID] | None, new_unshallow: set[ObjectID] | None
     ) -> None:
         """Update the list of shallow objects.
 
@@ -988,8 +993,8 @@ class BaseRepo:
 
     def get_walker(
         self,
-        include: Sequence[bytes] | None = None,
-        exclude: Sequence[bytes] | None = None,
+        include: Sequence[ObjectID] | None = None,
+        exclude: Sequence[ObjectID] | None = None,
         order: str = "date",
         reverse: bool = False,
         max_entries: int | None = None,
@@ -1047,7 +1052,7 @@ class BaseRepo:
             queue_cls=queue_cls if queue_cls is not None else _CommitTimeQueue,
         )
 
-    def __getitem__(self, name: ObjectID | Ref) -> "ShaFile":
+    def __getitem__(self, name: ObjectID | Ref | bytes) -> "ShaFile":
         """Retrieve a Git object by SHA1 or ref.
 
         Args:
@@ -1060,11 +1065,14 @@ class BaseRepo:
             raise TypeError(f"'name' must be bytestring, not {type(name).__name__:.80}")
         if len(name) in (20, 40):
             try:
-                return self.object_store[name]
+                # Try as ObjectID/RawObjectID
+                return self.object_store[
+                    ObjectID(name) if len(name) == 40 else RawObjectID(name)
+                ]
             except (KeyError, ValueError):
                 pass
         try:
-            return self.object_store[self.refs[name]]
+            return self.object_store[self.refs[Ref(name)]]
         except RefFormatError as exc:
             raise KeyError(name) from exc
 
@@ -1074,10 +1082,12 @@ class BaseRepo:
         Args:
           name: Git object SHA1 or ref name
         """
-        if len(name) == 20 or (len(name) == 40 and valid_hexsha(name)):
-            return name in self.object_store or name in self.refs
+        if len(name) == 20:
+            return RawObjectID(name) in self.object_store or Ref(name) in self.refs
+        elif len(name) == 40 and valid_hexsha(name):
+            return ObjectID(name) in self.object_store or Ref(name) in self.refs
         else:
-            return name in self.refs
+            return Ref(name) in self.refs
 
     def __setitem__(self, name: bytes, value: ShaFile | bytes) -> None:
         """Set a ref.
@@ -1086,11 +1096,12 @@ class BaseRepo:
           name: ref name
           value: Ref value - either a ShaFile object, or a hex sha
         """
-        if name.startswith(b"refs/") or name == b"HEAD":
+        if name.startswith(b"refs/") or name == HEADREF:
+            ref_name = Ref(name)
             if isinstance(value, ShaFile):
-                self.refs[name] = value.id
+                self.refs[ref_name] = value.id
             elif isinstance(value, bytes):
-                self.refs[name] = value
+                self.refs[ref_name] = ObjectID(value)
             else:
                 raise TypeError(value)
         else:
@@ -1102,8 +1113,8 @@ class BaseRepo:
         Args:
           name: Name of the ref to remove
         """
-        if name.startswith(b"refs/") or name == b"HEAD":
-            del self.refs[name]
+        if name.startswith(b"refs/") or name == HEADREF:
+            del self.refs[Ref(name)]
         else:
             raise ValueError(name)
 
@@ -1117,7 +1128,9 @@ class BaseRepo:
         )
         return get_user_identity(config)
 
-    def _add_graftpoints(self, updated_graftpoints: dict[bytes, list[bytes]]) -> None:
+    def _add_graftpoints(
+        self, updated_graftpoints: dict[ObjectID, list[ObjectID]]
+    ) -> None:
         """Add or modify graftpoints.
 
         Args:
@@ -1130,7 +1143,7 @@ class BaseRepo:
 
         self._graftpoints.update(updated_graftpoints)
 
-    def _remove_graftpoints(self, to_remove: Sequence[bytes] = ()) -> None:
+    def _remove_graftpoints(self, to_remove: Sequence[ObjectID] = ()) -> None:
         """Remove graftpoints.
 
         Args:
@@ -1139,12 +1152,12 @@ class BaseRepo:
         for sha in to_remove:
             del self._graftpoints[sha]
 
-    def _read_heads(self, name: str) -> list[bytes]:
+    def _read_heads(self, name: str) -> list[ObjectID]:
         f = self.get_named_file(name)
         if f is None:
             return []
         with f:
-            return [line.strip() for line in f.readlines() if line.strip()]
+            return [ObjectID(line.strip()) for line in f.readlines() if line.strip()]
 
     def get_worktree(self) -> "WorkTree":
         """Get the working tree for this repository.
@@ -1171,7 +1184,7 @@ class BaseRepo:
         author_timezone: int | None = None,
         tree: ObjectID | None = None,
         encoding: bytes | None = None,
-        ref: Ref | None = b"HEAD",
+        ref: Ref | None = HEADREF,
         merge_heads: list[ObjectID] | None = None,
         no_verify: bool = False,
         sign: bool = False,
@@ -1782,19 +1795,21 @@ class Repo(BaseRepo):
                 ref_message = b"clone: from " + encoded_path
                 self.fetch(target, depth=depth)
                 target.refs.import_refs(
-                    b"refs/remotes/" + origin,
-                    self.refs.as_dict(b"refs/heads"),
+                    Ref(b"refs/remotes/" + origin),
+                    self.refs.as_dict(Ref(b"refs/heads")),
                     message=ref_message,
                 )
                 target.refs.import_refs(
-                    b"refs/tags", self.refs.as_dict(b"refs/tags"), message=ref_message
+                    Ref(b"refs/tags"),
+                    self.refs.as_dict(Ref(b"refs/tags")),
+                    message=ref_message,
                 )
 
-                head_chain, origin_sha = self.refs.follow(b"HEAD")
+                head_chain, origin_sha = self.refs.follow(HEADREF)
                 origin_head = head_chain[-1] if head_chain else None
                 if origin_sha and not origin_head:
                     # set detached HEAD
-                    target.refs[b"HEAD"] = origin_sha
+                    target.refs[HEADREF] = origin_sha
                 else:
                     _set_origin_head(target.refs, origin, origin_head)
                     head_ref = _set_default_branch(
@@ -1821,7 +1836,7 @@ class Repo(BaseRepo):
         return target
 
     @replace_me(remove_in="0.26.0")
-    def reset_index(self, tree: bytes | None = None) -> None:
+    def reset_index(self, tree: ObjectID | None = None) -> None:
         """Reset the index back to a specific tree.
 
         Args:
@@ -1896,7 +1911,7 @@ class Repo(BaseRepo):
             """
             try:
                 # Get the current branch using refs
-                ref_chain, _ = self.refs.follow(b"HEAD")
+                ref_chain, _ = self.refs.follow(HEADREF)
                 head_ref = ref_chain[-1]  # Get the final resolved ref
             except KeyError:
                 pass
@@ -2037,7 +2052,7 @@ class Repo(BaseRepo):
                 default_branch = config.get("init", "defaultBranch")
             except KeyError:
                 default_branch = DEFAULT_BRANCH
-        ret.refs.set_symbolic_ref(b"HEAD", local_branch_name(default_branch))
+        ret.refs.set_symbolic_ref(HEADREF, local_branch_name(default_branch))
         ret._init_files(
             bare=bare,
             symlinks=symlinks,
@@ -2559,7 +2574,7 @@ class MemoryRepo(BaseRepo):
         author_timezone: int | None = None,
         tree: ObjectID | None = None,
         encoding: bytes | None = None,
-        ref: Ref | None = b"HEAD",
+        ref: Ref | None = HEADREF,
         merge_heads: list[ObjectID] | None = None,
         no_verify: bool = False,
         sign: bool = False,
@@ -2682,7 +2697,7 @@ class MemoryRepo(BaseRepo):
     def init_bare(
         cls,
         objects: Iterable[ShaFile],
-        refs: Mapping[bytes, bytes],
+        refs: Mapping[Ref, ObjectID],
         format: int | None = None,
     ) -> "MemoryRepo":
         """Create a new bare repository in memory.

+ 47 - 45
dulwich/server.py

@@ -99,6 +99,7 @@ from .protocol import (
     MULTI_ACK,
     MULTI_ACK_DETAILED,
     NAK_LINE,
+    PEELED_TAG_SUFFIX,
     SIDE_BAND_CHANNEL_DATA,
     SIDE_BAND_CHANNEL_FATAL,
     SIDE_BAND_CHANNEL_PROGRESS,
@@ -118,7 +119,7 @@ from .protocol import (
     format_unshallow_line,
     symref_capabilities,
 )
-from .refs import PEELED_TAG_SUFFIX, Ref, RefsContainer, write_info_refs
+from .refs import Ref, RefsContainer, write_info_refs
 from .repo import Repo
 
 logger = log_utils.getLogger(__name__)
@@ -156,7 +157,7 @@ class BackendRepo(TypingProtocol):
         """
         raise NotImplementedError
 
-    def get_peeled(self, name: bytes) -> bytes | None:
+    def get_peeled(self, name: bytes) -> ObjectID | None:
         """Return the cached peeled value of a ref, if available.
 
         Args:
@@ -170,11 +171,11 @@ class BackendRepo(TypingProtocol):
 
     def find_missing_objects(
         self,
-        determine_wants: Callable[[Mapping[bytes, bytes], int | None], list[bytes]],
+        determine_wants: Callable[[Mapping[Ref, ObjectID], int | None], list[ObjectID]],
         graph_walker: "_ProtocolGraphWalker",
         progress: Callable[[bytes], None] | None,
         *,
-        get_tagged: Callable[[], dict[bytes, bytes]] | None = None,
+        get_tagged: Callable[[], dict[ObjectID, ObjectID]] | None = None,
         depth: int | None = None,
     ) -> "MissingObjectFinder | None":
         """Yield the objects required for a list of commits.
@@ -496,8 +497,8 @@ class UploadPackHandler(PackHandler):
         tagged = {}
         for name, sha in refs.items():
             peeled_sha = repo.get_peeled(name)
-            if peeled_sha is not None and peeled_sha != sha:
-                tagged[peeled_sha] = sha
+            if peeled_sha is not None and peeled_sha != ObjectID(sha):
+                tagged[peeled_sha] = ObjectID(sha)
         return tagged
 
     def handle(self) -> None:
@@ -517,11 +518,11 @@ class UploadPackHandler(PackHandler):
             self.repo.get_peeled,
             self.repo.refs.get_symrefs,
         )
-        wants = []
+        wants: list[ObjectID] = []
 
         def wants_wrapper(
-            refs: Mapping[bytes, bytes], depth: int | None = None
-        ) -> list[bytes]:
+            refs: Mapping[Ref, ObjectID], depth: int | None = None
+        ) -> list[ObjectID]:
             wants.extend(graph_walker.determine_wants(refs, depth))
             return wants
 
@@ -612,7 +613,7 @@ def _split_proto_line(
 
 
 def _want_satisfied(
-    store: ObjectContainer, haves: set[bytes], want: bytes, earliest: int
+    store: ObjectContainer, haves: set[ObjectID], want: ObjectID, earliest: int
 ) -> bool:
     """Check if a specific want is satisfied by a set of haves.
 
@@ -646,7 +647,7 @@ def _want_satisfied(
 
 
 def _all_wants_satisfied(
-    store: ObjectContainer, haves: AbstractSet[bytes], wants: set[bytes]
+    store: ObjectContainer, haves: AbstractSet[ObjectID], wants: set[ObjectID]
 ) -> bool:
     """Check whether all the current wants are satisfied by a set of haves.
 
@@ -712,8 +713,8 @@ class _ProtocolGraphWalker:
         self,
         handler: PackHandler,
         object_store: ObjectContainer,
-        get_peeled: Callable[[bytes], bytes | None],
-        get_symrefs: Callable[[], dict[bytes, bytes]],
+        get_peeled: Callable[[bytes], ObjectID | None],
+        get_symrefs: Callable[[], dict[Ref, Ref]],
     ) -> None:
         """Initialize a ProtocolGraphWalker.
 
@@ -730,18 +731,18 @@ class _ProtocolGraphWalker:
         self.proto = handler.proto
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
-        self._wants: list[bytes] = []
-        self.shallow: set[bytes] = set()
-        self.client_shallow: set[bytes] = set()
-        self.unshallow: set[bytes] = set()
+        self._wants: list[ObjectID] = []
+        self.shallow: set[ObjectID] = set()
+        self.client_shallow: set[ObjectID] = set()
+        self.unshallow: set[ObjectID] = set()
         self._cached = False
-        self._cache: list[bytes] = []
+        self._cache: list[ObjectID] = []
         self._cache_index = 0
         self._impl: AckGraphWalkerImpl | None = None
 
     def determine_wants(
-        self, heads: Mapping[bytes, bytes], depth: int | None = None
-    ) -> list[bytes]:
+        self, heads: Mapping[Ref, ObjectID], depth: int | None = None
+    ) -> list[ObjectID]:
         """Determine the wants for a set of heads.
 
         The given heads are advertised to the client, who then specifies which
@@ -804,12 +805,12 @@ class _ProtocolGraphWalker:
         allowed = (COMMAND_WANT, COMMAND_SHALLOW, COMMAND_DEEPEN, None)
         command, sha_result = _split_proto_line(line, allowed)
 
-        want_revs = []
+        want_revs: list[ObjectID] = []
         while command == COMMAND_WANT:
             assert isinstance(sha_result, bytes)
             if sha_result not in values:
                 raise GitProtocolError(f"Client wants invalid object {sha_result!r}")
-            want_revs.append(sha_result)
+            want_revs.append(ObjectID(sha_result))
             command, sha_result = self.read_proto_line(allowed)
 
         self.set_wants(want_revs)
@@ -841,7 +842,7 @@ class _ProtocolGraphWalker:
     def nak(self) -> None:
         """Send a NAK response."""
 
-    def ack(self, have_ref: bytes) -> None:
+    def ack(self, have_ref: ObjectID) -> None:
         """Acknowledge a have reference.
 
         Args:
@@ -860,7 +861,7 @@ class _ProtocolGraphWalker:
         self._cached = True
         self._cache_index = 0
 
-    def next(self) -> bytes | None:
+    def next(self) -> ObjectID | None:
         """Get the next SHA from the graph walker.
 
         Returns: Next SHA or None if done
@@ -891,7 +892,7 @@ class _ProtocolGraphWalker:
         """
         return _split_proto_line(self.proto.read_pkt_line(), allowed)
 
-    def _handle_shallow_request(self, wants: Sequence[bytes]) -> None:
+    def _handle_shallow_request(self, wants: Sequence[ObjectID]) -> None:
         """Handle shallow clone requests from the client.
 
         Args:
@@ -904,7 +905,7 @@ class _ProtocolGraphWalker:
                 depth = val
                 break
             assert isinstance(val, bytes)
-            self.client_shallow.add(val)
+            self.client_shallow.add(ObjectID(val))
         self.read_proto_line((None,))  # consume client's flush-pkt
 
         shallow, not_shallow = find_shallow(self.store, wants, depth)
@@ -918,7 +919,7 @@ class _ProtocolGraphWalker:
         self.update_shallow(new_shallow, unshallow)
 
     def update_shallow(
-        self, new_shallow: AbstractSet[bytes], unshallow: AbstractSet[bytes]
+        self, new_shallow: AbstractSet[ObjectID], unshallow: AbstractSet[ObjectID]
     ) -> None:
         """Update shallow/unshallow information to the client.
 
@@ -938,7 +939,7 @@ class _ProtocolGraphWalker:
         # relay the message down to the handler.
         self.handler.notify_done()
 
-    def send_ack(self, sha: bytes, ack_type: bytes = b"") -> None:
+    def send_ack(self, sha: ObjectID, ack_type: bytes = b"") -> None:
         """Send an ACK to the client.
 
         Args:
@@ -963,7 +964,7 @@ class _ProtocolGraphWalker:
         assert self._impl is not None
         return self._impl.handle_done(done_required, done_received)
 
-    def set_wants(self, wants: list[bytes]) -> None:
+    def set_wants(self, wants: list[ObjectID]) -> None:
         """Set the list of wanted objects.
 
         Args:
@@ -971,7 +972,7 @@ class _ProtocolGraphWalker:
         """
         self._wants = wants
 
-    def all_wants_satisfied(self, haves: AbstractSet[bytes]) -> bool:
+    def all_wants_satisfied(self, haves: AbstractSet[ObjectID]) -> bool:
         """Check whether all the current wants are satisfied by a set of haves.
 
         Args:
@@ -1008,9 +1009,9 @@ class SingleAckGraphWalkerImpl(AckGraphWalkerImpl):
           walker: Parent ProtocolGraphWalker instance
         """
         self.walker = walker
-        self._common: list[bytes] = []
+        self._common: list[ObjectID] = []
 
-    def ack(self, have_ref: bytes) -> None:
+    def ack(self, have_ref: ObjectID) -> None:
         """Acknowledge a have reference.
 
         Args:
@@ -1020,7 +1021,7 @@ class SingleAckGraphWalkerImpl(AckGraphWalkerImpl):
             self.walker.send_ack(have_ref)
             self._common.append(have_ref)
 
-    def next(self) -> bytes | None:
+    def next(self) -> ObjectID | None:
         """Get next SHA from graph walker.
 
         Returns:
@@ -1033,7 +1034,7 @@ class SingleAckGraphWalkerImpl(AckGraphWalkerImpl):
             return None
         elif command == COMMAND_HAVE:
             assert isinstance(sha, bytes)
-            return sha
+            return ObjectID(sha)
         return None
 
     __next__ = next
@@ -1079,9 +1080,9 @@ class MultiAckGraphWalkerImpl(AckGraphWalkerImpl):
         """
         self.walker = walker
         self._found_base = False
-        self._common: list[bytes] = []
+        self._common: list[ObjectID] = []
 
-    def ack(self, have_ref: bytes) -> None:
+    def ack(self, have_ref: ObjectID) -> None:
         """Acknowledge a have reference.
 
         Args:
@@ -1094,7 +1095,7 @@ class MultiAckGraphWalkerImpl(AckGraphWalkerImpl):
                 self._found_base = True
         # else we blind ack within next
 
-    def next(self) -> bytes | None:
+    def next(self) -> ObjectID | None:
         """Get next SHA from graph walker.
 
         Returns:
@@ -1112,10 +1113,11 @@ class MultiAckGraphWalkerImpl(AckGraphWalkerImpl):
                 return None
             elif command == COMMAND_HAVE:
                 assert isinstance(sha, bytes)
+                sha_id = ObjectID(sha)
                 if self._found_base:
                     # blind ack
-                    self.walker.send_ack(sha, b"continue")
-                return sha
+                    self.walker.send_ack(sha_id, b"continue")
+                return sha_id
 
     __next__ = next
 
@@ -1162,9 +1164,9 @@ class MultiAckDetailedGraphWalkerImpl(AckGraphWalkerImpl):
             walker: Parent ProtocolGraphWalker instance
         """
         self.walker = walker
-        self._common: list[bytes] = []
+        self._common: list[ObjectID] = []
 
-    def ack(self, have_ref: bytes) -> None:
+    def ack(self, have_ref: ObjectID) -> None:
         """Acknowledge a have reference.
 
         Args:
@@ -1174,7 +1176,7 @@ class MultiAckDetailedGraphWalkerImpl(AckGraphWalkerImpl):
         self._common.append(have_ref)
         self.walker.send_ack(have_ref, b"common")
 
-    def next(self) -> bytes | None:
+    def next(self) -> ObjectID | None:
         """Get next SHA from graph walker.
 
         Returns:
@@ -1203,7 +1205,7 @@ class MultiAckDetailedGraphWalkerImpl(AckGraphWalkerImpl):
                 # return the sha and let the caller ACK it with the
                 # above ack method.
                 assert isinstance(sha, bytes)
-                return sha
+                return ObjectID(sha)
         # don't nak unless no common commits were found, even if not
         # everything is satisfied
         return None
@@ -1432,7 +1434,7 @@ class ReceivePackHandler(PackHandler):
         # client will now send us a list of (oldsha, newsha, ref)
         while ref_line:
             (oldsha, newsha, ref_name) = ref_line.split()
-            client_refs.append((oldsha, newsha, ref_name))
+            client_refs.append((ObjectID(oldsha), ObjectID(newsha), Ref(ref_name)))
             ref_line = self.proto.read_pkt_line()
 
         # backend can now deal with this refs and read a pack using self.read
@@ -1492,7 +1494,7 @@ class UploadArchiveHandler(Handler):
                 i += 1
                 format = arguments[i].decode("ascii")
             else:
-                commit_sha = self.repo.refs[argument]
+                commit_sha = self.repo.refs[Ref(argument)]
                 commit_obj = store[commit_sha]
                 assert isinstance(commit_obj, Commit)
                 tree_obj = store[commit_obj.tree]

+ 4 - 2
dulwich/stash.py

@@ -58,7 +58,7 @@ class CommitKwargs(TypedDict, total=False):
     author: bytes
 
 
-DEFAULT_STASH_REF = b"refs/stash"
+DEFAULT_STASH_REF = Ref(b"refs/stash")
 
 
 class Stash:
@@ -135,7 +135,9 @@ class Stash:
 
         # Get current HEAD to determine if we can apply cleanly
         try:
-            current_head = self._repo.refs[b"HEAD"]
+            from dulwich.refs import HEADREF
+
+            current_head = self._repo.refs[HEADREF]
         except KeyError:
             raise ValueError("Cannot pop stash: no HEAD")
 

+ 3 - 2
dulwich/submodule.py

@@ -29,13 +29,14 @@ from .object_store import iter_tree_contents
 from .objects import S_ISGITLINK
 
 if TYPE_CHECKING:
+    from .objects import ObjectID
     from .pack import ObjectContainer
     from .repo import Repo
 
 
 def iter_cached_submodules(
-    store: "ObjectContainer", root_tree_id: bytes
-) -> Iterator[tuple[bytes, bytes]]:
+    store: "ObjectContainer", root_tree_id: "ObjectID"
+) -> Iterator[tuple[bytes, "ObjectID"]]:
     """Iterate over cached submodules.
 
     Args:

+ 14 - 7
dulwich/tests/test_object_store.py

@@ -38,12 +38,14 @@ from dulwich.object_store import (
 from dulwich.objects import (
     Blob,
     Commit,
+    ObjectID,
     ShaFile,
     Tag,
     Tree,
     TreeEntry,
 )
 from dulwich.protocol import DEPTH_INFINITE
+from dulwich.refs import Ref
 
 from .utils import make_commit, make_object, make_tag
 
@@ -71,20 +73,25 @@ class ObjectStoreTests:
     def test_determine_wants_all(self) -> None:
         """Test determine_wants_all with valid ref."""
         self.assertEqual(
-            [b"1" * 40],
-            self.store.determine_wants_all({b"refs/heads/foo": b"1" * 40}),
+            [ObjectID(b"1" * 40)],
+            self.store.determine_wants_all(
+                {Ref(b"refs/heads/foo"): ObjectID(b"1" * 40)}
+            ),
         )
 
     def test_determine_wants_all_zero(self) -> None:
         """Test determine_wants_all with zero ref."""
         self.assertEqual(
-            [], self.store.determine_wants_all({b"refs/heads/foo": b"0" * 40})
+            [],
+            self.store.determine_wants_all(
+                {Ref(b"refs/heads/foo"): ObjectID(b"0" * 40)}
+            ),
         )
 
     def test_determine_wants_all_depth(self) -> None:
         """Test determine_wants_all with depth parameter."""
         self.store.add_object(testobject)
-        refs = {b"refs/heads/foo": testobject.id}
+        refs = {Ref(b"refs/heads/foo"): testobject.id}
         with patch.object(self.store, "_get_depth", return_value=1) as m:
             self.assertEqual([], self.store.determine_wants_all(refs, depth=0))
             self.assertEqual(
@@ -124,7 +131,7 @@ class ObjectStoreTests:
 
     def test_get_nonexistant(self) -> None:
         """Test getting non-existent object raises KeyError."""
-        self.assertRaises(KeyError, lambda: self.store[b"a" * 40])
+        self.assertRaises(KeyError, lambda: self.store[ObjectID(b"a" * 40)])
 
     def test_contains_nonexistant(self) -> None:
         """Test checking for non-existent object."""
@@ -300,7 +307,7 @@ class ObjectStoreTests:
         """Test iterating with missing objects when not allowed."""
         blob1 = make_object(Blob, data=b"blob 1 data")
         self.store.add_object(blob1)
-        missing_sha = b"1" * 40
+        missing_sha = ObjectID(b"1" * 40)
 
         self.assertRaises(
             KeyError,
@@ -311,7 +318,7 @@ class ObjectStoreTests:
         """Test iterating with missing objects when allowed."""
         blob1 = make_object(Blob, data=b"blob 1 data")
         self.store.add_object(blob1)
-        missing_sha = b"1" * 40
+        missing_sha = ObjectID(b"1" * 40)
 
         objects = list(
             self.store.iterobjects_subset([blob1.id, missing_sha], allow_missing=True)

+ 4 - 4
dulwich/walk.py

@@ -274,8 +274,8 @@ class Walker:
     def __init__(
         self,
         store: "BaseObjectStore",
-        include: Sequence[bytes],
-        exclude: Sequence[bytes] | None = None,
+        include: ObjectID | Sequence[ObjectID],
+        exclude: Sequence[ObjectID] | None = None,
         order: str = "date",
         reverse: bool = False,
         max_entries: int | None = None,
@@ -284,7 +284,7 @@ class Walker:
         follow: bool = False,
         since: int | None = None,
         until: int | None = None,
-        get_parents: Callable[[Commit], list[bytes]] = lambda commit: commit.parents,
+        get_parents: Callable[[Commit], list[ObjectID]] = lambda commit: commit.parents,
         queue_cls: type = _CommitTimeQueue,
     ) -> None:
         """Constructor.
@@ -468,7 +468,7 @@ class Walker:
 
 def _topo_reorder(
     entries: Iterator[WalkEntry],
-    get_parents: Callable[[Commit], list[bytes]] = lambda commit: commit.parents,
+    get_parents: Callable[[Commit], list[ObjectID]] = lambda commit: commit.parents,
 ) -> Iterator[WalkEntry]:
     """Reorder an iterable of entries topologically.
 

+ 2 - 1
dulwich/web.py

@@ -107,6 +107,7 @@ else:
 from dulwich import log_utils
 
 from .errors import NotGitRepository
+from .objects import ObjectID
 from .protocol import ReceivableProtocol
 from .repo import BaseRepo, Repo
 from .server import (
@@ -309,7 +310,7 @@ def get_loose_object(
     Returns:
       Iterator yielding object contents as bytes
     """
-    sha = (mat.group(1) + mat.group(2)).encode("ascii")
+    sha = cast(ObjectID, (mat.group(1) + mat.group(2)).encode("ascii"))
     logger.info("Sending loose object %s", sha)
     object_store = get_repo(backend, mat).object_store
     if not object_store.contains_loose(sha):

+ 10 - 8
dulwich/worktree.py

@@ -67,7 +67,7 @@ class WorkTreeInfo:
         self,
         path: str,
         head: bytes | None = None,
-        branch: bytes | None = None,
+        branch: Ref | None = None,
         bare: bool = False,
         detached: bool = False,
         locked: bool = False,
@@ -358,7 +358,7 @@ class WorkTree:
 
         index = self._repo.open_index()
         try:
-            commit = self._repo[b"HEAD"]
+            commit = self._repo[Ref(b"HEAD")]
         except KeyError:
             # no head mean no commit in the repo
             for fs_path in fs_paths:
@@ -425,7 +425,7 @@ class WorkTree:
         author_timezone: int | None = None,
         tree: ObjectID | None = None,
         encoding: bytes | None = None,
-        ref: Ref | None = b"HEAD",
+        ref: Ref | None = Ref(b"HEAD"),
         merge_heads: Sequence[ObjectID] | None = None,
         no_verify: bool = False,
         sign: bool | None = None,
@@ -670,7 +670,7 @@ class WorkTree:
 
         return c.id
 
-    def reset_index(self, tree: bytes | None = None) -> None:
+    def reset_index(self, tree: ObjectID | None = None) -> None:
         """Reset the index back to a specific tree.
 
         Args:
@@ -685,7 +685,7 @@ class WorkTree:
         )
 
         if tree is None:
-            head = self._repo[b"HEAD"]
+            head = self._repo[Ref(b"HEAD")]
             if isinstance(head, Tag):
                 _cls, obj = head.object
                 head = self._repo.get_object(obj)
@@ -840,7 +840,7 @@ def list_worktrees(repo: Repo) -> list[WorkTreeInfo]:
         with open(os.path.join(repo.controldir(), "HEAD"), "rb") as f:
             head_contents = f.read().strip()
             if head_contents.startswith(SYMREF):
-                ref_name = head_contents[len(SYMREF) :].strip()
+                ref_name = Ref(head_contents[len(SYMREF) :].strip())
                 main_wt_info.branch = ref_name
             else:
                 main_wt_info.detached = True
@@ -892,7 +892,7 @@ def list_worktrees(repo: Repo) -> list[WorkTreeInfo]:
                 with open(head_path, "rb") as f:
                     head_contents = f.read().strip()
                     if head_contents.startswith(SYMREF):
-                        ref_name = head_contents[len(SYMREF) :].strip()
+                        ref_name = Ref(head_contents[len(SYMREF) :].strip())
                         wt_info.branch = ref_name
                         # Resolve ref to get commit sha
                         try:
@@ -1005,7 +1005,9 @@ def add_worktree(
     else:
         # Point to branch
         assert branch is not None  # Should be guaranteed by logic above
-        wt_repo.refs.set_symbolic_ref(b"HEAD", branch)
+        from dulwich.refs import HEADREF
+
+        wt_repo.refs.set_symbolic_ref(HEADREF, branch)
 
     # Reset index to match HEAD
     wt_repo.get_worktree().reset_index()

+ 3 - 3
tests/compat/test_reftable.py

@@ -498,7 +498,7 @@ class ReftableCompatTestCase(CompatTestCase):
 
         # Delete a ref using dulwich
         with repo.refs.batch_update():
-            repo.refs.set_if_equals(b"refs/heads/feature", commit_sha2, None)
+            repo.refs.remove_if_equals(b"refs/heads/feature", commit_sha2)
 
         repo.close()
 
@@ -725,8 +725,8 @@ class ReftableCompatTestCase(CompatTestCase):
             repo.refs.set_if_equals(
                 b"refs/heads/develop", commits[1], commits[4]
             )  # Update develop
-            repo.refs.set_if_equals(
-                b"refs/heads/feature", commits[3], None
+            repo.refs.remove_if_equals(
+                b"refs/heads/feature", commits[3]
             )  # Delete feature
             repo.refs.set_symbolic_ref(
                 b"HEAD", b"refs/heads/develop"

+ 3 - 1
tests/test_bundle.py

@@ -413,7 +413,9 @@ class BundleTests(TestCase):
         repo.refs[b"refs/heads/feature"] = commit.id
 
         # Create bundle with only master ref
-        bundle = create_bundle_from_repo(repo, refs=[b"refs/heads/master"])
+        from dulwich.refs import Ref
+
+        bundle = create_bundle_from_repo(repo, refs=[Ref(b"refs/heads/master")])
 
         # Verify only master ref is included
         self.assertEqual(len(bundle.references), 1)

Неке датотеке нису приказане због велике количине промена