Explorar o código

Add type annotations to tests/test_archive.py

Jelmer Vernooij hai 5 meses
pai
achega
7c6231e504
Modificáronse 1 ficheiros con 5 adicións e 10 borrados
  1. 5 10
      tests/test_archive.py

+ 5 - 10
tests/test_archive.py

@@ -24,7 +24,8 @@
 import struct
 import tarfile
 from io import BytesIO
-from unittest import skipUnless
+from typing import Optional
+from unittest.mock import patch
 
 from dulwich.archive import tar_stream
 from dulwich.object_store import MemoryObjectStore
@@ -33,11 +34,6 @@ from dulwich.tests.utils import build_commit_graph
 
 from . import TestCase
 
-try:
-    from unittest.mock import patch
-except ImportError:
-    patch = None
-
 
 class ArchiveTests(TestCase):
     def test_empty(self) -> None:
@@ -50,14 +46,14 @@ class ArchiveTests(TestCase):
         self.addCleanup(tf.close)
         self.assertEqual([], tf.getnames())
 
-    def _get_example_tar_stream(self, *tar_stream_args, **tar_stream_kwargs):
+    def _get_example_tar_stream(self, mtime: int, prefix: bytes = b"", format: str = "") -> BytesIO:
         store = MemoryObjectStore()
         b1 = Blob.from_string(b"somedata")
         store.add_object(b1)
         t1 = Tree()
         t1.add(b"somename", 0o100644, b1.id)
         store.add_object(t1)
-        stream = b"".join(tar_stream(store, t1, *tar_stream_args, **tar_stream_kwargs))
+        stream = b"".join(tar_stream(store, t1, mtime, prefix, format))
         return BytesIO(stream)
 
     def test_simple(self) -> None:
@@ -89,9 +85,8 @@ class ArchiveTests(TestCase):
         expected_mtime = struct.pack("<L", 1234)
         self.assertEqual(stream.getvalue()[4:8], expected_mtime)
 
-    @skipUnless(patch, "Required mock.patch")
     def test_same_file(self) -> None:
-        contents = [None, None]
+        contents: list[Optional[bytes]] = [None, None]
         for format in ["", "gz", "bz2"]:
             for i in [0, 1]:
                 with patch("time.time", return_value=i):