2
0
Эх сурвалжийг харах

Import upstream version 0.21.0

Jelmer Vernooij 2 жил өмнө
parent
commit
e40d52fb20
98 өөрчлөгдсөн 2055 нэмэгдсэн , 1481 устгасан
  1. 5 0
      .github/CODEOWNERS
  2. 2 0
      .github/workflows/docs.yml
  3. 6 14
      .github/workflows/pythontest.yml
  4. 15 0
      NEWS
  5. 3 4
      PKG-INFO
  6. 3 4
      docs/conf.py
  7. 1 0
      docs/tutorial/remote.txt
  8. 3 4
      dulwich.egg-info/PKG-INFO
  9. 2 0
      dulwich.egg-info/SOURCES.txt
  10. 3 0
      dulwich.egg-info/requires.txt
  11. 1 1
      dulwich/__init__.py
  12. 2 3
      dulwich/archive.py
  13. 13 7
      dulwich/bundle.py
  14. 45 17
      dulwich/cli.py
  15. 137 80
      dulwich/client.py
  16. 3 3
      dulwich/cloud/gcs.py
  17. 9 9
      dulwich/config.py
  18. 7 4
      dulwich/contrib/diffstat.py
  19. 2 2
      dulwich/contrib/paramiko_vendor.py
  20. 2 2
      dulwich/contrib/release_robot.py
  21. 1 1
      dulwich/contrib/requests_vendor.py
  22. 7 30
      dulwich/contrib/swift.py
  23. 2 2
      dulwich/contrib/test_paramiko_vendor.py
  24. 7 7
      dulwich/contrib/test_swift.py
  25. 9 8
      dulwich/diff_tree.py
  26. 9 7
      dulwich/errors.py
  27. 5 2
      dulwich/fastexport.py
  28. 19 19
      dulwich/file.py
  29. 2 2
      dulwich/graph.py
  30. 8 42
      dulwich/greenthreads.py
  31. 1 1
      dulwich/hooks.py
  32. 11 11
      dulwich/ignore.py
  33. 18 20
      dulwich/index.py
  34. 1 1
      dulwich/lfs.py
  35. 3 2
      dulwich/line_ending.py
  36. 59 35
      dulwich/lru_cache.py
  37. 1 1
      dulwich/mailmap.py
  38. 320 321
      dulwich/object_store.py
  39. 40 26
      dulwich/objects.py
  40. 294 176
      dulwich/pack.py
  41. 14 12
      dulwich/patch.py
  42. 34 14
      dulwich/porcelain.py
  43. 6 7
      dulwich/protocol.py
  44. 71 13
      dulwich/refs.py
  45. 57 39
      dulwich/repo.py
  46. 80 55
      dulwich/server.py
  47. 1 2
      dulwich/stash.py
  48. 2 1
      dulwich/submodule.py
  49. 16 10
      dulwich/tests/__init__.py
  50. 4 4
      dulwich/tests/compat/server_utils.py
  51. 10 12
      dulwich/tests/compat/test_client.py
  52. 6 6
      dulwich/tests/compat/test_pack.py
  53. 1 1
      dulwich/tests/compat/test_patch.py
  54. 3 3
      dulwich/tests/compat/test_porcelain.py
  55. 4 4
      dulwich/tests/compat/test_repository.py
  56. 1 1
      dulwich/tests/compat/test_server.py
  57. 2 2
      dulwich/tests/compat/test_utils.py
  58. 3 3
      dulwich/tests/compat/test_web.py
  59. 5 5
      dulwich/tests/compat/utils.py
  60. 1 1
      dulwich/tests/test_archive.py
  61. 2 2
      dulwich/tests/test_blackbox.py
  62. 9 0
      dulwich/tests/test_bundle.py
  63. 189 47
      dulwich/tests/test_client.py
  64. 2 10
      dulwich/tests/test_config.py
  65. 3 3
      dulwich/tests/test_diff_tree.py
  66. 2 2
      dulwich/tests/test_fastexport.py
  67. 6 6
      dulwich/tests/test_file.py
  68. 4 4
      dulwich/tests/test_grafts.py
  69. 7 8
      dulwich/tests/test_graph.py
  70. 1 39
      dulwich/tests/test_greenthreads.py
  71. 1 1
      dulwich/tests/test_hooks.py
  72. 2 2
      dulwich/tests/test_ignore.py
  73. 12 13
      dulwich/tests/test_index.py
  74. 1 1
      dulwich/tests/test_lfs.py
  75. 0 1
      dulwich/tests/test_line_ending.py
  76. 9 9
      dulwich/tests/test_missing_obj_finder.py
  77. 14 11
      dulwich/tests/test_object_store.py
  78. 5 5
      dulwich/tests/test_objects.py
  79. 81 63
      dulwich/tests/test_pack.py
  80. 114 49
      dulwich/tests/test_porcelain.py
  81. 1 1
      dulwich/tests/test_protocol.py
  82. 0 1
      dulwich/tests/test_reflog.py
  83. 62 3
      dulwich/tests/test_refs.py
  84. 12 24
      dulwich/tests/test_repository.py
  85. 36 28
      dulwich/tests/test_server.py
  86. 1 1
      dulwich/tests/test_utils.py
  87. 4 4
      dulwich/tests/test_walk.py
  88. 23 25
      dulwich/tests/test_web.py
  89. 18 16
      dulwich/walk.py
  90. 12 13
      dulwich/web.py
  91. 1 1
      examples/clone.py
  92. 2 2
      examples/latest_change.py
  93. 4 4
      examples/memoryrepo.py
  94. 1 1
      examples/rename-branch.py
  95. 3 0
      pyproject.toml
  96. 2 2
      setup.cfg
  97. 6 2
      setup.py
  98. 1 14
      tox.ini

+ 5 - 0
.github/CODEOWNERS

@@ -0,0 +1,5 @@
+* @jelmer
+
+# Release robot
+dulwich/contrib/release_robot.py @mikofski
+dulwich/contrib/test_release_robot.py @mikofski

+ 2 - 0
.github/workflows/docs.yml

@@ -15,6 +15,8 @@ jobs:
       - uses: actions/checkout@v2
       - name: Set up Python
         uses: actions/setup-python@v2
+        with:
+          python-version: "3.10"
       - name: Install pydoctor
         run: |
           sudo apt-get update && sudo apt -y install -y pydoctor python3-pip

+ 6 - 14
.github/workflows/pythontest.yml

@@ -13,14 +13,7 @@ jobs:
       matrix:
         os: [ubuntu-latest, macos-latest, windows-latest]
         python-version:
-          ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11.0-rc - 3.11", pypy3]
-        exclude:
-          # sqlite3 exit handling seems to get in the way
-          - os: macos-latest
-            python-version: pypy3
-          # doesn't support passing in bytestrings to os.scandir
-          - os: windows-latest
-            python-version: pypy3
+          ["3.7", "3.8", "3.9", "3.10", "3.11"]
       fail-fast: false
 
     steps:
@@ -38,19 +31,17 @@ jobs:
       - name: Install dependencies
         run: |
           python -m pip install --upgrade pip
-          pip install -U pip coverage flake8 fastimport paramiko urllib3
+          pip install -U ".[fastimport,paramiko,https]"
       - name: Install gpg on supported platforms
-        run: pip install -U gpg
+        run: pip install -U ".[pgp]"
         if: "matrix.os != 'windows-latest' && matrix.python-version != 'pypy3'"
-      - name: Install mypy
-        run: |
-          pip install -U mypy types-paramiko types-requests
-        if: "matrix.python-version != 'pypy3'"
       - name: Style checks
         run: |
+          pip install -U flake8
           python -m flake8
       - name: Typing checks
         run: |
+          pip install -U mypy types-paramiko types-requests
           python -m mypy dulwich
         if: "matrix.python-version != 'pypy3'"
       - name: Build
@@ -58,4 +49,5 @@ jobs:
           python setup.py build_ext -i
       - name: Coverage test suite run
         run: |
+          pip install -U coverage
           python -m coverage run -p -m unittest dulwich.tests.test_suite

+ 15 - 0
NEWS

@@ -1,5 +1,20 @@
+0.21.0	2023-01-16
+
+ * Pack internals have been significantly refactored, including
+   significant low-level API changes.
+
+   As a consequence of this, Dulwich now reuses pack deltas
+   when communicating with remote servers, which brings a
+   big boost to network performance.
+   (Jelmer Vernooij)
+
 0.20.50	2022-10-30
 
+ * Add --deltify option to ``dulwich pack-objects`` which enables
+   deltification, and add initial support for reusing suitable
+   deltas found in an existing pack file.
+   (Stefan Sperling)
+
  * Fix Repo.reset_index.
    Previously, it instead took the union with the given tree.
    (Christian Sattler, #1072)

+ 3 - 4
PKG-INFO

@@ -1,6 +1,6 @@
 Metadata-Version: 2.1
 Name: dulwich
-Version: 0.20.50
+Version: 0.21.0
 Summary: Python Git Library
 Home-page: https://www.dulwich.io/
 Author: Jelmer Vernooij
@@ -12,7 +12,6 @@ Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
 Keywords: vcs,git
 Classifier: Development Status :: 4 - Beta
 Classifier: License :: OSI Approved :: Apache Software License
-Classifier: Programming Language :: Python :: 3.6
 Classifier: Programming Language :: Python :: 3.7
 Classifier: Programming Language :: Python :: 3.8
 Classifier: Programming Language :: Python :: 3.9
@@ -23,11 +22,11 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
 Classifier: Operating System :: POSIX
 Classifier: Operating System :: Microsoft :: Windows
 Classifier: Topic :: Software Development :: Version Control
-Requires-Python: >=3.6
+Requires-Python: >=3.7
 Provides-Extra: fastimport
 Provides-Extra: https
-Provides-Extra: pgp
 Provides-Extra: paramiko
+Provides-Extra: pgp
 License-File: COPYING
 
 Dulwich

+ 3 - 4
docs/conf.py

@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 #
 # dulwich documentation build configuration file, created by
 # sphinx-quickstart on Thu Feb 18 23:18:28 2010.
@@ -48,8 +47,8 @@ source_suffix = '.txt'
 master_doc = 'index'
 
 # General information about the project.
-project = u'dulwich'
-copyright = u'2011-2019 Jelmer Vernooij'
+project = 'dulwich'
+copyright = '2011-2023 Jelmer Vernooij'
 
 # The version info for the project you're documenting, acts as replacement for
 # |version| and |release|, also used in various other places throughout the
@@ -186,7 +185,7 @@ htmlhelp_basename = 'dulwichdoc'
 # (source start file, target name, title, author, documentclass
 # [howto/manual]).
 latex_documents = [
-    ('index', 'dulwich.tex', u'dulwich Documentation',
+    ('index', 'dulwich.tex', 'dulwich Documentation',
      'Jelmer Vernooij', 'manual'),
 ]
 

+ 1 - 0
docs/tutorial/remote.txt

@@ -55,6 +55,7 @@ which claims that the client doesn't have any objects::
 
    >>> class DummyGraphWalker(object):
    ...     def ack(self, sha): pass
+   ...     def nak(self): pass
    ...     def next(self): pass
    ...     def __next__(self): pass
 

+ 3 - 4
dulwich.egg-info/PKG-INFO

@@ -1,6 +1,6 @@
 Metadata-Version: 2.1
 Name: dulwich
-Version: 0.20.50
+Version: 0.21.0
 Summary: Python Git Library
 Home-page: https://www.dulwich.io/
 Author: Jelmer Vernooij
@@ -12,7 +12,6 @@ Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
 Keywords: vcs,git
 Classifier: Development Status :: 4 - Beta
 Classifier: License :: OSI Approved :: Apache Software License
-Classifier: Programming Language :: Python :: 3.6
 Classifier: Programming Language :: Python :: 3.7
 Classifier: Programming Language :: Python :: 3.8
 Classifier: Programming Language :: Python :: 3.9
@@ -23,11 +22,11 @@ Classifier: Programming Language :: Python :: Implementation :: PyPy
 Classifier: Operating System :: POSIX
 Classifier: Operating System :: Microsoft :: Windows
 Classifier: Topic :: Software Development :: Version Control
-Requires-Python: >=3.6
+Requires-Python: >=3.7
 Provides-Extra: fastimport
 Provides-Extra: https
-Provides-Extra: pgp
 Provides-Extra: paramiko
+Provides-Extra: pgp
 License-File: COPYING
 
 Dulwich

+ 2 - 0
dulwich.egg-info/SOURCES.txt

@@ -18,11 +18,13 @@ SECURITY.md
 TODO
 disperse.conf
 dulwich.cfg
+pyproject.toml
 requirements.txt
 setup.cfg
 setup.py
 status.yaml
 tox.ini
+.github/CODEOWNERS
 .github/FUNDING.yml
 .github/workflows/disperse.yml
 .github/workflows/docs.yml

+ 3 - 0
dulwich.egg-info/requires.txt

@@ -1,5 +1,8 @@
 urllib3>=1.25
 
+[:python_version <= "3.7"]
+typing_extensions
+
 [fastimport]
 fastimport
 

+ 1 - 1
dulwich/__init__.py

@@ -22,4 +22,4 @@
 
 """Python implementation of the Git file formats and protocols."""
 
-__version__ = (0, 20, 50)
+__version__ = (0, 21, 0)

+ 2 - 3
dulwich/archive.py

@@ -32,7 +32,7 @@ from io import BytesIO
 from contextlib import closing
 
 
-class ChunkedBytesIO(object):
+class ChunkedBytesIO:
     """Turn a list of bytestrings into a file-like object.
 
     This is similar to creating a `BytesIO` from a concatenation of the
@@ -129,7 +129,6 @@ def _walk_tree(store, tree, root=b""):
     for entry in tree.iteritems():
         entry_abspath = posixpath.join(root, entry.path)
         if stat.S_ISDIR(entry.mode):
-            for _ in _walk_tree(store, store[entry.sha], entry_abspath):
-                yield _
+            yield from _walk_tree(store, store[entry.sha], entry_abspath)
         else:
             yield (entry_abspath, entry)

+ 13 - 7
dulwich/bundle.py

@@ -25,14 +25,20 @@ from typing import Dict, List, Tuple, Optional, Union, Sequence
 from .pack import PackData, write_pack_data
 
 
-class Bundle(object):
+class Bundle:
 
-    version = None  # type: Optional[int]
+    version: Optional[int] = None
 
-    capabilities = {}  # type: Dict[str, str]
-    prerequisites = []  # type: List[Tuple[bytes, str]]
-    references = {}  # type: Dict[str, bytes]
-    pack_data = []  # type: Union[PackData, Sequence[bytes]]
+    capabilities: Dict[str, str] = {}
+    prerequisites: List[Tuple[bytes, str]] = []
+    references: Dict[str, bytes] = {}
+    pack_data: Union[PackData, Sequence[bytes]] = []
+
+    def __repr__(self):
+        return (f"<{type(self).__name__}(version={self.version}, "
+                f"capabilities={self.capabilities}, "
+                f"prerequisites={self.prerequisites}, "
+                f"references={self.references})>")
 
     def __eq__(self, other):
         if not isinstance(other, type(self)):
@@ -119,4 +125,4 @@ def write_bundle(f, bundle):
     for ref, obj_id in bundle.references.items():
         f.write(b"%s %s\n" % (obj_id, ref))
     f.write(b"\n")
-    write_pack_data(f.write, records=bundle.pack_data)
+    write_pack_data(f.write, num_records=len(bundle.pack_data), records=bundle.pack_data.iter_unpacked())

+ 45 - 17
dulwich/cli.py

@@ -37,7 +37,7 @@ import signal
 from typing import Dict, Type, Optional
 
 from dulwich import porcelain
-from dulwich.client import get_transport_and_path
+from dulwich.client import get_transport_and_path, GitProtocolError
 from dulwich.errors import ApplyDeltaError
 from dulwich.index import Index
 from dulwich.objectspec import parse_commit
@@ -55,7 +55,7 @@ def signal_quit(signal, frame):
     pdb.set_trace()
 
 
-class Command(object):
+class Command:
     """A Dulwich subcommand."""
 
     def run(self, args):
@@ -139,7 +139,7 @@ class cmd_fsck(Command):
         opts, args = getopt(args, "", [])
         opts = dict(opts)
         for (obj, msg) in porcelain.fsck("."):
-            print("%s: %s" % (obj, msg))
+            print("{}: {}".format(obj, msg))
 
 
 class cmd_log(Command):
@@ -202,9 +202,9 @@ class cmd_dump_pack(Command):
             try:
                 print("\t%s" % x[name])
             except KeyError as k:
-                print("\t%s: Unable to resolve base %s" % (name, k))
+                print("\t{}: Unable to resolve base {}".format(name, k))
             except ApplyDeltaError as e:
-                print("\t%s: Unable to apply delta: %r" % (name, e))
+                print("\t{}: Unable to apply delta: {!r}".format(name, e))
 
 
 class cmd_dump_index(Command):
@@ -263,8 +263,11 @@ class cmd_clone(Command):
         else:
             target = None
 
-        porcelain.clone(source, target, bare=options.bare, depth=options.depth,
-                        branch=options.branch)
+        try:
+            porcelain.clone(source, target, bare=options.bare, depth=options.depth,
+                            branch=options.branch)
+        except GitProtocolError as e:
+            print("%s" % e)
 
 
 class cmd_commit(Command):
@@ -300,6 +303,18 @@ class cmd_symbolic_ref(Command):
         porcelain.symbolic_ref(".", ref_name=ref_name, force="--force" in args)
 
 
+class cmd_pack_refs(Command):
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--all', action='store_true')
+        # ignored, we never prune
+        parser.add_argument('--no-prune', action='store_true')
+
+        args = parser.parse_args(argv)
+
+        porcelain.pack_refs(".", all=args.all)
+
+
 class cmd_show(Command):
     def run(self, argv):
         parser = argparse.ArgumentParser()
@@ -471,7 +486,7 @@ class cmd_status(Command):
             for kind, names in status.staged.items():
                 for name in names:
                     sys.stdout.write(
-                        "\t%s: %s\n" % (kind, name.decode(sys.getfilesystemencoding()))
+                        "\t{}: {}\n".format(kind, name.decode(sys.getfilesystemencoding()))
                     )
             sys.stdout.write("\n")
         if status.unstaged:
@@ -494,7 +509,7 @@ class cmd_ls_remote(Command):
             sys.exit(1)
         refs = porcelain.ls_remote(args[0])
         for ref in sorted(refs):
-            sys.stdout.write("%s\t%s\n" % (ref, refs[ref]))
+            sys.stdout.write("{}\t{}\n".format(ref, refs[ref]))
 
 
 class cmd_ls_tree(Command):
@@ -523,22 +538,34 @@ class cmd_ls_tree(Command):
 
 class cmd_pack_objects(Command):
     def run(self, args):
-        opts, args = getopt(args, "", ["stdout"])
+        deltify = False
+        reuse_deltas = True
+        opts, args = getopt(args, "", ["stdout", "deltify", "no-reuse-deltas"])
         opts = dict(opts)
-        if len(args) < 1 and "--stdout" not in args:
+        if len(args) < 1 and "--stdout" not in opts.keys():
             print("Usage: dulwich pack-objects basename")
             sys.exit(1)
         object_ids = [line.strip() for line in sys.stdin.readlines()]
-        basename = args[0]
-        if "--stdout" in opts:
+        if "--deltify" in opts.keys():
+            deltify = True
+        if "--no-reuse-deltas" in opts.keys():
+            reuse_deltas = False
+        if "--stdout" in opts.keys():
             packf = getattr(sys.stdout, "buffer", sys.stdout)
             idxf = None
             close = []
         else:
-            packf = open(basename + ".pack", "w")
-            idxf = open(basename + ".idx", "w")
+            basename = args[0]
+            packf = open(basename + ".pack", "wb")
+            idxf = open(basename + ".idx", "wb")
             close = [packf, idxf]
-        porcelain.pack_objects(".", object_ids, packf, idxf)
+        porcelain.pack_objects(
+            ".",
+            object_ids,
+            packf,
+            idxf,
+            deltify=deltify,
+            reuse_deltas=reuse_deltas)
         for f in close:
             f.close()
 
@@ -606,7 +633,7 @@ class cmd_submodule_list(Command):
         parser = argparse.ArgumentParser()
         parser.parse_args(argv)
         for path, sha in porcelain.submodule_list("."):
-            sys.stdout.write(' %s %s\n' % (sha, path))
+            sys.stdout.write(' {} {}\n'.format(sha, path))
 
 
 class cmd_submodule_init(Command):
@@ -744,6 +771,7 @@ commands = {
     "ls-remote": cmd_ls_remote,
     "ls-tree": cmd_ls_tree,
     "pack-objects": cmd_pack_objects,
+    "pack-refs": cmd_pack_refs,
     "pull": cmd_pull,
     "push": cmd_push,
     "receive-pack": cmd_receive_pack,

+ 137 - 80
dulwich/client.py

@@ -51,6 +51,8 @@ from typing import (
     Callable,
     Dict,
     List,
+    Iterable,
+    Iterator,
     Optional,
     Set,
     Tuple,
@@ -71,7 +73,6 @@ from urllib.parse import (
 if TYPE_CHECKING:
     import urllib3
 
-
 import dulwich
 from dulwich.config import get_xdg_config_home_path, Config, apply_instead_of
 from dulwich.errors import (
@@ -118,7 +119,8 @@ from dulwich.protocol import (
     pkt_line,
 )
 from dulwich.pack import (
-    write_pack_objects,
+    write_pack_from_container,
+    UnpackedObject,
     PackChunkGenerator,
 )
 from dulwich.refs import (
@@ -195,7 +197,7 @@ RECEIVE_CAPABILITIES = [
 ] + COMMON_CAPABILITIES
 
 
-class ReportStatusParser(object):
+class ReportStatusParser:
     """Handle status as reported by servers with 'report-status' capability."""
 
     def __init__(self):
@@ -259,13 +261,13 @@ def read_pkt_refs(pkt_seq):
         refs[ref] = sha
 
     if len(refs) == 0:
-        return {}, set([])
+        return {}, set()
     if refs == {CAPABILITIES_REF: ZERO_SHA}:
         refs = {}
     return refs, set(server_capabilities)
 
 
-class FetchPackResult(object):
+class FetchPackResult:
     """Result of a fetch-pack operation.
 
     Attributes:
@@ -337,10 +339,10 @@ class FetchPackResult(object):
         if name in type(self)._FORWARDED_ATTRS:
             self._warn_deprecated()
             return getattr(self.refs, name)
-        return super(FetchPackResult, self).__getattribute__(name)
+        return super().__getattribute__(name)
 
     def __repr__(self):
-        return "%s(%r, %r, %r)" % (
+        return "{}({!r}, {!r}, {!r})".format(
             self.__class__.__name__,
             self.refs,
             self.symrefs,
@@ -348,7 +350,7 @@ class FetchPackResult(object):
         )
 
 
-class SendPackResult(object):
+class SendPackResult:
     """Result of a upload-pack operation.
 
     Attributes:
@@ -415,10 +417,10 @@ class SendPackResult(object):
         if name in type(self)._FORWARDED_ATTRS:
             self._warn_deprecated()
             return getattr(self.refs, name)
-        return super(SendPackResult, self).__getattribute__(name)
+        return super().__getattribute__(name)
 
     def __repr__(self):
-        return "%s(%r, %r)" % (self.__class__.__name__, self.refs, self.agent)
+        return "{}({!r}, {!r})".format(self.__class__.__name__, self.refs, self.agent)
 
 
 def _read_shallow_updates(pkt_seq):
@@ -435,7 +437,7 @@ def _read_shallow_updates(pkt_seq):
     return (new_shallow, new_unshallow)
 
 
-class _v1ReceivePackHeader(object):
+class _v1ReceivePackHeader:
 
     def __init__(self, capabilities, old_refs, new_refs):
         self.want = []
@@ -465,12 +467,12 @@ class _v1ReceivePackHeader(object):
             old_sha1 = old_refs.get(refname, ZERO_SHA)
             if not isinstance(old_sha1, bytes):
                 raise TypeError(
-                    "old sha1 for %s is not a bytestring: %r" % (refname, old_sha1)
+                    "old sha1 for {!r} is not a bytestring: {!r}".format(refname, old_sha1)
                 )
             new_sha1 = new_refs.get(refname, ZERO_SHA)
             if not isinstance(new_sha1, bytes):
                 raise TypeError(
-                    "old sha1 for %s is not a bytestring %r" % (refname, new_sha1)
+                    "old sha1 for {!r} is not a bytestring {!r}".format(refname, new_sha1)
                 )
 
             if old_sha1 != new_sha1:
@@ -495,27 +497,17 @@ class _v1ReceivePackHeader(object):
         yield None
 
 
-def _read_side_band64k_data(pkt_seq, channel_callbacks):
+def _read_side_band64k_data(pkt_seq: Iterable[bytes]) -> Iterator[Tuple[int, bytes]]:
     """Read per-channel data.
 
     This requires the side-band-64k capability.
 
     Args:
       pkt_seq: Sequence of packets to read
-      channel_callbacks: Dictionary mapping channels to packet
-        handlers to use. None for a callback discards channel data.
     """
     for pkt in pkt_seq:
         channel = ord(pkt[:1])
-        pkt = pkt[1:]
-        try:
-            cb = channel_callbacks[channel]
-        except KeyError as exc:
-            raise AssertionError(
-                "Invalid sideband channel %d" % channel) from exc
-        else:
-            if cb is not None:
-                cb(pkt)
+        yield channel, pkt[1:]
 
 
 def _handle_upload_pack_head(
@@ -588,9 +580,9 @@ def _handle_upload_pack_head(
 
 def _handle_upload_pack_tail(
     proto,
-    capabilities,
+    capabilities: Set[bytes],
     graph_walker,
-    pack_data,
+    pack_data: Callable[[bytes], None],
     progress=None,
     rbufsize=_RBUFSIZE,
 ):
@@ -612,6 +604,8 @@ def _handle_upload_pack_tail(
         parts = pkt.rstrip(b"\n").split(b" ")
         if parts[0] == b"ACK":
             graph_walker.ack(parts[1])
+        if parts[0] == b"NAK":
+            graph_walker.nak()
         if len(parts) < 3 or parts[2] not in (
             b"ready",
             b"continue",
@@ -626,13 +620,14 @@ def _handle_upload_pack_tail(
             def progress(x):
                 pass
 
-        _read_side_band64k_data(
-            proto.read_pkt_seq(),
-            {
-                SIDE_BAND_CHANNEL_DATA: pack_data,
-                SIDE_BAND_CHANNEL_PROGRESS: progress,
-            },
-        )
+        for chan, data in _read_side_band64k_data(proto.read_pkt_seq()):
+            if chan == SIDE_BAND_CHANNEL_DATA:
+                pack_data(data)
+            elif chan == SIDE_BAND_CHANNEL_PROGRESS:
+                progress(data)
+            else:
+                raise AssertionError(
+                    "Invalid sideband channel %d" % chan)
     else:
         while True:
             data = proto.read(rbufsize)
@@ -644,7 +639,7 @@ def _handle_upload_pack_tail(
 # TODO(durin42): this doesn't correctly degrade if the server doesn't
 # support some capabilities. This should work properly with servers
 # that don't support multi_ack.
-class GitClient(object):
+class GitClient:
     """Git smart server client."""
 
     def __init__(
@@ -700,7 +695,7 @@ class GitClient(object):
         """
         raise NotImplementedError(cls.from_parsedurl)
 
-    def send_pack(self, path, update_refs, generate_pack_data, progress=None):
+    def send_pack(self, path, update_refs, generate_pack_data: Callable[[Set[bytes], Set[bytes], bool], Tuple[int, Iterator[UnpackedObject]]], progress=None):
         """Upload a pack to a remote repository.
 
         Args:
@@ -817,15 +812,13 @@ class GitClient(object):
         if determine_wants is None:
             determine_wants = target.object_store.determine_wants_all
         if CAPABILITY_THIN_PACK in self._fetch_capabilities:
-            # TODO(jelmer): Avoid reading entire file into memory and
-            # only processing it after the whole file has been fetched.
             from tempfile import SpooledTemporaryFile
-            f = SpooledTemporaryFile()  # type: IO[bytes]
+            f: IO[bytes] = SpooledTemporaryFile()
 
             def commit():
                 if f.tell():
                     f.seek(0)
-                    target.object_store.add_thin_pack(f.read, None)
+                    target.object_store.add_thin_pack(f.read, None, progress=progress)
                 f.close()
 
             def abort():
@@ -856,6 +849,7 @@ class GitClient(object):
         determine_wants,
         graph_walker,
         pack_data,
+        *,
         progress=None,
         depth=None,
     ):
@@ -910,7 +904,7 @@ class GitClient(object):
         self,
         proto: Protocol,
         capabilities: Set[bytes],
-        progress: Callable[[bytes], None] = None,
+        progress: Optional[Callable[[bytes], None]] = None,
     ) -> Optional[Dict[bytes, Optional[str]]]:
         """Handle the tail of a 'git-receive-pack' request.
 
@@ -930,12 +924,17 @@ class GitClient(object):
                 def progress(x):
                     pass
 
-            channel_callbacks = {2: progress}
             if CAPABILITY_REPORT_STATUS in capabilities:
-                channel_callbacks[1] = PktLineParser(
-                    self._report_status_parser.handle_packet
-                ).parse
-            _read_side_band64k_data(proto.read_pkt_seq(), channel_callbacks)
+                pktline_parser = PktLineParser(self._report_status_parser.handle_packet)
+            for chan, data in _read_side_band64k_data(proto.read_pkt_seq()):
+                if chan == SIDE_BAND_CHANNEL_DATA:
+                    if CAPABILITY_REPORT_STATUS in capabilities:
+                        pktline_parser.parse(data)
+                elif chan == SIDE_BAND_CHANNEL_PROGRESS:
+                    progress(data)
+                else:
+                    raise AssertionError(
+                        "Invalid sideband channel %d" % chan)
         else:
             if CAPABILITY_REPORT_STATUS in capabilities:
                 for pkt in proto.read_pkt_seq():
@@ -1013,7 +1012,7 @@ class TraditionalGitClient(GitClient):
 
     def __init__(self, path_encoding=DEFAULT_ENCODING, **kwargs):
         self._remote_path_encoding = path_encoding
-        super(TraditionalGitClient, self).__init__(**kwargs)
+        super().__init__(**kwargs)
 
     async def _connect(self, cmd, path):
         """Create a connection to the server.
@@ -1107,10 +1106,11 @@ class TraditionalGitClient(GitClient):
                 header_handler.have,
                 header_handler.want,
                 ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities),
+                progress=progress,
             )
 
             if self._should_send_pack(new_refs):
-                for chunk in PackChunkGenerator(pack_data_count, pack_data):
+                for chunk in PackChunkGenerator(pack_data_count, pack_data, progress=progress):
                     proto.write(chunk)
 
             ref_status = self._handle_receive_pack_tail(
@@ -1238,14 +1238,15 @@ class TraditionalGitClient(GitClient):
             ret = proto.read_pkt_line()
             if ret is not None:
                 raise AssertionError("expected pkt tail")
-            _read_side_band64k_data(
-                proto.read_pkt_seq(),
-                {
-                    SIDE_BAND_CHANNEL_DATA: write_data,
-                    SIDE_BAND_CHANNEL_PROGRESS: progress,
-                    SIDE_BAND_CHANNEL_FATAL: write_error,
-                },
-            )
+            for chan, data in _read_side_band64k_data(proto.read_pkt_seq()):
+                if chan == SIDE_BAND_CHANNEL_DATA:
+                    write_data(data)
+                elif chan == SIDE_BAND_CHANNEL_PROGRESS:
+                    progress(data)
+                elif chan == SIDE_BAND_CHANNEL_FATAL:
+                    write_error(data)
+                else:
+                    raise AssertionError("Invalid sideband channel %d" % chan)
 
 
 class TCPGitClient(TraditionalGitClient):
@@ -1256,7 +1257,7 @@ class TCPGitClient(TraditionalGitClient):
             port = TCP_GIT_PORT
         self._host = host
         self._port = port
-        super(TCPGitClient, self).__init__(**kwargs)
+        super().__init__(**kwargs)
 
     @classmethod
     def from_parsedurl(cls, parsedurl, **kwargs):
@@ -1284,7 +1285,7 @@ class TCPGitClient(TraditionalGitClient):
             try:
                 s.connect(sockaddr)
                 break
-            except socket.error as e:
+            except OSError as e:
                 err = e
                 if s is not None:
                     s.close()
@@ -1314,7 +1315,7 @@ class TCPGitClient(TraditionalGitClient):
         return proto, lambda: _fileno_can_read(s), None
 
 
-class SubprocessWrapper(object):
+class SubprocessWrapper:
     """A socket-like object that talks to a subprocess via pipes."""
 
     def __init__(self, proc):
@@ -1473,7 +1474,7 @@ class LocalGitClient(GitClient):
                 old_sha1 = old_refs.get(refname, ZERO_SHA)
                 if new_sha1 != ZERO_SHA:
                     if not target.refs.set_if_equals(refname, old_sha1, new_sha1):
-                        msg = "unable to set %s to %s" % (refname, new_sha1)
+                        msg = "unable to set {} to {}".format(refname, new_sha1)
                         progress(msg)
                         ref_status[refname] = msg
                 else:
@@ -1516,7 +1517,7 @@ class LocalGitClient(GitClient):
         pack_data,
         progress=None,
         depth=None,
-    ):
+    ) -> FetchPackResult:
         """Retrieve a pack from a git smart server.
 
         Args:
@@ -1534,17 +1535,19 @@ class LocalGitClient(GitClient):
 
         """
         with self._open_repo(path) as r:
-            objects_iter = r.fetch_objects(
+            missing_objects = r.find_missing_objects(
                 determine_wants, graph_walker, progress=progress, depth=depth
             )
+            other_haves = missing_objects.get_remote_has()
+            object_ids = list(missing_objects)
             symrefs = r.refs.get_symrefs()
             agent = agent_string()
 
             # Did the process short-circuit (e.g. in a stateless RPC call)?
             # Note that the client still expects a 0-object pack in most cases.
-            if objects_iter is None:
+            if object_ids is None:
                 return FetchPackResult(None, symrefs, agent)
-            write_pack_objects(pack_data, objects_iter)
+            write_pack_from_container(pack_data, r.object_store, object_ids, other_haves=other_haves)
             return FetchPackResult(r.get_refs(), symrefs, agent)
 
     def get_refs(self, path):
@@ -1558,7 +1561,7 @@ class LocalGitClient(GitClient):
 default_local_git_client_cls = LocalGitClient
 
 
-class SSHVendor(object):
+class SSHVendor:
     """A client side SSH implementation."""
 
     def run_command(
@@ -1595,7 +1598,7 @@ class StrangeHostname(Exception):
     """Refusing to connect to strange SSH hostname."""
 
     def __init__(self, hostname):
-        super(StrangeHostname, self).__init__(hostname)
+        super().__init__(hostname)
 
 
 class SubprocessSSHVendor(SSHVendor):
@@ -1631,7 +1634,7 @@ class SubprocessSSHVendor(SSHVendor):
             args.extend(["-i", str(key_filename)])
 
         if username:
-            host = "%s@%s" % (username, host)
+            host = "{}@{}".format(username, host)
         if host.startswith("-"):
             raise StrangeHostname(hostname=host)
         args.append(host)
@@ -1685,7 +1688,7 @@ class PLinkSSHVendor(SSHVendor):
             args.extend(["-i", str(key_filename)])
 
         if username:
-            host = "%s@%s" % (username, host)
+            host = "{}@{}".format(username, host)
         if host.startswith("-"):
             raise StrangeHostname(hostname=host)
         args.append(host)
@@ -1737,7 +1740,7 @@ class SSHGitClient(TraditionalGitClient):
         self.ssh_command = ssh_command or os.environ.get(
             "GIT_SSH_COMMAND", os.environ.get("GIT_SSH")
         )
-        super(SSHGitClient, self).__init__(**kwargs)
+        super().__init__(**kwargs)
         self.alternative_paths = {}
         if vendor is not None:
             self.ssh_vendor = vendor
@@ -1811,7 +1814,7 @@ def default_user_agent_string():
 
 
 def default_urllib3_manager(   # noqa: C901
-    config, pool_manager_cls=None, proxy_manager_cls=None, **override_kwargs
+    config, pool_manager_cls=None, proxy_manager_cls=None, base_url=None, **override_kwargs
 ) -> Union["urllib3.ProxyManager", "urllib3.PoolManager"]:
     """Return urllib3 connection pool manager.
 
@@ -1833,9 +1836,13 @@ def default_urllib3_manager(   # noqa: C901
     if proxy_server is None:
         for proxyname in ("https_proxy", "http_proxy", "all_proxy"):
             proxy_server = os.environ.get(proxyname)
-            if proxy_server is not None:
+            if proxy_server:
                 break
 
+    if proxy_server:
+        if check_for_proxy_bypass(base_url):
+            proxy_server = None
+    
     if config is not None:
         if proxy_server is None:
             try:
@@ -1892,6 +1899,54 @@ def default_urllib3_manager(   # noqa: C901
     return manager
 
 
+def check_for_proxy_bypass(base_url):
+    # Check if a proxy bypass is defined with the no_proxy environment variable
+    if base_url:  # only check if base_url is provided
+        no_proxy_str = os.environ.get("no_proxy")
+        if no_proxy_str:
+            # implementation based on curl behavior: https://curl.se/libcurl/c/CURLOPT_NOPROXY.html
+            # get hostname of provided parsed url
+            parsed_url = urlparse(base_url)
+            hostname = parsed_url.hostname
+
+            if hostname:
+                import ipaddress
+
+                # check if hostname is an ip address
+                try:
+                    hostname_ip = ipaddress.ip_address(hostname)
+                except ValueError:
+                    hostname_ip = None
+
+                no_proxy_values = no_proxy_str.split(',')
+                for no_proxy_value in no_proxy_values:
+                    no_proxy_value = no_proxy_value.strip()
+                    if no_proxy_value:
+                        no_proxy_value = no_proxy_value.lower()
+                        no_proxy_value = no_proxy_value.lstrip('.')  # ignore leading dots
+
+                        if hostname_ip:
+                            # check if no_proxy_value is a ip network
+                            try:
+                                no_proxy_value_network = ipaddress.ip_network(no_proxy_value, strict=False)
+                            except ValueError:
+                                no_proxy_value_network = None
+                            if no_proxy_value_network:
+                                # if hostname is a ip address and no_proxy_value is a ip network -> check if ip address is part of network
+                                if hostname_ip in no_proxy_value_network:
+                                    return True
+                                
+                        if no_proxy_value == '*':
+                            # '*' is special case for always bypass proxy
+                            return True
+                        if hostname == no_proxy_value:
+                            return True
+                        no_proxy_value = '.' + no_proxy_value   # add a dot to only match complete domains
+                        if hostname.endswith(no_proxy_value):
+                            return True
+    return False
+
+
 class AbstractHttpGitClient(GitClient):
     """Abstract base class for HTTP Git Clients.
 
@@ -1943,7 +1998,9 @@ class AbstractHttpGitClient(GitClient):
             base_url = resp.redirect_location[: -len(tail)]
 
         try:
-            self.dumb = not resp.content_type.startswith("application/x-git-")
+            self.dumb = (
+                resp.content_type is None
+                or not resp.content_type.startswith("application/x-git-"))
             if not self.dumb:
                 proto = Protocol(read, None)
                 # The first line should mention the service
@@ -2142,14 +2199,14 @@ class AbstractHttpGitClient(GitClient):
             kwargs["username"] = urlunquote(username)
         netloc = parsedurl.hostname
         if parsedurl.port:
-            netloc = "%s:%s" % (netloc, parsedurl.port)
+            netloc = "{}:{}".format(netloc, parsedurl.port)
         if parsedurl.username:
-            netloc = "%s@%s" % (parsedurl.username, netloc)
+            netloc = "{}@{}".format(parsedurl.username, netloc)
         parsedurl = parsedurl._replace(netloc=netloc)
         return cls(urlunparse(parsedurl), **kwargs)
 
     def __repr__(self):
-        return "%s(%r, dumb=%r)" % (
+        return "{}({!r}, dumb={!r})".format(
             type(self).__name__,
             self._base_url,
             self.dumb,
@@ -2171,7 +2228,7 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
         self._password = password
 
         if pool_manager is None:
-            self.pool_manager = default_urllib3_manager(config)
+            self.pool_manager = default_urllib3_manager(config, base_url=base_url)
         else:
             self.pool_manager = pool_manager
 
@@ -2186,7 +2243,7 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
 
         self.config = config
 
-        super(Urllib3HttpGitClient, self).__init__(
+        super().__init__(
             base_url=base_url, dumb=dumb, **kwargs)
 
     def _get_url(self, path):
@@ -2213,15 +2270,15 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
         if resp.status == 404:
             raise NotGitRepository()
         if resp.status == 401:
-            raise HTTPUnauthorized(resp.getheader("WWW-Authenticate"), url)
+            raise HTTPUnauthorized(resp.headers.get("WWW-Authenticate"), url)
         if resp.status == 407:
-            raise HTTPProxyUnauthorized(resp.getheader("Proxy-Authenticate"), url)
+            raise HTTPProxyUnauthorized(resp.headers.get("Proxy-Authenticate"), url)
         if resp.status != 200:
             raise GitProtocolError(
                 "unexpected http resp %d for %s" % (resp.status, url)
             )
 
-        resp.content_type = resp.getheader("Content-Type")
+        resp.content_type = resp.headers.get("Content-Type")
         # Check if geturl() is available (urllib3 version >= 1.23)
         try:
             resp_url = resp.geturl()

+ 3 - 3
dulwich/cloud/gcs.py

@@ -34,12 +34,12 @@ from ..pack import PackData, Pack, load_pack_index_file
 class GcsObjectStore(BucketBasedObjectStore):
 
     def __init__(self, bucket, subpath=''):
-        super(GcsObjectStore, self).__init__()
+        super().__init__()
         self.bucket = bucket
         self.subpath = subpath
 
     def __repr__(self):
-        return "%s(%r, subpath=%r)" % (
+        return "{}({!r}, subpath={!r})".format(
             type(self).__name__, self.bucket, self.subpath)
 
     def _remove_pack(self, name):
@@ -53,7 +53,7 @@ class GcsObjectStore(BucketBasedObjectStore):
             name, ext = posixpath.splitext(posixpath.basename(blob.name))
             packs.setdefault(name, set()).add(ext)
         for name, exts in packs.items():
-            if exts == set(['.pack', '.idx']):
+            if exts == {'.pack', '.idx'}:
                 yield name
 
     def _load_pack_data(self, name):

+ 9 - 9
dulwich/config.py

@@ -52,7 +52,7 @@ def lower_key(key):
         return key.lower()
 
     if isinstance(key, Iterable):
-        return type(key)(map(lower_key, key))
+        return type(key)(map(lower_key, key))  # type: ignore
 
     return key
 
@@ -145,7 +145,7 @@ Value = bytes
 ValueLike = Union[bytes, str]
 
 
-class Config(object):
+class Config:
     """A Git configuration."""
 
     def get(self, section: SectionLike, name: NameLike) -> Value:
@@ -265,7 +265,7 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
         self._values = CaseInsensitiveOrderedMultiDict.make(values)
 
     def __repr__(self) -> str:
-        return "%s(%r)" % (self.__class__.__name__, self._values)
+        return "{}({!r})".format(self.__class__.__name__, self._values)
 
     def __eq__(self, other: object) -> bool:
         return isinstance(other, self.__class__) and other._values == self._values
@@ -534,7 +534,7 @@ class ConfigFile(ConfigDict):
         ] = None,
         encoding: Union[str, None] = None
     ) -> None:
-        super(ConfigFile, self).__init__(values=values, encoding=encoding)
+        super().__init__(values=values, encoding=encoding)
         self.path: Optional[str] = None
 
     @classmethod  # noqa: C901
@@ -651,11 +651,11 @@ def _find_git_in_win_reg():
             "Uninstall\\Git_is1"
         )
 
-    for key in (winreg.HKEY_CURRENT_USER, winreg.HKEY_LOCAL_MACHINE):
+    for key in (winreg.HKEY_CURRENT_USER, winreg.HKEY_LOCAL_MACHINE):  # type: ignore
         try:
-            with winreg.OpenKey(key, subkey) as k:
-                val, typ = winreg.QueryValueEx(k, "InstallLocation")
-                if typ == winreg.REG_SZ:
+            with winreg.OpenKey(key, subkey) as k:  # type: ignore
+                val, typ = winreg.QueryValueEx(k, "InstallLocation")  # type: ignore
+                if typ == winreg.REG_SZ:  # type: ignore
                     yield val
         except OSError:
             pass
@@ -687,7 +687,7 @@ class StackedConfig(Config):
         self.writable = writable
 
     def __repr__(self) -> str:
-        return "<%s for %r>" % (self.__class__.__name__, self.backends)
+        return "<{} for {!r}>".format(self.__class__.__name__, self.backends)
 
     @classmethod
     def default(cls) -> "StackedConfig":

+ 7 - 4
dulwich/contrib/diffstat.py

@@ -1,5 +1,4 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*-
 # vim:ts=4:sw=4:softtabstop=4:smarttab:expandtab
 
 # Copyright (c) 2020 Kevin B. Hendricks, Stratford Ontario Canada
@@ -35,6 +34,7 @@
 
 import sys
 import re
+from typing import Optional, Tuple, List
 
 # only needs to detect git style diffs as this is for
 # use with dulwich
@@ -55,7 +55,7 @@ _GIT_UNCHANGED_START = b" "
 # properly interface with diffstat routine
 
 
-def _parse_patch(lines):
+def _parse_patch(lines: List[bytes]) -> Tuple[List[bytes], List[bool], List[Tuple[int, int]]]:
     """Parse a git style diff or patch to generate diff stats.
 
     Args:
@@ -66,7 +66,7 @@ def _parse_patch(lines):
     nametypes = []
     counts = []
     in_patch_chunk = in_git_header = binaryfile = False
-    currentfile = None
+    currentfile: Optional[bytes] = None
     added = deleted = 0
     for line in lines:
         if line.startswith(_GIT_HEADER_START):
@@ -74,7 +74,9 @@ def _parse_patch(lines):
                 names.append(currentfile)
                 nametypes.append(binaryfile)
                 counts.append((added, deleted))
-            currentfile = _git_header_name.search(line).group(2)
+            m = _git_header_name.search(line)
+            assert m
+            currentfile = m.group(2)
             binaryfile = False
             added = deleted = 0
             in_git_header = True
@@ -85,6 +87,7 @@ def _parse_patch(lines):
         elif line.startswith(_GIT_RENAMEFROM_START) and in_git_header:
             currentfile = line[12:]
         elif line.startswith(_GIT_RENAMETO_START) and in_git_header:
+            assert currentfile
             currentfile += b" => %s" % line[10:]
         elif line.startswith(_GIT_CHUNK_START) and (in_patch_chunk or in_git_header):
             in_patch_chunk = True

+ 2 - 2
dulwich/contrib/paramiko_vendor.py

@@ -34,7 +34,7 @@ import paramiko
 import paramiko.client
 
 
-class _ParamikoWrapper(object):
+class _ParamikoWrapper:
     def __init__(self, client, channel):
         self.client = client
         self.channel = channel
@@ -70,7 +70,7 @@ class _ParamikoWrapper(object):
         self.channel.close()
 
 
-class ParamikoSSHVendor(object):
+class ParamikoSSHVendor:
     # http://docs.paramiko.org/en/2.4/api/client.html
 
     def __init__(self, **kwargs):

+ 2 - 2
dulwich/contrib/release_robot.py

@@ -78,11 +78,11 @@ def get_recent_tags(projdir=PROJDIR):
             obj = project.get_object(value)  # dulwich object from SHA-1
             # don't just check if object is "tag" b/c it could be a "commit"
             # instead check if "tags" is in the ref-name
-            if u"tags" not in key:
+            if "tags" not in key:
                 # skip ref if not a tag
                 continue
             # strip the leading text from refs to get "tag name"
-            _, tag = key.rsplit(u"/", 1)
+            _, tag = key.rsplit("/", 1)
             # check if tag object is "commit" or "tag" pointing to a "commit"
             try:
                 commit = obj.object  # a tuple (commit class, commit id)

+ 1 - 1
dulwich/contrib/requests_vendor.py

@@ -55,7 +55,7 @@ class RequestsHttpGitClient(AbstractHttpGitClient):
         if username is not None:
             self.session.auth = (username, password)
 
-        super(RequestsHttpGitClient, self).__init__(
+        super().__init__(
             base_url=base_url, dumb=dumb, **kwargs)
 
     def _http_request(self, url, headers=None, data=None, allow_compression=False):

+ 7 - 30
dulwich/contrib/swift.py

@@ -40,7 +40,6 @@ from geventhttpclient import HTTPClient
 
 from dulwich.greenthreads import (
     GreenThreadsMissingObjectFinder,
-    GreenThreadsObjectStoreIterator,
 )
 
 from dulwich.lru_cache import LRUSizeCache
@@ -119,15 +118,6 @@ cache_length = 20
 """
 
 
-class PackInfoObjectStoreIterator(GreenThreadsObjectStoreIterator):
-    def __len__(self):
-        while self.finder.objects_to_send:
-            for _ in range(0, len(self.finder.objects_to_send)):
-                sha = self.finder.next()
-                self._shas.append(sha)
-        return len(self._shas)
-
-
 class PackInfoMissingObjectFinder(GreenThreadsMissingObjectFinder):
     def next(self):
         while True:
@@ -234,7 +224,7 @@ class SwiftException(Exception):
     pass
 
 
-class SwiftConnector(object):
+class SwiftConnector:
     """A Connector to swift that manage authentication and errors catching"""
 
     def __init__(self, root, conf):
@@ -501,7 +491,7 @@ class SwiftConnector(object):
             )
 
 
-class SwiftPackReader(object):
+class SwiftPackReader:
     """A SwiftPackReader that mimic read and sync method
 
     The reader allows to read a specified amount of bytes from
@@ -532,7 +522,7 @@ class SwiftPackReader(object):
             self.buff_length = self.buff_length * 2
         offset = self.base_offset
         r = min(self.base_offset + self.buff_length, self.pack_length)
-        ret = self.scon.get_object(self.filename, range="%s-%s" % (offset, r))
+        ret = self.scon.get_object(self.filename, range="{}-{}".format(offset, r))
         self.buff = ret
 
     def read(self, length):
@@ -629,7 +619,7 @@ class SwiftPack(Pack):
     def __init__(self, *args, **kwargs):
         self.scon = kwargs["scon"]
         del kwargs["scon"]
-        super(SwiftPack, self).__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
         self._pack_info_path = self._basename + ".info"
         self._pack_info = None
         self._pack_info_load = lambda: load_pack_info(self._pack_info_path, self.scon)
@@ -657,7 +647,7 @@ class SwiftObjectStore(PackBasedObjectStore):
         Args:
           scon: A `SwiftConnector` instance
         """
-        super(SwiftObjectStore, self).__init__()
+        super().__init__()
         self.scon = scon
         self.root = self.scon.root
         self.pack_dir = posixpath.join(OBJECTDIR, PACKDIR)
@@ -681,19 +671,6 @@ class SwiftObjectStore(PackBasedObjectStore):
         """Loose objects are not supported by this repository"""
         return []
 
-    def iter_shas(self, finder):
-        """An iterator over pack's ObjectStore.
-
-        Returns: a `ObjectStoreIterator` or `GreenThreadsObjectStoreIterator`
-                 instance if gevent is enabled
-        """
-        shas = iter(finder.next, None)
-        return PackInfoObjectStoreIterator(self, shas, finder, self.scon.concurrency)
-
-    def find_missing_objects(self, *args, **kwargs):
-        kwargs["concurrency"] = self.scon.concurrency
-        return PackInfoMissingObjectFinder(self, *args, **kwargs)
-
     def pack_info_get(self, sha):
         for pack in self.packs:
             if sha in pack:
@@ -860,7 +837,7 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         f = self.scon.get_object(self.filename)
         if not f:
             f = BytesIO(b"")
-        super(SwiftInfoRefsContainer, self).__init__(f)
+        super().__init__(f)
 
     def _load_check_ref(self, name, old_ref):
         self._check_refname(name)
@@ -1066,7 +1043,7 @@ def main(argv=sys.argv):
     }
 
     if len(sys.argv) < 2:
-        print("Usage: %s <%s> [OPTIONS...]" % (sys.argv[0], "|".join(commands.keys())))
+        print("Usage: {} <{}> [OPTIONS...]".format(sys.argv[0], "|".join(commands.keys())))
         sys.exit(1)
 
     cmd = sys.argv[1]

+ 2 - 2
dulwich/contrib/test_paramiko_vendor.py

@@ -38,7 +38,7 @@ else:
     class Server(paramiko.ServerInterface):
         """http://docs.paramiko.org/en/2.4/api/server.html"""
         def __init__(self, commands, *args, **kwargs):
-            super(Server, self).__init__(*args, **kwargs)
+            super().__init__(*args, **kwargs)
             self.commands = commands
 
         def check_channel_exec_request(self, channel, command):
@@ -152,7 +152,7 @@ class ParamikoSSHVendorTests(TestCase):
     def _run(self):
         try:
             conn, addr = self.socket.accept()
-        except socket.error:
+        except OSError:
             return False
         self.transport = paramiko.Transport(conn)
         self.addCleanup(self.transport.close)

+ 7 - 7
dulwich/contrib/test_swift.py

@@ -99,7 +99,7 @@ def create_swift_connector(store={}):
     return lambda root, conf: FakeSwiftConnector(root, conf=conf, store=store)
 
 
-class Response(object):
+class Response:
     def __init__(self, headers={}, status=200, content=None):
         self.headers = headers
         self.status_code = status
@@ -183,14 +183,14 @@ def create_commit(data, marker=b"Default", blob=None):
 def create_commits(length=1, marker=b"Default"):
     data = []
     for i in range(0, length):
-        _marker = ("%s_%s" % (marker, i)).encode()
+        _marker = ("{}_{}".format(marker, i)).encode()
         blob, tree, tag, cmt = create_commit(data, _marker)
         data.extend([blob, tree, tag, cmt])
     return data
 
 
 @skipIf(missing_libs, skipmsg)
-class FakeSwiftConnector(object):
+class FakeSwiftConnector:
     def __init__(self, root, conf, store=None):
         if store:
             self.store = store
@@ -246,7 +246,7 @@ class FakeSwiftConnector(object):
 @skipIf(missing_libs, skipmsg)
 class TestSwiftRepo(TestCase):
     def setUp(self):
-        super(TestSwiftRepo, self).setUp()
+        super().setUp()
         self.conf = swift.load_conf(file=StringIO(config_file % def_config_file))
 
     def test_init(self):
@@ -302,7 +302,7 @@ class TestSwiftRepo(TestCase):
 @skipIf(missing_libs, skipmsg)
 class TestSwiftInfoRefsContainer(TestCase):
     def setUp(self):
-        super(TestSwiftInfoRefsContainer, self).setUp()
+        super().setUp()
         content = (
             b"22effb216e3a82f97da599b8885a6cadb488b4c5\trefs/heads/master\n"
             b"cca703b0e1399008b53a1a236d6b4584737649e4\trefs/heads/dev"
@@ -343,7 +343,7 @@ class TestSwiftInfoRefsContainer(TestCase):
 @skipIf(missing_libs, skipmsg)
 class TestSwiftConnector(TestCase):
     def setUp(self):
-        super(TestSwiftConnector, self).setUp()
+        super().setUp()
         self.conf = swift.load_conf(file=StringIO(config_file % def_config_file))
         with patch("geventhttpclient.HTTPClient.request", fake_auth_request_v1):
             self.conn = swift.SwiftConnector("fakerepo", conf=self.conf)
@@ -409,7 +409,7 @@ class TestSwiftConnector(TestCase):
         with patch(
             "geventhttpclient.HTTPClient.request",
             lambda *args: Response(
-                content=json.dumps((({"name": "a"}, {"name": "b"})))
+                content=json.dumps(({"name": "a"}, {"name": "b"}))
             ),
         ):
             self.assertEqual(len(self.conn.get_container_objects()), 2)

+ 9 - 8
dulwich/diff_tree.py

@@ -28,10 +28,12 @@ from collections import (
 from io import BytesIO
 from itertools import chain
 import stat
+from typing import List, Dict, Optional
 
 from dulwich.objects import (
     S_ISGITLINK,
     TreeEntry,
+    Tree,
 )
 
 
@@ -65,8 +67,8 @@ class TreeChange(namedtuple("TreeChange", ["type", "old", "new"])):
         return cls(CHANGE_DELETE, old, _NULL_ENTRY)
 
 
-def _tree_entries(path, tree):
-    result = []
+def _tree_entries(path: str, tree: Tree) -> List[TreeEntry]:
+    result: List[TreeEntry] = []
     if not tree:
         return result
     for entry in tree.iteritems(name_order=True):
@@ -189,13 +191,12 @@ def tree_changes(
         source and target tree.
     """
     if rename_detector is not None and tree1_id is not None and tree2_id is not None:
-        for change in rename_detector.changes_with_renames(
+        yield from rename_detector.changes_with_renames(
             tree1_id,
             tree2_id,
             want_unchanged=want_unchanged,
             include_trees=include_trees,
-        ):
-            yield change
+        )
         return
 
     entries = walk_trees(
@@ -269,7 +270,7 @@ def tree_changes_for_merge(store, parent_tree_ids, tree_id, rename_detector=None
         for t in parent_tree_ids
     ]
     num_parents = len(parent_tree_ids)
-    changes_by_path = defaultdict(lambda: [None] * num_parents)
+    changes_by_path: Dict[str, List[Optional[TreeChange]]] = defaultdict(lambda: [None] * num_parents)
 
     # Organize by path.
     for i, parent_changes in enumerate(all_parent_changes):
@@ -315,7 +316,7 @@ def _count_blocks(obj):
     Returns:
       A dict of block hashcode -> total bytes occurring.
     """
-    block_counts = defaultdict(int)
+    block_counts: Dict[int, int] = defaultdict(int)
     block = BytesIO()
     n = 0
 
@@ -400,7 +401,7 @@ def _tree_change_key(entry):
     return (path1, path2)
 
 
-class RenameDetector(object):
+class RenameDetector:
     """Object for handling rename detection between two trees."""
 
     def __init__(

+ 9 - 7
dulwich/errors.py

@@ -43,12 +43,12 @@ class ChecksumMismatch(Exception):
         if self.extra is None:
             Exception.__init__(
                 self,
-                "Checksum mismatch: Expected %s, got %s" % (expected, got),
+                "Checksum mismatch: Expected {}, got {}".format(expected, got),
             )
         else:
             Exception.__init__(
                 self,
-                "Checksum mismatch: Expected %s, got %s; %s" % (expected, got, extra),
+                "Checksum mismatch: Expected {}, got {}; {}".format(expected, got, extra),
             )
 
 
@@ -61,8 +61,10 @@ class WrongObjectException(Exception):
     was expected if they were raised.
     """
 
+    type_name: str
+
     def __init__(self, sha, *args, **kwargs):
-        Exception.__init__(self, "%s is not a %s" % (sha, self.type_name))
+        Exception.__init__(self, "{} is not a {}".format(sha, self.type_name))
 
 
 class NotCommitError(WrongObjectException):
@@ -140,7 +142,7 @@ class UpdateRefsError(GitProtocolError):
 
     def __init__(self, *args, **kwargs):
         self.ref_status = kwargs.pop("ref_status")
-        super(UpdateRefsError, self).__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
 
 
 class HangupException(GitProtocolError):
@@ -148,13 +150,13 @@ class HangupException(GitProtocolError):
 
     def __init__(self, stderr_lines=None):
         if stderr_lines:
-            super(HangupException, self).__init__(
+            super().__init__(
                 "\n".join(
                     [line.decode("utf-8", "surrogateescape") for line in stderr_lines]
                 )
             )
         else:
-            super(HangupException, self).__init__(
+            super().__init__(
                 "The remote server unexpectedly closed the connection."
             )
         self.stderr_lines = stderr_lines
@@ -171,7 +173,7 @@ class UnexpectedCommandError(GitProtocolError):
             command = "flush-pkt"
         else:
             command = "command %s" % command
-        super(UnexpectedCommandError, self).__init__(
+        super().__init__(
             "Protocol got unexpected %s" % command
         )
 

+ 5 - 2
dulwich/fastexport.py

@@ -30,6 +30,9 @@ from dulwich.objects import (
     Tag,
     ZERO_SHA,
 )
+from dulwich.object_store import (
+    iter_tree_contents,
+)
 from fastimport import (
     commands,
     errors as fastimport_errors,
@@ -45,7 +48,7 @@ def split_email(text):
     return (name, email.rstrip(b">"))
 
 
-class GitFastExporter(object):
+class GitFastExporter:
     """Generate a fast-export output stream for Git objects."""
 
     def __init__(self, outf, store):
@@ -232,7 +235,7 @@ class GitImportProcessor(processor.ImportProcessor):
                 path,
                 mode,
                 hexsha,
-            ) in self.repo.object_store.iter_tree_contents(tree_id):
+            ) in iter_tree_contents(self.repo.object_store, tree_id):
                 self._contents[path] = (mode, hexsha)
 
     def reset_handler(self, cmd):

+ 19 - 19
dulwich/file.py

@@ -20,7 +20,6 @@
 
 """Safe access to git files."""
 
-import io
 import os
 import sys
 
@@ -83,15 +82,15 @@ def GitFile(filename, mode="rb", bufsize=-1, mask=0o644):
 
     """
     if "a" in mode:
-        raise IOError("append mode not supported for Git files")
+        raise OSError("append mode not supported for Git files")
     if "+" in mode:
-        raise IOError("read/write mode not supported for Git files")
+        raise OSError("read/write mode not supported for Git files")
     if "b" not in mode:
-        raise IOError("text mode not supported for Git files")
+        raise OSError("text mode not supported for Git files")
     if "w" in mode:
         return _GitFile(filename, mode, bufsize, mask)
     else:
-        return io.open(filename, mode, bufsize)
+        return open(filename, mode, bufsize)
 
 
 class FileLocked(Exception):
@@ -100,10 +99,10 @@ class FileLocked(Exception):
     def __init__(self, filename, lockfilename):
         self.filename = filename
         self.lockfilename = lockfilename
-        super(FileLocked, self).__init__(filename, lockfilename)
+        super().__init__(filename, lockfilename)
 
 
-class _GitFile(object):
+class _GitFile:
     """File that follows the git locking protocol for writes.
 
     All writes to a file foo will be written into foo.lock in the same
@@ -114,17 +113,15 @@ class _GitFile(object):
         released. Typically this will happen in a finally block.
     """
 
-    PROXY_PROPERTIES = set(
-        [
-            "closed",
-            "encoding",
-            "errors",
-            "mode",
-            "name",
-            "newlines",
-            "softspace",
-        ]
-    )
+    PROXY_PROPERTIES = {
+        "closed",
+        "encoding",
+        "errors",
+        "mode",
+        "name",
+        "newlines",
+        "softspace",
+    }
     PROXY_METHODS = (
         "__iter__",
         "flush",
@@ -209,7 +206,10 @@ class _GitFile(object):
         return self
 
     def __exit__(self, exc_type, exc_val, exc_tb):
-        self.close()
+        if exc_type is not None:
+            self.abort()
+        else:
+            self.close()
 
     def __getattr__(self, name):
         """Proxy property calls to the underlying file."""

+ 2 - 2
dulwich/graph.py

@@ -1,5 +1,4 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*-
 # vim:ts=4:sw=4:softtabstop=4:smarttab:expandtab
 # Copyright (c) 2020 Kevin B. Hendricks, Stratford Ontario Canada
 #
@@ -23,6 +22,7 @@
 Implementation of merge-base following the approach of git
 """
 
+from typing import Deque
 from collections import deque
 
 
@@ -44,7 +44,7 @@ def _find_lcas(lookup_parents, c1, c2s):
         return False
 
     # initialize the working list
-    wlst = deque()
+    wlst: Deque[int] = deque()
     cstates[c1] = _ANC_OF_1
     wlst.append(c1)
     for c2 in c2s:

+ 8 - 42
dulwich/greenthreads.py

@@ -31,12 +31,12 @@ from dulwich.objects import (
 )
 from dulwich.object_store import (
     MissingObjectFinder,
+    _collect_ancestors,
     _collect_filetree_revs,
-    ObjectStoreIterator,
 )
 
 
-def _split_commits_and_tags(obj_store, lst, ignore_unknown=False, pool=None):
+def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
     """Split object id list into two list with commit SHA1s and tag SHA1s.
 
     Same implementation as object_store._split_commits_and_tags
@@ -90,11 +90,11 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
         self.object_store = object_store
         p = pool.Pool(size=concurrency)
 
-        have_commits, have_tags = _split_commits_and_tags(object_store, haves, True, p)
-        want_commits, want_tags = _split_commits_and_tags(object_store, wants, False, p)
-        all_ancestors = object_store._collect_ancestors(have_commits)[0]
-        missing_commits, common_commits = object_store._collect_ancestors(
-            want_commits, all_ancestors
+        have_commits, have_tags = _split_commits_and_tags(object_store, haves, ignore_unknown=True, pool=p)
+        want_commits, want_tags = _split_commits_and_tags(object_store, wants, ignore_unknown=False, pool=p)
+        all_ancestors = _collect_ancestors(object_store, have_commits)[0]
+        missing_commits, common_commits = _collect_ancestors(
+            object_store, want_commits, all_ancestors
         )
 
         self.sha_done = set()
@@ -104,43 +104,9 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
             self.sha_done.add(t)
         missing_tags = want_tags.difference(have_tags)
         wants = missing_commits.union(missing_tags)
-        self.objects_to_send = set([(w, None, False) for w in wants])
+        self.objects_to_send = {(w, None, False) for w in wants}
         if progress is None:
             self.progress = lambda x: None
         else:
             self.progress = progress
         self._tagged = get_tagged and get_tagged() or {}
-
-
-class GreenThreadsObjectStoreIterator(ObjectStoreIterator):
-    """ObjectIterator that works on top of an ObjectStore.
-
-    Same implementation as object_store.ObjectStoreIterator
-    except we use gevent to parallelize object retrieval.
-    """
-
-    def __init__(self, store, shas, finder, concurrency=1):
-        self.finder = finder
-        self.p = pool.Pool(size=concurrency)
-        super(GreenThreadsObjectStoreIterator, self).__init__(store, shas)
-
-    def retrieve(self, args):
-        sha, path = args
-        return self.store[sha], path
-
-    def __iter__(self):
-        for sha, path in self.p.imap_unordered(self.retrieve, self.itershas()):
-            yield sha, path
-
-    def __len__(self):
-        if len(self._shas) > 0:
-            return len(self._shas)
-        while self.finder.objects_to_send:
-            jobs = []
-            for _ in range(0, len(self.finder.objects_to_send)):
-                jobs.append(self.p.spawn(self.finder.next))
-            gevent.joinall(jobs)
-            for j in jobs:
-                if j.value is not None:
-                    self._shas.append(j.value)
-        return len(self._shas)

+ 1 - 1
dulwich/hooks.py

@@ -28,7 +28,7 @@ from dulwich.errors import (
 )
 
 
-class Hook(object):
+class Hook:
     """Generic hook object."""
 
     def execute(self, *args):

+ 11 - 11
dulwich/ignore.py

@@ -154,7 +154,7 @@ def match_pattern(path: bytes, pattern: bytes, ignorecase: bool = False) -> bool
     return Pattern(pattern, ignorecase).match(path)
 
 
-class Pattern(object):
+class Pattern:
     """A single ignore pattern."""
 
     def __init__(self, pattern: bytes, ignorecase: bool = False):
@@ -186,7 +186,7 @@ class Pattern(object):
         )
 
     def __repr__(self) -> str:
-        return "%s(%r, %r)" % (
+        return "{}({!r}, {!r})".format(
             type(self).__name__,
             self.pattern,
             self.ignorecase,
@@ -202,9 +202,9 @@ class Pattern(object):
         return bool(self._re.match(path))
 
 
-class IgnoreFilter(object):
+class IgnoreFilter:
     def __init__(self, patterns: Iterable[bytes], ignorecase: bool = False, path=None):
-        self._patterns = []  # type: List[Pattern]
+        self._patterns: List[Pattern] = []
         self._ignorecase = ignorecase
         self._path = path
         for pattern in patterns:
@@ -249,12 +249,12 @@ class IgnoreFilter(object):
     def __repr__(self) -> str:
         path = getattr(self, "_path", None)
         if path is not None:
-            return "%s.from_path(%r)" % (type(self).__name__, path)
+            return "{}.from_path({!r})".format(type(self).__name__, path)
         else:
             return "<%s>" % (type(self).__name__)
 
 
-class IgnoreFilterStack(object):
+class IgnoreFilterStack:
     """Check for ignore status in multiple filters."""
 
     def __init__(self, filters):
@@ -295,7 +295,7 @@ def default_user_ignore_filter_path(config: Config) -> str:
     return get_xdg_config_home_path("git", "ignore")
 
 
-class IgnoreFilterManager(object):
+class IgnoreFilterManager:
     """Ignore file manager."""
 
     def __init__(
@@ -304,13 +304,13 @@ class IgnoreFilterManager(object):
         global_filters: List[IgnoreFilter],
         ignorecase: bool,
     ):
-        self._path_filters = {}  # type: Dict[str, Optional[IgnoreFilter]]
+        self._path_filters: Dict[str, Optional[IgnoreFilter]] = {}
         self._top_path = top_path
         self._global_filters = global_filters
         self._ignorecase = ignorecase
 
     def __repr__(self) -> str:
-        return "%s(%s, %r, %r)" % (
+        return "{}({}, {!r}, {!r})".format(
             type(self).__name__,
             self._top_path,
             self._global_filters,
@@ -326,7 +326,7 @@ class IgnoreFilterManager(object):
         p = os.path.join(self._top_path, path, ".gitignore")
         try:
             self._path_filters[path] = IgnoreFilter.from_path(p, self._ignorecase)
-        except IOError:
+        except OSError:
             self._path_filters[path] = None
         return self._path_filters[path]
 
@@ -389,7 +389,7 @@ class IgnoreFilterManager(object):
         ]:
             try:
                 global_filters.append(IgnoreFilter.from_path(os.path.expanduser(p)))
-            except IOError:
+            except OSError:
                 pass
         config = repo.get_config_stack()
         ignorecase = config.get_boolean((b"core"), (b"ignorecase"), False)

+ 18 - 20
dulwich/index.py

@@ -32,16 +32,12 @@ from typing import (
     Dict,
     List,
     Optional,
-    TYPE_CHECKING,
     Iterable,
     Iterator,
     Tuple,
     Union,
 )
 
-if TYPE_CHECKING:
-    from dulwich.object_store import BaseObjectStore
-
 from dulwich.file import GitFile
 from dulwich.objects import (
     Blob,
@@ -52,9 +48,11 @@ from dulwich.objects import (
     sha_to_hex,
     ObjectID,
 )
+from dulwich.object_store import iter_tree_contents
 from dulwich.pack import (
     SHA1Reader,
     SHA1Writer,
+    ObjectContainer,
 )
 
 
@@ -174,7 +172,7 @@ def read_cache_entry(f, version: int) -> Tuple[str, IndexEntry]:
         (extended_flags, ) = struct.unpack(">H", f.read(2))
     else:
         extended_flags = 0
-    name = f.read((flags & 0x0FFF))
+    name = f.read(flags & 0x0FFF)
     # Padding:
     if version < 4:
         real_size = (f.tell() - beginoffset + 8) & ~7
@@ -313,7 +311,7 @@ def cleanup_mode(mode: int) -> int:
     return ret
 
 
-class Index(object):
+class Index:
     """A Git Index file."""
 
     def __init__(self, filename: Union[bytes, str], read=True):
@@ -335,7 +333,7 @@ class Index(object):
         return self._filename
 
     def __repr__(self):
-        return "%s(%r)" % (self.__class__.__name__, self._filename)
+        return "{}({!r})".format(self.__class__.__name__, self._filename)
 
     def write(self) -> None:
         """Write current contents of index to disk."""
@@ -431,14 +429,13 @@ class Index(object):
             entry = self[path]
             return entry.sha, cleanup_mode(entry.mode)
 
-        for (name, mode, sha) in changes_from_tree(
+        yield from changes_from_tree(
             self._byname.keys(),
             lookup_entry,
             object_store,
             tree,
             want_unchanged=want_unchanged,
-        ):
-            yield (name, mode, sha)
+        )
 
     def commit(self, object_store):
         """Create a new tree from an index.
@@ -452,7 +449,7 @@ class Index(object):
 
 
 def commit_tree(
-    object_store: "BaseObjectStore", blobs: Iterable[Tuple[bytes, bytes, int]]
+    object_store: ObjectContainer, blobs: Iterable[Tuple[bytes, bytes, int]]
 ) -> bytes:
     """Commit a new tree.
 
@@ -462,7 +459,7 @@ def commit_tree(
     Returns:
       SHA1 of the created tree.
     """
-    trees = {b"": {}}  # type: Dict[bytes, Any]
+    trees: Dict[bytes, Any] = {b"": {}}
 
     def add_tree(path):
         if path in trees:
@@ -495,7 +492,7 @@ def commit_tree(
     return build_tree(b"")
 
 
-def commit_index(object_store: "BaseObjectStore", index: Index) -> bytes:
+def commit_index(object_store: ObjectContainer, index: Index) -> bytes:
     """Create a new tree from an index.
 
     Args:
@@ -510,7 +507,7 @@ def commit_index(object_store: "BaseObjectStore", index: Index) -> bytes:
 def changes_from_tree(
     names: Iterable[bytes],
     lookup_entry: Callable[[bytes], Tuple[bytes, int]],
-    object_store: "BaseObjectStore",
+    object_store: ObjectContainer,
     tree: Optional[bytes],
     want_unchanged=False,
 ) -> Iterable[
@@ -536,7 +533,7 @@ def changes_from_tree(
     other_names = set(names)
 
     if tree is not None:
-        for (name, mode, sha) in object_store.iter_tree_contents(tree):
+        for (name, mode, sha) in iter_tree_contents(object_store, tree):
             try:
                 (other_sha, other_mode) = lookup_entry(name)
             except KeyError:
@@ -687,7 +684,7 @@ def validate_path(path: bytes,
 def build_index_from_tree(
     root_path: Union[str, bytes],
     index_path: Union[str, bytes],
-    object_store: "BaseObjectStore",
+    object_store: ObjectContainer,
     tree_id: bytes,
     honor_filemode: bool = True,
     validate_path_element=validate_path_element_default,
@@ -712,7 +709,7 @@ def build_index_from_tree(
     if not isinstance(root_path, bytes):
         root_path = os.fsencode(root_path)
 
-    for entry in object_store.iter_tree_contents(tree_id):
+    for entry in iter_tree_contents(object_store, tree_id):
         if not validate_path(entry.path, validate_path_element):
             continue
         full_path = _tree_to_fs_path(root_path, entry.path)
@@ -728,6 +725,7 @@ def build_index_from_tree(
             # TODO(jelmer): record and return submodule paths
         else:
             obj = object_store[entry.sha]
+            assert isinstance(obj, Blob)
             st = build_file_from_blob(
                 obj, entry.mode, full_path,
                 honor_filemode=honor_filemode,
@@ -928,7 +926,7 @@ def index_entry_from_directory(st, path: bytes) -> Optional[IndexEntry]:
 
 
 def index_entry_from_path(
-        path: bytes, object_store: Optional["BaseObjectStore"] = None
+        path: bytes, object_store: Optional[ObjectContainer] = None
 ) -> Optional[IndexEntry]:
     """Create an index from a filesystem path.
 
@@ -958,7 +956,7 @@ def index_entry_from_path(
 
 def iter_fresh_entries(
     paths: Iterable[bytes], root_path: bytes,
-    object_store: Optional["BaseObjectStore"] = None
+    object_store: Optional[ObjectContainer] = None
 ) -> Iterator[Tuple[bytes, Optional[IndexEntry]]]:
     """Iterate over current versions of index entries on disk.
 
@@ -1014,7 +1012,7 @@ def refresh_index(index: Index, root_path: bytes):
             index[path] = entry
 
 
-class locked_index(object):
+class locked_index:
     """Lock the index while making modifications.
 
     Works as a context manager.

+ 1 - 1
dulwich/lfs.py

@@ -23,7 +23,7 @@ import os
 import tempfile
 
 
-class LFSStore(object):
+class LFSStore:
     """Stores objects on disk, indexed by SHA256."""
 
     def __init__(self, path):

+ 3 - 2
dulwich/line_ending.py

@@ -136,6 +136,7 @@ Sources:
 - https://adaptivepatchwork.com/2012/03/01/mind-the-end-of-your-line/
 """
 
+from dulwich.object_store import iter_tree_contents
 from dulwich.objects import Blob
 from dulwich.patch import is_binary
 
@@ -214,7 +215,7 @@ def get_checkin_filter_autocrlf(core_autocrlf):
     return None
 
 
-class BlobNormalizer(object):
+class BlobNormalizer:
     """An object to store computation result of which filter to apply based
     on configuration, gitattributes, path and operation (checkin or checkout)
     """
@@ -290,7 +291,7 @@ class TreeBlobNormalizer(BlobNormalizer):
         if tree:
             self.existing_paths = {
                 name
-                for name, _, _ in object_store.iter_tree_contents(tree)
+                for name, _, _ in iter_tree_contents(object_store, tree)
             }
         else:
             self.existing_paths = set()

+ 59 - 35
dulwich/lru_cache.py

@@ -1,5 +1,6 @@
 # lru_cache.py -- Simple LRU cache for dulwich
 # Copyright (C) 2006, 2008 Canonical Ltd
+# Copyright (C) 2022 Jelmer Vernooij <jelmer@jelmer.uk>
 #
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
 # General Public License as public by the Free Software Foundation; version 2.0
@@ -20,17 +21,28 @@
 
 """A simple least-recently-used (LRU) cache."""
 
+from typing import Generic, TypeVar, Optional, Callable, Dict, Iterable, Iterator
+
+
 _null_key = object()
 
 
-class _LRUNode(object):
+K = TypeVar('K')
+V = TypeVar('V')
+
+
+class _LRUNode(Generic[K, V]):
     """This maintains the linked-list which is the lru internals."""
 
     __slots__ = ("prev", "next_key", "key", "value", "cleanup", "size")
 
-    def __init__(self, key, value, cleanup=None):
+    prev: Optional["_LRUNode[K, V]"]
+    next_key: K
+    size: Optional[int]
+
+    def __init__(self, key: K, value: V, cleanup=None):
         self.prev = None
-        self.next_key = _null_key
+        self.next_key = _null_key  # type: ignore
         self.key = key
         self.value = value
         self.cleanup = cleanup
@@ -44,36 +56,39 @@ class _LRUNode(object):
             prev_key = None
         else:
             prev_key = self.prev.key
-        return "%s(%r n:%r p:%r)" % (
+        return "{}({!r} n:{!r} p:{!r})".format(
             self.__class__.__name__,
             self.key,
             self.next_key,
             prev_key,
         )
 
-    def run_cleanup(self):
+    def run_cleanup(self) -> None:
         if self.cleanup is not None:
             self.cleanup(self.key, self.value)
         self.cleanup = None
         # Just make sure to break any refcycles, etc
-        self.value = None
+        del self.value
 
 
-class LRUCache(object):
+class LRUCache(Generic[K, V]):
     """A class which manages a cache of entries, removing unused ones."""
 
-    def __init__(self, max_cache=100, after_cleanup_count=None):
-        self._cache = {}
+    _least_recently_used: Optional[_LRUNode[K, V]]
+    _most_recently_used: Optional[_LRUNode[K, V]]
+
+    def __init__(self, max_cache: int = 100, after_cleanup_count: Optional[int] = None) -> None:
+        self._cache: Dict[K, _LRUNode[K, V]] = {}
         # The "HEAD" of the lru linked list
         self._most_recently_used = None
         # The "TAIL" of the lru linked list
         self._least_recently_used = None
         self._update_max_cache(max_cache, after_cleanup_count)
 
-    def __contains__(self, key):
+    def __contains__(self, key: K) -> bool:
         return key in self._cache
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: K) -> V:
         cache = self._cache
         node = cache[key]
         # Inlined from _record_access to decrease the overhead of __getitem__
@@ -96,6 +111,8 @@ class LRUCache(object):
         else:
             node_next = cache[next_key]
             node_next.prev = node_prev
+        assert node_prev
+        assert mru
         node_prev.next_key = next_key
         # Insert this node at the front of the list
         node.next_key = mru.key
@@ -104,10 +121,10 @@ class LRUCache(object):
         node.prev = None
         return node.value
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self._cache)
 
-    def _walk_lru(self):
+    def _walk_lru(self) -> Iterator[_LRUNode[K, V]]:
         """Walk the LRU list, only meant to be used in tests."""
         node = self._most_recently_used
         if node is not None:
@@ -144,7 +161,7 @@ class LRUCache(object):
             yield node
             node = node_next
 
-    def add(self, key, value, cleanup=None):
+    def add(self, key: K, value: V, cleanup: Optional[Callable[[K, V], None]] = None) -> None:
         """Add a new value to the cache.
 
         Also, if the entry is ever removed from the cache, call
@@ -172,18 +189,18 @@ class LRUCache(object):
             # Trigger the cleanup
             self.cleanup()
 
-    def cache_size(self):
+    def cache_size(self) -> int:
         """Get the number of entries we will cache."""
         return self._max_cache
 
-    def get(self, key, default=None):
+    def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
         node = self._cache.get(key, None)
         if node is None:
             return default
         self._record_access(node)
         return node.value
 
-    def keys(self):
+    def keys(self) -> Iterable[K]:
         """Get the list of keys currently cached.
 
         Note that values returned here may not be available by the time you
@@ -194,7 +211,7 @@ class LRUCache(object):
         """
         return self._cache.keys()
 
-    def items(self):
+    def items(self) -> Dict[K, V]:
         """Get the key:value pairs as a dict."""
         return {k: n.value for k, n in self._cache.items()}
 
@@ -208,11 +225,11 @@ class LRUCache(object):
         while len(self._cache) > self._after_cleanup_count:
             self._remove_lru()
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: K, value: V) -> None:
         """Add a value to the cache, there will be no cleanup function."""
         self.add(key, value, cleanup=None)
 
-    def _record_access(self, node):
+    def _record_access(self, node: _LRUNode[K, V]) -> None:
         """Record that key was accessed."""
         # Move 'node' to the front of the queue
         if self._most_recently_used is None:
@@ -238,7 +255,7 @@ class LRUCache(object):
         self._most_recently_used = node
         node.prev = None
 
-    def _remove_node(self, node):
+    def _remove_node(self, node: _LRUNode[K, V]) -> None:
         if node is self._least_recently_used:
             self._least_recently_used = node.prev
         self._cache.pop(node.key)
@@ -254,23 +271,24 @@ class LRUCache(object):
             node_next.prev = node.prev
         # And remove this node's pointers
         node.prev = None
-        node.next_key = _null_key
+        node.next_key = _null_key  # type: ignore
 
-    def _remove_lru(self):
+    def _remove_lru(self) -> None:
         """Remove one entry from the lru, and handle consequences.
 
         If there are no more references to the lru, then this entry should be
         removed from the cache.
         """
+        assert self._least_recently_used
         self._remove_node(self._least_recently_used)
 
-    def clear(self):
+    def clear(self) -> None:
         """Clear out all of the cache."""
         # Clean up in LRU order
         while self._cache:
             self._remove_lru()
 
-    def resize(self, max_cache, after_cleanup_count=None):
+    def resize(self, max_cache: int, after_cleanup_count: Optional[int] = None) -> None:
         """Change the number of entries that will be cached."""
         self._update_max_cache(max_cache, after_cleanup_count=after_cleanup_count)
 
@@ -283,7 +301,7 @@ class LRUCache(object):
         self.cleanup()
 
 
-class LRUSizeCache(LRUCache):
+class LRUSizeCache(LRUCache[K, V]):
     """An LRUCache that removes things based on the size of the values.
 
     This differs in that it doesn't care how many actual items there are,
@@ -293,9 +311,12 @@ class LRUSizeCache(LRUCache):
     defaults to len() if not supplied.
     """
 
+    _compute_size: Callable[[V], int]
+
     def __init__(
-        self, max_size=1024 * 1024, after_cleanup_size=None, compute_size=None
-    ):
+            self, max_size: int = 1024 * 1024, after_cleanup_size: Optional[int] = None,
+            compute_size: Optional[Callable[[V], int]] = None
+    ) -> None:
         """Create a new LRUSizeCache.
 
         Args:
@@ -311,13 +332,14 @@ class LRUSizeCache(LRUCache):
             If not supplied, it defaults to 'len()'
         """
         self._value_size = 0
-        self._compute_size = compute_size
         if compute_size is None:
-            self._compute_size = len
+            self._compute_size = len  # type: ignore
+        else:
+            self._compute_size = compute_size
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
         LRUCache.__init__(self, max_cache=max(int(max_size / 512), 1))
 
-    def add(self, key, value, cleanup=None):
+    def add(self, key: K, value: V, cleanup: Optional[Callable[[K, V], None]] = None) -> None:
         """Add a new value to the cache.
 
         Also, if the entry is ever removed from the cache, call
@@ -346,6 +368,7 @@ class LRUSizeCache(LRUCache):
             node = _LRUNode(key, value, cleanup=cleanup)
             self._cache[key] = node
         else:
+            assert node.size is not None
             self._value_size -= node.size
         node.size = value_len
         self._value_size += value_len
@@ -355,7 +378,7 @@ class LRUSizeCache(LRUCache):
             # Time to cleanup
             self.cleanup()
 
-    def cleanup(self):
+    def cleanup(self) -> None:
         """Clear the cache until it shrinks to the requested size.
 
         This does not completely wipe the cache, just makes sure it is under
@@ -365,17 +388,18 @@ class LRUSizeCache(LRUCache):
         while self._value_size > self._after_cleanup_size:
             self._remove_lru()
 
-    def _remove_node(self, node):
+    def _remove_node(self, node: _LRUNode[K, V]) -> None:
+        assert node.size is not None
         self._value_size -= node.size
         LRUCache._remove_node(self, node)
 
-    def resize(self, max_size, after_cleanup_size=None):
+    def resize(self, max_size: int, after_cleanup_size: Optional[int] = None) -> None:
         """Change the number of bytes that will be cached."""
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
         max_cache = max(int(max_size / 512), 1)
         self._update_max_cache(max_cache)
 
-    def _update_max_size(self, max_size, after_cleanup_size=None):
+    def _update_max_size(self, max_size: int, after_cleanup_size: Optional[int] = None) -> None:
         self._max_size = max_size
         if after_cleanup_size is None:
             self._after_cleanup_size = self._max_size * 8 // 10

+ 1 - 1
dulwich/mailmap.py

@@ -58,7 +58,7 @@ def read_mailmap(f):
         yield parsed_canonical_identity, parsed_from_identity
 
 
-class Mailmap(object):
+class Mailmap:
     """Class for accessing a mailmap file."""
 
     def __init__(self, map=None):

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 320 - 321
dulwich/object_store.py


+ 40 - 26
dulwich/objects.py

@@ -37,6 +37,7 @@ from typing import (
     List,
 )
 import zlib
+from _hashlib import HASH
 from hashlib import sha1
 
 from dulwich.errors import (
@@ -104,7 +105,7 @@ def _decompress(string):
 def sha_to_hex(sha):
     """Takes a string and returns the hex of the sha within"""
     hexsha = binascii.hexlify(sha)
-    assert len(hexsha) == 40, "Incorrect length of sha1 string: %s" % hexsha
+    assert len(hexsha) == 40, "Incorrect length of sha1 string: %r" % hexsha
     return hexsha
 
 
@@ -135,7 +136,7 @@ def hex_to_filename(path, hex):
     # os.path.join accepts bytes or unicode, but all args must be of the same
     # type. Make sure that hex which is expected to be bytes, is the same type
     # as path.
-    if getattr(path, "encode", None) is not None:
+    if type(path) != type(hex) and getattr(path, "encode", None) is not None:
         hex = hex.decode("ascii")
     dir = hex[:2]
     file = hex[2:]
@@ -198,7 +199,7 @@ def check_hexsha(hex, error_msg):
       ObjectFormatException: Raised when the string is not valid
     """
     if not valid_hexsha(hex):
-        raise ObjectFormatException("%s %s" % (error_msg, hex))
+        raise ObjectFormatException("{} {}".format(error_msg, hex))
 
 
 def check_identity(identity: bytes, error_msg: str) -> None:
@@ -242,7 +243,7 @@ def git_line(*items):
     return b" ".join(items) + b"\n"
 
 
-class FixedSha(object):
+class FixedSha:
     """SHA object that behaves like hashlib's but is given a fixed value."""
 
     __slots__ = ("_hexsha", "_sha")
@@ -264,7 +265,7 @@ class FixedSha(object):
         return self._hexsha.decode("ascii")
 
 
-class ShaFile(object):
+class ShaFile:
     """A git SHA file."""
 
     __slots__ = ("_chunked_text", "_sha", "_needs_serialization")
@@ -273,6 +274,7 @@ class ShaFile(object):
     type_name: bytes
     type_num: int
     _chunked_text: Optional[List[bytes]]
+    _sha: Union[FixedSha, None, HASH]
 
     @staticmethod
     def _parse_legacy_object_header(magic, f) -> "ShaFile":
@@ -454,7 +456,10 @@ class ShaFile(object):
           string: The raw uncompressed contents.
           sha: Optional known sha for the object
         """
-        obj = object_class(type_num)()
+        cls = object_class(type_num)
+        if cls is None:
+            raise AssertionError("unsupported class type num: %d" % type_num)
+        obj = cls()
         obj.set_raw_string(string, sha)
         return obj
 
@@ -523,10 +528,7 @@ class ShaFile(object):
 
     def raw_length(self) -> int:
         """Returns the length of the raw string of this object."""
-        ret = 0
-        for chunk in self.as_raw_chunks():
-            ret += len(chunk)
-        return ret
+        return sum(map(len, self.as_raw_chunks()))
 
     def sha(self):
         """The SHA1 object that is the name of this object."""
@@ -542,6 +544,8 @@ class ShaFile(object):
     def copy(self):
         """Create a new copy of this SHA1 object from its raw string"""
         obj_class = object_class(self.type_num)
+        if obj_class is None:
+            raise AssertionError('invalid type num %d' % self.type_num)
         return obj_class.from_raw_string(self.type_num, self.as_raw_string(), self.id)
 
     @property
@@ -550,7 +554,7 @@ class ShaFile(object):
         return self.sha().hexdigest().encode("ascii")
 
     def __repr__(self):
-        return "<%s %s>" % (self.__class__.__name__, self.id)
+        return "<{} {}>".format(self.__class__.__name__, self.id)
 
     def __ne__(self, other):
         """Check whether this object does not match the other."""
@@ -581,8 +585,10 @@ class Blob(ShaFile):
     type_name = b"blob"
     type_num = 3
 
+    _chunked_text: List[bytes]
+
     def __init__(self):
-        super(Blob, self).__init__()
+        super().__init__()
         self._chunked_text = []
         self._needs_serialization = False
 
@@ -599,7 +605,7 @@ class Blob(ShaFile):
     def _get_chunked(self):
         return self._chunked_text
 
-    def _set_chunked(self, chunks):
+    def _set_chunked(self, chunks: List[bytes]):
         self._chunked_text = chunks
 
     def _serialize(self):
@@ -627,7 +633,7 @@ class Blob(ShaFile):
         Raises:
           ObjectFormatException: if the object is malformed in some way
         """
-        super(Blob, self).check()
+        super().check()
 
     def splitlines(self) -> List[bytes]:
         """Return list of lines in this blob.
@@ -729,8 +735,10 @@ class Tag(ShaFile):
         "_signature",
     )
 
+    _tagger: Optional[bytes]
+
     def __init__(self):
-        super(Tag, self).__init__()
+        super().__init__()
         self._tagger = None
         self._tag_time = None
         self._tag_timezone = None
@@ -750,7 +758,8 @@ class Tag(ShaFile):
         Raises:
           ObjectFormatException: if the object is malformed in some way
         """
-        super(Tag, self).check()
+        super().check()
+        assert self._chunked_text is not None
         self._check_has_member("_object_sha", "missing object sha")
         self._check_has_member("_object_class", "missing object type")
         self._check_has_member("_name", "missing tag name")
@@ -760,7 +769,7 @@ class Tag(ShaFile):
 
         check_hexsha(self._object_sha, "invalid object sha")
 
-        if getattr(self, "_tagger", None):
+        if self._tagger is not None:
             check_identity(self._tagger, "invalid tagger")
 
         self._check_has_member("_tag_time", "missing tag time")
@@ -977,7 +986,7 @@ def serialize_tree(items):
         )
 
 
-def sorted_tree_items(entries, name_order):
+def sorted_tree_items(entries, name_order: bool):
     """Iterate over a tree entries dictionary.
 
     Args:
@@ -987,7 +996,10 @@ def sorted_tree_items(entries, name_order):
       entries: Dictionary mapping names to (mode, sha) tuples
     Returns: Iterator over (name, mode, hexsha)
     """
-    key_func = name_order and key_entry_name_order or key_entry
+    if name_order:
+        key_func = key_entry_name_order
+    else:
+        key_func = key_entry
     for name, entry in sorted(entries.items(), key=key_func):
         mode, hexsha = entry
         # Stricter type checks than normal to mirror checks in the C version.
@@ -1027,7 +1039,7 @@ def pretty_format_tree_entry(name, mode, hexsha, encoding="utf-8"):
         kind = "tree"
     else:
         kind = "blob"
-    return "%04o %s %s\t%s\n" % (
+    return "{:04o} {} {}\t{}\n".format(
         mode,
         kind,
         hexsha.decode("ascii"),
@@ -1052,7 +1064,7 @@ class Tree(ShaFile):
     __slots__ = "_entries"
 
     def __init__(self):
-        super(Tree, self).__init__()
+        super().__init__()
         self._entries = {}
 
     @classmethod
@@ -1129,7 +1141,7 @@ class Tree(ShaFile):
         # TODO: list comprehension is for efficiency in the common (small)
         # case; if memory efficiency in the large case is a concern, use a
         # genexp.
-        self._entries = dict([(n, (m, s)) for n, m, s in parsed_entries])
+        self._entries = {n: (m, s) for n, m, s in parsed_entries}
 
     def check(self):
         """Check this object for internal consistency.
@@ -1137,7 +1149,8 @@ class Tree(ShaFile):
         Raises:
           ObjectFormatException: if the object is malformed in some way
         """
-        super(Tree, self).check()
+        super().check()
+        assert self._chunked_text is not None
         last = None
         allowed_modes = (
             stat.S_IFREG | 0o755,
@@ -1346,7 +1359,7 @@ class Commit(ShaFile):
     )
 
     def __init__(self):
-        super(Commit, self).__init__()
+        super().__init__()
         self._parents = []
         self._encoding = None
         self._mergetag = []
@@ -1391,7 +1404,8 @@ class Commit(ShaFile):
         Raises:
           ObjectFormatException: if the object is malformed in some way
         """
-        super(Commit, self).check()
+        super().check()
+        assert self._chunked_text is not None
         self._check_has_member("_tree", "missing tree")
         self._check_has_member("_author", "missing author")
         self._check_has_member("_committer", "missing committer")
@@ -1523,7 +1537,7 @@ class Commit(ShaFile):
                 chunks[-1] = chunks[-1][:-2]
         for k, v in self.extra:
             if b"\n" in k or b"\n" in v:
-                raise AssertionError("newline in extra data: %r -> %r" % (k, v))
+                raise AssertionError("newline in extra data: {!r} -> {!r}".format(k, v))
             chunks.append(git_line(k, v))
         if self.gpgsig:
             sig_chunks = self.gpgsig.split(b"\n")

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 294 - 176
dulwich/pack.py


+ 14 - 12
dulwich/patch.py

@@ -27,12 +27,14 @@ on.
 from difflib import SequenceMatcher
 import email.parser
 import time
+from typing import Union, TextIO, BinaryIO, Optional
 
 from dulwich.objects import (
     Blob,
     Commit,
     S_ISGITLINK,
 )
+from dulwich.pack import ObjectContainer
 
 FIRST_FEW_BYTES = 8000
 
@@ -108,10 +110,10 @@ def _format_range_unified(start, stop):
     beginning = start + 1  # lines start numbering with one
     length = stop - start
     if length == 1:
-        return "{}".format(beginning)
+        return f"{beginning}"
     if not length:
         beginning -= 1  # empty ranges begin at line just before the range
-    return "{},{}".format(beginning, length)
+    return f"{beginning},{length}"
 
 
 def unified_diff(
@@ -135,8 +137,8 @@ def unified_diff(
     for group in SequenceMatcher(None, a, b).get_grouped_opcodes(n):
         if not started:
             started = True
-            fromdate = "\t{}".format(fromfiledate) if fromfiledate else ""
-            todate = "\t{}".format(tofiledate) if tofiledate else ""
+            fromdate = f"\t{fromfiledate}" if fromfiledate else ""
+            todate = f"\t{tofiledate}" if tofiledate else ""
             yield "--- {}{}{}".format(
                 fromfile.decode(tree_encoding), fromdate, lineterm
             ).encode(output_encoding)
@@ -147,7 +149,7 @@ def unified_diff(
         first, last = group[0], group[-1]
         file1_range = _format_range_unified(first[1], last[2])
         file2_range = _format_range_unified(first[3], last[4])
-        yield "@@ -{} +{} @@{}".format(file1_range, file2_range, lineterm).encode(
+        yield f"@@ -{file1_range} +{file2_range} @@{lineterm}".encode(
             output_encoding
         )
 
@@ -191,7 +193,7 @@ def patch_filename(p, root):
         return root + b"/" + p
 
 
-def write_object_diff(f, store, old_file, new_file, diff_binary=False):
+def write_object_diff(f, store: ObjectContainer, old_file, new_file, diff_binary=False):
     """Write the diff for an object.
 
     Args:
@@ -338,7 +340,7 @@ def write_tree_diff(f, store, old_tree, new_tree, diff_binary=False):
         )
 
 
-def git_am_patch_split(f, encoding=None):
+def git_am_patch_split(f: Union[TextIO, BinaryIO], encoding: Optional[str] = None):
     """Parse a git-am-style patch and split it up into bits.
 
     Args:
@@ -349,12 +351,12 @@ def git_am_patch_split(f, encoding=None):
     encoding = encoding or getattr(f, "encoding", "ascii")
     encoding = encoding or "ascii"
     contents = f.read()
-    if isinstance(contents, bytes) and getattr(email.parser, "BytesParser", None):
-        parser = email.parser.BytesParser()
-        msg = parser.parsebytes(contents)
+    if isinstance(contents, bytes):
+        bparser = email.parser.BytesParser()
+        msg = bparser.parsebytes(contents)
     else:
-        parser = email.parser.Parser()
-        msg = parser.parsestr(contents)
+        uparser = email.parser.Parser()
+        msg = uparser.parsestr(contents)
     return parse_patch_message(msg, encoding)
 
 

+ 34 - 14
dulwich/porcelain.py

@@ -134,7 +134,7 @@ from dulwich.objectspec import (
 )
 from dulwich.pack import (
     write_pack_index,
-    write_pack_objects,
+    write_pack_from_container,
 )
 from dulwich.patch import write_tree_diff
 from dulwich.protocol import (
@@ -187,7 +187,7 @@ class Error(Exception):
     """Porcelain-based error. """
 
     def __init__(self, msg):
-        super(Error, self).__init__(msg)
+        super().__init__(msg)
 
 
 class RemoteExists(Error):
@@ -407,6 +407,18 @@ def symbolic_ref(repo, ref_name, force=False):
         repo_obj.refs.set_symbolic_ref(b"HEAD", ref_path)
 
 
+def pack_refs(repo, all=False):
+    with open_repo_closing(repo) as repo_obj:
+        refs = repo_obj.refs
+        packed_refs = {
+            ref: refs[ref]
+            for ref in refs
+            if (all or ref.startswith(LOCAL_TAG_PREFIX)) and ref != b"HEAD"
+        }
+
+        refs.add_packed_refs(packed_refs)
+
+
 def commit(
     repo=".",
     message=None,
@@ -687,7 +699,7 @@ def remove(repo=".", paths=None, cached=False):
                 else:
                     try:
                         blob = blob_from_path_and_stat(full_path, st)
-                    except IOError:
+                    except OSError:
                         pass
                     else:
                         try:
@@ -1023,7 +1035,7 @@ def submodule_list(repo):
     from .submodule import iter_cached_submodules
     with open_repo_closing(repo) as r:
         for path, sha in iter_cached_submodules(r.object_store, r[r.head()].tree):
-            yield path.decode(DEFAULT_ENCODING), sha.decode(DEFAULT_ENCODING)
+            yield path, sha.decode(DEFAULT_ENCODING)
 
 
 def tag_create(
@@ -1146,7 +1158,7 @@ def get_remote_repo(
 
     section = (b"remote", encoded_location)
 
-    remote_name = None  # type: Optional[str]
+    remote_name: Optional[str] = None
 
     if config.has_section(section):
         remote_name = encoded_location.decode()
@@ -1741,7 +1753,7 @@ def repack(repo):
         r.object_store.pack_loose_objects()
 
 
-def pack_objects(repo, object_ids, packf, idxf, delta_window_size=None):
+def pack_objects(repo, object_ids, packf, idxf, delta_window_size=None, deltify=None, reuse_deltas=True):
     """Pack objects into a file.
 
     Args:
@@ -1749,12 +1761,19 @@ def pack_objects(repo, object_ids, packf, idxf, delta_window_size=None):
       object_ids: List of object ids to write
       packf: File-like object to write to
       idxf: File-like object to write to (can be None)
+      delta_window_size: Sliding window size for searching for deltas;
+                         Set to None for default window size.
+      deltify: Whether to deltify objects
+      reuse_deltas: Allow reuse of existing deltas while deltifying
     """
     with open_repo_closing(repo) as r:
-        entries, data_sum = write_pack_objects(
+        entries, data_sum = write_pack_from_container(
             packf.write,
-            r.object_store.iter_shas((oid, None) for oid in object_ids),
+            r.object_store,
+            [(oid, None) for oid in object_ids],
+            deltify=deltify,
             delta_window_size=delta_window_size,
+            reuse_deltas=reuse_deltas,
         )
     if idxf is not None:
         entries = sorted([(k, v[0], v[1]) for (k, v) in entries.items()])
@@ -1985,11 +2004,12 @@ def find_unique_abbrev(object_store, object_id):
     return object_id.decode("ascii")[:7]
 
 
-def describe(repo):
+def describe(repo, abbrev=7):
     """Describe the repository version.
 
     Args:
       repo: git repository
+      abbrev: number of characters of commit to take, default is 7
     Returns: a string description of the current git revision
 
     Examples: "gabcdefh", "v0.1" or "v0.1-5-gabcdefh".
@@ -2002,10 +2022,10 @@ def describe(repo):
         for key, value in refs.items():
             key = key.decode()
             obj = r.get_object(value)
-            if u"tags" not in key:
+            if "tags" not in key:
                 continue
 
-            _, tag = key.rsplit(u"/", 1)
+            _, tag = key.rsplit("/", 1)
 
             try:
                 commit = obj.object
@@ -2022,7 +2042,7 @@ def describe(repo):
 
         # If there are no tags, return the current commit
         if len(sorted_tags) == 0:
-            return "g{}".format(find_unique_abbrev(r.object_store, r[r.head()].id))
+            return f"g{find_unique_abbrev(r.object_store, r[r.head()].id)}"
 
         # We're now 0 commits from the top
         commit_count = 0
@@ -2045,13 +2065,13 @@ def describe(repo):
                         return "{}-{}-g{}".format(
                             tag_name,
                             commit_count,
-                            latest_commit.id.decode("ascii")[:7],
+                            latest_commit.id.decode("ascii")[:abbrev],
                         )
 
             commit_count += 1
 
         # Return plain commit if no parent tag can be found
-        return "g{}".format(latest_commit.id.decode("ascii")[:7])
+        return "g{}".format(latest_commit.id.decode("ascii")[:abbrev])
 
 
 def get_object_by_path(repo, path, committish=None):

+ 6 - 7
dulwich/protocol.py

@@ -25,7 +25,6 @@ from io import BytesIO
 from os import (
     SEEK_END,
 )
-import socket
 
 import dulwich
 from dulwich.errors import (
@@ -171,7 +170,7 @@ def pkt_line(data):
     return ("%04x" % (len(data) + 4)).encode("ascii") + data
 
 
-class Protocol(object):
+class Protocol:
     """Class for interacting with a remote git process over the wire.
 
     Parts of the git wire protocol use 'pkt-lines' to communicate. A pkt-line
@@ -228,7 +227,7 @@ class Protocol(object):
             pkt_contents = read(size - 4)
         except ConnectionResetError as exc:
             raise HangupException() from exc
-        except socket.error as exc:
+        except OSError as exc:
             raise GitProtocolError(exc) from exc
         else:
             if len(pkt_contents) + 4 != size:
@@ -291,7 +290,7 @@ class Protocol(object):
             self.write(line)
             if self.report_activity:
                 self.report_activity(len(line), "write")
-        except socket.error as exc:
+        except OSError as exc:
             raise GitProtocolError(exc) from exc
 
     def write_sideband(self, channel, blob):
@@ -348,7 +347,7 @@ class ReceivableProtocol(Protocol):
     def __init__(
         self, recv, write, close=None, report_activity=None, rbufsize=_RBUFSIZE
     ):
-        super(ReceivableProtocol, self).__init__(
+        super().__init__(
             self.read, write, close=close, report_activity=report_activity
         )
         self._recv = recv
@@ -480,7 +479,7 @@ def ack_type(capabilities):
     return SINGLE_ACK
 
 
-class BufferedPktLineWriter(object):
+class BufferedPktLineWriter:
     """Writer that wraps its data in pkt-lines and has an independent buffer.
 
     Consecutive calls to write() wrap the data in a pkt-line and then buffers
@@ -524,7 +523,7 @@ class BufferedPktLineWriter(object):
         self._wbuf = BytesIO()
 
 
-class PktLineParser(object):
+class PktLineParser:
     """Packet line parser that hands completed packets off to a callback."""
 
     def __init__(self, handle_pkt):

+ 71 - 13
dulwich/refs.py

@@ -22,6 +22,7 @@
 """Ref handling.
 
 """
+from contextlib import suppress
 import os
 from typing import Dict, Optional
 
@@ -36,6 +37,7 @@ from dulwich.objects import (
     Tag,
     ObjectID,
 )
+from dulwich.pack import ObjectContainer
 from dulwich.file import (
     GitFile,
     ensure_dir_exists,
@@ -105,7 +107,7 @@ def check_ref_format(refname: Ref):
     return True
 
 
-class RefsContainer(object):
+class RefsContainer:
     """A container for refs."""
 
     def __init__(self, logger=None):
@@ -155,6 +157,15 @@ class RefsContainer(object):
         """
         raise NotImplementedError(self.get_packed_refs)
 
+    def add_packed_refs(self, new_refs: Dict[Ref, Optional[ObjectID]]):
+        """Add the given refs as packed refs.
+
+        Args:
+          new_refs: A mapping of ref names to targets; if a target is None that
+            means remove the ref
+        """
+        raise NotImplementedError(self.add_packed_refs)
+
     def get_peeled(self, name):
         """Return the cached peeled value of a ref, if available.
 
@@ -437,7 +448,7 @@ class DictRefsContainer(RefsContainer):
     """
 
     def __init__(self, refs, logger=None):
-        super(DictRefsContainer, self).__init__(logger=logger)
+        super().__init__(logger=logger)
         self._refs = refs
         self._peeled = {}
         self._watchers = set()
@@ -612,7 +623,7 @@ class DiskRefsContainer(RefsContainer):
     """Refs container that reads refs from disk."""
 
     def __init__(self, path, worktree_path=None, logger=None):
-        super(DiskRefsContainer, self).__init__(logger=logger)
+        super().__init__(logger=logger)
         if getattr(path, "encode", None) is not None:
             path = os.fsencode(path)
         self.path = path
@@ -625,7 +636,7 @@ class DiskRefsContainer(RefsContainer):
         self._peeled_refs = None
 
     def __repr__(self):
-        return "%s(%r)" % (self.__class__.__name__, self.path)
+        return "{}({!r})".format(self.__class__.__name__, self.path)
 
     def subkeys(self, base):
         subkeys = set()
@@ -706,6 +717,44 @@ class DiskRefsContainer(RefsContainer):
                         self._packed_refs[name] = sha
         return self._packed_refs
 
+    def add_packed_refs(self, new_refs: Dict[Ref, Optional[ObjectID]]):
+        """Add the given refs as packed refs.
+
+        Args:
+          new_refs: A mapping of ref names to targets; if a target is None that
+            means remove the ref
+        """
+        if not new_refs:
+            return
+
+        path = os.path.join(self.path, b"packed-refs")
+
+        with GitFile(path, "wb") as f:
+            # reread cached refs from disk, while holding the lock
+            packed_refs = self.get_packed_refs().copy()
+
+            for ref, target in new_refs.items():
+                # sanity check
+                if ref == HEADREF:
+                    raise ValueError("cannot pack HEAD")
+
+                # remove any loose refs pointing to this one -- please
+                # note that this bypasses remove_if_equals as we don't
+                # want to affect packed refs in here
+                try:
+                    os.remove(self.refpath(ref))
+                except OSError:
+                    pass
+
+                if target is not None:
+                    packed_refs[ref] = target
+                else:
+                    packed_refs.pop(ref, None)
+
+            write_packed_refs(f, packed_refs, self._peeled_refs)
+
+            self._packed_refs = packed_refs
+
     def get_peeled(self, name):
         """Return the cached peeled value of a ref, if available.
 
@@ -748,7 +797,10 @@ class DiskRefsContainer(RefsContainer):
                 else:
                     # Read only the first 40 bytes
                     return header + f.read(40 - len(SYMREF))
-        except (FileNotFoundError, IsADirectoryError, NotADirectoryError):
+        except (OSError, UnicodeError):
+            # don't assume anything specific about the error; in
+            # particular, invalid or forbidden paths can raise weird
+            # errors depending on the specific operating system
             return None
 
     def _remove_packed_ref(self, name):
@@ -765,7 +817,7 @@ class DiskRefsContainer(RefsContainer):
                 return
 
             del self._packed_refs[name]
-            if name in self._peeled_refs:
+            with suppress(KeyError):
                 del self._peeled_refs[name]
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             f.close()
@@ -860,12 +912,12 @@ class DiskRefsContainer(RefsContainer):
                     if orig_ref != old_ref:
                         f.abort()
                         return False
-                except (OSError, IOError):
+                except OSError:
                     f.abort()
                     raise
             try:
                 f.write(new_ref + b"\n")
-            except (OSError, IOError):
+            except OSError:
                 f.abort()
                 raise
             self._log(
@@ -915,7 +967,7 @@ class DiskRefsContainer(RefsContainer):
                 return False
             try:
                 f.write(ref + b"\n")
-            except (OSError, IOError):
+            except OSError:
                 f.abort()
                 raise
             else:
@@ -965,9 +1017,13 @@ class DiskRefsContainer(RefsContainer):
 
             # remove the reference file itself
             try:
+                found = os.path.lexists(filename)
+            except OSError:
+                # may only be packed, or otherwise unstorable
+                found = False
+
+            if found:
                 os.remove(filename)
-            except FileNotFoundError:
-                pass  # may only be packed
 
             self._remove_packed_ref(name)
             self._log(
@@ -1095,8 +1151,10 @@ def read_info_refs(f):
     return ret
 
 
-def write_info_refs(refs, store):
+def write_info_refs(refs, store: ObjectContainer):
     """Generate info refs."""
+    # Avoid recursive import :(
+    from dulwich.object_store import peel_sha
     for name, sha in sorted(refs.items()):
         # get_refs() includes HEAD as a special case, but we don't want to
         # advertise it
@@ -1106,7 +1164,7 @@ def write_info_refs(refs, store):
             o = store[sha]
         except KeyError:
             continue
-        peeled = store.peel_sha(sha)
+        peeled = peel_sha(store, sha)
         yield o.id + b"\t" + name + b"\n"
         if o.id != peeled.id:
             yield peeled.id + b"\t" + name + ANNOTATED_TAG_SUFFIX + b"\n"

+ 57 - 39
dulwich/repo.py

@@ -39,6 +39,7 @@ from typing import (
     Callable,
     Tuple,
     TYPE_CHECKING,
+    FrozenSet,
     List,
     Dict,
     Union,
@@ -70,8 +71,10 @@ from dulwich.file import (
 from dulwich.object_store import (
     DiskObjectStore,
     MemoryObjectStore,
-    BaseObjectStore,
+    MissingObjectFinder,
+    PackBasedObjectStore,
     ObjectStoreGraphWalker,
+    peel_sha,
 )
 from dulwich.objects import (
     check_hexsha,
@@ -84,7 +87,7 @@ from dulwich.objects import (
     ObjectID,
 )
 from dulwich.pack import (
-    pack_objects_to_data,
+    generate_unpacked_objects
 )
 
 from dulwich.hooks import (
@@ -180,7 +183,7 @@ def _get_default_identity() -> Tuple[str, str]:
         fullname = username
     email = os.environ.get("EMAIL")
     if email is None:
-        email = "{}@{}".format(username, socket.gethostname())
+        email = f"{username}@{socket.gethostname()}"
     return (fullname, email)  # type: ignore
 
 
@@ -205,8 +208,8 @@ def get_user_identity(config: "StackedConfig", kind: Optional[str] = None) -> by
     Returns:
       A user identity
     """
-    user = None  # type: Optional[bytes]
-    email = None  # type: Optional[bytes]
+    user: Optional[bytes] = None
+    email: Optional[bytes] = None
     if kind:
         user_uc = os.environ.get("GIT_" + kind + "_NAME")
         if user_uc is not None:
@@ -329,7 +332,7 @@ def _set_filesystem_hidden(path):
     # Could implement other platform specific filesystem hiding here
 
 
-class ParentsProvider(object):
+class ParentsProvider:
     def __init__(self, store, grafts={}, shallows=[]):
         self.store = store
         self.grafts = grafts
@@ -347,7 +350,7 @@ class ParentsProvider(object):
         return commit.parents
 
 
-class BaseRepo(object):
+class BaseRepo:
     """Base class for a git repository.
 
     This base class is meant to be used for Repository implementations that e.g.
@@ -360,7 +363,7 @@ class BaseRepo(object):
         repository
     """
 
-    def __init__(self, object_store: BaseObjectStore, refs: RefsContainer):
+    def __init__(self, object_store: PackBasedObjectStore, refs: RefsContainer):
         """Open a repository.
 
         This shouldn't be called directly, but rather through one of the
@@ -373,8 +376,8 @@ class BaseRepo(object):
         self.object_store = object_store
         self.refs = refs
 
-        self._graftpoints = {}  # type: Dict[bytes, List[bytes]]
-        self.hooks = {}  # type: Dict[str, Hook]
+        self._graftpoints: Dict[bytes, List[bytes]] = {}
+        self.hooks: Dict[str, Hook] = {}
 
     def _determine_file_mode(self) -> bool:
         """Probe the file-system to determine whether permissions can be trusted.
@@ -482,20 +485,23 @@ class BaseRepo(object):
           depth: Shallow fetch depth
         Returns: count and iterator over pack data
         """
-        # TODO(jelmer): Fetch pack data directly, don't create objects first.
-        objects = self.fetch_objects(
+        missing_objects = self.find_missing_objects(
             determine_wants, graph_walker, progress, get_tagged, depth=depth
         )
-        return pack_objects_to_data(objects)
+        remote_has = missing_objects.get_remote_has()
+        object_ids = list(missing_objects)
+        return len(object_ids), generate_unpacked_objects(
+            self.object_store, object_ids, progress=progress,
+            other_haves=remote_has)
 
-    def fetch_objects(
+    def find_missing_objects(
         self,
         determine_wants,
         graph_walker,
         progress,
         get_tagged=None,
         depth=None,
-    ):
+    ) -> Optional[MissingObjectFinder]:
         """Fetch the missing objects required for a set of revisions.
 
         Args:
@@ -534,8 +540,8 @@ class BaseRepo(object):
         if not isinstance(wants, list):
             raise TypeError("determine_wants() did not return a list")
 
-        shallows = getattr(graph_walker, "shallow", frozenset())
-        unshallows = getattr(graph_walker, "unshallow", frozenset())
+        shallows: FrozenSet[ObjectID] = getattr(graph_walker, "shallow", frozenset())
+        unshallows: FrozenSet[ObjectID] = getattr(graph_walker, "unshallow", frozenset())
 
         if wants == []:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
@@ -545,7 +551,18 @@ class BaseRepo(object):
                 # Do not send a pack in shallow short-circuit path
                 return None
 
-            return []
+            class DummyMissingObjectFinder:
+
+                def get_remote_has(self):
+                    return None
+
+                def __len__(self):
+                    return 0
+
+                def __iter__(self):
+                    yield from []
+
+            return DummyMissingObjectFinder()  # type: ignore
 
         # If the graph walker is set up with an implementation that can
         # ACK/NAK to the wire, it will write data to the client through
@@ -564,16 +581,14 @@ class BaseRepo(object):
         def get_parents(commit):
             return parents_provider.get_parents(commit.id, commit)
 
-        return self.object_store.iter_shas(
-            self.object_store.find_missing_objects(
-                haves,
-                wants,
-                self.get_shallow(),
-                progress,
-                get_tagged,
-                get_parents=get_parents,
-            )
-        )
+        return MissingObjectFinder(
+            self.object_store,
+            haves=haves,
+            wants=wants,
+            shallow=self.get_shallow(),
+            progress=progress,
+            get_tagged=get_tagged,
+            get_parents=get_parents)
 
     def generate_pack_data(self, have: List[ObjectID], want: List[ObjectID],
                            progress: Optional[Callable[[str], None]] = None,
@@ -595,7 +610,8 @@ class BaseRepo(object):
         )
 
     def get_graph_walker(
-            self, heads: List[ObjectID] = None) -> ObjectStoreGraphWalker:
+            self,
+            heads: Optional[List[ObjectID]] = None) -> ObjectStoreGraphWalker:
         """Retrieve a graph walker.
 
         A graph walker is used by a remote repository (or proxy)
@@ -641,7 +657,7 @@ class BaseRepo(object):
                 raise NotTagError(ret)
             else:
                 raise Exception(
-                    "Type invalid: %r != %r" % (ret.type_name, cls.type_name)
+                    "Type invalid: {!r} != {!r}".format(ret.type_name, cls.type_name)
                 )
         return ret
 
@@ -663,7 +679,8 @@ class BaseRepo(object):
             shallows=self.get_shallow(),
         )
 
-    def get_parents(self, sha: bytes, commit: Commit = None) -> List[bytes]:
+    def get_parents(self, sha: bytes,
+                    commit: Optional[Commit] = None) -> List[bytes]:
         """Retrieve the parents of a specific commit.
 
         If the specific commit is a graftpoint, the graft parents
@@ -755,7 +772,7 @@ class BaseRepo(object):
         cached = self.refs.get_peeled(ref)
         if cached is not None:
             return cached
-        return self.object_store.peel_sha(self.refs[ref]).id
+        return peel_sha(self.object_store, self.refs[ref]).id
 
     def get_walker(self, include: Optional[List[bytes]] = None,
                    *args, **kwargs):
@@ -855,7 +872,8 @@ class BaseRepo(object):
         else:
             raise ValueError(name)
 
-    def _get_user_identity(self, config: "StackedConfig", kind: str = None) -> bytes:
+    def _get_user_identity(self, config: "StackedConfig",
+                           kind: Optional[str] = None) -> bytes:
         """Determine the identity to use for new commits."""
         # TODO(jelmer): Deprecate this function in favor of get_user_identity
         return get_user_identity(config)
@@ -1110,7 +1128,7 @@ class Repo(BaseRepo):
     def __init__(
         self,
         root: str,
-        object_store: Optional[BaseObjectStore] = None,
+        object_store: Optional[PackBasedObjectStore] = None,
         bare: Optional[bool] = None
     ) -> None:
         self.symlink_fn = None
@@ -1130,7 +1148,7 @@ class Repo(BaseRepo):
         self.bare = bare
         if bare is False:
             if os.path.isfile(hidden_path):
-                with open(hidden_path, "r") as f:
+                with open(hidden_path) as f:
                     path = read_gitfile(f)
                 self._controldir = os.path.join(root, path)
             else:
@@ -1427,7 +1445,9 @@ class Repo(BaseRepo):
         for fs_path in fs_paths:
             tree_path = _fs_to_tree_path(fs_path)
             try:
-                tree_entry = self.object_store[tree_id].lookup_path(
+                tree = self.object_store[tree_id]
+                assert isinstance(tree, Tree)
+                tree_entry = tree.lookup_path(
                     self.object_store.__getitem__, tree_path)
             except KeyError:
                 # if tree_entry didn't exist, this file was being added, so
@@ -1490,9 +1510,7 @@ class Repo(BaseRepo):
         Returns: Created repository as `Repo`
         """
 
-        encoded_path = self.path
-        if not isinstance(encoded_path, bytes):
-            encoded_path = os.fsencode(encoded_path)
+        encoded_path = os.fsencode(self.path)
 
         if mkdir:
             os.mkdir(target_path)

+ 80 - 55
dulwich/server.py

@@ -1,6 +1,6 @@
 # server.py -- Implementation of the server side git protocols
 # Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
-# Coprygith (C) 2011-2012 Jelmer Vernooij <jelmer@jelmer.uk>
+# Copyright(C) 2011-2012 Jelmer Vernooij <jelmer@jelmer.uk>
 #
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
 # General Public License as public by the Free Software Foundation; version 2.0
@@ -43,11 +43,18 @@ Currently supported capabilities:
 """
 
 import collections
+from functools import partial
 import os
 import socket
 import sys
 import time
-from typing import List, Tuple, Dict, Optional, Iterable
+from typing import List, Tuple, Dict, Optional, Iterable, Set
+
+try:
+    from typing import Protocol as TypingProtocol
+except ImportError:  # python < 3.8
+    from typing_extensions import Protocol as TypingProtocol  # type: ignore
+
 import zlib
 
 import socketserver
@@ -65,10 +72,16 @@ from dulwich.errors import (
 from dulwich import log_utils
 from dulwich.objects import (
     Commit,
+    ObjectID,
     valid_hexsha,
 )
+from dulwich.object_store import (
+    peel_sha,
+)
 from dulwich.pack import (
-    write_pack_objects,
+    write_pack_from_container,
+    ObjectContainer,
+    PackedObjectContainer,
 )
 from dulwich.protocol import (
     BufferedPktLineWriter,
@@ -114,6 +127,7 @@ from dulwich.protocol import (
     NAK_LINE,
 )
 from dulwich.refs import (
+    RefsContainer,
     ANNOTATED_TAG_SUFFIX,
     write_info_refs,
 )
@@ -126,7 +140,7 @@ from dulwich.repo import (
 logger = log_utils.getLogger(__name__)
 
 
-class Backend(object):
+class Backend:
     """A backend for the Git smart server implementation."""
 
     def open_repository(self, path):
@@ -141,15 +155,15 @@ class Backend(object):
         raise NotImplementedError(self.open_repository)
 
 
-class BackendRepo(object):
+class BackendRepo(TypingProtocol):
     """Repository abstraction used by the Git server.
 
     The methods required here are a subset of those provided by
     dulwich.repo.Repo.
     """
 
-    object_store = None
-    refs = None
+    object_store: PackedObjectContainer
+    refs: RefsContainer
 
     def get_refs(self) -> Dict[bytes, bytes]:
         """
@@ -171,7 +185,7 @@ class BackendRepo(object):
         """
         return None
 
-    def fetch_objects(self, determine_wants, graph_walker, progress, get_tagged=None):
+    def find_missing_objects(self, determine_wants, graph_walker, progress, get_tagged=None):
         """
         Yield the objects required for a list of commits.
 
@@ -203,7 +217,7 @@ class FileSystemBackend(Backend):
     """Simple backend looking up Git repositories in the local file system."""
 
     def __init__(self, root=os.sep):
-        super(FileSystemBackend, self).__init__()
+        super().__init__()
         self.root = (os.path.abspath(root) + os.sep).replace(os.sep * 2, os.sep)
 
     def open_repository(self, path):
@@ -212,11 +226,11 @@ class FileSystemBackend(Backend):
         normcase_abspath = os.path.normcase(abspath)
         normcase_root = os.path.normcase(self.root)
         if not normcase_abspath.startswith(normcase_root):
-            raise NotGitRepository("Path %r not inside root %r" % (path, self.root))
+            raise NotGitRepository("Path {!r} not inside root {!r}".format(path, self.root))
         return Repo(abspath)
 
 
-class Handler(object):
+class Handler:
     """Smart protocol command handler base class."""
 
     def __init__(self, backend, proto, stateless_rpc=False):
@@ -232,7 +246,7 @@ class PackHandler(Handler):
     """Protocol handler for packs."""
 
     def __init__(self, backend, proto, stateless_rpc=False):
-        super(PackHandler, self).__init__(backend, proto, stateless_rpc)
+        super().__init__(backend, proto, stateless_rpc)
         self._client_capabilities = None
         # Flags needed for the no-done capability
         self._done_received = False
@@ -289,7 +303,7 @@ class UploadPackHandler(PackHandler):
     """Protocol handler for uploading a pack to the client."""
 
     def __init__(self, backend, args, proto, stateless_rpc=False, advertise_refs=False):
-        super(UploadPackHandler, self).__init__(
+        super().__init__(
             backend, proto, stateless_rpc=stateless_rpc
         )
         self.repo = backend.open_repository(args[0])
@@ -322,12 +336,21 @@ class UploadPackHandler(PackHandler):
             CAPABILITY_OFS_DELTA,
         )
 
-    def progress(self, message):
-        if self.has_capability(CAPABILITY_NO_PROGRESS) or self._processing_have_lines:
-            return
-        self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
+    def progress(self, message: bytes):
+        pass
+
+    def _start_pack_send_phase(self):
+        if self.has_capability(CAPABILITY_SIDE_BAND_64K):
+            # The provided haves are processed, and it is safe to send side-
+            # band data now.
+            if not self.has_capability(CAPABILITY_NO_PROGRESS):
+                self.progress = partial(self.proto.write_sideband, SIDE_BAND_CHANNEL_PROGRESS)
 
-    def get_tagged(self, refs=None, repo=None):
+            self.write_pack_data = partial(self.proto.write_sideband, SIDE_BAND_CHANNEL_DATA)
+        else:
+            self.write_pack_data = self.proto.write
+
+    def get_tagged(self, refs=None, repo=None) -> Dict[ObjectID, ObjectID]:
         """Get a dict of peeled values of tags to their original tag shas.
 
         Args:
@@ -351,7 +374,7 @@ class UploadPackHandler(PackHandler):
                 # TODO: fix behavior when missing
                 return {}
         # TODO(jelmer): Integrate this with the refs logic in
-        # Repo.fetch_objects
+        # Repo.find_missing_objects
         tagged = {}
         for name, sha in refs.items():
             peeled_sha = repo.get_peeled(name)
@@ -360,8 +383,10 @@ class UploadPackHandler(PackHandler):
         return tagged
 
     def handle(self):
-        def write(x):
-            return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
+        # Note the fact that client is only processing responses related
+        # to the have lines it sent, and any other data (including side-
+        # band) will be be considered a fatal error.
+        self._processing_have_lines = True
 
         graph_walker = _ProtocolGraphWalker(
             self,
@@ -375,17 +400,14 @@ class UploadPackHandler(PackHandler):
             wants.extend(graph_walker.determine_wants(refs, **kwargs))
             return wants
 
-        objects_iter = self.repo.fetch_objects(
+        missing_objects = self.repo.find_missing_objects(
             wants_wrapper,
             graph_walker,
             self.progress,
             get_tagged=self.get_tagged,
         )
 
-        # Note the fact that client is only processing responses related
-        # to the have lines it sent, and any other data (including side-
-        # band) will be be considered a fatal error.
-        self._processing_have_lines = True
+        object_ids = list(missing_objects)
 
         # Did the process short-circuit (e.g. in a stateless RPC call)? Note
         # that the client still expects a 0-object pack in most cases.
@@ -396,19 +418,17 @@ class UploadPackHandler(PackHandler):
         if len(wants) == 0:
             return
 
-        # The provided haves are processed, and it is safe to send side-
-        # band data now.
-        self._processing_have_lines = False
-
         if not graph_walker.handle_done(
             not self.has_capability(CAPABILITY_NO_DONE), self._done_received
         ):
             return
 
+        self._start_pack_send_phase()
         self.progress(
-            ("counting objects: %d, done.\n" % len(objects_iter)).encode("ascii")
+            ("counting objects: %d, done.\n" % len(object_ids)).encode("ascii")
         )
-        write_pack_objects(write, objects_iter)
+
+        write_pack_from_container(self.write_pack_data, self.repo.object_store, object_ids)
         # we are done
         self.proto.write_pkt_line(None)
 
@@ -456,7 +476,7 @@ def _split_proto_line(line, allowed):
     raise GitProtocolError("Received invalid line from client: %r" % line)
 
 
-def _find_shallow(store, heads, depth):
+def _find_shallow(store: ObjectContainer, heads, depth):
     """Find shallow commits according to a given depth.
 
     Args:
@@ -468,7 +488,7 @@ def _find_shallow(store, heads, depth):
         considered shallow and unshallow according to the arguments. Note that
         these sets may overlap if a commit is reachable along multiple paths.
     """
-    parents = {}
+    parents: Dict[bytes, List[bytes]] = {}
 
     def get_parents(sha):
         result = parents.get(sha, None)
@@ -479,7 +499,7 @@ def _find_shallow(store, heads, depth):
 
     todo = []  # stack of (sha, depth)
     for head_sha in heads:
-        obj = store.peel_sha(head_sha)
+        obj = peel_sha(store, head_sha)
         if isinstance(obj, Commit):
             todo.append((obj.id, 1))
 
@@ -497,15 +517,15 @@ def _find_shallow(store, heads, depth):
     return shallow, not_shallow
 
 
-def _want_satisfied(store, haves, want, earliest):
+def _want_satisfied(store: ObjectContainer, haves, want, earliest):
     o = store[want]
     pending = collections.deque([o])
-    known = set([want])
+    known = {want}
     while pending:
         commit = pending.popleft()
         if commit.id in haves:
             return True
-        if commit.type_name != b"commit":
+        if not isinstance(commit, Commit):
             # non-commit wants are assumed to be satisfied
             continue
         for parent in commit.parents:
@@ -513,13 +533,14 @@ def _want_satisfied(store, haves, want, earliest):
                 continue
             known.add(parent)
             parent_obj = store[parent]
+            assert isinstance(parent_obj, Commit)
             # TODO: handle parents with later commit times than children
             if parent_obj.commit_time >= earliest:
                 pending.append(parent_obj)
     return False
 
 
-def _all_wants_satisfied(store, haves, wants):
+def _all_wants_satisfied(store: ObjectContainer, haves, wants):
     """Check whether all the current wants are satisfied by a set of haves.
 
     Args:
@@ -531,7 +552,8 @@ def _all_wants_satisfied(store, haves, wants):
     """
     haves = set(haves)
     if haves:
-        earliest = min([store[h].commit_time for h in haves])
+        have_objs = [store[h] for h in haves]
+        earliest = min([h.commit_time for h in have_objs if isinstance(h, Commit)])
     else:
         earliest = 0
     for want in wants:
@@ -541,7 +563,7 @@ def _all_wants_satisfied(store, haves, wants):
     return True
 
 
-class _ProtocolGraphWalker(object):
+class _ProtocolGraphWalker:
     """A graph walker that knows the git protocol.
 
     As a graph walker, this class implements ack(), next(), and reset(). It
@@ -555,20 +577,20 @@ class _ProtocolGraphWalker(object):
     any calls to next() or ack() are made.
     """
 
-    def __init__(self, handler, object_store, get_peeled, get_symrefs):
+    def __init__(self, handler, object_store: ObjectContainer, get_peeled, get_symrefs):
         self.handler = handler
-        self.store = object_store
+        self.store: ObjectContainer = object_store
         self.get_peeled = get_peeled
         self.get_symrefs = get_symrefs
         self.proto = handler.proto
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
-        self._wants = []
-        self.shallow = set()
-        self.client_shallow = set()
-        self.unshallow = set()
+        self._wants: List[bytes] = []
+        self.shallow: Set[bytes] = set()
+        self.client_shallow: Set[bytes] = set()
+        self.unshallow: Set[bytes] = set()
         self._cached = False
-        self._cache = []
+        self._cache: List[bytes] = []
         self._cache_index = 0
         self._impl = None
 
@@ -598,7 +620,7 @@ class _ProtocolGraphWalker(object):
                     peeled_sha = self.get_peeled(ref)
                 except KeyError:
                     # Skip refs that are inaccessible
-                    # TODO(jelmer): Integrate with Repo.fetch_objects refs
+                    # TODO(jelmer): Integrate with Repo.find_missing_objects refs
                     # logic.
                     continue
                 if i == 0:
@@ -657,6 +679,9 @@ class _ProtocolGraphWalker(object):
             value = str(value).encode("ascii")
         self.proto.unread_pkt_line(command + b" " + value)
 
+    def nak(self):
+        pass
+
     def ack(self, have_ref):
         if len(have_ref) != 40:
             raise ValueError("invalid sha %r" % have_ref)
@@ -752,7 +777,7 @@ class _ProtocolGraphWalker(object):
 _GRAPH_WALKER_COMMANDS = (COMMAND_HAVE, COMMAND_DONE, None)
 
 
-class SingleAckGraphWalkerImpl(object):
+class SingleAckGraphWalkerImpl:
     """Graph walker implementation that speaks the single-ack protocol."""
 
     def __init__(self, walker):
@@ -796,7 +821,7 @@ class SingleAckGraphWalkerImpl(object):
         return True
 
 
-class MultiAckGraphWalkerImpl(object):
+class MultiAckGraphWalkerImpl:
     """Graph walker implementation that speaks the multi-ack protocol."""
 
     def __init__(self, walker):
@@ -855,7 +880,7 @@ class MultiAckGraphWalkerImpl(object):
         return True
 
 
-class MultiAckDetailedGraphWalkerImpl(object):
+class MultiAckDetailedGraphWalkerImpl:
     """Graph walker implementation speaking the multi-ack-detailed protocol."""
 
     def __init__(self, walker):
@@ -924,7 +949,7 @@ class ReceivePackHandler(PackHandler):
     """Protocol handler for downloading a pack from the client."""
 
     def __init__(self, backend, args, proto, stateless_rpc=False, advertise_refs=False):
-        super(ReceivePackHandler, self).__init__(
+        super().__init__(
             backend, proto, stateless_rpc=stateless_rpc
         )
         self.repo = backend.open_repository(args[0])
@@ -1088,7 +1113,7 @@ class ReceivePackHandler(PackHandler):
 
 class UploadArchiveHandler(Handler):
     def __init__(self, backend, args, proto, stateless_rpc=False):
-        super(UploadArchiveHandler, self).__init__(backend, proto, stateless_rpc)
+        super().__init__(backend, proto, stateless_rpc)
         self.repo = backend.open_repository(args[0])
 
     def handle(self):
@@ -1104,7 +1129,7 @@ class UploadArchiveHandler(Handler):
         prefix = b""
         format = "tar"
         i = 0
-        store = self.repo.object_store
+        store: ObjectContainer = self.repo.object_store
         while i < len(arguments):
             argument = arguments[i]
             if argument == b"--prefix":

+ 1 - 2
dulwich/stash.py

@@ -20,7 +20,6 @@
 
 """Stash handling."""
 
-from __future__ import absolute_import
 
 import os
 
@@ -35,7 +34,7 @@ from dulwich.reflog import drop_reflog_entry, read_reflog
 DEFAULT_STASH_REF = b"refs/stash"
 
 
-class Stash(object):
+class Stash:
     """A Git stash.
 
     Note that this doesn't currently update the working tree.

+ 2 - 1
dulwich/submodule.py

@@ -22,6 +22,7 @@
 """
 
 from typing import Iterator, Tuple
+from .object_store import iter_tree_contents
 from .objects import S_ISGITLINK
 
 
@@ -35,6 +36,6 @@ def iter_cached_submodules(store, root_tree_id: bytes) -> Iterator[Tuple[str, by
     Returns:
       Iterator over over (path, sha) tuples
     """
-    for entry in store.iter_tree_contents(root_tree_id):
+    for entry in iter_tree_contents(store, root_tree_id):
         if S_ISGITLINK(entry.mode):
             yield entry.path, entry.sha

+ 16 - 10
dulwich/tests/__init__.py

@@ -48,17 +48,23 @@ from unittest import (  # noqa: F401
 
 class TestCase(_TestCase):
     def setUp(self):
-        super(TestCase, self).setUp()
-        self._old_home = os.environ.get("HOME")
-        os.environ["HOME"] = "/nonexistent"
-        os.environ["GIT_CONFIG_NOSYSTEM"] = "1"
-
-    def tearDown(self):
-        super(TestCase, self).tearDown()
-        if self._old_home:
-            os.environ["HOME"] = self._old_home
+        super().setUp()
+        self.overrideEnv("HOME", "/nonexistent")
+        self.overrideEnv("GIT_CONFIG_NOSYSTEM", "1")
+
+    def overrideEnv(self, name, value):
+        def restore():
+            if oldval is not None:
+                os.environ[name] = oldval
+            else:
+                del os.environ[name]
+
+        oldval = os.environ.get(name)
+        if value is not None:
+            os.environ[name] = value
         else:
-            del os.environ["HOME"]
+            del os.environ[name]
+        self.addCleanup(restore)
 
 
 class BlackboxTestCase(TestCase):

+ 4 - 4
dulwich/tests/compat/server_utils.py

@@ -43,7 +43,7 @@ from dulwich.tests.compat.utils import (
 from dulwich.tests.compat.utils import require_git_version
 
 
-class _StubRepo(object):
+class _StubRepo:
     """A stub repo that just contains a path to tear down."""
 
     def __init__(self, name):
@@ -70,7 +70,7 @@ def _get_shallow(repo):
     return shallows
 
 
-class ServerTests(object):
+class ServerTests:
     """Base tests for testing servers.
 
     Does not inherit from TestCase so tests are not automatically run.
@@ -87,12 +87,12 @@ class ServerTests(object):
         self._new_repo = self.import_repo("server_new.export")
 
     def url(self, port):
-        return "%s://localhost:%s/" % (self.protocol, port)
+        return "{}://localhost:{}/".format(self.protocol, port)
 
     def branch_args(self, branches=None):
         if branches is None:
             branches = ["master", "branch"]
-        return ["%s:%s" % (b, b) for b in branches]
+        return ["{}:{}".format(b, b) for b in branches]
 
     def test_push_to_dulwich(self):
         self.import_repos()

+ 10 - 12
dulwich/tests/compat/test_client.py

@@ -62,7 +62,7 @@ if sys.platform == "win32":
     import ctypes
 
 
-class DulwichClientTestBase(object):
+class DulwichClientTestBase:
     """Tests for client/server compatibility."""
 
     def setUp(self):
@@ -248,12 +248,10 @@ class DulwichClientTestBase(object):
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertEqual(
                 dest.get_shallow(),
-                set(
-                    [
-                        b"35e0b59e187dd72a0af294aedffc213eaa4d03ff",
-                        b"514dc6d3fbfe77361bcaef320c4d21b72bc10be9",
-                    ]
-                ),
+                {
+                    b"35e0b59e187dd72a0af294aedffc213eaa4d03ff",
+                    b"514dc6d3fbfe77361bcaef320c4d21b72bc10be9",
+                },
             )
 
     def test_repeat(self):
@@ -331,7 +329,7 @@ class DulwichClientTestBase(object):
             sendrefs[b"refs/heads/abranch"] = b"00" * 20
             del sendrefs[b"HEAD"]
 
-            def gen_pack(have, want, ofs_delta=False):
+            def gen_pack(have, want, ofs_delta=False, progress=None):
                 return 0, []
 
             c = self._client()
@@ -346,7 +344,7 @@ class DulwichClientTestBase(object):
             dest.refs[b"refs/heads/abranch"] = dummy_commit
             sendrefs = {b"refs/heads/bbranch": dummy_commit}
 
-            def gen_pack(have, want, ofs_delta=False):
+            def gen_pack(have, want, ofs_delta=False, progress=None):
                 return 0, []
 
             c = self._client()
@@ -409,7 +407,7 @@ class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
             try:
                 os.kill(pid, signal.SIGKILL)
                 os.unlink(self.pidfile)
-            except (OSError, IOError):
+            except OSError:
                 pass
         self.process.wait()
         self.process.stdout.close()
@@ -435,7 +433,7 @@ class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
         self.skipTest('skip flaky test; see #1015')
 
 
-class TestSSHVendor(object):
+class TestSSHVendor:
     @staticmethod
     def run_command(
         host,
@@ -648,7 +646,7 @@ class HTTPGitServer(http.server.HTTPServer):
         self.server_name = "localhost"
 
     def get_url(self):
-        return "http://%s:%s/" % (self.server_name, self.server_port)
+        return "http://{}:{}/".format(self.server_name, self.server_port)
 
 
 class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):

+ 6 - 6
dulwich/tests/compat/test_pack.py

@@ -66,7 +66,7 @@ class TestPack(PackTests):
 
     def setUp(self):
         require_git_version((1, 5, 0))
-        super(TestPack, self).setUp()
+        super().setUp()
         self._tempdir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self._tempdir)
 
@@ -84,7 +84,7 @@ class TestPack(PackTests):
             orig_blob = orig_pack[a_sha]
             new_blob = Blob()
             new_blob.data = orig_blob.data + b"x"
-            all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None)]
+            all_to_pack = [(o, None) for o in orig_pack.iterobjects()] + [(new_blob, None)]
         pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         write_pack(pack_path, all_to_pack, deltify=True)
         output = run_git_or_fail(["verify-pack", "-v", pack_path])
@@ -115,8 +115,8 @@ class TestPack(PackTests):
                 (new_blob, None),
                 (new_blob_2, None),
             ]
-        pack_path = os.path.join(self._tempdir, "pack_with_deltas")
-        write_pack(pack_path, all_to_pack, deltify=True)
+            pack_path = os.path.join(self._tempdir, "pack_with_deltas")
+            write_pack(pack_path, all_to_pack, deltify=True)
         output = run_git_or_fail(["verify-pack", "-v", pack_path])
         self.assertEqual(
             {x[0].id for x in all_to_pack},
@@ -154,8 +154,8 @@ class TestPack(PackTests):
                 (new_blob, None),
                 (new_blob_2, None),
             ]
-        pack_path = os.path.join(self._tempdir, "pack_with_deltas")
-        write_pack(pack_path, all_to_pack, deltify=True)
+            pack_path = os.path.join(self._tempdir, "pack_with_deltas")
+            write_pack(pack_path, all_to_pack, deltify=True)
         output = run_git_or_fail(["verify-pack", "-v", pack_path])
         self.assertEqual(
             {x[0].id for x in all_to_pack},

+ 1 - 1
dulwich/tests/compat/test_patch.py

@@ -36,7 +36,7 @@ from dulwich.tests.compat.utils import (
 
 class CompatPatchTestCase(CompatTestCase):
     def setUp(self):
-        super(CompatPatchTestCase, self).setUp()
+        super().setUp()
         self.test_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self.test_dir)
         self.repo_path = os.path.join(self.test_dir, "repo")

+ 3 - 3
dulwich/tests/compat/test_porcelain.py

@@ -41,7 +41,7 @@ from dulwich.tests.test_porcelain import (
 @skipIf(platform.python_implementation() == "PyPy" or sys.platform == "win32", "gpgme not easily available or supported on Windows and PyPy")
 class TagCreateSignTestCase(PorcelainGpgTestCase, CompatTestCase):
     def setUp(self):
-        super(TagCreateSignTestCase, self).setUp()
+        super().setUp()
 
     def test_sign(self):
         # Test that dulwich signatures can be verified by CGit
@@ -64,7 +64,7 @@ class TagCreateSignTestCase(PorcelainGpgTestCase, CompatTestCase):
 
         run_git_or_fail(
             [
-                "--git-dir={}".format(self.repo.controldir()),
+                f"--git-dir={self.repo.controldir()}",
                 "tag",
                 "-v",
                 "tryme"
@@ -82,7 +82,7 @@ class TagCreateSignTestCase(PorcelainGpgTestCase, CompatTestCase):
 
         run_git_or_fail(
             [
-                "--git-dir={}".format(self.repo.controldir()),
+                f"--git-dir={self.repo.controldir()}",
                 "tag",
                 "-u",
                 PorcelainGpgTestCase.DEFAULT_KEY_ID,

+ 4 - 4
dulwich/tests/compat/test_repository.py

@@ -45,7 +45,7 @@ class ObjectStoreTestCase(CompatTestCase):
     """Tests for git repository compatibility."""
 
     def setUp(self):
-        super(ObjectStoreTestCase, self).setUp()
+        super().setUp()
         self._repo = self.import_repo("server_new.export")
 
     def _run_git(self, args):
@@ -147,7 +147,7 @@ class WorkingTreeTestCase(ObjectStoreTestCase):
         return temp_dir
 
     def setUp(self):
-        super(WorkingTreeTestCase, self).setUp()
+        super().setUp()
         self._worktree_path = self.create_new_worktree(self._repo.path, "branch")
         self._worktree_repo = Repo(self._worktree_path)
         self.addCleanup(self._worktree_repo.close)
@@ -156,7 +156,7 @@ class WorkingTreeTestCase(ObjectStoreTestCase):
         self._repo = self._worktree_repo
 
     def test_refs(self):
-        super(WorkingTreeTestCase, self).test_refs()
+        super().test_refs()
         self.assertEqual(
             self._mainworktree_repo.refs.allkeys(), self._repo.refs.allkeys()
         )
@@ -225,7 +225,7 @@ class InitNewWorkingDirectoryTestCase(WorkingTreeTestCase):
     min_git_version = (2, 5, 0)
 
     def setUp(self):
-        super(InitNewWorkingDirectoryTestCase, self).setUp()
+        super().setUp()
         self._other_worktree = self._repo
         worktree_repo_path = tempfile.mkdtemp()
         self.addCleanup(rmtree_ro, worktree_repo_path)

+ 1 - 1
dulwich/tests/compat/test_server.py

@@ -81,7 +81,7 @@ class GitServerSideBand64kTestCase(GitServerTestCase):
     min_git_version = (1, 7, 0, 2)
 
     def setUp(self):
-        super(GitServerSideBand64kTestCase, self).setUp()
+        super().setUp()
         # side-band-64k is broken in the windows client.
         # https://github.com/msysgit/git/issues/101
         # Fix has landed for the 1.9.3 release.

+ 2 - 2
dulwich/tests/compat/test_utils.py

@@ -29,7 +29,7 @@ from dulwich.tests.compat import utils
 
 class GitVersionTests(TestCase):
     def setUp(self):
-        super(GitVersionTests, self).setUp()
+        super().setUp()
         self._orig_run_git = utils.run_git
         self._version_str = None  # tests can override to set stub version
 
@@ -40,7 +40,7 @@ class GitVersionTests(TestCase):
         utils.run_git = run_git
 
     def tearDown(self):
-        super(GitVersionTests, self).tearDown()
+        super().tearDown()
         utils.run_git = self._orig_run_git
 
     def test_git_version_none(self):

+ 3 - 3
dulwich/tests/compat/test_web.py

@@ -90,7 +90,7 @@ class SmartWebTestCase(WebTests, CompatTestCase):
     This server test case does not use side-band-64k in git-receive-pack.
     """
 
-    min_git_version = (1, 6, 6)  # type: Tuple[int, ...]
+    min_git_version: Tuple[int, ...] = (1, 6, 6)
 
     def _handlers(self):
         return {b"git-receive-pack": NoSideBand64kReceivePackHandler}
@@ -135,10 +135,10 @@ class SmartWebSideBand64kTestCase(SmartWebTestCase):
     def setUp(self):
         self.o_uph_cap = patch_capabilities(UploadPackHandler, (b"no-done",))
         self.o_rph_cap = patch_capabilities(ReceivePackHandler, (b"no-done",))
-        super(SmartWebSideBand64kTestCase, self).setUp()
+        super().setUp()
 
     def tearDown(self):
-        super(SmartWebSideBand64kTestCase, self).tearDown()
+        super().tearDown()
         UploadPackHandler.capabilities = self.o_uph_cap
         ReceivePackHandler.capabilities = self.o_rph_cap
 

+ 5 - 5
dulwich/tests/compat/utils.py

@@ -94,7 +94,7 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
     found_version = git_version(git_path=git_path)
     if found_version is None:
         raise SkipTest(
-            "Test requires git >= %s, but c git not found" % (required_version,)
+            "Test requires git >= {}, but c git not found".format(required_version)
         )
 
     if len(required_version) > _VERSION_LEN:
@@ -112,7 +112,7 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
         required_version = ".".join(map(str, required_version))
         found_version = ".".join(map(str, found_version))
         raise SkipTest(
-            "Test requires git >= %s, found %s" % (required_version, found_version)
+            "Test requires git >= {}, found {}".format(required_version, found_version)
         )
 
 
@@ -216,7 +216,7 @@ def check_for_daemon(limit=10, delay=0.1, timeout=0.1, port=TCP_GIT_PORT):
             return True
         except socket.timeout:
             pass
-        except socket.error as e:
+        except OSError as e:
             if getattr(e, "errno", False) and e.errno != errno.ECONNREFUSED:
                 raise
             elif e.args[0] != errno.ECONNREFUSED:
@@ -233,10 +233,10 @@ class CompatTestCase(TestCase):
     min_git_version.
     """
 
-    min_git_version = (1, 5, 0)  # type: Tuple[int, ...]
+    min_git_version: Tuple[int, ...] = (1, 5, 0)
 
     def setUp(self):
-        super(CompatTestCase, self).setUp()
+        super().setUp()
         require_git_version(self.min_git_version)
 
     def assertObjectStoreEqual(self, store1, store2):

+ 1 - 1
dulwich/tests/test_archive.py

@@ -78,7 +78,7 @@ class ArchiveTests(TestCase):
         b1 = Blob.from_string(b"somedata")
         store.add_object(b1)
         t1 = Tree()
-        t1.add("ő".encode('utf-8'), 0o100644, b1.id)
+        t1.add("ő".encode(), 0o100644, b1.id)
         store.add_object(t1)
         stream = b"".join(tar_stream(store, t1, mtime=0))
         tf = tarfile.TarFile(fileobj=BytesIO(stream))

+ 2 - 2
dulwich/tests/test_blackbox.py

@@ -35,7 +35,7 @@ class GitReceivePackTests(BlackboxTestCase):
     """Blackbox tests for dul-receive-pack."""
 
     def setUp(self):
-        super(GitReceivePackTests, self).setUp()
+        super().setUp()
         self.path = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self.path)
         self.repo = Repo.init(self.path)
@@ -60,7 +60,7 @@ class GitUploadPackTests(BlackboxTestCase):
     """Blackbox tests for dul-upload-pack."""
 
     def setUp(self):
-        super(GitUploadPackTests, self).setUp()
+        super().setUp()
         self.path = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self.path)
         self.repo = Repo.init(self.path)

+ 9 - 0
dulwich/tests/test_bundle.py

@@ -20,6 +20,7 @@
 
 """Tests for bundle support."""
 
+from io import BytesIO
 import os
 import tempfile
 
@@ -32,6 +33,10 @@ from dulwich.bundle import (
     read_bundle,
     write_bundle,
 )
+from dulwich.pack import (
+    PackData,
+    write_pack_objects,
+)
 
 
 class BundleTests(TestCase):
@@ -41,6 +46,10 @@ class BundleTests(TestCase):
         origbundle.capabilities = {"foo": None}
         origbundle.references = {b"refs/heads/master": b"ab" * 20}
         origbundle.prerequisites = [(b"cc" * 20, "comment")]
+        b = BytesIO()
+        write_pack_objects(b.write, [])
+        b.seek(0)
+        origbundle.pack_data = PackData.from_file(b)
         with tempfile.TemporaryDirectory() as td:
             with open(os.path.join(td, "foo"), "wb") as f:
                 write_bundle(f, origbundle)

+ 189 - 47
dulwich/tests/test_client.py

@@ -118,7 +118,7 @@ class DummyPopen:
 # TODO(durin42): add unit-level tests of GitClient
 class GitClientTests(TestCase):
     def setUp(self):
-        super(GitClientTests, self).setUp()
+        super().setUp()
         self.rout = BytesIO()
         self.rin = BytesIO()
         self.client = DummyClient(lambda x: True, self.rin.read, self.rout.write)
@@ -126,29 +126,25 @@ class GitClientTests(TestCase):
     def test_caps(self):
         agent_cap = ("agent=dulwich/%d.%d.%d" % dulwich.__version__).encode("ascii")
         self.assertEqual(
-            set(
-                [
-                    b"multi_ack",
-                    b"side-band-64k",
-                    b"ofs-delta",
-                    b"thin-pack",
-                    b"multi_ack_detailed",
-                    b"shallow",
-                    agent_cap,
-                ]
-            ),
+            {
+                b"multi_ack",
+                b"side-band-64k",
+                b"ofs-delta",
+                b"thin-pack",
+                b"multi_ack_detailed",
+                b"shallow",
+                agent_cap,
+            },
             set(self.client._fetch_capabilities),
         )
         self.assertEqual(
-            set(
-                [
-                    b"delete-refs",
-                    b"ofs-delta",
-                    b"report-status",
-                    b"side-band-64k",
-                    agent_cap,
-                ]
-            ),
+            {
+                b"delete-refs",
+                b"ofs-delta",
+                b"report-status",
+                b"side-band-64k",
+                agent_cap,
+            },
             set(self.client._send_capabilities),
         )
 
@@ -204,7 +200,7 @@ class GitClientTests(TestCase):
         self.assertEqual({}, ret.symrefs)
         self.assertEqual(self.rout.getvalue(), b"0000")
 
-    def test_send_pack_no_sideband64k_with_update_ref_error(self):
+    def test_send_pack_no_sideband64k_with_update_ref_error(self) -> None:
         # No side-bank-64k reported by server shouldn't try to parse
         # side band data
         pkts = [
@@ -237,11 +233,11 @@ class GitClientTests(TestCase):
                 b"refs/foo/bar": commit.id,
             }
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return pack_objects_to_data(
                 [
                     (commit, None),
-                    (tree, ""),
+                    (tree, b""),
                 ]
             )
 
@@ -264,7 +260,7 @@ class GitClientTests(TestCase):
         def update_refs(refs):
             return {b"refs/heads/master": b"310ca9477129b8586fa2afc779c1f57cf64bba6c"}
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return 0, []
 
         self.client.send_pack(b"/", update_refs, generate_pack_data)
@@ -284,7 +280,7 @@ class GitClientTests(TestCase):
         def update_refs(refs):
             return {b"refs/heads/master": b"0" * 40}
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return 0, []
 
         self.client.send_pack(b"/", update_refs, generate_pack_data)
@@ -308,7 +304,7 @@ class GitClientTests(TestCase):
         def update_refs(refs):
             return {b"refs/heads/master": b"0" * 40}
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return 0, []
 
         self.client.send_pack(b"/", update_refs, generate_pack_data)
@@ -335,11 +331,11 @@ class GitClientTests(TestCase):
                 b"refs/heads/master": b"310ca9477129b8586fa2afc779c1f57cf64bba6c",
             }
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return 0, []
 
         f = BytesIO()
-        write_pack_objects(f.write, {})
+        write_pack_objects(f.write, [])
         self.client.send_pack("/", update_refs, generate_pack_data)
         self.assertEqual(
             self.rout.getvalue(),
@@ -375,7 +371,7 @@ class GitClientTests(TestCase):
                 b"refs/heads/master": b"310ca9477129b8586fa2afc779c1f57cf64bba6c",
             }
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return pack_objects_to_data(
                 [
                     (commit, None),
@@ -384,7 +380,8 @@ class GitClientTests(TestCase):
             )
 
         f = BytesIO()
-        write_pack_data(f.write, *generate_pack_data(None, None))
+        count, records = generate_pack_data(None, None)
+        write_pack_data(f.write, records, num_records=count)
         self.client.send_pack(b"/", update_refs, generate_pack_data)
         self.assertEqual(
             self.rout.getvalue(),
@@ -411,7 +408,7 @@ class GitClientTests(TestCase):
         def update_refs(refs):
             return {b"refs/heads/master": b"0" * 40}
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return 0, []
 
         result = self.client.send_pack(b"/", update_refs, generate_pack_data)
@@ -720,7 +717,7 @@ class TestGetTransportAndPathFromUrl(TestCase):
                     c, path = get_transport_and_path(remote_url)
 
 
-class TestSSHVendor(object):
+class TestSSHVendor:
     def __init__(self):
         self.host = None
         self.command = ""
@@ -759,7 +756,7 @@ class TestSSHVendor(object):
 
 class SSHGitClientTests(TestCase):
     def setUp(self):
-        super(SSHGitClientTests, self).setUp()
+        super().setUp()
 
         self.server = TestSSHVendor()
         self.real_vendor = client.get_ssh_vendor
@@ -768,7 +765,7 @@ class SSHGitClientTests(TestCase):
         self.client = SSHGitClient("git.samba.org")
 
     def tearDown(self):
-        super(SSHGitClientTests, self).tearDown()
+        super().tearDown()
         client.get_ssh_vendor = self.real_vendor
 
     def test_get_url(self):
@@ -820,20 +817,17 @@ class SSHGitClientTests(TestCase):
         self.assertEqual("git-relative-command '~/path/to/repo'", server.command)
 
     def test_ssh_command_precedence(self):
-        os.environ["GIT_SSH"] = "/path/to/ssh"
+        self.overrideEnv("GIT_SSH", "/path/to/ssh")
         test_client = SSHGitClient("git.samba.org")
         self.assertEqual(test_client.ssh_command, "/path/to/ssh")
 
-        os.environ["GIT_SSH_COMMAND"] = "/path/to/ssh -o Option=Value"
+        self.overrideEnv("GIT_SSH_COMMAND", "/path/to/ssh -o Option=Value")
         test_client = SSHGitClient("git.samba.org")
         self.assertEqual(test_client.ssh_command, "/path/to/ssh -o Option=Value")
 
         test_client = SSHGitClient("git.samba.org", ssh_command="ssh -o Option1=Value1")
         self.assertEqual(test_client.ssh_command, "ssh -o Option1=Value1")
 
-        del os.environ["GIT_SSH"]
-        del os.environ["GIT_SSH_COMMAND"]
-
 
 class ReportStatusParserTests(TestCase):
     def test_invalid_pack(self):
@@ -868,7 +862,9 @@ class LocalGitClientTests(TestCase):
 
     def test_fetch_into_empty(self):
         c = LocalGitClient()
-        t = MemoryRepo()
+        target = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, target)
+        t = Repo.init_bare(target)
         s = open_repo("a.git")
         self.addCleanup(tear_down_repo, s)
         self.assertEqual(s.get_refs(), c.fetch(s.path, t).refs)
@@ -1016,7 +1012,7 @@ class HttpGitClientTests(TestCase):
         self.assertEqual("passwd", c._password)
 
         basic_auth = c.pool_manager.headers["authorization"]
-        auth_string = "%s:%s" % ("user", "passwd")
+        auth_string = "{}:{}".format("user", "passwd")
         b64_credentials = base64.b64encode(auth_string.encode("latin1"))
         expected_basic_auth = "Basic %s" % b64_credentials.decode("latin1")
         self.assertEqual(basic_auth, expected_basic_auth)
@@ -1072,7 +1068,7 @@ class HttpGitClientTests(TestCase):
         self.assertEqual(original_password, c._password)
 
         basic_auth = c.pool_manager.headers["authorization"]
-        auth_string = "%s:%s" % (original_username, original_password)
+        auth_string = "{}:{}".format(original_username, original_password)
         b64_credentials = base64.b64encode(auth_string.encode("latin1"))
         expected_basic_auth = "Basic %s" % b64_credentials.decode("latin1")
         self.assertEqual(basic_auth, expected_basic_auth)
@@ -1229,14 +1225,160 @@ class DefaultUrllib3ManagerTest(TestCase):
         import urllib3
 
         config = ConfigDict()
-        os.environ["http_proxy"] = "http://myproxy:8080"
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
         manager = default_urllib3_manager(config=config)
         self.assertIsInstance(manager, urllib3.ProxyManager)
         self.assertTrue(hasattr(manager, "proxy"))
         self.assertEqual(manager.proxy.scheme, "http")
         self.assertEqual(manager.proxy.host, "myproxy")
         self.assertEqual(manager.proxy.port, 8080)
-        del os.environ["http_proxy"]
+
+    def test_environment_empty_proxy(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "")
+        manager = default_urllib3_manager(config=config)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
+
+    def test_environment_no_proxy_1(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,abc.gh")
+        base_url = "http://xyz.abc.def.gh:8080/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
+
+    def test_environment_no_proxy_2(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,abc.gh,ample.com")
+        base_url = "http://ample.com/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
+
+    def test_environment_no_proxy_3(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,abc.gh,ample.com")
+        base_url = "http://ample.com:80/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
+
+    def test_environment_no_proxy_4(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,abc.gh,ample.com")
+        base_url = "http://www.ample.com/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
+
+    def test_environment_no_proxy_5(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,abc.gh,ample.com")
+        base_url = "http://www.example.com/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertIsInstance(manager, urllib3.ProxyManager)
+        self.assertTrue(hasattr(manager, "proxy"))
+        self.assertEqual(manager.proxy.scheme, "http")
+        self.assertEqual(manager.proxy.host, "myproxy")
+        self.assertEqual(manager.proxy.port, 8080)
+
+    def test_environment_no_proxy_6(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,abc.gh,ample.com")
+        base_url = "http://ample.com.org/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertIsInstance(manager, urllib3.ProxyManager)
+        self.assertTrue(hasattr(manager, "proxy"))
+        self.assertEqual(manager.proxy.scheme, "http")
+        self.assertEqual(manager.proxy.host, "myproxy")
+        self.assertEqual(manager.proxy.port, 8080)
+
+    def test_environment_no_proxy_ipv4_address_1(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,192.168.0.10,ample.com")
+        base_url = "http://192.168.0.10/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
+
+    def test_environment_no_proxy_ipv4_address_2(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,192.168.0.10,ample.com")
+        base_url = "http://192.168.0.10:8888/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
+
+    def test_environment_no_proxy_ipv4_address_3(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,ff80:1::/64,192.168.0.0/24,ample.com")
+        base_url = "http://192.168.0.10/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
+
+    def test_environment_no_proxy_ipv6_address_1(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,ff80:1::affe,ample.com")
+        base_url = "http://[ff80:1::affe]/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
+
+    def test_environment_no_proxy_ipv6_address_2(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,ff80:1::affe,ample.com")
+        base_url = "http://[ff80:1::affe]:1234/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
+
+    def test_environment_no_proxy_ipv6_address_3(self):
+        import urllib3
+
+        config = ConfigDict()
+        self.overrideEnv("http_proxy", "http://myproxy:8080")
+        self.overrideEnv("no_proxy", "xyz,abc.def.gh,192.168.0.0/24,ff80:1::/64,ample.com")
+        base_url = "http://[ff80:1::affe]/path/port"
+        manager = default_urllib3_manager(config=config, base_url=base_url)
+        self.assertNotIsInstance(manager, urllib3.ProxyManager)
+        self.assertIsInstance(manager, urllib3.PoolManager)
 
     def test_config_proxy_custom_cls(self):
         import urllib3
@@ -1384,7 +1526,7 @@ class PLinkSSHVendorTests(TestCase):
                 break
         else:
             raise AssertionError(
-                "Expected warning %r not in %r" % (expected_warning, warnings_list)
+                "Expected warning {!r} not in {!r}".format(expected_warning, warnings_list)
             )
 
         args = command.proc.args
@@ -1429,7 +1571,7 @@ class PLinkSSHVendorTests(TestCase):
                 break
         else:
             raise AssertionError(
-                "Expected warning %r not in %r" % (expected_warning, warnings_list)
+                "Expected warning {!r} not in {!r}".format(expected_warning, warnings_list)
             )
 
         args = command.proc.args

+ 2 - 10
dulwich/tests/test_config.py

@@ -307,14 +307,6 @@ class ConfigDictTests(TestCase):
 
 
 class StackedConfigTests(TestCase):
-    def setUp(self):
-        super(StackedConfigTests, self).setUp()
-        self._old_path = os.environ.get("PATH")
-
-    def tearDown(self):
-        super(StackedConfigTests, self).tearDown()
-        os.environ["PATH"] = self._old_path
-
     def test_default_backends(self):
         StackedConfig.default_backends()
 
@@ -323,7 +315,7 @@ class StackedConfigTests(TestCase):
         from dulwich.config import get_win_system_paths
 
         install_dir = os.path.join("C:", "foo", "Git")
-        os.environ["PATH"] = os.path.join(install_dir, "cmd")
+        self.overrideEnv("PATH", os.path.join(install_dir, "cmd"))
         with patch("os.path.exists", return_value=True):
             paths = set(get_win_system_paths())
         self.assertEqual(
@@ -340,7 +332,7 @@ class StackedConfigTests(TestCase):
 
         from dulwich.config import get_win_system_paths
 
-        del os.environ["PATH"]
+        self.overrideEnv("PATH", None)
         install_dir = os.path.join("C:", "foo", "Git")
         with patch("winreg.OpenKey"):
             with patch(

+ 3 - 3
dulwich/tests/test_diff_tree.py

@@ -64,7 +64,7 @@ from dulwich.tests.utils import (
 
 class DiffTestCase(TestCase):
     def setUp(self):
-        super(DiffTestCase, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
         self.empty_tree = self.commit_tree([])
 
@@ -87,7 +87,7 @@ class DiffTestCase(TestCase):
 
 class TreeChangesTest(DiffTestCase):
     def setUp(self):
-        super(TreeChangesTest, self).setUp()
+        super().setUp()
         self.detector = RenameDetector(self.store)
 
     def assertMergeFails(self, merge_entries, name, mode, sha):
@@ -699,7 +699,7 @@ class RenameDetectionTest(DiffTestCase):
 
         block_cache = {}
         self.assertEqual(50, _similarity_score(blob1, blob2, block_cache=block_cache))
-        self.assertEqual(set([blob1.id, blob2.id]), set(block_cache))
+        self.assertEqual({blob1.id, blob2.id}, set(block_cache))
 
         def fail_chunks():
             self.fail("Unexpected call to as_raw_chunks()")

+ 2 - 2
dulwich/tests/test_fastexport.py

@@ -47,7 +47,7 @@ class GitFastExporterTests(TestCase):
     """Tests for the GitFastExporter tests."""
 
     def setUp(self):
-        super(GitFastExporterTests, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
         self.stream = BytesIO()
         try:
@@ -96,7 +96,7 @@ class GitImportProcessorTests(TestCase):
     """Tests for the GitImportProcessor tests."""
 
     def setUp(self):
-        super(GitImportProcessorTests, self).setUp()
+        super().setUp()
         self.repo = MemoryRepo()
         try:
             from dulwich.fastexport import GitImportProcessor

+ 6 - 6
dulwich/tests/test_file.py

@@ -33,7 +33,7 @@ from dulwich.tests import (
 
 class FancyRenameTests(TestCase):
     def setUp(self):
-        super(FancyRenameTests, self).setUp()
+        super().setUp()
         self._tempdir = tempfile.mkdtemp()
         self.foo = self.path("foo")
         self.bar = self.path("bar")
@@ -41,7 +41,7 @@ class FancyRenameTests(TestCase):
 
     def tearDown(self):
         shutil.rmtree(self._tempdir)
-        super(FancyRenameTests, self).tearDown()
+        super().tearDown()
 
     def path(self, filename):
         return os.path.join(self._tempdir, filename)
@@ -89,7 +89,7 @@ class FancyRenameTests(TestCase):
 
 class GitFileTests(TestCase):
     def setUp(self):
-        super(GitFileTests, self).setUp()
+        super().setUp()
         self._tempdir = tempfile.mkdtemp()
         f = open(self.path("foo"), "wb")
         f.write(b"foo contents")
@@ -97,7 +97,7 @@ class GitFileTests(TestCase):
 
     def tearDown(self):
         shutil.rmtree(self._tempdir)
-        super(GitFileTests, self).tearDown()
+        super().tearDown()
 
     def path(self, filename):
         return os.path.join(self._tempdir, filename)
@@ -191,14 +191,14 @@ class GitFileTests(TestCase):
         f.abort()
         try:
             f.close()
-        except (IOError, OSError):
+        except OSError:
             self.fail()
 
         f = GitFile(foo, "wb")
         f.close()
         try:
             f.abort()
-        except (IOError, OSError):
+        except OSError:
             self.fail()
 
     def test_abort_close_removed(self):

+ 4 - 4
dulwich/tests/test_grafts.py

@@ -104,9 +104,9 @@ class GraftSerializerTests(TestCase):
         )
 
 
-class GraftsInRepositoryBase(object):
+class GraftsInRepositoryBase:
     def tearDown(self):
-        super(GraftsInRepositoryBase, self).tearDown()
+        super().tearDown()
 
     def get_repo_with_grafts(self, grafts):
         r = self._repo
@@ -148,7 +148,7 @@ class GraftsInRepositoryBase(object):
 
 class GraftsInRepoTests(GraftsInRepositoryBase, TestCase):
     def setUp(self):
-        super(GraftsInRepoTests, self).setUp()
+        super().setUp()
         self._repo_dir = os.path.join(tempfile.mkdtemp())
         r = self._repo = Repo.init(self._repo_dir)
         self.addCleanup(shutil.rmtree, self._repo_dir)
@@ -188,7 +188,7 @@ class GraftsInRepoTests(GraftsInRepositoryBase, TestCase):
 
 class GraftsInMemoryRepoTests(GraftsInRepositoryBase, TestCase):
     def setUp(self):
-        super(GraftsInMemoryRepoTests, self).setUp()
+        super().setUp()
         r = self._repo = MemoryRepo()
 
         self._shas = []

+ 7 - 8
dulwich/tests/test_graph.py

@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # test_index.py -- Tests for merge
 # encoding: utf-8
 # Copyright (c) 2020 Kevin B. Hendricks, Stratford Ontario Canada
@@ -48,7 +47,7 @@ class FindMergeBaseTests(TestCase):
             "1": [],
             "0": [],
         }
-        self.assertEqual(self.run_test(graph, ["4", "5"]), set(["1", "2"]))
+        self.assertEqual(self.run_test(graph, ["4", "5"]), {"1", "2"})
 
     def test_no_common_ancestor(self):
         # no common ancestor
@@ -59,7 +58,7 @@ class FindMergeBaseTests(TestCase):
             "1": ["0"],
             "0": [],
         }
-        self.assertEqual(self.run_test(graph, ["4", "3"]), set([]))
+        self.assertEqual(self.run_test(graph, ["4", "3"]), set())
 
     def test_ancestor(self):
         # ancestor
@@ -72,7 +71,7 @@ class FindMergeBaseTests(TestCase):
             "B": ["A"],
             "A": [],
         }
-        self.assertEqual(self.run_test(graph, ["D", "C"]), set(["C"]))
+        self.assertEqual(self.run_test(graph, ["D", "C"]), {"C"})
 
     def test_direct_parent(self):
         # parent
@@ -85,7 +84,7 @@ class FindMergeBaseTests(TestCase):
             "B": ["A"],
             "A": [],
         }
-        self.assertEqual(self.run_test(graph, ["G", "D"]), set(["D"]))
+        self.assertEqual(self.run_test(graph, ["G", "D"]), {"D"})
 
     def test_another_crossover(self):
         # Another cross over
@@ -98,7 +97,7 @@ class FindMergeBaseTests(TestCase):
             "B": ["A"],
             "A": [],
         }
-        self.assertEqual(self.run_test(graph, ["D", "F"]), set(["E", "C"]))
+        self.assertEqual(self.run_test(graph, ["D", "F"]), {"E", "C"})
 
     def test_three_way_merge_lca(self):
         # three way merge commit straight from git docs
@@ -121,7 +120,7 @@ class FindMergeBaseTests(TestCase):
         }
         # assumes a theoretical merge M exists that merges B and C first
         # which actually means find the first LCA from either of B OR C with A
-        self.assertEqual(self.run_test(graph, ["A", "B", "C"]), set(["1"]))
+        self.assertEqual(self.run_test(graph, ["A", "B", "C"]), {"1"})
 
     def test_octopus(self):
         # octopus algorithm test
@@ -156,7 +155,7 @@ class FindMergeBaseTests(TestCase):
                 res = _find_lcas(lookup_parents, cmt, [ca])
                 next_lcas.extend(res)
             lcas = next_lcas[:]
-        self.assertEqual(set(lcas), set(["2"]))
+        self.assertEqual(set(lcas), {"2"})
 
 
 class CanFastForwardTests(TestCase):

+ 1 - 39
dulwich/tests/test_greenthreads.py

@@ -28,7 +28,6 @@ from dulwich.tests import (
 )
 from dulwich.object_store import (
     MemoryObjectStore,
-    MissingObjectFinder,
 )
 from dulwich.objects import (
     Commit,
@@ -46,7 +45,6 @@ except ImportError:
 
 if gevent_support:
     from dulwich.greenthreads import (
-        GreenThreadsObjectStoreIterator,
         GreenThreadsMissingObjectFinder,
     )
 
@@ -77,46 +75,10 @@ def init_store(store, count=1):
     return ret
 
 
-@skipIf(not gevent_support, skipmsg)
-class TestGreenThreadsObjectStoreIterator(TestCase):
-    def setUp(self):
-        super(TestGreenThreadsObjectStoreIterator, self).setUp()
-        self.store = MemoryObjectStore()
-        self.cmt_amount = 10
-        self.objs = init_store(self.store, self.cmt_amount)
-
-    def test_len(self):
-        wants = [sha.id for sha in self.objs if isinstance(sha, Commit)]
-        finder = MissingObjectFinder(self.store, (), wants)
-        iterator = GreenThreadsObjectStoreIterator(
-            self.store, iter(finder.next, None), finder
-        )
-        # One commit refers one tree and one blob
-        self.assertEqual(len(iterator), self.cmt_amount * 3)
-        haves = wants[0 : self.cmt_amount - 1]
-        finder = MissingObjectFinder(self.store, haves, wants)
-        iterator = GreenThreadsObjectStoreIterator(
-            self.store, iter(finder.next, None), finder
-        )
-        self.assertEqual(len(iterator), 3)
-
-    def test_iter(self):
-        wants = [sha.id for sha in self.objs if isinstance(sha, Commit)]
-        finder = MissingObjectFinder(self.store, (), wants)
-        iterator = GreenThreadsObjectStoreIterator(
-            self.store, iter(finder.next, None), finder
-        )
-        objs = []
-        for sha, path in iterator:
-            self.assertIn(sha, self.objs)
-            objs.append(sha)
-        self.assertEqual(len(objs), len(self.objs))
-
-
 @skipIf(not gevent_support, skipmsg)
 class TestGreenThreadsMissingObjectFinder(TestCase):
     def setUp(self):
-        super(TestGreenThreadsMissingObjectFinder, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
         self.cmt_amount = 10
         self.objs = init_store(self.store, self.cmt_amount)

+ 1 - 1
dulwich/tests/test_hooks.py

@@ -38,7 +38,7 @@ from dulwich.tests import TestCase
 
 class ShellHookTests(TestCase):
     def setUp(self):
-        super(ShellHookTests, self).setUp()
+        super().setUp()
         if os.name != "posix":
             self.skipTest("shell hook tests requires POSIX shell")
         self.assertTrue(os.path.exists("/bin/sh"))

+ 2 - 2
dulwich/tests/test_ignore.py

@@ -133,14 +133,14 @@ class MatchPatternTests(TestCase):
         for (path, pattern) in POSITIVE_MATCH_TESTS:
             self.assertTrue(
                 match_pattern(path, pattern),
-                "path: %r, pattern: %r" % (path, pattern),
+                "path: {!r}, pattern: {!r}".format(path, pattern),
             )
 
     def test_no_matches(self):
         for (path, pattern) in NEGATIVE_MATCH_TESTS:
             self.assertFalse(
                 match_pattern(path, pattern),
-                "path: %r, pattern: %r" % (path, pattern),
+                "path: {!r}, pattern: {!r}".format(path, pattern),
             )
 
 

+ 12 - 13
dulwich/tests/test_index.py

@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # test_index.py -- Tests for the git index
 # encoding: utf-8
 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@jelmer.uk>
@@ -203,7 +202,7 @@ class ReadIndexDictTests(IndexTestCase):
 
 class CommitTreeTests(TestCase):
     def setUp(self):
-        super(CommitTreeTests, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
 
     def test_single_blob(self):
@@ -214,7 +213,7 @@ class CommitTreeTests(TestCase):
         rootid = commit_tree(self.store, blobs)
         self.assertEqual(rootid, b"1a1e80437220f9312e855c37ac4398b68e5c1d50")
         self.assertEqual((stat.S_IFREG, blob.id), self.store[rootid][b"bla"])
-        self.assertEqual(set([rootid, blob.id]), set(self.store._data.keys()))
+        self.assertEqual({rootid, blob.id}, set(self.store._data.keys()))
 
     def test_nested(self):
         blob = Blob()
@@ -227,12 +226,12 @@ class CommitTreeTests(TestCase):
         self.assertEqual(dirid, b"c1a1deb9788150829579a8b4efa6311e7b638650")
         self.assertEqual((stat.S_IFDIR, dirid), self.store[rootid][b"bla"])
         self.assertEqual((stat.S_IFREG, blob.id), self.store[dirid][b"bar"])
-        self.assertEqual(set([rootid, dirid, blob.id]), set(self.store._data.keys()))
+        self.assertEqual({rootid, dirid, blob.id}, set(self.store._data.keys()))
 
 
 class CleanupModeTests(TestCase):
     def assertModeEqual(self, expected, got):
-        self.assertEqual(expected, got, "%o != %o" % (expected, got))
+        self.assertEqual(expected, got, "{:o} != {:o}".format(expected, got))
 
     def test_file(self):
         self.assertModeEqual(0o100644, cleanup_mode(0o100000))
@@ -547,9 +546,9 @@ class BuildIndexTests(TestCase):
             file = Blob.from_string(b"foo")
 
             tree = Tree()
-            latin1_name = u"À".encode("latin1")
+            latin1_name = "À".encode("latin1")
             latin1_path = os.path.join(repo_dir_bytes, latin1_name)
-            utf8_name = u"À".encode("utf8")
+            utf8_name = "À".encode()
             utf8_path = os.path.join(repo_dir_bytes, utf8_name)
             tree[latin1_name] = (stat.S_IFREG | 0o644, file.id)
             tree[utf8_name] = (stat.S_IFREG | 0o644, file.id)
@@ -795,19 +794,19 @@ class TestValidatePathElement(TestCase):
 
 class TestTreeFSPathConversion(TestCase):
     def test_tree_to_fs_path(self):
-        tree_path = u"délwíçh/foo".encode("utf8")
+        tree_path = "délwíçh/foo".encode()
         fs_path = _tree_to_fs_path(b"/prefix/path", tree_path)
         self.assertEqual(
             fs_path,
-            os.fsencode(os.path.join(u"/prefix/path", u"délwíçh", u"foo")),
+            os.fsencode(os.path.join("/prefix/path", "délwíçh", "foo")),
         )
 
     def test_fs_to_tree_path_str(self):
-        fs_path = os.path.join(os.path.join(u"délwíçh", u"foo"))
+        fs_path = os.path.join(os.path.join("délwíçh", "foo"))
         tree_path = _fs_to_tree_path(fs_path)
-        self.assertEqual(tree_path, u"délwíçh/foo".encode("utf-8"))
+        self.assertEqual(tree_path, "délwíçh/foo".encode())
 
     def test_fs_to_tree_path_bytes(self):
-        fs_path = os.path.join(os.fsencode(os.path.join(u"délwíçh", u"foo")))
+        fs_path = os.path.join(os.fsencode(os.path.join("délwíçh", "foo")))
         tree_path = _fs_to_tree_path(fs_path)
-        self.assertEqual(tree_path, u"délwíçh/foo".encode("utf-8"))
+        self.assertEqual(tree_path, "délwíçh/foo".encode())

+ 1 - 1
dulwich/tests/test_lfs.py

@@ -28,7 +28,7 @@ import tempfile
 
 class LFSTests(TestCase):
     def setUp(self):
-        super(LFSTests, self).setUp()
+        super().setUp()
         self.test_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self.test_dir)
         self.lfs = LFSStore.create(self.test_dir)

+ 0 - 1
dulwich/tests/test_line_ending.py

@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # test_line_ending.py -- Tests for the line ending functions
 # encoding: utf-8
 # Copyright (C) 2018-2019 Boris Feld <boris.feld@comet.ml>

+ 9 - 9
dulwich/tests/test_missing_obj_finder.py

@@ -20,6 +20,7 @@
 
 from dulwich.object_store import (
     MemoryObjectStore,
+    MissingObjectFinder,
 )
 from dulwich.objects import (
     Blob,
@@ -34,7 +35,7 @@ from dulwich.tests.utils import (
 
 class MissingObjectFinderTest(TestCase):
     def setUp(self):
-        super(MissingObjectFinderTest, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
         self.commits = []
 
@@ -42,24 +43,24 @@ class MissingObjectFinderTest(TestCase):
         return self.commits[n - 1]
 
     def assertMissingMatch(self, haves, wants, expected):
-        for sha, path in self.store.find_missing_objects(haves, wants, set()):
+        for sha, path in MissingObjectFinder(self.store, haves, wants, shallow=set()):
             self.assertIn(
                 sha,
                 expected,
-                "(%s,%s) erroneously reported as missing" % (sha, path)
+                "({},{}) erroneously reported as missing".format(sha, path)
             )
             expected.remove(sha)
 
         self.assertEqual(
             len(expected),
             0,
-            "some objects are not reported as missing: %s" % (expected,),
+            "some objects are not reported as missing: {}".format(expected),
         )
 
 
 class MOFLinearRepoTest(MissingObjectFinderTest):
     def setUp(self):
-        super(MOFLinearRepoTest, self).setUp()
+        super().setUp()
         # present in 1, removed in 3
         f1_1 = make_object(Blob, data=b"f1")
         # present in all revisions, changed in 2 and 3
@@ -115,8 +116,7 @@ class MOFLinearRepoTest(MissingObjectFinderTest):
         haves = [self.cmt(1).id]
         wants = [self.cmt(3).id, bogus_sha]
         self.assertRaises(
-            KeyError, self.store.find_missing_objects, haves, wants, set()
-        )
+            KeyError, MissingObjectFinder, self.store, haves, wants, shallow=set())
 
     def test_no_changes(self):
         self.assertMissingMatch([self.cmt(3).id], [self.cmt(3).id], [])
@@ -130,7 +130,7 @@ class MOFMergeForkRepoTest(MissingObjectFinderTest):
     #             5
 
     def setUp(self):
-        super(MOFMergeForkRepoTest, self).setUp()
+        super().setUp()
         f1_1 = make_object(Blob, data=b"f1")
         f1_2 = make_object(Blob, data=b"f1-2")
         f1_4 = make_object(Blob, data=b"f1-4")
@@ -256,7 +256,7 @@ class MOFMergeForkRepoTest(MissingObjectFinderTest):
 
 class MOFTagsTest(MissingObjectFinderTest):
     def setUp(self):
-        super(MOFTagsTest, self).setUp()
+        super().setUp()
         f1_1 = make_object(Blob, data=b"f1")
         commit_spec = [[1]]
         trees = {1: [(b"f1", f1_1)]}

+ 14 - 11
dulwich/tests/test_object_store.py

@@ -51,6 +51,8 @@ from dulwich.object_store import (
     OverlayObjectStore,
     ObjectStoreGraphWalker,
     commit_tree_changes,
+    iter_tree_contents,
+    peel_sha,
     read_packs_file,
     tree_lookup_path,
 )
@@ -77,7 +79,7 @@ except ImportError:
 testobject = make_object(Blob, data=b"yummy data")
 
 
-class ObjectStoreTests(object):
+class ObjectStoreTests:
     def test_determine_wants_all(self):
         self.assertEqual(
             [b"1" * 40],
@@ -164,7 +166,7 @@ class ObjectStoreTests(object):
 
     def test_add_object(self):
         self.store.add_object(testobject)
-        self.assertEqual(set([testobject.id]), set(self.store))
+        self.assertEqual({testobject.id}, set(self.store))
         self.assertIn(testobject.id, self.store)
         r = self.store[testobject.id]
         self.assertEqual(r, testobject)
@@ -172,7 +174,7 @@ class ObjectStoreTests(object):
     def test_add_objects(self):
         data = [(testobject, "mypath")]
         self.store.add_objects(data)
-        self.assertEqual(set([testobject.id]), set(self.store))
+        self.assertEqual({testobject.id}, set(self.store))
         self.assertIn(testobject.id, self.store)
         r = self.store[testobject.id]
         self.assertEqual(r, testobject)
@@ -219,8 +221,9 @@ class ObjectStoreTests(object):
         tree_id = commit_tree(self.store, blobs)
         self.assertEqual(
             [TreeEntry(p, m, h) for (p, h, m) in blobs],
-            list(self.store.iter_tree_contents(tree_id)),
+            list(iter_tree_contents(self.store, tree_id)),
         )
+        self.assertEqual([], list(iter_tree_contents(self.store, None)))
 
     def test_iter_tree_contents_include_trees(self):
         blob_a = make_object(Blob, data=b"a")
@@ -247,7 +250,7 @@ class ObjectStoreTests(object):
             TreeEntry(b"ad/bd", 0o040000, tree_bd.id),
             TreeEntry(b"ad/bd/c", 0o100755, blob_c.id),
         ]
-        actual = self.store.iter_tree_contents(tree_id, include_trees=True)
+        actual = iter_tree_contents(self.store, tree_id, include_trees=True)
         self.assertEqual(expected, list(actual))
 
     def make_tag(self, name, obj):
@@ -261,7 +264,7 @@ class ObjectStoreTests(object):
         tag2 = self.make_tag(b"2", testobject)
         tag3 = self.make_tag(b"3", testobject)
         for obj in [testobject, tag1, tag2, tag3]:
-            self.assertEqual(testobject, self.store.peel_sha(obj.id))
+            self.assertEqual(testobject, peel_sha(self.store, obj.id))
 
     def test_get_raw(self):
         self.store.add_object(testobject)
@@ -302,7 +305,7 @@ class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
     def test_add_pack_emtpy(self):
         o = MemoryObjectStore()
         f, commit, abort = o.add_pack()
-        commit()
+        self.assertRaises(AssertionError, commit)
 
     def test_add_thin_pack(self):
         o = MemoryObjectStore()
@@ -626,9 +629,9 @@ class TreeLookupPathTests(TestCase):
 
 class ObjectStoreGraphWalkerTests(TestCase):
     def get_walker(self, heads, parent_map):
-        new_parent_map = dict(
-            [(k * 40, [(p * 40) for p in ps]) for (k, ps) in parent_map.items()]
-        )
+        new_parent_map = {
+            k * 40: [(p * 40) for p in ps] for (k, ps) in parent_map.items()
+        }
         return ObjectStoreGraphWalker(
             [x * 40 for x in heads], new_parent_map.__getitem__
         )
@@ -707,7 +710,7 @@ class ObjectStoreGraphWalkerTests(TestCase):
 
 class CommitTreeChangesTests(TestCase):
     def setUp(self):
-        super(CommitTreeChangesTests, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
         self.blob_a = make_object(Blob, data=b"a")
         self.blob_b = make_object(Blob, data=b"b")

+ 5 - 5
dulwich/tests/test_objects.py

@@ -980,7 +980,7 @@ class TreeTests(ShaFileCheckTests):
     def test_iter(self):
         t = Tree()
         t[b"foo"] = (0o100644, a_sha)
-        self.assertEqual(set([b"foo"]), set(t))
+        self.assertEqual({b"foo"}, set(t))
 
 
 class TagSerializeTests(TestCase):
@@ -1143,7 +1143,7 @@ class TagParseTests(ShaFileCheckTests):
 
     def test_check_tag_with_overflow_time(self):
         """Date with overflow should raise an ObjectFormatException when checked"""
-        author = "Some Dude <some@dude.org> %s +0000" % (MAX_TIME + 1,)
+        author = "Some Dude <some@dude.org> {} +0000".format(MAX_TIME + 1)
         tag = Tag.from_string(self.make_tag_text(tagger=(author.encode())))
         with self.assertRaises(ObjectFormatException):
             tag.check()
@@ -1301,14 +1301,14 @@ class TimezoneTests(TestCase):
         self.assertEqual(b"-0440", format_timezone(int(((-4 * 60) - 40) * 60)))
 
     def test_format_timezone_double_negative(self):
-        self.assertEqual(b"--700", format_timezone(int(((7 * 60)) * 60), True))
+        self.assertEqual(b"--700", format_timezone(int((7 * 60) * 60), True))
 
     def test_parse_timezone_pdt_half(self):
         self.assertEqual((((-4 * 60) - 40) * 60, False), parse_timezone(b"-0440"))
 
     def test_parse_timezone_double_negative(self):
-        self.assertEqual((int(((7 * 60)) * 60), False), parse_timezone(b"+700"))
-        self.assertEqual((int(((7 * 60)) * 60), True), parse_timezone(b"--700"))
+        self.assertEqual((int((7 * 60) * 60), False), parse_timezone(b"+700"))
+        self.assertEqual((int((7 * 60) * 60), True), parse_timezone(b"--700"))
 
 
 class ShaFileCopyTests(TestCase):

+ 81 - 63
dulwich/tests/test_pack.py

@@ -91,7 +91,7 @@ class PackTests(TestCase):
     """Base class for testing packs"""
 
     def setUp(self):
-        super(PackTests, self).setUp()
+        super().setUp()
         self.tempdir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self.tempdir)
 
@@ -122,13 +122,13 @@ class PackTests(TestCase):
 class PackIndexTests(PackTests):
     """Class that tests the index of packfiles"""
 
-    def test_object_index(self):
+    def test_object_offset(self):
         """Tests that the correct object offset is returned from the index."""
         p = self.get_pack_index(pack1_sha)
-        self.assertRaises(KeyError, p.object_index, pack1_sha)
-        self.assertEqual(p.object_index(a_sha), 178)
-        self.assertEqual(p.object_index(tree_sha), 138)
-        self.assertEqual(p.object_index(commit_sha), 12)
+        self.assertRaises(KeyError, p.object_offset, pack1_sha)
+        self.assertEqual(p.object_offset(a_sha), 178)
+        self.assertEqual(p.object_offset(tree_sha), 138)
+        self.assertEqual(p.object_offset(commit_sha), 12)
 
     def test_object_sha1(self):
         """Tests that the correct object offset is returned from the index."""
@@ -171,7 +171,7 @@ class PackIndexTests(PackTests):
 
     def test_iter(self):
         p = self.get_pack_index(pack1_sha)
-        self.assertEqual(set([tree_sha, commit_sha, a_sha]), set(p))
+        self.assertEqual({tree_sha, commit_sha, a_sha}, set(p))
 
 
 class TestPackDeltas(TestCase):
@@ -284,7 +284,7 @@ class TestPackData(PackTests):
         with self.get_pack_data(pack1_sha) as p:
             self.assertSucceeds(p.check)
 
-    def test_iterobjects(self):
+    def test_iter_unpacked(self):
         with self.get_pack_data(pack1_sha) as p:
             commit_data = (
                 b"tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\n"
@@ -297,14 +297,12 @@ class TestPackData(PackTests):
             )
             blob_sha = b"6f670c0fb53f9463760b7295fbb814e965fb20c8"
             tree_data = b"100644 a\0" + hex_to_sha(blob_sha)
-            actual = []
-            for offset, type_num, chunks, crc32 in p.iterobjects():
-                actual.append((offset, type_num, b"".join(chunks), crc32))
+            actual = list(p.iter_unpacked())
             self.assertEqual(
                 [
-                    (12, 1, commit_data, 3775879613),
-                    (138, 2, tree_data, 912998690),
-                    (178, 3, b"test 1\n", 1373561701),
+                    UnpackedObject(offset=12, pack_type_num=1, decomp_chunks=[commit_data], crc32=None),
+                    UnpackedObject(offset=138, pack_type_num=2, decomp_chunks=[tree_data], crc32=None),
+                    UnpackedObject(offset=178, pack_type_num=3, decomp_chunks=[b"test 1\n"], crc32=None),
                 ],
                 actual,
             )
@@ -313,25 +311,23 @@ class TestPackData(PackTests):
         with self.get_pack_data(pack1_sha) as p:
             entries = {(sha_to_hex(s), o, c) for s, o, c in p.iterentries()}
             self.assertEqual(
-                set(
-                    [
-                        (
-                            b"6f670c0fb53f9463760b7295fbb814e965fb20c8",
-                            178,
-                            1373561701,
-                        ),
-                        (
-                            b"b2a2766a2879c209ab1176e7e778b81ae422eeaa",
-                            138,
-                            912998690,
-                        ),
-                        (
-                            b"f18faa16531ac570a3fdc8c7ca16682548dafd12",
-                            12,
-                            3775879613,
-                        ),
-                    ]
-                ),
+                {
+                    (
+                        b"6f670c0fb53f9463760b7295fbb814e965fb20c8",
+                        178,
+                        1373561701,
+                    ),
+                    (
+                        b"b2a2766a2879c209ab1176e7e778b81ae422eeaa",
+                        138,
+                        912998690,
+                    ),
+                    (
+                        b"f18faa16531ac570a3fdc8c7ca16682548dafd12",
+                        12,
+                        3775879613,
+                    ),
+                },
                 entries,
             )
 
@@ -399,17 +395,17 @@ class TestPack(PackTests):
 
     def test_iter(self):
         with self.get_pack(pack1_sha) as p:
-            self.assertEqual(set([tree_sha, commit_sha, a_sha]), set(p))
+            self.assertEqual({tree_sha, commit_sha, a_sha}, set(p))
 
     def test_iterobjects(self):
         with self.get_pack(pack1_sha) as p:
-            expected = set([p[s] for s in [commit_sha, tree_sha, a_sha]])
+            expected = {p[s] for s in [commit_sha, tree_sha, a_sha]}
             self.assertEqual(expected, set(list(p.iterobjects())))
 
     def test_pack_tuples(self):
         with self.get_pack(pack1_sha) as p:
             tuples = p.pack_tuples()
-            expected = set([(p[s], None) for s in [commit_sha, tree_sha, a_sha]])
+            expected = {(p[s], None) for s in [commit_sha, tree_sha, a_sha]}
             self.assertEqual(expected, set(list(tuples)))
             self.assertEqual(expected, set(list(tuples)))
             self.assertEqual(3, len(tuples))
@@ -468,7 +464,7 @@ class TestPack(PackTests):
         # file should exist
         self.assertTrue(os.path.exists(keepfile_name))
 
-        with open(keepfile_name, "r") as f:
+        with open(keepfile_name) as f:
             buf = f.read()
             self.assertEqual("", buf)
 
@@ -535,7 +531,7 @@ class TestPack(PackTests):
 
 class TestThinPack(PackTests):
     def setUp(self):
-        super(TestThinPack, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
         self.blobs = {}
         for blob in (b"foo", b"bar", b"foo1234", b"bar2468"):
@@ -580,24 +576,28 @@ class TestThinPack(PackTests):
         with self.make_pack(True) as p:
             self.assertEqual((3, b"foo1234"), p.get_raw(self.blobs[b"foo1234"].id))
 
-    def test_get_raw_unresolved(self):
+    def test_get_unpacked_object(self):
+        self.maxDiff = None
         with self.make_pack(False) as p:
-            self.assertEqual(
-                (
-                    7,
-                    b"\x19\x10(\x15f=#\xf8\xb7ZG\xe7\xa0\x19e\xdc\xdc\x96F\x8c",
-                    [b"x\x9ccf\x9f\xc0\xccbhdl\x02\x00\x06f\x01l"],
-                ),
-                p.get_raw_unresolved(self.blobs[b"foo1234"].id),
+            expected = UnpackedObject(
+                7,
+                delta_base=b"\x19\x10(\x15f=#\xf8\xb7ZG\xe7\xa0\x19e\xdc\xdc\x96F\x8c",
+                decomp_chunks=[b'\x03\x07\x90\x03\x041234'],
             )
+            expected.offset = 12
+            got = p.get_unpacked_object(self.blobs[b"foo1234"].id)
+            self.assertEqual(expected, got)
         with self.make_pack(True) as p:
+            expected = UnpackedObject(
+                7,
+                delta_base=b"\x19\x10(\x15f=#\xf8\xb7ZG\xe7\xa0\x19e\xdc\xdc\x96F\x8c",
+                decomp_chunks=[b'\x03\x07\x90\x03\x041234'],
+            )
+            expected.offset = 12
+            got = p.get_unpacked_object(self.blobs[b"foo1234"].id)
             self.assertEqual(
-                (
-                    7,
-                    b"\x19\x10(\x15f=#\xf8\xb7ZG\xe7\xa0\x19e\xdc\xdc\x96F\x8c",
-                    [b"x\x9ccf\x9f\xc0\xccbhdl\x02\x00\x06f\x01l"],
-                ),
-                p.get_raw_unresolved(self.blobs[b"foo1234"].id),
+                expected,
+                got,
             )
 
     def test_iterobjects(self):
@@ -664,7 +664,7 @@ class WritePackTests(TestCase):
 pack_checksum = hex_to_sha("721980e866af9a5f93ad674144e1459b8ba3e7b7")
 
 
-class BaseTestPackIndexWriting(object):
+class BaseTestPackIndexWriting:
     def assertSucceeds(self, func, *args, **kwargs):
         try:
             func(*args, **kwargs)
@@ -801,9 +801,9 @@ class ReadZlibTests(TestCase):
     extra = b"nextobject"
 
     def setUp(self):
-        super(ReadZlibTests, self).setUp()
+        super().setUp()
         self.read = BytesIO(self.comp + self.extra).read
-        self.unpacked = UnpackedObject(Tree.type_num, None, len(self.decomp), 0)
+        self.unpacked = UnpackedObject(Tree.type_num, decomp_len=len(self.decomp), crc32=0)
 
     def test_decompress_size(self):
         good_decomp_len = len(self.decomp)
@@ -822,7 +822,7 @@ class ReadZlibTests(TestCase):
         self.assertRaises(zlib.error, read_zlib_chunks, read, self.unpacked)
 
     def test_decompress_empty(self):
-        unpacked = UnpackedObject(Tree.type_num, None, 0, None)
+        unpacked = UnpackedObject(Tree.type_num, decomp_len=0)
         comp = zlib.compress(b"")
         read = BytesIO(comp + self.extra).read
         unused = read_zlib_chunks(read, unpacked)
@@ -874,7 +874,7 @@ class DeltifyTests(TestCase):
     def test_single(self):
         b = Blob.from_string(b"foo")
         self.assertEqual(
-            [(b.type_num, b.sha().digest(), None, b.as_raw_chunks())],
+            [UnpackedObject(b.type_num, sha=b.sha().digest(), delta_base=None, decomp_chunks=b.as_raw_chunks())],
             list(deltify_pack_objects([(b, b"")])),
         )
 
@@ -884,8 +884,8 @@ class DeltifyTests(TestCase):
         delta = list(create_delta(b1.as_raw_chunks(), b2.as_raw_chunks()))
         self.assertEqual(
             [
-                (b1.type_num, b1.sha().digest(), None, b1.as_raw_chunks()),
-                (b2.type_num, b2.sha().digest(), b1.sha().digest(), delta),
+                UnpackedObject(b1.type_num, sha=b1.sha().digest(), delta_base=None, decomp_chunks=b1.as_raw_chunks()),
+                UnpackedObject(b2.type_num, sha=b2.sha().digest(), delta_base=b1.sha().digest(), decomp_chunks=delta),
             ],
             list(deltify_pack_objects([(b1, b""), (b2, b"")])),
         )
@@ -945,7 +945,7 @@ class TestPackStreamReader(TestCase):
 
     def test_read_objects_empty(self):
         reader = PackStreamReader(BytesIO().read)
-        self.assertEqual([], list(reader.read_objects()))
+        self.assertRaises(AssertionError, list, reader.read_objects())
 
 
 class TestPackIterator(DeltaChainIterator):
@@ -953,7 +953,7 @@ class TestPackIterator(DeltaChainIterator):
     _compute_crc32 = True
 
     def __init__(self, *args, **kwargs):
-        super(TestPackIterator, self).__init__(*args, **kwargs)
+        super().__init__(*args, **kwargs)
         self._unpacked_offsets = set()
 
     def _result(self, unpacked):
@@ -971,14 +971,14 @@ class TestPackIterator(DeltaChainIterator):
             "Attempted to re-inflate offset %i" % offset
         )
         self._unpacked_offsets.add(offset)
-        return super(TestPackIterator, self)._resolve_object(
+        return super()._resolve_object(
             offset, pack_type_num, base_chunks
         )
 
 
 class DeltaChainIteratorTests(TestCase):
     def setUp(self):
-        super(DeltaChainIteratorTests, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
         self.fetched = set()
 
@@ -1008,6 +1008,16 @@ class DeltaChainIteratorTests(TestCase):
         data = PackData("test.pack", file=f)
         return TestPackIterator.for_pack_data(data, resolve_ext_ref=resolve_ext_ref)
 
+    def make_pack_iter_subset(self, f, subset, thin=None):
+        if thin is None:
+            thin = bool(list(self.store))
+        resolve_ext_ref = thin and self.get_raw_no_repeat or None
+        data = PackData("test.pack", file=f)
+        assert data
+        index = MemoryPackIndex.for_pack(data)
+        pack = Pack.from_objects(data, index)
+        return TestPackIterator.for_pack_subset(pack, subset, resolve_ext_ref=resolve_ext_ref)
+
     def assertEntriesMatch(self, expected_indexes, entries, pack_iter):
         expected = [entries[i] for i in expected_indexes]
         self.assertEqual(expected, list(pack_iter._walk_all_chains()))
@@ -1023,6 +1033,10 @@ class DeltaChainIteratorTests(TestCase):
             ],
         )
         self.assertEntriesMatch([0, 1, 2], entries, self.make_pack_iter(f))
+        f.seek(0)
+        self.assertEntriesMatch([], entries, self.make_pack_iter_subset(f, []))
+        f.seek(0)
+        self.assertEntriesMatch([1, 0], entries, self.make_pack_iter_subset(f, [entries[0][3], entries[1][3]]))
 
     def test_ofs_deltas(self):
         f = BytesIO()
@@ -1036,6 +1050,10 @@ class DeltaChainIteratorTests(TestCase):
         )
         # Delta resolution changed to DFS
         self.assertEntriesMatch([0, 2, 1], entries, self.make_pack_iter(f))
+        f.seek(0)
+        self.assertEntriesMatch(
+            [0, 2, 1], entries,
+            self.make_pack_iter_subset(f, [entries[1][3], entries[2][3]]))
 
     def test_ofs_deltas_chain(self):
         f = BytesIO()

+ 114 - 49
dulwich/tests/test_porcelain.py

@@ -86,7 +86,7 @@ def flat_walk_dir(dir_to_walk):
 
 class PorcelainTestCase(TestCase):
     def setUp(self):
-        super(PorcelainTestCase, self).setUp()
+        super().setUp()
         self.test_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self.test_dir)
         self.repo_path = os.path.join(self.test_dir, "repo")
@@ -274,19 +274,14 @@ ya6JVZCRbMXfdCy8lVPgtNQ6VlHaj8Wvnn2FLbWWO2n2r3s=
     NON_DEFAULT_KEY_ID = "6A93393F50C5E6ACD3D6FB45B936212EDB4E14C0"
 
     def setUp(self):
-        super(PorcelainGpgTestCase, self).setUp()
+        super().setUp()
         self.gpg_dir = os.path.join(self.test_dir, "gpg")
         os.mkdir(self.gpg_dir, mode=0o700)
         # Ignore errors when deleting GNUPGHOME, because of race conditions
         # (e.g. the gpg-agent socket having been deleted). See
         # https://github.com/jelmer/dulwich/issues/1000
         self.addCleanup(shutil.rmtree, self.gpg_dir, ignore_errors=True)
-        self._old_gnupghome = os.environ.get("GNUPGHOME")
-        os.environ["GNUPGHOME"] = self.gpg_dir
-        if self._old_gnupghome is None:
-            self.addCleanup(os.environ.__delitem__, "GNUPGHOME")
-        else:
-            self.addCleanup(os.environ.__setitem__, "GNUPGHOME", self._old_gnupghome)
+        self.overrideEnv('GNUPGHOME', self.gpg_dir)
 
     def import_default_key(self):
         subprocess.run(
@@ -294,7 +289,7 @@ ya6JVZCRbMXfdCy8lVPgtNQ6VlHaj8Wvnn2FLbWWO2n2r3s=
             stdout=subprocess.DEVNULL,
             stderr=subprocess.DEVNULL,
             input=PorcelainGpgTestCase.DEFAULT_KEY,
-            universal_newlines=True,
+            text=True,
         )
 
     def import_non_default_key(self):
@@ -303,7 +298,7 @@ ya6JVZCRbMXfdCy8lVPgtNQ6VlHaj8Wvnn2FLbWWO2n2r3s=
             stdout=subprocess.DEVNULL,
             stderr=subprocess.DEVNULL,
             input=PorcelainGpgTestCase.NON_DEFAULT_KEY,
-            universal_newlines=True,
+            text=True,
         )
 
 
@@ -441,7 +436,8 @@ class CommitTests(PorcelainTestCase):
         self.assertEqual(commit._author_timezone, 18000)
         self.assertEqual(commit._commit_timezone, 18000)
 
-        os.environ["GIT_AUTHOR_DATE"] = os.environ["GIT_COMMITTER_DATE"] = "1995-11-20T19:12:08-0501"
+        self.overrideEnv("GIT_AUTHOR_DATE", "1995-11-20T19:12:08-0501")
+        self.overrideEnv("GIT_COMMITTER_DATE", "1995-11-20T19:12:08-0501")
 
         sha = porcelain.commit(
             self.repo.path,
@@ -456,8 +452,9 @@ class CommitTests(PorcelainTestCase):
         self.assertEqual(commit._author_timezone, -18060)
         self.assertEqual(commit._commit_timezone, -18060)
 
-        del os.environ["GIT_AUTHOR_DATE"]
-        del os.environ["GIT_COMMITTER_DATE"]
+        self.overrideEnv("GIT_AUTHOR_DATE", None)
+        self.overrideEnv("GIT_COMMITTER_DATE", None)
+
         local_timezone = time.localtime().tm_gmtoff
 
         sha = porcelain.commit(
@@ -541,7 +538,8 @@ class CommitSignTests(PorcelainGpgTestCase):
 class TimezoneTests(PorcelainTestCase):
 
     def put_envs(self, value):
-        os.environ["GIT_AUTHOR_DATE"] = os.environ["GIT_COMMITTER_DATE"] = value
+        self.overrideEnv("GIT_AUTHOR_DATE", value)
+        self.overrideEnv("GIT_COMMITTER_DATE", value)
 
     def fallback(self, value):
         self.put_envs(value)
@@ -588,8 +586,8 @@ class TimezoneTests(PorcelainTestCase):
         self.fallback("20.11.1995")
 
     def test_different_envs(self):
-        os.environ["GIT_AUTHOR_DATE"] = "0 +0500"
-        os.environ["GIT_COMMITTER_DATE"] = "0 +0501"
+        self.overrideEnv("GIT_AUTHOR_DATE", "0 +0500")
+        self.overrideEnv("GIT_COMMITTER_DATE", "0 +0501")
         self.assertTupleEqual((18000, 18060), porcelain.get_user_timezones())
 
     def test_no_envs(self):
@@ -598,16 +596,16 @@ class TimezoneTests(PorcelainTestCase):
         self.put_envs("0 +0500")
         self.assertTupleEqual((18000, 18000), porcelain.get_user_timezones())
 
-        del os.environ["GIT_COMMITTER_DATE"]
+        self.overrideEnv("GIT_COMMITTER_DATE", None)
         self.assertTupleEqual((18000, local_timezone), porcelain.get_user_timezones())
 
         self.put_envs("0 +0500")
-        del os.environ["GIT_AUTHOR_DATE"]
+        self.overrideEnv("GIT_AUTHOR_DATE", None)
         self.assertTupleEqual((local_timezone, 18000), porcelain.get_user_timezones())
 
         self.put_envs("0 +0500")
-        del os.environ["GIT_AUTHOR_DATE"]
-        del os.environ["GIT_COMMITTER_DATE"]
+        self.overrideEnv("GIT_AUTHOR_DATE", None)
+        self.overrideEnv("GIT_COMMITTER_DATE", None)
         self.assertTupleEqual((local_timezone, local_timezone), porcelain.get_user_timezones())
 
 
@@ -908,7 +906,7 @@ class AddTests(PorcelainTestCase):
         cwd = os.getcwd()
         try:
             os.chdir(self.repo.path)
-            self.assertEqual(set(["foo", "blah", "adir", ".git"]), set(os.listdir(".")))
+            self.assertEqual({"foo", "blah", "adir", ".git"}, set(os.listdir(".")))
             self.assertEqual(
                 (["foo", os.path.join("adir", "afile")], set()),
                 porcelain.add(self.repo.path),
@@ -969,8 +967,8 @@ class AddTests(PorcelainTestCase):
             ],
         )
         self.assertIn(b"bar", self.repo.open_index())
-        self.assertEqual(set(["bar"]), set(added))
-        self.assertEqual(set(["foo", os.path.join("subdir", "")]), ignored)
+        self.assertEqual({"bar"}, set(added))
+        self.assertEqual({"foo", os.path.join("subdir", "")}, ignored)
 
     def test_add_file_absolute_path(self):
         # Absolute paths are (not yet) supported
@@ -1556,7 +1554,7 @@ class ResetFileTests(PorcelainTestCase):
             f.write('something new')
         porcelain.reset_file(self.repo, file, target=sha)
 
-        with open(full_path, 'r') as f:
+        with open(full_path) as f:
             self.assertEqual('hello', f.read())
 
     def test_reset_remove_file_to_commit(self):
@@ -1575,7 +1573,7 @@ class ResetFileTests(PorcelainTestCase):
         os.remove(full_path)
         porcelain.reset_file(self.repo, file, target=sha)
 
-        with open(full_path, 'r') as f:
+        with open(full_path) as f:
             self.assertEqual('hello', f.read())
 
     def test_resetfile_with_dir(self):
@@ -1600,7 +1598,7 @@ class ResetFileTests(PorcelainTestCase):
             author=b"John <john@example.com>",
         )
         porcelain.reset_file(self.repo, os.path.join('new_dir', 'foo'), target=sha)
-        with open(full_path, 'r') as f:
+        with open(full_path) as f:
             self.assertEqual('hello', f.read())
 
 
@@ -1618,7 +1616,7 @@ class SubmoduleTests(PorcelainTestCase):
 
     def test_add(self):
         porcelain.submodule_add(self.repo, "../bar.git", "bar")
-        with open('%s/.gitmodules' % self.repo.path, 'r') as f:
+        with open('%s/.gitmodules' % self.repo.path) as f:
             self.assertEqual("""\
 [submodule "bar"]
 \turl = ../bar.git
@@ -1903,7 +1901,7 @@ class PushTests(PorcelainTestCase):
 
 class PullTests(PorcelainTestCase):
     def setUp(self):
-        super(PullTests, self).setUp()
+        super().setUp()
         # create a file for initial commit
         handle, fullpath = tempfile.mkstemp(dir=self.repo.path)
         os.close(handle)
@@ -2303,7 +2301,7 @@ class StatusTests(PorcelainTestCase):
             os.path.join(self.repo.path, "link"),
         )
         self.assertEqual(
-            set(["ignored", "notignored", ".gitignore", "link"]),
+            {"ignored", "notignored", ".gitignore", "link"},
             set(
                 porcelain.get_untracked_paths(
                     self.repo.path, self.repo.path, self.repo.open_index()
@@ -2311,11 +2309,11 @@ class StatusTests(PorcelainTestCase):
             ),
         )
         self.assertEqual(
-            set([".gitignore", "notignored", "link"]),
+            {".gitignore", "notignored", "link"},
             set(porcelain.status(self.repo).untracked),
         )
         self.assertEqual(
-            set([".gitignore", "notignored", "ignored", "link"]),
+            {".gitignore", "notignored", "ignored", "link"},
             set(porcelain.status(self.repo, ignored=True).untracked),
         )
 
@@ -2334,7 +2332,7 @@ class StatusTests(PorcelainTestCase):
             f.write("blop\n")
 
         self.assertEqual(
-            set([".gitignore", "notignored", os.path.join("nested", "")]),
+            {".gitignore", "notignored", os.path.join("nested", "")},
             set(
                 porcelain.get_untracked_paths(
                     self.repo.path, self.repo.path, self.repo.open_index()
@@ -2342,7 +2340,7 @@ class StatusTests(PorcelainTestCase):
             ),
         )
         self.assertEqual(
-            set([".gitignore", "notignored"]),
+            {".gitignore", "notignored"},
             set(
                 porcelain.get_untracked_paths(
                     self.repo.path,
@@ -2353,7 +2351,7 @@ class StatusTests(PorcelainTestCase):
             ),
         )
         self.assertEqual(
-            set(["ignored", "with", "manager"]),
+            {"ignored", "with", "manager"},
             set(
                 porcelain.get_untracked_paths(
                     subrepo.path, subrepo.path, subrepo.open_index()
@@ -2371,9 +2369,9 @@ class StatusTests(PorcelainTestCase):
             ),
         )
         self.assertEqual(
-            set([os.path.join('nested', 'ignored'),
+            {os.path.join('nested', 'ignored'),
                 os.path.join('nested', 'with'),
-                os.path.join('nested', 'manager')]),
+                os.path.join('nested', 'manager')},
             set(
                 porcelain.get_untracked_paths(
                     self.repo.path,
@@ -2395,14 +2393,12 @@ class StatusTests(PorcelainTestCase):
             f.write("foo")
 
         self.assertEqual(
-            set(
-                [
-                    ".gitignore",
-                    "notignored",
-                    "ignored",
-                    os.path.join("subdir", ""),
-                ]
-            ),
+            {
+                ".gitignore",
+                "notignored",
+                "ignored",
+                os.path.join("subdir", ""),
+            },
             set(
                 porcelain.get_untracked_paths(
                     self.repo.path,
@@ -2412,7 +2408,7 @@ class StatusTests(PorcelainTestCase):
             )
         )
         self.assertEqual(
-            set([".gitignore", "notignored"]),
+            {".gitignore", "notignored"},
             set(
                 porcelain.get_untracked_paths(
                     self.repo.path,
@@ -2490,14 +2486,14 @@ class ReceivePackTests(PorcelainTestCase):
 
 class BranchListTests(PorcelainTestCase):
     def test_standard(self):
-        self.assertEqual(set([]), set(porcelain.branch_list(self.repo)))
+        self.assertEqual(set(), set(porcelain.branch_list(self.repo)))
 
     def test_new_branch(self):
         [c1] = build_commit_graph(self.repo.object_store, [[1]])
         self.repo[b"HEAD"] = c1.id
         porcelain.branch_create(self.repo, b"foo")
         self.assertEqual(
-            set([b"master", b"foo"]), set(porcelain.branch_list(self.repo))
+            {b"master", b"foo"}, set(porcelain.branch_list(self.repo))
         )
 
 
@@ -2514,7 +2510,7 @@ class BranchCreateTests(PorcelainTestCase):
         self.repo[b"HEAD"] = c1.id
         porcelain.branch_create(self.repo, b"foo")
         self.assertEqual(
-            set([b"master", b"foo"]), set(porcelain.branch_list(self.repo))
+            {b"master", b"foo"}, set(porcelain.branch_list(self.repo))
         )
 
 
@@ -2990,10 +2986,42 @@ class DescribeTests(PorcelainTestCase):
             porcelain.describe(self.repo.path),
         )
 
+    def test_tag_and_commit_full(self):
+        fullpath = os.path.join(self.repo.path, "foo")
+        with open(fullpath, "w") as f:
+            f.write("BAR")
+        porcelain.add(repo=self.repo.path, paths=[fullpath])
+        porcelain.commit(
+            self.repo.path,
+            message=b"Some message",
+            author=b"Joe <joe@example.com>",
+            committer=b"Bob <bob@example.com>",
+        )
+        porcelain.tag_create(
+            self.repo.path,
+            b"tryme",
+            b"foo <foo@bar.com>",
+            b"bar",
+            annotated=True,
+        )
+        with open(fullpath, "w") as f:
+            f.write("BAR2")
+        porcelain.add(repo=self.repo.path, paths=[fullpath])
+        sha = porcelain.commit(
+            self.repo.path,
+            message=b"Some message",
+            author=b"Joe <joe@example.com>",
+            committer=b"Bob <bob@example.com>",
+        )
+        self.assertEqual(
+            "tryme-1-g{}".format(sha.decode("ascii")),
+            porcelain.describe(self.repo.path, abbrev=40),
+        )
+
 
 class PathToTreeTests(PorcelainTestCase):
     def setUp(self):
-        super(PathToTreeTests, self).setUp()
+        super().setUp()
         self.fp = os.path.join(self.test_dir, "bar")
         with open(self.fp, "w") as f:
             f.write("something")
@@ -3110,6 +3138,43 @@ class FindUniqueAbbrevTests(PorcelainTestCase):
             porcelain.find_unique_abbrev(self.repo.object_store, c1.id))
 
 
+class PackRefsTests(PorcelainTestCase):
+    def test_all(self):
+        c1, c2, c3 = build_commit_graph(
+            self.repo.object_store, [[1], [2, 1], [3, 1, 2]]
+        )
+        self.repo.refs[b"HEAD"] = c3.id
+        self.repo.refs[b"refs/heads/master"] = c2.id
+        self.repo.refs[b"refs/tags/foo"] = c1.id
+
+        porcelain.pack_refs(self.repo, all=True)
+
+        self.assertEqual(
+            self.repo.refs.get_packed_refs(),
+            {
+                b"refs/heads/master": c2.id,
+                b"refs/tags/foo": c1.id,
+            },
+        )
+
+    def test_not_all(self):
+        c1, c2, c3 = build_commit_graph(
+            self.repo.object_store, [[1], [2, 1], [3, 1, 2]]
+        )
+        self.repo.refs[b"HEAD"] = c3.id
+        self.repo.refs[b"refs/heads/master"] = c2.id
+        self.repo.refs[b"refs/tags/foo"] = c1.id
+
+        porcelain.pack_refs(self.repo)
+
+        self.assertEqual(
+            self.repo.refs.get_packed_refs(),
+            {
+                b"refs/tags/foo": c1.id,
+            },
+        )
+
+
 class ServerTests(PorcelainTestCase):
     @contextlib.contextmanager
     def _serving(self):

+ 1 - 1
dulwich/tests/test_protocol.py

@@ -42,7 +42,7 @@ from dulwich.protocol import (
 from dulwich.tests import TestCase
 
 
-class BaseProtocolTests(object):
+class BaseProtocolTests:
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)
         self.assertEqual(self.rout.getvalue(), b"0000")

+ 0 - 1
dulwich/tests/test_reflog.py

@@ -1,5 +1,4 @@
 # test_reflog.py -- tests for reflog.py
-# encoding: utf-8
 # Copyright (C) 2015 Jelmer Vernooij <jelmer@jelmer.uk>
 #
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU

+ 62 - 3
dulwich/tests/test_refs.py

@@ -1,5 +1,4 @@
 # test_refs.py -- tests for refs.py
-# encoding: utf-8
 # Copyright (C) 2013 Jelmer Vernooij <jelmer@jelmer.uk>
 #
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
@@ -172,7 +171,7 @@ _TEST_REFS = {
 }
 
 
-class RefsContainerTests(object):
+class RefsContainerTests:
     def test_keys(self):
         actual_keys = set(self._refs.keys())
         self.assertEqual(set(self._refs.allkeys()), actual_keys)
@@ -447,6 +446,66 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
             b"42d06bd4b77fed026b154d16493e5deab78f02ec",
         )
 
+        # this shouldn't overwrite the packed refs
+        self.assertEqual(
+            {b"refs/heads/packed": b"42d06bd4b77fed026b154d16493e5deab78f02ec"},
+            self._refs.get_packed_refs(),
+        )
+
+    def test_add_packed_refs(self):
+        # first, create a non-packed ref
+        self._refs[b"refs/heads/packed"] = b"3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8"
+
+        packed_ref_path = os.path.join(self._refs.path, b"refs", b"heads", b"packed")
+        self.assertTrue(os.path.exists(packed_ref_path))
+
+        # now overwrite that with a packed ref
+        packed_refs_file_path = os.path.join(self._refs.path, b"packed-refs")
+        self._refs.add_packed_refs(
+            {
+                b"refs/heads/packed": b"42d06bd4b77fed026b154d16493e5deab78f02ec",
+            }
+        )
+
+        # that should kill the file
+        self.assertFalse(os.path.exists(packed_ref_path))
+
+        # now delete the packed ref
+        self._refs.add_packed_refs(
+            {
+                b"refs/heads/packed": None,
+            }
+        )
+
+        # and it's gone!
+        self.assertFalse(os.path.exists(packed_ref_path))
+
+        self.assertRaises(
+            KeyError,
+            self._refs.__getitem__,
+            b"refs/heads/packed",
+        )
+
+        # just in case, make sure we can't pack HEAD
+        self.assertRaises(
+            ValueError,
+            self._refs.add_packed_refs,
+            {b"HEAD": "02ac81614bcdbd585a37b4b0edf8cb8a"},
+        )
+
+        # delete all packed refs
+        self._refs.add_packed_refs({ref: None for ref in self._refs.get_packed_refs()})
+
+        self.assertEqual({}, self._refs.get_packed_refs())
+
+        # remove the packed ref file, and check that adding nothing doesn't affect that
+        os.remove(packed_refs_file_path)
+
+        # adding nothing doesn't make it reappear
+        self._refs.add_packed_refs({})
+
+        self.assertFalse(os.path.exists(packed_refs_file_path))
+
     def test_setitem_symbolic(self):
         ones = b"1" * 40
         self._refs[b"HEAD"] = ones
@@ -640,7 +699,7 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
 
     def test_non_ascii(self):
         try:
-            encoded_ref = os.fsencode(u"refs/tags/schön")
+            encoded_ref = os.fsencode("refs/tags/schön")
         except UnicodeEncodeError as exc:
             raise SkipTest(
                 "filesystem encoding doesn't support special character"

+ 12 - 24
dulwich/tests/test_repository.py

@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
 # test_repository.py -- tests for repository.py
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 #
@@ -332,7 +331,7 @@ class RepositoryRootTests(TestCase):
         self.assertFilesystemHidden(os.path.join(repo_dir, ".git"))
 
     def test_init_mkdir_unicode(self):
-        repo_name = u"\xa7"
+        repo_name = "\xa7"
         try:
             os.fsencode(repo_name)
         except UnicodeEncodeError:
@@ -540,10 +539,10 @@ class RepositoryRootTests(TestCase):
         This test demonstrates that ``find_common_revisions()`` actually
         returns common heads, not revisions; dulwich already uses
         ``find_common_revisions()`` in such a manner (see
-        ``Repo.fetch_objects()``).
+        ``Repo.find_objects()``).
         """
 
-        expected_shas = set([b"60dacdc733de308bb77bb76ce0fb0f9b44c9769e"])
+        expected_shas = {b"60dacdc733de308bb77bb76ce0fb0f9b44c9769e"}
 
         # Source for objects.
         r_base = self.open_repo("simple_merge.git")
@@ -690,9 +689,9 @@ exit 0
         if os.name != "posix":
             self.skipTest("shell hook tests requires POSIX shell")
 
-        pre_commit_contents = """#!%(executable)s
+        pre_commit_contents = """#!{executable}
 import sys
-sys.path.extend(%(path)r)
+sys.path.extend({path!r})
 from dulwich.repo import Repo
 
 with open('foo', 'w') as f:
@@ -700,9 +699,9 @@ with open('foo', 'w') as f:
 
 r = Repo('.')
 r.stage(['foo'])
-""" % {
-            'executable': sys.executable,
-            'path': [os.path.join(os.path.dirname(__file__), '..', '..')] + sys.path}
+""".format(
+            executable=sys.executable,
+            path=[os.path.join(os.path.dirname(__file__), '..', '..')] + sys.path)
 
         repo_dir = os.path.join(self.mkdtemp())
         self.addCleanup(shutil.rmtree, repo_dir)
@@ -732,7 +731,7 @@ r.stage(['foo'])
         self.assertEqual([], r[commit_sha].parents)
 
         tree = r[r[commit_sha].tree]
-        self.assertEqual(set([b'blah', b'foo']), set(tree))
+        self.assertEqual({b'blah', b'foo'}, set(tree))
 
     def test_shell_hook_post_commit(self):
         if os.name != "posix":
@@ -814,7 +813,7 @@ exit 1
                 break
         else:
             raise AssertionError(
-                "Expected warning %r not in %r" % (expected_warning, warnings_list)
+                "Expected warning {!r} not in {!r}".format(expected_warning, warnings_list)
             )
         self.assertEqual([commit_sha], r[commit_sha2].parents)
 
@@ -887,7 +886,7 @@ class BuildRepoRootTests(TestCase):
         return os.path.join(tempfile.mkdtemp(), "test")
 
     def setUp(self):
-        super(BuildRepoRootTests, self).setUp()
+        super().setUp()
         self._repo_dir = self.get_repo_dir()
         os.makedirs(self._repo_dir)
         r = self._repo = Repo.init(self._repo_dir)
@@ -1154,17 +1153,6 @@ class BuildRepoRootTests(TestCase):
         self.assertEqual(b"Jelmer <jelmer@apache.org>", r[commit_sha].author)
         self.assertEqual(b"Jelmer <jelmer@apache.org>", r[commit_sha].committer)
 
-    def overrideEnv(self, name, value):
-        def restore():
-            if oldval is not None:
-                os.environ[name] = oldval
-            else:
-                del os.environ[name]
-
-        oldval = os.environ.get(name)
-        os.environ[name] = value
-        self.addCleanup(restore)
-
     def test_commit_config_identity_from_env(self):
         # commit falls back to the users' identity if it wasn't specified
         self.overrideEnv("GIT_COMMITTER_NAME", "joe")
@@ -1445,7 +1433,7 @@ class BuildRepoRootTests(TestCase):
         r = self._repo
         repo_path_bytes = os.fsencode(r.path)
         encodings = ("utf8", "latin1")
-        names = [u"À".encode(encoding) for encoding in encodings]
+        names = ["À".encode(encoding) for encoding in encodings]
         for name, encoding in zip(names, encodings):
             full_path = os.path.join(repo_path_bytes, name)
             with open(full_path, "wb") as f:

+ 36 - 28
dulwich/tests/test_server.py

@@ -75,7 +75,7 @@ FIVE = b"5" * 40
 SIX = b"6" * 40
 
 
-class TestProto(object):
+class TestProto:
     def __init__(self):
         self._output = []
         self._received = {0: [], 1: [], 2: [], 3: []}
@@ -120,7 +120,7 @@ class TestGenericPackHandler(PackHandler):
 
 class HandlerTestCase(TestCase):
     def setUp(self):
-        super(HandlerTestCase, self).setUp()
+        super().setUp()
         self._handler = TestGenericPackHandler()
 
     def assertSucceeds(self, func, *args, **kwargs):
@@ -164,8 +164,11 @@ class HandlerTestCase(TestCase):
 
 class UploadPackHandlerTestCase(TestCase):
     def setUp(self):
-        super(UploadPackHandlerTestCase, self).setUp()
-        self._repo = MemoryRepo.init_bare([], {})
+        super().setUp()
+        self.path = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, self.path)
+        self.repo = Repo.init(self.path)
+        self._repo = Repo.init_bare(self.path)
         backend = DictBackend({b"/": self._repo})
         self._handler = UploadPackHandler(
             backend, [b"/", b"host=lolcathost"], TestProto()
@@ -174,6 +177,7 @@ class UploadPackHandlerTestCase(TestCase):
     def test_progress(self):
         caps = self._handler.required_capabilities()
         self._handler.set_client_capabilities(caps)
+        self._handler._start_pack_send_phase()
         self._handler.progress(b"first message")
         self._handler.progress(b"second message")
         self.assertEqual(b"first message", self._handler.proto.get_received_line(2))
@@ -195,12 +199,14 @@ class UploadPackHandlerTestCase(TestCase):
         }
         # repo needs to peel this object
         self._repo.object_store.add_object(make_commit(id=FOUR))
-        self._repo.refs._update(refs)
+        for name, sha in refs.items():
+            self._repo.refs[name] = sha
         peeled = {
             b"refs/tags/tag1": b"1234" * 10,
             b"refs/tags/tag2": b"5678" * 10,
         }
-        self._repo.refs._update_peeled(peeled)
+        self._repo.refs._peeled_refs = peeled
+        self._repo.refs.add_packed_refs(refs)
 
         caps = list(self._handler.required_capabilities()) + [b"include-tag"]
         self._handler.set_client_capabilities(caps)
@@ -221,7 +227,8 @@ class UploadPackHandlerTestCase(TestCase):
         tree = Tree()
         self._repo.object_store.add_object(tree)
         self._repo.object_store.add_object(make_commit(id=ONE, tree=tree))
-        self._repo.refs._update(refs)
+        for name, sha in refs.items():
+            self._repo.refs[name] = sha
         self._handler.proto.set_output(
             [
                 b"want " + ONE + b" side-band-64k thin-pack ofs-delta",
@@ -241,7 +248,8 @@ class UploadPackHandlerTestCase(TestCase):
         tree = Tree()
         self._repo.object_store.add_object(tree)
         self._repo.object_store.add_object(make_commit(id=ONE, tree=tree))
-        self._repo.refs._update(refs)
+        for ref, sha in refs.items():
+            self._repo.refs[ref] = sha
         self._handler.proto.set_output([None])
         self._handler.handle()
         # The server should not send a pack, since the client didn't ask for
@@ -251,7 +259,7 @@ class UploadPackHandlerTestCase(TestCase):
 
 class FindShallowTests(TestCase):
     def setUp(self):
-        super(FindShallowTests, self).setUp()
+        super().setUp()
         self._store = MemoryObjectStore()
 
     def make_commit(self, **attrs):
@@ -274,18 +282,18 @@ class FindShallowTests(TestCase):
         c1, c2, c3 = self.make_linear_commits(3)
 
         self.assertEqual(
-            (set([c3.id]), set([])), _find_shallow(self._store, [c3.id], 1)
+            ({c3.id}, set()), _find_shallow(self._store, [c3.id], 1)
         )
         self.assertEqual(
-            (set([c2.id]), set([c3.id])),
+            ({c2.id}, {c3.id}),
             _find_shallow(self._store, [c3.id], 2),
         )
         self.assertEqual(
-            (set([c1.id]), set([c2.id, c3.id])),
+            ({c1.id}, {c2.id, c3.id}),
             _find_shallow(self._store, [c3.id], 3),
         )
         self.assertEqual(
-            (set([]), set([c1.id, c2.id, c3.id])),
+            (set(), {c1.id, c2.id, c3.id}),
             _find_shallow(self._store, [c3.id], 4),
         )
 
@@ -296,7 +304,7 @@ class FindShallowTests(TestCase):
         heads = [a[1].id, b[1].id, c[1].id]
 
         self.assertEqual(
-            (set([a[0].id, b[0].id, c[0].id]), set(heads)),
+            ({a[0].id, b[0].id, c[0].id}, set(heads)),
             _find_shallow(self._store, heads, 2),
         )
 
@@ -311,7 +319,7 @@ class FindShallowTests(TestCase):
 
         # 1 is shallow along the path from 4, but not along the path from 2.
         self.assertEqual(
-            (set([c1.id]), set([c1.id, c2.id, c3.id, c4.id])),
+            ({c1.id}, {c1.id, c2.id, c3.id, c4.id}),
             _find_shallow(self._store, [c2.id, c4.id], 3),
         )
 
@@ -321,7 +329,7 @@ class FindShallowTests(TestCase):
         c3 = self.make_commit(parents=[c1.id, c2.id])
 
         self.assertEqual(
-            (set([c1.id, c2.id]), set([c3.id])),
+            ({c1.id, c2.id}, {c3.id}),
             _find_shallow(self._store, [c3.id], 2),
         )
 
@@ -331,7 +339,7 @@ class FindShallowTests(TestCase):
         self._store.add_object(tag)
 
         self.assertEqual(
-            (set([c1.id]), set([c2.id])),
+            ({c1.id}, {c2.id}),
             _find_shallow(self._store, [tag.id], 2),
         )
 
@@ -344,7 +352,7 @@ class TestUploadPackHandler(UploadPackHandler):
 
 class ReceivePackHandlerTestCase(TestCase):
     def setUp(self):
-        super(ReceivePackHandlerTestCase, self).setUp()
+        super().setUp()
         self._repo = MemoryRepo.init_bare([], {})
         backend = DictBackend({b"/": self._repo})
         self._handler = ReceivePackHandler(
@@ -367,7 +375,7 @@ class ReceivePackHandlerTestCase(TestCase):
 
 class ProtocolGraphWalkerEmptyTestCase(TestCase):
     def setUp(self):
-        super(ProtocolGraphWalkerEmptyTestCase, self).setUp()
+        super().setUp()
         self._repo = MemoryRepo.init_bare([], {})
         backend = DictBackend({b"/": self._repo})
         self._walker = _ProtocolGraphWalker(
@@ -390,7 +398,7 @@ class ProtocolGraphWalkerEmptyTestCase(TestCase):
 
 class ProtocolGraphWalkerTestCase(TestCase):
     def setUp(self):
-        super(ProtocolGraphWalkerTestCase, self).setUp()
+        super().setUp()
         # Create the following commit tree:
         #   3---5
         #  /
@@ -555,7 +563,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
 
     def test_handle_shallow_request_no_client_shallows(self):
         self._handle_shallow_request([b"deepen 2\n"], [FOUR, FIVE])
-        self.assertEqual(set([TWO, THREE]), self._walker.shallow)
+        self.assertEqual({TWO, THREE}, self._walker.shallow)
         self.assertReceived(
             [
                 b"shallow " + TWO,
@@ -570,7 +578,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
             b"deepen 2\n",
         ]
         self._handle_shallow_request(lines, [FOUR, FIVE])
-        self.assertEqual(set([TWO, THREE]), self._walker.shallow)
+        self.assertEqual({TWO, THREE}, self._walker.shallow)
         self.assertReceived([])
 
     def test_handle_shallow_request_unshallows(self):
@@ -579,7 +587,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
             b"deepen 3\n",
         ]
         self._handle_shallow_request(lines, [FOUR, FIVE])
-        self.assertEqual(set([ONE]), self._walker.shallow)
+        self.assertEqual({ONE}, self._walker.shallow)
         self.assertReceived(
             [
                 b"shallow " + ONE,
@@ -589,7 +597,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
         )
 
 
-class TestProtocolGraphWalker(object):
+class TestProtocolGraphWalker:
     def __init__(self):
         self.acks = []
         self.lines = []
@@ -639,7 +647,7 @@ class AckGraphWalkerImplTestCase(TestCase):
     """Base setup and asserts for AckGraphWalker tests."""
 
     def setUp(self):
-        super(AckGraphWalkerImplTestCase, self).setUp()
+        super().setUp()
         self._walker = TestProtocolGraphWalker()
         self._walker.lines = [
             (b"have", TWO),
@@ -1064,7 +1072,7 @@ class FileSystemBackendTests(TestCase):
     """Tests for FileSystemBackend."""
 
     def setUp(self):
-        super(FileSystemBackendTests, self).setUp()
+        super().setUp()
         self.path = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self.path)
         self.repo = Repo.init(self.path)
@@ -1124,7 +1132,7 @@ class ServeCommandTests(TestCase):
     """Tests for serve_command."""
 
     def setUp(self):
-        super(ServeCommandTests, self).setUp()
+        super().setUp()
         self.backend = DictBackend({})
 
     def serve_command(self, handler_cls, args, inf, outf):
@@ -1159,7 +1167,7 @@ class UpdateServerInfoTests(TestCase):
     """Tests for update_server_info."""
 
     def setUp(self):
-        super(UpdateServerInfoTests, self).setUp()
+        super().setUp()
         self.path = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self.path)
         self.repo = Repo.init(self.path)

+ 1 - 1
dulwich/tests/test_utils.py

@@ -37,7 +37,7 @@ from dulwich.tests.utils import (
 
 class BuildCommitGraphTest(TestCase):
     def setUp(self):
-        super(BuildCommitGraphTest, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
 
     def test_linear(self):

+ 4 - 4
dulwich/tests/test_walk.py

@@ -51,13 +51,13 @@ from dulwich.tests.utils import (
 )
 
 
-class TestWalkEntry(object):
+class TestWalkEntry:
     def __init__(self, commit, changes):
         self.commit = commit
         self.changes = changes
 
     def __repr__(self):
-        return "<TestWalkEntry commit=%s, changes=%r>" % (
+        return "<TestWalkEntry commit={}, changes={!r}>".format(
             self.commit.id,
             self.changes,
         )
@@ -72,7 +72,7 @@ class TestWalkEntry(object):
 
 class WalkerTest(TestCase):
     def setUp(self):
-        super(WalkerTest, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
 
     def make_commits(self, commit_spec, **kwargs):
@@ -500,7 +500,7 @@ class WalkerTest(TestCase):
 
 class WalkEntryTest(TestCase):
     def setUp(self):
-        super(WalkEntryTest, self).setUp()
+        super().setUp()
         self.store = MemoryObjectStore()
 
     def make_commits(self, commit_spec, **kwargs):

+ 23 - 25
dulwich/tests/test_web.py

@@ -67,7 +67,7 @@ from dulwich.tests.utils import (
 )
 
 
-class MinimalistWSGIInputStream(object):
+class MinimalistWSGIInputStream:
     """WSGI input stream with no 'seek()' and 'tell()' methods."""
 
     def __init__(self, data):
@@ -110,10 +110,10 @@ class TestHTTPGitRequest(HTTPGitRequest):
 class WebTestCase(TestCase):
     """Base TestCase with useful instance vars and utility functions."""
 
-    _req_class = TestHTTPGitRequest  # type: Type[HTTPGitRequest]
+    _req_class: Type[HTTPGitRequest] = TestHTTPGitRequest
 
     def setUp(self):
-        super(WebTestCase, self).setUp()
+        super().setUp()
         self._environ = {}
         self._req = self._req_class(
             self._environ, self._start_response, handlers=self._handlers()
@@ -168,7 +168,7 @@ class DumbHandlersTestCase(WebTestCase):
         self.assertTrue(f.closed)
 
     def test_send_file_error(self):
-        class TestFile(object):
+        class TestFile:
             def __init__(self, exc_class):
                 self.closed = False
                 self._exc_class = exc_class
@@ -221,7 +221,7 @@ class DumbHandlersTestCase(WebTestCase):
         mat = re.search("^(..)(.{38})$", blob.id.decode("ascii"))
 
         def as_legacy_object_error(self):
-            raise IOError
+            raise OSError
 
         self.addCleanup(setattr, Blob, "as_legacy_object", Blob.as_legacy_object)
         Blob.as_legacy_object = as_legacy_object_error
@@ -296,11 +296,11 @@ class DumbHandlersTestCase(WebTestCase):
         self.assertContentTypeEquals("text/plain")
 
     def test_get_info_packs(self):
-        class TestPackData(object):
+        class TestPackData:
             def __init__(self, sha):
                 self.filename = "pack-%s.pack" % sha
 
-        class TestPack(object):
+        class TestPack:
             def __init__(self, sha):
                 self.data = TestPackData(sha)
 
@@ -327,7 +327,7 @@ class DumbHandlersTestCase(WebTestCase):
 
 
 class SmartHandlersTestCase(WebTestCase):
-    class _TestUploadPackHandler(object):
+    class _TestUploadPackHandler:
         def __init__(
             self,
             backend,
@@ -364,7 +364,7 @@ class SmartHandlersTestCase(WebTestCase):
             self._environ["CONTENT_LENGTH"] = content_length
         mat = re.search(".*", "/git-upload-pack")
 
-        class Backend(object):
+        class Backend:
             def open_repository(self, path):
                 return None
 
@@ -390,7 +390,7 @@ class SmartHandlersTestCase(WebTestCase):
     def test_get_info_refs_unknown(self):
         self._environ["QUERY_STRING"] = "service=git-evil-handler"
 
-        class Backend(object):
+        class Backend:
             def open_repository(self, url):
                 return None
 
@@ -404,7 +404,7 @@ class SmartHandlersTestCase(WebTestCase):
         self._environ["wsgi.input"] = BytesIO(b"foo")
         self._environ["QUERY_STRING"] = "service=git-upload-pack"
 
-        class Backend(object):
+        class Backend:
             def open_repository(self, url):
                 return None
 
@@ -454,14 +454,14 @@ class HTTPGitRequestTestCase(WebTestCase):
         message = "Something not found"
         self.assertEqual(message.encode("ascii"), self._req.not_found(message))
         self.assertEqual(HTTP_NOT_FOUND, self._status)
-        self.assertEqual(set([("Content-Type", "text/plain")]), set(self._headers))
+        self.assertEqual({("Content-Type", "text/plain")}, set(self._headers))
 
     def test_forbidden(self):
         self._req.cache_forever()  # cache headers should be discarded
         message = "Something not found"
         self.assertEqual(message.encode("ascii"), self._req.forbidden(message))
         self.assertEqual(HTTP_FORBIDDEN, self._status)
-        self.assertEqual(set([("Content-Type", "text/plain")]), set(self._headers))
+        self.assertEqual({("Content-Type", "text/plain")}, set(self._headers))
 
     def test_respond_ok(self):
         self._req.respond()
@@ -476,16 +476,14 @@ class HTTPGitRequestTestCase(WebTestCase):
             headers=[("X-Foo", "foo"), ("X-Bar", "bar")],
         )
         self.assertEqual(
-            set(
-                [
-                    ("X-Foo", "foo"),
-                    ("X-Bar", "bar"),
-                    ("Content-Type", "some/type"),
-                    ("Expires", "Fri, 01 Jan 1980 00:00:00 GMT"),
-                    ("Pragma", "no-cache"),
-                    ("Cache-Control", "no-cache, max-age=0, must-revalidate"),
-                ]
-            ),
+            {
+                ("X-Foo", "foo"),
+                ("X-Bar", "bar"),
+                ("Content-Type", "some/type"),
+                ("Expires", "Fri, 01 Jan 1980 00:00:00 GMT"),
+                ("Pragma", "no-cache"),
+                ("Cache-Control", "no-cache, max-age=0, must-revalidate"),
+            },
             set(self._headers),
         )
         self.assertEqual(402, self._status)
@@ -493,7 +491,7 @@ class HTTPGitRequestTestCase(WebTestCase):
 
 class HTTPGitApplicationTestCase(TestCase):
     def setUp(self):
-        super(HTTPGitApplicationTestCase, self).setUp()
+        super().setUp()
         self._app = HTTPGitApplication("backend")
 
         self._environ = {
@@ -533,7 +531,7 @@ class GunzipTestCase(HTTPGitApplicationTestCase):
     example_text = __doc__.encode("ascii")
 
     def setUp(self):
-        super(GunzipTestCase, self).setUp()
+        super().setUp()
         self._app = GunzipFilter(self._app)
         self._environ["HTTP_CONTENT_ENCODING"] = "gzip"
         self._environ["REQUEST_METHOD"] = "POST"

+ 18 - 16
dulwich/walk.py

@@ -24,7 +24,7 @@
 import collections
 import heapq
 from itertools import chain
-from typing import List, Tuple, Set
+from typing import List, Tuple, Set, Deque, Optional
 
 from dulwich.diff_tree import (
     RENAME_CHANGE_TYPES,
@@ -50,7 +50,7 @@ ALL_ORDERS = (ORDER_DATE, ORDER_TOPO)
 _MAX_EXTRA_COMMITS = 5
 
 
-class WalkEntry(object):
+class WalkEntry:
     """Object encapsulating a single result from a walk."""
 
     def __init__(self, walker, commit):
@@ -122,13 +122,13 @@ class WalkEntry(object):
         return self._changes[path_prefix]
 
     def __repr__(self):
-        return "<WalkEntry commit=%s, changes=%r>" % (
+        return "<WalkEntry commit={}, changes={!r}>".format(
             self.commit.id,
             self.changes(),
         )
 
 
-class _CommitTimeQueue(object):
+class _CommitTimeQueue:
     """Priority queue of WalkEntry objects by commit time."""
 
     def __init__(self, walker: "Walker"):
@@ -232,7 +232,7 @@ class _CommitTimeQueue(object):
     __next__ = next
 
 
-class Walker(object):
+class Walker:
     """Object for performing a walk of commits in a store.
 
     Walker objects are initialized with a store and other options and can then
@@ -242,16 +242,16 @@ class Walker(object):
     def __init__(
         self,
         store,
-        include,
-        exclude=None,
-        order=ORDER_DATE,
-        reverse=False,
-        max_entries=None,
-        paths=None,
-        rename_detector=None,
-        follow=False,
-        since=None,
-        until=None,
+        include: List[bytes],
+        exclude: Optional[List[bytes]] = None,
+        order: str = 'date',
+        reverse: bool = False,
+        max_entries: Optional[int] = None,
+        paths: Optional[List[bytes]] = None,
+        rename_detector: Optional[RenameDetector] = None,
+        follow: bool = False,
+        since: Optional[int] = None,
+        until: Optional[int] = None,
         get_parents=lambda commit: commit.parents,
         queue_cls=_CommitTimeQueue,
     ):
@@ -306,11 +306,13 @@ class Walker(object):
 
         self._num_entries = 0
         self._queue = queue_cls(self)
-        self._out_queue = collections.deque()
+        self._out_queue: Deque[WalkEntry] = collections.deque()
 
     def _path_matches(self, changed_path):
         if changed_path is None:
             return False
+        if self.paths is None:
+            return True
         for followed_path in self.paths:
             if changed_path == followed_path:
                 return True

+ 12 - 13
dulwich/web.py

@@ -157,7 +157,7 @@ def send_file(req, f, content_type):
             if not data:
                 break
             yield data
-    except IOError:
+    except OSError:
         yield req.error("Error reading file")
     finally:
         f.close()
@@ -183,7 +183,7 @@ def get_loose_object(req, backend, mat):
         return
     try:
         data = object_store[sha].as_legacy_object()
-    except IOError:
+    except OSError:
         yield req.error("Error reading object")
         return
     req.cache_forever()
@@ -245,8 +245,7 @@ def get_info_refs(req, backend, mat):
         req.nocache()
         req.respond(HTTP_OK, "text/plain")
         logger.info("Emulating dumb info/refs")
-        for text in generate_info_refs(repo):
-            yield text
+        yield from generate_info_refs(repo)
 
 
 def get_info_packs(req, backend, mat):
@@ -266,7 +265,7 @@ def _chunk_iter(f):
         yield chunk[:-2]
 
 
-class ChunkReader(object):
+class ChunkReader:
 
     def __init__(self, f):
         self._iter = _chunk_iter(f)
@@ -284,7 +283,7 @@ class ChunkReader(object):
         return ret
 
 
-class _LengthLimitedFile(object):
+class _LengthLimitedFile:
     """Wrapper class to limit the length of reads from a file-like object.
 
     This is used to ensure EOF is read from the wsgi.input object once
@@ -332,7 +331,7 @@ def handle_service_request(req, backend, mat):
     handler.handle()
 
 
-class HTTPGitRequest(object):
+class HTTPGitRequest:
     """Class encapsulating the state of a single git HTTP request.
 
     Attributes:
@@ -344,8 +343,8 @@ class HTTPGitRequest(object):
         self.dumb = dumb
         self.handlers = handlers
         self._start_response = start_response
-        self._cache_headers = []  # type: List[Tuple[str, str]]
-        self._headers = []  # type: List[Tuple[str, str]]
+        self._cache_headers: List[Tuple[str, str]] = []
+        self._headers: List[Tuple[str, str]] = []
 
     def add_header(self, name, value):
         """Add a header to the response."""
@@ -396,7 +395,7 @@ class HTTPGitRequest(object):
         self._cache_headers = cache_forever_headers()
 
 
-class HTTPGitApplication(object):
+class HTTPGitApplication:
     """Class encapsulating the state of a git WSGI application.
 
     Attributes:
@@ -458,7 +457,7 @@ class HTTPGitApplication(object):
         return handler(req, self.backend, mat)
 
 
-class GunzipFilter(object):
+class GunzipFilter:
     """WSGI middleware that unzips gzip-encoded requests before
     passing on to the underlying application.
     """
@@ -471,7 +470,7 @@ class GunzipFilter(object):
             try:
                 environ["wsgi.input"].tell()
                 wsgi_input = environ["wsgi.input"]
-            except (AttributeError, IOError, NotImplementedError):
+            except (AttributeError, OSError, NotImplementedError):
                 # The gzip implementation in the standard library of Python 2.x
                 # requires working '.seek()' and '.tell()' methods on the input
                 # stream.  Read the data into a temporary file to work around
@@ -490,7 +489,7 @@ class GunzipFilter(object):
         return self.app(environ, start_response)
 
 
-class LimitedInputFilter(object):
+class LimitedInputFilter:
     """WSGI middleware that limits the input length of a request to that
     specified in Content-Length.
     """

+ 1 - 1
examples/clone.py

@@ -22,7 +22,7 @@ _, args = getopt(sys.argv, "", [])
 
 
 if len(args) < 2:
-    print("usage: %s host:path path" % (args[0], ))
+    print("usage: {} host:path path".format(args[0]))
     sys.exit(1)
 
 elif len(args) < 3:

+ 2 - 2
examples/latest_change.py

@@ -6,7 +6,7 @@ import time
 from dulwich.repo import Repo
 
 if len(sys.argv) < 2:
-    print("usage: %s filename" % (sys.argv[0], ))
+    print("usage: {} filename".format(sys.argv[0]))
     sys.exit(1)
 
 r = Repo(".")
@@ -19,5 +19,5 @@ try:
 except StopIteration:
     print("No file %s anywhere in history." % sys.argv[1])
 else:
-    print("%s was last changed by %s at %s (commit %s)" % (
+    print("{} was last changed by {} at {} (commit {})".format(
         sys.argv[1], c.author, time.ctime(c.author_time), c.id))

+ 4 - 4
examples/memoryrepo.py

@@ -14,12 +14,12 @@ from dulwich.repo import MemoryRepo
 
 local_repo = MemoryRepo()
 local_repo.refs.set_symbolic_ref(b'HEAD', b'refs/heads/master')
-print(local_repo.refs.as_dict())
 
-porcelain.fetch(local_repo, sys.argv[1])
-local_repo['refs/heads/master'] = local_repo['refs/remotes/origin/master']
+fetch_result = porcelain.fetch(local_repo, sys.argv[1])
+local_repo.refs[b'refs/heads/master'] = fetch_result.refs[b'refs/heads/master']
+print(local_repo.refs.as_dict())
 
-last_tree = local_repo[local_repo['HEAD'].tree]
+last_tree = local_repo[local_repo[b'HEAD'].tree]
 new_blob = Blob.from_string(b'Some contents')
 local_repo.object_store.add_object(new_blob)
 last_tree.add(b'test', stat.S_IFREG, new_blob.id)

+ 1 - 1
examples/rename-branch.py

@@ -26,4 +26,4 @@ def update_refs(refs):
 
 
 client.send_pack(path, update_refs, generate_pack_data)
-print("Renamed %s to %s" % (args.old_ref, args.new_ref))
+print("Renamed {} to {}".format(args.old_ref, args.new_ref))

+ 3 - 0
pyproject.toml

@@ -0,0 +1,3 @@
+[build-system]
+requires = ["setuptools"]
+build-backend = "setuptools.build_meta"

+ 2 - 2
setup.cfg

@@ -14,7 +14,6 @@ keywords = vcs, git
 classifiers = 
 	Development Status :: 4 - Beta
 	License :: OSI Approved :: Apache Software License
-	Programming Language :: Python :: 3.6
 	Programming Language :: Python :: 3.7
 	Programming Language :: Python :: 3.8
 	Programming Language :: Python :: 3.9
@@ -42,7 +41,7 @@ console_scripts =
 	dulwich = dulwich.cli:main
 
 [options]
-python_requires = >=3.6
+python_requires = >=3.7
 packages = 
 	dulwich
 	dulwich.cloud
@@ -52,6 +51,7 @@ packages =
 include_package_data = True
 install_requires = 
 	urllib3>=1.25
+	typing_extensions;python_version<="3.7"
 zip_safe = False
 scripts = 
 	bin/dul-receive-pack

+ 6 - 2
setup.py

@@ -1,5 +1,4 @@
 #!/usr/bin/python3
-# encoding: utf-8
 # Setup file for dulwich
 # Copyright (C) 2008-2022 Jelmer Vernooij <jelmer@jelmer.uk>
 
@@ -39,7 +38,6 @@ if '__pypy__' not in sys.modules and sys.platform != 'win32':
 
 optional = os.environ.get('CIBUILDWHEEL', '0') != '1'
 
-
 ext_modules = [
     Extension('dulwich._objects', ['dulwich/_objects.c'],
               optional=optional),
@@ -49,6 +47,12 @@ ext_modules = [
               optional=optional),
 ]
 
+# Ideally, setuptools would just provide a way to do this
+if '--pure' in sys.argv:
+    sys.argv.remove('--pure')
+    ext_modules = []
+
+
 setup(package_data={'': ['../docs/tutorial/*.txt', 'py.typed']},
       ext_modules=ext_modules,
       tests_require=tests_require)

+ 1 - 14
tox.ini

@@ -1,6 +1,5 @@
 [tox]
 downloadcache = {toxworkdir}/cache/
-envlist = py35, py35-noext, py36, py36-noext, py37, py37-noext, py38, py38-noext
 
 [testenv]
 
@@ -8,17 +7,5 @@ commands = make check
 recreate = True
 whitelist_externals = make
 
-[testenv:pypy-noext]
-commands = make check-noextensions
-
-[testenv:py35-noext]
-commands = make check-noextensions
-
-[testenv:py36-noext]
-commands = make check-noextensions
-
-[testenv:py37-noext]
-commands = make check-noextensions
-
-[testenv:py38-noext]
+[testenv:noext]
 commands = make check-noextensions

Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно