Просмотр исходного кода

Replace hardcoded SHA-1 byte lengths in pack.py with object_format attributes

Jelmer Vernooij 2 месяцев назад
Родитель
Сommit
5d25da17a2
1 измененных файлов с 47 добавлено и 17 удалено
  1. 47 17
      dulwich/pack.py

+ 47 - 17
dulwich/pack.py

@@ -758,12 +758,15 @@ class MemoryPackIndex(PackIndex):
         self,
         self,
         entries: list[PackIndexEntry],
         entries: list[PackIndexEntry],
         pack_checksum: bytes | None = None,
         pack_checksum: bytes | None = None,
+        *,
+        object_format: ObjectFormat | None = None,
     ) -> None:
     ) -> None:
         """Create a new MemoryPackIndex.
         """Create a new MemoryPackIndex.
 
 
         Args:
         Args:
           entries: Sequence of name, idx, crc32 (sorted)
           entries: Sequence of name, idx, crc32 (sorted)
           pack_checksum: Optional pack checksum
           pack_checksum: Optional pack checksum
+          object_format: Object format (hash algorithm) to use
         """
         """
         self._by_sha = {}
         self._by_sha = {}
         self._by_offset = {}
         self._by_offset = {}
@@ -773,6 +776,17 @@ class MemoryPackIndex(PackIndex):
         self._entries = entries
         self._entries = entries
         self._pack_checksum = pack_checksum
         self._pack_checksum = pack_checksum
 
 
+        # Set hash size from object format
+        if object_format:
+            self.hash_size = object_format.oid_length
+        else:
+            warnings.warn(
+                "MemoryPackIndex() should be called with object_format parameter",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+            self.hash_size = 20  # Default to SHA1
+
     def get_pack_checksum(self) -> bytes | None:
     def get_pack_checksum(self) -> bytes | None:
         """Return the SHA checksum stored for the corresponding packfile."""
         """Return the SHA checksum stored for the corresponding packfile."""
         return self._pack_checksum
         return self._pack_checksum
@@ -788,9 +802,9 @@ class MemoryPackIndex(PackIndex):
           sha: SHA to look up (binary or hex)
           sha: SHA to look up (binary or hex)
         Returns: Offset in the pack file
         Returns: Offset in the pack file
         """
         """
-        if len(sha) in (40, 64):  # Hex string (SHA1 or SHA256)
-            sha = hex_to_sha(cast(ObjectID, sha))
-        return self._by_sha[cast(RawObjectID, sha)]
+        if len(sha) == self.hash_size * 2:  # hex string
+            sha = hex_to_sha(sha)
+        return self._by_sha[sha]
 
 
     def object_sha1(self, offset: int) -> bytes:
     def object_sha1(self, offset: int) -> bytes:
         """Return the SHA1 for the object at the given offset."""
         """Return the SHA1 for the object at the given offset."""
@@ -971,8 +985,8 @@ class FilePackIndex(PackIndex):
         lives at within the corresponding pack file. If the pack file doesn't
         lives at within the corresponding pack file. If the pack file doesn't
         have the object then None will be returned.
         have the object then None will be returned.
         """
         """
-        if len(sha) == 40:
-            sha = hex_to_sha(cast(ObjectID, sha))
+        if len(sha) == self.hash_size * 2:  # hex string
+            sha = hex_to_sha(sha)
         try:
         try:
             return self._object_offset(sha)
             return self._object_offset(sha)
         except ValueError as exc:
         except ValueError as exc:
@@ -1056,12 +1070,9 @@ class PackIndex1(FilePackIndex):
 
 
     def _unpack_entry(self, i: int) -> tuple[RawObjectID, int, None]:
     def _unpack_entry(self, i: int) -> tuple[RawObjectID, int, None]:
         base_offset = (0x100 * 4) + (i * self._entry_size)
         base_offset = (0x100 * 4) + (i * self._entry_size)
-        if self.hash_size == 20:
-            (offset, name) = unpack_from(">L20s", self._contents, base_offset)
-        else:  # SHA256
-            offset = unpack_from(">L", self._contents, base_offset)[0]
-            name = self._contents[base_offset + 4 : base_offset + 4 + self.hash_size]
-        return (RawObjectID(name), offset, None)
+        offset = unpack_from(">L", self._contents, base_offset)[0]
+        name = self._contents[base_offset + 4 : base_offset + 4 + self.hash_size]
+        return (name, offset, None)
 
 
     def _unpack_name(self, i: int) -> bytes:
     def _unpack_name(self, i: int) -> bytes:
         offset = (0x100 * 4) + (i * self._entry_size) + 4
         offset = (0x100 * 4) + (i * self._entry_size) + 4
@@ -2552,6 +2563,7 @@ class SHA1Writer(BinaryIO):
 
 
 def pack_object_header(
 def pack_object_header(
     type_num: int, delta_base: bytes | int | None, size: int
     type_num: int, delta_base: bytes | int | None, size: int
+    object_format: "ObjectFormat" | None = None
 ) -> bytearray:
 ) -> bytearray:
     """Create a pack object header for the given object info.
     """Create a pack object header for the given object info.
 
 
@@ -2559,8 +2571,18 @@ def pack_object_header(
       type_num: Numeric type of the object.
       type_num: Numeric type of the object.
       delta_base: Delta base offset or ref, or None for whole objects.
       delta_base: Delta base offset or ref, or None for whole objects.
       size: Uncompressed object size.
       size: Uncompressed object size.
+      object_format: Object format (hash algorithm) to use.
     Returns: A header for a packed object.
     Returns: A header for a packed object.
     """
     """
+    from .object_format import DEFAULT_OBJECT_FORMAT
+    if object_format is None:
+        warnings.warn(
+            "pack_object_header() should be called with object_format parameter",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        object_format = DEFAULT_OBJECT_FORMAT
+
     header = []
     header = []
     c = (type_num << 4) | (size & 15)
     c = (type_num << 4) | (size & 15)
     size >>= 4
     size >>= 4
@@ -2579,8 +2601,7 @@ def pack_object_header(
             delta_base >>= 7
             delta_base >>= 7
         header.extend(ret)
         header.extend(ret)
     elif type_num == REF_DELTA:
     elif type_num == REF_DELTA:
-        assert isinstance(delta_base, bytes)
-        assert len(delta_base) == 20
+        assert len(delta_base) == object_format.oid_length
         header += delta_base
         header += delta_base
     return bytearray(header)
     return bytearray(header)
 
 
@@ -3955,8 +3976,8 @@ class Pack:
                 assert isinstance(base_type, int)
                 assert isinstance(base_type, int)
             elif base_type == REF_DELTA:
             elif base_type == REF_DELTA:
                 (basename, delta) = base_obj
                 (basename, delta) = base_obj
-                assert isinstance(basename, bytes) and len(basename) == 20
-                base_offset, base_type, base_obj = get_ref(cast(RawObjectID, basename))
+                assert isinstance(basename, bytes) and len(basename) == self.object_format.oid_length
+                base_offset, base_type, base_obj = get_ref(basename)
                 assert isinstance(base_type, int)
                 assert isinstance(base_type, int)
                 if base_offset == prev_offset:  # object is based on itself
                 if base_offset == prev_offset:  # object is based on itself
                     raise UnresolvedDeltas([basename])
                     raise UnresolvedDeltas([basename])
@@ -4046,12 +4067,21 @@ def extend_pack(
     *,
     *,
     compression_level: int = -1,
     compression_level: int = -1,
     progress: Callable[[bytes], None] | None = None,
     progress: Callable[[bytes], None] | None = None,
-) -> tuple[bytes, list[tuple["RawObjectID", int, int]]]:
+    object_format: ObjectFormat | None = None,
+) -> tuple[bytes, list[tuple[bytes, int, int]]]:
     """Extend a pack file with more objects.
     """Extend a pack file with more objects.
 
 
     The caller should make sure that object_ids does not contain any objects
     The caller should make sure that object_ids does not contain any objects
     that are already in the pack
     that are already in the pack
     """
     """
+    from .object_format import DEFAULT_OBJECT_FORMAT
+    if object_format is None:
+        warnings.warn(
+            "extend_pack() should be called with object_format parameter",
+            DeprecationWarning,
+            stacklevel=2,
+        )
+        object_format = DEFAULT_OBJECT_FORMAT
     # Update the header with the new number of objects.
     # Update the header with the new number of objects.
     f.seek(0)
     f.seek(0)
     _version, num_objects = read_pack_header(f.read)
     _version, num_objects = read_pack_header(f.read)
@@ -4077,7 +4107,7 @@ def extend_pack(
             progress(
             progress(
                 (f"writing extra base objects: {i}/{len(object_ids)}\r").encode("ascii")
                 (f"writing extra base objects: {i}/{len(object_ids)}\r").encode("ascii")
             )
             )
-        assert len(object_id) == 20
+        assert len(object_id) == object_format.oid_length
         type_num, data = get_raw(object_id)
         type_num, data = get_raw(object_id)
         offset = f.tell()
         offset = f.tell()
         crc32 = write_pack_object(
         crc32 = write_pack_object(