Browse Source

patch: Add typing

Jelmer Vernooij 3 months ago
parent
commit
99fa92c431
1 changed files with 91 additions and 34 deletions
  1. 91 34
      dulwich/patch.py

+ 91 - 34
dulwich/patch.py

@@ -27,28 +27,46 @@ on.
 
 import email.parser
 import time
+from collections.abc import Generator
 from difflib import SequenceMatcher
-from typing import BinaryIO, Optional, TextIO, Union
+from typing import (
+    TYPE_CHECKING,
+    BinaryIO,
+    Optional,
+    TextIO,
+    Union,
+)
+
+if TYPE_CHECKING:
+    import email.message
+
+    from .object_store import BaseObjectStore
 
 from .objects import S_ISGITLINK, Blob, Commit
-from .pack import ObjectContainer
 
 FIRST_FEW_BYTES = 8000
 
 
 def write_commit_patch(
-    f, commit, contents, progress, version=None, encoding=None
+    f: BinaryIO,
+    commit: "Commit",
+    contents: Union[str, bytes],
+    progress: tuple[int, int],
+    version: Optional[str] = None,
+    encoding: Optional[str] = None,
 ) -> None:
     """Write a individual file patch.
 
     Args:
       commit: Commit object
-      progress: Tuple with current patch number and total.
+      progress: tuple with current patch number and total.
 
     Returns:
       tuple with filename and contents
     """
     encoding = encoding or getattr(f, "encoding", "ascii")
+    if encoding is None:
+        encoding = "ascii"
     if isinstance(contents, str):
         contents = contents.encode(encoding)
     (num, total) = progress
@@ -87,10 +105,12 @@ def write_commit_patch(
 
         f.write(b"Dulwich %d.%d.%d\n" % dulwich_version)
     else:
+        if encoding is None:
+            encoding = "ascii"
         f.write(version.encode(encoding) + b"\n")
 
 
-def get_summary(commit):
+def get_summary(commit: "Commit") -> str:
     """Determine the summary line for use in a filename.
 
     Args:
@@ -102,7 +122,7 @@ def get_summary(commit):
 
 
 #  Unified Diff
-def _format_range_unified(start, stop) -> str:
+def _format_range_unified(start: int, stop: int) -> str:
     """Convert range to the "ed" format."""
     # Per the diff spec at http://www.unix.org/single_unix_specification/
     beginning = start + 1  # lines start numbering with one
@@ -115,17 +135,17 @@ def _format_range_unified(start, stop) -> str:
 
 
 def unified_diff(
-    a,
-    b,
-    fromfile="",
-    tofile="",
-    fromfiledate="",
-    tofiledate="",
-    n=3,
-    lineterm="\n",
-    tree_encoding="utf-8",
-    output_encoding="utf-8",
-):
+    a: list[bytes],
+    b: list[bytes],
+    fromfile: bytes = b"",
+    tofile: bytes = b"",
+    fromfiledate: str = "",
+    tofiledate: str = "",
+    n: int = 3,
+    lineterm: str = "\n",
+    tree_encoding: str = "utf-8",
+    output_encoding: str = "utf-8",
+) -> Generator[bytes, None, None]:
     """difflib.unified_diff that can detect "No newline at end of file" as
     original "git diff" does.
 
@@ -166,7 +186,7 @@ def unified_diff(
                     yield b"+" + line
 
 
-def is_binary(content):
+def is_binary(content: bytes) -> bool:
     """See if the first few bytes contain any null characters.
 
     Args:
@@ -175,14 +195,14 @@ def is_binary(content):
     return b"\0" in content[:FIRST_FEW_BYTES]
 
 
-def shortid(hexsha):
+def shortid(hexsha: Optional[bytes]) -> bytes:
     if hexsha is None:
         return b"0" * 7
     else:
         return hexsha[:7]
 
 
-def patch_filename(p, root):
+def patch_filename(p: Optional[bytes], root: bytes) -> bytes:
     if p is None:
         return b"/dev/null"
     else:
@@ -190,7 +210,11 @@ def patch_filename(p, root):
 
 
 def write_object_diff(
-    f, store: ObjectContainer, old_file, new_file, diff_binary=False
+    f: BinaryIO,
+    store: "BaseObjectStore",
+    old_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
+    new_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
+    diff_binary: bool = False,
 ) -> None:
     """Write the diff for an object.
 
@@ -209,15 +233,22 @@ def write_object_diff(
     patched_old_path = patch_filename(old_path, b"a")
     patched_new_path = patch_filename(new_path, b"b")
 
-    def content(mode, hexsha):
+    def content(mode: Optional[int], hexsha: Optional[bytes]) -> Blob:
+        from typing import cast
+
         if hexsha is None:
-            return Blob.from_string(b"")
-        elif S_ISGITLINK(mode):
-            return Blob.from_string(b"Subproject commit " + hexsha + b"\n")
+            return cast(Blob, Blob.from_string(b""))
+        elif mode is not None and S_ISGITLINK(mode):
+            return cast(Blob, Blob.from_string(b"Subproject commit " + hexsha + b"\n"))
         else:
-            return store[hexsha]
+            obj = store[hexsha]
+            if isinstance(obj, Blob):
+                return obj
+            else:
+                # Fallback for non-blob objects
+                return cast(Blob, Blob.from_string(obj.as_raw_string()))
 
-    def lines(content):
+    def lines(content: "Blob") -> list[bytes]:
         if not content:
             return []
         else:
@@ -249,7 +280,11 @@ def write_object_diff(
 
 
 # TODO(jelmer): Support writing unicode, rather than bytes.
-def gen_diff_header(paths, modes, shas):
+def gen_diff_header(
+    paths: tuple[Optional[bytes], Optional[bytes]],
+    modes: tuple[Optional[int], Optional[int]],
+    shas: tuple[Optional[bytes], Optional[bytes]],
+) -> Generator[bytes, None, None]:
     """Write a blob diff header.
 
     Args:
@@ -282,7 +317,11 @@ def gen_diff_header(paths, modes, shas):
 
 
 # TODO(jelmer): Support writing unicode, rather than bytes.
-def write_blob_diff(f, old_file, new_file) -> None:
+def write_blob_diff(
+    f: BinaryIO,
+    old_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
+    new_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
+) -> None:
     """Write blob diff.
 
     Args:
@@ -297,7 +336,7 @@ def write_blob_diff(f, old_file, new_file) -> None:
     patched_old_path = patch_filename(old_path, b"a")
     patched_new_path = patch_filename(new_path, b"b")
 
-    def lines(blob):
+    def lines(blob: Optional["Blob"]) -> list[bytes]:
         if blob is not None:
             return blob.splitlines()
         else:
@@ -317,7 +356,13 @@ def write_blob_diff(f, old_file, new_file) -> None:
     )
 
 
-def write_tree_diff(f, store, old_tree, new_tree, diff_binary=False) -> None:
+def write_tree_diff(
+    f: BinaryIO,
+    store: "BaseObjectStore",
+    old_tree: Optional[bytes],
+    new_tree: Optional[bytes],
+    diff_binary: bool = False,
+) -> None:
     """Write tree diff.
 
     Args:
@@ -338,7 +383,9 @@ def write_tree_diff(f, store, old_tree, new_tree, diff_binary=False) -> None:
         )
 
 
-def git_am_patch_split(f: Union[TextIO, BinaryIO], encoding: Optional[str] = None):
+def git_am_patch_split(
+    f: Union[TextIO, BinaryIO], encoding: Optional[str] = None
+) -> tuple["Commit", bytes, Optional[bytes]]:
     """Parse a git-am-style patch and split it up into bits.
 
     Args:
@@ -358,7 +405,9 @@ def git_am_patch_split(f: Union[TextIO, BinaryIO], encoding: Optional[str] = Non
     return parse_patch_message(msg, encoding)
 
 
-def parse_patch_message(msg, encoding=None):
+def parse_patch_message(
+    msg: "email.message.Message", encoding: Optional[str] = None
+) -> tuple["Commit", bytes, Optional[bytes]]:
     """Extract a Commit object and patch from an e-mail message.
 
     Args:
@@ -367,6 +416,8 @@ def parse_patch_message(msg, encoding=None):
     Returns: Tuple with commit object, diff contents and git version
     """
     c = Commit()
+    if encoding is None:
+        encoding = "ascii"
     c.author = msg["from"].encode(encoding)
     c.committer = msg["from"].encode(encoding)
     try:
@@ -380,7 +431,13 @@ def parse_patch_message(msg, encoding=None):
     first = True
 
     body = msg.get_payload(decode=True)
-    lines = body.splitlines(True)
+    if isinstance(body, str):
+        body = body.encode(encoding)
+    if isinstance(body, bytes):
+        lines = body.splitlines(True)
+    else:
+        # Handle other types by converting to string first
+        lines = str(body).encode(encoding).splitlines(True)
     line_iter = iter(lines)
 
     for line in line_iter: