|
@@ -18,9 +18,13 @@ https://git-scm.com/docs/gitformat-commit-graph
|
|
|
|
|
|
import os
|
|
|
import struct
|
|
|
-from typing import BinaryIO, Optional, Union
|
|
|
+from collections.abc import Iterator
|
|
|
+from typing import TYPE_CHECKING, BinaryIO, Optional, Union
|
|
|
|
|
|
-from .objects import ObjectID, hex_to_sha, sha_to_hex
|
|
|
+if TYPE_CHECKING:
|
|
|
+ from .object_store import BaseObjectStore
|
|
|
+
|
|
|
+from .objects import Commit, ObjectID, hex_to_sha, sha_to_hex
|
|
|
|
|
|
# File format constants
|
|
|
COMMIT_GRAPH_SIGNATURE = b"CGPH"
|
|
@@ -358,7 +362,7 @@ class CommitGraph:
|
|
|
"""Return number of commits in the graph."""
|
|
|
return len(self.entries)
|
|
|
|
|
|
- def __iter__(self):
|
|
|
+ def __iter__(self) -> Iterator["CommitGraphEntry"]:
|
|
|
"""Iterate over commit graph entries."""
|
|
|
return iter(self.entries)
|
|
|
|
|
@@ -396,7 +400,9 @@ def find_commit_graph_file(git_dir: Union[str, bytes]) -> Optional[bytes]:
|
|
|
return None
|
|
|
|
|
|
|
|
|
-def generate_commit_graph(object_store, commit_ids: list[ObjectID]) -> CommitGraph:
|
|
|
+def generate_commit_graph(
|
|
|
+ object_store: "BaseObjectStore", commit_ids: list[ObjectID]
|
|
|
+) -> CommitGraph:
|
|
|
"""Generate a commit graph from a set of commits.
|
|
|
|
|
|
Args:
|
|
@@ -426,12 +432,13 @@ def generate_commit_graph(object_store, commit_ids: list[ObjectID]) -> CommitGra
|
|
|
normalized_commit_ids.append(commit_id)
|
|
|
|
|
|
# Build a map of all commits and their metadata
|
|
|
- commit_map = {}
|
|
|
+ commit_map: dict[bytes, Commit] = {}
|
|
|
for commit_id in normalized_commit_ids:
|
|
|
try:
|
|
|
commit_obj = object_store[commit_id]
|
|
|
if commit_obj.type_name != b"commit":
|
|
|
continue
|
|
|
+ assert isinstance(commit_obj, Commit)
|
|
|
commit_map[commit_id] = commit_obj
|
|
|
except KeyError:
|
|
|
# Commit not found, skip
|
|
@@ -440,7 +447,7 @@ def generate_commit_graph(object_store, commit_ids: list[ObjectID]) -> CommitGra
|
|
|
# Calculate generation numbers using topological sort
|
|
|
generation_map: dict[bytes, int] = {}
|
|
|
|
|
|
- def calculate_generation(commit_id):
|
|
|
+ def calculate_generation(commit_id: ObjectID) -> int:
|
|
|
if commit_id in generation_map:
|
|
|
return generation_map[commit_id]
|
|
|
|
|
@@ -507,7 +514,9 @@ def generate_commit_graph(object_store, commit_ids: list[ObjectID]) -> CommitGra
|
|
|
|
|
|
|
|
|
def write_commit_graph(
|
|
|
- git_dir: Union[str, bytes], object_store, commit_ids: list[ObjectID]
|
|
|
+ git_dir: Union[str, bytes],
|
|
|
+ object_store: "BaseObjectStore",
|
|
|
+ commit_ids: list[ObjectID],
|
|
|
) -> None:
|
|
|
"""Write a commit graph file for the given commits.
|
|
|
|
|
@@ -534,11 +543,13 @@ def write_commit_graph(
|
|
|
|
|
|
graph_path = os.path.join(info_dir, b"commit-graph")
|
|
|
with GitFile(graph_path, "wb") as f:
|
|
|
- graph.write_to_file(f)
|
|
|
+ from typing import BinaryIO, cast
|
|
|
+
|
|
|
+ graph.write_to_file(cast(BinaryIO, f))
|
|
|
|
|
|
|
|
|
def get_reachable_commits(
|
|
|
- object_store, start_commits: list[ObjectID]
|
|
|
+ object_store: "BaseObjectStore", start_commits: list[ObjectID]
|
|
|
) -> list[ObjectID]:
|
|
|
"""Get all commits reachable from the given starting commits.
|
|
|
|
|
@@ -578,7 +589,7 @@ def get_reachable_commits(
|
|
|
|
|
|
try:
|
|
|
commit_obj = object_store[commit_id]
|
|
|
- if commit_obj.type_name != b"commit":
|
|
|
+ if not isinstance(commit_obj, Commit):
|
|
|
continue
|
|
|
|
|
|
# Add to reachable list (commit_id is already hex ObjectID)
|