Explorar el Código

commit_graph: Add typing

Jelmer Vernooij hace 2 meses
padre
commit
2995fdb39b
Se han modificado 1 ficheros con 21 adiciones y 10 borrados
  1. 21 10
      dulwich/commit_graph.py

+ 21 - 10
dulwich/commit_graph.py

@@ -18,9 +18,13 @@ https://git-scm.com/docs/gitformat-commit-graph
 
 import os
 import struct
-from typing import BinaryIO, Optional, Union
+from collections.abc import Iterator
+from typing import TYPE_CHECKING, BinaryIO, Optional, Union
 
-from .objects import ObjectID, hex_to_sha, sha_to_hex
+if TYPE_CHECKING:
+    from .object_store import BaseObjectStore
+
+from .objects import Commit, ObjectID, hex_to_sha, sha_to_hex
 
 # File format constants
 COMMIT_GRAPH_SIGNATURE = b"CGPH"
@@ -358,7 +362,7 @@ class CommitGraph:
         """Return number of commits in the graph."""
         return len(self.entries)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator["CommitGraphEntry"]:
         """Iterate over commit graph entries."""
         return iter(self.entries)
 
@@ -396,7 +400,9 @@ def find_commit_graph_file(git_dir: Union[str, bytes]) -> Optional[bytes]:
     return None
 
 
-def generate_commit_graph(object_store, commit_ids: list[ObjectID]) -> CommitGraph:
+def generate_commit_graph(
+    object_store: "BaseObjectStore", commit_ids: list[ObjectID]
+) -> CommitGraph:
     """Generate a commit graph from a set of commits.
 
     Args:
@@ -426,12 +432,13 @@ def generate_commit_graph(object_store, commit_ids: list[ObjectID]) -> CommitGra
             normalized_commit_ids.append(commit_id)
 
     # Build a map of all commits and their metadata
-    commit_map = {}
+    commit_map: dict[bytes, Commit] = {}
     for commit_id in normalized_commit_ids:
         try:
             commit_obj = object_store[commit_id]
             if commit_obj.type_name != b"commit":
                 continue
+            assert isinstance(commit_obj, Commit)
             commit_map[commit_id] = commit_obj
         except KeyError:
             # Commit not found, skip
@@ -440,7 +447,7 @@ def generate_commit_graph(object_store, commit_ids: list[ObjectID]) -> CommitGra
     # Calculate generation numbers using topological sort
     generation_map: dict[bytes, int] = {}
 
-    def calculate_generation(commit_id):
+    def calculate_generation(commit_id: ObjectID) -> int:
         if commit_id in generation_map:
             return generation_map[commit_id]
 
@@ -507,7 +514,9 @@ def generate_commit_graph(object_store, commit_ids: list[ObjectID]) -> CommitGra
 
 
 def write_commit_graph(
-    git_dir: Union[str, bytes], object_store, commit_ids: list[ObjectID]
+    git_dir: Union[str, bytes],
+    object_store: "BaseObjectStore",
+    commit_ids: list[ObjectID],
 ) -> None:
     """Write a commit graph file for the given commits.
 
@@ -534,11 +543,13 @@ def write_commit_graph(
 
     graph_path = os.path.join(info_dir, b"commit-graph")
     with GitFile(graph_path, "wb") as f:
-        graph.write_to_file(f)
+        from typing import BinaryIO, cast
+
+        graph.write_to_file(cast(BinaryIO, f))
 
 
 def get_reachable_commits(
-    object_store, start_commits: list[ObjectID]
+    object_store: "BaseObjectStore", start_commits: list[ObjectID]
 ) -> list[ObjectID]:
     """Get all commits reachable from the given starting commits.
 
@@ -578,7 +589,7 @@ def get_reachable_commits(
 
         try:
             commit_obj = object_store[commit_id]
-            if commit_obj.type_name != b"commit":
+            if not isinstance(commit_obj, Commit):
                 continue
 
             # Add to reachable list (commit_id is already hex ObjectID)