Ver Fonte

archive: add typing

Jelmer Vernooij há 1 mês atrás
pai
commit
91d990068d
1 ficheiros alterados com 43 adições e 13 exclusões
  1. 43 13
      dulwich/archive.py

+ 43 - 13
dulwich/archive.py

@@ -26,9 +26,17 @@ import posixpath
 import stat
 import struct
 import tarfile
+from collections.abc import Generator
 from contextlib import closing
 from io import BytesIO
 from os import SEEK_END
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+    from .object_store import BaseObjectStore
+    from .objects import TreeEntry
+
+from .objects import Tree
 
 
 class ChunkedBytesIO:
@@ -42,33 +50,43 @@ class ChunkedBytesIO:
             list_of_bytestrings)
     """
 
-    def __init__(self, contents) -> None:
+    def __init__(self, contents: list[bytes]) -> None:
         self.contents = contents
         self.pos = (0, 0)
 
-    def read(self, maxbytes=None):
-        if maxbytes < 0:
-            maxbytes = float("inf")
+    def read(self, maxbytes: Optional[int] = None) -> bytes:
+        if maxbytes is None or maxbytes < 0:
+            remaining = None
+        else:
+            remaining = maxbytes
 
         buf = []
         chunk, cursor = self.pos
 
         while chunk < len(self.contents):
-            if maxbytes < len(self.contents[chunk]) - cursor:
-                buf.append(self.contents[chunk][cursor : cursor + maxbytes])
-                cursor += maxbytes
+            chunk_remainder = len(self.contents[chunk]) - cursor
+            if remaining is not None and remaining < chunk_remainder:
+                buf.append(self.contents[chunk][cursor : cursor + remaining])
+                cursor += remaining
                 self.pos = (chunk, cursor)
                 break
             else:
                 buf.append(self.contents[chunk][cursor:])
-                maxbytes -= len(self.contents[chunk]) - cursor
+                if remaining is not None:
+                    remaining -= chunk_remainder
                 chunk += 1
                 cursor = 0
                 self.pos = (chunk, cursor)
         return b"".join(buf)
 
 
-def tar_stream(store, tree, mtime, prefix=b"", format=""):
+def tar_stream(
+    store: "BaseObjectStore",
+    tree: "Tree",
+    mtime: int,
+    prefix: bytes = b"",
+    format: str = "",
+) -> Generator[bytes, None, None]:
     """Generate a tar stream for the contents of a Git tree.
 
     Returns a generator that lazily assembles a .tar.gz archive, yielding it in
@@ -85,7 +103,11 @@ def tar_stream(store, tree, mtime, prefix=b"", format=""):
       Bytestrings
     """
     buf = BytesIO()
-    with closing(tarfile.open(None, f"w:{format}", buf)) as tar:
+    mode = "w:" + format if format else "w"
+    from typing import Any, cast
+
+    # The tarfile.open overloads are complex; cast to Any to avoid issues
+    with closing(cast(Any, tarfile.open)(name=None, mode=mode, fileobj=buf)) as tar:
         if format == "gz":
             # Manually correct the gzip header file modification time so that
             # archives created from the same Git tree are always identical.
@@ -105,7 +127,11 @@ def tar_stream(store, tree, mtime, prefix=b"", format=""):
                 # Entry probably refers to a submodule, which we don't yet
                 # support.
                 continue
-            data = ChunkedBytesIO(blob.chunked)
+            if hasattr(blob, "chunked"):
+                data = ChunkedBytesIO(blob.chunked)
+            else:
+                # Fallback for objects without chunked attribute
+                data = ChunkedBytesIO([blob.as_raw_string()])
 
             info = tarfile.TarInfo()
             # tarfile only works with ascii.
@@ -121,13 +147,17 @@ def tar_stream(store, tree, mtime, prefix=b"", format=""):
     yield buf.getvalue()
 
 
-def _walk_tree(store, tree, root=b""):
+def _walk_tree(
+    store: "BaseObjectStore", tree: "Tree", root: bytes = b""
+) -> Generator[tuple[bytes, "TreeEntry"], None, None]:
     """Recursively walk a dulwich Tree, yielding tuples of
     (absolute path, TreeEntry) along the way.
     """
     for entry in tree.iteritems():
         entry_abspath = posixpath.join(root, entry.path)
         if stat.S_ISDIR(entry.mode):
-            yield from _walk_tree(store, store[entry.sha], entry_abspath)
+            subtree = store[entry.sha]
+            if isinstance(subtree, Tree):
+                yield from _walk_tree(store, subtree, entry_abspath)
         else:
             yield (entry_abspath, entry)