Explorar el Código

Replace hardcoded SHA-1 byte lengths in pack.py with object_format attributes

Jelmer Vernooij hace 2 meses
padre
commit
5d25da17a2
Se han modificado 1 ficheros con 47 adiciones y 17 borrados
  1. 47 17
      dulwich/pack.py

+ 47 - 17
dulwich/pack.py

@@ -758,12 +758,15 @@ class MemoryPackIndex(PackIndex):
         self,
         entries: list[PackIndexEntry],
         pack_checksum: bytes | None = None,
+        *,
+        object_format: ObjectFormat | None = None,
     ) -> None:
         """Create a new MemoryPackIndex.
 
         Args:
           entries: Sequence of name, idx, crc32 (sorted)
           pack_checksum: Optional pack checksum
+          object_format: Object format (hash algorithm) to use
         """
         self._by_sha = {}
         self._by_offset = {}
@@ -773,6 +776,17 @@ class MemoryPackIndex(PackIndex):
         self._entries = entries
         self._pack_checksum = pack_checksum
 
+        # Set hash size from object format
+        if object_format:
+            self.hash_size = object_format.oid_length
+        else:
+            warnings.warn(
+                "MemoryPackIndex() should be called with object_format parameter",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            self.hash_size = 20  # Default to SHA1
+
     def get_pack_checksum(self) -> bytes | None:
         """Return the SHA checksum stored for the corresponding packfile."""
         return self._pack_checksum
@@ -788,9 +802,9 @@ class MemoryPackIndex(PackIndex):
           sha: SHA to look up (binary or hex)
         Returns: Offset in the pack file
         """
-        if len(sha) in (40, 64):  # Hex string (SHA1 or SHA256)
-            sha = hex_to_sha(cast(ObjectID, sha))
-        return self._by_sha[cast(RawObjectID, sha)]
+        if len(sha) == self.hash_size * 2:  # hex string
+            sha = hex_to_sha(sha)
+        return self._by_sha[sha]
 
     def object_sha1(self, offset: int) -> bytes:
         """Return the SHA1 for the object at the given offset."""
@@ -971,8 +985,8 @@ class FilePackIndex(PackIndex):
         lives at within the corresponding pack file. If the pack file doesn't
         have the object then None will be returned.
         """
-        if len(sha) == 40:
-            sha = hex_to_sha(cast(ObjectID, sha))
+        if len(sha) == self.hash_size * 2:  # hex string
+            sha = hex_to_sha(sha)
         try:
             return self._object_offset(sha)
         except ValueError as exc:
@@ -1056,12 +1070,9 @@ class PackIndex1(FilePackIndex):
 
     def _unpack_entry(self, i: int) -> tuple[RawObjectID, int, None]:
         base_offset = (0x100 * 4) + (i * self._entry_size)
-        if self.hash_size == 20:
-            (offset, name) = unpack_from(">L20s", self._contents, base_offset)
-        else:  # SHA256
-            offset = unpack_from(">L", self._contents, base_offset)[0]
-            name = self._contents[base_offset + 4 : base_offset + 4 + self.hash_size]
-        return (RawObjectID(name), offset, None)
+        offset = unpack_from(">L", self._contents, base_offset)[0]
+        name = self._contents[base_offset + 4 : base_offset + 4 + self.hash_size]
+        return (name, offset, None)
 
     def _unpack_name(self, i: int) -> bytes:
         offset = (0x100 * 4) + (i * self._entry_size) + 4
@@ -2552,6 +2563,7 @@ class SHA1Writer(BinaryIO):
 
 def pack_object_header(
     type_num: int, delta_base: bytes | int | None, size: int
+    object_format: "ObjectFormat" | None = None
 ) -> bytearray:
     """Create a pack object header for the given object info.
 
@@ -2559,8 +2571,18 @@ def pack_object_header(
       type_num: Numeric type of the object.
       delta_base: Delta base offset or ref, or None for whole objects.
       size: Uncompressed object size.
+      object_format: Object format (hash algorithm) to use.
     Returns: A header for a packed object.
     """
+    from .object_format import DEFAULT_OBJECT_FORMAT
+    if object_format is None:
+        warnings.warn(
+            "pack_object_header() should be called with object_format parameter",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        object_format = DEFAULT_OBJECT_FORMAT
+
     header = []
     c = (type_num << 4) | (size & 15)
     size >>= 4
@@ -2579,8 +2601,7 @@ def pack_object_header(
             delta_base >>= 7
         header.extend(ret)
     elif type_num == REF_DELTA:
-        assert isinstance(delta_base, bytes)
-        assert len(delta_base) == 20
+        assert len(delta_base) == object_format.oid_length
         header += delta_base
     return bytearray(header)
 
@@ -3955,8 +3976,8 @@ class Pack:
                 assert isinstance(base_type, int)
             elif base_type == REF_DELTA:
                 (basename, delta) = base_obj
-                assert isinstance(basename, bytes) and len(basename) == 20
-                base_offset, base_type, base_obj = get_ref(cast(RawObjectID, basename))
+                assert isinstance(basename, bytes) and len(basename) == self.object_format.oid_length
+                base_offset, base_type, base_obj = get_ref(basename)
                 assert isinstance(base_type, int)
                 if base_offset == prev_offset:  # object is based on itself
                     raise UnresolvedDeltas([basename])
@@ -4046,12 +4067,21 @@ def extend_pack(
     *,
     compression_level: int = -1,
     progress: Callable[[bytes], None] | None = None,
-) -> tuple[bytes, list[tuple["RawObjectID", int, int]]]:
+    object_format: ObjectFormat | None = None,
+) -> tuple[bytes, list[tuple[bytes, int, int]]]:
     """Extend a pack file with more objects.
 
     The caller should make sure that object_ids does not contain any objects
     that are already in the pack
     """
+    from .object_format import DEFAULT_OBJECT_FORMAT
+    if object_format is None:
+        warnings.warn(
+            "extend_pack() should be called with object_format parameter",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        object_format = DEFAULT_OBJECT_FORMAT
     # Update the header with the new number of objects.
     f.seek(0)
     _version, num_objects = read_pack_header(f.read)
@@ -4077,7 +4107,7 @@ def extend_pack(
             progress(
                 (f"writing extra base objects: {i}/{len(object_ids)}\r").encode("ascii")
             )
-        assert len(object_id) == 20
+        assert len(object_id) == object_format.oid_length
         type_num, data = get_raw(object_id)
         offset = f.tell()
         crc32 = write_pack_object(