Explorar el Código

Fix regressions in typing

Jelmer Vernooij hace 5 meses
padre
commit
3636e3fefc
Se han modificado 4 ficheros con 31 adiciones y 13 borrados
  1. 28 1
      dulwich/file.py
  2. 2 2
      dulwich/pack.py
  3. 0 9
      dulwich/refs.py
  4. 1 1
      tests/test_pack.py

+ 28 - 1
dulwich/file.py

@@ -26,7 +26,7 @@ import sys
 import warnings
 from collections.abc import Iterator
 from types import TracebackType
-from typing import IO, Any, ClassVar, Optional, Union
+from typing import IO, Any, ClassVar, Literal, Optional, Union, overload
 
 
 def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
@@ -59,6 +59,33 @@ def _fancy_rename(oldname: Union[str, bytes], newname: Union[str, bytes]) -> Non
     os.remove(tmpfile)
 
 
+@overload
+def GitFile(
+    filename: Union[str, bytes, os.PathLike],
+    mode: Literal["wb"],
+    bufsize: int = -1,
+    mask: int = 0o644,
+) -> "_GitFile": ...
+
+
+@overload
+def GitFile(
+    filename: Union[str, bytes, os.PathLike],
+    mode: Literal["rb"] = "rb",
+    bufsize: int = -1,
+    mask: int = 0o644,
+) -> IO[bytes]: ...
+
+
+@overload
+def GitFile(
+    filename: Union[str, bytes, os.PathLike],
+    mode: str = "rb",
+    bufsize: int = -1,
+    mask: int = 0o644,
+) -> Union[IO[bytes], "_GitFile"]: ...
+
+
 def GitFile(
     filename: Union[str, bytes, os.PathLike],
     mode: str = "rb",

+ 2 - 2
dulwich/pack.py

@@ -609,9 +609,9 @@ class MemoryPackIndex(PackIndex):
         return iter(self._entries)
 
     @classmethod
-    def for_pack(cls, pack: "Pack") -> "MemoryPackIndex":
+    def for_pack(cls, pack_data: "PackData") -> "MemoryPackIndex":
         return MemoryPackIndex(
-            list(pack.sorted_entries()), pack.data.get_stored_checksum()
+            list(pack_data.sorted_entries()), pack_data.get_stored_checksum()
         )
 
     @classmethod

+ 0 - 9
dulwich/refs.py

@@ -876,7 +876,6 @@ class DiskRefsContainer(RefsContainer):
         self._check_refname(other)
         filename = self.refpath(name)
         f = GitFile(filename, "wb")
-        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
         try:
             f.write(SYMREF + other + b"\n")
             sha = self.follow(name)[-1]
@@ -936,9 +935,6 @@ class DiskRefsContainer(RefsContainer):
 
         ensure_dir_exists(os.path.dirname(filename))
         with GitFile(filename, "wb") as f:
-            assert isinstance(
-                f, _GitFile
-            )  # GitFile in write mode always returns _GitFile
             if old_ref is not None:
                 try:
                     # read again while holding the lock to handle race conditions
@@ -1010,9 +1006,6 @@ class DiskRefsContainer(RefsContainer):
         filename = self.refpath(realname)
         ensure_dir_exists(os.path.dirname(filename))
         with GitFile(filename, "wb") as f:
-            assert isinstance(
-                f, _GitFile
-            )  # GitFile in write mode always returns _GitFile
             if os.path.exists(filename) or name in self.get_packed_refs():
                 f.abort()
                 return False
@@ -1058,7 +1051,6 @@ class DiskRefsContainer(RefsContainer):
         filename = self.refpath(name)
         ensure_dir_exists(os.path.dirname(filename))
         f = GitFile(filename, "wb")
-        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
         try:
             if old_ref is not None:
                 orig_ref = self.read_loose_ref(name)
@@ -1410,7 +1402,6 @@ class locked_ref:
         filename = self._refs_container.refpath(self._realname)
         ensure_dir_exists(os.path.dirname(filename))
         f = GitFile(filename, "wb")
-        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
         self._file = f
         return self
 

+ 1 - 1
tests/test_pack.py

@@ -1594,7 +1594,7 @@ class DeltaChainIteratorTests(TestCase):
             # Attempting to open this REF_DELTA object would loop forever
             pack[b1.id]
         except UnresolvedDeltas as e:
-            self.assertEqual((b1.id), e.shas)
+            self.assertEqual([hex_to_sha(b1.id)], e.shas)
 
 
 class DeltaEncodeSizeTests(TestCase):