Jelmer Vernooij 1 неделя назад
Родитель
Сommit
3631ba1f95

+ 3 - 0
docs/tutorial/remote.txt

@@ -89,4 +89,7 @@ importing the received pack file into the local repository::
 
 Let's shut down the server now that all tests have been run::
 
+   >>> client.close()
    >>> dul_server.shutdown()
+   >>> dul_server.server_close()
+   >>> repo.close()

+ 48 - 38
dulwich/__init__.py

@@ -23,7 +23,7 @@
 """Python implementation of the Git file formats and protocols."""
 
 from collections.abc import Callable
-from typing import Any, ParamSpec, TypeVar
+from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
 
 __version__ = (0, 25, 0)
 
@@ -33,51 +33,61 @@ P = ParamSpec("P")
 R = TypeVar("R")
 F = TypeVar("F", bound=Callable[..., Any])
 
-try:
-    from dissolve import replace_me as replace_me
-except ImportError:
-    # if dissolve is not installed, then just provide a basic implementation
-    # of its replace_me decorator
+if TYPE_CHECKING:
+    # For type checking, always use our typed signature
     def replace_me(
         since: tuple[int, ...] | str | None = None,
         remove_in: tuple[int, ...] | str | None = None,
-    ) -> Callable[[F], F]:
-        """Decorator to mark functions as deprecated.
+    ) -> Callable[[Callable[P, R]], Callable[P, R]]:
+        """Decorator to mark functions as deprecated."""
+        ...
 
-        Args:
-            since: Version when the function was deprecated
-            remove_in: Version when the function will be removed
+else:
+    try:
+        from dissolve import replace_me as replace_me
+    except ImportError:
+        # if dissolve is not installed, then just provide a basic implementation
+        # of its replace_me decorator
+        def replace_me(
+            since: tuple[int, ...] | str | None = None,
+            remove_in: tuple[int, ...] | str | None = None,
+        ) -> Callable[[Callable[P, R]], Callable[P, R]]:
+            """Decorator to mark functions as deprecated.
 
-        Returns:
-            Decorator function
-        """
+            Args:
+                since: Version when the function was deprecated
+                remove_in: Version when the function will be removed
 
-        def decorator(func: Callable[P, R]) -> Callable[P, R]:
-            import functools
-            import warnings
+            Returns:
+                Decorator function
+            """
 
-            m = f"{func.__name__} is deprecated"
-            since_str = str(since) if since is not None else None
-            remove_in_str = str(remove_in) if remove_in is not None else None
+            def decorator(func: Callable[P, R]) -> Callable[P, R]:
+                import functools
+                import warnings
 
-            if since_str is not None and remove_in_str is not None:
-                m += f" since {since_str} and will be removed in {remove_in_str}"
-            elif since_str is not None:
-                m += f" since {since_str}"
-            elif remove_in_str is not None:
-                m += f" and will be removed in {remove_in_str}"
-            else:
-                m += " and will be removed in a future version"
+                m = f"{func.__name__} is deprecated"
+                since_str = str(since) if since is not None else None
+                remove_in_str = str(remove_in) if remove_in is not None else None
 
-            @functools.wraps(func)
-            def _wrapped_func(*args: P.args, **kwargs: P.kwargs) -> R:
-                warnings.warn(
-                    m,
-                    DeprecationWarning,
-                    stacklevel=2,
-                )
-                return func(*args, **kwargs)
+                if since_str is not None and remove_in_str is not None:
+                    m += f" since {since_str} and will be removed in {remove_in_str}"
+                elif since_str is not None:
+                    m += f" since {since_str}"
+                elif remove_in_str is not None:
+                    m += f" and will be removed in {remove_in_str}"
+                else:
+                    m += " and will be removed in a future version"
 
-            return _wrapped_func
+                @functools.wraps(func)
+                def _wrapped_func(*args: P.args, **kwargs: P.kwargs) -> R:
+                    warnings.warn(
+                        m,
+                        DeprecationWarning,
+                        stacklevel=2,
+                    )
+                    return func(*args, **kwargs)
 
-        return decorator  # type: ignore[return-value]
+                return _wrapped_func
+
+            return decorator

+ 21 - 12
dulwich/aiohttp/server.py

@@ -30,6 +30,7 @@ from aiohttp import web
 
 from .. import log_utils
 from ..errors import HangupException
+from ..objects import ObjectID
 from ..protocol import ReceivableProtocol
 from ..repo import Repo
 from ..server import (
@@ -43,6 +44,11 @@ from ..web import NO_CACHE_HEADERS, cache_forever_headers
 
 logger = log_utils.getLogger(__name__)
 
+# Application keys for type-safe access to app state
+REPO_KEY = web.AppKey("repo", Repo)
+HANDLERS_KEY = web.AppKey("handlers", dict)
+DUMB_KEY = web.AppKey("dumb", bool)
+
 
 async def send_file(
     req: web.Request, f: BinaryIO | None, headers: dict[str, str]
@@ -80,9 +86,11 @@ async def get_loose_object(request: web.Request) -> web.Response:
       request: aiohttp request object
     Returns: Response with the loose object data
     """
-    sha = (request.match_info["dir"] + request.match_info["file"]).encode("ascii")
+    sha = ObjectID(
+        (request.match_info["dir"] + request.match_info["file"]).encode("ascii")
+    )
     logger.info("Sending loose object %s", sha)
-    object_store = request.app["repo"].object_store
+    object_store = request.app[REPO_KEY].object_store
     if not object_store.contains_loose(sha):
         raise web.HTTPNotFound(text="Object not found")
     try:
@@ -105,7 +113,7 @@ async def get_text_file(request: web.Request) -> web.StreamResponse:
     headers.update(NO_CACHE_HEADERS)
     path = request.match_info["file"]
     logger.info("Sending plain text file %s", path)
-    repo = request.app["repo"]
+    repo = request.app[REPO_KEY]
     return await send_file(request, repo.get_named_file(path), headers)
 
 
@@ -169,8 +177,8 @@ async def get_info_refs(request: web.Request) -> web.StreamResponse | web.Respon
       request: aiohttp request object
     Returns: Response with refs information
     """
-    repo = request.app["repo"]
-    return await refs_request(repo, request, request.app["handlers"])
+    repo = request.app[REPO_KEY]
+    return await refs_request(repo, request, request.app[HANDLERS_KEY])
 
 
 async def get_info_packs(request: web.Request) -> web.Response:
@@ -184,7 +192,8 @@ async def get_info_packs(request: web.Request) -> web.Response:
     headers.update(NO_CACHE_HEADERS)
     logger.info("Emulating dumb info/packs")
     return web.Response(
-        body=b"".join(generate_objects_info_packs(request.app["repo"])), headers=headers
+        body=b"".join(generate_objects_info_packs(request.app[REPO_KEY])),
+        headers=headers,
     )
 
 
@@ -202,7 +211,7 @@ async def get_pack_file(request: web.Request) -> web.StreamResponse:
     logger.info("Sending pack file %s", path)
     return await send_file(
         request,
-        request.app["repo"].get_named_file(path),
+        request.app[REPO_KEY].get_named_file(path),
         headers=headers,
     )
 
@@ -281,9 +290,9 @@ async def handle_service_request(request: web.Request) -> web.StreamResponse:
       request: aiohttp request object
     Returns: Response with service result
     """
-    repo = request.app["repo"]
+    repo = request.app[REPO_KEY]
 
-    return await service_request(repo, request, request.app["handlers"])
+    return await service_request(repo, request, request.app[HANDLERS_KEY])
 
 
 def create_repo_app(
@@ -298,11 +307,11 @@ def create_repo_app(
     Returns: Configured aiohttp Application
     """
     app = web.Application()
-    app["repo"] = repo
+    app[REPO_KEY] = repo
     if handlers is None:
         handlers = dict(DEFAULT_HANDLERS)
-    app["handlers"] = handlers
-    app["dumb"] = dumb
+    app[HANDLERS_KEY] = handlers
+    app[DUMB_KEY] = dumb
     app.router.add_get("/info/refs", get_info_refs)
     app.router.add_post(
         "/{service:git-upload-pack|git-receive-pack}", handle_service_request

+ 39 - 0
dulwich/bundle.py

@@ -29,6 +29,7 @@ __all__ = [
     "write_bundle",
 ]
 
+import types
 from collections.abc import Callable, Iterator, Sequence
 from typing import (
     TYPE_CHECKING,
@@ -60,6 +61,10 @@ class PackDataLike(Protocol):
         """Iterate over unpacked objects in the pack."""
         ...
 
+    def close(self) -> None:
+        """Close any open resources."""
+        ...
+
 
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
@@ -101,6 +106,37 @@ class Bundle:
             return False
         return True
 
+    def close(self) -> None:
+        """Close any open resources in this bundle."""
+        if self.pack_data is not None:
+            self.pack_data.close()
+            self.pack_data = None
+
+    def __enter__(self) -> "Bundle":
+        """Enter context manager."""
+        return self
+
+    def __exit__(
+        self,
+        exc_type: type[BaseException] | None,
+        exc_val: BaseException | None,
+        exc_tb: types.TracebackType | None,
+    ) -> None:
+        """Exit context manager and close bundle."""
+        self.close()
+
+    def __del__(self) -> None:
+        """Warn if bundle was not explicitly closed."""
+        if self.pack_data is not None:
+            import warnings
+
+            warnings.warn(
+                f"Bundle {self!r} was not explicitly closed. "
+                "Please use bundle.close() or a context manager.",
+                ResourceWarning,
+                stacklevel=2,
+            )
+
     def store_objects(
         self,
         object_store: "BaseObjectStore",
@@ -332,6 +368,9 @@ def create_bundle_from_repo(
         def iter_unpacked(self) -> Iterator[UnpackedObject]:
             return iter(self._objects)
 
+        def close(self) -> None:
+            """Close pack data (no-op for in-memory pack data)."""
+
     pack_data = _BundlePackData(pack_count, pack_objects, repo.object_format)
 
     # Create bundle object

+ 12 - 2
dulwich/client.py

@@ -1020,6 +1020,14 @@ class GitClient:
             self._fetch_capabilities.add(CAPABILITY_INCLUDE_TAG)
         self.protocol_version = 0  # will be overridden later
 
+    def close(self) -> None:
+        """Close the client and release any resources.
+
+        Default implementation does nothing as most clients don't maintain
+        persistent connections. Subclasses that hold resources should override
+        this method to properly clean them up.
+        """
+
     def get_url(self, path: str) -> str:
         """Retrieves full url to given path.
 
@@ -2924,7 +2932,10 @@ class BundleClient(GitClient):
 
         pack_io = BytesIO(pack_bytes)
         pack_data = PackData.from_file(pack_io, object_format=DEFAULT_OBJECT_FORMAT)
-        target.object_store.add_pack_data(len(pack_data), pack_data.iter_unpacked())
+        try:
+            target.object_store.add_pack_data(len(pack_data), pack_data.iter_unpacked())
+        finally:
+            pack_data.close()
 
         # Apply ref filtering if specified
         if ref_prefix:
@@ -4595,7 +4606,6 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
             if resp.status != 200:
                 raise GitProtocolError(f"unexpected http resp {resp.status} for {url}")
 
-        # With urllib3 >= 2.2, geturl() is always available
         resp.content_type = resp.headers.get("Content-Type")  # type: ignore[union-attr]
         resp_url = resp.geturl()
         resp.redirect_location = resp_url if resp_url != url else ""  # type: ignore[union-attr]

+ 52 - 43
dulwich/mbox.py

@@ -78,46 +78,52 @@ def split_mbox(
         raise ValueError(f"Output path is not a directory: {output_dir}")
 
     # Open the mbox file
+    mbox_obj: mailbox.mbox | None = None
     mbox_iter: Iterable[mailbox.mboxMessage]
     if isinstance(input_file, (str, bytes)):
         if isinstance(input_file, bytes):
             input_file = input_file.decode("utf-8")
-        mbox_iter = mailbox.mbox(input_file)
+        mbox_obj = mailbox.mbox(input_file)
+        mbox_iter = mbox_obj
     else:
         # For file-like objects, we need to read and parse manually
         mbox_iter = _parse_mbox_from_file(input_file)
 
-    output_files = []
-    msg_number = start_number
+    try:
+        output_files = []
+        msg_number = start_number
 
-    for message in mbox_iter:
-        # Format the output filename with the specified precision
-        output_filename = f"{msg_number:0{precision}d}"
-        output_file_path = output_path / output_filename
+        for message in mbox_iter:
+            # Format the output filename with the specified precision
+            output_filename = f"{msg_number:0{precision}d}"
+            output_file_path = output_path / output_filename
 
-        # Write the message to the output file
-        with open(output_file_path, "wb") as f:
-            message_bytes = bytes(message)
+            # Write the message to the output file
+            with open(output_file_path, "wb") as f:
+                message_bytes = bytes(message)
 
-            # Handle mboxrd format - reverse the escaping
-            if mboxrd:
-                message_bytes = _reverse_mboxrd_escaping(message_bytes)
+                # Handle mboxrd format - reverse the escaping
+                if mboxrd:
+                    message_bytes = _reverse_mboxrd_escaping(message_bytes)
 
-            # Handle CR/LF if needed
-            if not keep_cr:
-                message_bytes = message_bytes.replace(b"\r\n", b"\n")
+                # Handle CR/LF if needed
+                if not keep_cr:
+                    message_bytes = message_bytes.replace(b"\r\n", b"\n")
 
-            # Strip trailing newlines (mailbox module adds separator newlines)
-            message_bytes = message_bytes.rstrip(b"\n")
-            if message_bytes:
-                message_bytes += b"\n"
+                # Strip trailing newlines (mailbox module adds separator newlines)
+                message_bytes = message_bytes.rstrip(b"\n")
+                if message_bytes:
+                    message_bytes += b"\n"
 
-            f.write(message_bytes)
+                f.write(message_bytes)
 
-        output_files.append(str(output_file_path))
-        msg_number += 1
+            output_files.append(str(output_file_path))
+            msg_number += 1
 
-    return output_files
+        return output_files
+    finally:
+        if mbox_obj is not None:
+            mbox_obj.close()
 
 
 def split_maildir(
@@ -167,33 +173,36 @@ def split_maildir(
     # Open the Maildir
     md = mailbox.Maildir(str(maildir), factory=None)
 
-    # Get all messages and sort by their keys to ensure consistent ordering
-    sorted_keys = sorted(md.keys())
+    try:
+        # Get all messages and sort by their keys to ensure consistent ordering
+        sorted_keys = sorted(md.keys())
 
-    output_files = []
-    msg_number = start_number
+        output_files = []
+        msg_number = start_number
 
-    for key in sorted_keys:
-        message = md[key]
+        for key in sorted_keys:
+            message = md[key]
 
-        # Format the output filename with the specified precision
-        output_filename = f"{msg_number:0{precision}d}"
-        output_file_path = output_path / output_filename
+            # Format the output filename with the specified precision
+            output_filename = f"{msg_number:0{precision}d}"
+            output_file_path = output_path / output_filename
 
-        # Write the message to the output file
-        with open(output_file_path, "wb") as f:
-            message_bytes = bytes(message)
+            # Write the message to the output file
+            with open(output_file_path, "wb") as f:
+                message_bytes = bytes(message)
 
-            # Handle CR/LF if needed
-            if not keep_cr:
-                message_bytes = message_bytes.replace(b"\r\n", b"\n")
+                # Handle CR/LF if needed
+                if not keep_cr:
+                    message_bytes = message_bytes.replace(b"\r\n", b"\n")
 
-            f.write(message_bytes)
+                f.write(message_bytes)
 
-        output_files.append(str(output_file_path))
-        msg_number += 1
+            output_files.append(str(output_file_path))
+            msg_number += 1
 
-    return output_files
+        return output_files
+    finally:
+        md.close()
 
 
 def _parse_mbox_from_file(file_obj: BinaryIO) -> Iterator[mailbox.mboxMessage]:

+ 20 - 11
dulwich/object_store.py

@@ -847,6 +847,7 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
         """
         super().__init__(object_format=object_format)
         self._pack_cache: dict[str, Pack] = {}
+        self._closed = False
         self.pack_compression_level = pack_compression_level
         self.pack_index_version = pack_index_version
         self.pack_delta_window_size = pack_delta_window_size
@@ -1011,11 +1012,14 @@ class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
 
         This method closes all cached pack files and frees associated resources.
         """
+        self._closed = True
         self._clear_cached_packs()
 
     @property
     def packs(self) -> list[Pack]:
         """List with pack objects."""
+        if self._closed:
+            raise ValueError("Cannot access packs on a closed object store")
         return list(self._iter_cached_packs()) + list(self._update_pack_cache())
 
     def count_pack_files(self) -> int:
@@ -1663,7 +1667,6 @@ class DiskObjectStore(PackBasedObjectStore):
         try:
             pack_dir_contents = os.listdir(self.pack_dir)
         except FileNotFoundError:
-            self.close()
             return []
         pack_files = set()
         for name in pack_dir_contents:
@@ -1672,15 +1675,16 @@ class DiskObjectStore(PackBasedObjectStore):
                 # fully written)
                 idx_name = os.path.splitext(name)[0] + ".idx"
                 if idx_name in pack_dir_contents:
-                    pack_name = name[: -len(".pack")]
-                    pack_files.add(pack_name)
+                    # Extract just the hash (remove "pack-" prefix and ".pack" suffix)
+                    pack_hash = name[len("pack-") : -len(".pack")]
+                    pack_files.add(pack_hash)
 
         # Open newly appeared pack files
         new_packs = []
-        for f in pack_files:
-            if f not in self._pack_cache:
+        for pack_hash in pack_files:
+            if pack_hash not in self._pack_cache:
                 pack = Pack(
-                    os.path.join(self.pack_dir, f),
+                    os.path.join(self.pack_dir, "pack-" + pack_hash),
                     object_format=self.object_format,
                     delta_window_size=self.pack_delta_window_size,
                     window_memory=self.pack_window_memory,
@@ -1690,7 +1694,7 @@ class DiskObjectStore(PackBasedObjectStore):
                     big_file_threshold=self.pack_big_file_threshold,
                 )
                 new_packs.append(pack)
-                self._pack_cache[f] = pack
+                self._pack_cache[pack_hash] = pack
         # Remove disappeared pack files
         for f in set(self._pack_cache) - pack_files:
             self._pack_cache.pop(f).close()
@@ -1799,10 +1803,13 @@ class DiskObjectStore(PackBasedObjectStore):
             del self._pack_cache[os.path.basename(pack._basename)]
         except KeyError:
             pass
+        # Store paths before closing to avoid re-opening files on Windows
+        data_path = pack._data_path
+        idx_path = pack._idx_path
         pack.close()
-        os.remove(pack.data.path)
-        if hasattr(pack.index, "path"):
-            os.remove(pack.index.path)
+        os.remove(data_path)
+        if os.path.exists(idx_path):
+            os.remove(idx_path)
 
     def _get_pack_basepath(
         self, entries: Iterable[tuple[bytes, int, int | None]]
@@ -1937,7 +1944,9 @@ class DiskObjectStore(PackBasedObjectStore):
             big_file_threshold=self.pack_big_file_threshold,
         )
         final_pack.check_length_and_checksum()
-        self._add_cached_pack(pack_base_name, final_pack)
+        # Extract just the hash from pack_base_name (/path/to/pack-HASH -> HASH)
+        pack_hash = os.path.basename(pack_base_name)[len("pack-") :]
+        self._add_cached_pack(pack_hash, final_pack)
         return final_pack
 
     def add_thin_pack(

+ 1 - 1
dulwich/objects.py

@@ -2098,7 +2098,7 @@ class Commit(ShaFile):
             gpgsig,
             message,
             extra,
-        ) = parse_commit(chunks)
+        ) = _parse_commit(chunks)
 
         self._tree = tree
         self._parents = [ObjectID(p) for p in parents]

+ 36 - 1
dulwich/pack.py

@@ -1813,7 +1813,26 @@ class PackData:
 
     def close(self) -> None:
         """Close the underlying pack file."""
-        self._file.close()
+        if self._file is not None:
+            self._file.close()
+            self._file = None  # type: ignore
+
+    def __del__(self) -> None:
+        """Ensure pack file is closed when PackData is garbage collected."""
+        if self._file is not None:
+            import warnings
+
+            warnings.warn(
+                f"unclosed PackData {self!r}",
+                ResourceWarning,
+                stacklevel=2,
+                source=self,
+            )
+            try:
+                self.close()
+            except Exception:
+                # Ignore errors during cleanup
+                pass
 
     def __enter__(self) -> "PackData":
         """Enter context manager."""
@@ -4066,8 +4085,24 @@ class Pack:
         """Close the pack file and index."""
         if self._data is not None:
             self._data.close()
+            self._data = None
         if self._idx is not None:
             self._idx.close()
+            self._idx = None
+
+    def __del__(self) -> None:
+        """Ensure pack file is closed when Pack is garbage collected."""
+        if self._data is not None or self._idx is not None:
+            import warnings
+
+            warnings.warn(
+                f"unclosed Pack {self!r}", ResourceWarning, stacklevel=2, source=self
+            )
+            try:
+                self.close()
+            except Exception:
+                # Ignore errors during cleanup
+                pass
 
     def __enter__(self) -> "Pack":
         """Enter context manager."""

+ 18 - 0
dulwich/protocol.py

@@ -401,6 +401,24 @@ class Protocol:
         """Close the underlying transport if a close function was provided."""
         if self._close:
             self._close()
+            self._close = None  # Prevent double-close
+
+    def __del__(self) -> None:
+        """Ensure transport is closed when Protocol is garbage collected."""
+        if self._close is not None:
+            import warnings
+
+            warnings.warn(
+                f"unclosed Protocol {self!r}",
+                ResourceWarning,
+                stacklevel=2,
+                source=self,
+            )
+            try:
+                self.close()
+            except Exception:
+                # Ignore errors during cleanup
+                pass
 
     def __enter__(self) -> "Protocol":
         """Enter context manager."""

+ 2 - 0
dulwich/repo.py

@@ -2746,6 +2746,8 @@ class MemoryRepo(BaseRepo):
         if self.filter_context is not None:
             self.filter_context.close()
             self.filter_context = None
+        # Close object store to release pack files
+        self.object_store.close()
 
     def do_commit(
         self,

+ 1 - 2
dulwich/tests/test_object_store.py

@@ -329,8 +329,7 @@ class PackBasedObjectStoreTests(ObjectStoreTests):
 
     def tearDown(self) -> None:
         """Clean up by closing all packs."""
-        for pack in self.store.packs:
-            pack.close()
+        self.store.close()
 
     def test_empty_packs(self) -> None:
         """Test that new store has no packs."""

+ 1 - 0
pyproject.toml

@@ -104,6 +104,7 @@ warn_return_any = false
 [tool.setuptools]
 packages = [
     "dulwich",
+    "dulwich.aiohttp",
     "dulwich.cloud",
     "dulwich.contrib",
     "dulwich.porcelain",

+ 5 - 3
tests/compat/server_utils.py

@@ -228,9 +228,11 @@ class ServerTests:
         )
 
         # compare the two clones; they should be equal
-        self.assertReposEqual(
-            Repo(self._stub_repo_git.path), Repo(self._stub_repo_dw.path)
-        )
+        repo_git = Repo(self._stub_repo_git.path)
+        self.addCleanup(repo_git.close)
+        repo_dw = Repo(self._stub_repo_dw.path)
+        self.addCleanup(repo_dw.close)
+        self.assertReposEqual(repo_git, repo_dw)
 
     def test_fetch_same_depth_into_shallow_clone_from_dulwich(self) -> None:
         require_git_version(self.min_single_branch_version)

+ 3 - 1
tests/compat/test_aiohttp.py

@@ -88,12 +88,14 @@ class AiohttpServerTests(ServerTests):
         # Cleanup function
         def cleanup():
             async def stop():
+                await site.stop()
                 await runner.cleanup()
 
             future = asyncio.run_coroutine_threadsafe(stop(), loop)
             future.result(timeout=5)
             loop.call_soon_threadsafe(loop.stop)
-            thread.join(timeout=1.0)
+            thread.join()
+            loop.close()
 
         self.addCleanup(cleanup)
         self._server = runner

+ 7 - 0
tests/compat/test_bundle.py

@@ -62,6 +62,7 @@ class CompatBundleTestCase(CompatTestCase):
 
         # Use create_bundle_from_repo helper
         bundle = create_bundle_from_repo(self.repo)
+        self.addCleanup(bundle.close)
 
         with open(bundle_path, "wb") as f:
             write_bundle(f, bundle)
@@ -94,6 +95,7 @@ class CompatBundleTestCase(CompatTestCase):
         # Read bundle using dulwich
         with open(bundle_path, "rb") as f:
             bundle = read_bundle(f)
+        self.addCleanup(bundle.close)
 
         # Verify bundle contents
         self.assertEqual(2, bundle.version)
@@ -138,6 +140,7 @@ class CompatBundleTestCase(CompatTestCase):
         # Read bundle using dulwich
         with open(bundle_path, "rb") as f:
             bundle = read_bundle(f)
+        self.addCleanup(bundle.close)
 
         # Verify bundle contains all refs
         self.assertIn(b"refs/heads/master", bundle.references)
@@ -185,6 +188,7 @@ class CompatBundleTestCase(CompatTestCase):
         # Read bundle using dulwich
         with open(bundle_path, "rb") as f:
             bundle = read_bundle(f)
+        self.addCleanup(bundle.close)
 
         # Verify bundle has prerequisites
         self.assertGreater(len(bundle.prerequisites), 0)
@@ -234,6 +238,7 @@ class CompatBundleTestCase(CompatTestCase):
         # Read bundle using dulwich
         with open(bundle_path, "rb") as f:
             bundle = read_bundle(f)
+        self.addCleanup(bundle.close)
 
         # Verify bundle contents
         self.assertEqual(2, bundle.version)
@@ -266,6 +271,7 @@ class CompatBundleTestCase(CompatTestCase):
         # Read bundle using dulwich
         with open(bundle_path, "rb") as f:
             bundle = read_bundle(f)
+        self.addCleanup(bundle.close)
 
         # Verify bundle was read correctly
         self.assertEqual(2, bundle.version)
@@ -319,6 +325,7 @@ class CompatBundleTestCase(CompatTestCase):
         # Read the bundle and store objects using dulwich
         with open(bundle_path, "rb") as f:
             bundle = read_bundle(f)
+        self.addCleanup(bundle.close)
 
         # Use the bundle's store_objects method to unbundle
         bundle.store_objects(unbundle_repo.object_store)

+ 30 - 31
tests/compat/test_client.py

@@ -66,9 +66,8 @@ class DulwichClientTestBase:
         self.dest = os.path.join(self.gitroot, "dest")
         file.ensure_dir_exists(self.dest)
         run_git_or_fail(["init", "--quiet", "--bare"], cwd=self.dest)
-
-    def tearDown(self) -> None:
-        rmtree_ro(self.gitroot)
+        # Register cleanup to run after test's cleanup handlers
+        self.addCleanup(rmtree_ro, self.gitroot)
 
     def assertDestEqualsSrc(self) -> None:
         repo_dir = os.path.join(self.gitroot, "server_new.export")
@@ -211,6 +210,7 @@ class DulwichClientTestBase:
     def disable_ff_and_make_dummy_commit(self):
         # disable non-fast-forward pushes to the server
         dest = repo.Repo(os.path.join(self.gitroot, "dest"))
+        self.addCleanup(dest.close)
         run_git_or_fail(
             ["config", "receive.denyNonFastForwards", "true"], cwd=dest.path
         )
@@ -293,6 +293,7 @@ class DulwichClientTestBase:
     def test_fetch_pack_with_nondefault_symref(self) -> None:
         c = self._client()
         src = repo.Repo(os.path.join(self.gitroot, "server_new.export"))
+        self.addCleanup(src.close)
         src.refs.add_if_new(b"refs/heads/main", src.refs[b"refs/heads/master"])
         src.refs.set_symbolic_ref(b"HEAD", b"refs/heads/main")
         with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
@@ -405,35 +406,37 @@ class DulwichClientTestBase:
 
     def test_repeat(self) -> None:
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
-            result = c.fetch(self._build_path("/server_new.export"), dest)
-            for r in result.refs.items():
-                dest.refs.set_if_equals(r[0], None, r[1])
-            self.assertDestEqualsSrc()
-            result = c.fetch(self._build_path("/server_new.export"), dest)
-            for r in result.refs.items():
-                dest.refs.set_if_equals(r[0], None, r[1])
-            self.assertDestEqualsSrc()
+        dest = repo.Repo(os.path.join(self.gitroot, "dest"))
+        self.addCleanup(dest.close)
+        result = c.fetch(self._build_path("/server_new.export"), dest)
+        for r in result.refs.items():
+            dest.refs.set_if_equals(r[0], None, r[1])
+        self.assertDestEqualsSrc()
+        result = c.fetch(self._build_path("/server_new.export"), dest)
+        for r in result.refs.items():
+            dest.refs.set_if_equals(r[0], None, r[1])
+        self.assertDestEqualsSrc()
 
     def test_fetch_empty_pack(self) -> None:
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
-            result = c.fetch(self._build_path("/server_new.export"), dest)
-            for r in result.refs.items():
-                dest.refs.set_if_equals(r[0], None, r[1])
-            self.assertDestEqualsSrc()
+        dest = repo.Repo(os.path.join(self.gitroot, "dest"))
+        self.addCleanup(dest.close)
+        result = c.fetch(self._build_path("/server_new.export"), dest)
+        for r in result.refs.items():
+            dest.refs.set_if_equals(r[0], None, r[1])
+        self.assertDestEqualsSrc()
 
-            def dw(refs, **kwargs):
-                return list(refs.values())
+        def dw(refs, **kwargs):
+            return list(refs.values())
 
-            result = c.fetch(
-                self._build_path("/server_new.export"),
-                dest,
-                determine_wants=dw,
-            )
-            for r in result.refs.items():
-                dest.refs.set_if_equals(r[0], None, r[1])
-            self.assertDestEqualsSrc()
+        result = c.fetch(
+            self._build_path("/server_new.export"),
+            dest,
+            determine_wants=dw,
+        )
+        for r in result.refs.items():
+            dest.refs.set_if_equals(r[0], None, r[1])
+        self.assertDestEqualsSrc()
 
     def test_incremental_fetch_pack(self) -> None:
         self.test_fetch_pack()
@@ -561,7 +564,6 @@ class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
         self.process.wait()
         self.process.stdout.close()
         self.process.stderr.close()
-        DulwichClientTestBase.tearDown(self)
         CompatTestCase.tearDown(self)
 
     def _client(self):
@@ -633,7 +635,6 @@ class DulwichMockSSHClientTest(CompatTestCase, DulwichClientTestBase):
         client.get_ssh_vendor = TestSSHVendor
 
     def tearDown(self) -> None:
-        DulwichClientTestBase.tearDown(self)
         CompatTestCase.tearDown(self)
         client.get_ssh_vendor = self.real_vendor
 
@@ -655,7 +656,6 @@ class DulwichSubprocessClientTest(CompatTestCase, DulwichClientTestBase):
         DulwichClientTestBase.setUp(self)
 
     def tearDown(self) -> None:
-        DulwichClientTestBase.tearDown(self)
         CompatTestCase.tearDown(self)
 
     def _client(self):
@@ -843,7 +843,6 @@ class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):
         run_git_or_fail(["config", "http.receivepack", "true"], cwd=self.dest)
 
     def tearDown(self) -> None:
-        DulwichClientTestBase.tearDown(self)
         CompatTestCase.tearDown(self)
         self._httpd.shutdown()
         self._httpd.socket.close()

+ 2 - 0
tests/compat/test_dumb.py

@@ -97,6 +97,7 @@ class DumbHTTPGitServer:
     def stop(self):
         """Stop the HTTP server."""
         self.server.shutdown()
+        self.server.server_close()
         if self.thread:
             self.thread.join()
 
@@ -174,6 +175,7 @@ class DumbHTTPClientNoPackTests(CompatTestCase):
         dest_path = os.path.join(self.temp_dir, "cloned")
         # Use a dummy errstream to suppress progress output
         repo = clone(self.server.url, dest_path, errstream=io.BytesIO())
+        self.addCleanup(repo.close)
         assert b"HEAD" in repo
 
     def test_clone_from_dumb_http(self):

+ 2 - 0
tests/compat/test_index.py

@@ -1063,6 +1063,7 @@ class SparseIndexCompatTestCase(CompatTestCase):
 
         # Create a tree structure using Dulwich
         repo = Repo(repo_path)
+        self.addCleanup(repo.close)
 
         # Create blobs
         blob1 = Blob()
@@ -1171,6 +1172,7 @@ class SparseIndexCompatTestCase(CompatTestCase):
         if idx.is_sparse():
             # Expand the index
             repo = Repo(repo_path)
+            self.addCleanup(repo.close)
             idx.ensure_full_index(repo.object_store)
 
             # Should no longer be sparse

+ 29 - 21
tests/compat/test_server.py

@@ -59,13 +59,17 @@ class GitServerTestCase(ServerTests, CompatTestCase):
 
         receive_pack_handler_cls = dul_server.handlers[b"git-receive-pack"]
         # Create a handler instance to check capabilities
-        handler = receive_pack_handler_cls(
-            dul_server.backend,
-            [b"/"],
-            Protocol(lambda x: b"", lambda x: None),
-        )
-        caps = handler.capabilities()
-        self.assertNotIn(b"side-band-64k", caps)
+        proto = Protocol(lambda x: b"", lambda x: None)
+        try:
+            handler = receive_pack_handler_cls(
+                dul_server.backend,
+                [b"/"],
+                proto,
+            )
+            caps = handler.capabilities()
+            self.assertNotIn(b"side-band-64k", caps)
+        finally:
+            proto.close()
 
     def _start_server(self, repo):
         backend = DictBackend({b"/": repo})
@@ -77,12 +81,12 @@ class GitServerTestCase(ServerTests, CompatTestCase):
         server_thread.daemon = True  # Make thread daemon so it dies with main thread
         server_thread.start()
 
-        # Add cleanup in the correct order
+        # Add cleanup in the correct order - shutdown first, then close
         def cleanup_server():
             dul_server.shutdown()
-            dul_server.server_close()
-            # Give thread a moment to exit cleanly
+            # Give thread a moment to exit cleanly before closing socket
             server_thread.join(timeout=1.0)
+            dul_server.server_close()
 
         self.addCleanup(cleanup_server)
         self._server = dul_server
@@ -113,13 +117,17 @@ class GitServerSideBand64kTestCase(GitServerTestCase):
 
         receive_pack_handler_cls = server.handlers[b"git-receive-pack"]
         # Create a handler instance to check capabilities
-        handler = receive_pack_handler_cls(
-            server.backend,
-            [b"/"],
-            Protocol(lambda x: b"", lambda x: None),
-        )
-        caps = handler.capabilities()
-        self.assertIn(b"side-band-64k", caps)
+        proto = Protocol(lambda x: b"", lambda x: None)
+        try:
+            handler = receive_pack_handler_cls(
+                server.backend,
+                [b"/"],
+                proto,
+            )
+            caps = handler.capabilities()
+            self.assertIn(b"side-band-64k", caps)
+        finally:
+            proto.close()
 
 
 @skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
@@ -143,11 +151,12 @@ class GitServerSHA256TestCase(CompatTestCase):
         server_thread.daemon = True
         server_thread.start()
 
-        # Add cleanup
+        # Add cleanup - shutdown first, then close
         def cleanup_server():
             dul_server.shutdown()
-            dul_server.server_close()
+            # Give thread a moment to exit cleanly before closing socket
             server_thread.join(timeout=1.0)
+            dul_server.server_close()
 
         self.addCleanup(cleanup_server)
         self._server = dul_server
@@ -163,6 +172,7 @@ class GitServerSHA256TestCase(CompatTestCase):
         repo_path = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, repo_path)
         source_repo = Repo.init(repo_path, mkdir=False, object_format="sha256")
+        self.addCleanup(source_repo.close)
 
         # Create test content
         blob = Blob.from_string(b"Test SHA-256 content from dulwich server")
@@ -213,5 +223,3 @@ class GitServerSHA256TestCase(CompatTestCase):
         # Verify git can read the commit
         log_output = run_git_or_fail(["log", "--format=%s", "-n", "1"], cwd=clone_dir)
         self.assertEqual(log_output.strip(), b"Test SHA-256 commit")
-
-        source_repo.close()

+ 24 - 0
tests/test_bundle.py

@@ -43,6 +43,7 @@ class BundleTests(TestCase):
     def test_bundle_repr(self) -> None:
         """Test the Bundle.__repr__ method."""
         bundle = Bundle()
+        self.addCleanup(bundle.close)
         bundle.version = 3
         bundle.capabilities = {"foo": "bar"}
         bundle.prerequisites = [(b"cc" * 20, "comment")]
@@ -53,6 +54,7 @@ class BundleTests(TestCase):
         write_pack_objects(b.write, [], object_format=DEFAULT_OBJECT_FORMAT)
         b.seek(0)
         bundle.pack_data = PackData.from_file(b, object_format=DEFAULT_OBJECT_FORMAT)
+        self.addCleanup(bundle.pack_data.close)
 
         # Check the repr output
         rep = repr(bundle)
@@ -65,6 +67,7 @@ class BundleTests(TestCase):
         """Test the Bundle.__eq__ method."""
         # Create two identical bundles
         bundle1 = Bundle()
+        self.addCleanup(bundle1.close)
         bundle1.version = 3
         bundle1.capabilities = {"foo": "bar"}
         bundle1.prerequisites = [(b"cc" * 20, "comment")]
@@ -76,6 +79,7 @@ class BundleTests(TestCase):
         bundle1.pack_data = PackData.from_file(b1, object_format=DEFAULT_OBJECT_FORMAT)
 
         bundle2 = Bundle()
+        self.addCleanup(bundle2.close)
         bundle2.version = 3
         bundle2.capabilities = {"foo": "bar"}
         bundle2.prerequisites = [(b"cc" * 20, "comment")]
@@ -91,6 +95,7 @@ class BundleTests(TestCase):
 
         # Test inequality by changing different attributes
         bundle3 = Bundle()
+        self.addCleanup(bundle3.close)
         bundle3.version = 2  # Different version
         bundle3.capabilities = {"foo": "bar"}
         bundle3.prerequisites = [(b"cc" * 20, "comment")]
@@ -102,6 +107,7 @@ class BundleTests(TestCase):
         self.assertNotEqual(bundle1, bundle3)
 
         bundle4 = Bundle()
+        self.addCleanup(bundle4.close)
         bundle4.version = 3
         bundle4.capabilities = {"different": "value"}  # Different capabilities
         bundle4.prerequisites = [(b"cc" * 20, "comment")]
@@ -113,6 +119,7 @@ class BundleTests(TestCase):
         self.assertNotEqual(bundle1, bundle4)
 
         bundle5 = Bundle()
+        self.addCleanup(bundle5.close)
         bundle5.version = 3
         bundle5.capabilities = {"foo": "bar"}
         bundle5.prerequisites = [(b"dd" * 20, "different")]  # Different prerequisites
@@ -124,6 +131,7 @@ class BundleTests(TestCase):
         self.assertNotEqual(bundle1, bundle5)
 
         bundle6 = Bundle()
+        self.addCleanup(bundle6.close)
         bundle6.version = 3
         bundle6.capabilities = {"foo": "bar"}
         bundle6.prerequisites = [(b"cc" * 20, "comment")]
@@ -153,6 +161,7 @@ class BundleTests(TestCase):
         f.seek(0)
 
         bundle = read_bundle(f)
+        self.addCleanup(bundle.close)
         self.assertEqual(2, bundle.version)
         self.assertEqual({}, bundle.capabilities)
         self.assertEqual([(b"cc" * 20, b"prerequisite comment")], bundle.prerequisites)
@@ -174,6 +183,7 @@ class BundleTests(TestCase):
         f.seek(0)
 
         bundle = read_bundle(f)
+        self.addCleanup(bundle.close)
         self.assertEqual(3, bundle.version)
         self.assertEqual(
             {"capability1": None, "capability2": "value2"}, bundle.capabilities
@@ -193,6 +203,7 @@ class BundleTests(TestCase):
     def test_write_bundle_v2(self) -> None:
         """Test writing a v2 bundle."""
         bundle = Bundle()
+        self.addCleanup(bundle.close)
         bundle.version = 2
         bundle.capabilities = {}
         bundle.prerequisites = [(b"cc" * 20, b"prerequisite comment")]
@@ -219,6 +230,7 @@ class BundleTests(TestCase):
     def test_write_bundle_v3(self) -> None:
         """Test writing a v3 bundle with capabilities."""
         bundle = Bundle()
+        self.addCleanup(bundle.close)
         bundle.version = 3
         bundle.capabilities = {"capability1": None, "capability2": "value2"}
         bundle.prerequisites = [(b"cc" * 20, b"prerequisite comment")]
@@ -248,6 +260,7 @@ class BundleTests(TestCase):
         """Test writing a bundle with auto-detected version."""
         # Create a bundle with no explicit version but capabilities
         bundle1 = Bundle()
+        self.addCleanup(bundle1.close)
         bundle1.version = None
         bundle1.capabilities = {"capability1": "value1"}
         bundle1.prerequisites = [(b"cc" * 20, b"prerequisite comment")]
@@ -266,6 +279,7 @@ class BundleTests(TestCase):
 
         # Create a bundle with no explicit version and no capabilities
         bundle2 = Bundle()
+        self.addCleanup(bundle2.close)
         bundle2.version = None
         bundle2.capabilities = {}
         bundle2.prerequisites = [(b"cc" * 20, b"prerequisite comment")]
@@ -285,6 +299,7 @@ class BundleTests(TestCase):
     def test_write_bundle_invalid_version(self) -> None:
         """Test writing a bundle with an invalid version."""
         bundle = Bundle()
+        self.addCleanup(bundle.close)
         bundle.version = 4  # Invalid version
         bundle.capabilities = {}
         bundle.prerequisites = []
@@ -301,6 +316,7 @@ class BundleTests(TestCase):
 
     def test_roundtrip_bundle(self) -> None:
         origbundle = Bundle()
+        self.addCleanup(origbundle.close)
         origbundle.version = 3
         origbundle.capabilities = {"foo": None}
         origbundle.references = {b"refs/heads/master": b"ab" * 20}
@@ -317,6 +333,7 @@ class BundleTests(TestCase):
 
             with open(os.path.join(td, "foo"), "rb") as f:
                 newbundle = read_bundle(f)
+                self.addCleanup(newbundle.close)
 
                 self.assertEqual(origbundle, newbundle)
 
@@ -324,6 +341,7 @@ class BundleTests(TestCase):
         """Test creating a bundle from a repository."""
         # Create a simple repository
         repo = MemoryRepo()
+        self.addCleanup(repo.close)
 
         # Create a blob
         blob = Blob.from_string(b"Hello world")
@@ -348,6 +366,7 @@ class BundleTests(TestCase):
 
         # Create bundle from repository
         bundle = create_bundle_from_repo(repo)
+        self.addCleanup(bundle.close)
 
         # Verify bundle contents
         self.assertEqual(bundle.references, {b"refs/heads/master": commit.id})
@@ -387,6 +406,7 @@ class BundleTests(TestCase):
         # Create bundle with prerequisites
         prereq_id = b"aa" * 20  # hex string like other object ids
         bundle = create_bundle_from_repo(repo, prerequisites=[prereq_id])
+        self.addCleanup(bundle.close)
 
         # Verify prerequisites are included
         self.assertEqual(len(bundle.prerequisites), 1)
@@ -419,6 +439,7 @@ class BundleTests(TestCase):
         from dulwich.refs import Ref
 
         bundle = create_bundle_from_repo(repo, refs=[Ref(b"refs/heads/master")])
+        self.addCleanup(bundle.close)
 
         # Verify only master ref is included
         self.assertEqual(len(bundle.references), 1)
@@ -450,6 +471,7 @@ class BundleTests(TestCase):
         # Create bundle with capabilities
         capabilities = {"object-format": "sha1"}
         bundle = create_bundle_from_repo(repo, capabilities=capabilities, version=3)
+        self.addCleanup(bundle.close)
 
         # Verify capabilities are included
         self.assertEqual(bundle.capabilities, capabilities)
@@ -482,6 +504,7 @@ class BundleTests(TestCase):
 
         # Use blob.id directly (40-byte hex bytestring)
         bundle = create_bundle_from_repo(repo, prerequisites=[prereq_blob.id])
+        self.addCleanup(bundle.close)
 
         # Verify the prerequisite was added correctly
         self.assertEqual(len(bundle.prerequisites), 1)
@@ -513,6 +536,7 @@ class BundleTests(TestCase):
         prereq_hex = b"aa" * 20
 
         bundle = create_bundle_from_repo(repo, prerequisites=[prereq_hex])
+        self.addCleanup(bundle.close)
 
         # Verify the prerequisite was added correctly
         self.assertEqual(len(bundle.prerequisites), 1)

+ 27 - 12
tests/test_client.py

@@ -933,12 +933,18 @@ class TestSSHVendor:
         self.protocol_version = protocol_version
 
         class Subprocess:
-            pass
+            def read(self, *args):
+                return None
+
+            def write(self, *args):
+                return None
+
+            def close(self):
+                pass
+
+            def can_read(self):
+                return None
 
-        Subprocess.read = lambda: None
-        Subprocess.write = lambda: None
-        Subprocess.close = lambda: None
-        Subprocess.can_read = lambda: None
         return Subprocess()
 
 
@@ -996,13 +1002,19 @@ class SSHGitClientTests(TestCase):
         client.username = b"username"
         client.port = 1337
 
-        client._connect(b"command", b"/path/to/repo")
-        self.assertEqual(b"username", server.username)
-        self.assertEqual(1337, server.port)
-        self.assertEqual(b"git-command '/path/to/repo'", server.command)
+        proto, _, _ = client._connect(b"command", b"/path/to/repo")
+        try:
+            self.assertEqual(b"username", server.username)
+            self.assertEqual(1337, server.port)
+            self.assertEqual(b"git-command '/path/to/repo'", server.command)
+        finally:
+            proto.close()
 
-        client._connect(b"relative-command", b"/~/path/to/repo")
-        self.assertEqual(b"git-relative-command '~/path/to/repo'", server.command)
+        proto, _, _ = client._connect(b"relative-command", b"/~/path/to/repo")
+        try:
+            self.assertEqual(b"git-relative-command '~/path/to/repo'", server.command)
+        finally:
+            proto.close()
 
     def test_ssh_command_precedence(self) -> None:
         self.overrideEnv("GIT_SSH", "/path/to/ssh")
@@ -1053,7 +1065,8 @@ class SSHGitClientTests(TestCase):
         client.key_filename = "/path/to/key"
 
         # Connect and verify all kwargs are passed through
-        client._connect(b"upload-pack", b"/path/to/repo")
+        proto, _, _ = client._connect(b"upload-pack", b"/path/to/repo")
+        self.addCleanup(proto.close)
 
         self.assertEqual(server.ssh_command, "custom-ssh-wrapper.sh -o Option=Value")
         self.assertEqual(server.password, "test-password")
@@ -1373,6 +1386,7 @@ class BundleClientTests(TestCase):
 
         # Create bundle
         bundle = create_bundle_from_repo(repo)
+        self.addCleanup(bundle.close)
 
         # Write bundle to file
         bundle_path = os.path.join(self.tempdir, "test.bundle")
@@ -1440,6 +1454,7 @@ class BundleClientTests(TestCase):
 
         client = BundleClient()
         target_repo = MemoryRepo()
+        self.addCleanup(target_repo.close)
 
         result = client.fetch(bundle_path, target_repo)
 

+ 2 - 0
tests/test_commit_graph.py

@@ -631,6 +631,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         object_store_path = os.path.join(self.tempdir, "objects")
         os.makedirs(object_store_path, exist_ok=True)
         object_store = DiskObjectStore(object_store_path)
+        self.addCleanup(object_store.close)
 
         # Create a tree and commit
         tree = Tree()
@@ -752,6 +753,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         object_store_no_graph_path = os.path.join(self.tempdir, "objects2")
         os.makedirs(object_store_no_graph_path, exist_ok=True)
         object_store_no_graph = DiskObjectStore(object_store_no_graph_path)
+        self.addCleanup(object_store_no_graph.close)
         object_store_no_graph.add_object(tree)
         object_store_no_graph.add_object(commit1)
         object_store_no_graph.add_object(commit2)

+ 1 - 0
tests/test_fastexport.py

@@ -89,6 +89,7 @@ class GitImportProcessorTests(TestCase):
     def setUp(self) -> None:
         super().setUp()
         self.repo = MemoryRepo()
+        self.addCleanup(self.repo.close)
         try:
             from dulwich.fastexport import GitImportProcessor
         except ImportError as exc:

+ 182 - 178
tests/test_gc.py

@@ -447,6 +447,7 @@ class AutoGCTestCase(TestCase):
     def test_should_run_gc_disabled(self):
         """Test that auto GC doesn't run when gc.auto is 0."""
         r = MemoryRepo()
+        self.addCleanup(r.close)
         config = ConfigDict()
         config.set(b"gc", b"auto", b"0")
 
@@ -455,6 +456,7 @@ class AutoGCTestCase(TestCase):
     def test_should_run_gc_disabled_by_env_var(self):
         """Test that auto GC doesn't run when GIT_AUTO_GC environment variable is 0."""
         r = MemoryRepo()
+        self.addCleanup(r.close)
         config = ConfigDict()
         config.set(b"gc", b"auto", b"10")  # Should normally run
 
@@ -464,6 +466,7 @@ class AutoGCTestCase(TestCase):
     def test_should_run_gc_disabled_programmatically(self):
         """Test that auto GC doesn't run when disabled via _autogc_disabled attribute."""
         r = MemoryRepo()
+        self.addCleanup(r.close)
         config = ConfigDict()
         config.set(b"gc", b"auto", b"10")  # Should normally run
 
@@ -479,6 +482,7 @@ class AutoGCTestCase(TestCase):
     def test_should_run_gc_default_values(self):
         """Test auto GC with default configuration values."""
         r = MemoryRepo()
+        self.addCleanup(r.close)
         config = ConfigDict()
 
         # Should not run with empty repo
@@ -487,98 +491,97 @@ class AutoGCTestCase(TestCase):
     def test_should_run_gc_with_loose_objects(self):
         """Test that auto GC triggers based on loose object count."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-            config = ConfigDict()
-            config.set(b"gc", b"auto", b"10")  # Low threshold for testing
+            with Repo.init(tmpdir) as r:
+                config = ConfigDict()
+                config.set(b"gc", b"auto", b"10")  # Low threshold for testing
 
-            # Add some loose objects
-            for i in range(15):
-                blob = Blob()
-                blob.data = f"test blob {i}".encode()
-                r.object_store.add_object(blob)
+                # Add some loose objects
+                for i in range(15):
+                    blob = Blob()
+                    blob.data = f"test blob {i}".encode()
+                    r.object_store.add_object(blob)
 
-            self.assertTrue(should_run_gc(r, config))
+                self.assertTrue(should_run_gc(r, config))
 
     def test_should_run_gc_with_pack_limit(self):
         """Test that auto GC triggers based on pack file count."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-            config = ConfigDict()
-            config.set(b"gc", b"autoPackLimit", b"2")  # Low threshold for testing
+            with Repo.init(tmpdir) as r:
+                config = ConfigDict()
+                config.set(b"gc", b"autoPackLimit", b"2")  # Low threshold for testing
 
-            # Create some pack files by repacking
-            for i in range(3):
-                blob = Blob()
-                blob.data = f"test blob {i}".encode()
-                r.object_store.add_object(blob)
-                r.object_store.pack_loose_objects(progress=no_op_progress)
+                # Create some pack files by repacking
+                for i in range(3):
+                    blob = Blob()
+                    blob.data = f"test blob {i}".encode()
+                    r.object_store.add_object(blob)
+                    r.object_store.pack_loose_objects(progress=no_op_progress)
 
-            # Force re-enumeration of packs
-            r.object_store._update_pack_cache()
+                # Force re-enumeration of packs
+                r.object_store._update_pack_cache()
 
-            self.assertTrue(should_run_gc(r, config))
+                self.assertTrue(should_run_gc(r, config))
 
     def test_count_loose_objects(self):
         """Test counting loose objects."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-
-            # Initially should have no loose objects
-            count = r.object_store.count_loose_objects()
-            self.assertEqual(0, count)
+            with Repo.init(tmpdir) as r:
+                # Initially should have no loose objects
+                count = r.object_store.count_loose_objects()
+                self.assertEqual(0, count)
 
-            # Add some loose objects
-            for i in range(5):
-                blob = Blob()
-                blob.data = f"test blob {i}".encode()
-                r.object_store.add_object(blob)
+                # Add some loose objects
+                for i in range(5):
+                    blob = Blob()
+                    blob.data = f"test blob {i}".encode()
+                    r.object_store.add_object(blob)
 
-            count = r.object_store.count_loose_objects()
-            self.assertEqual(5, count)
+                count = r.object_store.count_loose_objects()
+                self.assertEqual(5, count)
 
     def test_count_pack_files(self):
         """Test counting pack files."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-
-            # Initially should have no packs
-            count = r.object_store.count_pack_files()
-            self.assertEqual(0, count)
+            with Repo.init(tmpdir) as r:
+                # Initially should have no packs
+                count = r.object_store.count_pack_files()
+                self.assertEqual(0, count)
 
-            # Create a pack
-            blob = Blob()
-            blob.data = b"test blob"
-            r.object_store.add_object(blob)
-            r.object_store.pack_loose_objects(progress=no_op_progress)
+                # Create a pack
+                blob = Blob()
+                blob.data = b"test blob"
+                r.object_store.add_object(blob)
+                r.object_store.pack_loose_objects(progress=no_op_progress)
 
-            # Force re-enumeration of packs
-            r.object_store._update_pack_cache()
+                # Force re-enumeration of packs
+                r.object_store._update_pack_cache()
 
-            count = r.object_store.count_pack_files()
-            self.assertEqual(1, count)
+                count = r.object_store.count_pack_files()
+                self.assertEqual(1, count)
 
     def test_maybe_auto_gc_runs_when_needed(self):
         """Test that auto GC runs when thresholds are exceeded."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-            config = ConfigDict()
-            config.set(b"gc", b"auto", b"5")  # Low threshold for testing
+            with Repo.init(tmpdir) as r:
+                config = ConfigDict()
+                config.set(b"gc", b"auto", b"5")  # Low threshold for testing
 
-            # Add enough loose objects to trigger GC
-            for i in range(10):
-                blob = Blob()
-                blob.data = f"test blob {i}".encode()
-                r.object_store.add_object(blob)
+                # Add enough loose objects to trigger GC
+                for i in range(10):
+                    blob = Blob()
+                    blob.data = f"test blob {i}".encode()
+                    r.object_store.add_object(blob)
 
-            with patch("dulwich.gc.garbage_collect") as mock_gc:
-                result = maybe_auto_gc(r, config, progress=no_op_progress)
+                with patch("dulwich.gc.garbage_collect") as mock_gc:
+                    result = maybe_auto_gc(r, config, progress=no_op_progress)
 
-            self.assertTrue(result)
-            mock_gc.assert_called_once_with(r, auto=True, progress=no_op_progress)
+                self.assertTrue(result)
+                mock_gc.assert_called_once_with(r, auto=True, progress=no_op_progress)
 
     def test_maybe_auto_gc_skips_when_not_needed(self):
         """Test that auto GC doesn't run when thresholds are not exceeded."""
         r = MemoryRepo()
+        self.addCleanup(r.close)
         config = ConfigDict()
 
         with patch("dulwich.gc.garbage_collect") as mock_gc:
@@ -590,151 +593,152 @@ class AutoGCTestCase(TestCase):
     def test_maybe_auto_gc_with_gc_log(self):
         """Test that auto GC is skipped when gc.log exists and is recent."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-            config = ConfigDict()
-            config.set(b"gc", b"auto", b"1")  # Low threshold
+            with Repo.init(tmpdir) as r:
+                config = ConfigDict()
+                config.set(b"gc", b"auto", b"1")  # Low threshold
 
-            # Create gc.log file
-            gc_log_path = os.path.join(r.controldir(), "gc.log")
-            with open(gc_log_path, "wb") as f:
-                f.write(b"Previous GC failed\n")
+                # Create gc.log file
+                gc_log_path = os.path.join(r.controldir(), "gc.log")
+                with open(gc_log_path, "wb") as f:
+                    f.write(b"Previous GC failed\n")
 
-            # Add objects to trigger GC
-            blob = Blob()
-            blob.data = b"test"
-            r.object_store.add_object(blob)
+                # Add objects to trigger GC
+                blob = Blob()
+                blob.data = b"test"
+                r.object_store.add_object(blob)
 
-            # Capture log messages
-            import logging
+                # Capture log messages
+                import logging
 
-            with self.assertLogs(level=logging.INFO) as cm:
-                result = maybe_auto_gc(r, config, progress=no_op_progress)
+                with self.assertLogs(level=logging.INFO) as cm:
+                    result = maybe_auto_gc(r, config, progress=no_op_progress)
 
-            self.assertFalse(result)
-            # Verify gc.log contents were logged
-            self.assertTrue(any("Previous GC failed" in msg for msg in cm.output))
+                self.assertFalse(result)
+                # Verify gc.log contents were logged
+                self.assertTrue(any("Previous GC failed" in msg for msg in cm.output))
 
     def test_maybe_auto_gc_with_expired_gc_log(self):
         """Test that auto GC runs when gc.log exists but is expired."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-            config = ConfigDict()
-            config.set(b"gc", b"auto", b"1")  # Low threshold
-            config.set(b"gc", b"logExpiry", b"0.days")  # Expire immediately
+            with Repo.init(tmpdir) as r:
+                config = ConfigDict()
+                config.set(b"gc", b"auto", b"1")  # Low threshold
+                config.set(b"gc", b"logExpiry", b"0.days")  # Expire immediately
 
-            # Create gc.log file
-            gc_log_path = os.path.join(r.controldir(), "gc.log")
-            with open(gc_log_path, "wb") as f:
-                f.write(b"Previous GC failed\n")
+                # Create gc.log file
+                gc_log_path = os.path.join(r.controldir(), "gc.log")
+                with open(gc_log_path, "wb") as f:
+                    f.write(b"Previous GC failed\n")
 
-            # Make the file old
-            old_time = time.time() - 86400  # 1 day ago
-            os.utime(gc_log_path, (old_time, old_time))
+                # Make the file old
+                old_time = time.time() - 86400  # 1 day ago
+                os.utime(gc_log_path, (old_time, old_time))
 
-            # Add objects to trigger GC
-            blob = Blob()
-            blob.data = b"test"
-            r.object_store.add_object(blob)
+                # Add objects to trigger GC
+                blob = Blob()
+                blob.data = b"test"
+                r.object_store.add_object(blob)
 
-            with patch("dulwich.gc.garbage_collect") as mock_gc:
-                result = maybe_auto_gc(r, config, progress=no_op_progress)
+                with patch("dulwich.gc.garbage_collect") as mock_gc:
+                    result = maybe_auto_gc(r, config, progress=no_op_progress)
 
-            self.assertTrue(result)
-            mock_gc.assert_called_once_with(r, auto=True, progress=no_op_progress)
-            # gc.log should be removed after successful GC
-            self.assertFalse(os.path.exists(gc_log_path))
+                self.assertTrue(result)
+                mock_gc.assert_called_once_with(r, auto=True, progress=no_op_progress)
+                # gc.log should be removed after successful GC
+                self.assertFalse(os.path.exists(gc_log_path))
 
     def test_maybe_auto_gc_handles_gc_failure(self):
         """Test that auto GC handles failures gracefully."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-            config = ConfigDict()
-            config.set(b"gc", b"auto", b"1")  # Low threshold
+            with Repo.init(tmpdir) as r:
+                config = ConfigDict()
+                config.set(b"gc", b"auto", b"1")  # Low threshold
 
-            # Add objects to trigger GC
-            blob = Blob()
-            blob.data = b"test"
-            r.object_store.add_object(blob)
+                # Add objects to trigger GC
+                blob = Blob()
+                blob.data = b"test"
+                r.object_store.add_object(blob)
 
-            with patch(
-                "dulwich.gc.garbage_collect", side_effect=OSError("GC failed")
-            ) as mock_gc:
-                result = maybe_auto_gc(r, config, progress=no_op_progress)
+                with patch(
+                    "dulwich.gc.garbage_collect", side_effect=OSError("GC failed")
+                ) as mock_gc:
+                    result = maybe_auto_gc(r, config, progress=no_op_progress)
 
-            self.assertFalse(result)
-            mock_gc.assert_called_once_with(r, auto=True, progress=no_op_progress)
+                self.assertFalse(result)
+                mock_gc.assert_called_once_with(r, auto=True, progress=no_op_progress)
 
-            # Check that error was written to gc.log
-            gc_log_path = os.path.join(r.controldir(), "gc.log")
-            self.assertTrue(os.path.exists(gc_log_path))
-            with open(gc_log_path, "rb") as f:
-                content = f.read()
-                self.assertIn(b"Auto GC failed: GC failed", content)
+                # Check that error was written to gc.log
+                gc_log_path = os.path.join(r.controldir(), "gc.log")
+                self.assertTrue(os.path.exists(gc_log_path))
+                with open(gc_log_path, "rb") as f:
+                    content = f.read()
+                    self.assertIn(b"Auto GC failed: GC failed", content)
 
     def test_gc_log_expiry_singular_day(self):
         """Test that gc.logExpiry supports singular '.day' format."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-            config = ConfigDict()
-            config.set(b"gc", b"auto", b"1")  # Low threshold
-            config.set(b"gc", b"logExpiry", b"1.day")  # Singular form
+            with Repo.init(tmpdir) as r:
+                config = ConfigDict()
+                config.set(b"gc", b"auto", b"1")  # Low threshold
+                config.set(b"gc", b"logExpiry", b"1.day")  # Singular form
 
-            # Create gc.log file
-            gc_log_path = os.path.join(r.controldir(), "gc.log")
-            with open(gc_log_path, "wb") as f:
-                f.write(b"Previous GC failed\n")
+                # Create gc.log file
+                gc_log_path = os.path.join(r.controldir(), "gc.log")
+                with open(gc_log_path, "wb") as f:
+                    f.write(b"Previous GC failed\n")
 
-            # Make the file 2 days old (older than 1 day expiry)
-            old_time = time.time() - (2 * 86400)
-            os.utime(gc_log_path, (old_time, old_time))
+                # Make the file 2 days old (older than 1 day expiry)
+                old_time = time.time() - (2 * 86400)
+                os.utime(gc_log_path, (old_time, old_time))
 
-            # Add objects to trigger GC
-            blob = Blob()
-            blob.data = b"test"
-            r.object_store.add_object(blob)
+                # Add objects to trigger GC
+                blob = Blob()
+                blob.data = b"test"
+                r.object_store.add_object(blob)
 
-            with patch("dulwich.gc.garbage_collect") as mock_gc:
-                result = maybe_auto_gc(r, config, progress=no_op_progress)
+                with patch("dulwich.gc.garbage_collect") as mock_gc:
+                    result = maybe_auto_gc(r, config, progress=no_op_progress)
 
-            self.assertTrue(result)
-            mock_gc.assert_called_once_with(r, auto=True, progress=no_op_progress)
+                self.assertTrue(result)
+                mock_gc.assert_called_once_with(r, auto=True, progress=no_op_progress)
 
     def test_gc_log_expiry_invalid_format(self):
         """Test that invalid gc.logExpiry format defaults to 1 day."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-            config = ConfigDict()
-            config.set(b"gc", b"auto", b"1")  # Low threshold
-            config.set(b"gc", b"logExpiry", b"invalid")  # Invalid format
+            with Repo.init(tmpdir) as r:
+                config = ConfigDict()
+                config.set(b"gc", b"auto", b"1")  # Low threshold
+                config.set(b"gc", b"logExpiry", b"invalid")  # Invalid format
 
-            # Create gc.log file
-            gc_log_path = os.path.join(r.controldir(), "gc.log")
-            with open(gc_log_path, "wb") as f:
-                f.write(b"Previous GC failed\n")
+                # Create gc.log file
+                gc_log_path = os.path.join(r.controldir(), "gc.log")
+                with open(gc_log_path, "wb") as f:
+                    f.write(b"Previous GC failed\n")
 
-            # Make the file recent (within default 1 day)
-            recent_time = time.time() - 3600  # 1 hour ago
-            os.utime(gc_log_path, (recent_time, recent_time))
+                # Make the file recent (within default 1 day)
+                recent_time = time.time() - 3600  # 1 hour ago
+                os.utime(gc_log_path, (recent_time, recent_time))
 
-            # Add objects to trigger GC
-            blob = Blob()
-            blob.data = b"test"
-            r.object_store.add_object(blob)
+                # Add objects to trigger GC
+                blob = Blob()
+                blob.data = b"test"
+                r.object_store.add_object(blob)
 
-            # Capture log messages
-            import logging
+                # Capture log messages
+                import logging
 
-            with self.assertLogs(level=logging.INFO) as cm:
-                result = maybe_auto_gc(r, config, progress=no_op_progress)
+                with self.assertLogs(level=logging.INFO) as cm:
+                    result = maybe_auto_gc(r, config, progress=no_op_progress)
 
-            # Should not run GC because gc.log is recent (within default 1 day)
-            self.assertFalse(result)
-            # Check that gc.log content was logged
-            self.assertTrue(any("gc.log content:" in msg for msg in cm.output))
+                # Should not run GC because gc.log is recent (within default 1 day)
+                self.assertFalse(result)
+                # Check that gc.log content was logged
+                self.assertTrue(any("gc.log content:" in msg for msg in cm.output))
 
     def test_maybe_auto_gc_non_disk_repo(self):
         """Test auto GC on non-disk repository (MemoryRepo)."""
         r = MemoryRepo()
+        self.addCleanup(r.close)
         config = ConfigDict()
         config.set(b"gc", b"auto", b"1")  # Would trigger if it were disk-based
 
@@ -752,27 +756,27 @@ class AutoGCTestCase(TestCase):
     def test_gc_removes_existing_gc_log_on_success(self):
         """Test that successful GC removes pre-existing gc.log file."""
         with tempfile.TemporaryDirectory() as tmpdir:
-            r = Repo.init(tmpdir)
-            config = ConfigDict()
-            config.set(b"gc", b"auto", b"1")  # Low threshold
+            with Repo.init(tmpdir) as r:
+                config = ConfigDict()
+                config.set(b"gc", b"auto", b"1")  # Low threshold
 
-            # Create gc.log file from previous failure
-            gc_log_path = os.path.join(r.controldir(), "gc.log")
-            with open(gc_log_path, "wb") as f:
-                f.write(b"Previous GC failed\n")
+                # Create gc.log file from previous failure
+                gc_log_path = os.path.join(r.controldir(), "gc.log")
+                with open(gc_log_path, "wb") as f:
+                    f.write(b"Previous GC failed\n")
 
-            # Make it old enough to be expired
-            old_time = time.time() - (2 * 86400)  # 2 days ago
-            os.utime(gc_log_path, (old_time, old_time))
+                # Make it old enough to be expired
+                old_time = time.time() - (2 * 86400)  # 2 days ago
+                os.utime(gc_log_path, (old_time, old_time))
 
-            # Add objects to trigger GC
-            blob = Blob()
-            blob.data = b"test"
-            r.object_store.add_object(blob)
+                # Add objects to trigger GC
+                blob = Blob()
+                blob.data = b"test"
+                r.object_store.add_object(blob)
 
-            # Run auto GC
-            result = maybe_auto_gc(r, config, progress=no_op_progress)
+                # Run auto GC
+                result = maybe_auto_gc(r, config, progress=no_op_progress)
 
-            self.assertTrue(result)
-            # gc.log should be removed after successful GC
-            self.assertFalse(os.path.exists(gc_log_path))
+                self.assertTrue(result)
+                # gc.log should be removed after successful GC
+                self.assertFalse(os.path.exists(gc_log_path))

+ 1 - 0
tests/test_grafts.py

@@ -193,6 +193,7 @@ class GraftsInMemoryRepoTests(GraftsInRepositoryBase, TestCase):
     def setUp(self) -> None:
         super().setUp()
         r = self._repo = MemoryRepo()
+        self.addCleanup(r.close)
 
         self._shas = []
 

+ 19 - 0
tests/test_graph.py

@@ -178,11 +178,13 @@ class FindMergeBaseTests(TestCase):
 class FindMergeBaseFunctionTests(TestCase):
     def test_find_merge_base_empty(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         # Empty list of commits
         self.assertEqual([], find_merge_base(r, []))
 
     def test_find_merge_base_single(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         r.object_store.add_objects([(base, None)])
         # Single commit returns itself
@@ -190,6 +192,7 @@ class FindMergeBaseFunctionTests(TestCase):
 
     def test_find_merge_base_identical(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         r.object_store.add_objects([(base, None)])
         # When the same commit is in both positions
@@ -197,6 +200,7 @@ class FindMergeBaseFunctionTests(TestCase):
 
     def test_find_merge_base_linear(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2 = make_commit(parents=[c1.id])
@@ -208,6 +212,7 @@ class FindMergeBaseFunctionTests(TestCase):
 
     def test_find_merge_base_diverged(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2a = make_commit(parents=[c1.id], message=b"2a")
@@ -218,6 +223,7 @@ class FindMergeBaseFunctionTests(TestCase):
 
     def test_find_merge_base_with_min_stamp(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit(commit_time=100)
         c1 = make_commit(parents=[base.id], commit_time=200)
         c2 = make_commit(parents=[c1.id], commit_time=300)
@@ -228,6 +234,7 @@ class FindMergeBaseFunctionTests(TestCase):
 
     def test_find_merge_base_multiple_common_ancestors(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit(commit_time=100)
         c1a = make_commit(parents=[base.id], commit_time=200, message=b"c1a")
         c1b = make_commit(parents=[base.id], commit_time=201, message=b"c1b")
@@ -247,11 +254,13 @@ class FindMergeBaseFunctionTests(TestCase):
 class FindOctopusBaseTests(TestCase):
     def test_find_octopus_base_empty(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         # Empty list of commits
         self.assertEqual([], find_octopus_base(r, []))
 
     def test_find_octopus_base_single(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         r.object_store.add_objects([(base, None)])
         # Single commit returns itself
@@ -259,6 +268,7 @@ class FindOctopusBaseTests(TestCase):
 
     def test_find_octopus_base_two_commits(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2 = make_commit(parents=[c1.id])
@@ -268,6 +278,7 @@ class FindOctopusBaseTests(TestCase):
 
     def test_find_octopus_base_multiple(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2a = make_commit(parents=[c1.id], message=b"2a")
@@ -283,6 +294,7 @@ class FindOctopusBaseTests(TestCase):
 class CanFastForwardTests(TestCase):
     def test_ff(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2 = make_commit(parents=[c1.id])
@@ -294,6 +306,7 @@ class CanFastForwardTests(TestCase):
 
     def test_diverged(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2a = make_commit(parents=[c1.id], message=b"2a")
@@ -306,6 +319,7 @@ class CanFastForwardTests(TestCase):
 
     def test_shallow_repository(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         # Create a shallow repository structure:
         # base (missing) -> c1 -> c2
         # We only have c1 and c2, base is missing (shallow boundary at c1)
@@ -552,11 +566,13 @@ class WorkListTest(TestCase):
 class IndependentTests(TestCase):
     def test_independent_empty(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         # Empty list of commits
         self.assertEqual([], independent(r, []))
 
     def test_independent_single(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         r.object_store.add_objects([(base, None)])
         # Single commit is independent
@@ -564,6 +580,7 @@ class IndependentTests(TestCase):
 
     def test_independent_linear(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2 = make_commit(parents=[c1.id])
@@ -573,6 +590,7 @@ class IndependentTests(TestCase):
 
     def test_independent_diverged(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2a = make_commit(parents=[c1.id], message=b"2a")
@@ -586,6 +604,7 @@ class IndependentTests(TestCase):
 
     def test_independent_mixed(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2a = make_commit(parents=[c1.id], message=b"2a")

+ 3 - 0
tests/test_merge.py

@@ -16,6 +16,7 @@ class MergeTests(unittest.TestCase):
 
     def setUp(self):
         self.repo = MemoryRepo()
+        self.addCleanup(self.repo.close)
         # Check if merge3 module is available
         if importlib.util.find_spec("merge3") is None:
             raise DependencyMissing("merge3")
@@ -300,6 +301,7 @@ class RecursiveMergeTests(unittest.TestCase):
 
     def setUp(self):
         self.repo = MemoryRepo()
+        self.addCleanup(self.repo.close)
         # Check if merge3 module is available
         if importlib.util.find_spec("merge3") is None:
             raise DependencyMissing("merge3")
@@ -738,6 +740,7 @@ class OctopusMergeTests(unittest.TestCase):
 
     def setUp(self):
         self.repo = MemoryRepo()
+        self.addCleanup(self.repo.close)
         # Check if merge3 module is available
         if importlib.util.find_spec("merge3") is None:
             raise DependencyMissing("merge3")

+ 6 - 0
tests/test_object_store.py

@@ -169,6 +169,7 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
         alternate_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, alternate_dir)
         alternate_store = DiskObjectStore(alternate_dir, loose_compression_level=6)
+        self.addCleanup(alternate_store.close)
         b2 = make_object(Blob, data=b"yummy data")
         alternate_store.add_object(b2)
 
@@ -176,9 +177,11 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
         alternate_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, alternate_dir)
         alternate_store = DiskObjectStore(alternate_dir)
+        self.addCleanup(alternate_store.close)
         b2 = make_object(Blob, data=b"yummy data")
         alternate_store.add_object(b2)
         store = DiskObjectStore(self.store_dir)
+        self.addCleanup(store.close)
         self.assertRaises(KeyError, store.__getitem__, b2.id)
         store.add_alternate_path(alternate_dir)
         self.assertIn(b2.id, store)
@@ -186,6 +189,7 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
 
     def test_read_alternate_paths(self) -> None:
         store = DiskObjectStore(self.store_dir)
+        self.addCleanup(store.close)
 
         abs_path = os.path.abspath(os.path.normpath("/abspath"))
         # ensures in particular existence of the alternates file
@@ -283,9 +287,11 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
         alternate_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, alternate_dir)
         alternate_store = DiskObjectStore(alternate_dir)
+        self.addCleanup(alternate_store.close)
         b2 = make_object(Blob, data=b"yummy data")
         alternate_store.add_object(b2)
         store = DiskObjectStore(self.store_dir)
+        self.addCleanup(store.close)
         self.assertRaises(KeyError, store.__getitem__, b2.id)
         store.add_alternate_path(os.path.relpath(alternate_dir, self.store_dir))
         self.assertEqual(list(alternate_store), list(store.alternates[0]))

+ 30 - 0
tests/test_objectspec.py

@@ -45,16 +45,19 @@ class ParseObjectTests(TestCase):
 
     def test_nonexistent(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         self.assertRaises(KeyError, parse_object, r, "thisdoesnotexist")
 
     def test_blob_by_sha(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         b = Blob.from_string(b"Blah")
         r.object_store.add_object(b)
         self.assertEqual(b, parse_object(r, b.id))
 
     def test_parent_caret(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         c1, c2, c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         # c3's parents are [c1, c2]
         self.assertEqual(c1, parse_object(r, c3.id + b"^1"))
@@ -63,6 +66,7 @@ class ParseObjectTests(TestCase):
 
     def test_parent_tilde(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         c1, c2, c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 2]])
         self.assertEqual(c2, parse_object(r, c3.id + b"~"))
         self.assertEqual(c2, parse_object(r, c3.id + b"~1"))
@@ -70,6 +74,7 @@ class ParseObjectTests(TestCase):
 
     def test_combined_operators(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         c1, c2, _c3, c4 = build_commit_graph(
             r.object_store, [[1], [2, 1], [3, 1, 2], [4, 3]]
         )
@@ -80,6 +85,7 @@ class ParseObjectTests(TestCase):
 
     def test_with_ref(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         c1, c2, c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 2]])
         r.refs[b"refs/heads/master"] = c3.id
         self.assertEqual(c2, parse_object(r, b"master~"))
@@ -87,6 +93,7 @@ class ParseObjectTests(TestCase):
 
     def test_caret_zero(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         c1, c2 = build_commit_graph(r.object_store, [[1], [2, 1]])
         # ^0 means the commit itself
         self.assertEqual(c2, parse_object(r, c2.id + b"^0"))
@@ -94,6 +101,7 @@ class ParseObjectTests(TestCase):
 
     def test_missing_parent(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         c1, c2 = build_commit_graph(r.object_store, [[1], [2, 1]])
         # c2 only has 1 parent, so ^2 should fail
         self.assertRaises(ValueError, parse_object, r, c2.id + b"^2")
@@ -102,11 +110,13 @@ class ParseObjectTests(TestCase):
 
     def test_empty_base(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         self.assertRaises(ValueError, parse_object, r, b"~1")
         self.assertRaises(ValueError, parse_object, r, b"^1")
 
     def test_non_commit_with_operators(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         b = Blob.from_string(b"Blah")
         r.object_store.add_object(b)
         # Can't apply ~ or ^ to a blob
@@ -114,6 +124,7 @@ class ParseObjectTests(TestCase):
 
     def test_tag_dereference(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         [c1] = build_commit_graph(r.object_store, [[1]])
         # Create an annotated tag
         tag = Tag()
@@ -129,6 +140,7 @@ class ParseObjectTests(TestCase):
 
     def test_nested_tag_dereference(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         [c1] = build_commit_graph(r.object_store, [[1]])
         # Create a tag pointing to a commit
         tag1 = Tag()
@@ -155,6 +167,7 @@ class ParseObjectTests(TestCase):
 
     def test_path_in_tree(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         # Create a blob
         b = Blob.from_string(b"Test content")
 
@@ -168,6 +181,7 @@ class ParseObjectTests(TestCase):
 
     def test_path_in_tree_nested(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         # Create blobs
         b1 = Blob.from_string(b"Content 1")
         b2 = Blob.from_string(b"Content 2")
@@ -423,15 +437,18 @@ class ParseCommitRangeTests(TestCase):
 
     def test_nonexistent(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         self.assertRaises(KeyError, parse_commit_range, r, "thisdoesnotexist..HEAD")
 
     def test_commit_by_sha(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         c1, _c2, _c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         self.assertIsNone(parse_commit_range(r, c1.id))
 
     def test_commit_range(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         c1, c2, _c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         result = parse_commit_range(r, f"{c1.id.decode()}..{c2.id.decode()}")
         self.assertIsNotNone(result)
@@ -445,20 +462,24 @@ class ParseCommitTests(TestCase):
 
     def test_nonexistent(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         self.assertRaises(KeyError, parse_commit, r, "thisdoesnotexist")
 
     def test_commit_by_sha(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         [c1] = build_commit_graph(r.object_store, [[1]])
         self.assertEqual(c1, parse_commit(r, c1.id))
 
     def test_commit_by_short_sha(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         [c1] = build_commit_graph(r.object_store, [[1]])
         self.assertEqual(c1, parse_commit(r, c1.id[:10]))
 
     def test_annotated_tag(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         [c1] = build_commit_graph(r.object_store, [[1]])
         # Create an annotated tag pointing to the commit
         tag = Tag()
@@ -474,6 +495,7 @@ class ParseCommitTests(TestCase):
 
     def test_nested_tags(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         [c1] = build_commit_graph(r.object_store, [[1]])
         # Create an annotated tag pointing to the commit
         tag1 = Tag()
@@ -516,6 +538,7 @@ class ParseCommitTests(TestCase):
 
     def test_tag_to_blob(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         # Create a blob
         blob = Blob.from_string(b"Test content")
         r.object_store.add_object(blob)
@@ -535,6 +558,7 @@ class ParseCommitTests(TestCase):
 
     def test_commit_object(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         [c1] = build_commit_graph(r.object_store, [[1]])
         # Test that passing a Commit object directly returns the same object
         self.assertEqual(c1, parse_commit(r, c1))
@@ -701,22 +725,26 @@ class ParseTreeTests(TestCase):
 
     def test_nonexistent(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         self.assertRaises(KeyError, parse_tree, r, "thisdoesnotexist")
 
     def test_from_commit(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         c1, _c2, _c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         self.assertEqual(r[c1.tree], parse_tree(r, c1.id))
         self.assertEqual(r[c1.tree], parse_tree(r, c1.tree))
 
     def test_from_ref(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         c1, _c2, _c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         r.refs[b"refs/heads/foo"] = c1.id
         self.assertEqual(r[c1.tree], parse_tree(r, b"foo"))
 
     def test_tree_object(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         [c1] = build_commit_graph(r.object_store, [[1]])
         tree = r[c1.tree]
         # Test that passing a Tree object directly returns the same object
@@ -724,12 +752,14 @@ class ParseTreeTests(TestCase):
 
     def test_commit_object(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         [c1] = build_commit_graph(r.object_store, [[1]])
         # Test that passing a Commit object returns its tree
         self.assertEqual(r[c1.tree], parse_tree(r, c1))
 
     def test_tag_object(self) -> None:
         r = MemoryRepo()
+        self.addCleanup(r.close)
         [c1] = build_commit_graph(r.object_store, [[1]])
         # Create an annotated tag pointing to the commit
         tag = Tag()

+ 19 - 3
tests/test_pack.py

@@ -478,7 +478,10 @@ class TestPackData(PackTests):
             self.datadir, "pack-{}.pack".format(pack1_sha.decode("ascii"))
         )
         with open(path, "rb") as f:
-            PackData.from_file(f, DEFAULT_OBJECT_FORMAT, os.path.getsize(path))
+            pack_data = PackData.from_file(
+                f, DEFAULT_OBJECT_FORMAT, os.path.getsize(path)
+            )
+            pack_data.close()
 
     def test_pack_len(self) -> None:
         with self.get_pack_data(pack1_sha) as p:
@@ -788,7 +791,9 @@ class TestPack(PackTests):
     def test_length_mismatch(self) -> None:
         with self.get_pack_data(pack1_sha) as data:
             index = self.get_pack_index(pack1_sha)
-            Pack.from_objects(data, index).check_length_and_checksum()
+            pack = Pack.from_objects(data, index)
+            self.addCleanup(pack.close)
+            pack.check_length_and_checksum()
 
             data._file.seek(12)
             bad_file = BytesIO()
@@ -796,19 +801,25 @@ class TestPack(PackTests):
             bad_file.write(data._file.read())
             bad_file = BytesIO(bad_file.getvalue())
             bad_data = PackData("", file=bad_file, object_format=DEFAULT_OBJECT_FORMAT)
+            self.addCleanup(bad_data.close)
             bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index)
+            self.addCleanup(bad_pack.close)
             self.assertRaises(AssertionError, lambda: bad_pack.data)
             self.assertRaises(AssertionError, bad_pack.check_length_and_checksum)
 
     def test_checksum_mismatch(self) -> None:
         with self.get_pack_data(pack1_sha) as data:
             index = self.get_pack_index(pack1_sha)
-            Pack.from_objects(data, index).check_length_and_checksum()
+            pack = Pack.from_objects(data, index)
+            self.addCleanup(pack.close)
+            pack.check_length_and_checksum()
 
             data._file.seek(0)
             bad_file = BytesIO(data._file.read()[:-20] + (b"\xff" * 20))
             bad_data = PackData("", file=bad_file, object_format=DEFAULT_OBJECT_FORMAT)
+            self.addCleanup(bad_data.close)
             bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index)
+            self.addCleanup(bad_pack.close)
             self.assertRaises(ChecksumMismatch, lambda: bad_pack.data)
             self.assertRaises(ChecksumMismatch, bad_pack.check_length_and_checksum)
 
@@ -1548,6 +1559,7 @@ class DeltaChainIteratorTests(TestCase):
             thin = bool(list(self.store))
         resolve_ext_ref = (thin and self.get_raw_no_repeat) or None
         data = PackData("test.pack", file=f, object_format=DEFAULT_OBJECT_FORMAT)
+        self.addCleanup(data.close)
         return TestPackIterator.for_pack_data(data, resolve_ext_ref=resolve_ext_ref)
 
     def make_pack_iter_subset(self, f, subset, thin=None):
@@ -1558,6 +1570,7 @@ class DeltaChainIteratorTests(TestCase):
         assert data
         index = MemoryPackIndex.for_pack(data)
         pack = Pack.from_objects(data, index)
+        self.addCleanup(pack.close)
         return TestPackIterator.for_pack_subset(
             pack, subset, resolve_ext_ref=resolve_ext_ref
         )
@@ -1837,6 +1850,9 @@ class DeltaChainIteratorTests(TestCase):
             pack[b1.id]
         except UnresolvedDeltas as e:
             self.assertEqual([b1.id], [sha_to_hex(sha) for sha in e.shas])
+        finally:
+            pack.close()
+            packdata.close()
 
 
 class DeltaEncodeSizeTests(TestCase):

+ 2 - 0
tests/test_rebase.py

@@ -49,6 +49,7 @@ class RebaserTestCase(TestCase):
         """Set up test repository."""
         super().setUp()
         self.repo = MemoryRepo()
+        self.addCleanup(self.repo.close)
 
     def _setup_initial_commit(self):
         """Set up initial commit for tests."""
@@ -478,6 +479,7 @@ class InteractiveRebaseTestCase(TestCase):
         """Set up test repository."""
         super().setUp()
         self.repo = MemoryRepo()
+        self.addCleanup(self.repo.close)
         self._setup_initial_commit()
 
     def _setup_initial_commit(self):

+ 1 - 0
tests/test_reflog.py

@@ -203,6 +203,7 @@ class RepoReflogTests(TestCase):
         TestCase.tearDown(self)
         import shutil
 
+        self.repo.close()
         shutil.rmtree(self.test_dir)
 
     def test_read_reflog_nonexistent(self) -> None:

+ 2 - 3
tests/test_repository.py

@@ -638,7 +638,7 @@ class RepositoryRootTests(TestCase):
         self.addCleanup(shutil.rmtree, tmp_dir)
 
         o = Repo.init(os.path.join(tmp_dir, "s"), mkdir=True)
-        o.close()
+        self.addCleanup(o.close)
         os.symlink("foo", os.path.join(tmp_dir, "s", "bar"))
         o.get_worktree().stage("bar")
         o.get_worktree().commit(
@@ -646,11 +646,10 @@ class RepositoryRootTests(TestCase):
         )
 
         t = o.clone(os.path.join(tmp_dir, "t"), symlinks=False)
+        self.addCleanup(t.close)
         with open(os.path.join(tmp_dir, "t", "bar")) as f:
             self.assertEqual("foo", f.read())
 
-        t.close()
-
     def test_reset_index_protect_hfs(self) -> None:
         tmp_dir = self.mkdtemp()
         self.addCleanup(shutil.rmtree, tmp_dir)