Преглед изворни кода

Add type annotations to tests/test_dumb.py

Jelmer Vernooij пре 5 месеци
родитељ
комит
c7f1b1c9ef
1 измењених фајлова са 37 додато и 35 уклоњено
  1. 37 35
      tests/test_dumb.py

+ 37 - 35
tests/test_dumb.py

@@ -22,45 +22,46 @@
 """Tests for dumb HTTP git repositories."""
 
 import zlib
+from typing import Callable, Optional, Union
 from unittest import TestCase
 from unittest.mock import Mock
 
 from dulwich.dumb import DumbHTTPObjectStore, DumbRemoteHTTPRepo
 from dulwich.errors import NotGitRepository
-from dulwich.objects import Blob, Commit, Tag, Tree, sha_to_hex
+from dulwich.objects import Blob, Commit, ShaFile, Tag, Tree, sha_to_hex
 
 
 class MockResponse:
-    def __init__(self, status=200, content=b"", headers=None):
+    def __init__(self, status: int = 200, content: bytes = b"", headers: Optional[dict[str, str]] = None) -> None:
         self.status = status
         self.content = content
         self.headers = headers or {}
         self.closed = False
 
-    def close(self):
+    def close(self) -> None:
         self.closed = True
 
 
 class DumbHTTPObjectStoreTests(TestCase):
     """Tests for DumbHTTPObjectStore."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.base_url = "https://example.com/repo.git/"
-        self.responses = {}
+        self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.store = DumbHTTPObjectStore(self.base_url, self._mock_http_request)
 
-    def _mock_http_request(self, url, headers):
+    def _mock_http_request(self, url: str, headers: dict[str, str]) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         if url in self.responses:
             resp_data = self.responses[url]
             resp = MockResponse(
-                resp_data.get("status", 200), resp_data.get("content", b"")
+                int(resp_data.get("status", 200)), bytes(resp_data.get("content", b""))
             )
             # Create a mock read function that behaves like urllib3's read
             content = resp.content
             offset = [0]  # Use list to make it mutable in closure
 
-            def read_func(size=None):
+            def read_func(size: Optional[int] = None) -> bytes:
                 if offset[0] >= len(content):
                     return b""
                 if size is None:
@@ -76,12 +77,12 @@ class DumbHTTPObjectStoreTests(TestCase):
             resp = MockResponse(404)
             return resp, lambda size: b""
 
-    def _add_response(self, path, content, status=200):
+    def _add_response(self, path: str, content: bytes, status: int = 200) -> None:
         """Add a mock response for a given path."""
         url = self.base_url + path
         self.responses[url] = {"status": status, "content": content}
 
-    def _make_object(self, obj):
+    def _make_object(self, obj: ShaFile) -> bytes:
         """Create compressed git object data."""
         type_name = {
             Blob.type_num: b"blob",
@@ -94,7 +95,7 @@ class DumbHTTPObjectStoreTests(TestCase):
         header = type_name + b" " + str(len(content)).encode() + b"\x00"
         return zlib.compress(header + content)
 
-    def test_fetch_loose_object_blob(self):
+    def test_fetch_loose_object_blob(self) -> None:
         # Create a blob object
         blob = Blob()
         blob.data = b"Hello, world!"
@@ -109,32 +110,33 @@ class DumbHTTPObjectStoreTests(TestCase):
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(b"Hello, world!", content)
 
-    def test_fetch_loose_object_not_found(self):
+    def test_fetch_loose_object_not_found(self) -> None:
         hex_sha = b"1" * 40
         self.assertRaises(KeyError, self.store._fetch_loose_object, hex_sha)
 
-    def test_fetch_loose_object_invalid_format(self):
+    def test_fetch_loose_object_invalid_format(self) -> None:
         sha = b"1" * 20
         hex_sha = sha_to_hex(sha)
-        path = f"objects/{hex_sha[:2]}/{hex_sha[2:]}"
+        path = f"objects/{hex_sha[:2].decode('ascii')}/{hex_sha[2:].decode('ascii')}"
 
         # Add invalid compressed data
         self._add_response(path, b"invalid data")
 
         self.assertRaises(Exception, self.store._fetch_loose_object, sha)
 
-    def test_load_packs_empty(self):
+    def test_load_packs_empty(self) -> None:
         # No packs file
         self.store._load_packs()
         self.assertEqual([], self.store._packs)
 
-    def test_load_packs_with_entries(self):
+    def test_load_packs_with_entries(self) -> None:
         packs_content = b"""P pack-1234567890abcdef1234567890abcdef12345678.pack
 P pack-abcdef1234567890abcdef1234567890abcdef12.pack
 """
         self._add_response("objects/info/packs", packs_content)
 
         self.store._load_packs()
+        assert self.store._packs is not None
         self.assertEqual(2, len(self.store._packs))
         self.assertEqual(
             "pack-1234567890abcdef1234567890abcdef12345678", self.store._packs[0][0]
@@ -143,7 +145,7 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
             "pack-abcdef1234567890abcdef1234567890abcdef12", self.store._packs[1][0]
         )
 
-    def test_get_raw_from_cache(self):
+    def test_get_raw_from_cache(self) -> None:
         sha = b"1" * 40
         self.store._cached_objects[sha] = (Blob.type_num, b"cached content")
 
@@ -151,7 +153,7 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(b"cached content", content)
 
-    def test_contains_loose(self):
+    def test_contains_loose(self) -> None:
         # Create a blob object
         blob = Blob()
         blob.data = b"Test blob"
@@ -164,35 +166,35 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         self.assertTrue(self.store.contains_loose(hex_sha))
         self.assertFalse(self.store.contains_loose(b"0" * 40))
 
-    def test_add_object_not_implemented(self):
+    def test_add_object_not_implemented(self) -> None:
         blob = Blob()
         blob.data = b"test"
         self.assertRaises(NotImplementedError, self.store.add_object, blob)
 
-    def test_add_objects_not_implemented(self):
+    def test_add_objects_not_implemented(self) -> None:
         self.assertRaises(NotImplementedError, self.store.add_objects, [])
 
 
 class DumbRemoteHTTPRepoTests(TestCase):
     """Tests for DumbRemoteHTTPRepo."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.base_url = "https://example.com/repo.git/"
-        self.responses = {}
+        self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.repo = DumbRemoteHTTPRepo(self.base_url, self._mock_http_request)
 
-    def _mock_http_request(self, url, headers):
+    def _mock_http_request(self, url: str, headers: dict[str, str]) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         if url in self.responses:
             resp_data = self.responses[url]
             resp = MockResponse(
-                resp_data.get("status", 200), resp_data.get("content", b"")
+                int(resp_data.get("status", 200)), bytes(resp_data.get("content", b""))
             )
             # Create a mock read function that behaves like urllib3's read
             content = resp.content
             offset = [0]  # Use list to make it mutable in closure
 
-            def read_func(size=None):
+            def read_func(size: Optional[int] = None) -> bytes:
                 if offset[0] >= len(content):
                     return b""
                 if size is None:
@@ -208,12 +210,12 @@ class DumbRemoteHTTPRepoTests(TestCase):
             resp = MockResponse(404)
             return resp, lambda size: b""
 
-    def _add_response(self, path, content, status=200):
+    def _add_response(self, path: str, content: bytes, status: int = 200) -> None:
         """Add a mock response for a given path."""
         url = self.base_url + path
         self.responses[url] = {"status": status, "content": content}
 
-    def test_get_refs(self):
+    def test_get_refs(self) -> None:
         refs_content = b"""0123456789abcdef0123456789abcdef01234567\trefs/heads/master
 abcdef0123456789abcdef0123456789abcdef01\trefs/heads/develop
 fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
@@ -235,10 +237,10 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
             refs[b"refs/tags/v1.0"],
         )
 
-    def test_get_refs_not_found(self):
+    def test_get_refs_not_found(self) -> None:
         self.assertRaises(NotGitRepository, self.repo.get_refs)
 
-    def test_get_peeled(self):
+    def test_get_peeled(self) -> None:
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
 
@@ -246,19 +248,19 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         peeled = self.repo.get_peeled(b"refs/heads/master")
         self.assertEqual(b"0123456789abcdef0123456789abcdef01234567", peeled)
 
-    def test_fetch_pack_data_no_wants(self):
+    def test_fetch_pack_data_no_wants(self) -> None:
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
 
         graph_walker = Mock()
 
-        def determine_wants(refs):
+        def determine_wants(refs: dict[bytes, bytes]) -> list[bytes]:
             return []
 
         result = list(self.repo.fetch_pack_data(graph_walker, determine_wants))
         self.assertEqual([], result)
 
-    def test_fetch_pack_data_with_blob(self):
+    def test_fetch_pack_data_with_blob(self) -> None:
         # Set up refs
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
@@ -277,10 +279,10 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         graph_walker = Mock()
         graph_walker.ack.return_value = []  # No existing objects
 
-        def determine_wants(refs):
+        def determine_wants(refs: dict[bytes, bytes]) -> list[bytes]:
             return [blob_sha]
 
-        def progress(msg):
+        def progress(msg: bytes) -> None:
             assert isinstance(msg, bytes)
 
         result = list(
@@ -290,6 +292,6 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         self.assertEqual(Blob.type_num, result[0].pack_type_num)
         self.assertEqual([blob.as_raw_string()], result[0].obj_chunks)
 
-    def test_object_store_property(self):
+    def test_object_store_property(self) -> None:
         self.assertIsInstance(self.repo.object_store, DumbHTTPObjectStore)
         self.assertEqual(self.base_url, self.repo.object_store.base_url)