Просмотр исходного кода

Add type annotations to tests/test_dumb.py

Jelmer Vernooij 5 месяцев назад
Родитель
Сommit
c7f1b1c9ef
1 измененных файлов с 37 добавлено и 35 удалено
  1. 37 35
      tests/test_dumb.py

+ 37 - 35
tests/test_dumb.py

@@ -22,45 +22,46 @@
 """Tests for dumb HTTP git repositories."""
 """Tests for dumb HTTP git repositories."""
 
 
 import zlib
 import zlib
+from typing import Callable, Optional, Union
 from unittest import TestCase
 from unittest import TestCase
 from unittest.mock import Mock
 from unittest.mock import Mock
 
 
 from dulwich.dumb import DumbHTTPObjectStore, DumbRemoteHTTPRepo
 from dulwich.dumb import DumbHTTPObjectStore, DumbRemoteHTTPRepo
 from dulwich.errors import NotGitRepository
 from dulwich.errors import NotGitRepository
-from dulwich.objects import Blob, Commit, Tag, Tree, sha_to_hex
+from dulwich.objects import Blob, Commit, ShaFile, Tag, Tree, sha_to_hex
 
 
 
 
 class MockResponse:
 class MockResponse:
-    def __init__(self, status=200, content=b"", headers=None):
+    def __init__(self, status: int = 200, content: bytes = b"", headers: Optional[dict[str, str]] = None) -> None:
         self.status = status
         self.status = status
         self.content = content
         self.content = content
         self.headers = headers or {}
         self.headers = headers or {}
         self.closed = False
         self.closed = False
 
 
-    def close(self):
+    def close(self) -> None:
         self.closed = True
         self.closed = True
 
 
 
 
 class DumbHTTPObjectStoreTests(TestCase):
 class DumbHTTPObjectStoreTests(TestCase):
     """Tests for DumbHTTPObjectStore."""
     """Tests for DumbHTTPObjectStore."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.base_url = "https://example.com/repo.git/"
         self.base_url = "https://example.com/repo.git/"
-        self.responses = {}
+        self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.store = DumbHTTPObjectStore(self.base_url, self._mock_http_request)
         self.store = DumbHTTPObjectStore(self.base_url, self._mock_http_request)
 
 
-    def _mock_http_request(self, url, headers):
+    def _mock_http_request(self, url: str, headers: dict[str, str]) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         """Mock HTTP request function."""
         if url in self.responses:
         if url in self.responses:
             resp_data = self.responses[url]
             resp_data = self.responses[url]
             resp = MockResponse(
             resp = MockResponse(
-                resp_data.get("status", 200), resp_data.get("content", b"")
+                int(resp_data.get("status", 200)), bytes(resp_data.get("content", b""))
             )
             )
             # Create a mock read function that behaves like urllib3's read
             # Create a mock read function that behaves like urllib3's read
             content = resp.content
             content = resp.content
             offset = [0]  # Use list to make it mutable in closure
             offset = [0]  # Use list to make it mutable in closure
 
 
-            def read_func(size=None):
+            def read_func(size: Optional[int] = None) -> bytes:
                 if offset[0] >= len(content):
                 if offset[0] >= len(content):
                     return b""
                     return b""
                 if size is None:
                 if size is None:
@@ -76,12 +77,12 @@ class DumbHTTPObjectStoreTests(TestCase):
             resp = MockResponse(404)
             resp = MockResponse(404)
             return resp, lambda size: b""
             return resp, lambda size: b""
 
 
-    def _add_response(self, path, content, status=200):
+    def _add_response(self, path: str, content: bytes, status: int = 200) -> None:
         """Add a mock response for a given path."""
         """Add a mock response for a given path."""
         url = self.base_url + path
         url = self.base_url + path
         self.responses[url] = {"status": status, "content": content}
         self.responses[url] = {"status": status, "content": content}
 
 
-    def _make_object(self, obj):
+    def _make_object(self, obj: ShaFile) -> bytes:
         """Create compressed git object data."""
         """Create compressed git object data."""
         type_name = {
         type_name = {
             Blob.type_num: b"blob",
             Blob.type_num: b"blob",
@@ -94,7 +95,7 @@ class DumbHTTPObjectStoreTests(TestCase):
         header = type_name + b" " + str(len(content)).encode() + b"\x00"
         header = type_name + b" " + str(len(content)).encode() + b"\x00"
         return zlib.compress(header + content)
         return zlib.compress(header + content)
 
 
-    def test_fetch_loose_object_blob(self):
+    def test_fetch_loose_object_blob(self) -> None:
         # Create a blob object
         # Create a blob object
         blob = Blob()
         blob = Blob()
         blob.data = b"Hello, world!"
         blob.data = b"Hello, world!"
@@ -109,32 +110,33 @@ class DumbHTTPObjectStoreTests(TestCase):
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(b"Hello, world!", content)
         self.assertEqual(b"Hello, world!", content)
 
 
-    def test_fetch_loose_object_not_found(self):
+    def test_fetch_loose_object_not_found(self) -> None:
         hex_sha = b"1" * 40
         hex_sha = b"1" * 40
         self.assertRaises(KeyError, self.store._fetch_loose_object, hex_sha)
         self.assertRaises(KeyError, self.store._fetch_loose_object, hex_sha)
 
 
-    def test_fetch_loose_object_invalid_format(self):
+    def test_fetch_loose_object_invalid_format(self) -> None:
         sha = b"1" * 20
         sha = b"1" * 20
         hex_sha = sha_to_hex(sha)
         hex_sha = sha_to_hex(sha)
-        path = f"objects/{hex_sha[:2]}/{hex_sha[2:]}"
+        path = f"objects/{hex_sha[:2].decode('ascii')}/{hex_sha[2:].decode('ascii')}"
 
 
         # Add invalid compressed data
         # Add invalid compressed data
         self._add_response(path, b"invalid data")
         self._add_response(path, b"invalid data")
 
 
         self.assertRaises(Exception, self.store._fetch_loose_object, sha)
         self.assertRaises(Exception, self.store._fetch_loose_object, sha)
 
 
-    def test_load_packs_empty(self):
+    def test_load_packs_empty(self) -> None:
         # No packs file
         # No packs file
         self.store._load_packs()
         self.store._load_packs()
         self.assertEqual([], self.store._packs)
         self.assertEqual([], self.store._packs)
 
 
-    def test_load_packs_with_entries(self):
+    def test_load_packs_with_entries(self) -> None:
         packs_content = b"""P pack-1234567890abcdef1234567890abcdef12345678.pack
         packs_content = b"""P pack-1234567890abcdef1234567890abcdef12345678.pack
 P pack-abcdef1234567890abcdef1234567890abcdef12.pack
 P pack-abcdef1234567890abcdef1234567890abcdef12.pack
 """
 """
         self._add_response("objects/info/packs", packs_content)
         self._add_response("objects/info/packs", packs_content)
 
 
         self.store._load_packs()
         self.store._load_packs()
+        assert self.store._packs is not None
         self.assertEqual(2, len(self.store._packs))
         self.assertEqual(2, len(self.store._packs))
         self.assertEqual(
         self.assertEqual(
             "pack-1234567890abcdef1234567890abcdef12345678", self.store._packs[0][0]
             "pack-1234567890abcdef1234567890abcdef12345678", self.store._packs[0][0]
@@ -143,7 +145,7 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
             "pack-abcdef1234567890abcdef1234567890abcdef12", self.store._packs[1][0]
             "pack-abcdef1234567890abcdef1234567890abcdef12", self.store._packs[1][0]
         )
         )
 
 
-    def test_get_raw_from_cache(self):
+    def test_get_raw_from_cache(self) -> None:
         sha = b"1" * 40
         sha = b"1" * 40
         self.store._cached_objects[sha] = (Blob.type_num, b"cached content")
         self.store._cached_objects[sha] = (Blob.type_num, b"cached content")
 
 
@@ -151,7 +153,7 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(b"cached content", content)
         self.assertEqual(b"cached content", content)
 
 
-    def test_contains_loose(self):
+    def test_contains_loose(self) -> None:
         # Create a blob object
         # Create a blob object
         blob = Blob()
         blob = Blob()
         blob.data = b"Test blob"
         blob.data = b"Test blob"
@@ -164,35 +166,35 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         self.assertTrue(self.store.contains_loose(hex_sha))
         self.assertTrue(self.store.contains_loose(hex_sha))
         self.assertFalse(self.store.contains_loose(b"0" * 40))
         self.assertFalse(self.store.contains_loose(b"0" * 40))
 
 
-    def test_add_object_not_implemented(self):
+    def test_add_object_not_implemented(self) -> None:
         blob = Blob()
         blob = Blob()
         blob.data = b"test"
         blob.data = b"test"
         self.assertRaises(NotImplementedError, self.store.add_object, blob)
         self.assertRaises(NotImplementedError, self.store.add_object, blob)
 
 
-    def test_add_objects_not_implemented(self):
+    def test_add_objects_not_implemented(self) -> None:
         self.assertRaises(NotImplementedError, self.store.add_objects, [])
         self.assertRaises(NotImplementedError, self.store.add_objects, [])
 
 
 
 
 class DumbRemoteHTTPRepoTests(TestCase):
 class DumbRemoteHTTPRepoTests(TestCase):
     """Tests for DumbRemoteHTTPRepo."""
     """Tests for DumbRemoteHTTPRepo."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.base_url = "https://example.com/repo.git/"
         self.base_url = "https://example.com/repo.git/"
-        self.responses = {}
+        self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.repo = DumbRemoteHTTPRepo(self.base_url, self._mock_http_request)
         self.repo = DumbRemoteHTTPRepo(self.base_url, self._mock_http_request)
 
 
-    def _mock_http_request(self, url, headers):
+    def _mock_http_request(self, url: str, headers: dict[str, str]) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         """Mock HTTP request function."""
         if url in self.responses:
         if url in self.responses:
             resp_data = self.responses[url]
             resp_data = self.responses[url]
             resp = MockResponse(
             resp = MockResponse(
-                resp_data.get("status", 200), resp_data.get("content", b"")
+                int(resp_data.get("status", 200)), bytes(resp_data.get("content", b""))
             )
             )
             # Create a mock read function that behaves like urllib3's read
             # Create a mock read function that behaves like urllib3's read
             content = resp.content
             content = resp.content
             offset = [0]  # Use list to make it mutable in closure
             offset = [0]  # Use list to make it mutable in closure
 
 
-            def read_func(size=None):
+            def read_func(size: Optional[int] = None) -> bytes:
                 if offset[0] >= len(content):
                 if offset[0] >= len(content):
                     return b""
                     return b""
                 if size is None:
                 if size is None:
@@ -208,12 +210,12 @@ class DumbRemoteHTTPRepoTests(TestCase):
             resp = MockResponse(404)
             resp = MockResponse(404)
             return resp, lambda size: b""
             return resp, lambda size: b""
 
 
-    def _add_response(self, path, content, status=200):
+    def _add_response(self, path: str, content: bytes, status: int = 200) -> None:
         """Add a mock response for a given path."""
         """Add a mock response for a given path."""
         url = self.base_url + path
         url = self.base_url + path
         self.responses[url] = {"status": status, "content": content}
         self.responses[url] = {"status": status, "content": content}
 
 
-    def test_get_refs(self):
+    def test_get_refs(self) -> None:
         refs_content = b"""0123456789abcdef0123456789abcdef01234567\trefs/heads/master
         refs_content = b"""0123456789abcdef0123456789abcdef01234567\trefs/heads/master
 abcdef0123456789abcdef0123456789abcdef01\trefs/heads/develop
 abcdef0123456789abcdef0123456789abcdef01\trefs/heads/develop
 fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
 fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
@@ -235,10 +237,10 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
             refs[b"refs/tags/v1.0"],
             refs[b"refs/tags/v1.0"],
         )
         )
 
 
-    def test_get_refs_not_found(self):
+    def test_get_refs_not_found(self) -> None:
         self.assertRaises(NotGitRepository, self.repo.get_refs)
         self.assertRaises(NotGitRepository, self.repo.get_refs)
 
 
-    def test_get_peeled(self):
+    def test_get_peeled(self) -> None:
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
         self._add_response("info/refs", refs_content)
 
 
@@ -246,19 +248,19 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         peeled = self.repo.get_peeled(b"refs/heads/master")
         peeled = self.repo.get_peeled(b"refs/heads/master")
         self.assertEqual(b"0123456789abcdef0123456789abcdef01234567", peeled)
         self.assertEqual(b"0123456789abcdef0123456789abcdef01234567", peeled)
 
 
-    def test_fetch_pack_data_no_wants(self):
+    def test_fetch_pack_data_no_wants(self) -> None:
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
         self._add_response("info/refs", refs_content)
 
 
         graph_walker = Mock()
         graph_walker = Mock()
 
 
-        def determine_wants(refs):
+        def determine_wants(refs: dict[bytes, bytes]) -> list[bytes]:
             return []
             return []
 
 
         result = list(self.repo.fetch_pack_data(graph_walker, determine_wants))
         result = list(self.repo.fetch_pack_data(graph_walker, determine_wants))
         self.assertEqual([], result)
         self.assertEqual([], result)
 
 
-    def test_fetch_pack_data_with_blob(self):
+    def test_fetch_pack_data_with_blob(self) -> None:
         # Set up refs
         # Set up refs
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
         self._add_response("info/refs", refs_content)
@@ -277,10 +279,10 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         graph_walker = Mock()
         graph_walker = Mock()
         graph_walker.ack.return_value = []  # No existing objects
         graph_walker.ack.return_value = []  # No existing objects
 
 
-        def determine_wants(refs):
+        def determine_wants(refs: dict[bytes, bytes]) -> list[bytes]:
             return [blob_sha]
             return [blob_sha]
 
 
-        def progress(msg):
+        def progress(msg: bytes) -> None:
             assert isinstance(msg, bytes)
             assert isinstance(msg, bytes)
 
 
         result = list(
         result = list(
@@ -290,6 +292,6 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         self.assertEqual(Blob.type_num, result[0].pack_type_num)
         self.assertEqual(Blob.type_num, result[0].pack_type_num)
         self.assertEqual([blob.as_raw_string()], result[0].obj_chunks)
         self.assertEqual([blob.as_raw_string()], result[0].obj_chunks)
 
 
-    def test_object_store_property(self):
+    def test_object_store_property(self) -> None:
         self.assertIsInstance(self.repo.object_store, DumbHTTPObjectStore)
         self.assertIsInstance(self.repo.object_store, DumbHTTPObjectStore)
         self.assertEqual(self.base_url, self.repo.object_store.base_url)
         self.assertEqual(self.base_url, self.repo.object_store.base_url)