瀏覽代碼

fastexport: Add typing

Jelmer Vernooij 1 月之前
父節點
當前提交
77d302b580
共有 1 個文件被更改,包括 63 次插入29 次删除
  1. 63 29
      dulwich/fastexport.py

+ 63 - 29
dulwich/fastexport.py

@@ -23,16 +23,23 @@
 """Fast export/import functionality."""
 
 import stat
+from collections.abc import Generator
+from typing import TYPE_CHECKING, Any, BinaryIO, Optional
 
 from fastimport import commands, parser, processor
 from fastimport import errors as fastimport_errors
 
 from .index import commit_tree
 from .object_store import iter_tree_contents
-from .objects import ZERO_SHA, Blob, Commit, Tag
+from .objects import ZERO_SHA, Blob, Commit, ObjectID, Tag
+from .refs import Ref
 
+if TYPE_CHECKING:
+    from .object_store import BaseObjectStore
+    from .repo import BaseRepo
 
-def split_email(text):
+
+def split_email(text: bytes) -> tuple[bytes, bytes]:
     # TODO(jelmer): Dedupe this and the same functionality in
     # format_annotate_line.
     (name, email) = text.rsplit(b" <", 1)
@@ -42,41 +49,53 @@ def split_email(text):
 class GitFastExporter:
     """Generate a fast-export output stream for Git objects."""
 
-    def __init__(self, outf, store) -> None:
+    def __init__(self, outf: BinaryIO, store: "BaseObjectStore") -> None:
         self.outf = outf
         self.store = store
         self.markers: dict[bytes, bytes] = {}
         self._marker_idx = 0
 
-    def print_cmd(self, cmd) -> None:
-        self.outf.write(getattr(cmd, "__bytes__", cmd.__repr__)() + b"\n")
+    def print_cmd(self, cmd: object) -> None:
+        if hasattr(cmd, "__bytes__"):
+            output = cmd.__bytes__()
+        else:
+            output = cmd.__repr__().encode("utf-8")
+        self.outf.write(output + b"\n")
 
-    def _allocate_marker(self):
+    def _allocate_marker(self) -> bytes:
         self._marker_idx += 1
         return str(self._marker_idx).encode("ascii")
 
-    def _export_blob(self, blob):
+    def _export_blob(self, blob: Blob) -> tuple[Any, bytes]:
         marker = self._allocate_marker()
         self.markers[marker] = blob.id
         return (commands.BlobCommand(marker, blob.data), marker)
 
-    def emit_blob(self, blob):
+    def emit_blob(self, blob: Blob) -> bytes:
         (cmd, marker) = self._export_blob(blob)
         self.print_cmd(cmd)
         return marker
 
-    def _iter_files(self, base_tree, new_tree):
+    def _iter_files(
+        self, base_tree: Optional[bytes], new_tree: Optional[bytes]
+    ) -> Generator[Any, None, None]:
         for (
             (old_path, new_path),
             (old_mode, new_mode),
             (old_hexsha, new_hexsha),
         ) in self.store.tree_changes(base_tree, new_tree):
             if new_path is None:
-                yield commands.FileDeleteCommand(old_path)
+                if old_path is not None:
+                    yield commands.FileDeleteCommand(old_path)
                 continue
-            if not stat.S_ISDIR(new_mode):
-                blob = self.store[new_hexsha]
-                marker = self.emit_blob(blob)
+            marker = b""
+            if new_mode is not None and not stat.S_ISDIR(new_mode):
+                if new_hexsha is not None:
+                    blob = self.store[new_hexsha]
+                    from .objects import Blob
+
+                    if isinstance(blob, Blob):
+                        marker = self.emit_blob(blob)
             if old_path != new_path and old_path is not None:
                 yield commands.FileRenameCommand(old_path, new_path)
             if old_mode != new_mode or old_hexsha != new_hexsha:
@@ -85,7 +104,9 @@ class GitFastExporter:
                     new_path, new_mode, prefixed_marker, None
                 )
 
-    def _export_commit(self, commit, ref, base_tree=None):
+    def _export_commit(
+        self, commit: Commit, ref: Ref, base_tree: Optional[ObjectID] = None
+    ) -> tuple[Any, bytes]:
         file_cmds = list(self._iter_files(base_tree, commit.tree))
         marker = self._allocate_marker()
         if commit.parents:
@@ -113,7 +134,9 @@ class GitFastExporter:
         )
         return (cmd, marker)
 
-    def emit_commit(self, commit, ref, base_tree=None):
+    def emit_commit(
+        self, commit: Commit, ref: Ref, base_tree: Optional[ObjectID] = None
+    ) -> bytes:
         cmd, marker = self._export_commit(commit, ref, base_tree)
         self.print_cmd(cmd)
         return marker
@@ -124,34 +147,40 @@ class GitImportProcessor(processor.ImportProcessor):
 
     # FIXME: Batch creation of objects?
 
-    def __init__(self, repo, params=None, verbose=False, outf=None) -> None:
+    def __init__(
+        self,
+        repo: "BaseRepo",
+        params: Optional[Any] = None,  # noqa: ANN401
+        verbose: bool = False,
+        outf: Optional[BinaryIO] = None,
+    ) -> None:
         processor.ImportProcessor.__init__(self, params, verbose)
         self.repo = repo
         self.last_commit = ZERO_SHA
         self.markers: dict[bytes, bytes] = {}
         self._contents: dict[bytes, tuple[int, bytes]] = {}
 
-    def lookup_object(self, objectish):
+    def lookup_object(self, objectish: bytes) -> ObjectID:
         if objectish.startswith(b":"):
             return self.markers[objectish[1:]]
         return objectish
 
-    def import_stream(self, stream):
+    def import_stream(self, stream: BinaryIO) -> dict[bytes, bytes]:
         p = parser.ImportParser(stream)
         self.process(p.iter_commands)
         return self.markers
 
-    def blob_handler(self, cmd) -> None:
+    def blob_handler(self, cmd: commands.BlobCommand) -> None:
         """Process a BlobCommand."""
         blob = Blob.from_string(cmd.data)
         self.repo.object_store.add_object(blob)
         if cmd.mark:
             self.markers[cmd.mark] = blob.id
 
-    def checkpoint_handler(self, cmd) -> None:
+    def checkpoint_handler(self, cmd: commands.CheckpointCommand) -> None:
         """Process a CheckpointCommand."""
 
-    def commit_handler(self, cmd) -> None:
+    def commit_handler(self, cmd: commands.CommitCommand) -> None:
         """Process a CommitCommand."""
         commit = Commit()
         if cmd.author is not None:
@@ -180,7 +209,7 @@ class GitImportProcessor(processor.ImportProcessor):
             if filecmd.name == b"filemodify":
                 if filecmd.data is not None:
                     blob = Blob.from_string(filecmd.data)
-                    self.repo.object_store.add(blob)
+                    self.repo.object_store.add_object(blob)
                     blob_id = blob.id
                 else:
                     blob_id = self.lookup_object(filecmd.dataref)
@@ -210,16 +239,21 @@ class GitImportProcessor(processor.ImportProcessor):
         if cmd.mark:
             self.markers[cmd.mark] = commit.id
 
-    def progress_handler(self, cmd) -> None:
+    def progress_handler(self, cmd: commands.ProgressCommand) -> None:
         """Process a ProgressCommand."""
 
-    def _reset_base(self, commit_id) -> None:
+    def _reset_base(self, commit_id: ObjectID) -> None:
         if self.last_commit == commit_id:
             return
         self._contents = {}
         self.last_commit = commit_id
         if commit_id != ZERO_SHA:
-            tree_id = self.repo[commit_id].tree
+            from .objects import Commit
+
+            commit = self.repo[commit_id]
+            tree_id = commit.tree if isinstance(commit, Commit) else None
+            if tree_id is None:
+                return
             for (
                 path,
                 mode,
@@ -227,7 +261,7 @@ class GitImportProcessor(processor.ImportProcessor):
             ) in iter_tree_contents(self.repo.object_store, tree_id):
                 self._contents[path] = (mode, hexsha)
 
-    def reset_handler(self, cmd) -> None:
+    def reset_handler(self, cmd: commands.ResetCommand) -> None:
         """Process a ResetCommand."""
         if cmd.from_ is None:
             from_ = ZERO_SHA
@@ -236,15 +270,15 @@ class GitImportProcessor(processor.ImportProcessor):
         self._reset_base(from_)
         self.repo.refs[cmd.ref] = from_
 
-    def tag_handler(self, cmd) -> None:
+    def tag_handler(self, cmd: commands.TagCommand) -> None:
         """Process a TagCommand."""
         tag = Tag()
         tag.tagger = cmd.tagger
         tag.message = cmd.message
         tag.name = cmd.tag
-        self.repo.add_object(tag)
+        self.repo.object_store.add_object(tag)
         self.repo.refs["refs/tags/" + tag.name] = tag.id
 
-    def feature_handler(self, cmd):
+    def feature_handler(self, cmd: commands.FeatureCommand) -> None:
         """Process a FeatureCommand."""
         raise fastimport_errors.UnknownFeature(cmd.feature_name)