浏览代码

bundle: Add typing

Jelmer Vernooij 1 月之前
父节点
当前提交
c3f2d42332
共有 1 个文件被更改,包括 10 次插入12 次删除
  1. 10 12
      dulwich/bundle.py

+ 10 - 12
dulwich/bundle.py

@@ -21,8 +21,7 @@
 
 """Bundle format support."""
 
-from collections.abc import Sequence
-from typing import Optional, Union
+from typing import BinaryIO, Optional
 
 from .pack import PackData, write_pack_data
 
@@ -30,10 +29,10 @@ from .pack import PackData, write_pack_data
 class Bundle:
     version: Optional[int]
 
-    capabilities: dict[str, str]
+    capabilities: dict[str, Optional[str]]
     prerequisites: list[tuple[bytes, str]]
-    references: dict[str, bytes]
-    pack_data: Union[PackData, Sequence[bytes]]
+    references: dict[bytes, bytes]
+    pack_data: PackData
 
     def __repr__(self) -> str:
         return (
@@ -43,7 +42,7 @@ class Bundle:
             f"references={self.references})>"
         )
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         if not isinstance(other, type(self)):
             return False
         if self.version != other.version:
@@ -59,7 +58,7 @@ class Bundle:
         return True
 
 
-def _read_bundle(f, version):
+def _read_bundle(f: BinaryIO, version: int) -> Bundle:
     capabilities = {}
     prerequisites = []
     references = {}
@@ -68,12 +67,11 @@ def _read_bundle(f, version):
         while line.startswith(b"@"):
             line = line[1:].rstrip(b"\n")
             try:
-                key, value = line.split(b"=", 1)
+                key, value_bytes = line.split(b"=", 1)
+                value = value_bytes.decode("utf-8")
             except ValueError:
                 key = line
                 value = None
-            else:
-                value = value.decode("utf-8")
             capabilities[key.decode("utf-8")] = value
             line = f.readline()
     while line.startswith(b"-"):
@@ -94,7 +92,7 @@ def _read_bundle(f, version):
     return ret
 
 
-def read_bundle(f):
+def read_bundle(f: BinaryIO) -> Bundle:
     """Read a bundle file."""
     firstline = f.readline()
     if firstline == b"# v2 git bundle\n":
@@ -104,7 +102,7 @@ def read_bundle(f):
     raise AssertionError(f"unsupported bundle format header: {firstline!r}")
 
 
-def write_bundle(f, bundle) -> None:
+def write_bundle(f: BinaryIO, bundle: Bundle) -> None:
     version = bundle.version
     if version is None:
         if bundle.capabilities: