Просмотр исходного кода

Add object_format support to commit_graph module

Replace hardcoded SHA-1 byte lengths with object_format attributes to
enable SHA-256 support in commit graph files.
Jelmer Vernooij 2 месяцев назад
Родитель
Сommit
f561be1f02
1 измененных файлов с 56 добавлено и 32 удалено
  1. 56 32
      dulwich/commit_graph.py

+ 56 - 32
dulwich/commit_graph.py

@@ -156,17 +156,27 @@ class CommitGraphChunk:
 class CommitGraph:
     """Git commit graph file reader/writer."""
 
-    def __init__(self, hash_version: int = HASH_VERSION_SHA1) -> None:
+    def __init__(self, *, object_format=None) -> None:
         """Initialize CommitGraph.
 
         Args:
-          hash_version: Hash version to use (SHA1 or SHA256)
+          object_format: Object format to use (defaults to SHA1)
         """
-        self.hash_version = hash_version
+        import warnings
+        from .object_format import DEFAULT_OBJECT_FORMAT, SHA256
+
+        if object_format is None:
+            warnings.warn(
+                "CommitGraph() should be called with object_format parameter",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            object_format = DEFAULT_OBJECT_FORMAT
+        self.object_format = object_format
+        self.hash_version = HASH_VERSION_SHA256 if object_format == SHA256 else HASH_VERSION_SHA1
         self.chunks: dict[bytes, CommitGraphChunk] = {}
         self.entries: list[CommitGraphEntry] = []
         self._oid_to_index: dict[ObjectID, int] = {}
-        self._hash_size = 20 if hash_version == HASH_VERSION_SHA1 else 32
 
     @classmethod
     def from_file(cls, f: BinaryIO) -> "CommitGraph":
@@ -187,10 +197,15 @@ class CommitGraph:
             raise ValueError(f"Unsupported commit graph version: {version}")
 
         self.hash_version = struct.unpack(">B", f.read(1))[0]
-        if self.hash_version not in (HASH_VERSION_SHA1, HASH_VERSION_SHA256):
-            raise ValueError(f"Unsupported hash version: {self.hash_version}")
 
-        self._hash_size = 20 if self.hash_version == HASH_VERSION_SHA1 else 32
+        # Set object_format based on hash_version from file
+        from .object_format import SHA1, SHA256
+        if self.hash_version == HASH_VERSION_SHA1:
+            self.object_format = SHA1
+        elif self.hash_version == HASH_VERSION_SHA256:
+            self.object_format = SHA256
+        else:
+            raise ValueError(f"Unsupported hash version: {self.hash_version}")
 
         num_chunks = struct.unpack(">B", f.read(1))[0]
         struct.unpack(">B", f.read(1))[0]
@@ -225,19 +240,19 @@ class CommitGraph:
 
         # Parse OID lookup chunk
         oid_lookup_data = self.chunks[CHUNK_OID_LOOKUP].data
-        num_commits = len(oid_lookup_data) // self._hash_size
+        num_commits = len(oid_lookup_data) // self.object_format.oid_length
 
         oids = []
         for i in range(num_commits):
-            start = i * self._hash_size
-            end = start + self._hash_size
-            oid = RawObjectID(oid_lookup_data[start:end])
+            start = i * self.object_format.oid_length
+            end = start + self.object_format.oid_length
+            oid = oid_lookup_data[start:end]
             oids.append(oid)
             self._oid_to_index[sha_to_hex(oid)] = i
 
         # Parse commit data chunk
         commit_data = self.chunks[CHUNK_COMMIT_DATA].data
-        expected_size = num_commits * (self._hash_size + 16)
+        expected_size = num_commits * (self.object_format.oid_length + 16)
         if len(commit_data) != expected_size:
             raise ValueError(
                 f"Invalid commit data chunk size: {len(commit_data)}, expected {expected_size}"
@@ -245,11 +260,11 @@ class CommitGraph:
 
         self.entries = []
         for i in range(num_commits):
-            offset = i * (self._hash_size + 16)
+            offset = i * (self.object_format.oid_length + 16)
 
             # Tree OID
-            tree_id = RawObjectID(commit_data[offset : offset + self._hash_size])
-            offset += self._hash_size
+            tree_id = commit_data[offset : offset + self.object_format.oid_length]
+            offset += self.object_format.oid_length
 
             # Parent positions (2 x 4 bytes)
             parent1_pos, parent2_pos = struct.unpack(
@@ -314,7 +329,14 @@ class CommitGraph:
 
     def get_entry_by_oid(self, oid: ObjectID) -> CommitGraphEntry | None:
         """Get commit graph entry by commit OID."""
-        index = self._oid_to_index.get(oid)
+        # Convert hex ObjectID to binary if needed for lookup
+        if isinstance(oid, bytes) and len(oid) == self.object_format.hex_length:
+            # Input is hex ObjectID, convert to binary for internal lookup
+            lookup_oid = hex_to_sha(oid)
+        else:
+            # Input is already binary
+            lookup_oid = oid
+        index = self._oid_to_index.get(lookup_oid)
         if index is not None:
             return self.entries[index]
         return None
@@ -472,19 +494,20 @@ def generate_commit_graph(
     Returns:
         CommitGraph object containing the specified commits
     """
-    graph = CommitGraph()
+    graph = CommitGraph(object_format=object_store.object_format)
 
     if not commit_ids:
         return graph
 
     # Ensure all commit_ids are in the correct format for object store access
-    # DiskObjectStore expects hex ObjectIDs (40-byte hex strings)
-    normalized_commit_ids: list[ObjectID] = []
+    hex_length = object_store.object_format.hex_length
+    oid_length = object_store.object_format.oid_length
+    normalized_commit_ids = []
     for commit_id in commit_ids:
-        if isinstance(commit_id, bytes) and len(commit_id) == 40:
+        if isinstance(commit_id, bytes) and len(commit_id) == hex_length:
             # Already hex ObjectID
-            normalized_commit_ids.append(ObjectID(commit_id))
-        elif isinstance(commit_id, bytes) and len(commit_id) == 20:
+            normalized_commit_ids.append(commit_id)
+        elif isinstance(commit_id, bytes) and len(commit_id) == oid_length:
             # Binary SHA, convert to hex ObjectID
             normalized_commit_ids.append(sha_to_hex(RawObjectID(commit_id)))
         else:
@@ -542,17 +565,16 @@ def generate_commit_graph(
         commit_hex: ObjectID = commit_id
 
         # Handle tree ID - might already be hex ObjectID
-        tree_hex: ObjectID
-        if isinstance(commit_obj.tree, bytes) and len(commit_obj.tree) == 40:
-            tree_hex = ObjectID(commit_obj.tree)  # Already hex ObjectID
+        if isinstance(commit_obj.tree, bytes) and len(commit_obj.tree) == hex_length:
+            tree_hex = commit_obj.tree  # Already hex ObjectID
         else:
             tree_hex = sha_to_hex(commit_obj.tree)  # Binary, convert to hex
 
         # Handle parent IDs - might already be hex ObjectIDs
         parents_hex: list[ObjectID] = []
         for parent_id in commit_obj.parents:
-            if isinstance(parent_id, bytes) and len(parent_id) == 40:
-                parents_hex.append(ObjectID(parent_id))  # Already hex ObjectID
+            if isinstance(parent_id, bytes) and len(parent_id) == hex_length:
+                parents_hex.append(parent_id)  # Already hex ObjectID
             else:
                 parents_hex.append(sha_to_hex(parent_id))  # Binary, convert to hex
 
@@ -622,14 +644,16 @@ def get_reachable_commits(
     reachable: list[ObjectID] = []
     stack: list[ObjectID] = []
 
+    hex_length = object_store.object_format.hex_length
+    oid_length = object_store.object_format.oid_length
+
     # Normalize commit IDs for object store access and tracking
     for commit_id in start_commits:
-        if isinstance(commit_id, bytes) and len(commit_id) == 40:
+        if isinstance(commit_id, bytes) and len(commit_id) == hex_length:
             # Hex ObjectID - use directly for object store access
-            oid = ObjectID(commit_id)
-            if oid not in visited:
-                stack.append(oid)
-        elif isinstance(commit_id, bytes) and len(commit_id) == 20:
+            if commit_id not in visited:
+                stack.append(commit_id)
+        elif isinstance(commit_id, bytes) and len(commit_id) == oid_length:
             # Binary SHA, convert to hex ObjectID for object store access
             hex_id = sha_to_hex(RawObjectID(commit_id))
             if hex_id not in visited: