瀏覽代碼

contrib: Add typing

Jelmer Vernooij 3 月之前
父節點
當前提交
dbb2d9e179
共有 5 個文件被更改,包括 317 次插入191 次删除
  1. 1 1
      dulwich/contrib/diffstat.py
  2. 25 20
      dulwich/contrib/paramiko_vendor.py
  3. 33 20
      dulwich/contrib/release_robot.py
  4. 45 23
      dulwich/contrib/requests_vendor.py
  5. 213 127
      dulwich/contrib/swift.py

+ 1 - 1
dulwich/contrib/diffstat.py

@@ -111,7 +111,7 @@ def _parse_patch(
 
 # note must all done using bytes not string because on linux filenames
 # may not be encodable even to utf-8
-def diffstat(lines, max_width=80):
+def diffstat(lines: list[bytes], max_width: int = 80) -> bytes:
     """Generate summary statistics from a git style diff ala
        (git diff tag1 tag2 --stat).
 

+ 25 - 20
dulwich/contrib/paramiko_vendor.py

@@ -31,12 +31,14 @@ the dulwich.client.get_ssh_vendor attribute:
 This implementation is experimental and does not have any tests.
 """
 
+from typing import Any, BinaryIO, Optional, cast
+
 import paramiko
 import paramiko.client
 
 
 class _ParamikoWrapper:
-    def __init__(self, client, channel) -> None:
+    def __init__(self, client: paramiko.SSHClient, channel: paramiko.Channel) -> None:
         self.client = client
         self.channel = channel
 
@@ -44,17 +46,17 @@ class _ParamikoWrapper:
         self.channel.setblocking(True)
 
     @property
-    def stderr(self):
-        return self.channel.makefile_stderr("rb")
+    def stderr(self) -> BinaryIO:
+        return cast(BinaryIO, self.channel.makefile_stderr("rb"))
 
-    def can_read(self):
+    def can_read(self) -> bool:
         return self.channel.recv_ready()
 
-    def write(self, data):
+    def write(self, data: bytes) -> None:
         return self.channel.sendall(data)
 
-    def read(self, n=None):
-        data = self.channel.recv(n)
+    def read(self, n: Optional[int] = None) -> bytes:
+        data = self.channel.recv(n or 4096)
         data_len = len(data)
 
         # Closed socket
@@ -74,24 +76,24 @@ class _ParamikoWrapper:
 class ParamikoSSHVendor:
     # http://docs.paramiko.org/en/2.4/api/client.html
 
-    def __init__(self, **kwargs) -> None:
+    def __init__(self, **kwargs: object) -> None:
         self.kwargs = kwargs
 
     def run_command(
         self,
-        host,
-        command,
-        username=None,
-        port=None,
-        password=None,
-        pkey=None,
-        key_filename=None,
-        protocol_version=None,
-        **kwargs,
-    ):
+        host: str,
+        command: str,
+        username: Optional[str] = None,
+        port: Optional[int] = None,
+        password: Optional[str] = None,
+        pkey: Optional[paramiko.PKey] = None,
+        key_filename: Optional[str] = None,
+        protocol_version: Optional[int] = None,
+        **kwargs: object,
+    ) -> _ParamikoWrapper:
         client = paramiko.SSHClient()
 
-        connection_kwargs = {"hostname": host}
+        connection_kwargs: dict[str, Any] = {"hostname": host}
         connection_kwargs.update(self.kwargs)
         if username:
             connection_kwargs["username"] = username
@@ -110,7 +112,10 @@ class ParamikoSSHVendor:
         client.connect(**connection_kwargs)
 
         # Open SSH session
-        channel = client.get_transport().open_session()
+        transport = client.get_transport()
+        if transport is None:
+            raise RuntimeError("Transport is None")
+        channel = transport.open_session()
 
         if protocol_version is None or protocol_version == 2:
             channel.set_environment_variable(name="GIT_PROTOCOL", value="version=2")

+ 33 - 20
dulwich/contrib/release_robot.py

@@ -46,9 +46,11 @@ EG::
 """
 
 import datetime
+import logging
 import re
 import sys
 import time
+from typing import Any, Optional, cast
 
 from ..repo import Repo
 
@@ -57,7 +59,7 @@ PROJDIR = "."
 PATTERN = r"[ a-zA-Z_\-]*([\d\.]+[\-\w\.]*)"
 
 
-def get_recent_tags(projdir=PROJDIR):
+def get_recent_tags(projdir: str = PROJDIR) -> list[tuple[str, list[Any]]]:
     """Get list of tags in order from newest to oldest and their datetimes.
 
     Args:
@@ -74,8 +76,8 @@ def get_recent_tags(projdir=PROJDIR):
         refs = project.get_refs()  # dictionary of refs and their SHA-1 values
         tags = {}  # empty dictionary to hold tags, commits and datetimes
         # iterate over refs in repository
-        for key, value in refs.items():
-            key = key.decode("utf-8")  # compatible with Python-3
+        for key_bytes, value in refs.items():
+            key = key_bytes.decode("utf-8")  # compatible with Python-3
             obj = project.get_object(value)  # dulwich object from SHA-1
             # don't just check if object is "tag" b/c it could be a "commit"
             # instead check if "tags" is in the ref-name
@@ -85,25 +87,27 @@ def get_recent_tags(projdir=PROJDIR):
             # strip the leading text from refs to get "tag name"
             _, tag = key.rsplit("/", 1)
             # check if tag object is "commit" or "tag" pointing to a "commit"
-            try:
-                commit = obj.object  # a tuple (commit class, commit id)
-            except AttributeError:
-                commit = obj
-                tag_meta = None
-            else:
+            from ..objects import Commit, Tag
+
+            if isinstance(obj, Tag):
+                commit_info = obj.object  # a tuple (commit class, commit id)
                 tag_meta = (
                     datetime.datetime(*time.gmtime(obj.tag_time)[:6]),
                     obj.id.decode("utf-8"),
                     obj.name.decode("utf-8"),
                 )  # compatible with Python-3
-                commit = project.get_object(commit[1])  # commit object
+                commit = project.get_object(commit_info[1])  # commit object
+            else:
+                commit = obj
+                tag_meta = None
             # get tag commit datetime, but dulwich returns seconds since
             # beginning of epoch, so use Python time module to convert it to
             # timetuple then convert to datetime
+            commit_obj = cast(Commit, commit)
             tags[tag] = [
-                datetime.datetime(*time.gmtime(commit.commit_time)[:6]),
-                commit.id.decode("utf-8"),
-                commit.author.decode("utf-8"),
+                datetime.datetime(*time.gmtime(commit_obj.commit_time)[:6]),
+                commit_obj.id.decode("utf-8"),
+                commit_obj.author.decode("utf-8"),
                 tag_meta,
             ]  # compatible with Python-3
 
@@ -111,7 +115,11 @@ def get_recent_tags(projdir=PROJDIR):
     return sorted(tags.items(), key=lambda tag: tag[1][0], reverse=True)
 
 
-def get_current_version(projdir=PROJDIR, pattern=PATTERN, logger=None):
+def get_current_version(
+    projdir: str = PROJDIR,
+    pattern: str = PATTERN,
+    logger: Optional[logging.Logger] = None,
+) -> Optional[str]:
     """Return the most recent tag, using an options regular expression pattern.
 
     The default pattern will strip any characters preceding the first semantic
@@ -129,15 +137,20 @@ def get_current_version(projdir=PROJDIR, pattern=PATTERN, logger=None):
     try:
         tag = tags[0][0]
     except IndexError:
-        return
+        return None
     matches = re.match(pattern, tag)
-    try:
-        current_version = matches.group(1)
-    except (IndexError, AttributeError) as err:
+    if matches:
+        try:
+            current_version = matches.group(1)
+            return current_version
+        except IndexError as err:
+            if logger:
+                logger.debug("Pattern %r didn't match tag %r: %s", pattern, tag, err)
+            return tag
+    else:
         if logger:
-            logger.debug("Pattern %r didn't match tag %r: %s", pattern, tag, err)
+            logger.debug("Pattern %r didn't match tag %r", pattern, tag)
         return tag
-    return current_version
 
 
 if __name__ == "__main__":

+ 45 - 23
dulwich/contrib/requests_vendor.py

@@ -32,6 +32,10 @@ This implementation is experimental and does not have any tests.
 """
 
 from io import BytesIO
+from typing import TYPE_CHECKING, Any, Callable, Optional
+
+if TYPE_CHECKING:
+    from ..config import ConfigFile
 
 from requests import Session
 
@@ -46,7 +50,13 @@ from ..errors import GitProtocolError, NotGitRepository
 
 class RequestsHttpGitClient(AbstractHttpGitClient):
     def __init__(
-        self, base_url, dumb=None, config=None, username=None, password=None, **kwargs
+        self,
+        base_url: str,
+        dumb: Optional[bool] = None,
+        config: Optional["ConfigFile"] = None,
+        username: Optional[str] = None,
+        password: Optional[str] = None,
+        **kwargs: object,
     ) -> None:
         self._username = username
         self._password = password
@@ -54,12 +64,20 @@ class RequestsHttpGitClient(AbstractHttpGitClient):
         self.session = get_session(config)
 
         if username is not None:
-            self.session.auth = (username, password)
-
-        super().__init__(base_url=base_url, dumb=dumb, **kwargs)
-
-    def _http_request(self, url, headers=None, data=None, allow_compression=False):
-        req_headers = self.session.headers.copy()
+            self.session.auth = (username, password)  # type: ignore[assignment]
+
+        super().__init__(
+            base_url=base_url, dumb=bool(dumb) if dumb is not None else False, **kwargs
+        )
+
+    def _http_request(
+        self,
+        url: str,
+        headers: Optional[dict[str, str]] = None,
+        data: Optional[bytes] = None,
+        allow_compression: bool = False,
+    ) -> tuple[Any, Callable[[int], bytes]]:
+        req_headers = self.session.headers.copy()  # type: ignore[attr-defined]
         if headers is not None:
             req_headers.update(headers)
 
@@ -83,34 +101,37 @@ class RequestsHttpGitClient(AbstractHttpGitClient):
             raise GitProtocolError(f"unexpected http resp {resp.status_code} for {url}")
 
         # Add required fields as stated in AbstractHttpGitClient._http_request
-        resp.content_type = resp.headers.get("Content-Type")
-        resp.redirect_location = ""
+        resp.content_type = resp.headers.get("Content-Type")  # type: ignore[attr-defined]
+        resp.redirect_location = ""  # type: ignore[attr-defined]
         if resp.history:
-            resp.redirect_location = resp.url
+            resp.redirect_location = resp.url  # type: ignore[attr-defined]
 
         read = BytesIO(resp.content).read
 
         return resp, read
 
 
-def get_session(config):
+def get_session(config: Optional["ConfigFile"]) -> Session:
     session = Session()
     session.headers.update({"Pragma": "no-cache"})
 
-    proxy_server = user_agent = ca_certs = ssl_verify = None
+    proxy_server: Optional[str] = None
+    user_agent: Optional[str] = None
+    ca_certs: Optional[str] = None
+    ssl_verify: Optional[bool] = None
 
     if config is not None:
         try:
-            proxy_server = config.get(b"http", b"proxy")
-            if isinstance(proxy_server, bytes):
-                proxy_server = proxy_server.decode()
+            proxy_bytes = config.get(b"http", b"proxy")
+            if isinstance(proxy_bytes, bytes):
+                proxy_server = proxy_bytes.decode()
         except KeyError:
             pass
 
         try:
-            user_agent = config.get(b"http", b"useragent")
-            if isinstance(user_agent, bytes):
-                user_agent = user_agent.decode()
+            agent_bytes = config.get(b"http", b"useragent")
+            if isinstance(agent_bytes, bytes):
+                user_agent = agent_bytes.decode()
         except KeyError:
             pass
 
@@ -120,21 +141,22 @@ def get_session(config):
             ssl_verify = True
 
         try:
-            ca_certs = config.get(b"http", b"sslCAInfo")
-            if isinstance(ca_certs, bytes):
-                ca_certs = ca_certs.decode()
+            certs_bytes = config.get(b"http", b"sslCAInfo")
+            if isinstance(certs_bytes, bytes):
+                ca_certs = certs_bytes.decode()
         except KeyError:
             ca_certs = None
 
     if user_agent is None:
         user_agent = default_user_agent_string()
-    session.headers.update({"User-agent": user_agent})
+    if user_agent is not None:
+        session.headers.update({"User-agent": user_agent})
 
     if ca_certs:
         session.verify = ca_certs
     elif ssl_verify is False:
         session.verify = ssl_verify
 
-    if proxy_server:
+    if proxy_server is not None:
         session.proxies.update({"http": proxy_server, "https": proxy_server})
     return session

+ 213 - 127
dulwich/contrib/swift.py

@@ -28,6 +28,7 @@
 # TODO(fbo): More logs for operations
 
 import json
+import logging
 import os
 import posixpath
 import stat
@@ -35,19 +36,21 @@ import sys
 import tempfile
 import urllib.parse as urlparse
 import zlib
+from collections.abc import Iterator
 from configparser import ConfigParser
 from io import BytesIO
-from typing import Optional
+from typing import BinaryIO, Callable, Optional, cast
 
 from geventhttpclient import HTTPClient
 
 from ..greenthreads import GreenThreadsMissingObjectFinder
 from ..lru_cache import LRUSizeCache
-from ..object_store import INFODIR, PACKDIR, PackBasedObjectStore
+from ..object_store import INFODIR, PACKDIR, ObjectContainer, PackBasedObjectStore
 from ..objects import S_ISGITLINK, Blob, Commit, Tag, Tree
 from ..pack import (
     Pack,
     PackData,
+    PackIndex,
     PackIndexer,
     PackStreamCopier,
     _compute_object_size,
@@ -63,7 +66,7 @@ from ..pack import (
 from ..protocol import TCP_GIT_PORT
 from ..refs import InfoRefsContainer, read_info_refs, split_peeled_refs, write_info_refs
 from ..repo import OBJECTDIR, BaseRepo
-from ..server import Backend, TCPGitServer
+from ..server import Backend, BackendRepo, TCPGitServer
 
 """
 # Configuration file sample
@@ -94,29 +97,47 @@ cache_length = 20
 
 
 class PackInfoMissingObjectFinder(GreenThreadsMissingObjectFinder):
-    def next(self):
+    def next(self) -> Optional[tuple[bytes, int, bytes | None]]:
         while True:
             if not self.objects_to_send:
                 return None
-            (sha, name, leaf) = self.objects_to_send.pop()
+            (sha, name, leaf, _) = self.objects_to_send.pop()
             if sha not in self.sha_done:
                 break
         if not leaf:
-            info = self.object_store.pack_info_get(sha)
-            if info[0] == Commit.type_num:
-                self.add_todo([(info[2], "", False)])
-            elif info[0] == Tree.type_num:
-                self.add_todo([tuple(i) for i in info[1]])
-            elif info[0] == Tag.type_num:
-                self.add_todo([(info[1], None, False)])
-            if sha in self._tagged:
-                self.add_todo([(self._tagged[sha], None, True)])
+            try:
+                obj = self.object_store[sha]
+                if isinstance(obj, Commit):
+                    self.add_todo([(obj.tree, b"", None, False)])
+                elif isinstance(obj, Tree):
+                    tree_items = [
+                        (
+                            item.sha,
+                            item.path
+                            if isinstance(item.path, bytes)
+                            else item.path.encode("utf-8"),
+                            None,
+                            False,
+                        )
+                        for item in obj.items()
+                    ]
+                    self.add_todo(tree_items)
+                elif isinstance(obj, Tag):
+                    self.add_todo([(obj.object[1], None, None, False)])
+                if sha in self._tagged:
+                    self.add_todo([(self._tagged[sha], None, None, True)])
+            except KeyError:
+                pass
         self.sha_done.add(sha)
         self.progress(f"counting objects: {len(self.sha_done)}\r")
-        return (sha, name)
+        return (
+            sha,
+            0,
+            name if isinstance(name, bytes) else name.encode("utf-8") if name else None,
+        )
 
 
-def load_conf(path=None, file=None):
+def load_conf(path: Optional[str] = None, file: Optional[str] = None) -> ConfigParser:
     """Load configuration in global var CONF.
 
     Args:
@@ -125,27 +146,23 @@ def load_conf(path=None, file=None):
     """
     conf = ConfigParser()
     if file:
-        try:
-            conf.read_file(file, path)
-        except AttributeError:
-            # read_file only exists in Python3
-            conf.readfp(file)
-        return conf
-    confpath = None
-    if not path:
-        try:
-            confpath = os.environ["DULWICH_SWIFT_CFG"]
-        except KeyError as exc:
-            raise Exception("You need to specify a configuration file") from exc
+        conf.read_file(file, path)
     else:
-        confpath = path
-    if not os.path.isfile(confpath):
-        raise Exception(f"Unable to read configuration file {confpath}")
-    conf.read(confpath)
+        confpath = None
+        if not path:
+            try:
+                confpath = os.environ["DULWICH_SWIFT_CFG"]
+            except KeyError as exc:
+                raise Exception("You need to specify a configuration file") from exc
+        else:
+            confpath = path
+        if not os.path.isfile(confpath):
+            raise Exception(f"Unable to read configuration file {confpath}")
+        conf.read(confpath)
     return conf
 
 
-def swift_load_pack_index(scon, filename):
+def swift_load_pack_index(scon: "SwiftConnector", filename: str) -> "PackIndex":
     """Read a pack index file from Swift.
 
     Args:
@@ -153,45 +170,66 @@ def swift_load_pack_index(scon, filename):
       filename: Path to the index file objectise
     Returns: a `PackIndexer` instance
     """
-    with scon.get_object(filename) as f:
-        return load_pack_index_file(filename, f)
+    f = scon.get_object(filename)
+    if f is None:
+        raise Exception(f"Could not retrieve index file {filename}")
+    if isinstance(f, bytes):
+        f = BytesIO(f)
+    return load_pack_index_file(filename, f)
 
 
-def pack_info_create(pack_data, pack_index):
+def pack_info_create(pack_data: "PackData", pack_index: "PackIndex") -> bytes:
     pack = Pack.from_objects(pack_data, pack_index)
-    info = {}
+    info: dict = {}
     for obj in pack.iterobjects():
         # Commit
         if obj.type_num == Commit.type_num:
-            info[obj.id] = (obj.type_num, obj.parents, obj.tree)
+            commit_obj = obj
+            assert isinstance(commit_obj, Commit)
+            info[obj.id] = (obj.type_num, commit_obj.parents, commit_obj.tree)
         # Tree
         elif obj.type_num == Tree.type_num:
+            tree_obj = obj
+            assert isinstance(tree_obj, Tree)
             shas = [
                 (s, n, not stat.S_ISDIR(m))
-                for n, m, s in obj.items()
+                for n, m, s in tree_obj.items()
                 if not S_ISGITLINK(m)
             ]
             info[obj.id] = (obj.type_num, shas)
         # Blob
         elif obj.type_num == Blob.type_num:
-            info[obj.id] = None
+            info[obj.id] = (obj.type_num,)
         # Tag
         elif obj.type_num == Tag.type_num:
-            info[obj.id] = (obj.type_num, obj.object[1])
-    return zlib.compress(json.dumps(info))
+            tag_obj = obj
+            assert isinstance(tag_obj, Tag)
+            info[obj.id] = (obj.type_num, tag_obj.object[1])
+    return zlib.compress(json.dumps(info).encode("utf-8"))
 
 
-def load_pack_info(filename, scon=None, file=None):
+def load_pack_info(
+    filename: str,
+    scon: Optional["SwiftConnector"] = None,
+    file: Optional[BinaryIO] = None,
+) -> Optional[dict]:
     if not file:
-        f = scon.get_object(filename)
+        if scon is None:
+            return None
+        obj = scon.get_object(filename)
+        if obj is None:
+            return None
+        if isinstance(obj, bytes):
+            return json.loads(zlib.decompress(obj))
+        else:
+            f: BinaryIO = obj
     else:
         f = file
-    if not f:
-        return None
     try:
         return json.loads(zlib.decompress(f.read()))
     finally:
-        f.close()
+        if hasattr(f, "close"):
+            f.close()
 
 
 class SwiftException(Exception):
@@ -201,7 +239,7 @@ class SwiftException(Exception):
 class SwiftConnector:
     """A Connector to swift that manage authentication and errors catching."""
 
-    def __init__(self, root, conf) -> None:
+    def __init__(self, root: str, conf: ConfigParser) -> None:
         """Initialize a SwiftConnector.
 
         Args:
@@ -242,7 +280,7 @@ class SwiftConnector:
             posixpath.join(urlparse.urlparse(self.storage_url).path, self.root)
         )
 
-    def swift_auth_v1(self):
+    def swift_auth_v1(self) -> tuple[str, str]:
         self.user = self.user.replace(";", ":")
         auth_httpclient = HTTPClient.from_url(
             self.auth_url,
@@ -265,7 +303,7 @@ class SwiftConnector:
         token = ret["X-Auth-Token"]
         return storage_url, token
 
-    def swift_auth_v2(self):
+    def swift_auth_v2(self) -> tuple[str, str]:
         self.tenant, self.user = self.user.split(";")
         auth_dict = {}
         auth_dict["auth"] = {
@@ -331,7 +369,7 @@ class SwiftConnector:
                     f"PUT request failed with error code {ret.status_code}"
                 )
 
-    def get_container_objects(self):
+    def get_container_objects(self) -> Optional[list[dict]]:
         """Retrieve objects list in a container.
 
         Returns: A list of dict that describe objects
@@ -349,7 +387,7 @@ class SwiftConnector:
         content = ret.read()
         return json.loads(content)
 
-    def get_object_stat(self, name):
+    def get_object_stat(self, name: str) -> Optional[dict]:
         """Retrieve object stat.
 
         Args:
@@ -370,7 +408,7 @@ class SwiftConnector:
             resp_headers[header.lower()] = value
         return resp_headers
 
-    def put_object(self, name, content) -> None:
+    def put_object(self, name: str, content: BinaryIO) -> None:
         """Put an object.
 
         Args:
@@ -384,7 +422,7 @@ class SwiftConnector:
         path = self.base_path + "/" + name
         headers = {"Content-Length": str(len(data))}
 
-        def _send():
+        def _send() -> object:
             ret = self.httpclient.request("PUT", path, body=data, headers=headers)
             return ret
 
@@ -395,12 +433,14 @@ class SwiftConnector:
             # Second attempt work
             ret = _send()
 
-        if ret.status_code < 200 or ret.status_code > 300:
+        if ret.status_code < 200 or ret.status_code > 300:  # type: ignore
             raise SwiftException(
-                f"PUT request failed with error code {ret.status_code}"
+                f"PUT request failed with error code {ret.status_code}"  # type: ignore
             )
 
-    def get_object(self, name, range=None):
+    def get_object(
+        self, name: str, range: Optional[str] = None
+    ) -> Optional[bytes | BytesIO]:
         """Retrieve an object.
 
         Args:
@@ -427,7 +467,7 @@ class SwiftConnector:
             return content
         return BytesIO(content)
 
-    def del_object(self, name) -> None:
+    def del_object(self, name: str) -> None:
         """Delete an object.
 
         Args:
@@ -448,8 +488,10 @@ class SwiftConnector:
         Raises:
           SwiftException: if unable to delete
         """
-        for obj in self.get_container_objects():
-            self.del_object(obj["name"])
+        objects = self.get_container_objects()
+        if objects:
+            for obj in objects:
+                self.del_object(obj["name"])
         ret = self.httpclient.request("DELETE", self.base_path)
         if ret.status_code < 200 or ret.status_code > 300:
             raise SwiftException(
@@ -467,7 +509,7 @@ class SwiftPackReader:
     to read from Swift.
     """
 
-    def __init__(self, scon, filename, pack_length) -> None:
+    def __init__(self, scon: SwiftConnector, filename: str, pack_length: int) -> None:
         """Initialize a SwiftPackReader.
 
         Args:
@@ -483,15 +525,20 @@ class SwiftPackReader:
         self.buff = b""
         self.buff_length = self.scon.chunk_length
 
-    def _read(self, more=False) -> None:
+    def _read(self, more: bool = False) -> None:
         if more:
             self.buff_length = self.buff_length * 2
         offset = self.base_offset
         r = min(self.base_offset + self.buff_length, self.pack_length)
         ret = self.scon.get_object(self.filename, range=f"{offset}-{r}")
-        self.buff = ret
+        if ret is None:
+            self.buff = b""
+        elif isinstance(ret, bytes):
+            self.buff = ret
+        else:
+            self.buff = ret.read()
 
-    def read(self, length):
+    def read(self, length: int) -> bytes:
         """Read a specified amount of Bytes form the pack object.
 
         Args:
@@ -512,7 +559,7 @@ class SwiftPackReader:
         self.offset = end
         return data
 
-    def seek(self, offset) -> None:
+    def seek(self, offset: int) -> None:
         """Seek to a specified offset.
 
         Args:
@@ -522,12 +569,18 @@ class SwiftPackReader:
         self._read()
         self.offset = 0
 
-    def read_checksum(self):
+    def read_checksum(self) -> bytes:
         """Read the checksum from the pack.
 
         Returns: the checksum bytestring
         """
-        return self.scon.get_object(self.filename, range="-20")
+        ret = self.scon.get_object(self.filename, range="-20")
+        if ret is None:
+            return b""
+        elif isinstance(ret, bytes):
+            return ret
+        else:
+            return ret.read()
 
 
 class SwiftPackData(PackData):
@@ -537,7 +590,7 @@ class SwiftPackData(PackData):
     using the Range header feature of Swift.
     """
 
-    def __init__(self, scon, filename) -> None:
+    def __init__(self, scon: SwiftConnector, filename: str) -> None:
         """Initialize a SwiftPackReader.
 
         Args:
@@ -548,6 +601,8 @@ class SwiftPackData(PackData):
         self._filename = filename
         self._header_size = 12
         headers = self.scon.get_object_stat(self._filename)
+        if headers is None:
+            raise Exception(f"Could not get stats for {self._filename}")
         self.pack_length = int(headers["content-length"])
         pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length)
         (version, self._num_objects) = read_pack_header(pack_reader.read)
@@ -557,16 +612,19 @@ class SwiftPackData(PackData):
         )
         self.pack = None
 
-    def get_object_at(self, offset):
+    def get_object_at(
+        self, offset: int
+    ) -> tuple[int, tuple[bytes | int, list[bytes]] | list[bytes]]:
         if offset in self._offset_cache:
             return self._offset_cache[offset]
         assert offset >= self._header_size
         pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length)
         pack_reader.seek(offset)
         unpacked, _ = unpack_object(pack_reader.read)
-        return (unpacked.pack_type_num, unpacked._obj())
+        obj_data = unpacked._obj()
+        return (unpacked.pack_type_num, obj_data)
 
-    def get_stored_checksum(self):
+    def get_stored_checksum(self) -> bytes:
         pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length)
         return pack_reader.read_checksum()
 
@@ -582,18 +640,18 @@ class SwiftPack(Pack):
     PackData.
     """
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, *args: object, **kwargs: object) -> None:
         self.scon = kwargs["scon"]
         del kwargs["scon"]
-        super().__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)  # type: ignore
         self._pack_info_path = self._basename + ".info"
-        self._pack_info = None
-        self._pack_info_load = lambda: load_pack_info(self._pack_info_path, self.scon)
-        self._idx_load = lambda: swift_load_pack_index(self.scon, self._idx_path)
-        self._data_load = lambda: SwiftPackData(self.scon, self._data_path)
+        self._pack_info: Optional[dict] = None
+        self._pack_info_load = lambda: load_pack_info(self._pack_info_path, self.scon)  # type: ignore
+        self._idx_load = lambda: swift_load_pack_index(self.scon, self._idx_path)  # type: ignore
+        self._data_load = lambda: SwiftPackData(self.scon, self._data_path)  # type: ignore
 
     @property
-    def pack_info(self):
+    def pack_info(self) -> Optional[dict]:
         """The pack data object being used."""
         if self._pack_info is None:
             self._pack_info = self._pack_info_load()
@@ -607,7 +665,7 @@ class SwiftObjectStore(PackBasedObjectStore):
     This object store only supports pack files and not loose objects.
     """
 
-    def __init__(self, scon) -> None:
+    def __init__(self, scon: SwiftConnector) -> None:
         """Open a Swift object store.
 
         Args:
@@ -619,8 +677,10 @@ class SwiftObjectStore(PackBasedObjectStore):
         self.pack_dir = posixpath.join(OBJECTDIR, PACKDIR)
         self._alternates = None
 
-    def _update_pack_cache(self):
+    def _update_pack_cache(self) -> list:
         objects = self.scon.get_container_objects()
+        if objects is None:
+            return []
         pack_files = [
             o["name"].replace(".pack", "")
             for o in objects
@@ -633,25 +693,37 @@ class SwiftObjectStore(PackBasedObjectStore):
             ret.append(pack)
         return ret
 
-    def _iter_loose_objects(self):
+    def _iter_loose_objects(self) -> Iterator:
         """Loose objects are not supported by this repository."""
-        return []
+        return iter([])
 
-    def pack_info_get(self, sha):
+    def pack_info_get(self, sha: bytes) -> Optional[tuple]:
         for pack in self.packs:
             if sha in pack:
-                return pack.pack_info[sha]
+                if hasattr(pack, "pack_info"):
+                    pack_info = pack.pack_info
+                    if pack_info is not None:
+                        return pack_info.get(sha)
+        return None
 
-    def _collect_ancestors(self, heads, common=set()):
-        def _find_parents(commit):
+    def _collect_ancestors(
+        self, heads: list, common: Optional[set] = None
+    ) -> tuple[set, set]:
+        if common is None:
+            common = set()
+
+        def _find_parents(commit: bytes) -> list:
             for pack in self.packs:
                 if commit in pack:
                     try:
-                        parents = pack.pack_info[commit][1]
+                        if hasattr(pack, "pack_info"):
+                            pack_info = pack.pack_info
+                            if pack_info is not None:
+                                return pack_info[commit][1]
                     except KeyError:
                         # Seems to have no parents
                         return []
-                    return parents
+            return []
 
         bases = set()
         commits = set()
@@ -667,7 +739,7 @@ class SwiftObjectStore(PackBasedObjectStore):
                 queue.extend(parents)
         return (commits, bases)
 
-    def add_pack(self):
+    def add_pack(self) -> tuple[BytesIO, Callable, Callable]:
         """Add a new pack to this object store.
 
         Returns: Fileobject to write to and a commit function to
@@ -675,14 +747,14 @@ class SwiftObjectStore(PackBasedObjectStore):
         """
         f = BytesIO()
 
-        def commit():
+        def commit() -> Optional["SwiftPack"]:
             f.seek(0)
             pack = PackData(file=f, filename="")
             entries = pack.sorted_entries()
             if entries:
                 basename = posixpath.join(
                     self.pack_dir,
-                    f"pack-{iter_sha1(entry[0] for entry in entries)}",
+                    f"pack-{iter_sha1(entry[0] for entry in entries).decode('ascii')}",
                 )
                 index = BytesIO()
                 write_pack_index_v2(index, entries, pack.get_stored_checksum())
@@ -702,20 +774,20 @@ class SwiftObjectStore(PackBasedObjectStore):
 
         return f, commit, abort
 
-    def add_object(self, obj) -> None:
+    def add_object(self, obj: object) -> None:
         self.add_objects(
             [
-                (obj, None),
+                (obj, None),  # type: ignore
             ]
         )
 
     def _pack_cache_stale(self) -> bool:
         return False
 
-    def _get_loose_object(self, sha) -> None:
+    def _get_loose_object(self, sha: bytes) -> None:
         return None
 
-    def add_thin_pack(self, read_all, read_some):
+    def add_thin_pack(self, read_all: Callable, read_some: Callable) -> "SwiftPack":
         """Read a thin pack.
 
         Read it from a stream and complete it in a temporary file.
@@ -724,7 +796,7 @@ class SwiftObjectStore(PackBasedObjectStore):
         fd, path = tempfile.mkstemp(prefix="tmp_pack_")
         f = os.fdopen(fd, "w+b")
         try:
-            indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
+            indexer = PackIndexer(f, resolve_ext_ref=None)
             copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             return self._complete_thin_pack(f, path, copier, indexer)
@@ -732,12 +804,14 @@ class SwiftObjectStore(PackBasedObjectStore):
             f.close()
             os.unlink(path)
 
-    def _complete_thin_pack(self, f, path, copier, indexer):
-        entries = list(indexer)
+    def _complete_thin_pack(
+        self, f: BinaryIO, path: str, copier: object, indexer: object
+    ) -> "SwiftPack":
+        entries = list(indexer)  # type: ignore
 
         # Update the header with the new number of objects.
         f.seek(0)
-        write_pack_header(f, len(entries) + len(indexer.ext_refs()))
+        write_pack_header(f, len(entries) + len(indexer.ext_refs()))  # type: ignore
 
         # Must flush before reading (http://bugs.python.org/issue3207)
         f.flush()
@@ -749,11 +823,11 @@ class SwiftObjectStore(PackBasedObjectStore):
         f.seek(0, os.SEEK_CUR)
 
         # Complete the pack.
-        for ext_sha in indexer.ext_refs():
+        for ext_sha in indexer.ext_refs():  # type: ignore
             assert len(ext_sha) == 20
             type_num, data = self.get_raw(ext_sha)
             offset = f.tell()
-            crc32 = write_pack_object(f, type_num, data, sha=new_sha)
+            crc32 = write_pack_object(f, type_num, data, sha=new_sha)  # type: ignore
             entries.append((ext_sha, offset, crc32))
         pack_sha = new_sha.digest()
         f.write(pack_sha)
@@ -796,20 +870,26 @@ class SwiftObjectStore(PackBasedObjectStore):
 class SwiftInfoRefsContainer(InfoRefsContainer):
     """Manage references in info/refs object."""
 
-    def __init__(self, scon, store) -> None:
+    def __init__(self, scon: SwiftConnector, store: object) -> None:
         self.scon = scon
         self.filename = "info/refs"
         self.store = store
         f = self.scon.get_object(self.filename)
         if not f:
             f = BytesIO(b"")
+        elif isinstance(f, bytes):
+            f = BytesIO(f)
         super().__init__(f)
 
-    def _load_check_ref(self, name, old_ref):
+    def _load_check_ref(self, name: bytes, old_ref: Optional[bytes]) -> dict | bool:
         self._check_refname(name)
-        f = self.scon.get_object(self.filename)
-        if not f:
+        obj = self.scon.get_object(self.filename)
+        if not obj:
             return {}
+        if isinstance(obj, bytes):
+            f = BytesIO(obj)
+        else:
+            f = obj
         refs = read_info_refs(f)
         (refs, peeled) = split_peeled_refs(refs)
         if old_ref is not None:
@@ -817,20 +897,20 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
                 return False
         return refs
 
-    def _write_refs(self, refs) -> None:
+    def _write_refs(self, refs: dict) -> None:
         f = BytesIO()
-        f.writelines(write_info_refs(refs, self.store))
+        f.writelines(write_info_refs(refs, cast("ObjectContainer", self.store)))
         self.scon.put_object(self.filename, f)
 
     def set_if_equals(
         self,
-        name,
-        old_ref,
-        new_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        new_ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[float] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref."""
         if name == "HEAD":
@@ -844,7 +924,13 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         return True
 
     def remove_if_equals(
-        self, name, old_ref, committer=None, timestamp=None, timezone=None, message=None
+        self,
+        name: bytes,
+        old_ref: Optional[bytes],
+        committer: object = None,
+        timestamp: object = None,
+        timezone: object = None,
+        message: object = None,
     ) -> bool:
         """Remove a refname only if it currently equals old_ref."""
         if name == "HEAD":
@@ -857,16 +943,16 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         del self._refs[name]
         return True
 
-    def allkeys(self):
+    def allkeys(self) -> Iterator[bytes]:
         try:
-            self._refs["HEAD"] = self._refs["refs/heads/master"]
+            self._refs[b"HEAD"] = self._refs[b"refs/heads/master"]
         except KeyError:
             pass
-        return self._refs.keys()
+        return iter(self._refs.keys())
 
 
 class SwiftRepo(BaseRepo):
-    def __init__(self, root, conf) -> None:
+    def __init__(self, root: str, conf: ConfigParser) -> None:
         """Init a Git bare Repository on top of a Swift container.
 
         References are managed in info/refs objects by
@@ -899,7 +985,7 @@ class SwiftRepo(BaseRepo):
         """
         return False
 
-    def _put_named_file(self, filename, contents) -> None:
+    def _put_named_file(self, filename: str, contents: bytes) -> None:
         """Put an object in a Swift container.
 
         Args:
@@ -911,7 +997,7 @@ class SwiftRepo(BaseRepo):
             self.scon.put_object(filename, f)
 
     @classmethod
-    def init_bare(cls, scon, conf):
+    def init_bare(cls, scon: SwiftConnector, conf: ConfigParser) -> "SwiftRepo":
         """Create a new bare repository.
 
         Args:
@@ -932,16 +1018,16 @@ class SwiftRepo(BaseRepo):
 
 
 class SwiftSystemBackend(Backend):
-    def __init__(self, logger, conf) -> None:
+    def __init__(self, logger: "logging.Logger", conf: ConfigParser) -> None:
         self.conf = conf
         self.logger = logger
 
-    def open_repository(self, path):
+    def open_repository(self, path: str) -> "BackendRepo":
         self.logger.info("opening repository at %s", path)
-        return SwiftRepo(path, self.conf)
+        return cast("BackendRepo", SwiftRepo(path, self.conf))
 
 
-def cmd_daemon(args) -> None:
+def cmd_daemon(args: list) -> None:
     """Entry point for starting a TCP git server."""
     import optparse
 
@@ -993,7 +1079,7 @@ def cmd_daemon(args) -> None:
     server.serve_forever()
 
 
-def cmd_init(args) -> None:
+def cmd_init(args: list) -> None:
     import optparse
 
     parser = optparse.OptionParser()
@@ -1014,7 +1100,7 @@ def cmd_init(args) -> None:
     SwiftRepo.init_bare(scon, conf)
 
 
-def main(argv=sys.argv) -> None:
+def main(argv: list = sys.argv) -> None:
     commands = {
         "init": cmd_init,
         "daemon": cmd_daemon,