瀏覽代碼

Import upstream version 0.20.23

Jelmer Vernooij 3 年之前
父節點
當前提交
0c434308d1
共有 100 個文件被更改,包括 11311 次插入7798 次删除
  1. 12 0
      .deepsource.toml
  2. 5 0
      .flake8
  3. 9 0
      .github/workflows/pythonpackage.yml
  4. 17 0
      .github/workflows/pythonpublish.yml
  5. 1 0
      AUTHORS
  6. 86 0
      NEWS
  7. 6 3
      PKG-INFO
  8. 4 1
      README.rst
  9. 3 2
      docs/tutorial/file-format.txt
  10. 4 1
      docs/tutorial/remote.txt
  11. 6 3
      dulwich.egg-info/PKG-INFO
  12. 10 1
      dulwich.egg-info/SOURCES.txt
  13. 1 1
      dulwich/__init__.py
  14. 10 9
      dulwich/archive.py
  15. 22 23
      dulwich/bundle.py
  16. 188 168
      dulwich/cli.py
  17. 322 183
      dulwich/client.py
  18. 0 0
      dulwich/cloud/__init__.py
  19. 82 0
      dulwich/cloud/gcs.py
  20. 115 64
      dulwich/config.py
  21. 5 4
      dulwich/contrib/__init__.py
  22. 45 36
      dulwich/contrib/diffstat.py
  23. 18 12
      dulwich/contrib/paramiko_vendor.py
  24. 11 11
      dulwich/contrib/release_robot.py
  25. 221 188
      dulwich/contrib/swift.py
  26. 189 0
      dulwich/contrib/test_paramiko_vendor.py
  27. 28 21
      dulwich/contrib/test_release_robot.py
  28. 215 185
      dulwich/contrib/test_swift.py
  29. 102 99
      dulwich/contrib/test_swift_smoke.py
  30. 74 53
      dulwich/diff_tree.py
  31. 22 19
      dulwich/errors.py
  32. 39 24
      dulwich/fastexport.py
  33. 41 20
      dulwich/file.py
  34. 2 2
      dulwich/graph.py
  35. 23 19
      dulwich/greenthreads.py
  36. 38 28
      dulwich/hooks.py
  37. 86 85
      dulwich/ignore.py
  38. 216 115
      dulwich/index.py
  39. 7 8
      dulwich/lfs.py
  40. 40 12
      dulwich/line_ending.py
  41. 6 3
      dulwich/log_utils.py
  42. 35 26
      dulwich/lru_cache.py
  43. 6 5
      dulwich/mailmap.py
  44. 292 115
      dulwich/object_store.py
  45. 330 213
      dulwich/objects.py
  46. 8 5
      dulwich/objectspec.py
  47. 230 189
      dulwich/pack.py
  48. 100 69
      dulwich/patch.py
  49. 320 186
      dulwich/porcelain.py
  50. 84 75
      dulwich/protocol.py
  51. 88 13
      dulwich/reflog.py
  52. 295 142
      dulwich/refs.py
  53. 278 191
      dulwich/repo.py
  54. 209 157
      dulwich/server.py
  55. 37 17
      dulwich/stash.py
  56. 70 56
      dulwich/tests/__init__.py
  57. 10 9
      dulwich/tests/compat/__init__.py
  58. 148 94
      dulwich/tests/compat/server_utils.py
  59. 220 176
      dulwich/tests/compat/test_client.py
  60. 55 39
      dulwich/tests/compat/test_pack.py
  61. 6 6
      dulwich/tests/compat/test_patch.py
  62. 101 0
      dulwich/tests/compat/test_porcelain.py
  63. 41 43
      dulwich/tests/compat/test_repository.py
  64. 14 17
      dulwich/tests/compat/test_server.py
  65. 12 14
      dulwich/tests/compat/test_utils.py
  66. 36 34
      dulwich/tests/compat/test_web.py
  67. 44 35
      dulwich/tests/compat/utils.py
  68. 15 18
      dulwich/tests/test_archive.py
  69. 9 9
      dulwich/tests/test_blackbox.py
  70. 7 8
      dulwich/tests/test_bundle.py
  71. 312 286
      dulwich/tests/test_client.py
  72. 172 101
      dulwich/tests/test_config.py
  73. 787 584
      dulwich/tests/test_diff_tree.py
  74. 112 56
      dulwich/tests/test_fastexport.py
  75. 62 64
      dulwich/tests/test_file.py
  76. 67 65
      dulwich/tests/test_grafts.py
  77. 76 77
      dulwich/tests/test_graph.py
  78. 18 15
      dulwich/tests/test_greenthreads.py
  79. 51 34
      dulwich/tests/test_hooks.py
  80. 126 115
      dulwich/tests/test_ignore.py
  81. 296 221
      dulwich/tests/test_index.py
  82. 3 5
      dulwich/tests/test_lfs.py
  83. 4 13
      dulwich/tests/test_line_ending.py
  84. 89 88
      dulwich/tests/test_lru_cache.py
  85. 51 40
      dulwich/tests/test_mailmap.py
  86. 142 85
      dulwich/tests/test_missing_obj_finder.py
  87. 257 170
      dulwich/tests/test_object_store.py
  88. 316 284
      dulwich/tests/test_objects.py
  89. 78 56
      dulwich/tests/test_objectspec.py
  90. 313 250
      dulwich/tests/test_pack.py
  91. 311 216
      dulwich/tests/test_patch.py
  92. 470 201
      dulwich/tests/test_porcelain.py
  93. 84 84
      dulwich/tests/test_protocol.py
  94. 102 28
      dulwich/tests/test_reflog.py
  95. 432 354
      dulwich/tests/test_refs.py
  96. 416 322
      dulwich/tests/test_repository.py
  97. 297 256
      dulwich/tests/test_server.py
  98. 29 23
      dulwich/tests/test_utils.py
  99. 220 171
      dulwich/tests/test_walk.py
  100. 187 170
      dulwich/tests/test_web.py

+ 12 - 0
.deepsource.toml

@@ -0,0 +1,12 @@
+version = 1
+
+test_patterns = ["dulwich/**test_*.py"]
+
+exclude_patterns = ["examples/**"]
+
+[[analyzers]]
+name = "python"
+enabled = true
+
+  [analyzers.meta]
+  runtime_version = "3.x.x"

+ 5 - 0
.flake8

@@ -0,0 +1,5 @@
+[flake8]
+extend-ignore = E203, E266, E501, W293, W291
+max-line-length = 88
+max-complexity = 18
+select = B,C,E,F,W,T4,B9

+ 9 - 0
.github/workflows/pythonpackage.yml

@@ -31,10 +31,19 @@ jobs:
       uses: actions/setup-python@v2
       with:
         python-version: ${{ matrix.python-version }}
+    - name: Install native dependencies (Ubuntu)
+      run: sudo apt-get update && sudo apt-get install -y libgpgme-dev libgpg-error-dev
+      if: "matrix.os == 'ubuntu-latest'"
+    - name: Install native dependencies (MacOS)
+      run: brew install swig gpgme
+      if: "matrix.os == 'macos-latest'"
     - name: Install dependencies
       run: |
         python -m pip install --upgrade pip
         pip install -U pip coverage codecov flake8 fastimport
+    - name: Install gpg on supported platforms
+      run: pip install -U gpg
+      if: "matrix.os != 'windows-latest' && matrix.python-version != 'pypy3'"
     - name: Install mypy
       run: |
         pip install -U mypy

+ 17 - 0
.github/workflows/pythonpublish.yml

@@ -30,10 +30,19 @@ jobs:
       uses: actions/setup-python@v2
       with:
         python-version: ${{ matrix.python-version }}
+    - name: Install native dependencies (Ubuntu)
+      run: sudo apt-get update && sudo apt-get install -y libgpgme-dev libgpg-error-dev
+      if: "matrix.os == 'ubuntu-latest'"
+    - name: Install native dependencies (MacOS)
+      run: brew install swig gpgme
+      if: "matrix.os == 'macos-latest'"
     - name: Install dependencies
       run: |
         python -m pip install --upgrade pip
         pip install setuptools wheel twine fastimport
+    - name: Install gpg on supported platforms
+      run: pip install -U gpg
+      if: "matrix.os != 'windows-latest' && matrix.python-version != 'pypy3'"
     - name: Run test suite
       run: |
         python -m unittest dulwich.tests.test_suite
@@ -41,6 +50,14 @@ jobs:
       run: |
         python setup.py sdist bdist_wheel
       if: "matrix.os != 'ubuntu-latest'"
+    - uses: docker/setup-qemu-action@v1
+      name: Set up QEMU
+      if: "matrix.os == 'ubuntu-latest'"
+    - name: Build and publish (Linux aarch64)
+      uses: RalfG/python-wheels-manylinux-build@v0.3.3-manylinux2014_aarch64
+      with:
+        python-versions: 'cp36-cp36m cp37-cp37m cp38-cp38 cp39-cp39'
+      if: "matrix.os == 'ubuntu-latest'"
     - name: Build and publish (Linux)
       uses: RalfG/python-wheels-manylinux-build@v0.3.1
       with:

+ 1 - 0
AUTHORS

@@ -150,5 +150,6 @@ Antoine Lambert <anlambert@softwareheritage.org>
 Lane Barlow <lane.barlow@gmail.com>
 Manuel Jacob <me@manueljacob.de>
 Brecht Machiels <brecht@mos6581.org>
+Peter Rowlands <peter@pmrowla.com>
 
 If you contributed but are missing from this list, please send me an e-mail.

+ 86 - 0
NEWS

@@ -1,3 +1,88 @@
+0.20.23	2021-05-24
+
+ * Fix installation of GPG during package publishing.
+   (Ruslan Kuprieiev)
+
+0.20.22	2021-05-24
+
+ * Prevent removal of refs directory when the last ref is
+   deleted. (Jelmer Vernooij)
+
+ * Fix filename: MERGE_HEADS => MERGE_HEAD.
+   (Jelmer Vernooij, #861)
+
+ * For ignored directories, porcelain.add and porcelain.status now only return
+   the path to directory itself in the list of ignored paths. Previously, paths
+   for all files within the directory would also be included in the list.
+   (Peter Rowlands, #853)
+
+ * Provide depth argument to ``determine_wants``.
+   (Peter Rowlands)
+
+ * Various tag signature handling improvements.
+   (Daniel Murphy)
+
+ * Add separate Tag.verify().  (Peter Rowlands)
+
+ * Add support for version 3 index files. (Jelmer Vernooij)
+
+ * Fix autocrlf=input handling. (Peter Rowlands, Boris Feld)
+
+ * Attempt to find C Git global config on Windows.
+   (Peter Rowlands)
+
+ API CHANGES
+
+ * The APIs for writing and reading individual index entries have changed
+   to handle lists of (name, entry) tuples rather than tuples.
+
+0.20.21	2021-03-20
+
+ * Add basic support for a GcsObjectStore that stores
+   pack files in gcs. (Jelmer Vernooij)
+
+ * In porcelain.push, default to local active branch.
+   (Jelmer Vernooij, #846)
+
+ * Support fetching symrefs.
+   (Jelmer Vernooij, #485, #847)
+
+ * Add aarch64 wheel building.
+   (odidev, Jelmer Vernooij)
+
+0.20.20	2021-03-03
+
+ * Implement ``Stash.drop``. (Peter Rowlands)
+
+ * Support untracked symlinks to paths outside the
+   repository. (Peter Rowlands, #842)
+
+0.20.19	2021-02-11
+
+ * Fix handling of negative matches in nested gitignores.
+   (Corentin Hembise, #836)
+
+0.20.18	2021-02-04
+
+ * Fix formatting in setup.py. (Jelmer Vernooij)
+
+ * Add release configuration. (Jelmer Vernooij)
+
+0.20.17	2021-02-04
+
+ * credentials: ignore end-of-line character. (Georges Racinet)
+
+ * Fix failure in get_untracked_paths when the repository contains symlinks.
+   (#830, #793, mattseddon)
+
+ * docs: Clarify that Git objects are created on `git add`.
+   (Utku Gultopu)
+
+0.20.16	2021-01-16
+
+ * Add flag to only attempt to fetch ignored untracked files when specifically requested.
+   (Matt Seddon)
+
 0.20.15	2020-12-23
 
  * Add some functions for parsing and writing bundles.
@@ -994,6 +1079,7 @@
 
   * In dulwich.index.build_index_from_tree, by default
     refuse to create entries that start with .git/.
+    (Jelmer Vernooij, CVE-2014-9706)
 
   * Fix running of testsuite when installed.
     (Jelmer Vernooij, #223)

+ 6 - 3
PKG-INFO

@@ -1,6 +1,6 @@
 Metadata-Version: 2.1
 Name: dulwich
-Version: 0.20.15
+Version: 0.20.23
 Summary: Python Git Library
 Home-page: https://www.dulwich.io/
 Author: Jelmer Vernooij
@@ -9,7 +9,10 @@ License: Apachev2 or later or GPLv2
 Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
 Project-URL: Repository, https://www.dulwich.io/code/
 Project-URL: GitHub, https://github.com/dulwich/dulwich
-Description: This is the Dulwich project.
+Description: Dulwich
+        =======
+        
+        This is the Dulwich project.
         
         It aims to provide an interface to git repos (both local and remote) that
         doesn't call out to git directly but instead uses pure Python.
@@ -80,7 +83,7 @@ Description: This is the Dulwich project.
         Help
         ----
         
-        There is a *#dulwich* IRC channel on the `Freenode <https://www.freenode.net/>`_, and
+        There is a *#dulwich* IRC channel on the `OFTC <https://www.oftc.net/>`_, and
         `dulwich-announce <https://groups.google.com/forum/#!forum/dulwich-announce>`_
         and `dulwich-discuss <https://groups.google.com/forum/#!forum/dulwich-discuss>`_
         mailing lists.

+ 4 - 1
README.rst

@@ -1,3 +1,6 @@
+Dulwich
+=======
+
 This is the Dulwich project.
 
 It aims to provide an interface to git repos (both local and remote) that
@@ -69,7 +72,7 @@ doc``. It can also be found `on the web <https://www.dulwich.io/docs/>`_.
 Help
 ----
 
-There is a *#dulwich* IRC channel on the `Freenode <https://www.freenode.net/>`_, and
+There is a *#dulwich* IRC channel on the `OFTC <https://www.oftc.net/>`_, and
 `dulwich-announce <https://groups.google.com/forum/#!forum/dulwich-announce>`_
 and `dulwich-discuss <https://groups.google.com/forum/#!forum/dulwich-discuss>`_
 mailing lists.

+ 3 - 2
docs/tutorial/file-format.txt

@@ -69,8 +69,9 @@ A blob file looks like this::
 
   blob <content length><NUL><content>
 
-If you change a single line, another blob will be generated by Git at commit
-time. This is how Git can fastly checkout any version in time.
+If you change a single line, another blob will be generated by Git each time you
+successfully run ``git add``. This is how Git can fastly checkout any version in
+time.
 
 On the opposite, several identical files with different filenames generate
 only one blob. That's mostly how renames are so cheap and efficient in Git.

+ 4 - 1
docs/tutorial/remote.txt

@@ -41,10 +41,13 @@ The client object can then be used to retrieve a pack. The ``fetch_pack``
 method takes a ``determine_wants`` callback argument, which allows the
 client to determine which objects it wants to end up with::
 
-   >>> def determine_wants(refs):
+   >>> def determine_wants(refs, depth=None):
    ...    # retrieve all objects
    ...    return refs.values()
 
+Note that the ``depth`` keyword argument will contain an optional requested
+shallow fetch depth.
+
 Another required object is a "graph walker", which is used to determine
 which objects that the client already has should not be sent again
 by the server. Here in the tutorial we'll just use a dummy graph walker

+ 6 - 3
dulwich.egg-info/PKG-INFO

@@ -1,6 +1,6 @@
 Metadata-Version: 2.1
 Name: dulwich
-Version: 0.20.15
+Version: 0.20.23
 Summary: Python Git Library
 Home-page: https://www.dulwich.io/
 Author: Jelmer Vernooij
@@ -9,7 +9,10 @@ License: Apachev2 or later or GPLv2
 Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
 Project-URL: Repository, https://www.dulwich.io/code/
 Project-URL: GitHub, https://github.com/dulwich/dulwich
-Description: This is the Dulwich project.
+Description: Dulwich
+        =======
+        
+        This is the Dulwich project.
         
         It aims to provide an interface to git repos (both local and remote) that
         doesn't call out to git directly but instead uses pure Python.
@@ -80,7 +83,7 @@ Description: This is the Dulwich project.
         Help
         ----
         
-        There is a *#dulwich* IRC channel on the `Freenode <https://www.freenode.net/>`_, and
+        There is a *#dulwich* IRC channel on the `OFTC <https://www.oftc.net/>`_, and
         `dulwich-announce <https://groups.google.com/forum/#!forum/dulwich-announce>`_
         and `dulwich-discuss <https://groups.google.com/forum/#!forum/dulwich-discuss>`_
         mailing lists.

+ 10 - 1
dulwich.egg-info/SOURCES.txt

@@ -1,4 +1,6 @@
 .coveragerc
+.deepsource.toml
+.flake8
 .gitignore
 .mailmap
 .testr.conf
@@ -15,6 +17,7 @@ SECURITY.md
 TODO
 build.cmd
 dulwich.cfg
+releaser.conf
 requirements.txt
 setup.cfg
 setup.py
@@ -92,12 +95,15 @@ dulwich.egg-info/dependency_links.txt
 dulwich.egg-info/entry_points.txt
 dulwich.egg-info/requires.txt
 dulwich.egg-info/top_level.txt
+dulwich/cloud/__init__.py
+dulwich/cloud/gcs.py
 dulwich/contrib/README.md
 dulwich/contrib/__init__.py
 dulwich/contrib/diffstat.py
 dulwich/contrib/paramiko_vendor.py
 dulwich/contrib/release_robot.py
 dulwich/contrib/swift.py
+dulwich/contrib/test_paramiko_vendor.py
 dulwich/contrib/test_release_robot.py
 dulwich/contrib/test_swift.py
 dulwich/contrib/test_swift_smoke.py
@@ -142,6 +148,7 @@ dulwich/tests/compat/server_utils.py
 dulwich/tests/compat/test_client.py
 dulwich/tests/compat/test_pack.py
 dulwich/tests/compat/test_patch.py
+dulwich/tests/compat/test_porcelain.py
 dulwich/tests/compat/test_repository.py
 dulwich/tests/compat/test_server.py
 dulwich/tests/compat/test_utils.py
@@ -229,5 +236,7 @@ dulwich/tests/data/trees/70/c190eb48fa8bbb50ddc692a17b44cb781af7f6
 examples/clone.py
 examples/config.py
 examples/diff.py
+examples/gcs.py
 examples/latest_change.py
-examples/memoryrepo.py
+examples/memoryrepo.py
+examples/rename-branch.py

+ 1 - 1
dulwich/__init__.py

@@ -22,4 +22,4 @@
 
 """Python implementation of the Git file formats and protocols."""
 
-__version__ = (0, 20, 15)
+__version__ = (0, 20, 23)

+ 10 - 9
dulwich/archive.py

@@ -42,20 +42,21 @@ class ChunkedBytesIO(object):
         BytesIO(b''.join(list_of_bytestrings)) =~= ChunkedBytesIO(
             list_of_bytestrings)
     """
+
     def __init__(self, contents):
         self.contents = contents
         self.pos = (0, 0)
 
     def read(self, maxbytes=None):
         if maxbytes < 0:
-            maxbytes = float('inf')
+            maxbytes = float("inf")
 
         buf = []
         chunk, cursor = self.pos
 
         while chunk < len(self.contents):
             if maxbytes < len(self.contents[chunk]) - cursor:
-                buf.append(self.contents[chunk][cursor:cursor+maxbytes])
+                buf.append(self.contents[chunk][cursor : cursor + maxbytes])
                 cursor += maxbytes
                 self.pos = (chunk, cursor)
                 break
@@ -65,10 +66,10 @@ class ChunkedBytesIO(object):
                 chunk += 1
                 cursor = 0
                 self.pos = (chunk, cursor)
-        return b''.join(buf)
+        return b"".join(buf)
 
 
-def tar_stream(store, tree, mtime, prefix=b'', format=''):
+def tar_stream(store, tree, mtime, prefix=b"", format=""):
     """Generate a tar stream for the contents of a Git tree.
 
     Returns a generator that lazily assembles a .tar.gz archive, yielding it in
@@ -86,16 +87,16 @@ def tar_stream(store, tree, mtime, prefix=b'', format=''):
     """
     buf = BytesIO()
     with closing(tarfile.open(None, "w:%s" % format, buf)) as tar:
-        if format == 'gz':
+        if format == "gz":
             # Manually correct the gzip header file modification time so that
             # archives created from the same Git tree are always identical.
             # The gzip header file modification time is not currenctly
             # accessible from the tarfile API, see:
             # https://bugs.python.org/issue31526
             buf.seek(0)
-            assert buf.read(2) == b'\x1f\x8b', 'Invalid gzip header'
+            assert buf.read(2) == b"\x1f\x8b", "Invalid gzip header"
             buf.seek(4)
-            buf.write(struct.pack('<L', mtime))
+            buf.write(struct.pack("<L", mtime))
             buf.seek(0, SEEK_END)
 
         for entry_abspath, entry in _walk_tree(store, tree, prefix):
@@ -109,7 +110,7 @@ def tar_stream(store, tree, mtime, prefix=b'', format=''):
 
             info = tarfile.TarInfo()
             # tarfile only works with ascii.
-            info.name = entry_abspath.decode('ascii')
+            info.name = entry_abspath.decode("ascii")
             info.size = blob.raw_length()
             info.mode = entry.mode
             info.mtime = mtime
@@ -121,7 +122,7 @@ def tar_stream(store, tree, mtime, prefix=b'', format=''):
     yield buf.getvalue()
 
 
-def _walk_tree(store, tree, root=b''):
+def _walk_tree(store, tree, root=b""):
     """Recursively walk a dulwich Tree, yielding tuples of
     (absolute path, TreeEntry) along the way.
     """

+ 22 - 23
dulwich/bundle.py

@@ -56,23 +56,23 @@ def _read_bundle(f, version):
     references = {}
     line = f.readline()
     if version >= 3:
-        while line.startswith(b'@'):
-            line = line[1:].rstrip(b'\n')
+        while line.startswith(b"@"):
+            line = line[1:].rstrip(b"\n")
             try:
-                key, value = line.split(b'=', 1)
+                key, value = line.split(b"=", 1)
             except ValueError:
                 key = line
                 value = None
             else:
-                value = value.decode('utf-8')
-            capabilities[key.decode('utf-8')] = value
+                value = value.decode("utf-8")
+            capabilities[key.decode("utf-8")] = value
             line = f.readline()
-    while line.startswith(b'-'):
-        (obj_id, comment) = line[1:].rstrip(b'\n').split(b' ', 1)
-        prerequisites.append((obj_id, comment.decode('utf-8')))
+    while line.startswith(b"-"):
+        (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
+        prerequisites.append((obj_id, comment.decode("utf-8")))
         line = f.readline()
-    while line != b'\n':
-        (obj_id, ref) = line.rstrip(b'\n').split(b' ', 1)
+    while line != b"\n":
+        (obj_id, ref) = line.rstrip(b"\n").split(b" ", 1)
         references[ref] = obj_id
         line = f.readline()
     pack_data = PackData.from_file(f)
@@ -88,12 +88,11 @@ def _read_bundle(f, version):
 def read_bundle(f):
     """Read a bundle file."""
     firstline = f.readline()
-    if firstline == b'# v2 git bundle\n':
+    if firstline == b"# v2 git bundle\n":
         return _read_bundle(f, 2)
-    if firstline == b'# v3 git bundle\n':
+    if firstline == b"# v3 git bundle\n":
         return _read_bundle(f, 3)
-    raise AssertionError(
-        'unsupported bundle format header: %r' % firstline)
+    raise AssertionError("unsupported bundle format header: %r" % firstline)
 
 
 def write_bundle(f, bundle):
@@ -104,20 +103,20 @@ def write_bundle(f, bundle):
         else:
             version = 2
     if version == 2:
-        f.write(b'# v2 git bundle\n')
+        f.write(b"# v2 git bundle\n")
     elif version == 3:
-        f.write(b'# v3 git bundle\n')
+        f.write(b"# v3 git bundle\n")
     else:
-        raise AssertionError('unknown version %d' % version)
+        raise AssertionError("unknown version %d" % version)
     if version == 3:
         for key, value in bundle.capabilities.items():
-            f.write(b'@' + key.encode('utf-8'))
+            f.write(b"@" + key.encode("utf-8"))
             if value is not None:
-                f.write(b'=' + value.encode('utf-8'))
-            f.write(b'\n')
+                f.write(b"=" + value.encode("utf-8"))
+            f.write(b"\n")
     for (obj_id, comment) in bundle.prerequisites:
-        f.write(b'-%s %s\n' % (obj_id, comment.encode('utf-8')))
+        f.write(b"-%s %s\n" % (obj_id, comment.encode("utf-8")))
     for ref, obj_id in bundle.references.items():
-        f.write(b'%s %s\n' % (obj_id, ref))
-    f.write(b'\n')
+        f.write(b"%s %s\n" % (obj_id, ref))
+    f.write(b"\n")
     write_pack_data(f, len(bundle.pack_data), iter(bundle.pack_data))

+ 188 - 168
dulwich/cli.py

@@ -31,6 +31,7 @@ a way to test Dulwich.
 import os
 import sys
 from getopt import getopt
+import argparse
 import optparse
 import signal
 from typing import Dict, Type
@@ -50,6 +51,7 @@ def signal_int(signal, frame):
 
 def signal_quit(signal, frame):
     import pdb
+
     pdb.set_trace()
 
 
@@ -62,57 +64,65 @@ class Command(object):
 
 
 class cmd_archive(Command):
-
     def run(self, args):
-        parser = optparse.OptionParser()
-        parser.add_option("--remote", type=str,
-                          help="Retrieve archive from specified remote repo")
-        options, args = parser.parse_args(args)
-        committish = args.pop(0)
-        if options.remote:
-            client, path = get_transport_and_path(options.remote)
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "--remote",
+            type=str,
+            help="Retrieve archive from specified remote repo",
+        )
+        parser.add_argument('committish', type=str, nargs='?')
+        args = parser.parse_args(args)
+        if args.remote:
+            client, path = get_transport_and_path(args.remote)
             client.archive(
-                path, committish, sys.stdout.write,
-                write_error=sys.stderr.write)
+                path,
+                args.committish,
+                sys.stdout.write,
+                write_error=sys.stderr.write,
+            )
         else:
             porcelain.archive(
-                '.', committish, outstream=sys.stdout,
-                errstream=sys.stderr)
+                ".", args.committish, outstream=sys.stdout.buffer,
+                errstream=sys.stderr
+            )
 
 
 class cmd_add(Command):
-
-    def run(self, args):
-        opts, args = getopt(args, "", [])
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        args = parser.parse_args(argv)
 
         porcelain.add(".", paths=args)
 
 
 class cmd_rm(Command):
-
-    def run(self, args):
-        opts, args = getopt(args, "", [])
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        args = parser.parse_args(argv)
 
         porcelain.rm(".", paths=args)
 
 
 class cmd_fetch_pack(Command):
-
-    def run(self, args):
-        opts, args = getopt(args, "", ["all"])
-        opts = dict(opts)
-        client, path = get_transport_and_path(args.pop(0))
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--all', action='store_true')
+        parser.add_argument('location', nargs='?', type=str)
+        args = parser.parse_args(argv)
+        client, path = get_transport_and_path(args.location)
         r = Repo(".")
-        if "--all" in opts:
+        if args.all:
             determine_wants = r.object_store.determine_wants_all
         else:
-            def determine_wants(x):
+
+            def determine_wants(x, **kwargs):
                 return [y for y in args if y not in r.object_store]
+
         client.fetch(path, r, determine_wants)
 
 
 class cmd_fetch(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts = dict(opts)
@@ -125,32 +135,40 @@ class cmd_fetch(Command):
 
 
 class cmd_fsck(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts = dict(opts)
-        for (obj, msg) in porcelain.fsck('.'):
+        for (obj, msg) in porcelain.fsck("."):
             print("%s: %s" % (obj, msg))
 
 
 class cmd_log(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
-        parser.add_option("--reverse", dest="reverse", action="store_true",
-                          help="Reverse order in which entries are printed")
-        parser.add_option("--name-status", dest="name_status",
-                          action="store_true",
-                          help="Print name/status for each changed file")
+        parser.add_option(
+            "--reverse",
+            dest="reverse",
+            action="store_true",
+            help="Reverse order in which entries are printed",
+        )
+        parser.add_option(
+            "--name-status",
+            dest="name_status",
+            action="store_true",
+            help="Print name/status for each changed file",
+        )
         options, args = parser.parse_args(args)
 
-        porcelain.log(".", paths=args, reverse=options.reverse,
-                      name_status=options.name_status,
-                      outstream=sys.stdout)
+        porcelain.log(
+            ".",
+            paths=args,
+            reverse=options.reverse,
+            name_status=options.name_status,
+            outstream=sys.stdout,
+        )
 
 
 class cmd_diff(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", [])
 
@@ -162,12 +180,10 @@ class cmd_diff(Command):
         commit_id = args[0]
         commit = r[commit_id]
         parent_commit = r[commit.parents[0]]
-        write_tree_diff(
-            sys.stdout, r.object_store, parent_commit.tree, commit.tree)
+        write_tree_diff(sys.stdout, r.object_store, parent_commit.tree, commit.tree)
 
 
 class cmd_dump_pack(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", [])
 
@@ -192,7 +208,6 @@ class cmd_dump_pack(Command):
 
 
 class cmd_dump_index(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", [])
 
@@ -208,7 +223,6 @@ class cmd_dump_index(Command):
 
 
 class cmd_init(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", ["bare"])
         opts = dict(opts)
@@ -222,14 +236,17 @@ class cmd_init(Command):
 
 
 class cmd_clone(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
-        parser.add_option("--bare", dest="bare",
-                          help="Whether to create a bare repository.",
-                          action="store_true")
-        parser.add_option("--depth", dest="depth",
-                          type=int, help="Depth at which to fetch")
+        parser.add_option(
+            "--bare",
+            dest="bare",
+            help="Whether to create a bare repository.",
+            action="store_true",
+        )
+        parser.add_option(
+            "--depth", dest="depth", type=int, help="Depth at which to fetch"
+        )
         options, args = parser.parse_args(args)
 
         if args == []:
@@ -246,7 +263,6 @@ class cmd_clone(Command):
 
 
 class cmd_commit(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", ["message"])
         opts = dict(opts)
@@ -254,7 +270,6 @@ class cmd_commit(Command):
 
 
 class cmd_commit_tree(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", ["message"])
         if args == []:
@@ -265,13 +280,11 @@ class cmd_commit_tree(Command):
 
 
 class cmd_update_server_info(Command):
-
     def run(self, args):
         porcelain.update_server_info(".")
 
 
 class cmd_symbolic_ref(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", ["ref-name", "force"])
         if not args:
@@ -279,18 +292,18 @@ class cmd_symbolic_ref(Command):
             sys.exit(1)
 
         ref_name = args.pop(0)
-        porcelain.symbolic_ref(".", ref_name=ref_name, force='--force' in args)
+        porcelain.symbolic_ref(".", ref_name=ref_name, force="--force" in args)
 
 
 class cmd_show(Command):
-
-    def run(self, args):
-        opts, args = getopt(args, "", [])
-        porcelain.show(".", args)
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        parser.add_argument('objectish', type=str, nargs='*')
+        args = parser.parse_args(argv)
+        porcelain.show(".", args.objectish or None)
 
 
 class cmd_diff_tree(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", [])
         if len(args) < 2:
@@ -300,41 +313,40 @@ class cmd_diff_tree(Command):
 
 
 class cmd_rev_list(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", [])
         if len(args) < 1:
-            print('Usage: dulwich rev-list COMMITID...')
+            print("Usage: dulwich rev-list COMMITID...")
             sys.exit(1)
-        porcelain.rev_list('.', args)
+        porcelain.rev_list(".", args)
 
 
 class cmd_tag(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         parser.add_option(
-            "-a", "--annotated", help="Create an annotated tag.",
-            action="store_true")
+            "-a",
+            "--annotated",
+            help="Create an annotated tag.",
+            action="store_true",
+        )
         parser.add_option(
-            "-s", "--sign", help="Sign the annotated tag.",
-            action="store_true")
+            "-s", "--sign", help="Sign the annotated tag.", action="store_true"
+        )
         options, args = parser.parse_args(args)
         porcelain.tag_create(
-            '.', args[0], annotated=options.annotated,
-            sign=options.sign)
+            ".", args[0], annotated=options.annotated, sign=options.sign
+        )
 
 
 class cmd_repack(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts = dict(opts)
-        porcelain.repack('.')
+        porcelain.repack(".")
 
 
 class cmd_reset(Command):
-
     def run(self, args):
         opts, args = getopt(args, "", ["hard", "soft", "mixed"])
         opts = dict(opts)
@@ -345,110 +357,122 @@ class cmd_reset(Command):
             mode = "soft"
         elif "--mixed" in opts:
             mode = "mixed"
-        porcelain.reset('.', mode=mode, *args)
+        porcelain.reset(".", mode=mode, *args)
 
 
 class cmd_daemon(Command):
-
     def run(self, args):
         from dulwich import log_utils
         from dulwich.protocol import TCP_GIT_PORT
+
         parser = optparse.OptionParser()
-        parser.add_option("-l", "--listen_address", dest="listen_address",
-                          default="localhost",
-                          help="Binding IP address.")
-        parser.add_option("-p", "--port", dest="port", type=int,
-                          default=TCP_GIT_PORT,
-                          help="Binding TCP port.")
+        parser.add_option(
+            "-l",
+            "--listen_address",
+            dest="listen_address",
+            default="localhost",
+            help="Binding IP address.",
+        )
+        parser.add_option(
+            "-p",
+            "--port",
+            dest="port",
+            type=int,
+            default=TCP_GIT_PORT,
+            help="Binding TCP port.",
+        )
         options, args = parser.parse_args(args)
 
         log_utils.default_logging_config()
         if len(args) >= 1:
             gitdir = args[0]
         else:
-            gitdir = '.'
-        from dulwich import porcelain
-        porcelain.daemon(gitdir, address=options.listen_address,
-                         port=options.port)
+            gitdir = "."
 
+        porcelain.daemon(gitdir, address=options.listen_address, port=options.port)
 
-class cmd_web_daemon(Command):
 
+class cmd_web_daemon(Command):
     def run(self, args):
         from dulwich import log_utils
+
         parser = optparse.OptionParser()
-        parser.add_option("-l", "--listen_address", dest="listen_address",
-                          default="",
-                          help="Binding IP address.")
-        parser.add_option("-p", "--port", dest="port", type=int,
-                          default=8000,
-                          help="Binding TCP port.")
+        parser.add_option(
+            "-l",
+            "--listen_address",
+            dest="listen_address",
+            default="",
+            help="Binding IP address.",
+        )
+        parser.add_option(
+            "-p",
+            "--port",
+            dest="port",
+            type=int,
+            default=8000,
+            help="Binding TCP port.",
+        )
         options, args = parser.parse_args(args)
 
         log_utils.default_logging_config()
         if len(args) >= 1:
             gitdir = args[0]
         else:
-            gitdir = '.'
-        from dulwich import porcelain
-        porcelain.web_daemon(gitdir, address=options.listen_address,
-                             port=options.port)
+            gitdir = "."
 
+        porcelain.web_daemon(gitdir, address=options.listen_address, port=options.port)
 
-class cmd_write_tree(Command):
 
+class cmd_write_tree(Command):
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
-        sys.stdout.write('%s\n' % porcelain.write_tree('.'))
+        sys.stdout.write("%s\n" % porcelain.write_tree("."))
 
 
 class cmd_receive_pack(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         if len(args) >= 1:
             gitdir = args[0]
         else:
-            gitdir = '.'
+            gitdir = "."
         porcelain.receive_pack(gitdir)
 
 
 class cmd_upload_pack(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         if len(args) >= 1:
             gitdir = args[0]
         else:
-            gitdir = '.'
+            gitdir = "."
         porcelain.upload_pack(gitdir)
 
 
 class cmd_status(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         if len(args) >= 1:
             gitdir = args[0]
         else:
-            gitdir = '.'
+            gitdir = "."
         status = porcelain.status(gitdir)
         if any(names for (kind, names) in status.staged.items()):
             sys.stdout.write("Changes to be committed:\n\n")
             for kind, names in status.staged.items():
                 for name in names:
-                    sys.stdout.write("\t%s: %s\n" % (
-                        kind, name.decode(sys.getfilesystemencoding())))
+                    sys.stdout.write(
+                        "\t%s: %s\n" % (kind, name.decode(sys.getfilesystemencoding()))
+                    )
             sys.stdout.write("\n")
         if status.unstaged:
             sys.stdout.write("Changes not staged for commit:\n\n")
             for name in status.unstaged:
-                sys.stdout.write(
-                    "\t%s\n" % name.decode(sys.getfilesystemencoding()))
+                sys.stdout.write("\t%s\n" % name.decode(sys.getfilesystemencoding()))
             sys.stdout.write("\n")
         if status.untracked:
             sys.stdout.write("Untracked files:\n\n")
@@ -458,11 +482,10 @@ class cmd_status(Command):
 
 
 class cmd_ls_remote(Command):
-
     def run(self, args):
-        opts, args = getopt(args, '', [])
+        opts, args = getopt(args, "", [])
         if len(args) < 1:
-            print('Usage: dulwich ls-remote URL')
+            print("Usage: dulwich ls-remote URL")
             sys.exit(1)
         refs = porcelain.ls_remote(args[0])
         for ref in sorted(refs):
@@ -470,48 +493,52 @@ class cmd_ls_remote(Command):
 
 
 class cmd_ls_tree(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
-        parser.add_option("-r", "--recursive", action="store_true",
-                          help="Recusively list tree contents.")
-        parser.add_option("--name-only", action="store_true",
-                          help="Only display name.")
+        parser.add_option(
+            "-r",
+            "--recursive",
+            action="store_true",
+            help="Recusively list tree contents.",
+        )
+        parser.add_option("--name-only", action="store_true", help="Only display name.")
         options, args = parser.parse_args(args)
         try:
             treeish = args.pop(0)
         except IndexError:
             treeish = None
         porcelain.ls_tree(
-            '.', treeish, outstream=sys.stdout, recursive=options.recursive,
-            name_only=options.name_only)
+            ".",
+            treeish,
+            outstream=sys.stdout,
+            recursive=options.recursive,
+            name_only=options.name_only,
+        )
 
 
 class cmd_pack_objects(Command):
-
     def run(self, args):
-        opts, args = getopt(args, '', ['stdout'])
+        opts, args = getopt(args, "", ["stdout"])
         opts = dict(opts)
-        if len(args) < 1 and '--stdout' not in args:
-            print('Usage: dulwich pack-objects basename')
+        if len(args) < 1 and "--stdout" not in args:
+            print("Usage: dulwich pack-objects basename")
             sys.exit(1)
         object_ids = [line.strip() for line in sys.stdin.readlines()]
         basename = args[0]
-        if '--stdout' in opts:
-            packf = getattr(sys.stdout, 'buffer', sys.stdout)
+        if "--stdout" in opts:
+            packf = getattr(sys.stdout, "buffer", sys.stdout)
             idxf = None
             close = []
         else:
-            packf = open(basename + '.pack', 'w')
-            idxf = open(basename + '.idx', 'w')
+            packf = open(basename + ".pack", "w")
+            idxf = open(basename + ".idx", "w")
             close = [packf, idxf]
-        porcelain.pack_objects('.', object_ids, packf, idxf)
+        porcelain.pack_objects(".", object_ids, packf, idxf)
         for f in close:
             f.close()
 
 
 class cmd_pull(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
@@ -519,28 +546,24 @@ class cmd_pull(Command):
             from_location = args[0]
         except IndexError:
             from_location = None
-        porcelain.pull('.', from_location)
+        porcelain.pull(".", from_location)
 
 
 class cmd_push(Command):
 
-    def run(self, args):
-        parser = optparse.OptionParser()
-        options, args = parser.parse_args(args)
-        if len(args) < 2:
-            print("Usage: dulwich push TO-LOCATION REFSPEC..")
-            sys.exit(1)
-        to_location = args[0]
-        refspecs = args[1:]
-        porcelain.push('.', to_location, refspecs)
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        parser.add_argument('to_location', type=str)
+        parser.add_argument('refspec', type=str, nargs='*')
+        args = parser.parse_args(argv)
+        porcelain.push('.', args.to_location, args.refspec or None)
 
 
 class cmd_remote_add(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
-        porcelain.remote_add('.', args[0], args[1])
+        porcelain.remote_add(".", args[0], args[1])
 
 
 class SuperCommand(Command):
@@ -549,14 +572,13 @@ class SuperCommand(Command):
 
     def run(self, args):
         if not args:
-            print("Supported subcommands: %s" %
-                  ', '.join(self.subcommands.keys()))
+            print("Supported subcommands: %s" % ", ".join(self.subcommands.keys()))
             return False
         cmd = args[0]
         try:
             cmd_kls = self.subcommands[cmd]
         except KeyError:
-            print('No such subcommand: %s' % args[0])
+            print("No such subcommand: %s" % args[0])
             return False
         return cmd_kls().run(args[1:])
 
@@ -569,51 +591,46 @@ class cmd_remote(SuperCommand):
 
 
 class cmd_check_ignore(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         ret = 1
-        for path in porcelain.check_ignore('.', args):
+        for path in porcelain.check_ignore(".", args):
             print(path)
             ret = 0
         return ret
 
 
 class cmd_check_mailmap(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         for arg in args:
-            canonical_identity = porcelain.check_mailmap('.', arg)
+            canonical_identity = porcelain.check_mailmap(".", arg)
             print(canonical_identity)
 
 
 class cmd_stash_list(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
-        for i, entry in porcelain.stash_list('.'):
-            print("stash@{%d}: %s" % (i, entry.message.rstrip('\n')))
+        for i, entry in porcelain.stash_list("."):
+            print("stash@{%d}: %s" % (i, entry.message.rstrip("\n")))
 
 
 class cmd_stash_push(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
-        porcelain.stash_push('.')
+        porcelain.stash_push(".")
         print("Saved working directory and index state")
 
 
 class cmd_stash_pop(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
-        porcelain.stash_pop('.')
+        porcelain.stash_pop(".")
         print("Restrored working directory and index state")
 
 
@@ -627,42 +644,45 @@ class cmd_stash(SuperCommand):
 
 
 class cmd_ls_files(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
-        for name in porcelain.ls_files('.'):
+        for name in porcelain.ls_files("."):
             print(name)
 
 
 class cmd_describe(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
-        print(porcelain.describe('.'))
+        print(porcelain.describe("."))
 
 
 class cmd_help(Command):
-
     def run(self, args):
         parser = optparse.OptionParser()
-        parser.add_option("-a", "--all", dest="all",
-                          action="store_true",
-                          help="List all commands.")
+        parser.add_option(
+            "-a",
+            "--all",
+            dest="all",
+            action="store_true",
+            help="List all commands.",
+        )
         options, args = parser.parse_args(args)
 
         if options.all:
-            print('Available commands:')
+            print("Available commands:")
             for cmd in sorted(commands):
-                print('  %s' % cmd)
+                print("  %s" % cmd)
         else:
-            print("""\
+            print(
+                """\
 The dulwich command line tool is currently a very basic frontend for the
 Dulwich python module. For full functionality, please see the API reference.
 
 For a list of supported commands, see 'dulwich help -a'.
-""")
+"""
+            )
 
 
 commands = {
@@ -706,7 +726,7 @@ commands = {
     "upload-pack": cmd_upload_pack,
     "web-daemon": cmd_web_daemon,
     "write-tree": cmd_write_tree,
-    }
+}
 
 
 def main(argv=None):
@@ -727,8 +747,8 @@ def main(argv=None):
     return cmd_kls().run(argv[1:])
 
 
-if __name__ == '__main__':
-    if 'DULWICH_PDB' in os.environ and getattr(signal, 'SIGQUIT', None):
+if __name__ == "__main__":
+    if "DULWICH_PDB" in os.environ and getattr(signal, "SIGQUIT", None):
         signal.signal(signal.SIGQUIT, signal_quit)  # type: ignore
     signal.signal(signal.SIGINT, signal_int)
 

文件差異過大導致無法顯示
+ 322 - 183
dulwich/client.py


+ 0 - 0
dulwich/cloud/__init__.py


+ 82 - 0
dulwich/cloud/gcs.py

@@ -0,0 +1,82 @@
+# object_store.py -- Object store for git objects
+# Copyright (C) 2021 Jelmer Vernooij <jelmer@jelmer.uk>
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+
+"""Storage of repositories on GCS."""
+
+import posixpath
+import tempfile
+
+from ..object_store import BucketBasedObjectStore
+from ..pack import PackData, Pack, load_pack_index_file
+
+
+# TODO(jelmer): For performance, read ranges?
+
+
+class GcsObjectStore(BucketBasedObjectStore):
+
+    def __init__(self, bucket, subpath=''):
+        super(GcsObjectStore, self).__init__()
+        self.bucket = bucket
+        self.subpath = subpath
+
+    def __repr__(self):
+        return "%s(%r, subpath=%r)" % (
+            type(self).__name__, self.bucket, self.subpath)
+
+    def _remove_pack(self, name):
+        self.bucket.delete_blobs([
+            posixpath.join(self.subpath, name) + '.' + ext
+            for ext in ['pack', 'idx']])
+
+    def _iter_pack_names(self):
+        packs = {}
+        for blob in self.bucket.list_blobs(prefix=self.subpath):
+            name, ext = posixpath.splitext(posixpath.basename(blob.name))
+            packs.setdefault(name, set()).add(ext)
+        for name, exts in packs.items():
+            if exts == set(['.pack', '.idx']):
+                yield name
+
+    def _load_pack_data(self, name):
+        b = self.bucket.blob(posixpath.join(self.subpath, name + '.pack'))
+        f = tempfile.SpooledTemporaryFile()
+        b.download_to_file(f)
+        f.seek(0)
+        return PackData(name + '.pack', f)
+
+    def _load_pack_index(self, name):
+        b = self.bucket.blob(posixpath.join(self.subpath, name + '.idx'))
+        f = tempfile.SpooledTemporaryFile()
+        b.download_to_file(f)
+        f.seek(0)
+        return load_pack_index_file(name + '.idx', f)
+
+    def _get_pack(self, name):
+        return Pack.from_lazy_objects(
+            lambda: self._load_pack_data(name),
+            lambda: self._load_pack_index(name))
+
+    def _upload_pack(self, basename, pack_file, index_file):
+        idxblob = self.bucket.blob(posixpath.join(self.subpath, basename + '.idx'))
+        datablob = self.bucket.blob(posixpath.join(self.subpath, basename + '.pack'))
+        idxblob.upload_from_file(index_file)
+        datablob.upload_from_file(pack_file)

+ 115 - 64
dulwich/config.py

@@ -33,17 +33,18 @@ from typing import BinaryIO, Tuple, Optional
 
 from collections import (
     OrderedDict,
-    )
+)
+
 try:
     from collections.abc import (
         Iterable,
         MutableMapping,
-        )
+    )
 except ImportError:  # python < 3.7
     from collections import (
         Iterable,
         MutableMapping,
-        )
+    )
 
 from dulwich.file import GitFile
 
@@ -56,15 +57,12 @@ def lower_key(key):
         return key.lower()
 
     if isinstance(key, Iterable):
-        return type(key)(
-            map(lower_key, key)
-        )
+        return type(key)(map(lower_key, key))
 
     return key
 
 
 class CaseInsensitiveDict(OrderedDict):
-
     @classmethod
     def make(cls, dict_in=None):
 
@@ -87,7 +85,7 @@ class CaseInsensitiveDict(OrderedDict):
     def __setitem__(self, key, value, **kwargs):
         key = lower_key(key)
 
-        super(CaseInsensitiveDict, self).__setitem__(key, value,  **kwargs)
+        super(CaseInsensitiveDict, self).__setitem__(key, value, **kwargs)
 
     def __getitem__(self, item):
         key = lower_key(item)
@@ -188,7 +186,7 @@ class Config(object):
         Returns:
           boolean indicating whether the section exists
         """
-        return (name in self.itersections())
+        return name in self.itersections()
 
 
 class ConfigDict(Config, MutableMapping):
@@ -205,9 +203,7 @@ class ConfigDict(Config, MutableMapping):
         return "%s(%r)" % (self.__class__.__name__, self._values)
 
     def __eq__(self, other):
-        return (
-            isinstance(other, self.__class__) and
-            other._values == self._values)
+        return isinstance(other, self.__class__) and other._values == self._values
 
     def __getitem__(self, key):
         return self._values.__getitem__(key)
@@ -234,13 +230,16 @@ class ConfigDict(Config, MutableMapping):
 
     def _check_section_and_name(self, section, name):
         if not isinstance(section, tuple):
-            section = (section, )
-
-        section = tuple([
-            subsection.encode(self.encoding)
-            if not isinstance(subsection, bytes) else subsection
-            for subsection in section
-            ])
+            section = (section,)
+
+        section = tuple(
+            [
+                subsection.encode(self.encoding)
+                if not isinstance(subsection, bytes)
+                else subsection
+                for subsection in section
+            ]
+        )
 
         if not isinstance(name, bytes):
             name = name.encode(self.encoding)
@@ -274,11 +273,13 @@ class ConfigDict(Config, MutableMapping):
 
 
 def _format_string(value):
-    if (value.startswith(b" ") or
-            value.startswith(b"\t") or
-            value.endswith(b" ") or
-            b'#' in value or
-            value.endswith(b"\t")):
+    if (
+        value.startswith(b" ")
+        or value.startswith(b"\t")
+        or value.endswith(b" ")
+        or b"#" in value
+        or value.endswith(b"\t")
+    ):
         return b'"' + _escape_value(value) + b'"'
     else:
         return _escape_value(value)
@@ -286,11 +287,11 @@ def _format_string(value):
 
 _ESCAPE_TABLE = {
     ord(b"\\"): ord(b"\\"),
-    ord(b"\""): ord(b"\""),
+    ord(b'"'): ord(b'"'),
     ord(b"n"): ord(b"\n"),
     ord(b"t"): ord(b"\t"),
     ord(b"b"): ord(b"\b"),
-    }
+}
 _COMMENT_CHARS = [ord(b"#"), ord(b";")]
 _WHITESPACE_CHARS = [ord(b"\t"), ord(b" ")]
 
@@ -309,18 +310,19 @@ def _parse_string(value):
                 v = _ESCAPE_TABLE[value[i]]
             except IndexError:
                 raise ValueError(
-                    "escape character in %r at %d before end of string" %
-                    (value, i))
+                    "escape character in %r at %d before end of string" % (value, i)
+                )
             except KeyError:
                 raise ValueError(
                     "escape character followed by unknown character "
-                    "%s at %d in %r" % (value[i], i, value))
+                    "%s at %d in %r" % (value[i], i, value)
+                )
             if whitespace:
                 ret.extend(whitespace)
                 whitespace = bytearray()
             ret.append(v)
-        elif c == ord(b"\""):
-            in_quotes = (not in_quotes)
+        elif c == ord(b'"'):
+            in_quotes = not in_quotes
         elif c in _COMMENT_CHARS and not in_quotes:
             # the rest of the line is a comment
             break
@@ -344,22 +346,22 @@ def _escape_value(value):
     value = value.replace(b"\\", b"\\\\")
     value = value.replace(b"\n", b"\\n")
     value = value.replace(b"\t", b"\\t")
-    value = value.replace(b"\"", b"\\\"")
+    value = value.replace(b'"', b'\\"')
     return value
 
 
 def _check_variable_name(name):
     for i in range(len(name)):
-        c = name[i:i+1]
-        if not c.isalnum() and c != b'-':
+        c = name[i : i + 1]
+        if not c.isalnum() and c != b"-":
             return False
     return True
 
 
 def _check_section_name(name):
     for i in range(len(name)):
-        c = name[i:i+1]
-        if not c.isalnum() and c not in (b'-', b'.'):
+        c = name[i : i + 1]
+        if not c.isalnum() and c not in (b"-", b"."):
             return False
     return True
 
@@ -379,15 +381,14 @@ def _strip_comments(line):
 
 
 class ConfigFile(ConfigDict):
-    """A Git configuration file, like .git/config or ~/.gitconfig.
-    """
+    """A Git configuration file, like .git/config or ~/.gitconfig."""
 
     def __init__(self, values=None, encoding=None):
         super(ConfigFile, self).__init__(values=values, encoding=encoding)
         self.path = None
 
     @classmethod
-    def from_file(cls, f: BinaryIO) -> 'ConfigFile':
+    def from_file(cls, f: BinaryIO) -> "ConfigFile":
         """Read configuration from a file-like object."""
         ret = cls()
         section = None  # type: Optional[Tuple[bytes, ...]]
@@ -404,26 +405,23 @@ class ConfigFile(ConfigDict):
                     except ValueError:
                         raise ValueError("expected trailing ]")
                     pts = line[1:last].split(b" ", 1)
-                    line = line[last+1:]
+                    line = line[last + 1 :]
                     if len(pts) == 2:
-                        if pts[1][:1] != b"\"" or pts[1][-1:] != b"\"":
-                            raise ValueError(
-                                "Invalid subsection %r" % pts[1])
+                        if pts[1][:1] != b'"' or pts[1][-1:] != b'"':
+                            raise ValueError("Invalid subsection %r" % pts[1])
                         else:
                             pts[1] = pts[1][1:-1]
                         if not _check_section_name(pts[0]):
-                            raise ValueError("invalid section name %r" %
-                                             pts[0])
+                            raise ValueError("invalid section name %r" % pts[0])
                         section = (pts[0], pts[1])
                     else:
                         if not _check_section_name(pts[0]):
-                            raise ValueError(
-                                "invalid section name %r" % pts[0])
+                            raise ValueError("invalid section name %r" % pts[0])
                         pts = pts[0].split(b".", 1)
                         if len(pts) == 2:
                             section = (pts[0], pts[1])
                         else:
-                            section = (pts[0], )
+                            section = (pts[0],)
                     ret._values.setdefault(section)
                 if _strip_comments(line).strip() == b"":
                     continue
@@ -456,9 +454,9 @@ class ConfigFile(ConfigDict):
         return ret
 
     @classmethod
-    def from_path(cls, path) -> 'ConfigFile':
+    def from_path(cls, path) -> "ConfigFile":
         """Read configuration from a file on disk."""
-        with GitFile(path, 'rb') as f:
+        with GitFile(path, "rb") as f:
             ret = cls.from_file(f)
             ret.path = path
             return ret
@@ -467,7 +465,7 @@ class ConfigFile(ConfigDict):
         """Write configuration to a file on disk."""
         if path is None:
             path = self.path
-        with GitFile(path, 'wb') as f:
+        with GitFile(path, "wb") as f:
             self.write_to_file(f)
 
     def write_to_file(self, f: BinaryIO) -> None:
@@ -476,13 +474,12 @@ class ConfigFile(ConfigDict):
             try:
                 section_name, subsection_name = section
             except ValueError:
-                (section_name, ) = section
+                (section_name,) = section
                 subsection_name = None
             if subsection_name is None:
                 f.write(b"[" + section_name + b"]\n")
             else:
-                f.write(b"[" + section_name +
-                        b" \"" + subsection_name + b"\"]\n")
+                f.write(b"[" + section_name + b' "' + subsection_name + b'"]\n')
             for key, value in values.items():
                 if value is True:
                     value = b"true"
@@ -495,11 +492,63 @@ class ConfigFile(ConfigDict):
 
 def get_xdg_config_home_path(*path_segments):
     xdg_config_home = os.environ.get(
-        "XDG_CONFIG_HOME", os.path.expanduser("~/.config/"),
+        "XDG_CONFIG_HOME",
+        os.path.expanduser("~/.config/"),
     )
     return os.path.join(xdg_config_home, *path_segments)
 
 
+def _find_git_in_win_path():
+    for exe in ("git.exe", "git.cmd"):
+        for path in os.environ.get("PATH", "").split(";"):
+            if os.path.exists(os.path.join(path, exe)):
+                # exe path is .../Git/bin/git.exe or .../Git/cmd/git.exe
+                git_dir, _bin_dir = os.path.split(path)
+                yield git_dir
+                break
+
+
+def _find_git_in_win_reg():
+    import platform
+    import winreg
+
+    if platform.machine() == "AMD64":
+        subkey = (
+            "SOFTWARE\\Wow6432Node\\Microsoft\\Windows\\"
+            "CurrentVersion\\Uninstall\\Git_is1"
+        )
+    else:
+        subkey = (
+            "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\"
+            "Uninstall\\Git_is1"
+        )
+
+    for key in (winreg.HKEY_CURRENT_USER, winreg.HKEY_LOCAL_MACHINE):
+        try:
+            with winreg.OpenKey(key, subkey) as k:
+                val, typ = winreg.QueryValueEx(k, "InstallLocation")
+                if typ == winreg.REG_SZ:
+                    yield val
+        except OSError:
+            pass
+
+
+# There is no set standard for system config dirs on windows. We try the
+# following:
+#   - %PROGRAMDATA%/Git/config - (deprecated) Windows config dir per CGit docs
+#   - %PROGRAMFILES%/Git/etc/gitconfig - Git for Windows (msysgit) config dir
+#     Used if CGit installation (Git/bin/git.exe) is found in PATH in the
+#     system registry
+def get_win_system_paths():
+    if "PROGRAMDATA" in os.environ:
+        yield os.path.join(os.environ["PROGRAMDATA"], "Git", "config")
+
+    for git_dir in _find_git_in_win_path():
+        yield os.path.join(git_dir, "etc", "gitconfig")
+    for git_dir in _find_git_in_win_reg():
+        yield os.path.join(git_dir, "etc", "gitconfig")
+
+
 class StackedConfig(Config):
     """Configuration which reads from multiple config files.."""
 
@@ -526,6 +575,8 @@ class StackedConfig(Config):
 
         if "GIT_CONFIG_NOSYSTEM" not in os.environ:
             paths.append("/etc/gitconfig")
+            if sys.platform == "win32":
+                paths.extend(get_win_system_paths())
 
         backends = []
         for path in paths:
@@ -538,7 +589,7 @@ class StackedConfig(Config):
 
     def get(self, section, name):
         if not isinstance(section, tuple):
-            section = (section, )
+            section = (section,)
         for backend in self.backends:
             try:
                 return backend.get(section, name)
@@ -555,15 +606,15 @@ class StackedConfig(Config):
 def parse_submodules(config):
     """Parse a gitmodules GitConfig file, returning submodules.
 
-   Args:
-     config: A `ConfigFile`
-   Returns:
-     list of tuples (submodule path, url, name),
-       where name is quoted part of the section's name.
+    Args:
+      config: A `ConfigFile`
+    Returns:
+      list of tuples (submodule path, url, name),
+        where name is quoted part of the section's name.
     """
     for section in config.keys():
         section_kind, section_name = section
-        if section_kind == b'submodule':
-            sm_path = config.get(section, b'path')
-            sm_url = config.get(section, b'url')
+        if section_kind == b"submodule":
+            sm_path = config.get(section, b"path")
+            sm_url = config.get(section, b"url")
             yield (sm_path, sm_url, section_name)

+ 5 - 4
dulwich/contrib/__init__.py

@@ -21,10 +21,11 @@
 
 def test_suite():
     import unittest
+
     names = [
-        'release_robot',
-        'swift',
-        ]
-    module_names = ['dulwich.contrib.test_' + name for name in names]
+        "release_robot",
+        "swift",
+    ]
+    module_names = ["dulwich.contrib.test_" + name for name in names]
     loader = unittest.TestLoader()
     return loader.loadTestsFromNames(module_names)

+ 45 - 36
dulwich/contrib/diffstat.py

@@ -39,16 +39,16 @@ import re
 # only needs to detect git style diffs as this is for
 # use with dulwich
 
-_git_header_name = re.compile(br'diff --git a/(.*) b/(.*)')
+_git_header_name = re.compile(br"diff --git a/(.*) b/(.*)")
 
-_GIT_HEADER_START = b'diff --git a/'
-_GIT_BINARY_START = b'Binary file'
-_GIT_RENAMEFROM_START = b'rename from'
-_GIT_RENAMETO_START = b'rename to'
-_GIT_CHUNK_START = b'@@'
-_GIT_ADDED_START = b'+'
-_GIT_DELETED_START = b'-'
-_GIT_UNCHANGED_START = b' '
+_GIT_HEADER_START = b"diff --git a/"
+_GIT_BINARY_START = b"Binary file"
+_GIT_RENAMEFROM_START = b"rename from"
+_GIT_RENAMETO_START = b"rename to"
+_GIT_CHUNK_START = b"@@"
+_GIT_ADDED_START = b"+"
+_GIT_DELETED_START = b"-"
+_GIT_UNCHANGED_START = b" "
 
 # emulate original full Patch class by just extracting
 # filename and minimal chunk added/deleted information to
@@ -89,9 +89,8 @@ def _parse_patch(lines):
         elif line.startswith(_GIT_RENAMEFROM_START) and in_git_header:
             currentfile = line[12:]
         elif line.startswith(_GIT_RENAMETO_START) and in_git_header:
-            currentfile += b' => %s' % line[10:]
-        elif line.startswith(_GIT_CHUNK_START) and \
-                (in_patch_chunk or in_git_header):
+            currentfile += b" => %s" % line[10:]
+        elif line.startswith(_GIT_CHUNK_START) and (in_patch_chunk or in_git_header):
             in_patch_chunk = True
             in_git_header = False
         elif line.startswith(_GIT_ADDED_START) and in_patch_chunk:
@@ -130,8 +129,8 @@ def diffstat(lines, max_width=80):
         insert.append(i)
         delete.append(d)
         namelen = max(namelen, len(filename))
-        maxdiff = max(maxdiff, i+d)
-    output = b''
+        maxdiff = max(maxdiff, i + d)
+    output = b""
     statlen = len(str(maxdiff))  # stats column width
     for i, n in enumerate(names):
         binaryfile = nametypes[i]
@@ -139,16 +138,21 @@ def diffstat(lines, max_width=80):
         # note b'%d' % namelen is not supported until Python 3.5
         # To convert an int to a format width specifier for byte
         # strings use str(namelen).encode('ascii')
-        format = b' %-' + str(namelen).encode('ascii') + \
-            b's | %' + str(statlen).encode('ascii') + b's %s\n'
-        binformat = b' %-' + str(namelen).encode('ascii') + b's | %s\n'
+        format = (
+            b" %-"
+            + str(namelen).encode("ascii")
+            + b"s | %"
+            + str(statlen).encode("ascii")
+            + b"s %s\n"
+        )
+        binformat = b" %-" + str(namelen).encode("ascii") + b"s | %s\n"
         if not binaryfile:
-            hist = b''
+            hist = b""
             # -- calculating histogram --
-            width = len(format % (b'', b'', b''))
+            width = len(format % (b"", b"", b""))
             histwidth = max(2, max_width - width)
             if maxdiff < histwidth:
-                hist = b'+'*insert[i] + b'-'*delete[i]
+                hist = b"+" * insert[i] + b"-" * delete[i]
             else:
                 iratio = (float(insert[i]) / maxdiff) * histwidth
                 dratio = (float(delete[i]) / maxdiff) * histwidth
@@ -165,15 +169,20 @@ def diffstat(lines, max_width=80):
                     dwidth = int(dratio)
                     if dwidth == 0 and 0 < dratio < 1:
                         dwidth = 1
-                hist = b'+'*int(iwidth) + b'-'*int(dwidth)
-            output += (format % (bytes(names[i]),
-                                 str(insert[i] + delete[i]).encode('ascii'),
-                                 hist))
+                hist = b"+" * int(iwidth) + b"-" * int(dwidth)
+            output += format % (
+                bytes(names[i]),
+                str(insert[i] + delete[i]).encode("ascii"),
+                hist,
+            )
         else:
-            output += (binformat % (bytes(names[i]), b'Bin'))
+            output += binformat % (bytes(names[i]), b"Bin")
 
-    output += (b' %d files changed, %d insertions(+), %d deletions(-)'
-               % (len(names), sum(insert), sum(delete)))
+    output += b" %d files changed, %d insertions(+), %d deletions(-)" % (
+        len(names),
+        sum(insert),
+        sum(delete),
+    )
     return output
 
 
@@ -182,12 +191,12 @@ def main():
     # allow diffstat.py to also be used from the comand line
     if len(sys.argv) > 1:
         diffpath = argv[1]
-        data = b''
-        with open(diffpath, 'rb') as f:
+        data = b""
+        with open(diffpath, "rb") as f:
             data = f.read()
-        lines = data.split(b'\n')
+        lines = data.split(b"\n")
         result = diffstat(lines)
-        print(result.decode('utf-8'))
+        print(result.decode("utf-8"))
         return 0
 
     # if no path argument to a diff file is passed in, run
@@ -314,7 +323,7 @@ index 3b41fd80..64914c78 100644
  2. open Sigil.app to the normal nearly blank template epub it generates when opened
  3. use Plugins->Manage Plugins menu and make sure the "Use Bundled Python" checkbox is checked
  4. use the "Add Plugin" button to navigate to and add testplugin.zip and then hit "Okay" to exit the Manage Plugins Dialog
-"""     # noqa: E501 W293
+"""  # noqa: E501 W293
 
     testoutput = b""" docs/qt512.7_remove_bad_workaround.patch            | 15 ++++++++++++
  docs/testplugin_v017.zip                            | Bin
@@ -324,17 +333,17 @@ index 3b41fd80..64914c78 100644
  5 files changed, 16 insertions(+), 27 deletions(-)"""  # noqa: W291
 
     # return 0 on success otherwise return -1
-    result = diffstat(selftest.split(b'\n'))
+    result = diffstat(selftest.split(b"\n"))
     if result == testoutput:
         print("self test passed")
         return 0
     print("self test failed")
     print("Received:")
-    print(result.decode('utf-8'))
+    print(result.decode("utf-8"))
     print("Expected:")
-    print(testoutput.decode('utf-8'))
+    print(testoutput.decode("utf-8"))
     return -1
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     sys.exit(main())

+ 18 - 12
dulwich/contrib/paramiko_vendor.py

@@ -35,7 +35,6 @@ import paramiko.client
 
 
 class _ParamikoWrapper(object):
-
     def __init__(self, client, channel):
         self.client = client
         self.channel = channel
@@ -59,7 +58,7 @@ class _ParamikoWrapper(object):
 
         # Closed socket
         if not data:
-            return b''
+            return b""
 
         # Read more if needed
         if n and data_len < n:
@@ -77,25 +76,32 @@ class ParamikoSSHVendor(object):
     def __init__(self, **kwargs):
         self.kwargs = kwargs
 
-    def run_command(self, host, command,
-                    username=None, port=None,
-                    password=None, pkey=None,
-                    key_filename=None, **kwargs):
+    def run_command(
+        self,
+        host,
+        command,
+        username=None,
+        port=None,
+        password=None,
+        pkey=None,
+        key_filename=None,
+        **kwargs
+    ):
 
         client = paramiko.SSHClient()
 
-        connection_kwargs = {'hostname': host}
+        connection_kwargs = {"hostname": host}
         connection_kwargs.update(self.kwargs)
         if username:
-            connection_kwargs['username'] = username
+            connection_kwargs["username"] = username
         if port:
-            connection_kwargs['port'] = port
+            connection_kwargs["port"] = port
         if password:
-            connection_kwargs['password'] = password
+            connection_kwargs["password"] = password
         if pkey:
-            connection_kwargs['pkey'] = pkey
+            connection_kwargs["pkey"] = pkey
         if key_filename:
-            connection_kwargs['key_filename'] = key_filename
+            connection_kwargs["key_filename"] = key_filename
         connection_kwargs.update(kwargs)
 
         policy = paramiko.client.MissingHostKeyPolicy()

+ 11 - 11
dulwich/contrib/release_robot.py

@@ -52,8 +52,8 @@ import time
 from dulwich.repo import Repo
 
 # CONSTANTS
-PROJDIR = '.'
-PATTERN = r'[ a-zA-Z_\-]*([\d\.]+[\-\w\.]*)'
+PROJDIR = "."
+PATTERN = r"[ a-zA-Z_\-]*([\d\.]+[\-\w\.]*)"
 
 
 def get_recent_tags(projdir=PROJDIR):
@@ -74,15 +74,15 @@ def get_recent_tags(projdir=PROJDIR):
         tags = {}  # empty dictionary to hold tags, commits and datetimes
         # iterate over refs in repository
         for key, value in refs.items():
-            key = key.decode('utf-8')  # compatible with Python-3
+            key = key.decode("utf-8")  # compatible with Python-3
             obj = project.get_object(value)  # dulwich object from SHA-1
             # don't just check if object is "tag" b/c it could be a "commit"
             # instead check if "tags" is in the ref-name
-            if u'tags' not in key:
+            if u"tags" not in key:
                 # skip ref if not a tag
                 continue
             # strip the leading text from refs to get "tag name"
-            _, tag = key.rsplit(u'/', 1)
+            _, tag = key.rsplit(u"/", 1)
             # check if tag object is "commit" or "tag" pointing to a "commit"
             try:
                 commit = obj.object  # a tuple (commit class, commit id)
@@ -92,8 +92,8 @@ def get_recent_tags(projdir=PROJDIR):
             else:
                 tag_meta = (
                     datetime.datetime(*time.gmtime(obj.tag_time)[:6]),
-                    obj.id.decode('utf-8'),
-                    obj.name.decode('utf-8')
+                    obj.id.decode("utf-8"),
+                    obj.name.decode("utf-8"),
                 )  # compatible with Python-3
                 commit = project.get_object(commit[1])  # commit object
             # get tag commit datetime, but dulwich returns seconds since
@@ -101,9 +101,9 @@ def get_recent_tags(projdir=PROJDIR):
             # timetuple then convert to datetime
             tags[tag] = [
                 datetime.datetime(*time.gmtime(commit.commit_time)[:6]),
-                commit.id.decode('utf-8'),
-                commit.author.decode('utf-8'),
-                tag_meta
+                commit.id.decode("utf-8"),
+                commit.author.decode("utf-8"),
+                tag_meta,
             ]  # compatible with Python-3
 
     # return list of tags sorted by their datetimes from newest to oldest
@@ -139,7 +139,7 @@ def get_current_version(projdir=PROJDIR, pattern=PATTERN, logger=None):
     return current_version
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     if len(sys.argv) > 1:
         _PROJDIR = sys.argv[1]
     else:

+ 221 - 188
dulwich/contrib/swift.py

@@ -41,7 +41,7 @@ from geventhttpclient import HTTPClient
 from dulwich.greenthreads import (
     GreenThreadsMissingObjectFinder,
     GreenThreadsObjectStoreIterator,
-    )
+)
 
 from dulwich.lru_cache import LRUSizeCache
 from dulwich.objects import (
@@ -50,12 +50,12 @@ from dulwich.objects import (
     Tree,
     Tag,
     S_ISGITLINK,
-    )
+)
 from dulwich.object_store import (
     PackBasedObjectStore,
     PACKDIR,
     INFODIR,
-    )
+)
 from dulwich.pack import (
     PackData,
     Pack,
@@ -70,21 +70,21 @@ from dulwich.pack import (
     _compute_object_size,
     unpack_object,
     write_pack_object,
-    )
+)
 from dulwich.protocol import TCP_GIT_PORT
 from dulwich.refs import (
     InfoRefsContainer,
     read_info_refs,
     write_info_refs,
-    )
+)
 from dulwich.repo import (
     BaseRepo,
     OBJECTDIR,
-    )
+)
 from dulwich.server import (
     Backend,
     TCPGitServer,
-    )
+)
 
 import json
 
@@ -120,9 +120,8 @@ cache_length = 20
 
 
 class PackInfoObjectStoreIterator(GreenThreadsObjectStoreIterator):
-
     def __len__(self):
-        while len(self.finder.objects_to_send):
+        while self.finder.objects_to_send:
             for _ in range(0, len(self.finder.objects_to_send)):
                 sha = self.finder.next()
                 self._shas.append(sha)
@@ -130,7 +129,6 @@ class PackInfoObjectStoreIterator(GreenThreadsObjectStoreIterator):
 
 
 class PackInfoMissingObjectFinder(GreenThreadsMissingObjectFinder):
-
     def next(self):
         while True:
             if not self.objects_to_send:
@@ -171,7 +169,7 @@ def load_conf(path=None, file=None):
     confpath = None
     if not path:
         try:
-            confpath = os.environ['DULWICH_SWIFT_CFG']
+            confpath = os.environ["DULWICH_SWIFT_CFG"]
         except KeyError:
             raise Exception("You need to specify a configuration file")
     else:
@@ -203,8 +201,11 @@ def pack_info_create(pack_data, pack_index):
             info[obj.id] = (obj.type_num, obj.parents, obj.tree)
         # Tree
         elif obj.type_num == Tree.type_num:
-            shas = [(s, n, not stat.S_ISDIR(m)) for
-                    n, m, s in obj.items() if not S_ISGITLINK(m)]
+            shas = [
+                (s, n, not stat.S_ISDIR(m))
+                for n, m, s in obj.items()
+                if not S_ISGITLINK(m)
+            ]
             info[obj.id] = (obj.type_num, shas)
         # Blob
         elif obj.type_num == Blob.type_num:
@@ -233,11 +234,10 @@ class SwiftException(Exception):
 
 
 class SwiftConnector(object):
-    """A Connector to swift that manage authentication and errors catching
-    """
+    """A Connector to swift that manage authentication and errors catching"""
 
     def __init__(self, root, conf):
-        """ Initialize a SwiftConnector
+        """Initialize a SwiftConnector
 
         Args:
           root: The swift container that will act as Git bare repository
@@ -246,18 +246,15 @@ class SwiftConnector(object):
         self.conf = conf
         self.auth_ver = self.conf.get("swift", "auth_ver")
         if self.auth_ver not in ["1", "2"]:
-            raise NotImplementedError(
-                "Wrong authentication version use either 1 or 2")
+            raise NotImplementedError("Wrong authentication version use either 1 or 2")
         self.auth_url = self.conf.get("swift", "auth_url")
         self.user = self.conf.get("swift", "username")
         self.password = self.conf.get("swift", "password")
-        self.concurrency = self.conf.getint('swift', 'concurrency') or 10
-        self.http_timeout = self.conf.getint('swift', 'http_timeout') or 20
-        self.http_pool_length = \
-            self.conf.getint('swift', 'http_pool_length') or 10
+        self.concurrency = self.conf.getint("swift", "concurrency") or 10
+        self.http_timeout = self.conf.getint("swift", "http_timeout") or 20
+        self.http_pool_length = self.conf.getint("swift", "http_pool_length") or 10
         self.region_name = self.conf.get("swift", "region_name") or "RegionOne"
-        self.endpoint_type = \
-            self.conf.get("swift", "endpoint_type") or "internalURL"
+        self.endpoint_type = self.conf.get("swift", "endpoint_type") or "internalURL"
         self.cache_length = self.conf.getint("swift", "cache_length") or 20
         self.chunk_length = self.conf.getint("swift", "chunk_length") or 12228
         self.root = root
@@ -267,16 +264,18 @@ class SwiftConnector(object):
         else:
             self.storage_url, self.token = self.swift_auth_v2()
 
-        token_header = {'X-Auth-Token': str(self.token)}
-        self.httpclient = \
-            HTTPClient.from_url(str(self.storage_url),
-                                concurrency=self.http_pool_length,
-                                block_size=block_size,
-                                connection_timeout=self.http_timeout,
-                                network_timeout=self.http_timeout,
-                                headers=token_header)
-        self.base_path = str(posixpath.join(
-                urlparse.urlparse(self.storage_url).path, self.root))
+        token_header = {"X-Auth-Token": str(self.token)}
+        self.httpclient = HTTPClient.from_url(
+            str(self.storage_url),
+            concurrency=self.http_pool_length,
+            block_size=block_size,
+            connection_timeout=self.http_timeout,
+            network_timeout=self.http_timeout,
+            headers=token_header,
+        )
+        self.base_path = str(
+            posixpath.join(urlparse.urlparse(self.storage_url).path, self.root)
+        )
 
     def swift_auth_v1(self):
         self.user = self.user.replace(";", ":")
@@ -284,62 +283,68 @@ class SwiftConnector(object):
             self.auth_url,
             connection_timeout=self.http_timeout,
             network_timeout=self.http_timeout,
-            )
-        headers = {'X-Auth-User': self.user,
-                   'X-Auth-Key': self.password}
+        )
+        headers = {"X-Auth-User": self.user, "X-Auth-Key": self.password}
         path = urlparse.urlparse(self.auth_url).path
 
-        ret = auth_httpclient.request('GET', path, headers=headers)
+        ret = auth_httpclient.request("GET", path, headers=headers)
 
         # Should do something with redirections (301 in my case)
 
         if ret.status_code < 200 or ret.status_code >= 300:
-            raise SwiftException('AUTH v1.0 request failed on ' +
-                                 '%s with error code %s (%s)'
-                                 % (str(auth_httpclient.get_base_url()) +
-                                    path, ret.status_code,
-                                    str(ret.items())))
-        storage_url = ret['X-Storage-Url']
-        token = ret['X-Auth-Token']
+            raise SwiftException(
+                "AUTH v1.0 request failed on "
+                + "%s with error code %s (%s)"
+                % (
+                    str(auth_httpclient.get_base_url()) + path,
+                    ret.status_code,
+                    str(ret.items()),
+                )
+            )
+        storage_url = ret["X-Storage-Url"]
+        token = ret["X-Auth-Token"]
         return storage_url, token
 
     def swift_auth_v2(self):
-        self.tenant, self.user = self.user.split(';')
+        self.tenant, self.user = self.user.split(";")
         auth_dict = {}
-        auth_dict['auth'] = {'passwordCredentials':
-                             {
-                                 'username': self.user,
-                                 'password': self.password,
-                             },
-                             'tenantName': self.tenant}
+        auth_dict["auth"] = {
+            "passwordCredentials": {
+                "username": self.user,
+                "password": self.password,
+            },
+            "tenantName": self.tenant,
+        }
         auth_json = json.dumps(auth_dict)
-        headers = {'Content-Type': 'application/json'}
+        headers = {"Content-Type": "application/json"}
         auth_httpclient = HTTPClient.from_url(
             self.auth_url,
             connection_timeout=self.http_timeout,
             network_timeout=self.http_timeout,
-            )
+        )
         path = urlparse.urlparse(self.auth_url).path
-        if not path.endswith('tokens'):
-            path = posixpath.join(path, 'tokens')
-        ret = auth_httpclient.request('POST', path,
-                                      body=auth_json,
-                                      headers=headers)
+        if not path.endswith("tokens"):
+            path = posixpath.join(path, "tokens")
+        ret = auth_httpclient.request("POST", path, body=auth_json, headers=headers)
 
         if ret.status_code < 200 or ret.status_code >= 300:
-            raise SwiftException('AUTH v2.0 request failed on ' +
-                                 '%s with error code %s (%s)'
-                                 % (str(auth_httpclient.get_base_url()) +
-                                    path, ret.status_code,
-                                    str(ret.items())))
+            raise SwiftException(
+                "AUTH v2.0 request failed on "
+                + "%s with error code %s (%s)"
+                % (
+                    str(auth_httpclient.get_base_url()) + path,
+                    ret.status_code,
+                    str(ret.items()),
+                )
+            )
         auth_ret_json = json.loads(ret.read())
-        token = auth_ret_json['access']['token']['id']
-        catalogs = auth_ret_json['access']['serviceCatalog']
-        object_store = [o_store for o_store in catalogs if
-                        o_store['type'] == 'object-store'][0]
-        endpoints = object_store['endpoints']
-        endpoint = [endp for endp in endpoints if
-                    endp["region"] == self.region_name][0]
+        token = auth_ret_json["access"]["token"]["id"]
+        catalogs = auth_ret_json["access"]["serviceCatalog"]
+        object_store = [
+            o_store for o_store in catalogs if o_store["type"] == "object-store"
+        ][0]
+        endpoints = object_store["endpoints"]
+        endpoint = [endp for endp in endpoints if endp["region"] == self.region_name][0]
         return endpoint[self.endpoint_type], token
 
     def test_root_exists(self):
@@ -347,12 +352,13 @@ class SwiftConnector(object):
 
         Returns: True if exist or None it not
         """
-        ret = self.httpclient.request('HEAD', self.base_path)
+        ret = self.httpclient.request("HEAD", self.base_path)
         if ret.status_code == 404:
             return None
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('HEAD request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "HEAD request failed with error code %s" % ret.status_code
+            )
         return True
 
     def create_root(self):
@@ -362,10 +368,11 @@ class SwiftConnector(object):
           SwiftException: if unable to create
         """
         if not self.test_root_exists():
-            ret = self.httpclient.request('PUT', self.base_path)
+            ret = self.httpclient.request("PUT", self.base_path)
             if ret.status_code < 200 or ret.status_code > 300:
-                raise SwiftException('PUT request failed with error code %s'
-                                     % ret.status_code)
+                raise SwiftException(
+                    "PUT request failed with error code %s" % ret.status_code
+                )
 
     def get_container_objects(self):
         """Retrieve objects list in a container
@@ -373,14 +380,15 @@ class SwiftConnector(object):
         Returns: A list of dict that describe objects
                  or None if container does not exist
         """
-        qs = '?format=json'
+        qs = "?format=json"
         path = self.base_path + qs
-        ret = self.httpclient.request('GET', path)
+        ret = self.httpclient.request("GET", path)
         if ret.status_code == 404:
             return None
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('GET request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "GET request failed with error code %s" % ret.status_code
+            )
         content = ret.read()
         return json.loads(content)
 
@@ -392,13 +400,14 @@ class SwiftConnector(object):
         Returns:
           A dict that describe the object or None if object does not exist
         """
-        path = self.base_path + '/' + name
-        ret = self.httpclient.request('HEAD', path)
+        path = self.base_path + "/" + name
+        ret = self.httpclient.request("HEAD", path)
         if ret.status_code == 404:
             return None
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('HEAD request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "HEAD request failed with error code %s" % ret.status_code
+            )
         resp_headers = {}
         for header, value in ret.items():
             resp_headers[header.lower()] = value
@@ -415,13 +424,11 @@ class SwiftConnector(object):
         """
         content.seek(0)
         data = content.read()
-        path = self.base_path + '/' + name
-        headers = {'Content-Length': str(len(data))}
+        path = self.base_path + "/" + name
+        headers = {"Content-Length": str(len(data))}
 
         def _send():
-            ret = self.httpclient.request('PUT', path,
-                                          body=data,
-                                          headers=headers)
+            ret = self.httpclient.request("PUT", path, body=data, headers=headers)
             return ret
 
         try:
@@ -432,8 +439,9 @@ class SwiftConnector(object):
             ret = _send()
 
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('PUT request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "PUT request failed with error code %s" % ret.status_code
+            )
 
     def get_object(self, name, range=None):
         """Retrieve an object
@@ -447,14 +455,15 @@ class SwiftConnector(object):
         """
         headers = {}
         if range:
-            headers['Range'] = 'bytes=%s' % range
-        path = self.base_path + '/' + name
-        ret = self.httpclient.request('GET', path, headers=headers)
+            headers["Range"] = "bytes=%s" % range
+        path = self.base_path + "/" + name
+        ret = self.httpclient.request("GET", path, headers=headers)
         if ret.status_code == 404:
             return None
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('GET request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "GET request failed with error code %s" % ret.status_code
+            )
         content = ret.read()
 
         if range:
@@ -469,11 +478,12 @@ class SwiftConnector(object):
         Raises:
           SwiftException: if unable to delete
         """
-        path = self.base_path + '/' + name
-        ret = self.httpclient.request('DELETE', path)
+        path = self.base_path + "/" + name
+        ret = self.httpclient.request("DELETE", path)
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('DELETE request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "DELETE request failed with error code %s" % ret.status_code
+            )
 
     def del_root(self):
         """Delete the root container by removing container content
@@ -482,11 +492,12 @@ class SwiftConnector(object):
           SwiftException: if unable to delete
         """
         for obj in self.get_container_objects():
-            self.del_object(obj['name'])
-        ret = self.httpclient.request('DELETE', self.base_path)
+            self.del_object(obj["name"])
+        ret = self.httpclient.request("DELETE", self.base_path)
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('DELETE request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "DELETE request failed with error code %s" % ret.status_code
+            )
 
 
 class SwiftPackReader(object):
@@ -512,7 +523,7 @@ class SwiftPackReader(object):
         self.pack_length = pack_length
         self.offset = 0
         self.base_offset = 0
-        self.buff = b''
+        self.buff = b""
         self.buff_length = self.scon.chunk_length
 
     def _read(self, more=False):
@@ -531,16 +542,16 @@ class SwiftPackReader(object):
         Returns:
           a bytestring
         """
-        end = self.offset+length
+        end = self.offset + length
         if self.base_offset + end > self.pack_length:
-            data = self.buff[self.offset:]
+            data = self.buff[self.offset :]
             self.offset = end
             return data
         if end > len(self.buff):
             # Need to read more from swift
             self._read(more=True)
             return self.read(length)
-        data = self.buff[self.offset:end]
+        data = self.buff[self.offset : end]
         self.offset = end
         return data
 
@@ -570,7 +581,7 @@ class SwiftPackData(PackData):
     """
 
     def __init__(self, scon, filename):
-        """ Initialize a SwiftPackReader
+        """Initialize a SwiftPackReader
 
         Args:
           scon: a `SwiftConnector` instance
@@ -580,27 +591,26 @@ class SwiftPackData(PackData):
         self._filename = filename
         self._header_size = 12
         headers = self.scon.get_object_stat(self._filename)
-        self.pack_length = int(headers['content-length'])
-        pack_reader = SwiftPackReader(self.scon, self._filename,
-                                      self.pack_length)
+        self.pack_length = int(headers["content-length"])
+        pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length)
         (version, self._num_objects) = read_pack_header(pack_reader.read)
-        self._offset_cache = LRUSizeCache(1024*1024*self.scon.cache_length,
-                                          compute_size=_compute_object_size)
+        self._offset_cache = LRUSizeCache(
+            1024 * 1024 * self.scon.cache_length,
+            compute_size=_compute_object_size,
+        )
         self.pack = None
 
     def get_object_at(self, offset):
         if offset in self._offset_cache:
             return self._offset_cache[offset]
         assert offset >= self._header_size
-        pack_reader = SwiftPackReader(self.scon, self._filename,
-                                      self.pack_length)
+        pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length)
         pack_reader.seek(offset)
         unpacked, _ = unpack_object(pack_reader.read)
         return (unpacked.pack_type_num, unpacked._obj())
 
     def get_stored_checksum(self):
-        pack_reader = SwiftPackReader(self.scon, self._filename,
-                                      self.pack_length)
+        pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length)
         return pack_reader.read_checksum()
 
     def close(self):
@@ -616,15 +626,13 @@ class SwiftPack(Pack):
     """
 
     def __init__(self, *args, **kwargs):
-        self.scon = kwargs['scon']
-        del kwargs['scon']
+        self.scon = kwargs["scon"]
+        del kwargs["scon"]
         super(SwiftPack, self).__init__(*args, **kwargs)
-        self._pack_info_path = self._basename + '.info'
+        self._pack_info_path = self._basename + ".info"
         self._pack_info = None
-        self._pack_info_load = lambda: load_pack_info(self._pack_info_path,
-                                                      self.scon)
-        self._idx_load = lambda: swift_load_pack_index(self.scon,
-                                                       self._idx_path)
+        self._pack_info_load = lambda: load_pack_info(self._pack_info_path, self.scon)
+        self._idx_load = lambda: swift_load_pack_index(self.scon, self._idx_path)
         self._data_load = lambda: SwiftPackData(self.scon, self._data_path)
 
     @property
@@ -641,6 +649,7 @@ class SwiftObjectStore(PackBasedObjectStore):
     Allow to manage a bare Git repository from Openstack Swift.
     This object store only supports pack files and not loose objects.
     """
+
     def __init__(self, scon):
         """Open a Swift object store.
 
@@ -655,8 +664,11 @@ class SwiftObjectStore(PackBasedObjectStore):
 
     def _update_pack_cache(self):
         objects = self.scon.get_container_objects()
-        pack_files = [o['name'].replace(".pack", "")
-                      for o in objects if o['name'].endswith(".pack")]
+        pack_files = [
+            o["name"].replace(".pack", "")
+            for o in objects
+            if o["name"].endswith(".pack")
+        ]
         ret = []
         for basename in pack_files:
             pack = SwiftPack(basename, scon=self.scon)
@@ -665,8 +677,7 @@ class SwiftObjectStore(PackBasedObjectStore):
         return ret
 
     def _iter_loose_objects(self):
-        """Loose objects are not supported by this repository
-        """
+        """Loose objects are not supported by this repository"""
         return []
 
     def iter_shas(self, finder):
@@ -676,11 +687,10 @@ class SwiftObjectStore(PackBasedObjectStore):
                  instance if gevent is enabled
         """
         shas = iter(finder.next, None)
-        return PackInfoObjectStoreIterator(
-            self, shas, finder, self.scon.concurrency)
+        return PackInfoObjectStoreIterator(self, shas, finder, self.scon.concurrency)
 
     def find_missing_objects(self, *args, **kwargs):
-        kwargs['concurrency'] = self.scon.concurrency
+        kwargs["concurrency"] = self.scon.concurrency
         return PackInfoMissingObjectFinder(self, *args, **kwargs)
 
     def pack_info_get(self, sha):
@@ -725,11 +735,11 @@ class SwiftObjectStore(PackBasedObjectStore):
             f.seek(0)
             pack = PackData(file=f, filename="")
             entries = pack.sorted_entries()
-            if len(entries):
-                basename = posixpath.join(self.pack_dir,
-                                          "pack-%s" %
-                                          iter_sha1(entry[0] for
-                                                    entry in entries))
+            if entries:
+                basename = posixpath.join(
+                    self.pack_dir,
+                    "pack-%s" % iter_sha1(entry[0] for entry in entries),
+                )
                 index = BytesIO()
                 write_pack_index_v2(index, entries, pack.get_stored_checksum())
                 self.scon.put_object(basename + ".pack", f)
@@ -745,10 +755,15 @@ class SwiftObjectStore(PackBasedObjectStore):
 
         def abort():
             pass
+
         return f, commit, abort
 
     def add_object(self, obj):
-        self.add_objects([(obj, None), ])
+        self.add_objects(
+            [
+                (obj, None),
+            ]
+        )
 
     def _pack_cache_stale(self):
         return False
@@ -762,12 +777,11 @@ class SwiftObjectStore(PackBasedObjectStore):
         Read it from a stream and complete it in a temporary file.
         Then the pack and the corresponding index file are uploaded to Swift.
         """
-        fd, path = tempfile.mkstemp(prefix='tmp_pack_')
-        f = os.fdopen(fd, 'w+b')
+        fd, path = tempfile.mkstemp(prefix="tmp_pack_")
+        f = os.fdopen(fd, "w+b")
         try:
             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
-            copier = PackStreamCopier(read_all, read_some, f,
-                                      delta_iter=indexer)
+            copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             return self._complete_thin_pack(f, path, copier, indexer)
         finally:
@@ -805,11 +819,12 @@ class SwiftObjectStore(PackBasedObjectStore):
         entries.sort()
         pack_base_name = posixpath.join(
             self.pack_dir,
-            'pack-' + os.fsdecode(iter_sha1(e[0] for e in entries)))
-        self.scon.put_object(pack_base_name + '.pack', f)
+            "pack-" + os.fsdecode(iter_sha1(e[0] for e in entries)),
+        )
+        self.scon.put_object(pack_base_name + ".pack", f)
 
         # Write the index.
-        filename = pack_base_name + '.idx'
+        filename = pack_base_name + ".idx"
         index_file = BytesIO()
         write_pack_index_v2(index_file, entries, pack_sha)
         self.scon.put_object(filename, index_file)
@@ -818,12 +833,12 @@ class SwiftObjectStore(PackBasedObjectStore):
         f.seek(0)
         pack_data = PackData(filename="", file=f)
         index_file.seek(0)
-        pack_index = load_pack_index_file('', index_file)
+        pack_index = load_pack_index_file("", index_file)
         serialized_pack_info = pack_info_create(pack_data, pack_index)
         f.close()
         index_file.close()
         pack_info_file = BytesIO(serialized_pack_info)
-        filename = pack_base_name + '.info'
+        filename = pack_base_name + ".info"
         self.scon.put_object(filename, pack_info_file)
         pack_info_file.close()
 
@@ -835,16 +850,15 @@ class SwiftObjectStore(PackBasedObjectStore):
 
 
 class SwiftInfoRefsContainer(InfoRefsContainer):
-    """Manage references in info/refs object.
-    """
+    """Manage references in info/refs object."""
 
     def __init__(self, scon, store):
         self.scon = scon
-        self.filename = 'info/refs'
+        self.filename = "info/refs"
         self.store = store
         f = self.scon.get_object(self.filename)
         if not f:
-            f = BytesIO(b'')
+            f = BytesIO(b"")
         super(SwiftInfoRefsContainer, self).__init__(f)
 
     def _load_check_ref(self, name, old_ref):
@@ -864,9 +878,8 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         self.scon.put_object(self.filename, f)
 
     def set_if_equals(self, name, old_ref, new_ref):
-        """Set a refname to new_ref only if it currently equals old_ref.
-        """
-        if name == 'HEAD':
+        """Set a refname to new_ref only if it currently equals old_ref."""
+        if name == "HEAD":
             return True
         refs = self._load_check_ref(name, old_ref)
         if not isinstance(refs, dict):
@@ -877,9 +890,8 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         return True
 
     def remove_if_equals(self, name, old_ref):
-        """Remove a refname only if it currently equals old_ref.
-        """
-        if name == 'HEAD':
+        """Remove a refname only if it currently equals old_ref."""
+        if name == "HEAD":
             return True
         refs = self._load_check_ref(name, old_ref)
         if not isinstance(refs, dict):
@@ -891,14 +903,13 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
 
     def allkeys(self):
         try:
-            self._refs['HEAD'] = self._refs['refs/heads/master']
+            self._refs["HEAD"] = self._refs["refs/heads/master"]
         except KeyError:
             pass
         return self._refs.keys()
 
 
 class SwiftRepo(BaseRepo):
-
     def __init__(self, root, conf):
         """Init a Git bare Repository on top of a Swift container.
 
@@ -910,15 +921,15 @@ class SwiftRepo(BaseRepo):
           root: The container which contains the bare repo
           conf: A ConfigParser object
         """
-        self.root = root.lstrip('/')
+        self.root = root.lstrip("/")
         self.conf = conf
         self.scon = SwiftConnector(self.root, self.conf)
         objects = self.scon.get_container_objects()
         if not objects:
-            raise Exception('There is not any GIT repo here : %s' % self.root)
-        objects = [o['name'].split('/')[0] for o in objects]
+            raise Exception("There is not any GIT repo here : %s" % self.root)
+        objects = [o["name"].split("/")[0] for o in objects]
         if OBJECTDIR not in objects:
-            raise Exception('This repository (%s) is not bare.' % self.root)
+            raise Exception("This repository (%s) is not bare." % self.root)
         self.bare = True
         self._controldir = self.root
         object_store = SwiftObjectStore(self.scon)
@@ -954,66 +965,89 @@ class SwiftRepo(BaseRepo):
           a `SwiftRepo` instance
         """
         scon.create_root()
-        for obj in [posixpath.join(OBJECTDIR, PACKDIR),
-                    posixpath.join(INFODIR, 'refs')]:
-            scon.put_object(obj, BytesIO(b''))
+        for obj in [
+            posixpath.join(OBJECTDIR, PACKDIR),
+            posixpath.join(INFODIR, "refs"),
+        ]:
+            scon.put_object(obj, BytesIO(b""))
         ret = cls(scon.root, conf)
         ret._init_files(True)
         return ret
 
 
 class SwiftSystemBackend(Backend):
-
     def __init__(self, logger, conf):
         self.conf = conf
         self.logger = logger
 
     def open_repository(self, path):
-        self.logger.info('opening repository at %s', path)
+        self.logger.info("opening repository at %s", path)
         return SwiftRepo(path, self.conf)
 
 
 def cmd_daemon(args):
     """Entry point for starting a TCP git server."""
     import optparse
+
     parser = optparse.OptionParser()
-    parser.add_option("-l", "--listen_address", dest="listen_address",
-                      default="127.0.0.1",
-                      help="Binding IP address.")
-    parser.add_option("-p", "--port", dest="port", type=int,
-                      default=TCP_GIT_PORT,
-                      help="Binding TCP port.")
-    parser.add_option("-c", "--swift_config", dest="swift_config",
-                      default="",
-                      help="Path to the configuration file for Swift backend.")
+    parser.add_option(
+        "-l",
+        "--listen_address",
+        dest="listen_address",
+        default="127.0.0.1",
+        help="Binding IP address.",
+    )
+    parser.add_option(
+        "-p",
+        "--port",
+        dest="port",
+        type=int,
+        default=TCP_GIT_PORT,
+        help="Binding TCP port.",
+    )
+    parser.add_option(
+        "-c",
+        "--swift_config",
+        dest="swift_config",
+        default="",
+        help="Path to the configuration file for Swift backend.",
+    )
     options, args = parser.parse_args(args)
 
     try:
         import gevent
         import geventhttpclient  # noqa: F401
     except ImportError:
-        print("gevent and geventhttpclient libraries are mandatory "
-              " for use the Swift backend.")
+        print(
+            "gevent and geventhttpclient libraries are mandatory "
+            " for use the Swift backend."
+        )
         sys.exit(1)
     import gevent.monkey
+
     gevent.monkey.patch_socket()
     from dulwich import log_utils
+
     logger = log_utils.getLogger(__name__)
     conf = load_conf(options.swift_config)
     backend = SwiftSystemBackend(logger, conf)
 
     log_utils.default_logging_config()
-    server = TCPGitServer(backend, options.listen_address,
-                          port=options.port)
+    server = TCPGitServer(backend, options.listen_address, port=options.port)
     server.serve_forever()
 
 
 def cmd_init(args):
     import optparse
+
     parser = optparse.OptionParser()
-    parser.add_option("-c", "--swift_config", dest="swift_config",
-                      default="",
-                      help="Path to the configuration file for Swift backend.")
+    parser.add_option(
+        "-c",
+        "--swift_config",
+        dest="swift_config",
+        default="",
+        help="Path to the configuration file for Swift backend.",
+    )
     options, args = parser.parse_args(args)
 
     conf = load_conf(options.swift_config)
@@ -1031,8 +1065,7 @@ def main(argv=sys.argv):
     }
 
     if len(sys.argv) < 2:
-        print("Usage: %s <%s> [OPTIONS...]" % (
-                sys.argv[0], "|".join(commands.keys())))
+        print("Usage: %s <%s> [OPTIONS...]" % (sys.argv[0], "|".join(commands.keys())))
         sys.exit(1)
 
     cmd = sys.argv[1]
@@ -1042,5 +1075,5 @@ def main(argv=sys.argv):
     commands[cmd](sys.argv[2:])
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()

+ 189 - 0
dulwich/contrib/test_paramiko_vendor.py

@@ -0,0 +1,189 @@
+# test_paramiko_vendor.py
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Tests for paramiko_vendor."""
+
+import socket
+import paramiko
+import threading
+
+from dulwich.tests import TestCase
+from dulwich.contrib.paramiko_vendor import ParamikoSSHVendor
+
+try:
+    from StringIO import StringIO
+except ImportError:
+    from io import StringIO
+
+
+USER = 'testuser'
+PASSWORD = 'test'
+SERVER_KEY = """\
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEAy/L1sSYAzxsMprtNXW4u/1jGXXkQmQ2xtmKVlR+RlIL3a1BH
+bzTpPlZyjltAAwzIP8XRh0iJFKz5y3zSQChhX47ZGN0NvQsVct8R+YwsUonwfAJ+
+JN0KBKKvC8fPHlzqBr3gX+ZxqsFH934tQ6wdQPH5eQWtdM8L826lMsH1737uyTGk
++mCSDjL3c6EzY83g7qhkJU2R4qbi6ne01FaWADzG8sOzXnHT+xpxtk8TTT8yCVUY
+MmBNsSoA/ka3iWz70ghB+6Xb0WpFJZXWq1oYovviPAfZGZSrxBZMxsWMye70SdLl
+TqsBEt0+miIcm9s0fvjWvQuhaHX6mZs5VO4r5QIDAQABAoIBAGYqeYWaYgFdrYLA
+hUrubUCg+g3NHdFuGL4iuIgRXl4lFUh+2KoOuWDu8Uf60iA1AQNhV0sLvQ/Mbv3O
+s4xMLisuZfaclctDiCUZNenqnDFkxEF7BjH1QJV94W5nU4wEQ3/JEmM4D2zYkfKb
+FJW33JeyH6TOgUvohDYYEU1R+J9V8qA243p+ui1uVtNI6Pb0TXJnG5y9Ny4vkSWH
+Fi0QoMPR1r9xJ4SEearGzA/crb4SmmDTKhGSoMsT3d5ATieLmwcS66xWz8w4oFGJ
+yzDq24s4Fp9ccNjMf/xR8XRiekJv835gjEqwF9IXyvgOaq6XJ1iCqGPFDKa25nui
+JnEstOkCgYEA/ZXk7aIanvdeJlTqpX578sJfCnrXLydzE8emk1b7+5mrzGxQ4/pM
+PBQs2f8glT3t0O0mRX9NoRqnwrid88/b+cY4NCOICFZeasX336/gYQxyVeRLJS6Z
+hnGEQqry8qS7PdKAyeHMNmZFrUh4EiHiObymEfQS+mkRUObn0cGBTw8CgYEAzeQU
+D2baec1DawjppKaRynAvWjp+9ry1lZx9unryKVRwjRjkEpw+b3/+hdaF1IvsVSce
+cNj+6W2guZ2tyHuPhZ64/4SJVyE2hKDSKD4xTb2nVjsMeN0bLD2UWXC9mwbx8nWa
+2tmtUZ7a/okQb2cSdosJinRewLNqXIsBXamT1csCgYEA0cXb2RCOQQ6U3dTFPx4A
+3vMXuA2iUKmrsqMoEx6T2LBow/Sefdkik1iFOdipVYwjXP+w9zC2QR1Rxez/DR/X
+8ymceNUjxPHdrSoTQQG29dFcC92MpDeGXQcuyA+uZjcLhbrLOzYEvsOfxBb87NMG
+14hNQPDNekTMREafYo9WrtUCgYAREK54+FVzcwf7fymedA/xb4r9N4v+d3W1iNsC
+8d3Qfyc1CrMct8aVB07ZWQaOr2pPRIbJY7L9NhD0UZVt4I/sy1MaGqonhqE2LP4+
+R6legDG2e/50ph7yc8gwAaA1kUXMiuLi8Nfkw/3yyvmJwklNegi4aRzRbA2Mzhi2
+4q9WMQKBgQCb0JNyxHG4pvLWCF/j0Sm1FfvrpnqSv5678n1j4GX7Ka/TubOK1Y4K
+U+Oib7dKa/zQMWehVFNTayrsq6bKVZ6q7zG+IHiRLw4wjeAxREFH6WUjDrn9vl2l
+D48DKbBuBwuVOJWyq3qbfgJXojscgNQklrsPdXVhDwOF0dYxP89HnA=="""
+CLIENT_KEY = """\
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEAxvREKSElPOm/0z/nPO+j5rk2tjdgGcGc7We1QZ6TRXYLu7nN
+GeEFIL4p8N1i6dmB+Eydt7xqCU79MWD6Yy4prFe1+/K1wCDUxIbFMxqQcX5zjJzd
+i8j8PbcaUlVhP/OkjtkSxrXaGDO1BzfdV4iEBtTV/2l3zmLKJlt3jnOHLczP24CB
+DTQKp3rKshbRefzot9Y+wnaK692RsYgsyo9YEP0GyWKG9topCHk13r46J6vGLeuj
+ryUKqmbLJkzbJbIcEqwTDo5iHaCVqaMr5Hrb8BdMucSseqZQJsXSd+9tdRcIblUQ
+38kZjmFMm4SFbruJcpZCNM2wNSZPIRX+3eiwNwIDAQABAoIBAHSacOBSJsr+jIi5
+KUOTh9IPtzswVUiDKwARCjB9Sf8p4lKR4N1L/n9kNJyQhApeikgGT2GCMftmqgoo
+tlculQoHFgemBlOmak0MV8NNzF5YKEy/GzF0CDH7gJfEpoyetVFrdA+2QS5yD6U9
+XqKQxiBi2VEqdScmyyeT8AwzNYTnPeH/DOEcnbdRjqiy/CD79F49CQ1lX1Fuqm0K
+I7BivBH1xo/rVnUP4F+IzocDqoga+Pjdj0LTXIgJlHQDSbhsQqWujWQDDuKb+MAw
+sNK4Zf8ErV3j1PyA7f/M5LLq6zgstkW4qikDHo4SpZX8kFOO8tjqb7kujj7XqeaB
+CxqrOTECgYEA73uWkrohcmDJ4KqbuL3tbExSCOUiaIV+sT1eGPNi7GCmXD4eW5Z4
+75v2IHymW83lORSu/DrQ6sKr1nkuRpqr2iBzRmQpl/H+wahIhBXlnJ25uUjDsuPO
+1Pq2LcmyD+jTxVnmbSe/q7O09gZQw3I6H4+BMHmpbf8tC97lqimzpJ0CgYEA1K0W
+ZL70Xtn9quyHvbtae/BW07NZnxvUg4UaVIAL9Zu34JyplJzyzbIjrmlDbv6aRogH
+/KtuG9tfbf55K/jjqNORiuRtzt1hUN1ye4dyW7tHx2/7lXdlqtyK40rQl8P0kqf8
+zaS6BqjnobgSdSpg32rWoL/pcBHPdJCJEgQ8zeMCgYEA0/PK8TOhNIzrP1dgGSKn
+hkkJ9etuB5nW5mEM7gJDFDf6JPupfJ/xiwe6z0fjKK9S57EhqgUYMB55XYnE5iIw
+ZQ6BV9SAZ4V7VsRs4dJLdNC3tn/rDGHJBgCaym2PlbsX6rvFT+h1IC8dwv0V79Ui
+Ehq9WTzkMoE8yhvNokvkPZUCgYEAgBAFxv5xGdh79ftdtXLmhnDvZ6S8l6Fjcxqo
+Ay/jg66Tp43OU226iv/0mmZKM8Dd1xC8dnon4GBVc19jSYYiWBulrRPlx0Xo/o+K
+CzZBN1lrXH1i6dqufpc0jq8TMf/N+q1q/c1uMupsKCY1/xVYpc+ok71b7J7c49zQ
+nOeuUW8CgYA9Infooy65FTgbzca0c9kbCUBmcAPQ2ItH3JcPKWPQTDuV62HcT00o
+fZdIV47Nez1W5Clk191RMy8TXuqI54kocciUWpThc6j44hz49oUueb8U4bLcEHzA
+WxtWBWHwxfSmqgTXilEA3ALJp0kNolLnEttnhENwJpZHlqtes0ZA4w==
+-----END RSA PRIVATE KEY-----"""
+
+
+class Server(paramiko.ServerInterface):
+    """http://docs.paramiko.org/en/2.4/api/server.html"""
+    def __init__(self, commands, *args, **kwargs):
+        super(Server, self).__init__(*args, **kwargs)
+        self.commands = commands
+
+    def check_channel_exec_request(self, channel, command):
+        self.commands.append(command)
+        return True
+
+    def check_auth_password(self, username, password):
+        if username == USER and password == PASSWORD:
+            return paramiko.AUTH_SUCCESSFUL
+        return paramiko.AUTH_FAILED
+
+    def check_auth_publickey(self, username, key):
+        pubkey = paramiko.RSAKey.from_private_key(StringIO(CLIENT_KEY))
+        if username == USER and key == pubkey:
+            return paramiko.AUTH_SUCCESSFUL
+        return paramiko.AUTH_FAILED
+
+    def check_channel_request(self, kind, chanid):
+        if kind == "session":
+            return paramiko.OPEN_SUCCEEDED
+        return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+
+    def get_allowed_auths(self, username):
+        return "password,publickey"
+
+
+class ParamikoSSHVendorTests(TestCase):
+    def setUp(self):
+        self.commands = []
+        socket.setdefaulttimeout(10)
+        self.addCleanup(socket.setdefaulttimeout, None)
+        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.socket.bind(('127.0.0.1', 0))
+        self.socket.listen(5)
+        self.addCleanup(self.socket.close)
+        self.port = self.socket.getsockname()[1]
+        self.thread = threading.Thread(target=self._run)
+        self.thread.start()
+
+    def tearDown(self):
+        pass
+
+    def _run(self):
+        try:
+            conn, addr = self.socket.accept()
+        except socket.error:
+            return False
+        self.transport = paramiko.Transport(conn)
+        self.addCleanup(self.transport.close)
+        host_key = paramiko.RSAKey.from_private_key(StringIO(SERVER_KEY))
+        self.transport.add_server_key(host_key)
+        server = Server(self.commands)
+        self.transport.start_server(server=server)
+
+    def test_run_command_password(self):
+        vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False,)
+        vendor.run_command(
+            '127.0.0.1', 'test_run_command_password',
+            username=USER, port=self.port, password=PASSWORD)
+
+        self.assertIn(b'test_run_command_password', self.commands)
+
+    def test_run_command_with_privkey(self):
+        key = paramiko.RSAKey.from_private_key(StringIO(CLIENT_KEY))
+
+        vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False,)
+        vendor.run_command(
+            '127.0.0.1', 'test_run_command_with_privkey',
+            username=USER, port=self.port, pkey=key)
+
+        self.assertIn(b'test_run_command_with_privkey', self.commands)
+
+    def test_run_command_data_transfer(self):
+        vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False,)
+        con = vendor.run_command(
+            '127.0.0.1', 'test_run_command_data_transfer',
+            username=USER, port=self.port, password=PASSWORD)
+
+        self.assertIn(b'test_run_command_data_transfer', self.commands)
+
+        channel = self.transport.accept(5)
+        channel.send(b'stdout\n')
+        channel.send_stderr(b'stderr\n')
+        channel.close()
+
+        # Fixme: it's return false
+        # self.assertTrue(con.can_read())
+
+        self.assertEqual(b'stdout\n', con.read(4096))
+
+        # Fixme: it's return empty string
+        # self.assertEqual(b'stderr\n', con.read_stderr(4096))

+ 28 - 21
dulwich/contrib/test_release_robot.py

@@ -44,10 +44,17 @@ class TagPatternTests(unittest.TestCase):
     def test_tag_pattern(self):
         """test tag patterns"""
         test_cases = {
-            '0.3': '0.3', 'v0.3': '0.3', 'release0.3': '0.3',
-            'Release-0.3': '0.3', 'v0.3rc1': '0.3rc1', 'v0.3-rc1': '0.3-rc1',
-            'v0.3-rc.1': '0.3-rc.1', 'version 0.3': '0.3',
-            'version_0.3_rc_1': '0.3_rc_1', 'v1': '1', '0.3rc1': '0.3rc1'
+            "0.3": "0.3",
+            "v0.3": "0.3",
+            "release0.3": "0.3",
+            "Release-0.3": "0.3",
+            "v0.3rc1": "0.3rc1",
+            "v0.3-rc1": "0.3-rc1",
+            "v0.3-rc.1": "0.3-rc.1",
+            "version 0.3": "0.3",
+            "version_0.3_rc_1": "0.3_rc_1",
+            "v1": "1",
+            "0.3rc1": "0.3rc1",
         }
         for testcase, version in test_cases.items():
             matches = re.match(release_robot.PATTERN, testcase)
@@ -58,12 +65,12 @@ class GetRecentTagsTest(unittest.TestCase):
     """test get recent tags"""
 
     # Git repo for dulwich project
-    test_repo = os.path.join(BASEDIR, 'dulwich_test_repo.zip')
+    test_repo = os.path.join(BASEDIR, "dulwich_test_repo.zip")
     committer = b"Mark Mikofski <mark.mikofski@sunpowercorp.com>"
-    test_tags = [b'v0.1a', b'v0.1']
+    test_tags = [b"v0.1a", b"v0.1"]
     tag_test_data = {
-        test_tags[0]: [1484788003, b'3' * 40, None],
-        test_tags[1]: [1484788314, b'1' * 40, (1484788401, b'2' * 40)]
+        test_tags[0]: [1484788003, b"3" * 40, None],
+        test_tags[1]: [1484788314, b"1" * 40, (1484788401, b"2" * 40)],
     }
 
     @classmethod
@@ -75,20 +82,20 @@ class GetRecentTagsTest(unittest.TestCase):
         cls.c1 = make_commit(
             id=cls.tag_test_data[cls.test_tags[0]][1],
             commit_time=cls.tag_test_data[cls.test_tags[0]][0],
-            message=b'unannotated tag',
-            author=cls.committer
+            message=b"unannotated tag",
+            author=cls.committer,
         )
         obj_store.add_object(cls.c1)
         # tag 1: unannotated
         cls.t1 = cls.test_tags[0]
-        cls.repo[b'refs/tags/' + cls.t1] = cls.c1.id  # add unannotated tag
+        cls.repo[b"refs/tags/" + cls.t1] = cls.c1.id  # add unannotated tag
         # commit 2 ('2017-01-19T01:11:54')
         cls.c2 = make_commit(
             id=cls.tag_test_data[cls.test_tags[1]][1],
             commit_time=cls.tag_test_data[cls.test_tags[1]][0],
-            message=b'annotated tag',
+            message=b"annotated tag",
             parents=[cls.c1.id],
-            author=cls.committer
+            author=cls.committer,
         )
         obj_store.add_object(cls.c2)
         # tag 2: annotated ('2017-01-19T01:13:21')
@@ -96,11 +103,11 @@ class GetRecentTagsTest(unittest.TestCase):
             cls.c2,
             id=cls.tag_test_data[cls.test_tags[1]][2][1],
             name=cls.test_tags[1],
-            tag_time=cls.tag_test_data[cls.test_tags[1]][2][0]
+            tag_time=cls.tag_test_data[cls.test_tags[1]][2][0],
         )
         obj_store.add_object(cls.t2)
-        cls.repo[b'refs/heads/master'] = cls.c2.id
-        cls.repo[b'refs/tags/' + cls.t2.name] = cls.t2.id  # add annotated tag
+        cls.repo[b"refs/heads/master"] = cls.c2.id
+        cls.repo[b"refs/tags/" + cls.t2.name] = cls.t2.id  # add annotated tag
 
     @classmethod
     def tearDownClass(cls):
@@ -111,17 +118,17 @@ class GetRecentTagsTest(unittest.TestCase):
         """test get recent tags"""
         tags = release_robot.get_recent_tags(self.projdir)  # get test tags
         for tag, metadata in tags:
-            tag = tag.encode('utf-8')
+            tag = tag.encode("utf-8")
             test_data = self.tag_test_data[tag]  # test data tag
             # test commit date, id and author name
             self.assertEqual(metadata[0], gmtime_to_datetime(test_data[0]))
-            self.assertEqual(metadata[1].encode('utf-8'), test_data[1])
-            self.assertEqual(metadata[2].encode('utf-8'), self.committer)
+            self.assertEqual(metadata[1].encode("utf-8"), test_data[1])
+            self.assertEqual(metadata[2].encode("utf-8"), self.committer)
             # skip unannotated tags
             tag_obj = test_data[2]
             if not tag_obj:
                 continue
             # tag date, id and name
             self.assertEqual(metadata[3][0], gmtime_to_datetime(tag_obj[0]))
-            self.assertEqual(metadata[3][1].encode('utf-8'), tag_obj[1])
-            self.assertEqual(metadata[3][2].encode('utf-8'), tag)
+            self.assertEqual(metadata[3][1].encode("utf-8"), tag_obj[1])
+            self.assertEqual(metadata[3][2].encode("utf-8"), tag)

+ 215 - 185
dulwich/contrib/test_swift.py

@@ -31,17 +31,17 @@ from unittest import skipIf
 
 from dulwich.tests import (
     TestCase,
-    )
+)
 from dulwich.tests.test_object_store import (
     ObjectStoreTests,
-    )
+)
 from dulwich.objects import (
     Blob,
     Commit,
     Tree,
     Tag,
     parse_timezone,
-    )
+)
 
 import json
 
@@ -82,25 +82,24 @@ http_pool_length = %(http_pool_length)s
 http_timeout = %(http_timeout)s
 """
 
-def_config_file = {'version_str': 'v1.0',
-                   'version_int': 1,
-                   'concurrency': 1,
-                   'chunk_length': 12228,
-                   'cache_length': 1,
-                   'region_name': 'test',
-                   'endpoint_type': 'internalURL',
-                   'http_pool_length': 1,
-                   'http_timeout': 1}
+def_config_file = {
+    "version_str": "v1.0",
+    "version_int": 1,
+    "concurrency": 1,
+    "chunk_length": 12228,
+    "cache_length": 1,
+    "region_name": "test",
+    "endpoint_type": "internalURL",
+    "http_pool_length": 1,
+    "http_timeout": 1,
+}
 
 
 def create_swift_connector(store={}):
-    return lambda root, conf: FakeSwiftConnector(root,
-                                                 conf=conf,
-                                                 store=store)
+    return lambda root, conf: FakeSwiftConnector(root, conf=conf, store=store)
 
 
 class Response(object):
-
     def __init__(self, headers={}, status=200, content=None):
         self.headers = headers
         self.status_code = status
@@ -117,40 +116,46 @@ class Response(object):
 
 
 def fake_auth_request_v1(*args, **kwargs):
-    ret = Response({'X-Storage-Url':
-                    'http://127.0.0.1:8080/v1.0/AUTH_fakeuser',
-                    'X-Auth-Token': '12' * 10},
-                   200)
+    ret = Response(
+        {
+            "X-Storage-Url": "http://127.0.0.1:8080/v1.0/AUTH_fakeuser",
+            "X-Auth-Token": "12" * 10,
+        },
+        200,
+    )
     return ret
 
 
 def fake_auth_request_v1_error(*args, **kwargs):
-    ret = Response({},
-                   401)
+    ret = Response({}, 401)
     return ret
 
 
 def fake_auth_request_v2(*args, **kwargs):
-    s_url = 'http://127.0.0.1:8080/v1.0/AUTH_fakeuser'
-    resp = {'access': {'token': {'id': '12' * 10},
-                       'serviceCatalog':
-                       [
-                           {'type': 'object-store',
-                            'endpoints': [{'region': 'test',
-                                          'internalURL': s_url,
-                                           },
-                                          ]
-                            },
-                       ]
-                       }
-            }
+    s_url = "http://127.0.0.1:8080/v1.0/AUTH_fakeuser"
+    resp = {
+        "access": {
+            "token": {"id": "12" * 10},
+            "serviceCatalog": [
+                {
+                    "type": "object-store",
+                    "endpoints": [
+                        {
+                            "region": "test",
+                            "internalURL": s_url,
+                        },
+                    ],
+                },
+            ],
+        }
+    }
     ret = Response(status=200, content=json.dumps(resp))
     return ret
 
 
-def create_commit(data, marker=b'Default', blob=None):
+def create_commit(data, marker=b"Default", blob=None):
     if not blob:
-        blob = Blob.from_string(b'The blob content ' + marker)
+        blob = Blob.from_string(b"The blob content " + marker)
     tree = Tree()
     tree.add(b"thefile_" + marker, 0o100644, blob.id)
     cmt = Commit()
@@ -160,7 +165,7 @@ def create_commit(data, marker=b'Default', blob=None):
     cmt.tree = tree.id
     author = b"John Doe " + marker + b" <john@doe.net>"
     cmt.author = cmt.committer = author
-    tz = parse_timezone(b'-0200')[0]
+    tz = parse_timezone(b"-0200")[0]
     cmt.commit_time = cmt.author_time = int(time())
     cmt.commit_timezone = cmt.author_timezone = tz
     cmt.encoding = b"UTF-8"
@@ -168,14 +173,14 @@ def create_commit(data, marker=b'Default', blob=None):
     tag = Tag()
     tag.tagger = b"john@doe.net"
     tag.message = b"Annotated tag"
-    tag.tag_timezone = parse_timezone(b'-0200')[0]
+    tag.tag_timezone = parse_timezone(b"-0200")[0]
     tag.tag_time = cmt.author_time
     tag.object = (Commit, cmt.id)
     tag.name = b"v_" + marker + b"_0.1"
     return blob, tree, tag, cmt
 
 
-def create_commits(length=1, marker=b'Default'):
+def create_commits(length=1, marker=b"Default"):
     data = []
     for i in range(0, length):
         _marker = ("%s_%s" % (marker, i)).encode()
@@ -186,7 +191,6 @@ def create_commits(length=1, marker=b'Default'):
 
 @skipIf(missing_libs, skipmsg)
 class FakeSwiftConnector(object):
-
     def __init__(self, root, conf, store=None):
         if store:
             self.store = store
@@ -200,7 +204,7 @@ class FakeSwiftConnector(object):
 
     def put_object(self, name, content):
         name = posixpath.join(self.root, name)
-        if hasattr(content, 'seek'):
+        if hasattr(content, "seek"):
             content.seek(0)
             content = content.read()
         self.store[name] = content
@@ -213,96 +217,99 @@ class FakeSwiftConnector(object):
             except KeyError:
                 return None
         else:
-            l, r = range.split('-')
+            l, r = range.split("-")
             try:
                 if not l:
                     r = -int(r)
                     return self.store[name][r:]
                 else:
-                    return self.store[name][int(l):int(r)]
+                    return self.store[name][int(l) : int(r)]
             except KeyError:
                 return None
 
     def get_container_objects(self):
-        return [{'name': k.replace(self.root + '/', '')}
-                for k in self.store]
+        return [{"name": k.replace(self.root + "/", "")} for k in self.store]
 
     def create_root(self):
         if self.root in self.store.keys():
             pass
         else:
-            self.store[self.root] = ''
+            self.store[self.root] = ""
 
     def get_object_stat(self, name):
         name = posixpath.join(self.root, name)
         if name not in self.store:
             return None
-        return {'content-length': len(self.store[name])}
+        return {"content-length": len(self.store[name])}
 
 
 @skipIf(missing_libs, skipmsg)
 class TestSwiftRepo(TestCase):
-
     def setUp(self):
         super(TestSwiftRepo, self).setUp()
-        self.conf = swift.load_conf(file=StringIO(config_file %
-                                                  def_config_file))
+        self.conf = swift.load_conf(file=StringIO(config_file % def_config_file))
 
     def test_init(self):
-        store = {'fakerepo/objects/pack': ''}
-        with patch('dulwich.contrib.swift.SwiftConnector',
-                   new_callable=create_swift_connector,
-                   store=store):
-            swift.SwiftRepo('fakerepo', conf=self.conf)
+        store = {"fakerepo/objects/pack": ""}
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector",
+            new_callable=create_swift_connector,
+            store=store,
+        ):
+            swift.SwiftRepo("fakerepo", conf=self.conf)
 
     def test_init_no_data(self):
-        with patch('dulwich.contrib.swift.SwiftConnector',
-                   new_callable=create_swift_connector):
-            self.assertRaises(Exception, swift.SwiftRepo,
-                              'fakerepo', self.conf)
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector",
+            new_callable=create_swift_connector,
+        ):
+            self.assertRaises(Exception, swift.SwiftRepo, "fakerepo", self.conf)
 
     def test_init_bad_data(self):
-        store = {'fakerepo/.git/objects/pack': ''}
-        with patch('dulwich.contrib.swift.SwiftConnector',
-                   new_callable=create_swift_connector,
-                   store=store):
-            self.assertRaises(Exception, swift.SwiftRepo,
-                              'fakerepo', self.conf)
+        store = {"fakerepo/.git/objects/pack": ""}
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector",
+            new_callable=create_swift_connector,
+            store=store,
+        ):
+            self.assertRaises(Exception, swift.SwiftRepo, "fakerepo", self.conf)
 
     def test_put_named_file(self):
-        store = {'fakerepo/objects/pack': ''}
-        with patch('dulwich.contrib.swift.SwiftConnector',
-                   new_callable=create_swift_connector,
-                   store=store):
-            repo = swift.SwiftRepo('fakerepo', conf=self.conf)
-            desc = b'Fake repo'
-            repo._put_named_file('description', desc)
-        self.assertEqual(repo.scon.store['fakerepo/description'],
-                         desc)
+        store = {"fakerepo/objects/pack": ""}
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector",
+            new_callable=create_swift_connector,
+            store=store,
+        ):
+            repo = swift.SwiftRepo("fakerepo", conf=self.conf)
+            desc = b"Fake repo"
+            repo._put_named_file("description", desc)
+        self.assertEqual(repo.scon.store["fakerepo/description"], desc)
 
     def test_init_bare(self):
-        fsc = FakeSwiftConnector('fakeroot', conf=self.conf)
-        with patch('dulwich.contrib.swift.SwiftConnector',
-                   new_callable=create_swift_connector,
-                   store=fsc.store):
+        fsc = FakeSwiftConnector("fakeroot", conf=self.conf)
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector",
+            new_callable=create_swift_connector,
+            store=fsc.store,
+        ):
             swift.SwiftRepo.init_bare(fsc, conf=self.conf)
-        self.assertIn('fakeroot/objects/pack', fsc.store)
-        self.assertIn('fakeroot/info/refs', fsc.store)
-        self.assertIn('fakeroot/description', fsc.store)
+        self.assertIn("fakeroot/objects/pack", fsc.store)
+        self.assertIn("fakeroot/info/refs", fsc.store)
+        self.assertIn("fakeroot/description", fsc.store)
 
 
 @skipIf(missing_libs, skipmsg)
 class TestSwiftInfoRefsContainer(TestCase):
-
     def setUp(self):
         super(TestSwiftInfoRefsContainer, self).setUp()
         content = (
             b"22effb216e3a82f97da599b8885a6cadb488b4c5\trefs/heads/master\n"
-            b"cca703b0e1399008b53a1a236d6b4584737649e4\trefs/heads/dev")
-        self.store = {'fakerepo/info/refs': content}
-        self.conf = swift.load_conf(file=StringIO(config_file %
-                                                  def_config_file))
-        self.fsc = FakeSwiftConnector('fakerepo', conf=self.conf)
+            b"cca703b0e1399008b53a1a236d6b4584737649e4\trefs/heads/dev"
+        )
+        self.store = {"fakerepo/info/refs": content}
+        self.conf = swift.load_conf(file=StringIO(config_file % def_config_file))
+        self.fsc = FakeSwiftConnector("fakerepo", conf=self.conf)
         self.object_store = {}
 
     def test_init(self):
@@ -311,160 +318,183 @@ class TestSwiftInfoRefsContainer(TestCase):
         self.assertEqual(len(irc._refs), 0)
         self.fsc.store = self.store
         irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store)
-        self.assertIn(b'refs/heads/dev', irc.allkeys())
-        self.assertIn(b'refs/heads/master', irc.allkeys())
+        self.assertIn(b"refs/heads/dev", irc.allkeys())
+        self.assertIn(b"refs/heads/master", irc.allkeys())
 
     def test_set_if_equals(self):
         self.fsc.store = self.store
         irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store)
-        irc.set_if_equals(b'refs/heads/dev',
-                          b"cca703b0e1399008b53a1a236d6b4584737649e4", b'1'*40)
-        self.assertEqual(irc[b'refs/heads/dev'], b'1'*40)
+        irc.set_if_equals(
+            b"refs/heads/dev",
+            b"cca703b0e1399008b53a1a236d6b4584737649e4",
+            b"1" * 40,
+        )
+        self.assertEqual(irc[b"refs/heads/dev"], b"1" * 40)
 
     def test_remove_if_equals(self):
         self.fsc.store = self.store
         irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store)
-        irc.remove_if_equals(b'refs/heads/dev',
-                             b"cca703b0e1399008b53a1a236d6b4584737649e4")
-        self.assertNotIn(b'refs/heads/dev', irc.allkeys())
+        irc.remove_if_equals(
+            b"refs/heads/dev", b"cca703b0e1399008b53a1a236d6b4584737649e4"
+        )
+        self.assertNotIn(b"refs/heads/dev", irc.allkeys())
 
 
 @skipIf(missing_libs, skipmsg)
 class TestSwiftConnector(TestCase):
-
     def setUp(self):
         super(TestSwiftConnector, self).setUp()
-        self.conf = swift.load_conf(file=StringIO(config_file %
-                                                  def_config_file))
-        with patch('geventhttpclient.HTTPClient.request',
-                   fake_auth_request_v1):
-            self.conn = swift.SwiftConnector('fakerepo', conf=self.conf)
+        self.conf = swift.load_conf(file=StringIO(config_file % def_config_file))
+        with patch("geventhttpclient.HTTPClient.request", fake_auth_request_v1):
+            self.conn = swift.SwiftConnector("fakerepo", conf=self.conf)
 
     def test_init_connector(self):
-        self.assertEqual(self.conn.auth_ver, '1')
-        self.assertEqual(self.conn.auth_url,
-                         'http://127.0.0.1:8080/auth/v1.0')
-        self.assertEqual(self.conn.user, 'test:tester')
-        self.assertEqual(self.conn.password, 'testing')
-        self.assertEqual(self.conn.root, 'fakerepo')
-        self.assertEqual(self.conn.storage_url,
-                         'http://127.0.0.1:8080/v1.0/AUTH_fakeuser')
-        self.assertEqual(self.conn.token, '12' * 10)
+        self.assertEqual(self.conn.auth_ver, "1")
+        self.assertEqual(self.conn.auth_url, "http://127.0.0.1:8080/auth/v1.0")
+        self.assertEqual(self.conn.user, "test:tester")
+        self.assertEqual(self.conn.password, "testing")
+        self.assertEqual(self.conn.root, "fakerepo")
+        self.assertEqual(
+            self.conn.storage_url, "http://127.0.0.1:8080/v1.0/AUTH_fakeuser"
+        )
+        self.assertEqual(self.conn.token, "12" * 10)
         self.assertEqual(self.conn.http_timeout, 1)
         self.assertEqual(self.conn.http_pool_length, 1)
         self.assertEqual(self.conn.concurrency, 1)
-        self.conf.set('swift', 'auth_ver', '2')
-        self.conf.set('swift', 'auth_url', 'http://127.0.0.1:8080/auth/v2.0')
-        with patch('geventhttpclient.HTTPClient.request',
-                   fake_auth_request_v2):
-            conn = swift.SwiftConnector('fakerepo', conf=self.conf)
-        self.assertEqual(conn.user, 'tester')
-        self.assertEqual(conn.tenant, 'test')
-        self.conf.set('swift', 'auth_ver', '1')
-        self.conf.set('swift', 'auth_url', 'http://127.0.0.1:8080/auth/v1.0')
-        with patch('geventhttpclient.HTTPClient.request',
-                   fake_auth_request_v1_error):
-            self.assertRaises(swift.SwiftException,
-                              lambda: swift.SwiftConnector('fakerepo',
-                                                           conf=self.conf))
+        self.conf.set("swift", "auth_ver", "2")
+        self.conf.set("swift", "auth_url", "http://127.0.0.1:8080/auth/v2.0")
+        with patch("geventhttpclient.HTTPClient.request", fake_auth_request_v2):
+            conn = swift.SwiftConnector("fakerepo", conf=self.conf)
+        self.assertEqual(conn.user, "tester")
+        self.assertEqual(conn.tenant, "test")
+        self.conf.set("swift", "auth_ver", "1")
+        self.conf.set("swift", "auth_url", "http://127.0.0.1:8080/auth/v1.0")
+        with patch("geventhttpclient.HTTPClient.request", fake_auth_request_v1_error):
+            self.assertRaises(
+                swift.SwiftException,
+                lambda: swift.SwiftConnector("fakerepo", conf=self.conf),
+            )
 
     def test_root_exists(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response()):
+        with patch("geventhttpclient.HTTPClient.request", lambda *args: Response()):
             self.assertEqual(self.conn.test_root_exists(), True)
 
     def test_root_not_exists(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response(status=404)):
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args: Response(status=404),
+        ):
             self.assertEqual(self.conn.test_root_exists(), None)
 
     def test_create_root(self):
-        with patch('dulwich.contrib.swift.SwiftConnector.test_root_exists',
-                   lambda *args: None):
-            with patch('geventhttpclient.HTTPClient.request',
-                       lambda *args: Response()):
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector.test_root_exists",
+            lambda *args: None,
+        ):
+            with patch("geventhttpclient.HTTPClient.request", lambda *args: Response()):
                 self.assertEqual(self.conn.create_root(), None)
 
     def test_create_root_fails(self):
-        with patch('dulwich.contrib.swift.SwiftConnector.test_root_exists',
-                   lambda *args: None):
-            with patch('geventhttpclient.HTTPClient.request',
-                       lambda *args: Response(status=404)):
-                self.assertRaises(swift.SwiftException,
-                                  lambda: self.conn.create_root())
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector.test_root_exists",
+            lambda *args: None,
+        ):
+            with patch(
+                "geventhttpclient.HTTPClient.request",
+                lambda *args: Response(status=404),
+            ):
+                self.assertRaises(swift.SwiftException, self.conn.create_root)
 
     def test_get_container_objects(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response(content=json.dumps(
-                       (({'name': 'a'}, {'name': 'b'}))))):
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args: Response(
+                content=json.dumps((({"name": "a"}, {"name": "b"})))
+            ),
+        ):
             self.assertEqual(len(self.conn.get_container_objects()), 2)
 
     def test_get_container_objects_fails(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response(status=404)):
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args: Response(status=404),
+        ):
             self.assertEqual(self.conn.get_container_objects(), None)
 
     def test_get_object_stat(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response(headers={'content-length': '10'})):
-            self.assertEqual(self.conn.get_object_stat('a')['content-length'],
-                             '10')
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args: Response(headers={"content-length": "10"}),
+        ):
+            self.assertEqual(self.conn.get_object_stat("a")["content-length"], "10")
 
     def test_get_object_stat_fails(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response(status=404)):
-            self.assertEqual(self.conn.get_object_stat('a'), None)
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args: Response(status=404),
+        ):
+            self.assertEqual(self.conn.get_object_stat("a"), None)
 
     def test_put_object(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args, **kwargs: Response()):
-            self.assertEqual(self.conn.put_object('a', BytesIO(b'content')),
-                             None)
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args, **kwargs: Response(),
+        ):
+            self.assertEqual(self.conn.put_object("a", BytesIO(b"content")), None)
 
     def test_put_object_fails(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args, **kwargs: Response(status=400)):
-            self.assertRaises(swift.SwiftException,
-                              lambda: self.conn.put_object(
-                                  'a', BytesIO(b'content')))
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args, **kwargs: Response(status=400),
+        ):
+            self.assertRaises(
+                swift.SwiftException,
+                lambda: self.conn.put_object("a", BytesIO(b"content")),
+            )
 
     def test_get_object(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args, **kwargs: Response(content=b'content')):
-            self.assertEqual(self.conn.get_object('a').read(), b'content')
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args, **kwargs: Response(content=b'content')):
-            self.assertEqual(
-                    self.conn.get_object('a', range='0-6'),
-                    b'content')
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args, **kwargs: Response(content=b"content"),
+        ):
+            self.assertEqual(self.conn.get_object("a").read(), b"content")
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args, **kwargs: Response(content=b"content"),
+        ):
+            self.assertEqual(self.conn.get_object("a", range="0-6"), b"content")
 
     def test_get_object_fails(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args, **kwargs: Response(status=404)):
-            self.assertEqual(self.conn.get_object('a'), None)
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args, **kwargs: Response(status=404),
+        ):
+            self.assertEqual(self.conn.get_object("a"), None)
 
     def test_del_object(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response()):
-            self.assertEqual(self.conn.del_object('a'), None)
+        with patch("geventhttpclient.HTTPClient.request", lambda *args: Response()):
+            self.assertEqual(self.conn.del_object("a"), None)
 
     def test_del_root(self):
-        with patch('dulwich.contrib.swift.SwiftConnector.del_object',
-                   lambda *args: None):
-            with patch('dulwich.contrib.swift.SwiftConnector.'
-                       'get_container_objects',
-                       lambda *args: ({'name': 'a'}, {'name': 'b'})):
-                with patch('geventhttpclient.HTTPClient.request',
-                           lambda *args: Response()):
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector.del_object",
+            lambda *args: None,
+        ):
+            with patch(
+                "dulwich.contrib.swift.SwiftConnector." "get_container_objects",
+                lambda *args: ({"name": "a"}, {"name": "b"}),
+            ):
+                with patch(
+                    "geventhttpclient.HTTPClient.request",
+                    lambda *args: Response(),
+                ):
                     self.assertEqual(self.conn.del_root(), None)
 
 
 @skipIf(missing_libs, skipmsg)
 class SwiftObjectStoreTests(ObjectStoreTests, TestCase):
-
     def setUp(self):
         TestCase.setUp(self)
-        conf = swift.load_conf(file=StringIO(config_file %
-                               def_config_file))
-        fsc = FakeSwiftConnector('fakerepo', conf=conf)
+        conf = swift.load_conf(file=StringIO(config_file % def_config_file))
+        fsc = FakeSwiftConnector("fakerepo", conf=conf)
         self.store = swift.SwiftObjectStore(fsc)

+ 102 - 99
dulwich/contrib/test_swift_smoke.py

@@ -40,6 +40,7 @@ import shutil
 
 import gevent
 from gevent import monkey
+
 monkey.patch_all()
 
 from dulwich import (  # noqa:E402
@@ -48,21 +49,19 @@ from dulwich import (  # noqa:E402
     index,
     client,
     objects,
-    )
+)
 from dulwich.contrib import swift  # noqa:E402
 
 
-class DulwichServer():
-    """Start the TCPGitServer with Swift backend
-    """
+class DulwichServer:
+    """Start the TCPGitServer with Swift backend"""
+
     def __init__(self, backend, port):
         self.port = port
         self.backend = backend
 
     def run(self):
-        self.server = server.TCPGitServer(self.backend,
-                                          'localhost',
-                                          port=self.port)
+        self.server = server.TCPGitServer(self.backend, "localhost", port=self.port)
         self.job = gevent.spawn(self.server.serve_forever)
 
     def stop(self):
@@ -71,19 +70,17 @@ class DulwichServer():
 
 
 class SwiftSystemBackend(server.Backend):
-
     def open_repository(self, path):
         return swift.SwiftRepo(path, conf=swift.load_conf())
 
 
 class SwiftRepoSmokeTest(unittest.TestCase):
-
     @classmethod
     def setUpClass(cls):
         cls.backend = SwiftSystemBackend()
         cls.port = 9148
-        cls.server_address = 'localhost'
-        cls.fakerepo = 'fakerepo'
+        cls.server_address = "localhost"
+        cls.fakerepo = "fakerepo"
         cls.th_server = DulwichServer(cls.backend, cls.port)
         cls.th_server.run()
         cls.conf = swift.load_conf()
@@ -116,103 +113,103 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         self.assertTrue(self.scon.test_root_exists())
         obj = self.scon.get_container_objects()
-        filtered = [o for o in obj if o['name'] == 'info/refs'
-                    or o['name'] == 'objects/pack']
+        filtered = [
+            o for o in obj if o["name"] == "info/refs" or o["name"] == "objects/pack"
+        ]
         self.assertEqual(len(filtered), 2)
 
     def test_clone_bare(self):
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
         remote_refs = tcp_client.fetch(self.fakerepo, local_repo)
         # The remote repo is empty (no refs retreived)
         self.assertEqual(remote_refs, None)
 
     def test_push_commit(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         # Nothing in the staging area
-        local_repo.do_commit('Test commit', 'fbo@localhost')
-        sha = local_repo.refs.read_loose_ref('refs/heads/master')
+        local_repo.do_commit("Test commit", "fbo@localhost")
+        sha = local_repo.refs.read_loose_ref("refs/heads/master")
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack(self.fakerepo, determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            self.fakerepo, determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
-        remote_sha = swift_repo.refs.read_loose_ref('refs/heads/master')
+        remote_sha = swift_repo.refs.read_loose_ref("refs/heads/master")
         self.assertEqual(sha, remote_sha)
 
     def test_push_branch(self):
-        def determine_wants(*args):
-            return {"refs/heads/mybranch":
-                    local_repo.refs["refs/heads/mybranch"]}
+        def determine_wants(*args, **kwargs):
+            return {"refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"]}
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         # Nothing in the staging area
-        local_repo.do_commit('Test commit', 'fbo@localhost',
-                             ref='refs/heads/mybranch')
-        sha = local_repo.refs.read_loose_ref('refs/heads/mybranch')
+        local_repo.do_commit("Test commit", "fbo@localhost", ref="refs/heads/mybranch")
+        sha = local_repo.refs.read_loose_ref("refs/heads/mybranch")
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack("/fakerepo", determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            "/fakerepo", determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo(self.fakerepo, self.conf)
-        remote_sha = swift_repo.refs.read_loose_ref('refs/heads/mybranch')
+        remote_sha = swift_repo.refs.read_loose_ref("refs/heads/mybranch")
         self.assertEqual(sha, remote_sha)
 
     def test_push_multiple_branch(self):
-        def determine_wants(*args):
-            return {"refs/heads/mybranch":
-                    local_repo.refs["refs/heads/mybranch"],
-                    "refs/heads/master":
-                    local_repo.refs["refs/heads/master"],
-                    "refs/heads/pullr-108":
-                    local_repo.refs["refs/heads/pullr-108"]}
+        def determine_wants(*args, **kwargs):
+            return {
+                "refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"],
+                "refs/heads/master": local_repo.refs["refs/heads/master"],
+                "refs/heads/pullr-108": local_repo.refs["refs/heads/pullr-108"],
+            }
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         # Nothing in the staging area
         local_shas = {}
         remote_shas = {}
-        for branch in ('master', 'mybranch', 'pullr-108'):
+        for branch in ("master", "mybranch", "pullr-108"):
             local_shas[branch] = local_repo.do_commit(
-                'Test commit %s' % branch, 'fbo@localhost',
-                ref='refs/heads/%s' % branch)
+                "Test commit %s" % branch,
+                "fbo@localhost",
+                ref="refs/heads/%s" % branch,
+            )
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack(self.fakerepo, determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            self.fakerepo, determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
-        for branch in ('master', 'mybranch', 'pullr-108'):
+        for branch in ("master", "mybranch", "pullr-108"):
             remote_shas[branch] = swift_repo.refs.read_loose_ref(
-                'refs/heads/%s' % branch)
+                "refs/heads/%s" % branch
+            )
         self.assertDictEqual(local_shas, remote_shas)
 
     def test_push_data_branch(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
+
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         os.mkdir(os.path.join(self.temp_d, "dir"))
-        files = ('testfile', 'testfile2', 'dir/testfile3')
+        files = ("testfile", "testfile2", "dir/testfile3")
         i = 0
         for f in files:
-            open(os.path.join(self.temp_d, f), 'w').write("DATA %s" % i)
+            open(os.path.join(self.temp_d, f), "w").write("DATA %s" % i)
             i += 1
         local_repo.stage(files)
-        local_repo.do_commit('Test commit', 'fbo@localhost',
-                             ref='refs/heads/master')
+        local_repo.do_commit("Test commit", "fbo@localhost", ref="refs/heads/master")
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack(self.fakerepo, determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            self.fakerepo, determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
-        commit_sha = swift_repo.refs.read_loose_ref('refs/heads/master')
+        commit_sha = swift_repo.refs.read_loose_ref("refs/heads/master")
         otype, data = swift_repo.object_store.get_raw(commit_sha)
         commit = objects.ShaFile.from_raw_string(otype, data)
         otype, data = swift_repo.object_store.get_raw(commit._tree)
@@ -222,8 +219,8 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         for tree_entry in objs:
             objs_.append(swift_repo.object_store.get_raw(tree_entry.sha))
         # Blob
-        self.assertEqual(objs_[1][1], 'DATA 0')
-        self.assertEqual(objs_[2][1], 'DATA 1')
+        self.assertEqual(objs_[1][1], "DATA 0")
+        self.assertEqual(objs_[2][1], "DATA 1")
         # Tree
         self.assertEqual(objs_[0][0], 2)
 
@@ -231,80 +228,86 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.test_push_data_branch()
         shutil.rmtree(self.temp_d)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
         remote_refs = tcp_client.fetch(self.fakerepo, local_repo)
-        files = (os.path.join(self.temp_d, 'testfile'),
-                 os.path.join(self.temp_d, 'testfile2'))
+        files = (
+            os.path.join(self.temp_d, "testfile"),
+            os.path.join(self.temp_d, "testfile2"),
+        )
         local_repo["HEAD"] = remote_refs["refs/heads/master"]
         indexfile = local_repo.index_path()
         tree = local_repo["HEAD"].tree
-        index.build_index_from_tree(local_repo.path, indexfile,
-                                    local_repo.object_store, tree)
+        index.build_index_from_tree(
+            local_repo.path, indexfile, local_repo.object_store, tree
+        )
         for f in files:
             self.assertEqual(os.path.isfile(f), True)
 
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
+
         os.mkdir(os.path.join(self.temp_d, "test"))
-        files = ('testfile11', 'testfile22', 'test/testfile33')
+        files = ("testfile11", "testfile22", "test/testfile33")
         i = 0
         for f in files:
-            open(os.path.join(self.temp_d, f), 'w').write("DATA %s" % i)
+            open(os.path.join(self.temp_d, f), "w").write("DATA %s" % i)
             i += 1
         local_repo.stage(files)
-        local_repo.do_commit('Test commit', 'fbo@localhost',
-                             ref='refs/heads/master')
-        tcp_client.send_pack("/fakerepo", determine_wants,
-                             local_repo.generate_pack_data)
+        local_repo.do_commit("Test commit", "fbo@localhost", ref="refs/heads/master")
+        tcp_client.send_pack(
+            "/fakerepo", determine_wants, local_repo.generate_pack_data
+        )
 
     def test_push_remove_branch(self):
-        def determine_wants(*args):
-            return {"refs/heads/pullr-108": objects.ZERO_SHA,
-                    "refs/heads/master":
-                    local_repo.refs['refs/heads/master'],
-                    "refs/heads/mybranch":
-                    local_repo.refs['refs/heads/mybranch'],
-                    }
+        def determine_wants(*args, **kwargs):
+            return {
+                "refs/heads/pullr-108": objects.ZERO_SHA,
+                "refs/heads/master": local_repo.refs["refs/heads/master"],
+                "refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"],
+            }
+
         self.test_push_multiple_branch()
         local_repo = repo.Repo(self.temp_d)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack(self.fakerepo, determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            self.fakerepo, determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
-        self.assertNotIn('refs/heads/pullr-108', swift_repo.refs.allkeys())
+        self.assertNotIn("refs/heads/pullr-108", swift_repo.refs.allkeys())
 
     def test_push_annotated_tag(self):
-        def determine_wants(*args):
-            return {"refs/heads/master": local_repo.refs["HEAD"],
-                    "refs/tags/v1.0": local_repo.refs["refs/tags/v1.0"]}
+        def determine_wants(*args, **kwargs):
+            return {
+                "refs/heads/master": local_repo.refs["HEAD"],
+                "refs/tags/v1.0": local_repo.refs["refs/tags/v1.0"],
+            }
+
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         # Nothing in the staging area
-        sha = local_repo.do_commit('Test commit', 'fbo@localhost')
+        sha = local_repo.do_commit("Test commit", "fbo@localhost")
         otype, data = local_repo.object_store.get_raw(sha)
         commit = objects.ShaFile.from_raw_string(otype, data)
         tag = objects.Tag()
         tag.tagger = "fbo@localhost"
         tag.message = "Annotated tag"
-        tag.tag_timezone = objects.parse_timezone('-0200')[0]
+        tag.tag_timezone = objects.parse_timezone("-0200")[0]
         tag.tag_time = commit.author_time
         tag.object = (objects.Commit, commit.id)
         tag.name = "v0.1"
         local_repo.object_store.add_object(tag)
-        local_repo.refs['refs/tags/v1.0'] = tag.id
+        local_repo.refs["refs/tags/v1.0"] = tag.id
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack(self.fakerepo, determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            self.fakerepo, determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo(self.fakerepo, self.conf)
-        tag_sha = swift_repo.refs.read_loose_ref('refs/tags/v1.0')
+        tag_sha = swift_repo.refs.read_loose_ref("refs/tags/v1.0")
         otype, data = swift_repo.object_store.get_raw(tag_sha)
         rtag = objects.ShaFile.from_raw_string(otype, data)
         self.assertEqual(rtag.object[1], commit.id)
         self.assertEqual(rtag.id, tag.id)
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()

+ 74 - 53
dulwich/diff_tree.py

@@ -23,7 +23,7 @@
 from collections import (
     defaultdict,
     namedtuple,
-    )
+)
 
 from io import BytesIO
 from itertools import chain
@@ -32,16 +32,16 @@ import stat
 from dulwich.objects import (
     S_ISGITLINK,
     TreeEntry,
-    )
+)
 
 
 # TreeChange type constants.
-CHANGE_ADD = 'add'
-CHANGE_MODIFY = 'modify'
-CHANGE_DELETE = 'delete'
-CHANGE_RENAME = 'rename'
-CHANGE_COPY = 'copy'
-CHANGE_UNCHANGED = 'unchanged'
+CHANGE_ADD = "add"
+CHANGE_MODIFY = "modify"
+CHANGE_DELETE = "delete"
+CHANGE_RENAME = "rename"
+CHANGE_COPY = "copy"
+CHANGE_UNCHANGED = "unchanged"
 
 RENAME_CHANGE_TYPES = (CHANGE_RENAME, CHANGE_COPY)
 
@@ -53,7 +53,7 @@ MAX_FILES = 200
 REWRITE_THRESHOLD = None
 
 
-class TreeChange(namedtuple('TreeChange', ['type', 'old', 'new'])):
+class TreeChange(namedtuple("TreeChange", ["type", "old", "new"])):
     """Named tuple a single change between two trees."""
 
     @classmethod
@@ -142,7 +142,7 @@ def walk_trees(store, tree1_id, tree2_id, prune_identical=False):
     # case.
     mode1 = tree1_id and stat.S_IFDIR or None
     mode2 = tree2_id and stat.S_IFDIR or None
-    todo = [(TreeEntry(b'', mode1, tree1_id), TreeEntry(b'', mode2, tree2_id))]
+    todo = [(TreeEntry(b"", mode1, tree1_id), TreeEntry(b"", mode2, tree2_id))]
     while todo:
         entry1, entry2 = todo.pop()
         is_tree1 = _is_tree(entry1)
@@ -163,9 +163,15 @@ def _skip_tree(entry, include_trees):
     return entry
 
 
-def tree_changes(store, tree1_id, tree2_id, want_unchanged=False,
-                 rename_detector=None, include_trees=False,
-                 change_type_same=False):
+def tree_changes(
+    store,
+    tree1_id,
+    tree2_id,
+    want_unchanged=False,
+    rename_detector=None,
+    include_trees=False,
+    change_type_same=False,
+):
     """Find the differences between the contents of two trees.
 
     Args:
@@ -182,16 +188,19 @@ def tree_changes(store, tree1_id, tree2_id, want_unchanged=False,
       Iterator over TreeChange instances for each change between the
         source and target tree.
     """
-    if (rename_detector is not None and tree1_id is not None and
-            tree2_id is not None):
+    if rename_detector is not None and tree1_id is not None and tree2_id is not None:
         for change in rename_detector.changes_with_renames(
-                tree1_id, tree2_id, want_unchanged=want_unchanged,
-                include_trees=include_trees):
+            tree1_id,
+            tree2_id,
+            want_unchanged=want_unchanged,
+            include_trees=include_trees,
+        ):
             yield change
         return
 
-    entries = walk_trees(store, tree1_id, tree2_id,
-                         prune_identical=(not want_unchanged))
+    entries = walk_trees(
+        store, tree1_id, tree2_id, prune_identical=(not want_unchanged)
+    )
     for entry1, entry2 in entries:
         if entry1 == entry2 and not want_unchanged:
             continue
@@ -201,8 +210,10 @@ def tree_changes(store, tree1_id, tree2_id, want_unchanged=False,
         entry2 = _skip_tree(entry2, include_trees)
 
         if entry1 != _NULL_ENTRY and entry2 != _NULL_ENTRY:
-            if (stat.S_IFMT(entry1.mode) != stat.S_IFMT(entry2.mode)
-                    and not change_type_same):
+            if (
+                stat.S_IFMT(entry1.mode) != stat.S_IFMT(entry2.mode)
+                and not change_type_same
+            ):
                 # File type changed: report as delete/add.
                 yield TreeChange.delete(entry1)
                 entry1 = _NULL_ENTRY
@@ -232,8 +243,7 @@ def _all_same(seq, key):
     return _all_eq(seq[1:], key, key(seq[0]))
 
 
-def tree_changes_for_merge(store, parent_tree_ids, tree_id,
-                           rename_detector=None):
+def tree_changes_for_merge(store, parent_tree_ids, tree_id, rename_detector=None):
     """Get the tree changes for a merge tree relative to all its parents.
 
     Args:
@@ -254,9 +264,10 @@ def tree_changes_for_merge(store, parent_tree_ids, tree_id,
       in the merge tree is not found in any of the parents, or in the case of
       deletes, if not all of the old SHAs match.
     """
-    all_parent_changes = [tree_changes(store, t, tree_id,
-                                       rename_detector=rename_detector)
-                          for t in parent_tree_ids]
+    all_parent_changes = [
+        tree_changes(store, t, tree_id, rename_detector=rename_detector)
+        for t in parent_tree_ids
+    ]
     num_parents = len(parent_tree_ids)
     changes_by_path = defaultdict(lambda: [None] * num_parents)
 
@@ -315,10 +326,10 @@ def _count_blocks(obj):
     block_getvalue = block.getvalue
 
     for c in chain.from_iterable(obj.as_raw_chunks()):
-        c = c.to_bytes(1, 'big')
+        c = c.to_bytes(1, "big")
         block_write(c)
         n += 1
-        if c == b'\n' or n == _BLOCK_SIZE:
+        if c == b"\n" or n == _BLOCK_SIZE:
             value = block_getvalue()
             block_counts[hash(value)] += len(value)
             block_seek(0)
@@ -392,10 +403,14 @@ def _tree_change_key(entry):
 class RenameDetector(object):
     """Object for handling rename detection between two trees."""
 
-    def __init__(self, store, rename_threshold=RENAME_THRESHOLD,
-                 max_files=MAX_FILES,
-                 rewrite_threshold=REWRITE_THRESHOLD,
-                 find_copies_harder=False):
+    def __init__(
+        self,
+        store,
+        rename_threshold=RENAME_THRESHOLD,
+        max_files=MAX_FILES,
+        rewrite_threshold=REWRITE_THRESHOLD,
+        find_copies_harder=False,
+    ):
         """Initialize the rename detector.
 
         Args:
@@ -426,8 +441,11 @@ class RenameDetector(object):
         self._changes = []
 
     def _should_split(self, change):
-        if (self._rewrite_threshold is None or change.type != CHANGE_MODIFY or
-                change.old.sha == change.new.sha):
+        if (
+            self._rewrite_threshold is None
+            or change.type != CHANGE_MODIFY
+            or change.old.sha == change.new.sha
+        ):
             return False
         old_obj = self._store[change.old.sha]
         new_obj = self._store[change.new.sha]
@@ -441,8 +459,9 @@ class RenameDetector(object):
         elif self._should_split(change):
             self._deletes.append(TreeChange.delete(change.old))
             self._adds.append(TreeChange.add(change.new))
-        elif ((self._find_copies_harder and change.type == CHANGE_UNCHANGED)
-              or change.type == CHANGE_MODIFY):
+        elif (
+            self._find_copies_harder and change.type == CHANGE_UNCHANGED
+        ) or change.type == CHANGE_MODIFY:
             # Treat all modifies as potential deletes for rename detection,
             # but don't split them (to avoid spurious renames). Setting
             # find_copies_harder means we treat unchanged the same as
@@ -453,15 +472,18 @@ class RenameDetector(object):
 
     def _collect_changes(self, tree1_id, tree2_id):
         want_unchanged = self._find_copies_harder or self._want_unchanged
-        for change in tree_changes(self._store, tree1_id, tree2_id,
-                                   want_unchanged=want_unchanged,
-                                   include_trees=self._include_trees):
+        for change in tree_changes(
+            self._store,
+            tree1_id,
+            tree2_id,
+            want_unchanged=want_unchanged,
+            include_trees=self._include_trees,
+        ):
             self._add_change(change)
 
     def _prune(self, add_paths, delete_paths):
         self._adds = [a for a in self._adds if a.new.path not in add_paths]
-        self._deletes = [d for d in self._deletes
-                         if d.old.path not in delete_paths]
+        self._deletes = [d for d in self._deletes if d.old.path not in delete_paths]
 
     def _find_exact_renames(self):
         add_map = defaultdict(list)
@@ -534,8 +556,7 @@ class RenameDetector(object):
                 if stat.S_IFMT(delete.old.mode) != stat.S_IFMT(add.new.mode):
                     continue
                 new_obj = self._store[add.new.sha]
-                score = _similarity_score(old_obj, new_obj,
-                                          block_cache=block_cache)
+                score = _similarity_score(old_obj, new_obj, block_cache=block_cache)
                 if score > self._rename_threshold:
                     new_type = self._rename_type(check_paths, delete, add)
                     rename = TreeChange(new_type, delete.old, add.new)
@@ -570,17 +591,17 @@ class RenameDetector(object):
             return
 
         modifies = {}
-        delete_map = dict((d.old.path, d) for d in self._deletes)
+        delete_map = {d.old.path: d for d in self._deletes}
         for add in self._adds:
             path = add.new.path
             delete = delete_map.get(path)
-            if (delete is not None and
-                    stat.S_IFMT(delete.old.mode) == stat.S_IFMT(add.new.mode)):
+            if delete is not None and stat.S_IFMT(delete.old.mode) == stat.S_IFMT(
+                add.new.mode
+            ):
                 modifies[path] = TreeChange(CHANGE_MODIFY, delete.old, add.new)
 
         self._adds = [a for a in self._adds if a.new.path not in modifies]
-        self._deletes = [a for a in self._deletes if a.new.path not in
-                         modifies]
+        self._deletes = [a for a in self._deletes if a.new.path not in modifies]
         self._changes += modifies.values()
 
     def _sorted_changes(self):
@@ -594,11 +615,11 @@ class RenameDetector(object):
     def _prune_unchanged(self):
         if self._want_unchanged:
             return
-        self._deletes = [
-            d for d in self._deletes if d.type != CHANGE_UNCHANGED]
+        self._deletes = [d for d in self._deletes if d.type != CHANGE_UNCHANGED]
 
-    def changes_with_renames(self, tree1_id, tree2_id, want_unchanged=False,
-                             include_trees=False):
+    def changes_with_renames(
+        self, tree1_id, tree2_id, want_unchanged=False, include_trees=False
+    ):
         """Iterate TreeChanges between two tree SHAs, with rename detection."""
         self._reset()
         self._want_unchanged = want_unchanged
@@ -622,6 +643,6 @@ try:
         _is_tree,
         _merge_entries,
         _count_blocks,
-        )
+    )
 except ImportError:
     pass

+ 22 - 19
dulwich/errors.py

@@ -37,12 +37,14 @@ class ChecksumMismatch(Exception):
         self.extra = extra
         if self.extra is None:
             Exception.__init__(
-                self, "Checksum mismatch: Expected %s, got %s" %
-                (expected, got))
+                self,
+                "Checksum mismatch: Expected %s, got %s" % (expected, got),
+            )
         else:
             Exception.__init__(
-                self, "Checksum mismatch: Expected %s, got %s; %s" %
-                (expected, got, extra))
+                self,
+                "Checksum mismatch: Expected %s, got %s; %s" % (expected, got, extra),
+            )
 
 
 class WrongObjectException(Exception):
@@ -61,25 +63,25 @@ class WrongObjectException(Exception):
 class NotCommitError(WrongObjectException):
     """Indicates that the sha requested does not point to a commit."""
 
-    type_name = 'commit'
+    type_name = "commit"
 
 
 class NotTreeError(WrongObjectException):
     """Indicates that the sha requested does not point to a tree."""
 
-    type_name = 'tree'
+    type_name = "tree"
 
 
 class NotTagError(WrongObjectException):
     """Indicates that the sha requested does not point to a tag."""
 
-    type_name = 'tag'
+    type_name = "tag"
 
 
 class NotBlobError(WrongObjectException):
     """Indicates that the sha requested does not point to a blob."""
 
-    type_name = 'blob'
+    type_name = "blob"
 
 
 class MissingCommitError(Exception):
@@ -132,7 +134,7 @@ class UpdateRefsError(GitProtocolError):
     """The server reported errors updating refs."""
 
     def __init__(self, *args, **kwargs):
-        self.ref_status = kwargs.pop('ref_status')
+        self.ref_status = kwargs.pop("ref_status")
         super(UpdateRefsError, self).__init__(*args, **kwargs)
 
 
@@ -142,18 +144,18 @@ class HangupException(GitProtocolError):
     def __init__(self, stderr_lines=None):
         if stderr_lines:
             super(HangupException, self).__init__(
-                '\n'.join(
-                    [line.decode('utf-8', 'surrogateescape')
-                     for line in stderr_lines]))
+                "\n".join(
+                    [line.decode("utf-8", "surrogateescape") for line in stderr_lines]
+                )
+            )
         else:
             super(HangupException, self).__init__(
-                "The remote server unexpectedly closed the connection.")
+                "The remote server unexpectedly closed the connection."
+            )
         self.stderr_lines = stderr_lines
 
     def __eq__(self, other):
-        return (
-            isinstance(self, type(other)) and
-            self.stderr_lines == other.stderr_lines)
+        return isinstance(self, type(other)) and self.stderr_lines == other.stderr_lines
 
 
 class UnexpectedCommandError(GitProtocolError):
@@ -161,11 +163,12 @@ class UnexpectedCommandError(GitProtocolError):
 
     def __init__(self, command):
         if command is None:
-            command = 'flush-pkt'
+            command = "flush-pkt"
         else:
-            command = 'command %s' % command
+            command = "command %s" % command
         super(UnexpectedCommandError, self).__init__(
-            'Protocol got unexpected %s' % command)
+            "Protocol got unexpected %s" % command
+        )
 
 
 class FileFormatException(Exception):

+ 39 - 24
dulwich/fastexport.py

@@ -23,19 +23,19 @@
 
 from dulwich.index import (
     commit_tree,
-    )
+)
 from dulwich.objects import (
     Blob,
     Commit,
     Tag,
     ZERO_SHA,
-    )
+)
 from fastimport import (  # noqa: E402
     commands,
     errors as fastimport_errors,
     parser,
     processor,
-    )
+)
 
 import stat  # noqa: E402
 
@@ -59,7 +59,7 @@ class GitFastExporter(object):
 
     def _allocate_marker(self):
         self._marker_idx += 1
-        return ("%d" % (self._marker_idx,)).encode('ascii')
+        return ("%d" % (self._marker_idx,)).encode("ascii")
 
     def _export_blob(self, blob):
         marker = self._allocate_marker()
@@ -72,9 +72,11 @@ class GitFastExporter(object):
         return marker
 
     def _iter_files(self, base_tree, new_tree):
-        for ((old_path, new_path), (old_mode, new_mode),
-             (old_hexsha, new_hexsha)) in \
-                self.store.tree_changes(base_tree, new_tree):
+        for (
+            (old_path, new_path),
+            (old_mode, new_mode),
+            (old_hexsha, new_hexsha),
+        ) in self.store.tree_changes(base_tree, new_tree):
             if new_path is None:
                 yield commands.FileDeleteCommand(old_path)
                 continue
@@ -84,7 +86,7 @@ class GitFastExporter(object):
             if old_path != new_path and old_path is not None:
                 yield commands.FileRenameCommand(old_path, new_path)
             if old_mode != new_mode or old_hexsha != new_hexsha:
-                prefixed_marker = b':' + marker
+                prefixed_marker = b":" + marker
                 yield commands.FileModifyCommand(
                     new_path, new_mode, prefixed_marker, None
                 )
@@ -101,11 +103,20 @@ class GitFastExporter(object):
         author, author_email = split_email(commit.author)
         committer, committer_email = split_email(commit.committer)
         cmd = commands.CommitCommand(
-            ref, marker,
+            ref,
+            marker,
             (author, author_email, commit.author_time, commit.author_timezone),
-            (committer, committer_email, commit.commit_time,
-                commit.commit_timezone),
-            commit.message, from_, merges, file_cmds)
+            (
+                committer,
+                committer_email,
+                commit.commit_time,
+                commit.commit_timezone,
+            ),
+            commit.message,
+            from_,
+            merges,
+            file_cmds,
+        )
         return (cmd, marker)
 
     def emit_commit(self, commit, ref, base_tree=None):
@@ -115,9 +126,8 @@ class GitFastExporter(object):
 
 
 class GitImportProcessor(processor.ImportProcessor):
-    """An import processor that imports into a Git repository using Dulwich.
+    """An import processor that imports into a Git repository using Dulwich."""
 
-    """
     # FIXME: Batch creation of objects?
 
     def __init__(self, repo, params=None, verbose=False, outf=None):
@@ -156,8 +166,12 @@ class GitImportProcessor(processor.ImportProcessor):
         else:
             author = cmd.committer
         (author_name, author_email, author_timestamp, author_timezone) = author
-        (committer_name, committer_email, commit_timestamp,
-            commit_timezone) = cmd.committer
+        (
+            committer_name,
+            committer_email,
+            commit_timestamp,
+            commit_timezone,
+        ) = cmd.committer
         commit.author = author_name + b" <" + author_email + b">"
         commit.author_timezone = author_timezone
         commit.author_time = int(author_timestamp)
@@ -181,11 +195,9 @@ class GitImportProcessor(processor.ImportProcessor):
             elif filecmd.name == b"filedelete":
                 del self._contents[filecmd.path]
             elif filecmd.name == b"filecopy":
-                self._contents[filecmd.dest_path] = self._contents[
-                    filecmd.src_path]
+                self._contents[filecmd.dest_path] = self._contents[filecmd.src_path]
             elif filecmd.name == b"filerename":
-                self._contents[filecmd.new_path] = self._contents[
-                    filecmd.old_path]
+                self._contents[filecmd.new_path] = self._contents[filecmd.old_path]
                 del self._contents[filecmd.old_path]
             elif filecmd.name == b"filedeleteall":
                 self._contents = {}
@@ -193,8 +205,8 @@ class GitImportProcessor(processor.ImportProcessor):
                 raise Exception("Command %s not supported" % filecmd.name)
         commit.tree = commit_tree(
             self.repo.object_store,
-            ((path, hexsha, mode) for (path, (mode, hexsha)) in
-                self._contents.items()))
+            ((path, hexsha, mode) for (path, (mode, hexsha)) in self._contents.items()),
+        )
         if self.last_commit != ZERO_SHA:
             commit.parents.append(self.last_commit)
         for merge in cmd.merges:
@@ -216,8 +228,11 @@ class GitImportProcessor(processor.ImportProcessor):
         self.last_commit = commit_id
         if commit_id != ZERO_SHA:
             tree_id = self.repo[commit_id].tree
-            for (path, mode, hexsha) in (
-                    self.repo.object_store.iter_tree_contents(tree_id)):
+            for (
+                path,
+                mode,
+                hexsha,
+            ) in self.repo.object_store.iter_tree_contents(tree_id):
                 self._contents[path] = (mode, hexsha)
 
     def reset_handler(self, cmd):

+ 41 - 20
dulwich/file.py

@@ -44,6 +44,7 @@ def _fancy_rename(oldname, newname):
 
     # Defer the tempfile import since it pulls in a lot of other things.
     import tempfile
+
     # destination file exists
     try:
         (fd, tmpfile) = tempfile.mkstemp(".tmp", prefix=oldname, dir=".")
@@ -56,7 +57,7 @@ def _fancy_rename(oldname, newname):
     try:
         os.rename(newname, tmpfile)
     except OSError:
-        raise   # no rename occurred
+        raise  # no rename occurred
     try:
         os.rename(oldname, newname)
     except OSError:
@@ -65,7 +66,7 @@ def _fancy_rename(oldname, newname):
     os.remove(tmpfile)
 
 
-def GitFile(filename, mode='rb', bufsize=-1):
+def GitFile(filename, mode="rb", bufsize=-1):
     """Create a file object that obeys the git file locking protocol.
 
     Returns: a builtin file object or a _GitFile object
@@ -77,13 +78,13 @@ def GitFile(filename, mode='rb', bufsize=-1):
     the fact that opening a file for write does not actually open the file you
     request.
     """
-    if 'a' in mode:
-        raise IOError('append mode not supported for Git files')
-    if '+' in mode:
-        raise IOError('read/write mode not supported for Git files')
-    if 'b' not in mode:
-        raise IOError('text mode not supported for Git files')
-    if 'w' in mode:
+    if "a" in mode:
+        raise IOError("append mode not supported for Git files")
+    if "+" in mode:
+        raise IOError("read/write mode not supported for Git files")
+    if "b" not in mode:
+        raise IOError("text mode not supported for Git files")
+    if "w" in mode:
         return _GitFile(filename, mode, bufsize)
     else:
         return io.open(filename, mode, bufsize)
@@ -109,23 +110,43 @@ class _GitFile(object):
         released. Typically this will happen in a finally block.
     """
 
-    PROXY_PROPERTIES = set(['closed', 'encoding', 'errors', 'mode', 'name',
-                            'newlines', 'softspace'])
-    PROXY_METHODS = ('__iter__', 'flush', 'fileno', 'isatty', 'read',
-                     'readline', 'readlines', 'seek', 'tell',
-                     'truncate', 'write', 'writelines')
+    PROXY_PROPERTIES = set(
+        [
+            "closed",
+            "encoding",
+            "errors",
+            "mode",
+            "name",
+            "newlines",
+            "softspace",
+        ]
+    )
+    PROXY_METHODS = (
+        "__iter__",
+        "flush",
+        "fileno",
+        "isatty",
+        "read",
+        "readline",
+        "readlines",
+        "seek",
+        "tell",
+        "truncate",
+        "write",
+        "writelines",
+    )
 
     def __init__(self, filename, mode, bufsize):
         self._filename = filename
         if isinstance(self._filename, bytes):
-            self._lockfilename = self._filename + b'.lock'
+            self._lockfilename = self._filename + b".lock"
         else:
-            self._lockfilename = self._filename + '.lock'
+            self._lockfilename = self._filename + ".lock"
         try:
             fd = os.open(
                 self._lockfilename,
-                os.O_RDWR | os.O_CREAT | os.O_EXCL |
-                getattr(os, "O_BINARY", 0))
+                os.O_RDWR | os.O_CREAT | os.O_EXCL | getattr(os, "O_BINARY", 0),
+            )
         except FileExistsError:
             raise FileLocked(filename, self._lockfilename)
         self._file = os.fdopen(fd, mode, bufsize)
@@ -166,10 +187,10 @@ class _GitFile(object):
         os.fsync(self._file.fileno())
         self._file.close()
         try:
-            if getattr(os, 'replace', None) is not None:
+            if getattr(os, "replace", None) is not None:
                 os.replace(self._lockfilename, self._filename)
             else:
-                if sys.platform != 'win32':
+                if sys.platform != "win32":
                     os.rename(self._lockfilename, self._filename)
                 else:
                     # Windows versions prior to Vista don't support atomic

+ 2 - 2
dulwich/graph.py

@@ -33,8 +33,8 @@ def _find_lcas(lookup_parents, c1, c2s):
     # Flags to Record State
     _ANC_OF_1 = 1  # ancestor of commit 1
     _ANC_OF_2 = 2  # ancestor of commit 2
-    _DNC = 4       # Do Not Consider
-    _LCA = 8       # potential LCA
+    _DNC = 4  # Do Not Consider
+    _LCA = 8  # potential LCA
 
     def _has_candidates(wlst, cstates):
         for cmt in wlst:

+ 23 - 19
dulwich/greenthreads.py

@@ -28,16 +28,15 @@ from gevent import pool
 from dulwich.objects import (
     Commit,
     Tag,
-    )
+)
 from dulwich.object_store import (
     MissingObjectFinder,
     _collect_filetree_revs,
     ObjectStoreIterator,
-    )
+)
 
 
-def _split_commits_and_tags(obj_store, lst,
-                            ignore_unknown=False, pool=None):
+def _split_commits_and_tags(obj_store, lst, ignore_unknown=False, pool=None):
     """Split object id list into two list with commit SHA1s and tag SHA1s.
 
     Same implementation as object_store._split_commits_and_tags
@@ -59,7 +58,8 @@ def _split_commits_and_tags(obj_store, lst,
                 tags.add(sha)
                 commits.add(o.object[1])
             else:
-                raise KeyError('Not a commit or a tag: %s' % sha)
+                raise KeyError("Not a commit or a tag: %s" % sha)
+
     jobs = [pool.spawn(find_commit_type, s) for s in lst]
     gevent.joinall(jobs)
     return (commits, tags)
@@ -71,10 +71,17 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
     Same implementation as object_store.MissingObjectFinder
     except we use gevent to parallelize object retrieval.
     """
-    def __init__(self, object_store, haves, wants,
-                 progress=None, get_tagged=None,
-                 concurrency=1, get_parents=None):
 
+    def __init__(
+        self,
+        object_store,
+        haves,
+        wants,
+        progress=None,
+        get_tagged=None,
+        concurrency=1,
+        get_parents=None,
+    ):
         def collect_tree_sha(sha):
             self.sha_done.add(sha)
             cmt = object_store[sha]
@@ -83,15 +90,12 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
         self.object_store = object_store
         p = pool.Pool(size=concurrency)
 
-        have_commits, have_tags = \
-            _split_commits_and_tags(object_store, haves,
-                                    True, p)
-        want_commits, want_tags = \
-            _split_commits_and_tags(object_store, wants,
-                                    False, p)
+        have_commits, have_tags = _split_commits_and_tags(object_store, haves, True, p)
+        want_commits, want_tags = _split_commits_and_tags(object_store, wants, False, p)
         all_ancestors = object_store._collect_ancestors(have_commits)[0]
-        missing_commits, common_commits = \
-            object_store._collect_ancestors(want_commits, all_ancestors)
+        missing_commits, common_commits = object_store._collect_ancestors(
+            want_commits, all_ancestors
+        )
 
         self.sha_done = set()
         jobs = [p.spawn(collect_tree_sha, c) for c in common_commits]
@@ -114,6 +118,7 @@ class GreenThreadsObjectStoreIterator(ObjectStoreIterator):
     Same implementation as object_store.ObjectStoreIterator
     except we use gevent to parallelize object retrieval.
     """
+
     def __init__(self, store, shas, finder, concurrency=1):
         self.finder = finder
         self.p = pool.Pool(size=concurrency)
@@ -124,14 +129,13 @@ class GreenThreadsObjectStoreIterator(ObjectStoreIterator):
         return self.store[sha], path
 
     def __iter__(self):
-        for sha, path in self.p.imap_unordered(self.retrieve,
-                                               self.itershas()):
+        for sha, path in self.p.imap_unordered(self.retrieve, self.itershas()):
             yield sha, path
 
     def __len__(self):
         if len(self._shas) > 0:
             return len(self._shas)
-        while len(self.finder.objects_to_send):
+        while self.finder.objects_to_send:
             jobs = []
             for _ in range(0, len(self.finder.objects_to_send)):
                 jobs.append(self.p.spawn(self.finder.next))

+ 38 - 28
dulwich/hooks.py

@@ -52,9 +52,15 @@ class ShellHook(Hook):
     [0] http://www.kernel.org/pub/software/scm/git/docs/githooks.html
     """
 
-    def __init__(self, name, path, numparam,
-                 pre_exec_callback=None, post_exec_callback=None,
-                 cwd=None):
+    def __init__(
+        self,
+        name,
+        path,
+        numparam,
+        pre_exec_callback=None,
+        post_exec_callback=None,
+        cwd=None,
+    ):
         """Setup shell hook definition
 
         Args:
@@ -84,24 +90,27 @@ class ShellHook(Hook):
         """Execute the hook with given args"""
 
         if len(args) != self.numparam:
-            raise HookError("Hook %s executed with wrong number of args. \
+            raise HookError(
+                "Hook %s executed with wrong number of args. \
                             Expected %d. Saw %d. args: %s"
-                            % (self.name, self.numparam, len(args), args))
+                % (self.name, self.numparam, len(args), args)
+            )
 
-        if (self.pre_exec_callback is not None):
+        if self.pre_exec_callback is not None:
             args = self.pre_exec_callback(*args)
 
         try:
             ret = subprocess.call([self.filepath] + list(args), cwd=self.cwd)
             if ret != 0:
-                if (self.post_exec_callback is not None):
+                if self.post_exec_callback is not None:
                     self.post_exec_callback(0, *args)
-                raise HookError("Hook %s exited with non-zero status %d"
-                                % (self.name, ret))
-            if (self.post_exec_callback is not None):
+                raise HookError(
+                    "Hook %s exited with non-zero status %d" % (self.name, ret)
+                )
+            if self.post_exec_callback is not None:
                 return self.post_exec_callback(1, *args)
         except OSError:  # no file. silent failure.
-            if (self.post_exec_callback is not None):
+            if self.post_exec_callback is not None:
                 self.post_exec_callback(0, *args)
 
 
@@ -109,18 +118,18 @@ class PreCommitShellHook(ShellHook):
     """pre-commit shell hook"""
 
     def __init__(self, controldir):
-        filepath = os.path.join(controldir, 'hooks', 'pre-commit')
+        filepath = os.path.join(controldir, "hooks", "pre-commit")
 
-        ShellHook.__init__(self, 'pre-commit', filepath, 0, cwd=controldir)
+        ShellHook.__init__(self, "pre-commit", filepath, 0, cwd=controldir)
 
 
 class PostCommitShellHook(ShellHook):
     """post-commit shell hook"""
 
     def __init__(self, controldir):
-        filepath = os.path.join(controldir, 'hooks', 'post-commit')
+        filepath = os.path.join(controldir, "hooks", "post-commit")
 
-        ShellHook.__init__(self, 'post-commit', filepath, 0, cwd=controldir)
+        ShellHook.__init__(self, "post-commit", filepath, 0, cwd=controldir)
 
 
 class CommitMsgShellHook(ShellHook):
@@ -133,27 +142,29 @@ class CommitMsgShellHook(ShellHook):
     """
 
     def __init__(self, controldir):
-        filepath = os.path.join(controldir, 'hooks', 'commit-msg')
+        filepath = os.path.join(controldir, "hooks", "commit-msg")
 
         def prepare_msg(*args):
             import tempfile
+
             (fd, path) = tempfile.mkstemp()
 
-            with os.fdopen(fd, 'wb') as f:
+            with os.fdopen(fd, "wb") as f:
                 f.write(args[0])
 
             return (path,)
 
         def clean_msg(success, *args):
             if success:
-                with open(args[0], 'rb') as f:
+                with open(args[0], "rb") as f:
                     new_msg = f.read()
                 os.unlink(args[0])
                 return new_msg
             os.unlink(args[0])
 
-        ShellHook.__init__(self, 'commit-msg', filepath, 1,
-                           prepare_msg, clean_msg, controldir)
+        ShellHook.__init__(
+            self, "commit-msg", filepath, 1, prepare_msg, clean_msg, controldir
+        )
 
 
 class PostReceiveShellHook(ShellHook):
@@ -161,8 +172,8 @@ class PostReceiveShellHook(ShellHook):
 
     def __init__(self, controldir):
         self.controldir = controldir
-        filepath = os.path.join(controldir, 'hooks', 'post-receive')
-        ShellHook.__init__(self, 'post-receive', filepath, 0)
+        filepath = os.path.join(controldir, "hooks", "post-receive")
+        ShellHook.__init__(self, "post-receive", filepath, 0)
 
     def execute(self, client_refs):
         # do nothing if the script doesn't exist
@@ -171,26 +182,25 @@ class PostReceiveShellHook(ShellHook):
 
         try:
             env = os.environ.copy()
-            env['GIT_DIR'] = self.controldir
+            env["GIT_DIR"] = self.controldir
 
             p = subprocess.Popen(
                 self.filepath,
                 stdin=subprocess.PIPE,
                 stdout=subprocess.PIPE,
                 stderr=subprocess.PIPE,
-                env=env
+                env=env,
             )
 
             # client_refs is a list of (oldsha, newsha, ref)
-            in_data = '\n'.join([' '.join(ref) for ref in client_refs])
+            in_data = "\n".join([" ".join(ref) for ref in client_refs])
 
             out_data, err_data = p.communicate(in_data)
 
             if (p.returncode != 0) or err_data:
-                err_fmt = "post-receive exit code: %d\n" \
-                    + "stdout:\n%s\nstderr:\n%s"
+                err_fmt = "post-receive exit code: %d\n" + "stdout:\n%s\nstderr:\n%s"
                 err_msg = err_fmt % (p.returncode, out_data, err_data)
-                raise HookError(err_msg)
+                raise HookError(err_msg.decode('utf-8', 'backslashreplace'))
             return out_data
         except OSError as err:
             raise HookError(repr(err))

+ 86 - 85
dulwich/ignore.py

@@ -32,7 +32,7 @@ from typing import (
     TYPE_CHECKING,
     Dict,
     Union,
-    )
+)
 
 if TYPE_CHECKING:
     from dulwich.repo import Repo
@@ -42,34 +42,34 @@ from dulwich.config import get_xdg_config_home_path, Config
 
 def _translate_segment(segment: bytes) -> bytes:
     if segment == b"*":
-        return b'[^/]+'
+        return b"[^/]+"
     res = b""
     i, n = 0, len(segment)
     while i < n:
-        c = segment[i:i+1]
-        i = i+1
-        if c == b'*':
-            res += b'[^/]*'
-        elif c == b'?':
-            res += b'[^/]'
-        elif c == b'[':
+        c = segment[i : i + 1]
+        i = i + 1
+        if c == b"*":
+            res += b"[^/]*"
+        elif c == b"?":
+            res += b"[^/]"
+        elif c == b"[":
             j = i
-            if j < n and segment[j:j+1] == b'!':
-                j = j+1
-            if j < n and segment[j:j+1] == b']':
-                j = j+1
-            while j < n and segment[j:j+1] != b']':
-                j = j+1
+            if j < n and segment[j : j + 1] == b"!":
+                j = j + 1
+            if j < n and segment[j : j + 1] == b"]":
+                j = j + 1
+            while j < n and segment[j : j + 1] != b"]":
+                j = j + 1
             if j >= n:
-                res += b'\\['
+                res += b"\\["
             else:
-                stuff = segment[i:j].replace(b'\\', b'\\\\')
-                i = j+1
-                if stuff.startswith(b'!'):
-                    stuff = b'^' + stuff[1:]
-                elif stuff.startswith(b'^'):
-                    stuff = b'\\' + stuff
-                res += b'[' + stuff + b']'
+                stuff = segment[i:j].replace(b"\\", b"\\\\")
+                i = j + 1
+                if stuff.startswith(b"!"):
+                    stuff = b"^" + stuff[1:]
+                elif stuff.startswith(b"^"):
+                    stuff = b"\\" + stuff
+                res += b"[" + stuff + b"]"
         else:
             res += re.escape(c)
     return res
@@ -84,32 +84,31 @@ def translate(pat: bytes) -> bytes:
     to cope with features in Git ignore patterns.
     """
 
-    res = b'(?ms)'
+    res = b"(?ms)"
 
-    if b'/' not in pat[:-1]:
+    if b"/" not in pat[:-1]:
         # If there's no slash, this is a filename-based match
-        res += b'(.*/)?'
+        res += b"(.*/)?"
 
-    if pat.startswith(b'**/'):
+    if pat.startswith(b"**/"):
         # Leading **/
         pat = pat[2:]
-        res += b'(.*/)?'
+        res += b"(.*/)?"
 
-    if pat.startswith(b'/'):
+    if pat.startswith(b"/"):
         pat = pat[1:]
 
-    for i, segment in enumerate(pat.split(b'/')):
-        if segment == b'**':
-            res += b'(/.*)?'
+    for i, segment in enumerate(pat.split(b"/")):
+        if segment == b"**":
+            res += b"(/.*)?"
             continue
         else:
-            res += ((re.escape(b'/') if i > 0 else b'') +
-                    _translate_segment(segment))
+            res += (re.escape(b"/") if i > 0 else b"") + _translate_segment(segment)
 
-    if not pat.endswith(b'/'):
-        res += b'/?'
+    if not pat.endswith(b"/"):
+        res += b"/?"
 
-    return res + b'\\Z'
+    return res + b"\\Z"
 
 
 def read_ignore_patterns(f: BinaryIO) -> Iterable[bytes]:
@@ -127,20 +126,19 @@ def read_ignore_patterns(f: BinaryIO) -> Iterable[bytes]:
         if not line:
             continue
 
-        if line.startswith(b'#'):
+        if line.startswith(b"#"):
             # Comment
             continue
 
         # Trailing spaces are ignored unless they are quoted with a backslash.
-        while line.endswith(b' ') and not line.endswith(b'\\ '):
+        while line.endswith(b" ") and not line.endswith(b"\\ "):
             line = line[:-1]
-        line = line.replace(b'\\ ', b' ')
+        line = line.replace(b"\\ ", b" ")
 
         yield line
 
 
-def match_pattern(
-        path: bytes, pattern: bytes, ignorecase: bool = False) -> bool:
+def match_pattern(path: bytes, pattern: bytes, ignorecase: bool = False) -> bool:
     """Match a gitignore-style pattern against a path.
 
     Args:
@@ -159,11 +157,11 @@ class Pattern(object):
     def __init__(self, pattern: bytes, ignorecase: bool = False):
         self.pattern = pattern
         self.ignorecase = ignorecase
-        if pattern[0:1] == b'!':
+        if pattern[0:1] == b"!":
             self.is_exclude = False
             pattern = pattern[1:]
         else:
-            if pattern[0:1] == b'\\':
+            if pattern[0:1] == b"\\":
                 pattern = pattern[1:]
             self.is_exclude = True
         flags = 0
@@ -178,13 +176,18 @@ class Pattern(object):
         return os.fsdecode(self.pattern)
 
     def __eq__(self, other: object) -> bool:
-        return (isinstance(other, type(self)) and
-                self.pattern == other.pattern and
-                self.ignorecase == other.ignorecase)
+        return (
+            isinstance(other, type(self))
+            and self.pattern == other.pattern
+            and self.ignorecase == other.ignorecase
+        )
 
     def __repr__(self) -> str:
         return "%s(%r, %r)" % (
-            type(self).__name__, self.pattern, self.ignorecase)
+            type(self).__name__,
+            self.pattern,
+            self.ignorecase,
+        )
 
     def match(self, path: bytes) -> bool:
         """Try to match a path against this ignore pattern.
@@ -197,9 +200,7 @@ class Pattern(object):
 
 
 class IgnoreFilter(object):
-
-    def __init__(self, patterns: Iterable[bytes], ignorecase: bool = False,
-                 path=None):
+    def __init__(self, patterns: Iterable[bytes], ignorecase: bool = False, path=None):
         self._patterns = []  # type: List[Pattern]
         self._ignorecase = ignorecase
         self._path = path
@@ -238,15 +239,14 @@ class IgnoreFilter(object):
         return status
 
     @classmethod
-    def from_path(cls, path, ignorecase: bool = False) -> 'IgnoreFilter':
-        with open(path, 'rb') as f:
+    def from_path(cls, path, ignorecase: bool = False) -> "IgnoreFilter":
+        with open(path, "rb") as f:
             return cls(read_ignore_patterns(f), ignorecase, path=path)
 
     def __repr__(self) -> str:
-        path = getattr(self, '_path', None)
+        path = getattr(self, "_path", None)
         if path is not None:
-            return "%s.from_path(%r)" % (
-                type(self).__name__, path)
+            return "%s.from_path(%r)" % (type(self).__name__, path)
         else:
             return "<%s>" % (type(self).__name__)
 
@@ -283,19 +283,22 @@ def default_user_ignore_filter_path(config: Config) -> str:
       Path to a global ignore file
     """
     try:
-        return config.get((b'core', ), b'excludesFile')
+        return config.get((b"core",), b"excludesFile")
     except KeyError:
         pass
 
-    return get_xdg_config_home_path('git', 'ignore')
+    return get_xdg_config_home_path("git", "ignore")
 
 
 class IgnoreFilterManager(object):
     """Ignore file manager."""
 
     def __init__(
-            self, top_path: str, global_filters: List[IgnoreFilter],
-            ignorecase: bool):
+        self,
+        top_path: str,
+        global_filters: List[IgnoreFilter],
+        ignorecase: bool,
+    ):
         self._path_filters = {}  # type: Dict[str, Optional[IgnoreFilter]]
         self._top_path = top_path
         self._global_filters = global_filters
@@ -303,9 +306,11 @@ class IgnoreFilterManager(object):
 
     def __repr__(self) -> str:
         return "%s(%s, %r, %r)" % (
-            type(self).__name__, self._top_path,
+            type(self).__name__,
+            self._top_path,
             self._global_filters,
-            self._ignorecase)
+            self._ignorecase,
+        )
 
     def _load_path(self, path: str) -> Optional[IgnoreFilter]:
         try:
@@ -313,10 +318,9 @@ class IgnoreFilterManager(object):
         except KeyError:
             pass
 
-        p = os.path.join(self._top_path, path, '.gitignore')
+        p = os.path.join(self._top_path, path, ".gitignore")
         try:
-            self._path_filters[path] = IgnoreFilter.from_path(
-                p, self._ignorecase)
+            self._path_filters[path] = IgnoreFilter.from_path(p, self._ignorecase)
         except IOError:
             self._path_filters[path] = None
         return self._path_filters[path]
@@ -324,34 +328,31 @@ class IgnoreFilterManager(object):
     def find_matching(self, path: str) -> Iterable[Pattern]:
         """Find matching patterns for path.
 
-        Stops after the first ignore file with matches.
-
         Args:
           path: Path to check
         Returns:
           Iterator over Pattern instances
         """
         if os.path.isabs(path):
-            raise ValueError('%s is an absolute path' % path)
+            raise ValueError("%s is an absolute path" % path)
         filters = [(0, f) for f in self._global_filters]
-        if os.path.sep != '/':
-            path = path.replace(os.path.sep, '/')
-        parts = path.split('/')
-        for i in range(len(parts)+1):
-            dirname = '/'.join(parts[:i])
+        if os.path.sep != "/":
+            path = path.replace(os.path.sep, "/")
+        parts = path.split("/")
+        matches = []
+        for i in range(len(parts) + 1):
+            dirname = "/".join(parts[:i])
             for s, f in filters:
-                relpath = '/'.join(parts[s:i])
+                relpath = "/".join(parts[s:i])
                 if i < len(parts):
                     # Paths leading up to the final part are all directories,
                     # so need a trailing slash.
-                    relpath += '/'
-                matches = list(f.find_matching(relpath))
-                if matches:
-                    return iter(matches)
+                    relpath += "/"
+                matches += list(f.find_matching(relpath))
             ignore_filter = self._load_path(dirname)
             if ignore_filter is not None:
                 filters.insert(0, (i, ignore_filter))
-        return iter([])
+        return iter(matches)
 
     def is_ignored(self, path: str) -> Optional[bool]:
         """Check whether a path is explicitly included or excluded in ignores.
@@ -368,7 +369,7 @@ class IgnoreFilterManager(object):
         return None
 
     @classmethod
-    def from_repo(cls, repo: 'Repo') -> 'IgnoreFilterManager':
+    def from_repo(cls, repo: "Repo") -> "IgnoreFilterManager":
         """Create a IgnoreFilterManager from a repository.
 
         Args:
@@ -378,13 +379,13 @@ class IgnoreFilterManager(object):
         """
         global_filters = []
         for p in [
-                os.path.join(repo.controldir(), 'info', 'exclude'),
-                default_user_ignore_filter_path(repo.get_config_stack())]:
+            os.path.join(repo.controldir(), "info", "exclude"),
+            default_user_ignore_filter_path(repo.get_config_stack()),
+        ]:
             try:
-                global_filters.append(
-                    IgnoreFilter.from_path(os.path.expanduser(p)))
+                global_filters.append(IgnoreFilter.from_path(os.path.expanduser(p)))
             except IOError:
                 pass
         config = repo.get_config_stack()
-        ignorecase = config.get_boolean((b'core'), (b'ignorecase'), False)
+        ignorecase = config.get_boolean((b"core"), (b"ignorecase"), False)
         return cls(repo.path, global_filters, ignorecase)

+ 216 - 115
dulwich/index.py

@@ -36,7 +36,7 @@ from typing import (
     Iterable,
     Iterator,
     Tuple,
-    )
+)
 
 if TYPE_CHECKING:
     from dulwich.object_store import BaseObjectStore
@@ -49,24 +49,49 @@ from dulwich.objects import (
     Tree,
     hex_to_sha,
     sha_to_hex,
-    )
+)
 from dulwich.pack import (
     SHA1Reader,
     SHA1Writer,
-    )
+)
 
 
+# TODO(jelmer): Switch to dataclass?
 IndexEntry = collections.namedtuple(
-    'IndexEntry', [
-        'ctime', 'mtime', 'dev', 'ino', 'mode', 'uid', 'gid', 'size', 'sha',
-        'flags'])
-
-
+    "IndexEntry",
+    [
+        "ctime",
+        "mtime",
+        "dev",
+        "ino",
+        "mode",
+        "uid",
+        "gid",
+        "size",
+        "sha",
+        "flags",
+        "extended_flags",
+    ],
+)
+
+
+# 2-bit stage (during merge)
 FLAG_STAGEMASK = 0x3000
+
+# assume-valid
 FLAG_VALID = 0x8000
+
+# extended flag (must be zero in version 2)
 FLAG_EXTENDED = 0x4000
 
 
+# used by sparse checkout
+EXTENDED_FLAG_SKIP_WORKTREE = 0x4000
+
+# used by "git add -N"
+EXTENDED_FLAG_INTEND_TO_ADD = 0x2000
+
+
 DEFAULT_VERSION = 2
 
 
@@ -87,9 +112,7 @@ def pathsplit(path):
 
 
 def pathjoin(*args):
-    """Join a /-delimited path.
-
-    """
+    """Join a /-delimited path."""
     return b"/".join([p for p in args if p])
 
 
@@ -121,57 +144,101 @@ def write_cache_time(f, t):
     f.write(struct.pack(">LL", *t))
 
 
-def read_cache_entry(f):
+def read_cache_entry(f, version: int) -> Tuple[str, IndexEntry]:
     """Read an entry from a cache file.
 
     Args:
       f: File-like object to read from
     Returns:
-      tuple with: device, inode, mode, uid, gid, size, sha, flags
+      tuple with: name, IndexEntry
     """
     beginoffset = f.tell()
     ctime = read_cache_time(f)
     mtime = read_cache_time(f)
-    (dev, ino, mode, uid, gid, size, sha, flags, ) = \
-        struct.unpack(">LLLLLL20sH", f.read(20 + 4 * 6 + 2))
-    name = f.read((flags & 0x0fff))
+    (
+        dev,
+        ino,
+        mode,
+        uid,
+        gid,
+        size,
+        sha,
+        flags,
+    ) = struct.unpack(">LLLLLL20sH", f.read(20 + 4 * 6 + 2))
+    if flags & FLAG_EXTENDED:
+        if version < 3:
+            raise AssertionError(
+                'extended flag set in index with version < 3')
+        (extended_flags, ) = struct.unpack(">H", f.read(2))
+    else:
+        extended_flags = 0
+    name = f.read((flags & 0x0FFF))
     # Padding:
-    real_size = ((f.tell() - beginoffset + 8) & ~7)
-    f.read((beginoffset + real_size) - f.tell())
-    return (name, ctime, mtime, dev, ino, mode, uid, gid, size,
-            sha_to_hex(sha), flags & ~0x0fff)
-
-
-def write_cache_entry(f, entry):
+    if version < 4:
+        real_size = (f.tell() - beginoffset + 8) & ~7
+        f.read((beginoffset + real_size) - f.tell())
+    return (
+        name,
+        IndexEntry(
+            ctime,
+            mtime,
+            dev,
+            ino,
+            mode,
+            uid,
+            gid,
+            size,
+            sha_to_hex(sha),
+            flags & ~0x0FFF,
+            extended_flags,
+        ))
+
+
+def write_cache_entry(f, name, entry, version):
     """Write an index entry to a file.
 
     Args:
       f: File object
-      entry: Entry to write, tuple with:
-        (name, ctime, mtime, dev, ino, mode, uid, gid, size, sha, flags)
+      entry: IndexEntry to write, tuple with:
     """
     beginoffset = f.tell()
-    (name, ctime, mtime, dev, ino, mode, uid, gid, size, sha, flags) = entry
-    write_cache_time(f, ctime)
-    write_cache_time(f, mtime)
-    flags = len(name) | (flags & ~0x0fff)
-    f.write(struct.pack(
-            b'>LLLLLL20sH', dev & 0xFFFFFFFF, ino & 0xFFFFFFFF,
-            mode, uid, gid, size, hex_to_sha(sha), flags))
+    write_cache_time(f, entry.ctime)
+    write_cache_time(f, entry.mtime)
+    flags = len(name) | (entry.flags & ~0x0FFF)
+    if entry.extended_flags:
+        flags |= FLAG_EXTENDED
+    if flags & FLAG_EXTENDED and version is not None and version < 3:
+        raise AssertionError('unable to use extended flags in version < 3')
+    f.write(
+        struct.pack(
+            b">LLLLLL20sH",
+            entry.dev & 0xFFFFFFFF,
+            entry.ino & 0xFFFFFFFF,
+            entry.mode,
+            entry.uid,
+            entry.gid,
+            entry.size,
+            hex_to_sha(entry.sha),
+            flags,
+        )
+    )
+    if flags & FLAG_EXTENDED:
+        f.write(struct.pack(b">H", entry.extended_flags))
     f.write(name)
-    real_size = ((f.tell() - beginoffset + 8) & ~7)
-    f.write(b'\0' * ((beginoffset + real_size) - f.tell()))
+    if version < 4:
+        real_size = (f.tell() - beginoffset + 8) & ~7
+        f.write(b"\0" * ((beginoffset + real_size) - f.tell()))
 
 
 def read_index(f: BinaryIO):
     """Read an index file, yielding the individual entries."""
     header = f.read(4)
-    if header != b'DIRC':
+    if header != b"DIRC":
         raise AssertionError("Invalid index file header: %r" % header)
-    (version, num_entries) = struct.unpack(b'>LL', f.read(4 * 2))
-    assert version in (1, 2)
+    (version, num_entries) = struct.unpack(b">LL", f.read(4 * 2))
+    assert version in (1, 2, 3), "index version is %r" % version
     for i in range(num_entries):
-        yield read_cache_entry(f)
+        yield read_cache_entry(f, version)
 
 
 def read_index_dict(f):
@@ -181,14 +248,12 @@ def read_index_dict(f):
       f: File object to read from
     """
     ret = {}
-    for x in read_index(f):
-        ret[x[0]] = IndexEntry(*x[1:])
+    for name, entry in read_index(f):
+        ret[name] = entry
     return ret
 
 
-def write_index(
-        f: BinaryIO,
-        entries: List[Any], version: Optional[int] = None):
+def write_index(f: BinaryIO, entries: List[Tuple[bytes, IndexEntry]], version: Optional[int] = None):
     """Write an index file.
 
     Args:
@@ -198,21 +263,21 @@ def write_index(
     """
     if version is None:
         version = DEFAULT_VERSION
-    f.write(b'DIRC')
-    f.write(struct.pack(b'>LL', version, len(entries)))
-    for x in entries:
-        write_cache_entry(f, x)
+    f.write(b"DIRC")
+    f.write(struct.pack(b">LL", version, len(entries)))
+    for name, entry in entries:
+        write_cache_entry(f, name, entry, version)
 
 
 def write_index_dict(
-        f: BinaryIO, entries: Dict[bytes, IndexEntry],
-        version: Optional[int] = None) -> None:
-    """Write an index file based on the contents of a dictionary.
-
-    """
+    f: BinaryIO,
+    entries: Dict[bytes, IndexEntry],
+    version: Optional[int] = None,
+) -> None:
+    """Write an index file based on the contents of a dictionary."""
     entries_list = []
     for name in sorted(entries):
-        entries_list.append((name,) + tuple(entries[name]))
+        entries_list.append((name, entries[name]))
     write_index(f, entries_list, version=version)
 
 
@@ -262,7 +327,7 @@ class Index(object):
 
     def write(self) -> None:
         """Write current contents of index to disk."""
-        f = GitFile(self._filename, 'wb')
+        f = GitFile(self._filename, "wb")
         try:
             f = SHA1Writer(f)
             write_index_dict(f, self._byname, version=self._version)
@@ -273,13 +338,13 @@ class Index(object):
         """Read current contents of index from disk."""
         if not os.path.exists(self._filename):
             return
-        f = GitFile(self._filename, 'rb')
+        f = GitFile(self._filename, "rb")
         try:
             f = SHA1Reader(f)
-            for x in read_index(f):
-                self[x[0]] = IndexEntry(*x[1:])
+            for name, entry in read_index(f):
+                self[name] = entry
             # FIXME: Additional data?
-            f.read(os.path.getsize(self._filename)-f.tell()-20)
+            f.read(os.path.getsize(self._filename) - f.tell() - 20)
             f.check_sha()
         finally:
             f.close()
@@ -316,7 +381,8 @@ class Index(object):
 
     def iterblobs(self):
         import warnings
-        warnings.warn('Use iterobjects() instead.', PendingDeprecationWarning)
+
+        warnings.warn("Use iterobjects() instead.", PendingDeprecationWarning)
         return self.iterobjects()
 
     def clear(self):
@@ -325,7 +391,7 @@ class Index(object):
 
     def __setitem__(self, name, x):
         assert isinstance(name, bytes)
-        assert len(x) == 10
+        assert len(x) == len(IndexEntry._fields)
         # Remove the old entry if any
         self._byname[name] = IndexEntry(*x)
 
@@ -353,12 +419,18 @@ class Index(object):
         Returns: Iterator over tuples with (oldpath, newpath), (oldmode,
             newmode), (oldsha, newsha)
         """
+
         def lookup_entry(path):
             entry = self[path]
             return entry.sha, cleanup_mode(entry.mode)
+
         for (name, mode, sha) in changes_from_tree(
-                self._byname.keys(), lookup_entry, object_store, tree,
-                want_unchanged=want_unchanged):
+            self._byname.keys(),
+            lookup_entry,
+            object_store,
+            tree,
+            want_unchanged=want_unchanged,
+        ):
             yield (name, mode, sha)
 
     def commit(self, object_store):
@@ -373,8 +445,8 @@ class Index(object):
 
 
 def commit_tree(
-        object_store: 'BaseObjectStore',
-        blobs: Iterable[Tuple[bytes, bytes, int]]) -> bytes:
+    object_store: "BaseObjectStore", blobs: Iterable[Tuple[bytes, bytes, int]]
+) -> bytes:
     """Commit a new tree.
 
     Args:
@@ -383,7 +455,7 @@ def commit_tree(
     Returns:
       SHA1 of the created tree.
     """
-    trees = {b'': {}}  # type: Dict[bytes, Any]
+    trees = {b"": {}}  # type: Dict[bytes, Any]
 
     def add_tree(path):
         if path in trees:
@@ -412,10 +484,11 @@ def commit_tree(
             tree.add(basename, mode, sha)
         object_store.add_object(tree)
         return tree.id
-    return build_tree(b'')
+
+    return build_tree(b"")
 
 
-def commit_index(object_store: 'BaseObjectStore', index: Index) -> bytes:
+def commit_index(object_store: "BaseObjectStore", index: Index) -> bytes:
     """Create a new tree from an index.
 
     Args:
@@ -428,14 +501,18 @@ def commit_index(object_store: 'BaseObjectStore', index: Index) -> bytes:
 
 
 def changes_from_tree(
-        names: Iterable[bytes],
-        lookup_entry: Callable[[bytes], Tuple[bytes, int]],
-        object_store: 'BaseObjectStore', tree: Optional[bytes],
-        want_unchanged=False) -> Iterable[
-            Tuple[
-                Tuple[Optional[bytes], Optional[bytes]],
-                Tuple[Optional[int], Optional[int]],
-                Tuple[Optional[bytes], Optional[bytes]]]]:
+    names: Iterable[bytes],
+    lookup_entry: Callable[[bytes], Tuple[bytes, int]],
+    object_store: "BaseObjectStore",
+    tree: Optional[bytes],
+    want_unchanged=False,
+) -> Iterable[
+    Tuple[
+        Tuple[Optional[bytes], Optional[bytes]],
+        Tuple[Optional[int], Optional[int]],
+        Tuple[Optional[bytes], Optional[bytes]],
+    ]
+]:
     """Find the differences between the contents of a tree and
     a working copy.
 
@@ -460,7 +537,7 @@ def changes_from_tree(
                 yield ((name, None), (mode, None), (sha, None))
             else:
                 other_names.remove(name)
-                if (want_unchanged or other_sha != sha or other_mode != mode):
+                if want_unchanged or other_sha != sha or other_mode != mode:
                     yield ((name, name), (mode, other_mode), (sha, other_sha))
 
     # Mention added files
@@ -474,8 +551,9 @@ def changes_from_tree(
 
 
 def index_entry_from_stat(
-        stat_val, hex_sha: bytes, flags: int,
-        mode: Optional[int] = None):
+    stat_val, hex_sha: bytes, flags: int, mode: Optional[int] = None,
+    extended_flags: Optional[int] = None
+):
     """Create a new index entry from a stat value.
 
     Args:
@@ -487,13 +565,23 @@ def index_entry_from_stat(
         mode = cleanup_mode(stat_val.st_mode)
 
     return IndexEntry(
-            stat_val.st_ctime, stat_val.st_mtime, stat_val.st_dev,
-            stat_val.st_ino, mode, stat_val.st_uid,
-            stat_val.st_gid, stat_val.st_size, hex_sha, flags)
+        stat_val.st_ctime,
+        stat_val.st_mtime,
+        stat_val.st_dev,
+        stat_val.st_ino,
+        mode,
+        stat_val.st_uid,
+        stat_val.st_gid,
+        stat_val.st_size,
+        hex_sha,
+        flags,
+        extended_flags
+    )
 
 
-def build_file_from_blob(blob, mode, target_path, honor_filemode=True,
-                         tree_encoding='utf-8'):
+def build_file_from_blob(
+    blob, mode, target_path, honor_filemode=True, tree_encoding="utf-8"
+):
     """Build a file or symlink on disk based on a Git object.
 
     Args:
@@ -513,18 +601,18 @@ def build_file_from_blob(blob, mode, target_path, honor_filemode=True,
         # FIXME: This will fail on Windows. What should we do instead?
         if oldstat:
             os.unlink(target_path)
-        if sys.platform == 'win32':
+        if sys.platform == "win32":
             # os.readlink on Python3 on Windows requires a unicode string.
             contents = contents.decode(tree_encoding)
             target_path = target_path.decode(tree_encoding)
         os.symlink(contents, target_path)
     else:
         if oldstat is not None and oldstat.st_size == len(contents):
-            with open(target_path, 'rb') as f:
+            with open(target_path, "rb") as f:
                 if f.read() == contents:
                     return oldstat
 
-        with open(target_path, 'wb') as f:
+        with open(target_path, "wb") as f:
             # Write out file
             f.write(contents)
 
@@ -560,9 +648,14 @@ def validate_path(path, element_validator=validate_path_element_default):
         return True
 
 
-def build_index_from_tree(root_path, index_path, object_store, tree_id,
-                          honor_filemode=True,
-                          validate_path_element=validate_path_element_default):
+def build_index_from_tree(
+    root_path,
+    index_path,
+    object_store,
+    tree_id,
+    honor_filemode=True,
+    validate_path_element=validate_path_element_default,
+):
     """Generate and materialize index from a tree
 
     Args:
@@ -600,23 +693,33 @@ def build_index_from_tree(root_path, index_path, object_store, tree_id,
         else:
             obj = object_store[entry.sha]
             st = build_file_from_blob(
-                obj, entry.mode, full_path, honor_filemode=honor_filemode)
+                obj, entry.mode, full_path, honor_filemode=honor_filemode
+            )
 
         # Add file to index
         if not honor_filemode or S_ISGITLINK(entry.mode):
             # we can not use tuple slicing to build a new tuple,
             # because on windows that will convert the times to
             # longs, which causes errors further along
-            st_tuple = (entry.mode, st.st_ino, st.st_dev, st.st_nlink,
-                        st.st_uid, st.st_gid, st.st_size, st.st_atime,
-                        st.st_mtime, st.st_ctime)
+            st_tuple = (
+                entry.mode,
+                st.st_ino,
+                st.st_dev,
+                st.st_nlink,
+                st.st_uid,
+                st.st_gid,
+                st.st_size,
+                st.st_atime,
+                st.st_mtime,
+                st.st_ctime,
+            )
             st = st.__class__(st_tuple)
         index[entry.path] = index_entry_from_stat(st, entry.sha, 0)
 
     index.write()
 
 
-def blob_from_path_and_mode(fs_path, mode, tree_encoding='utf-8'):
+def blob_from_path_and_mode(fs_path, mode, tree_encoding="utf-8"):
     """Create a blob from a path and a stat object.
 
     Args:
@@ -627,19 +730,19 @@ def blob_from_path_and_mode(fs_path, mode, tree_encoding='utf-8'):
     assert isinstance(fs_path, bytes)
     blob = Blob()
     if stat.S_ISLNK(mode):
-        if sys.platform == 'win32':
+        if sys.platform == "win32":
             # os.readlink on Python3 on Windows requires a unicode string.
             fs_path = os.fsdecode(fs_path)
             blob.data = os.readlink(fs_path).encode(tree_encoding)
         else:
             blob.data = os.readlink(fs_path)
     else:
-        with open(fs_path, 'rb') as f:
+        with open(fs_path, "rb") as f:
             blob.data = f.read()
     return blob
 
 
-def blob_from_path_and_stat(fs_path, st, tree_encoding='utf-8'):
+def blob_from_path_and_stat(fs_path, st, tree_encoding="utf-8"):
     """Create a blob from a path and a stat object.
 
     Args:
@@ -659,6 +762,7 @@ def read_submodule_head(path):
     """
     from dulwich.errors import NotGitRepository
     from dulwich.repo import Repo
+
     # Repo currently expects a "str", so decode if necessary.
     # TODO(jelmer): Perhaps move this into Repo() ?
     if not isinstance(path, str):
@@ -686,7 +790,7 @@ def _has_directory_changed(tree_path, entry):
     otherwise or if the path is not a directory.
     """
     # This is actually a directory
-    if os.path.exists(os.path.join(tree_path, b'.git')):
+    if os.path.exists(os.path.join(tree_path, b".git")):
         # Submodule
         head = read_submodule_head(tree_path)
         if entry.sha != head:
@@ -735,7 +839,7 @@ def get_unstaged_changes(index: Index, root_path, filter_blob_callback=None):
                 yield tree_path
 
 
-os_sep_bytes = os.sep.encode('ascii')
+os_sep_bytes = os.sep.encode("ascii")
 
 
 def _tree_to_fs_path(root_path, tree_path: bytes):
@@ -748,8 +852,8 @@ def _tree_to_fs_path(root_path, tree_path: bytes):
     Returns: File system path.
     """
     assert isinstance(tree_path, bytes)
-    if os_sep_bytes != b'/':
-        sep_corrected_path = tree_path.replace(b'/', os_sep_bytes)
+    if os_sep_bytes != b"/":
+        sep_corrected_path = tree_path.replace(b"/", os_sep_bytes)
     else:
         sep_corrected_path = tree_path
     return os.path.join(root_path, sep_corrected_path)
@@ -767,8 +871,8 @@ def _fs_to_tree_path(fs_path):
         fs_path_bytes = os.fsencode(fs_path)
     else:
         fs_path_bytes = fs_path
-    if os_sep_bytes != b'/':
-        tree_path = fs_path_bytes.replace(os_sep_bytes, b'/')
+    if os_sep_bytes != b"/":
+        tree_path = fs_path_bytes.replace(os_sep_bytes, b"/")
     else:
         tree_path = fs_path_bytes
     return tree_path
@@ -790,12 +894,11 @@ def index_entry_from_path(path, object_store=None):
     assert isinstance(path, bytes)
     st = os.lstat(path)
     if stat.S_ISDIR(st.st_mode):
-        if os.path.exists(os.path.join(path, b'.git')):
+        if os.path.exists(os.path.join(path, b".git")):
             head = read_submodule_head(path)
             if head is None:
                 return None
-            return index_entry_from_stat(
-                st, head, 0, mode=S_IFGITLINK)
+            return index_entry_from_stat(st, head, 0, mode=S_IFGITLINK)
         return None
 
     if stat.S_ISREG(st.st_mode) or stat.S_ISLNK(st.st_mode):
@@ -808,7 +911,8 @@ def index_entry_from_path(path, object_store=None):
 
 
 def iter_fresh_entries(
-        paths, root_path, object_store: Optional['BaseObjectStore'] = None):
+    paths, root_path, object_store: Optional["BaseObjectStore"] = None
+):
     """Iterate over current versions of index entries on disk.
 
     Args:
@@ -839,18 +943,16 @@ def iter_fresh_blobs(index, root_path):
     Returns: Iterator over path, sha, mode
     """
     import warnings
-    warnings.warn(PendingDeprecationWarning,
-                  "Use iter_fresh_objects instead.")
-    for entry in iter_fresh_objects(
-            index, root_path, include_deleted=True):
+
+    warnings.warn(PendingDeprecationWarning, "Use iter_fresh_objects instead.")
+    for entry in iter_fresh_objects(index, root_path, include_deleted=True):
         if entry[1] is None:
             del index[entry[0]]
         else:
             yield entry
 
 
-def iter_fresh_objects(paths, root_path, include_deleted=False,
-                       object_store=None):
+def iter_fresh_objects(paths, root_path, include_deleted=False, object_store=None):
     """Iterate over versions of objecs on disk referenced by index.
 
     Args:
@@ -860,8 +962,7 @@ def iter_fresh_objects(paths, root_path, include_deleted=False,
       object_store: Optional object store to report new items to
     Returns: Iterator over path, sha, mode
     """
-    for path, entry in iter_fresh_entries(paths, root_path,
-                                          object_store=object_store):
+    for path, entry in iter_fresh_entries(paths, root_path, object_store=object_store):
         if entry is None:
             if include_deleted:
                 yield path, None, None

+ 7 - 8
dulwich/lfs.py

@@ -33,24 +33,24 @@ class LFSStore(object):
     def create(cls, lfs_dir):
         if not os.path.isdir(lfs_dir):
             os.mkdir(lfs_dir)
-        os.mkdir(os.path.join(lfs_dir, 'tmp'))
-        os.mkdir(os.path.join(lfs_dir, 'objects'))
+        os.mkdir(os.path.join(lfs_dir, "tmp"))
+        os.mkdir(os.path.join(lfs_dir, "objects"))
         return cls(lfs_dir)
 
     @classmethod
     def from_repo(cls, repo, create=False):
-        lfs_dir = os.path.join(repo.controldir, 'lfs')
+        lfs_dir = os.path.join(repo.controldir, "lfs")
         if create:
             return cls.create(lfs_dir)
         return cls(lfs_dir)
 
     def _sha_path(self, sha):
-        return os.path.join(self.path, 'objects', sha[0:2], sha[2:4], sha)
+        return os.path.join(self.path, "objects", sha[0:2], sha[2:4], sha)
 
     def open_object(self, sha):
         """Open an object by sha."""
         try:
-            return open(self._sha_path(sha), 'rb')
+            return open(self._sha_path(sha), "rb")
         except FileNotFoundError:
             raise KeyError(sha)
 
@@ -60,9 +60,8 @@ class LFSStore(object):
         Returns: object SHA
         """
         sha = hashlib.sha256()
-        tmpdir = os.path.join(self.path, 'tmp')
-        with tempfile.NamedTemporaryFile(
-                dir=tmpdir, mode='wb', delete=False) as f:
+        tmpdir = os.path.join(self.path, "tmp")
+        with tempfile.NamedTemporaryFile(dir=tmpdir, mode="wb", delete=False) as f:
             for chunk in chunks:
                 sha.update(chunk)
                 f.write(chunk)

+ 40 - 12
dulwich/line_ending.py

@@ -31,6 +31,16 @@ The normalization is a two-fold process that happens at two moments:
   when doing a `git add` call. We call this process the write filter in this
   module.
 
+Note that when checking status (getting unstaged changes), whether or not
+normalization is done on write depends on whether or not the file in the
+working dir has also been normalized on read:
+
+- For autocrlf=true all files are always normalized on both read and write.
+- For autocrlf=input files are only normalized on write if they are newly
+  "added". Since files which are already committed are not normalized on
+  checkout into the working tree, they are also left alone when staging
+  modifications into the index.
+
 One thing to know is that Git does line-ending normalization only on text
 files. How does Git know that a file is text? We can either mark a file as a
 text file, a binary file or ask Git to automatically decides. Git has an
@@ -156,8 +166,7 @@ def convert_lf_to_crlf(text_hunk):
 
 
 def get_checkout_filter(core_eol, core_autocrlf, git_attributes):
-    """ Returns the correct checkout filter based on the passed arguments
-    """
+    """Returns the correct checkout filter based on the passed arguments"""
     # TODO this function should process the git_attributes for the path and if
     # the text attribute is not defined, fallback on the
     # get_checkout_filter_autocrlf function with the autocrlf value
@@ -165,8 +174,7 @@ def get_checkout_filter(core_eol, core_autocrlf, git_attributes):
 
 
 def get_checkin_filter(core_eol, core_autocrlf, git_attributes):
-    """ Returns the correct checkin filter based on the passed arguments
-    """
+    """Returns the correct checkin filter based on the passed arguments"""
     # TODO this function should process the git_attributes for the path and if
     # the text attribute is not defined, fallback on the
     # get_checkin_filter_autocrlf function with the autocrlf value
@@ -174,7 +182,7 @@ def get_checkin_filter(core_eol, core_autocrlf, git_attributes):
 
 
 def get_checkout_filter_autocrlf(core_autocrlf):
-    """ Returns the correct checkout filter base on autocrlf value
+    """Returns the correct checkout filter base on autocrlf value
 
     Args:
       core_autocrlf: The bytes configuration value of core.autocrlf.
@@ -190,7 +198,7 @@ def get_checkout_filter_autocrlf(core_autocrlf):
 
 
 def get_checkin_filter_autocrlf(core_autocrlf):
-    """ Returns the correct checkin filter base on autocrlf value
+    """Returns the correct checkin filter base on autocrlf value
 
     Args:
       core_autocrlf: The bytes configuration value of core.autocrlf.
@@ -207,7 +215,7 @@ def get_checkin_filter_autocrlf(core_autocrlf):
 
 
 class BlobNormalizer(object):
-    """ An object to store computation result of which filter to apply based
+    """An object to store computation result of which filter to apply based
     on configuration, gitattributes, path and operation (checkin or checkout)
     """
 
@@ -234,8 +242,7 @@ class BlobNormalizer(object):
         )
 
     def checkin_normalize(self, blob, tree_path):
-        """ Normalize a blob during a checkin operation
-        """
+        """Normalize a blob during a checkin operation"""
         if self.fallback_write_filter is not None:
             return normalize_blob(
                 blob, self.fallback_write_filter, binary_detection=True
@@ -244,8 +251,7 @@ class BlobNormalizer(object):
         return blob
 
     def checkout_normalize(self, blob, tree_path):
-        """ Normalize a blob during a checkout operation
-        """
+        """Normalize a blob during a checkout operation"""
         if self.fallback_read_filter is not None:
             return normalize_blob(
                 blob, self.fallback_read_filter, binary_detection=True
@@ -255,7 +261,7 @@ class BlobNormalizer(object):
 
 
 def normalize_blob(blob, conversion, binary_detection):
-    """ Takes a blob as input returns either the original blob if
+    """Takes a blob as input returns either the original blob if
     binary_detection is True and the blob content looks like binary, else
     return a new blob with converted data
     """
@@ -276,3 +282,25 @@ def normalize_blob(blob, conversion, binary_detection):
     new_blob.data = converted_data
 
     return new_blob
+
+
+class TreeBlobNormalizer(BlobNormalizer):
+    def __init__(self, config_stack, git_attributes, object_store, tree=None):
+        super().__init__(config_stack, git_attributes)
+        if tree:
+            self.existing_paths = {
+                name
+                for name, _, _ in object_store.iter_tree_contents(tree)
+            }
+        else:
+            self.existing_paths = set()
+
+    def checkin_normalize(self, blob, tree_path):
+        # Existing files should only be normalized on checkin if it was
+        # previously normalized on checkout
+        if (
+            self.fallback_read_filter is not None
+            or tree_path not in self.existing_paths
+        ):
+            return super().checkin_normalize(blob, tree_path)
+        return blob

+ 6 - 3
dulwich/log_utils.py

@@ -49,15 +49,18 @@ class _NullHandler(logging.Handler):
 
 
 _NULL_HANDLER = _NullHandler()
-_DULWICH_LOGGER = getLogger('dulwich')
+_DULWICH_LOGGER = getLogger("dulwich")
 _DULWICH_LOGGER.addHandler(_NULL_HANDLER)
 
 
 def default_logging_config():
     """Set up the default Dulwich loggers."""
     remove_null_handler()
-    logging.basicConfig(level=logging.INFO, stream=sys.stderr,
-                        format='%(asctime)s %(levelname)s: %(message)s')
+    logging.basicConfig(
+        level=logging.INFO,
+        stream=sys.stderr,
+        format="%(asctime)s %(levelname)s: %(message)s",
+    )
 
 
 def remove_null_handler():

+ 35 - 26
dulwich/lru_cache.py

@@ -26,7 +26,7 @@ _null_key = object()
 class _LRUNode(object):
     """This maintains the linked-list which is the lru internals."""
 
-    __slots__ = ('prev', 'next_key', 'key', 'value', 'cleanup', 'size')
+    __slots__ = ("prev", "next_key", "key", "value", "cleanup", "size")
 
     def __init__(self, key, value, cleanup=None):
         self.prev = None
@@ -44,8 +44,12 @@ class _LRUNode(object):
             prev_key = None
         else:
             prev_key = self.prev.key
-        return '%s(%r n:%r p:%r)' % (self.__class__.__name__, self.key,
-                                     self.next_key, prev_key)
+        return "%s(%r n:%r p:%r)" % (
+            self.__class__.__name__,
+            self.key,
+            self.next_key,
+            prev_key,
+        )
 
     def run_cleanup(self):
         if self.cleanup is not None:
@@ -108,29 +112,35 @@ class LRUCache(object):
         node = self._most_recently_used
         if node is not None:
             if node.prev is not None:
-                raise AssertionError('the _most_recently_used entry is not'
-                                     ' supposed to have a previous entry'
-                                     ' %s' % (node,))
+                raise AssertionError(
+                    "the _most_recently_used entry is not"
+                    " supposed to have a previous entry"
+                    " %s" % (node,)
+                )
         while node is not None:
             if node.next_key is _null_key:
                 if node is not self._least_recently_used:
-                    raise AssertionError('only the last node should have'
-                                         ' no next value: %s' % (node,))
+                    raise AssertionError(
+                        "only the last node should have" " no next value: %s" % (node,)
+                    )
                 node_next = None
             else:
                 node_next = self._cache[node.next_key]
                 if node_next.prev is not node:
-                    raise AssertionError('inconsistency found, node.next.prev'
-                                         ' != node: %s' % (node,))
+                    raise AssertionError(
+                        "inconsistency found, node.next.prev" " != node: %s" % (node,)
+                    )
             if node.prev is None:
                 if node is not self._most_recently_used:
-                    raise AssertionError('only the _most_recently_used should'
-                                         ' not have a previous node: %s'
-                                         % (node,))
+                    raise AssertionError(
+                        "only the _most_recently_used should"
+                        " not have a previous node: %s" % (node,)
+                    )
             else:
                 if node.prev.next_key != node.key:
-                    raise AssertionError('inconsistency found, node.prev.next'
-                                         ' != node: %s' % (node,))
+                    raise AssertionError(
+                        "inconsistency found, node.prev.next" " != node: %s" % (node,)
+                    )
             yield node
             node = node_next
 
@@ -147,7 +157,7 @@ class LRUCache(object):
                         'value' should be cleaned up.
         """
         if key is _null_key:
-            raise ValueError('cannot use _null_key as a key')
+            raise ValueError("cannot use _null_key as a key")
         if key in self._cache:
             node = self._cache[key]
             node.run_cleanup()
@@ -186,7 +196,7 @@ class LRUCache(object):
 
     def items(self):
         """Get the key:value pairs as a dict."""
-        return dict((k, n.value) for k, n in self._cache.items())
+        return {k: n.value for k, n in self._cache.items()}
 
     def cleanup(self):
         """Clear the cache until it shrinks to the requested size.
@@ -262,16 +272,14 @@ class LRUCache(object):
 
     def resize(self, max_cache, after_cleanup_count=None):
         """Change the number of entries that will be cached."""
-        self._update_max_cache(max_cache,
-                               after_cleanup_count=after_cleanup_count)
+        self._update_max_cache(max_cache, after_cleanup_count=after_cleanup_count)
 
     def _update_max_cache(self, max_cache, after_cleanup_count=None):
         self._max_cache = max_cache
         if after_cleanup_count is None:
             self._after_cleanup_count = self._max_cache * 8 / 10
         else:
-            self._after_cleanup_count = min(after_cleanup_count,
-                                            self._max_cache)
+            self._after_cleanup_count = min(after_cleanup_count, self._max_cache)
         self.cleanup()
 
 
@@ -285,8 +293,9 @@ class LRUSizeCache(LRUCache):
     defaults to len() if not supplied.
     """
 
-    def __init__(self, max_size=1024*1024, after_cleanup_size=None,
-                 compute_size=None):
+    def __init__(
+        self, max_size=1024 * 1024, after_cleanup_size=None, compute_size=None
+    ):
         """Create a new LRUSizeCache.
 
         Args:
@@ -306,7 +315,7 @@ class LRUSizeCache(LRUCache):
         if compute_size is None:
             self._compute_size = len
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
-        LRUCache.__init__(self, max_cache=max(int(max_size/512), 1))
+        LRUCache.__init__(self, max_cache=max(int(max_size / 512), 1))
 
     def add(self, key, value, cleanup=None):
         """Add a new value to the cache.
@@ -321,7 +330,7 @@ class LRUSizeCache(LRUCache):
                         'value' should be cleaned up.
         """
         if key is _null_key:
-            raise ValueError('cannot use _null_key as a key')
+            raise ValueError("cannot use _null_key as a key")
         node = self._cache.get(key, None)
         value_len = self._compute_size(value)
         if value_len >= self._after_cleanup_size:
@@ -363,7 +372,7 @@ class LRUSizeCache(LRUCache):
     def resize(self, max_size, after_cleanup_size=None):
         """Change the number of bytes that will be cached."""
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
-        max_cache = max(int(max_size/512), 1)
+        max_cache = max(int(max_size / 512), 1)
         self._update_max_cache(max_cache)
 
     def _update_max_size(self, max_size, after_cleanup_size=None):

+ 6 - 5
dulwich/mailmap.py

@@ -44,11 +44,11 @@ def read_mailmap(f):
     """
     for line in f:
         # Remove comments
-        line = line.split(b'#')[0]
+        line = line.split(b"#")[0]
         line = line.strip()
         if not line:
             continue
-        (canonical_identity, from_identity) = line.split(b'>', 1)
+        (canonical_identity, from_identity) = line.split(b">", 1)
         canonical_identity += b">"
         if from_identity.strip():
             parsed_from_identity = parse_identity(from_identity)
@@ -99,8 +99,9 @@ class Mailmap(object):
             canonical_identity = self._table.get(query)
             if canonical_identity is not None:
                 identity = (
-                        canonical_identity[0] or identity[0],
-                        canonical_identity[1] or identity[1])
+                    canonical_identity[0] or identity[0],
+                    canonical_identity[1] or identity[1],
+                )
                 break
         if was_tuple:
             return identity
@@ -109,5 +110,5 @@ class Mailmap(object):
 
     @classmethod
     def from_path(cls, path):
-        with open(path, 'rb') as f:
+        with open(path, "rb") as f:
             return cls(read_mailmap(f))

+ 292 - 115
dulwich/object_store.py

@@ -30,10 +30,10 @@ import sys
 from dulwich.diff_tree import (
     tree_changes,
     walk_trees,
-    )
+)
 from dulwich.errors import (
     NotTreeError,
-    )
+)
 from dulwich.file import GitFile
 from dulwich.objects import (
     Commit,
@@ -47,12 +47,13 @@ from dulwich.objects import (
     S_ISGITLINK,
     object_class,
     valid_hexsha,
-    )
+)
 from dulwich.pack import (
     Pack,
     PackData,
     PackInflater,
     PackFileDisappeared,
+    load_pack_index_file,
     iter_sha1,
     pack_objects_to_data,
     write_pack_header,
@@ -62,21 +63,32 @@ from dulwich.pack import (
     compute_file_sha,
     PackIndexer,
     PackStreamCopier,
-    )
+)
+from dulwich.protocol import DEPTH_INFINITE
 from dulwich.refs import ANNOTATED_TAG_SUFFIX
 
-INFODIR = 'info'
-PACKDIR = 'pack'
+INFODIR = "info"
+PACKDIR = "pack"
 
 
 class BaseObjectStore(object):
     """Object store interface."""
 
-    def determine_wants_all(self, refs):
-        return [sha for (ref, sha) in refs.items()
-                if sha not in self and
-                not ref.endswith(ANNOTATED_TAG_SUFFIX) and
-                not sha == ZERO_SHA]
+    def determine_wants_all(self, refs, depth=None):
+        def _want_deepen(sha):
+            if not depth:
+                return False
+            if depth == DEPTH_INFINITE:
+                return True
+            return depth > self._get_depth(sha)
+
+        return [
+            sha
+            for (ref, sha) in refs.items()
+            if (sha not in self or _want_deepen(sha))
+            and not ref.endswith(ANNOTATED_TAG_SUFFIX)
+            and not sha == ZERO_SHA
+        ]
 
     def iter_shas(self, shas):
         """Iterate over the objects for the specified shas.
@@ -126,9 +138,7 @@ class BaseObjectStore(object):
         raise NotImplementedError(self.__iter__)
 
     def add_object(self, obj):
-        """Add a single object to this object store.
-
-        """
+        """Add a single object to this object store."""
         raise NotImplementedError(self.add_object)
 
     def add_objects(self, objects, progress=None):
@@ -152,17 +162,27 @@ class BaseObjectStore(object):
         f, commit, abort = self.add_pack()
         try:
             write_pack_data(
-                f, count, pack_data, progress,
-                compression_level=self.pack_compression_level)
+                f,
+                count,
+                pack_data,
+                progress,
+                compression_level=self.pack_compression_level,
+            )
         except BaseException:
             abort()
             raise
         else:
             return commit()
 
-    def tree_changes(self, source, target, want_unchanged=False,
-                     include_trees=False, change_type_same=False,
-                     rename_detector=None):
+    def tree_changes(
+        self,
+        source,
+        target,
+        want_unchanged=False,
+        include_trees=False,
+        change_type_same=False,
+        rename_detector=None,
+    ):
         """Find the differences between the contents of two trees
 
         Args:
@@ -175,14 +195,20 @@ class BaseObjectStore(object):
         Returns: Iterator over tuples with
             (oldpath, newpath), (oldmode, newmode), (oldsha, newsha)
         """
-        for change in tree_changes(self, source, target,
-                                   want_unchanged=want_unchanged,
-                                   include_trees=include_trees,
-                                   change_type_same=change_type_same,
-                                   rename_detector=rename_detector):
-            yield ((change.old.path, change.new.path),
-                   (change.old.mode, change.new.mode),
-                   (change.old.sha, change.new.sha))
+        for change in tree_changes(
+            self,
+            source,
+            target,
+            want_unchanged=want_unchanged,
+            include_trees=include_trees,
+            change_type_same=change_type_same,
+            rename_detector=rename_detector,
+        ):
+            yield (
+                (change.old.path, change.new.path),
+                (change.old.mode, change.new.mode),
+                (change.old.sha, change.new.sha),
+            )
 
     def iter_tree_contents(self, tree_id, include_trees=False):
         """Iterate the contents of a tree and all subtrees.
@@ -196,14 +222,21 @@ class BaseObjectStore(object):
             tree.
         """
         for entry, _ in walk_trees(self, tree_id, None):
-            if ((entry.mode is not None and
-                 not stat.S_ISDIR(entry.mode)) or include_trees):
+            if (
+                entry.mode is not None and not stat.S_ISDIR(entry.mode)
+            ) or include_trees:
                 yield entry
 
-    def find_missing_objects(self, haves, wants, shallow=None, progress=None,
-                             get_tagged=None,
-                             get_parents=lambda commit: commit.parents,
-                             depth=None):
+    def find_missing_objects(
+        self,
+        haves,
+        wants,
+        shallow=None,
+        progress=None,
+        get_tagged=None,
+        get_parents=lambda commit: commit.parents,
+        depth=None,
+    ):
         """Find the missing objects required for a set of revisions.
 
         Args:
@@ -218,8 +251,15 @@ class BaseObjectStore(object):
             commit.
         Returns: Iterator over (sha, path) pairs.
         """
-        finder = MissingObjectFinder(self, haves, wants, shallow, progress,
-                                     get_tagged, get_parents=get_parents)
+        finder = MissingObjectFinder(
+            self,
+            haves,
+            wants,
+            shallow,
+            progress,
+            get_tagged,
+            get_parents=get_parents,
+        )
         return iter(finder.next, None)
 
     def find_common_revisions(self, graphwalker):
@@ -250,8 +290,9 @@ class BaseObjectStore(object):
         missing = self.find_missing_objects(have, want, shallow, progress)
         return self.iter_shas(missing)
 
-    def generate_pack_data(self, have, want, shallow=None, progress=None,
-                           ofs_delta=True):
+    def generate_pack_data(
+        self, have, want, shallow=None, progress=None, ofs_delta=True
+    ):
         """Generate pack data objects for a set of wants/haves.
 
         Args:
@@ -263,7 +304,8 @@ class BaseObjectStore(object):
         """
         # TODO(jelmer): More efficient implementation
         return pack_objects_to_data(
-            self.generate_pack_contents(have, want, shallow, progress))
+            self.generate_pack_contents(have, want, shallow, progress)
+        )
 
     def peel_sha(self, sha):
         """Peel all tags from a SHA.
@@ -281,8 +323,13 @@ class BaseObjectStore(object):
             obj = self[sha]
         return obj
 
-    def _collect_ancestors(self, heads, common=set(), shallow=set(),
-                           get_parents=lambda commit: commit.parents):
+    def _collect_ancestors(
+        self,
+        heads,
+        common=set(),
+        shallow=set(),
+        get_parents=lambda commit: commit.parents,
+    ):
         """Collect all ancestors of heads up to (excluding) those in common.
 
         Args:
@@ -311,13 +358,42 @@ class BaseObjectStore(object):
                 queue.extend(get_parents(cmt))
         return (commits, bases)
 
+    def _get_depth(
+        self, head, get_parents=lambda commit: commit.parents, max_depth=None,
+    ):
+        """Return the current available depth for the given head.
+        For commits with multiple parents, the largest possible depth will be
+        returned.
+
+        Args:
+            head: commit to start from
+            get_parents: optional function for getting the parents of a commit
+            max_depth: maximum depth to search
+        """
+        if head not in self:
+            return 0
+        current_depth = 1
+        queue = [(head, current_depth)]
+        while queue and (max_depth is None or current_depth < max_depth):
+            e, depth = queue.pop(0)
+            current_depth = max(current_depth, depth)
+            cmt = self[e]
+            if isinstance(cmt, Tag):
+                _cls, sha = cmt.object
+                cmt = self[sha]
+            queue.extend(
+                (parent, depth + 1)
+                for parent in get_parents(cmt)
+                if parent in self
+            )
+        return current_depth
+
     def close(self):
         """Close any files opened by this object store."""
         # Default implementation is a NO-OP
 
 
 class PackBasedObjectStore(BaseObjectStore):
-
     def __init__(self, pack_compression_level=-1):
         self._pack_cache = {}
         self.pack_compression_level = pack_compression_level
@@ -352,9 +428,7 @@ class PackBasedObjectStore(BaseObjectStore):
         return False
 
     def _add_cached_pack(self, base_name, pack):
-        """Add a newly appeared pack to the cache by path.
-
-        """
+        """Add a newly appeared pack to the cache by path."""
         prev_pack = self._pack_cache.get(base_name)
         if prev_pack is not pack:
             self._pack_cache[base_name] = pack
@@ -380,8 +454,7 @@ class PackBasedObjectStore(BaseObjectStore):
     @property
     def packs(self):
         """List with pack objects."""
-        return (
-            list(self._iter_cached_packs()) + list(self._update_pack_cache()))
+        return list(self._iter_cached_packs()) + list(self._update_pack_cache())
 
     def _iter_alternate_objects(self):
         """Iterate over the SHAs of all the objects in alternate stores."""
@@ -480,7 +553,7 @@ class PackBasedObjectStore(BaseObjectStore):
             sha = name
             hexsha = None
         else:
-            raise AssertionError("Invalid object name %r" % (name, ))
+            raise AssertionError("Invalid object name %r" % (name,))
         for pack in self._iter_cached_packs():
             try:
                 return pack.get_raw(sha)
@@ -513,16 +586,13 @@ class PackBasedObjectStore(BaseObjectStore):
             __len__.
         Returns: Pack object of the objects written.
         """
-        return self.add_pack_data(
-                *pack_objects_to_data(objects),
-                progress=progress)
+        return self.add_pack_data(*pack_objects_to_data(objects), progress=progress)
 
 
 class DiskObjectStore(PackBasedObjectStore):
     """Git-style object store that exists on disk."""
 
-    def __init__(self, path, loose_compression_level=-1,
-                 pack_compression_level=-1):
+    def __init__(self, path, loose_compression_level=-1, pack_compression_level=-1):
         """Open an object store.
 
         Args:
@@ -531,7 +601,8 @@ class DiskObjectStore(PackBasedObjectStore):
           pack_compression_level: zlib compression level for pack objects
         """
         super(DiskObjectStore, self).__init__(
-            pack_compression_level=pack_compression_level)
+            pack_compression_level=pack_compression_level
+        )
         self.path = path
         self.pack_dir = os.path.join(self.path, PACKDIR)
         self._alternates = None
@@ -544,18 +615,21 @@ class DiskObjectStore(PackBasedObjectStore):
     @classmethod
     def from_config(cls, path, config):
         try:
-            default_compression_level = int(config.get(
-                (b'core', ), b'compression').decode())
+            default_compression_level = int(
+                config.get((b"core",), b"compression").decode()
+            )
         except KeyError:
             default_compression_level = -1
         try:
-            loose_compression_level = int(config.get(
-                (b'core', ), b'looseCompression').decode())
+            loose_compression_level = int(
+                config.get((b"core",), b"looseCompression").decode()
+            )
         except KeyError:
             loose_compression_level = default_compression_level
         try:
-            pack_compression_level = int(config.get(
-                (b'core', ), 'packCompression').decode())
+            pack_compression_level = int(
+                config.get((b"core",), "packCompression").decode()
+            )
         except KeyError:
             pack_compression_level = default_compression_level
         return cls(path, loose_compression_level, pack_compression_level)
@@ -571,7 +645,7 @@ class DiskObjectStore(PackBasedObjectStore):
 
     def _read_alternate_paths(self):
         try:
-            f = GitFile(os.path.join(self.path, INFODIR, "alternates"), 'rb')
+            f = GitFile(os.path.join(self.path, INFODIR, "alternates"), "rb")
         except FileNotFoundError:
             return
         with f:
@@ -582,20 +656,18 @@ class DiskObjectStore(PackBasedObjectStore):
                 if os.path.isabs(line):
                     yield os.fsdecode(line)
                 else:
-                    yield os.fsdecode(os.path.join(os.fsencode(self.path),
-                                                   line))
+                    yield os.fsdecode(os.path.join(os.fsencode(self.path), line))
 
     def add_alternate_path(self, path):
-        """Add an alternate path to this object store.
-        """
+        """Add an alternate path to this object store."""
         try:
             os.mkdir(os.path.join(self.path, INFODIR))
         except FileExistsError:
             pass
         alternates_path = os.path.join(self.path, INFODIR, "alternates")
-        with GitFile(alternates_path, 'wb') as f:
+        with GitFile(alternates_path, "wb") as f:
             try:
-                orig_f = open(alternates_path, 'rb')
+                orig_f = open(alternates_path, "rb")
             except FileNotFoundError:
                 pass
             else:
@@ -621,7 +693,7 @@ class DiskObjectStore(PackBasedObjectStore):
                 # fully written)
                 idx_name = os.path.splitext(name)[0] + ".idx"
                 if idx_name in pack_dir_contents:
-                    pack_name = name[:-len(".pack")]
+                    pack_name = name[: -len(".pack")]
                     pack_files.add(pack_name)
 
         # Open newly appeared pack files
@@ -645,7 +717,7 @@ class DiskObjectStore(PackBasedObjectStore):
             if len(base) != 2:
                 continue
             for rest in os.listdir(os.path.join(self.path, base)):
-                sha = os.fsencode(base+rest)
+                sha = os.fsencode(base + rest)
                 if not valid_hexsha(sha):
                     continue
                 yield sha
@@ -672,7 +744,7 @@ class DiskObjectStore(PackBasedObjectStore):
     def _get_pack_basepath(self, entries):
         suffix = iter_sha1(entry[0] for entry in entries)
         # TODO: Handle self.pack_dir being bytes
-        suffix = suffix.decode('ascii')
+        suffix = suffix.decode("ascii")
         return os.path.join(self.pack_dir, "pack-" + suffix)
 
     def _complete_thin_pack(self, f, path, copier, indexer):
@@ -708,8 +780,12 @@ class DiskObjectStore(PackBasedObjectStore):
             type_num, data = self.get_raw(ext_sha)
             offset = f.tell()
             crc32 = write_pack_object(
-                f, type_num, data, sha=new_sha,
-                compression_level=self.pack_compression_level)
+                f,
+                type_num,
+                data,
+                sha=new_sha,
+                compression_level=self.pack_compression_level,
+            )
             entries.append((ext_sha, offset, crc32))
         pack_sha = new_sha.digest()
         f.write(pack_sha)
@@ -718,8 +794,8 @@ class DiskObjectStore(PackBasedObjectStore):
         # Move the pack in.
         entries.sort()
         pack_base_name = self._get_pack_basepath(entries)
-        target_pack = pack_base_name + '.pack'
-        if sys.platform == 'win32':
+        target_pack = pack_base_name + ".pack"
+        if sys.platform == "win32":
             # Windows might have the target pack file lingering. Attempt
             # removal, silently passing if the target does not exist.
             try:
@@ -729,7 +805,7 @@ class DiskObjectStore(PackBasedObjectStore):
         os.rename(path, target_pack)
 
         # Write the index.
-        index_file = GitFile(pack_base_name + '.idx', 'wb')
+        index_file = GitFile(pack_base_name + ".idx", "wb")
         try:
             write_pack_index_v2(index_file, entries, pack_sha)
             index_file.close()
@@ -758,11 +834,11 @@ class DiskObjectStore(PackBasedObjectStore):
             objects/pack directory.
         """
         import tempfile
-        fd, path = tempfile.mkstemp(dir=self.path, prefix='tmp_pack_')
-        with os.fdopen(fd, 'w+b') as f:
+
+        fd, path = tempfile.mkstemp(dir=self.path, prefix="tmp_pack_")
+        with os.fdopen(fd, "w+b") as f:
             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
-            copier = PackStreamCopier(read_all, read_some, f,
-                                      delta_iter=indexer)
+            copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             return self._complete_thin_pack(f, path, copier, indexer)
 
@@ -785,8 +861,8 @@ class DiskObjectStore(PackBasedObjectStore):
         for pack in self.packs:
             if pack._basename == basename:
                 return pack
-        target_pack = basename + '.pack'
-        if sys.platform == 'win32':
+        target_pack = basename + ".pack"
+        if sys.platform == "win32":
             # Windows might have the target pack file lingering. Attempt
             # removal, silently passing if the target does not exist.
             try:
@@ -806,8 +882,9 @@ class DiskObjectStore(PackBasedObjectStore):
             function.
         """
         import tempfile
+
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
-        f = os.fdopen(fd, 'wb')
+        f = os.fdopen(fd, "wb")
 
         def commit():
             f.flush()
@@ -822,6 +899,7 @@ class DiskObjectStore(PackBasedObjectStore):
         def abort():
             f.close()
             os.remove(path)
+
         return f, commit, abort
 
     def add_object(self, obj):
@@ -838,9 +916,10 @@ class DiskObjectStore(PackBasedObjectStore):
             pass
         if os.path.exists(path):
             return  # Already there, no need to write again
-        with GitFile(path, 'wb') as f:
-            f.write(obj.as_legacy_object(
-                compression_level=self.loose_compression_level))
+        with GitFile(path, "wb") as f:
+            f.write(
+                obj.as_legacy_object(compression_level=self.loose_compression_level)
+            )
 
     @classmethod
     def init(cls, path):
@@ -904,9 +983,7 @@ class MemoryObjectStore(BaseObjectStore):
         del self._data[self._to_hexsha(name)]
 
     def add_object(self, obj):
-        """Add a single object to this object store.
-
-        """
+        """Add a single object to this object store."""
         self._data[obj.id] = obj.copy()
 
     def add_objects(self, objects, progress=None):
@@ -937,6 +1014,7 @@ class MemoryObjectStore(BaseObjectStore):
 
         def abort():
             pass
+
         return f, commit, abort
 
     def _complete_thin_pack(self, f, indexer):
@@ -959,8 +1037,7 @@ class MemoryObjectStore(BaseObjectStore):
         for ext_sha in indexer.ext_refs():
             assert len(ext_sha) == 20
             type_num, data = self.get_raw(ext_sha)
-            write_pack_object(
-                f, type_num, data, sha=new_sha)
+            write_pack_object(f, type_num, data, sha=new_sha)
         pack_sha = new_sha.digest()
         f.write(pack_sha)
 
@@ -980,8 +1057,7 @@ class MemoryObjectStore(BaseObjectStore):
         f, commit, abort = self.add_pack()
         try:
             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
-            copier = PackStreamCopier(read_all, read_some, f,
-                                      delta_iter=indexer)
+            copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             self._complete_thin_pack(f, indexer)
         except BaseException:
@@ -1059,7 +1135,8 @@ class ObjectStoreIterator(ObjectIterator):
 
     def empty(self):
         import warnings
-        warnings.warn('Use bool() instead.', DeprecationWarning)
+
+        warnings.warn("Use bool() instead.", DeprecationWarning)
         return self._empty()
 
     def _empty(self):
@@ -1138,7 +1215,8 @@ def _split_commits_and_tags(obj_store, lst, ignore_unknown=False):
                 tags.add(e)
                 tagged = o.object[1]
                 c, t, o = _split_commits_and_tags(
-                    obj_store, [tagged], ignore_unknown=ignore_unknown)
+                    obj_store, [tagged], ignore_unknown=ignore_unknown
+                )
                 commits |= c
                 tags |= t
                 others |= o
@@ -1162,8 +1240,16 @@ class MissingObjectFinder(object):
       tagged: dict of pointed-to sha -> tag sha for including tags
     """
 
-    def __init__(self, object_store, haves, wants, shallow=None, progress=None,
-                 get_tagged=None, get_parents=lambda commit: commit.parents):
+    def __init__(
+        self,
+        object_store,
+        haves,
+        wants,
+        shallow=None,
+        progress=None,
+        get_tagged=None,
+        get_parents=lambda commit: commit.parents,
+    ):
         self.object_store = object_store
         if shallow is None:
             shallow = set()
@@ -1173,20 +1259,26 @@ class MissingObjectFinder(object):
         # and such SHAs would get filtered out by _split_commits_and_tags,
         # wants shall list only known SHAs, and otherwise
         # _split_commits_and_tags fails with KeyError
-        have_commits, have_tags, have_others = (
-            _split_commits_and_tags(object_store, haves, True))
-        want_commits, want_tags, want_others = (
-            _split_commits_and_tags(object_store, wants, False))
+        have_commits, have_tags, have_others = _split_commits_and_tags(
+            object_store, haves, True
+        )
+        want_commits, want_tags, want_others = _split_commits_and_tags(
+            object_store, wants, False
+        )
         # all_ancestors is a set of commits that shall not be sent
         # (complete repository up to 'haves')
         all_ancestors = object_store._collect_ancestors(
-            have_commits, shallow=shallow, get_parents=self._get_parents)[0]
+            have_commits, shallow=shallow, get_parents=self._get_parents
+        )[0]
         # all_missing - complete set of commits between haves and wants
         # common - commits from all_ancestors we hit into while
         # traversing parent hierarchy of wants
         missing_commits, common_commits = object_store._collect_ancestors(
-            want_commits, all_ancestors, shallow=shallow,
-            get_parents=self._get_parents)
+            want_commits,
+            all_ancestors,
+            shallow=shallow,
+            get_parents=self._get_parents,
+        )
         self.sha_done = set()
         # Now, fill sha_done with commits and revisions of
         # files and directories known to be both locally
@@ -1216,8 +1308,7 @@ class MissingObjectFinder(object):
         self._tagged = get_tagged and get_tagged() or {}
 
     def add_todo(self, entries):
-        self.objects_to_send.update([e for e in entries
-                                     if not e[0] in self.sha_done])
+        self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
 
     def next(self):
         while True:
@@ -1231,16 +1322,19 @@ class MissingObjectFinder(object):
             if isinstance(o, Commit):
                 self.add_todo([(o.tree, "", False)])
             elif isinstance(o, Tree):
-                self.add_todo([(s, n, not stat.S_ISDIR(m))
-                               for n, m, s in o.iteritems()
-                               if not S_ISGITLINK(m)])
+                self.add_todo(
+                    [
+                        (s, n, not stat.S_ISDIR(m))
+                        for n, m, s in o.iteritems()
+                        if not S_ISGITLINK(m)
+                    ]
+                )
             elif isinstance(o, Tag):
                 self.add_todo([(o.object[1], None, False)])
         if sha in self._tagged:
             self.add_todo([(self._tagged[sha], None, True)])
         self.sha_done.add(sha)
-        self.progress(("counting objects: %d\r" %
-                       len(self.sha_done)).encode('ascii'))
+        self.progress(("counting objects: %d\r" % len(self.sha_done)).encode("ascii"))
         return (sha, name)
 
     __next__ = next
@@ -1297,10 +1391,12 @@ class ObjectStoreGraphWalker(object):
         """Iterate over ancestors of heads in the target."""
         if self.heads:
             ret = self.heads.pop()
-            ps = self.get_parents(ret)
+            try:
+                ps = self.get_parents(ret)
+            except KeyError:
+                return None
             self.parents[ret] = ps
-            self.heads.update(
-                [p for p in ps if p not in self.parents])
+            self.heads.update([p for p in ps if p not in self.parents])
             return ret
         return None
 
@@ -1333,15 +1429,14 @@ def commit_tree_changes(object_store, tree, changes):
     nested_changes = {}
     for (path, new_mode, new_sha) in changes:
         try:
-            (dirname, subpath) = path.split(b'/', 1)
+            (dirname, subpath) = path.split(b"/", 1)
         except ValueError:
             if new_sha is None:
                 del tree[path]
             else:
                 tree[path] = (new_mode, new_sha)
         else:
-            nested_changes.setdefault(dirname, []).append(
-                (subpath, new_mode, new_sha))
+            nested_changes.setdefault(dirname, []).append((subpath, new_mode, new_sha))
     for name, subchanges in nested_changes.items():
         try:
             orig_subtree = object_store[tree[name][1]]
@@ -1418,3 +1513,85 @@ def read_packs_file(f):
         if kind != b"P":
             continue
         yield os.fsdecode(name)
+
+
+class BucketBasedObjectStore(PackBasedObjectStore):
+    """Object store implementation that uses a bucket store like S3 as backend.
+    """
+
+    def _iter_loose_objects(self):
+        """Iterate over the SHAs of all loose objects."""
+        return iter([])
+
+    def _get_loose_object(self, sha):
+        return None
+
+    def _remove_loose_object(self, sha):
+        # Doesn't exist..
+        pass
+
+    def _remove_pack(self, name):
+        raise NotImplementedError(self._remove_pack)
+
+    def _iter_pack_names(self):
+        raise NotImplementedError(self._iter_pack_names)
+
+    def _get_pack(self, name):
+        raise NotImplementedError(self._get_pack)
+
+    def _update_pack_cache(self):
+        pack_files = set(self._iter_pack_names())
+
+        # Open newly appeared pack files
+        new_packs = []
+        for f in pack_files:
+            if f not in self._pack_cache:
+                pack = self._get_pack(f)
+                new_packs.append(pack)
+                self._pack_cache[f] = pack
+        # Remove disappeared pack files
+        for f in set(self._pack_cache) - pack_files:
+            self._pack_cache.pop(f).close()
+        return new_packs
+
+    def _upload_pack(self, basename, pack_file, index_file):
+        raise NotImplementedError
+
+    def add_pack(self):
+        """Add a new pack to this object store.
+
+        Returns: Fileobject to write to, a commit function to
+            call when the pack is finished and an abort
+            function.
+        """
+        import tempfile
+
+        pf = tempfile.SpooledTemporaryFile()
+
+        def commit():
+            if pf.tell() == 0:
+                pf.close()
+                return None
+
+            pf.seek(0)
+            p = PackData(pf.name, pf)
+            entries = p.sorted_entries()
+            basename = iter_sha1(entry[0] for entry in entries).decode('ascii')
+            idxf = tempfile.SpooledTemporaryFile()
+            checksum = p.get_stored_checksum()
+            write_pack_index_v2(idxf, entries, checksum)
+            idxf.seek(0)
+            idx = load_pack_index_file(basename + '.idx', idxf)
+            for pack in self.packs:
+                if pack.get_stored_checksum() == p.get_stored_checksum():
+                    p.close()
+                    idx.close()
+                    return pack
+            pf.seek(0)
+            idxf.seek(0)
+            self._upload_pack(basename, pf, idxf)
+            final_pack = Pack.from_objects(p, idx)
+            self._add_cached_pack(basename, final_pack)
+            return final_pack
+
+        return pf, commit, pf.close

文件差異過大導致無法顯示
+ 330 - 213
dulwich/objects.py


+ 8 - 5
dulwich/objectspec.py

@@ -20,10 +20,12 @@
 
 """Object specification."""
 
+from typing import Union, List, Tuple
+
 
 def to_bytes(text):
     if getattr(text, "encode", None) is not None:
-        text = text.encode('ascii')
+        text = text.encode("ascii")
     return text
 
 
@@ -75,7 +77,7 @@ def parse_ref(container, refspec):
         b"refs/tags/" + refspec,
         b"refs/heads/" + refspec,
         b"refs/remotes/" + refspec,
-        b"refs/remotes/" + refspec + b"/HEAD"
+        b"refs/remotes/" + refspec + b"/HEAD",
     ]
     for ref in possible_refs:
         if ref in container:
@@ -119,7 +121,9 @@ def parse_reftuple(lh_container, rh_container, refspec, force=False):
 
 
 def parse_reftuples(
-        lh_container, rh_container, refspecs, force=False):
+        lh_container, rh_container,
+        refspecs: Union[bytes, List[bytes], List[Tuple[bytes, bytes]]],
+        force: bool = False):
     """Parse a list of reftuple specs to a list of reftuples.
 
     Args:
@@ -136,8 +140,7 @@ def parse_reftuples(
     ret = []
     # TODO: Support * in refspecs
     for refspec in refspecs:
-        ret.append(parse_reftuple(
-            lh_container, rh_container, refspec, force=force))
+        ret.append(parse_reftuple(lh_container, rh_container, refspec, force=force))
     return ret
 
 

文件差異過大導致無法顯示
+ 230 - 189
dulwich/pack.py


+ 100 - 69
dulwich/patch.py

@@ -32,13 +32,12 @@ from dulwich.objects import (
     Blob,
     Commit,
     S_ISGITLINK,
-    )
+)
 
 FIRST_FEW_BYTES = 8000
 
 
-def write_commit_patch(f, commit, contents, progress, version=None,
-                       encoding=None):
+def write_commit_patch(f, commit, contents, progress, version=None, encoding=None):
     """Write a individual file patch.
 
     Args:
@@ -51,19 +50,30 @@ def write_commit_patch(f, commit, contents, progress, version=None,
     if isinstance(contents, str):
         contents = contents.encode(encoding)
     (num, total) = progress
-    f.write(b"From " + commit.id + b" " +
-            time.ctime(commit.commit_time).encode(encoding) + b"\n")
+    f.write(
+        b"From "
+        + commit.id
+        + b" "
+        + time.ctime(commit.commit_time).encode(encoding)
+        + b"\n"
+    )
     f.write(b"From: " + commit.author + b"\n")
-    f.write(b"Date: " +
-            time.strftime("%a, %d %b %Y %H:%M:%S %Z").encode(encoding) + b"\n")
-    f.write(("Subject: [PATCH %d/%d] " % (num, total)).encode(encoding) +
-            commit.message + b"\n")
+    f.write(
+        b"Date: " + time.strftime("%a, %d %b %Y %H:%M:%S %Z").encode(encoding) + b"\n"
+    )
+    f.write(
+        ("Subject: [PATCH %d/%d] " % (num, total)).encode(encoding)
+        + commit.message
+        + b"\n"
+    )
     f.write(b"\n")
     f.write(b"---\n")
     try:
         import subprocess
-        p = subprocess.Popen(["diffstat"], stdout=subprocess.PIPE,
-                             stdin=subprocess.PIPE)
+
+        p = subprocess.Popen(
+            ["diffstat"], stdout=subprocess.PIPE, stdin=subprocess.PIPE
+        )
     except (ImportError, OSError):
         pass  # diffstat not available?
     else:
@@ -74,6 +84,7 @@ def write_commit_patch(f, commit, contents, progress, version=None,
     f.write(b"-- \n")
     if version is None:
         from dulwich import __version__ as dulwich_version
+
         f.write(b"Dulwich %d.%d.%d\n" % dulwich_version)
     else:
         f.write(version.encode(encoding) + b"\n")
@@ -86,7 +97,7 @@ def get_summary(commit):
       commit: Commit
     Returns: Summary string
     """
-    decoded = commit.message.decode(errors='replace')
+    decoded = commit.message.decode(errors="replace")
     return decoded.splitlines()[0].replace(" ", "-")
 
 
@@ -97,15 +108,24 @@ def _format_range_unified(start, stop):
     beginning = start + 1  # lines start numbering with one
     length = stop - start
     if length == 1:
-        return '{}'.format(beginning)
+        return "{}".format(beginning)
     if not length:
         beginning -= 1  # empty ranges begin at line just before the range
-    return '{},{}'.format(beginning, length)
-
-
-def unified_diff(a, b, fromfile='', tofile='', fromfiledate='',
-                 tofiledate='', n=3, lineterm='\n', tree_encoding='utf-8',
-                 output_encoding='utf-8'):
+    return "{},{}".format(beginning, length)
+
+
+def unified_diff(
+    a,
+    b,
+    fromfile="",
+    tofile="",
+    fromfiledate="",
+    tofiledate="",
+    n=3,
+    lineterm="\n",
+    tree_encoding="utf-8",
+    output_encoding="utf-8",
+):
     """difflib.unified_diff that can detect "No newline at end of file" as
     original "git diff" does.
 
@@ -115,43 +135,37 @@ def unified_diff(a, b, fromfile='', tofile='', fromfiledate='',
     for group in SequenceMatcher(None, a, b).get_grouped_opcodes(n):
         if not started:
             started = True
-            fromdate = '\t{}'.format(fromfiledate) if fromfiledate else ''
-            todate = '\t{}'.format(tofiledate) if tofiledate else ''
-            yield '--- {}{}{}'.format(
-                fromfile.decode(tree_encoding),
-                fromdate,
-                lineterm
-                ).encode(output_encoding)
-            yield '+++ {}{}{}'.format(
-                tofile.decode(tree_encoding),
-                todate,
-                lineterm
-                ).encode(output_encoding)
+            fromdate = "\t{}".format(fromfiledate) if fromfiledate else ""
+            todate = "\t{}".format(tofiledate) if tofiledate else ""
+            yield "--- {}{}{}".format(
+                fromfile.decode(tree_encoding), fromdate, lineterm
+            ).encode(output_encoding)
+            yield "+++ {}{}{}".format(
+                tofile.decode(tree_encoding), todate, lineterm
+            ).encode(output_encoding)
 
         first, last = group[0], group[-1]
         file1_range = _format_range_unified(first[1], last[2])
         file2_range = _format_range_unified(first[3], last[4])
-        yield '@@ -{} +{} @@{}'.format(
-            file1_range,
-            file2_range,
-            lineterm
-             ).encode(output_encoding)
+        yield "@@ -{} +{} @@{}".format(file1_range, file2_range, lineterm).encode(
+            output_encoding
+        )
 
         for tag, i1, i2, j1, j2 in group:
-            if tag == 'equal':
+            if tag == "equal":
                 for line in a[i1:i2]:
-                    yield b' ' + line
+                    yield b" " + line
                 continue
-            if tag in ('replace', 'delete'):
+            if tag in ("replace", "delete"):
                 for line in a[i1:i2]:
-                    if not line[-1:] == b'\n':
-                        line += b'\n\\ No newline at end of file\n'
-                    yield b'-' + line
-            if tag in ('replace', 'insert'):
+                    if not line[-1:] == b"\n":
+                        line += b"\n\\ No newline at end of file\n"
+                    yield b"-" + line
+            if tag in ("replace", "insert"):
                 for line in b[j1:j2]:
-                    if not line[-1:] == b'\n':
-                        line += b'\n\\ No newline at end of file\n'
-                    yield b'+' + line
+                    if not line[-1:] == b"\n":
+                        line += b"\n\\ No newline at end of file\n"
+                    yield b"+" + line
 
 
 def is_binary(content):
@@ -160,7 +174,7 @@ def is_binary(content):
     Args:
       content: Bytestring to check for binary content
     """
-    return b'\0' in content[:FIRST_FEW_BYTES]
+    return b"\0" in content[:FIRST_FEW_BYTES]
 
 
 def shortid(hexsha):
@@ -197,7 +211,7 @@ def write_object_diff(f, store, old_file, new_file, diff_binary=False):
 
     def content(mode, hexsha):
         if hexsha is None:
-            return Blob.from_string(b'')
+            return Blob.from_string(b"")
         elif S_ISGITLINK(mode):
             return Blob.from_string(b"Subproject commit " + hexsha + b"\n")
         else:
@@ -208,12 +222,13 @@ def write_object_diff(f, store, old_file, new_file, diff_binary=False):
             return []
         else:
             return content.splitlines()
-    f.writelines(gen_diff_header(
-        (old_path, new_path), (old_mode, new_mode), (old_id, new_id)))
+
+    f.writelines(
+        gen_diff_header((old_path, new_path), (old_mode, new_mode), (old_id, new_id))
+    )
     old_content = content(old_mode, old_id)
     new_content = content(new_mode, new_id)
-    if not diff_binary and (
-            is_binary(old_content.data) or is_binary(new_content.data)):
+    if not diff_binary and (is_binary(old_content.data) or is_binary(new_content.data)):
         binary_diff = (
             b"Binary files "
             + patched_old_path
@@ -223,8 +238,14 @@ def write_object_diff(f, store, old_file, new_file, diff_binary=False):
         )
         f.write(binary_diff)
     else:
-        f.writelines(unified_diff(lines(old_content), lines(new_content),
-                     patched_old_path, patched_new_path))
+        f.writelines(
+            unified_diff(
+                lines(old_content),
+                lines(new_content),
+                patched_old_path,
+                patched_new_path,
+            )
+        )
 
 
 # TODO(jelmer): Support writing unicode, rather than bytes.
@@ -250,13 +271,13 @@ def gen_diff_header(paths, modes, shas):
     if old_mode != new_mode:
         if new_mode is not None:
             if old_mode is not None:
-                yield ("old file mode %o\n" % old_mode).encode('ascii')
-            yield ("new file mode %o\n" % new_mode).encode('ascii')
+                yield ("old file mode %o\n" % old_mode).encode("ascii")
+            yield ("new file mode %o\n" % new_mode).encode("ascii")
         else:
-            yield ("deleted file mode %o\n" % old_mode).encode('ascii')
+            yield ("deleted file mode %o\n" % old_mode).encode("ascii")
     yield b"index " + shortid(old_sha) + b".." + shortid(new_sha)
     if new_mode is not None and old_mode is not None:
-        yield (" %o" % new_mode).encode('ascii')
+        yield (" %o" % new_mode).encode("ascii")
     yield b"\n"
 
 
@@ -281,13 +302,19 @@ def write_blob_diff(f, old_file, new_file):
             return blob.splitlines()
         else:
             return []
-    f.writelines(gen_diff_header(
-        (old_path, new_path), (old_mode, new_mode),
-        (getattr(old_blob, "id", None), getattr(new_blob, "id", None))))
+
+    f.writelines(
+        gen_diff_header(
+            (old_path, new_path),
+            (old_mode, new_mode),
+            (getattr(old_blob, "id", None), getattr(new_blob, "id", None)),
+        )
+    )
     old_contents = lines(old_blob)
     new_contents = lines(new_blob)
-    f.writelines(unified_diff(old_contents, new_contents,
-                 patched_old_path, patched_new_path))
+    f.writelines(
+        unified_diff(old_contents, new_contents, patched_old_path, patched_new_path)
+    )
 
 
 def write_tree_diff(f, store, old_tree, new_tree, diff_binary=False):
@@ -302,8 +329,13 @@ def write_tree_diff(f, store, old_tree, new_tree, diff_binary=False):
     """
     changes = store.tree_changes(old_tree, new_tree)
     for (oldpath, newpath), (oldmode, newmode), (oldsha, newsha) in changes:
-        write_object_diff(f, store, (oldpath, oldmode, oldsha),
-                          (newpath, newmode, newsha), diff_binary=diff_binary)
+        write_object_diff(
+            f,
+            store,
+            (oldpath, oldmode, oldsha),
+            (newpath, newmode, newsha),
+            diff_binary=diff_binary,
+        )
 
 
 def git_am_patch_split(f, encoding=None):
@@ -317,8 +349,7 @@ def git_am_patch_split(f, encoding=None):
     encoding = encoding or getattr(f, "encoding", "ascii")
     encoding = encoding or "ascii"
     contents = f.read()
-    if (isinstance(contents, bytes) and
-            getattr(email.parser, "BytesParser", None)):
+    if isinstance(contents, bytes) and getattr(email.parser, "BytesParser", None):
         parser = email.parser.BytesParser()
         msg = parser.parsebytes(contents)
     else:
@@ -344,7 +375,7 @@ def parse_patch_message(msg, encoding=None):
         subject = msg["subject"]
     else:
         close = msg["subject"].index("] ", patch_tag_start)
-        subject = msg["subject"][close+2:]
+        subject = msg["subject"][close + 2 :]
     c.message = (subject.replace("\n", "") + "\n").encode(encoding)
     first = True
 
@@ -357,7 +388,7 @@ def parse_patch_message(msg, encoding=None):
             break
         if first:
             if line.startswith(b"From: "):
-                c.author = line[len(b"From: "):].rstrip()
+                c.author = line[len(b"From: ") :].rstrip()
             else:
                 c.message += b"\n" + line
             first = False

文件差異過大導致無法顯示
+ 320 - 186
dulwich/porcelain.py


+ 84 - 75
dulwich/protocol.py

@@ -24,14 +24,14 @@
 from io import BytesIO
 from os import (
     SEEK_END,
-    )
+)
 import socket
 
 import dulwich
 from dulwich.errors import (
     HangupException,
     GitProtocolError,
-    )
+)
 
 TCP_GIT_PORT = 9418
 
@@ -48,77 +48,86 @@ SIDE_BAND_CHANNEL_PROGRESS = 2
 # fatal error message just before stream aborts
 SIDE_BAND_CHANNEL_FATAL = 3
 
-CAPABILITY_ATOMIC = b'atomic'
-CAPABILITY_DEEPEN_SINCE = b'deepen-since'
-CAPABILITY_DEEPEN_NOT = b'deepen-not'
-CAPABILITY_DEEPEN_RELATIVE = b'deepen-relative'
-CAPABILITY_DELETE_REFS = b'delete-refs'
-CAPABILITY_INCLUDE_TAG = b'include-tag'
-CAPABILITY_MULTI_ACK = b'multi_ack'
-CAPABILITY_MULTI_ACK_DETAILED = b'multi_ack_detailed'
-CAPABILITY_NO_DONE = b'no-done'
-CAPABILITY_NO_PROGRESS = b'no-progress'
-CAPABILITY_OFS_DELTA = b'ofs-delta'
-CAPABILITY_QUIET = b'quiet'
-CAPABILITY_REPORT_STATUS = b'report-status'
-CAPABILITY_SHALLOW = b'shallow'
-CAPABILITY_SIDE_BAND = b'side-band'
-CAPABILITY_SIDE_BAND_64K = b'side-band-64k'
-CAPABILITY_THIN_PACK = b'thin-pack'
-CAPABILITY_AGENT = b'agent'
-CAPABILITY_SYMREF = b'symref'
-CAPABILITY_ALLOW_TIP_SHA1_IN_WANT = b'allow-tip-sha1-in-want'
-CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT = b'allow-reachable-sha1-in-want'
+CAPABILITY_ATOMIC = b"atomic"
+CAPABILITY_DEEPEN_SINCE = b"deepen-since"
+CAPABILITY_DEEPEN_NOT = b"deepen-not"
+CAPABILITY_DEEPEN_RELATIVE = b"deepen-relative"
+CAPABILITY_DELETE_REFS = b"delete-refs"
+CAPABILITY_INCLUDE_TAG = b"include-tag"
+CAPABILITY_MULTI_ACK = b"multi_ack"
+CAPABILITY_MULTI_ACK_DETAILED = b"multi_ack_detailed"
+CAPABILITY_NO_DONE = b"no-done"
+CAPABILITY_NO_PROGRESS = b"no-progress"
+CAPABILITY_OFS_DELTA = b"ofs-delta"
+CAPABILITY_QUIET = b"quiet"
+CAPABILITY_REPORT_STATUS = b"report-status"
+CAPABILITY_SHALLOW = b"shallow"
+CAPABILITY_SIDE_BAND = b"side-band"
+CAPABILITY_SIDE_BAND_64K = b"side-band-64k"
+CAPABILITY_THIN_PACK = b"thin-pack"
+CAPABILITY_AGENT = b"agent"
+CAPABILITY_SYMREF = b"symref"
+CAPABILITY_ALLOW_TIP_SHA1_IN_WANT = b"allow-tip-sha1-in-want"
+CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT = b"allow-reachable-sha1-in-want"
 
 # Magic ref that is used to attach capabilities to when
 # there are no refs. Should always be ste to ZERO_SHA.
-CAPABILITIES_REF = b'capabilities^{}'
+CAPABILITIES_REF = b"capabilities^{}"
 
 COMMON_CAPABILITIES = [
     CAPABILITY_OFS_DELTA,
     CAPABILITY_SIDE_BAND,
     CAPABILITY_SIDE_BAND_64K,
     CAPABILITY_AGENT,
-    CAPABILITY_NO_PROGRESS]
-KNOWN_UPLOAD_CAPABILITIES = set(COMMON_CAPABILITIES + [
-    CAPABILITY_THIN_PACK,
-    CAPABILITY_MULTI_ACK,
-    CAPABILITY_MULTI_ACK_DETAILED,
-    CAPABILITY_INCLUDE_TAG,
-    CAPABILITY_DEEPEN_SINCE,
-    CAPABILITY_SYMREF,
-    CAPABILITY_SHALLOW,
-    CAPABILITY_DEEPEN_NOT,
-    CAPABILITY_DEEPEN_RELATIVE,
-    CAPABILITY_ALLOW_TIP_SHA1_IN_WANT,
-    CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT,
-    ])
-KNOWN_RECEIVE_CAPABILITIES = set(COMMON_CAPABILITIES + [
-    CAPABILITY_REPORT_STATUS,
-    CAPABILITY_DELETE_REFS,
-    CAPABILITY_QUIET,
-    CAPABILITY_ATOMIC,
-    ])
+    CAPABILITY_NO_PROGRESS,
+]
+KNOWN_UPLOAD_CAPABILITIES = set(
+    COMMON_CAPABILITIES
+    + [
+        CAPABILITY_THIN_PACK,
+        CAPABILITY_MULTI_ACK,
+        CAPABILITY_MULTI_ACK_DETAILED,
+        CAPABILITY_INCLUDE_TAG,
+        CAPABILITY_DEEPEN_SINCE,
+        CAPABILITY_SYMREF,
+        CAPABILITY_SHALLOW,
+        CAPABILITY_DEEPEN_NOT,
+        CAPABILITY_DEEPEN_RELATIVE,
+        CAPABILITY_ALLOW_TIP_SHA1_IN_WANT,
+        CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT,
+    ]
+)
+KNOWN_RECEIVE_CAPABILITIES = set(
+    COMMON_CAPABILITIES
+    + [
+        CAPABILITY_REPORT_STATUS,
+        CAPABILITY_DELETE_REFS,
+        CAPABILITY_QUIET,
+        CAPABILITY_ATOMIC,
+    ]
+)
+
+DEPTH_INFINITE = 0x7FFFFFFF
 
 
 def agent_string():
-    return ('dulwich/%d.%d.%d' % dulwich.__version__).encode('ascii')
+    return ("dulwich/%d.%d.%d" % dulwich.__version__).encode("ascii")
 
 
 def capability_agent():
-    return CAPABILITY_AGENT + b'=' + agent_string()
+    return CAPABILITY_AGENT + b"=" + agent_string()
 
 
 def capability_symref(from_ref, to_ref):
-    return CAPABILITY_SYMREF + b'=' + from_ref + b':' + to_ref
+    return CAPABILITY_SYMREF + b"=" + from_ref + b":" + to_ref
 
 
 def extract_capability_names(capabilities):
-    return set(parse_capability(c)[0] for c in capabilities)
+    return {parse_capability(c)[0] for c in capabilities}
 
 
 def parse_capability(capability):
-    parts = capability.split(b'=', 1)
+    parts = capability.split(b"=", 1)
     if len(parts) == 1:
         return (parts[0], None)
     return tuple(parts)
@@ -128,12 +137,12 @@ def symref_capabilities(symrefs):
     return [capability_symref(*k) for k in symrefs]
 
 
-COMMAND_DEEPEN = b'deepen'
-COMMAND_SHALLOW = b'shallow'
-COMMAND_UNSHALLOW = b'unshallow'
-COMMAND_DONE = b'done'
-COMMAND_WANT = b'want'
-COMMAND_HAVE = b'have'
+COMMAND_DEEPEN = b"deepen"
+COMMAND_SHALLOW = b"shallow"
+COMMAND_UNSHALLOW = b"unshallow"
+COMMAND_DONE = b"done"
+COMMAND_WANT = b"want"
+COMMAND_HAVE = b"have"
 
 
 class ProtocolFile(object):
@@ -156,7 +165,7 @@ def format_cmd_pkt(cmd, *args):
 
 def parse_cmd_pkt(line):
     splice_at = line.find(b" ")
-    cmd, args = line[:splice_at], line[splice_at+1:]
+    cmd, args = line[:splice_at], line[splice_at + 1 :]
     assert args[-1:] == b"\x00"
     return cmd, args[:-1].split(b"\0")
 
@@ -170,8 +179,8 @@ def pkt_line(data):
         None, returns the flush-pkt ('0000').
     """
     if data is None:
-        return b'0000'
-    return ('%04x' % (len(data) + 4)).encode('ascii') + data
+        return b"0000"
+    return ("%04x" % (len(data) + 4)).encode("ascii") + data
 
 
 class Protocol(object):
@@ -224,18 +233,19 @@ class Protocol(object):
             size = int(sizestr, 16)
             if size == 0:
                 if self.report_activity:
-                    self.report_activity(4, 'read')
+                    self.report_activity(4, "read")
                 return None
             if self.report_activity:
-                self.report_activity(size, 'read')
-            pkt_contents = read(size-4)
+                self.report_activity(size, "read")
+            pkt_contents = read(size - 4)
         except socket.error as e:
             raise GitProtocolError(e)
         else:
             if len(pkt_contents) + 4 != size:
                 raise GitProtocolError(
-                    'Length of pkt read %04x does not match length prefix %04x'
-                    % (len(pkt_contents) + 4, size))
+                    "Length of pkt read %04x does not match length prefix %04x"
+                    % (len(pkt_contents) + 4, size)
+                )
             return pkt_contents
 
     def eof(self):
@@ -265,7 +275,7 @@ class Protocol(object):
           ValueError: If more than one pkt-line is unread.
         """
         if self._readahead is not None:
-            raise ValueError('Attempted to unread multiple pkt-lines.')
+            raise ValueError("Attempted to unread multiple pkt-lines.")
         self._readahead = BytesIO(pkt_line(data))
 
     def read_pkt_seq(self):
@@ -290,7 +300,7 @@ class Protocol(object):
             line = pkt_line(line)
             self.write(line)
             if self.report_activity:
-                self.report_activity(len(line), 'write')
+                self.report_activity(len(line), "write")
         except socket.error as e:
             raise GitProtocolError(e)
 
@@ -298,7 +308,6 @@ class Protocol(object):
         """Return a writable file-like object for this protocol."""
 
         class ProtocolFile(object):
-
             def __init__(self, proto):
                 self._proto = proto
                 self._offset = 0
@@ -366,10 +375,12 @@ class ReceivableProtocol(Protocol):
     will still block until at least one byte is read.
     """
 
-    def __init__(self, recv, write, close=None, report_activity=None,
-                 rbufsize=_RBUFSIZE):
+    def __init__(
+        self, recv, write, close=None, report_activity=None, rbufsize=_RBUFSIZE
+    ):
         super(ReceivableProtocol, self).__init__(
-                self.read, write, close=close, report_activity=report_activity)
+            self.read, write, close=close, report_activity=report_activity
+        )
         self._recv = recv
         self._rbuf = BytesIO()
         self._rbufsize = rbufsize
@@ -492,9 +503,9 @@ def extract_want_line_capabilities(text):
 
 def ack_type(capabilities):
     """Extract the ack type from a capabilities list."""
-    if b'multi_ack_detailed' in capabilities:
+    if b"multi_ack_detailed" in capabilities:
         return MULTI_ACK_DETAILED
-    elif b'multi_ack' in capabilities:
+    elif b"multi_ack" in capabilities:
         return MULTI_ACK
     return SINGLE_ACK
 
@@ -544,16 +555,14 @@ class BufferedPktLineWriter(object):
 
 
 class PktLineParser(object):
-    """Packet line parser that hands completed packets off to a callback.
-    """
+    """Packet line parser that hands completed packets off to a callback."""
 
     def __init__(self, handle_pkt):
         self.handle_pkt = handle_pkt
         self._readahead = BytesIO()
 
     def parse(self, data):
-        """Parse a fragment of data and call back for any completed packets.
-        """
+        """Parse a fragment of data and call back for any completed packets."""
         self._readahead.write(data)
         buf = self._readahead.getvalue()
         if len(buf) < 4:

+ 88 - 13
dulwich/reflog.py

@@ -27,15 +27,15 @@ from dulwich.objects import (
     format_timezone,
     parse_timezone,
     ZERO_SHA,
-    )
+)
 
 Entry = collections.namedtuple(
-    'Entry', ['old_sha', 'new_sha', 'committer', 'timestamp', 'timezone',
-              'message'])
+    "Entry",
+    ["old_sha", "new_sha", "committer", "timestamp", "timezone", "message"],
+)
 
 
-def format_reflog_line(old_sha, new_sha, committer, timestamp, timezone,
-                       message):
+def format_reflog_line(old_sha, new_sha, committer, timestamp, timezone, message):
     """Generate a single reflog line.
 
     Args:
@@ -48,9 +48,19 @@ def format_reflog_line(old_sha, new_sha, committer, timestamp, timezone,
     """
     if old_sha is None:
         old_sha = ZERO_SHA
-    return (old_sha + b' ' + new_sha + b' ' + committer + b' ' +
-            str(int(timestamp)).encode('ascii') + b' ' +
-            format_timezone(timezone) + b'\t' + message)
+    return (
+        old_sha
+        + b" "
+        + new_sha
+        + b" "
+        + committer
+        + b" "
+        + str(int(timestamp)).encode("ascii")
+        + b" "
+        + format_timezone(timezone)
+        + b"\t"
+        + message
+    )
 
 
 def parse_reflog_line(line):
@@ -61,11 +71,17 @@ def parse_reflog_line(line):
     Returns: Tuple of (old_sha, new_sha, committer, timestamp, timezone,
         message)
     """
-    (begin, message) = line.split(b'\t', 1)
-    (old_sha, new_sha, rest) = begin.split(b' ', 2)
-    (committer, timestamp_str, timezone_str) = rest.rsplit(b' ', 2)
-    return Entry(old_sha, new_sha, committer, int(timestamp_str),
-                 parse_timezone(timezone_str)[0], message)
+    (begin, message) = line.split(b"\t", 1)
+    (old_sha, new_sha, rest) = begin.split(b" ", 2)
+    (committer, timestamp_str, timezone_str) = rest.rsplit(b" ", 2)
+    return Entry(
+        old_sha,
+        new_sha,
+        committer,
+        int(timestamp_str),
+        parse_timezone(timezone_str)[0],
+        message,
+    )
 
 
 def read_reflog(f):
@@ -77,3 +93,62 @@ def read_reflog(f):
     """
     for line in f:
         yield parse_reflog_line(line)
+
+
+def drop_reflog_entry(f, index, rewrite=False):
+    """Drop the specified reflog entry.
+
+    Args:
+        f: File-like object
+        index: Reflog entry index (in Git reflog reverse 0-indexed order)
+        rewrite: If a reflog entry's predecessor is removed, set its
+            old SHA to the new SHA of the entry that now precedes it
+    """
+    if index < 0:
+        raise ValueError("Invalid reflog index %d" % index)
+
+    log = []
+    offset = f.tell()
+    for line in f:
+        log.append((offset, parse_reflog_line(line)))
+        offset = f.tell()
+
+    inverse_index = len(log) - index - 1
+    write_offset = log[inverse_index][0]
+    f.seek(write_offset)
+
+    if index == 0:
+        f.truncate()
+        return
+
+    del log[inverse_index]
+    if rewrite and index > 0 and log:
+        if inverse_index == 0:
+            previous_new = ZERO_SHA
+        else:
+            previous_new = log[inverse_index - 1][1].new_sha
+        offset, entry = log[inverse_index]
+        log[inverse_index] = (
+            offset,
+            Entry(
+                previous_new,
+                entry.new_sha,
+                entry.committer,
+                entry.timestamp,
+                entry.timezone,
+                entry.message,
+            ),
+        )
+
+    for _, entry in log[inverse_index:]:
+        f.write(
+            format_reflog_line(
+                entry.old_sha,
+                entry.new_sha,
+                entry.committer,
+                entry.timestamp,
+                entry.timezone,
+                entry.message,
+            )
+        )
+    f.truncate()

+ 295 - 142
dulwich/refs.py

@@ -27,23 +27,23 @@ import os
 from dulwich.errors import (
     PackedRefsException,
     RefFormatError,
-    )
+)
 from dulwich.objects import (
     git_line,
     valid_hexsha,
     ZERO_SHA,
-    )
+)
 from dulwich.file import (
     GitFile,
     ensure_dir_exists,
-    )
+)
 
 
-SYMREF = b'ref: '
-LOCAL_BRANCH_PREFIX = b'refs/heads/'
-LOCAL_TAG_PREFIX = b'refs/tags/'
-BAD_REF_CHARS = set(b'\177 ~^:?*[')
-ANNOTATED_TAG_SUFFIX = b'^{}'
+SYMREF = b"ref: "
+LOCAL_BRANCH_PREFIX = b"refs/heads/"
+LOCAL_TAG_PREFIX = b"refs/tags/"
+BAD_REF_CHARS = set(b"\177 ~^:?*[")
+ANNOTATED_TAG_SUFFIX = b"^{}"
 
 
 def parse_symref_value(contents):
@@ -54,7 +54,7 @@ def parse_symref_value(contents):
     Returns: Destination
     """
     if contents.startswith(SYMREF):
-        return contents[len(SYMREF):].rstrip(b'\r\n')
+        return contents[len(SYMREF) :].rstrip(b"\r\n")
     raise ValueError(contents)
 
 
@@ -72,22 +72,22 @@ def check_ref_format(refname):
     """
     # These could be combined into one big expression, but are listed
     # separately to parallel [1].
-    if b'/.' in refname or refname.startswith(b'.'):
+    if b"/." in refname or refname.startswith(b"."):
         return False
-    if b'/' not in refname:
+    if b"/" not in refname:
         return False
-    if b'..' in refname:
+    if b".." in refname:
         return False
     for i, c in enumerate(refname):
-        if ord(refname[i:i+1]) < 0o40 or c in BAD_REF_CHARS:
+        if ord(refname[i : i + 1]) < 0o40 or c in BAD_REF_CHARS:
             return False
-    if refname[-1] in b'/.':
+    if refname[-1] in b"/.":
         return False
-    if refname.endswith(b'.lock'):
+    if refname.endswith(b".lock"):
         return False
-    if b'@{' in refname:
+    if b"@{" in refname:
         return False
-    if b'\\' in refname:
+    if b"\\" in refname:
         return False
     return True
 
@@ -98,17 +98,31 @@ class RefsContainer(object):
     def __init__(self, logger=None):
         self._logger = logger
 
-    def _log(self, ref, old_sha, new_sha, committer=None, timestamp=None,
-             timezone=None, message=None):
+    def _log(
+        self,
+        ref,
+        old_sha,
+        new_sha,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         if self._logger is None:
             return
         if message is None:
             return
-        self._logger(ref, old_sha, new_sha, committer, timestamp,
-                     timezone, message)
-
-    def set_symbolic_ref(self, name, other, committer=None, timestamp=None,
-                         timezone=None, message=None):
+        self._logger(ref, old_sha, new_sha, committer, timestamp, timezone, message)
+
+    def set_symbolic_ref(
+        self,
+        name,
+        other,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Make a ref point at another ref.
 
         Args:
@@ -139,8 +153,16 @@ class RefsContainer(object):
         """
         return None
 
-    def import_refs(self, base, other, committer=None, timestamp=None,
-                    timezone=None, message=None, prune=False):
+    def import_refs(
+        self,
+        base,
+        other,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+        prune=False,
+    ):
         if prune:
             to_delete = set(self.subkeys(base))
         else:
@@ -149,16 +171,16 @@ class RefsContainer(object):
             if value is None:
                 to_delete.add(name)
             else:
-                self.set_if_equals(b'/'.join((base, name)), None, value,
-                                   message=message)
+                self.set_if_equals(
+                    b"/".join((base, name)), None, value, message=message
+                )
             if to_delete:
                 try:
                     to_delete.remove(name)
                 except KeyError:
                     pass
         for ref in to_delete:
-            self.remove_if_equals(
-                b'/'.join((base, ref)), None, message=message)
+            self.remove_if_equals(b"/".join((base, ref)), None, message=message)
 
     def allkeys(self):
         """All refs present in this container."""
@@ -196,18 +218,16 @@ class RefsContainer(object):
         return keys
 
     def as_dict(self, base=None):
-        """Return the contents of this container as a dictionary.
-
-        """
+        """Return the contents of this container as a dictionary."""
         ret = {}
         keys = self.keys(base)
         if base is None:
-            base = b''
+            base = b""
         else:
-            base = base.rstrip(b'/')
+            base = base.rstrip(b"/")
         for key in keys:
             try:
-                ret[key] = self[(base + b'/' + key).strip(b'/')]
+                ret[key] = self[(base + b"/" + key).strip(b"/")]
             except KeyError:
                 continue  # Unable to resolve
 
@@ -226,9 +246,9 @@ class RefsContainer(object):
         Raises:
           KeyError: if a refname is not HEAD or is otherwise not valid.
         """
-        if name in (b'HEAD', b'refs/stash'):
+        if name in (b"HEAD", b"refs/stash"):
             return
-        if not name.startswith(b'refs/') or not check_ref_format(name[5:]):
+        if not name.startswith(b"refs/") or not check_ref_format(name[5:]):
             raise RefFormatError(name)
 
     def read_ref(self, refname):
@@ -264,7 +284,7 @@ class RefsContainer(object):
         depth = 0
         refnames = []
         while contents.startswith(SYMREF):
-            refname = contents[len(SYMREF):]
+            refname = contents[len(SYMREF) :]
             refnames.append(refname)
             contents = self.read_ref(refname)
             if not contents:
@@ -276,9 +296,11 @@ class RefsContainer(object):
 
     def _follow(self, name):
         import warnings
+
         warnings.warn(
-            "RefsContainer._follow is deprecated. Use RefsContainer.follow "
-            "instead.", DeprecationWarning)
+            "RefsContainer._follow is deprecated. Use RefsContainer.follow " "instead.",
+            DeprecationWarning,
+        )
         refnames, contents = self.follow(name)
         if not refnames:
             return (None, contents)
@@ -299,8 +321,16 @@ class RefsContainer(object):
             raise KeyError(name)
         return sha
 
-    def set_if_equals(self, name, old_ref, new_ref, committer=None,
-                      timestamp=None, timezone=None, message=None):
+    def set_if_equals(
+        self,
+        name,
+        old_ref,
+        new_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Set a refname to new_ref only if it currently equals old_ref.
 
         This method follows all symbolic references if applicable for the
@@ -343,8 +373,15 @@ class RefsContainer(object):
         """
         self.set_if_equals(name, None, ref)
 
-    def remove_if_equals(self, name, old_ref, committer=None,
-                         timestamp=None, timezone=None, message=None):
+    def remove_if_equals(
+        self,
+        name,
+        old_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Remove a refname only if it currently equals old_ref.
 
         This method does not follow symbolic references, even if applicable for
@@ -399,12 +436,12 @@ class RefsContainer(object):
 
 
 class _DictRefsWatcher(object):
-
     def __init__(self, refs):
         self._refs = refs
 
     def __enter__(self):
         from queue import Queue
+
         self.queue = Queue()
         self._refs._watchers.add(self)
         return self
@@ -449,17 +486,39 @@ class DictRefsContainer(RefsContainer):
     def watch(self):
         return _DictRefsWatcher(self)
 
-    def set_symbolic_ref(self, name, other, committer=None,
-                         timestamp=None, timezone=None, message=None):
+    def set_symbolic_ref(
+        self,
+        name,
+        other,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         old = self.follow(name)[-1]
         new = SYMREF + other
         self._refs[name] = new
         self._notify(name, new)
-        self._log(name, old, new, committer=committer, timestamp=timestamp,
-                  timezone=timezone, message=message)
-
-    def set_if_equals(self, name, old_ref, new_ref, committer=None,
-                      timestamp=None, timezone=None, message=None):
+        self._log(
+            name,
+            old,
+            new,
+            committer=committer,
+            timestamp=timestamp,
+            timezone=timezone,
+            message=message,
+        )
+
+    def set_if_equals(
+        self,
+        name,
+        old_ref,
+        new_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
             return False
         realnames, _ = self.follow(name)
@@ -468,22 +527,50 @@ class DictRefsContainer(RefsContainer):
             old = self._refs.get(realname)
             self._refs[realname] = new_ref
             self._notify(realname, new_ref)
-            self._log(realname, old, new_ref, committer=committer,
-                      timestamp=timestamp, timezone=timezone, message=message)
+            self._log(
+                realname,
+                old,
+                new_ref,
+                committer=committer,
+                timestamp=timestamp,
+                timezone=timezone,
+                message=message,
+            )
         return True
 
-    def add_if_new(self, name, ref, committer=None, timestamp=None,
-                   timezone=None, message=None):
+    def add_if_new(
+        self,
+        name,
+        ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         if name in self._refs:
             return False
         self._refs[name] = ref
         self._notify(name, ref)
-        self._log(name, None, ref, committer=committer, timestamp=timestamp,
-                  timezone=timezone, message=message)
+        self._log(
+            name,
+            None,
+            ref,
+            committer=committer,
+            timestamp=timestamp,
+            timezone=timezone,
+            message=message,
+        )
         return True
 
-    def remove_if_equals(self, name, old_ref, committer=None, timestamp=None,
-                         timezone=None, message=None):
+    def remove_if_equals(
+        self,
+        name,
+        old_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
             return False
         try:
@@ -492,8 +579,15 @@ class DictRefsContainer(RefsContainer):
             pass
         else:
             self._notify(name, None)
-            self._log(name, old, None, committer=committer,
-                      timestamp=timestamp, timezone=timezone, message=message)
+            self._log(
+                name,
+                old,
+                None,
+                committer=committer,
+                timestamp=timestamp,
+                timezone=timezone,
+                message=message,
+            )
         return True
 
     def get_peeled(self, name):
@@ -518,7 +612,7 @@ class InfoRefsContainer(RefsContainer):
         self._refs = {}
         self._peeled = {}
         for line in f.readlines():
-            sha, name = line.rstrip(b'\n').split(b'\t')
+            sha, name = line.rstrip(b"\n").split(b"\t")
             if name.endswith(ANNOTATED_TAG_SUFFIX):
                 name = name[:-3]
                 if not check_ref_format(name):
@@ -546,32 +640,35 @@ class InfoRefsContainer(RefsContainer):
 
 
 class _InotifyRefsWatcher(object):
-
     def __init__(self, path):
         import pyinotify
         from queue import Queue
+
         self.path = os.fsdecode(path)
         self.manager = pyinotify.WatchManager()
         self.manager.add_watch(
-            self.path, pyinotify.IN_DELETE |
-            pyinotify.IN_CLOSE_WRITE | pyinotify.IN_MOVED_TO, rec=True,
-            auto_add=True)
+            self.path,
+            pyinotify.IN_DELETE | pyinotify.IN_CLOSE_WRITE | pyinotify.IN_MOVED_TO,
+            rec=True,
+            auto_add=True,
+        )
 
         self.notifier = pyinotify.ThreadedNotifier(
-            self.manager, default_proc_fun=self._notify)
+            self.manager, default_proc_fun=self._notify
+        )
         self.queue = Queue()
 
     def _notify(self, event):
         if event.dir:
             return
-        if event.pathname.endswith('.lock'):
+        if event.pathname.endswith(".lock"):
             return
         ref = os.fsencode(os.path.relpath(event.pathname, self.path))
-        if event.maskname == 'IN_DELETE':
+        if event.maskname == "IN_DELETE":
             self.queue.put_nowait((ref, None))
-        elif event.maskname in ('IN_CLOSE_WRITE', 'IN_MOVED_TO'):
-            with open(event.pathname, 'rb') as f:
-                sha = f.readline().rstrip(b'\n\r')
+        elif event.maskname in ("IN_CLOSE_WRITE", "IN_MOVED_TO"):
+            with open(event.pathname, "rb") as f:
+                sha = f.readline().rstrip(b"\n\r")
                 self.queue.put_nowait((ref, sha))
 
     def __next__(self):
@@ -591,12 +688,12 @@ class DiskRefsContainer(RefsContainer):
 
     def __init__(self, path, worktree_path=None, logger=None):
         super(DiskRefsContainer, self).__init__(logger=logger)
-        if getattr(path, 'encode', None) is not None:
+        if getattr(path, "encode", None) is not None:
             path = os.fsencode(path)
         self.path = path
         if worktree_path is None:
             worktree_path = path
-        if getattr(worktree_path, 'encode', None) is not None:
+        if getattr(worktree_path, "encode", None) is not None:
             worktree_path = os.fsencode(worktree_path)
         self.worktree_path = worktree_path
         self._packed_refs = None
@@ -609,30 +706,30 @@ class DiskRefsContainer(RefsContainer):
         subkeys = set()
         path = self.refpath(base)
         for root, unused_dirs, files in os.walk(path):
-            dir = root[len(path):]
-            if os.path.sep != '/':
+            dir = root[len(path) :]
+            if os.path.sep != "/":
                 dir = dir.replace(os.fsencode(os.path.sep), b"/")
-            dir = dir.strip(b'/')
+            dir = dir.strip(b"/")
             for filename in files:
                 refname = b"/".join(([dir] if dir else []) + [filename])
                 # check_ref_format requires at least one /, so we prepend the
                 # base before calling it.
-                if check_ref_format(base + b'/' + refname):
+                if check_ref_format(base + b"/" + refname):
                     subkeys.add(refname)
         for key in self.get_packed_refs():
             if key.startswith(base):
-                subkeys.add(key[len(base):].strip(b'/'))
+                subkeys.add(key[len(base) :].strip(b"/"))
         return subkeys
 
     def allkeys(self):
         allkeys = set()
-        if os.path.exists(self.refpath(b'HEAD')):
-            allkeys.add(b'HEAD')
-        path = self.refpath(b'')
-        refspath = self.refpath(b'refs')
+        if os.path.exists(self.refpath(b"HEAD")):
+            allkeys.add(b"HEAD")
+        path = self.refpath(b"")
+        refspath = self.refpath(b"refs")
         for root, unused_dirs, files in os.walk(refspath):
-            dir = root[len(path):]
-            if os.path.sep != '/':
+            dir = root[len(path) :]
+            if os.path.sep != "/":
                 dir = dir.replace(os.fsencode(os.path.sep), b"/")
             for filename in files:
                 refname = b"/".join([dir, filename])
@@ -642,14 +739,12 @@ class DiskRefsContainer(RefsContainer):
         return allkeys
 
     def refpath(self, name):
-        """Return the disk path of a ref.
-
-        """
+        """Return the disk path of a ref."""
         if os.path.sep != "/":
             name = name.replace(b"/", os.fsencode(os.path.sep))
         # TODO: as the 'HEAD' reference is working tree specific, it
         # should actually not be a part of RefsContainer
-        if name == b'HEAD':
+        if name == b"HEAD":
             return os.path.join(self.worktree_path, name)
         else:
             return os.path.join(self.path, name)
@@ -668,15 +763,14 @@ class DiskRefsContainer(RefsContainer):
             # None if and only if _packed_refs is also None.
             self._packed_refs = {}
             self._peeled_refs = {}
-            path = os.path.join(self.path, b'packed-refs')
+            path = os.path.join(self.path, b"packed-refs")
             try:
-                f = GitFile(path, 'rb')
+                f = GitFile(path, "rb")
             except FileNotFoundError:
                 return {}
             with f:
                 first_line = next(iter(f)).rstrip()
-                if (first_line.startswith(b'# pack-refs') and b' peeled' in
-                        first_line):
+                if first_line.startswith(b"# pack-refs") and b" peeled" in first_line:
                     for sha, name, peeled in read_packed_refs_with_peeled(f):
                         self._packed_refs[name] = sha
                         if peeled:
@@ -721,11 +815,11 @@ class DiskRefsContainer(RefsContainer):
         """
         filename = self.refpath(name)
         try:
-            with GitFile(filename, 'rb') as f:
+            with GitFile(filename, "rb") as f:
                 header = f.read(len(SYMREF))
                 if header == SYMREF:
                     # Read only the first line
-                    return header + next(iter(f)).rstrip(b'\r\n')
+                    return header + next(iter(f)).rstrip(b"\r\n")
                 else:
                     # Read only the first 40 bytes
                     return header + f.read(40 - len(SYMREF))
@@ -735,9 +829,9 @@ class DiskRefsContainer(RefsContainer):
     def _remove_packed_ref(self, name):
         if self._packed_refs is None:
             return
-        filename = os.path.join(self.path, b'packed-refs')
+        filename = os.path.join(self.path, b"packed-refs")
         # reread cached refs from disk, while holding the lock
-        f = GitFile(filename, 'wb')
+        f = GitFile(filename, "wb")
         try:
             self._packed_refs = None
             self.get_packed_refs()
@@ -753,8 +847,15 @@ class DiskRefsContainer(RefsContainer):
         finally:
             f.abort()
 
-    def set_symbolic_ref(self, name, other, committer=None, timestamp=None,
-                         timezone=None, message=None):
+    def set_symbolic_ref(
+        self,
+        name,
+        other,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Make a ref point at another ref.
 
         Args:
@@ -765,21 +866,35 @@ class DiskRefsContainer(RefsContainer):
         self._check_refname(name)
         self._check_refname(other)
         filename = self.refpath(name)
-        f = GitFile(filename, 'wb')
+        f = GitFile(filename, "wb")
         try:
-            f.write(SYMREF + other + b'\n')
+            f.write(SYMREF + other + b"\n")
             sha = self.follow(name)[-1]
-            self._log(name, sha, sha, committer=committer,
-                      timestamp=timestamp, timezone=timezone,
-                      message=message)
+            self._log(
+                name,
+                sha,
+                sha,
+                committer=committer,
+                timestamp=timestamp,
+                timezone=timezone,
+                message=message,
+            )
         except BaseException:
             f.abort()
             raise
         else:
             f.close()
 
-    def set_if_equals(self, name, old_ref, new_ref, committer=None,
-                      timestamp=None, timezone=None, message=None):
+    def set_if_equals(
+        self,
+        name,
+        old_ref,
+        new_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Set a refname to new_ref only if it currently equals old_ref.
 
         This method follows all symbolic references, and can be used to perform
@@ -810,14 +925,13 @@ class DiskRefsContainer(RefsContainer):
             probe_ref = os.path.dirname(probe_ref)
 
         ensure_dir_exists(os.path.dirname(filename))
-        with GitFile(filename, 'wb') as f:
+        with GitFile(filename, "wb") as f:
             if old_ref is not None:
                 try:
                     # read again while holding the lock
                     orig_ref = self.read_loose_ref(realname)
                     if orig_ref is None:
-                        orig_ref = self.get_packed_refs().get(
-                                realname, ZERO_SHA)
+                        orig_ref = self.get_packed_refs().get(realname, ZERO_SHA)
                     if orig_ref != old_ref:
                         f.abort()
                         return False
@@ -825,16 +939,30 @@ class DiskRefsContainer(RefsContainer):
                     f.abort()
                     raise
             try:
-                f.write(new_ref + b'\n')
+                f.write(new_ref + b"\n")
             except (OSError, IOError):
                 f.abort()
                 raise
-            self._log(realname, old_ref, new_ref, committer=committer,
-                      timestamp=timestamp, timezone=timezone, message=message)
+            self._log(
+                realname,
+                old_ref,
+                new_ref,
+                committer=committer,
+                timestamp=timestamp,
+                timezone=timezone,
+                message=message,
+            )
         return True
 
-    def add_if_new(self, name, ref, committer=None, timestamp=None,
-                   timezone=None, message=None):
+    def add_if_new(
+        self,
+        name,
+        ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Add a new reference only if it does not already exist.
 
         This method follows symrefs, and only ensures that the last ref in the
@@ -856,23 +984,36 @@ class DiskRefsContainer(RefsContainer):
         self._check_refname(realname)
         filename = self.refpath(realname)
         ensure_dir_exists(os.path.dirname(filename))
-        with GitFile(filename, 'wb') as f:
+        with GitFile(filename, "wb") as f:
             if os.path.exists(filename) or name in self.get_packed_refs():
                 f.abort()
                 return False
             try:
-                f.write(ref + b'\n')
+                f.write(ref + b"\n")
             except (OSError, IOError):
                 f.abort()
                 raise
             else:
-                self._log(name, None, ref, committer=committer,
-                          timestamp=timestamp, timezone=timezone,
-                          message=message)
+                self._log(
+                    name,
+                    None,
+                    ref,
+                    committer=committer,
+                    timestamp=timestamp,
+                    timezone=timezone,
+                    message=message,
+                )
         return True
 
-    def remove_if_equals(self, name, old_ref, committer=None, timestamp=None,
-                         timezone=None, message=None):
+    def remove_if_equals(
+        self,
+        name,
+        old_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Remove a refname only if it currently equals old_ref.
 
         This method does not follow symbolic references. It can be used to
@@ -888,7 +1029,7 @@ class DiskRefsContainer(RefsContainer):
         self._check_refname(name)
         filename = self.refpath(name)
         ensure_dir_exists(os.path.dirname(filename))
-        f = GitFile(filename, 'wb')
+        f = GitFile(filename, "wb")
         try:
             if old_ref is not None:
                 orig_ref = self.read_loose_ref(name)
@@ -904,8 +1045,15 @@ class DiskRefsContainer(RefsContainer):
                 pass  # may only be packed
 
             self._remove_packed_ref(name)
-            self._log(name, old_ref, None, committer=committer,
-                      timestamp=timestamp, timezone=timezone, message=message)
+            self._log(
+                name,
+                old_ref,
+                None,
+                committer=committer,
+                timestamp=timestamp,
+                timezone=timezone,
+                message=message,
+            )
         finally:
             # never write, we just wanted the lock
             f.abort()
@@ -916,10 +1064,12 @@ class DiskRefsContainer(RefsContainer):
         parent = name
         while True:
             try:
-                parent, _ = parent.rsplit(b'/', 1)
+                parent, _ = parent.rsplit(b"/", 1)
             except ValueError:
                 break
 
+            if parent == b'refs':
+                break
             parent_filename = self.refpath(parent)
             try:
                 os.rmdir(parent_filename)
@@ -934,12 +1084,13 @@ class DiskRefsContainer(RefsContainer):
 
     def watch(self):
         import pyinotify  # noqa: F401
+
         return _InotifyRefsWatcher(self.path)
 
 
 def _split_ref_line(line):
     """Split a single ref line into a tuple of SHA1 and name."""
-    fields = line.rstrip(b'\n\r').split(b' ')
+    fields = line.rstrip(b"\n\r").split(b" ")
     if len(fields) != 2:
         raise PackedRefsException("invalid ref line %r" % line)
     sha, name = fields
@@ -958,12 +1109,11 @@ def read_packed_refs(f):
     Returns: Iterator over tuples with SHA1s and ref names.
     """
     for line in f:
-        if line.startswith(b'#'):
+        if line.startswith(b"#"):
             # Comment
             continue
-        if line.startswith(b'^'):
-            raise PackedRefsException(
-              "found peeled ref in packed-refs without peeled")
+        if line.startswith(b"^"):
+            raise PackedRefsException("found peeled ref in packed-refs without peeled")
         yield _split_ref_line(line)
 
 
@@ -978,10 +1128,10 @@ def read_packed_refs_with_peeled(f):
     """
     last = None
     for line in f:
-        if line[0] == b'#':
+        if line[0] == b"#":
             continue
-        line = line.rstrip(b'\r\n')
-        if line.startswith(b'^'):
+        line = line.rstrip(b"\r\n")
+        if line.startswith(b"^"):
             if not last:
                 raise PackedRefsException("unexpected peeled ref line")
             if not valid_hexsha(line[1:]):
@@ -1010,11 +1160,11 @@ def write_packed_refs(f, packed_refs, peeled_refs=None):
     if peeled_refs is None:
         peeled_refs = {}
     else:
-        f.write(b'# pack-refs with: peeled\n')
+        f.write(b"# pack-refs with: peeled\n")
     for refname in sorted(packed_refs.keys()):
         f.write(git_line(packed_refs[refname], refname))
         if refname in peeled_refs:
-            f.write(b'^' + peeled_refs[refname] + b'\n')
+            f.write(b"^" + peeled_refs[refname] + b"\n")
 
 
 def read_info_refs(f):
@@ -1030,16 +1180,16 @@ def write_info_refs(refs, store):
     for name, sha in sorted(refs.items()):
         # get_refs() includes HEAD as a special case, but we don't want to
         # advertise it
-        if name == b'HEAD':
+        if name == b"HEAD":
             continue
         try:
             o = store[sha]
         except KeyError:
             continue
         peeled = store.peel_sha(sha)
-        yield o.id + b'\t' + name + b'\n'
+        yield o.id + b"\t" + name + b"\n"
         if o.id != peeled.id:
-            yield peeled.id + b'\t' + name + ANNOTATED_TAG_SUFFIX + b'\n'
+            yield peeled.id + b"\t" + name + ANNOTATED_TAG_SUFFIX + b"\n"
 
 
 def is_local_branch(x):
@@ -1048,5 +1198,8 @@ def is_local_branch(x):
 
 def strip_peeled_refs(refs):
     """Remove all peeled refs"""
-    return {ref: sha for (ref, sha) in refs.items()
-            if not ref.endswith(ANNOTATED_TAG_SUFFIX)}
+    return {
+        ref: sha
+        for (ref, sha) in refs.items()
+        if not ref.endswith(ANNOTATED_TAG_SUFFIX)
+    }

文件差異過大導致無法顯示
+ 278 - 191
dulwich/repo.py


+ 209 - 157
dulwich/server.py

@@ -61,15 +61,15 @@ from dulwich.errors import (
     NotGitRepository,
     UnexpectedCommandError,
     ObjectFormatException,
-    )
+)
 from dulwich import log_utils
 from dulwich.objects import (
     Commit,
     valid_hexsha,
-    )
+)
 from dulwich.pack import (
     write_pack_objects,
-    )
+)
 from dulwich.protocol import (  # noqa: F401
     BufferedPktLineWriter,
     capability_agent,
@@ -108,15 +108,15 @@ from dulwich.protocol import (  # noqa: F401
     extract_capabilities,
     extract_want_line_capabilities,
     symref_capabilities,
-    )
+)
 from dulwich.refs import (
     ANNOTATED_TAG_SUFFIX,
     write_info_refs,
-    )
+)
 from dulwich.repo import (
     BaseRepo,
     Repo,
-    )
+)
 
 
 logger = log_utils.getLogger(__name__)
@@ -167,8 +167,7 @@ class BackendRepo(object):
         """
         return None
 
-    def fetch_objects(self, determine_wants, graph_walker, progress,
-                      get_tagged=None):
+    def fetch_objects(self, determine_wants, graph_walker, progress, get_tagged=None):
         """
         Yield the objects required for a list of commits.
 
@@ -187,7 +186,7 @@ class DictBackend(Backend):
         self.repos = repos
 
     def open_repository(self, path: str) -> BaseRepo:
-        logger.debug('Opening repository at %s', path)
+        logger.debug("Opening repository at %s", path)
         try:
             return self.repos[path]
         except KeyError:
@@ -201,18 +200,15 @@ class FileSystemBackend(Backend):
 
     def __init__(self, root=os.sep):
         super(FileSystemBackend, self).__init__()
-        self.root = (os.path.abspath(root) + os.sep).replace(
-                os.sep * 2, os.sep)
+        self.root = (os.path.abspath(root) + os.sep).replace(os.sep * 2, os.sep)
 
     def open_repository(self, path):
-        logger.debug('opening repository at %s', path)
+        logger.debug("opening repository at %s", path)
         abspath = os.path.abspath(os.path.join(self.root, path)) + os.sep
         normcase_abspath = os.path.normcase(abspath)
         normcase_root = os.path.normcase(self.root)
         if not normcase_abspath.startswith(normcase_root):
-            raise NotGitRepository(
-                    "Path %r not inside root %r" %
-                    (path, self.root))
+            raise NotGitRepository("Path %r not inside root %r" % (path, self.root))
         return Repo(abspath)
 
 
@@ -239,7 +235,7 @@ class PackHandler(Handler):
 
     @classmethod
     def capability_line(cls, capabilities):
-        logger.info('Sending capabilities: %s', capabilities)
+        logger.info("Sending capabilities: %s", capabilities)
         return b"".join([b" " + c for c in capabilities])
 
     @classmethod
@@ -248,9 +244,13 @@ class PackHandler(Handler):
 
     @classmethod
     def innocuous_capabilities(cls) -> Iterable[bytes]:
-        return [CAPABILITY_INCLUDE_TAG, CAPABILITY_THIN_PACK,
-                CAPABILITY_NO_PROGRESS, CAPABILITY_OFS_DELTA,
-                capability_agent()]
+        return [
+            CAPABILITY_INCLUDE_TAG,
+            CAPABILITY_THIN_PACK,
+            CAPABILITY_NO_PROGRESS,
+            CAPABILITY_OFS_DELTA,
+            capability_agent(),
+        ]
 
     @classmethod
     def required_capabilities(cls) -> Iterable[bytes]:
@@ -261,22 +261,25 @@ class PackHandler(Handler):
         allowable_caps = set(self.innocuous_capabilities())
         allowable_caps.update(self.capabilities())
         for cap in caps:
-            if cap.startswith(CAPABILITY_AGENT + b'='):
+            if cap.startswith(CAPABILITY_AGENT + b"="):
                 continue
             if cap not in allowable_caps:
-                raise GitProtocolError('Client asked for capability %r that '
-                                       'was not advertised.' % cap)
+                raise GitProtocolError(
+                    "Client asked for capability %r that " "was not advertised." % cap
+                )
         for cap in self.required_capabilities():
             if cap not in caps:
-                raise GitProtocolError('Client does not support required '
-                                       'capability %r.' % cap)
+                raise GitProtocolError(
+                    "Client does not support required " "capability %r." % cap
+                )
         self._client_capabilities = set(caps)
-        logger.info('Client capabilities: %s', caps)
+        logger.info("Client capabilities: %s", caps)
 
     def has_capability(self, cap: bytes) -> bool:
         if self._client_capabilities is None:
-            raise GitProtocolError('Server attempted to access capability %r '
-                                   'before asking client' % cap)
+            raise GitProtocolError(
+                "Server attempted to access capability %r " "before asking client" % cap
+            )
         return cap in self._client_capabilities
 
     def notify_done(self) -> None:
@@ -286,10 +289,10 @@ class PackHandler(Handler):
 class UploadPackHandler(PackHandler):
     """Protocol handler for uploading a pack to the client."""
 
-    def __init__(self, backend, args, proto, stateless_rpc=None,
-                 advertise_refs=False):
+    def __init__(self, backend, args, proto, stateless_rpc=None, advertise_refs=False):
         super(UploadPackHandler, self).__init__(
-                backend, proto, stateless_rpc=stateless_rpc)
+            backend, proto, stateless_rpc=stateless_rpc
+        )
         self.repo = backend.open_repository(args[0])
         self._graph_walker = None
         self.advertise_refs = advertise_refs
@@ -300,19 +303,28 @@ class UploadPackHandler(PackHandler):
 
     @classmethod
     def capabilities(cls):
-        return [CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_MULTI_ACK,
-                CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
-                CAPABILITY_OFS_DELTA, CAPABILITY_NO_PROGRESS,
-                CAPABILITY_INCLUDE_TAG, CAPABILITY_SHALLOW, CAPABILITY_NO_DONE]
+        return [
+            CAPABILITY_MULTI_ACK_DETAILED,
+            CAPABILITY_MULTI_ACK,
+            CAPABILITY_SIDE_BAND_64K,
+            CAPABILITY_THIN_PACK,
+            CAPABILITY_OFS_DELTA,
+            CAPABILITY_NO_PROGRESS,
+            CAPABILITY_INCLUDE_TAG,
+            CAPABILITY_SHALLOW,
+            CAPABILITY_NO_DONE,
+        ]
 
     @classmethod
     def required_capabilities(cls):
-        return (CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
-                CAPABILITY_OFS_DELTA)
+        return (
+            CAPABILITY_SIDE_BAND_64K,
+            CAPABILITY_THIN_PACK,
+            CAPABILITY_OFS_DELTA,
+        )
 
     def progress(self, message):
-        if (self.has_capability(CAPABILITY_NO_PROGRESS) or
-                self._processing_have_lines):
+        if self.has_capability(CAPABILITY_NO_PROGRESS) or self._processing_have_lines:
             return
         self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
 
@@ -353,17 +365,23 @@ class UploadPackHandler(PackHandler):
             return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
 
         graph_walker = _ProtocolGraphWalker(
-                self, self.repo.object_store, self.repo.get_peeled,
-                self.repo.refs.get_symrefs)
+            self,
+            self.repo.object_store,
+            self.repo.get_peeled,
+            self.repo.refs.get_symrefs,
+        )
         wants = []
 
-        def wants_wrapper(refs):
-            wants.extend(graph_walker.determine_wants(refs))
+        def wants_wrapper(refs, **kwargs):
+            wants.extend(graph_walker.determine_wants(refs, **kwargs))
             return wants
 
         objects_iter = self.repo.fetch_objects(
-            wants_wrapper, graph_walker, self.progress,
-            get_tagged=self.get_tagged)
+            wants_wrapper,
+            graph_walker,
+            self.progress,
+            get_tagged=self.get_tagged,
+        )
 
         # Note the fact that client is only processing responses related
         # to the have lines it sent, and any other data (including side-
@@ -384,13 +402,13 @@ class UploadPackHandler(PackHandler):
         self._processing_have_lines = False
 
         if not graph_walker.handle_done(
-                not self.has_capability(CAPABILITY_NO_DONE),
-                self._done_received):
+            not self.has_capability(CAPABILITY_NO_DONE), self._done_received
+        ):
             return
 
         self.progress(
-                ("counting objects: %d, done.\n" % len(objects_iter)).encode(
-                    'ascii'))
+            ("counting objects: %d, done.\n" % len(objects_iter)).encode("ascii")
+        )
         write_pack_objects(ProtocolFile(None, write), objects_iter)
         # we are done
         self.proto.write_pkt_line(None)
@@ -418,21 +436,25 @@ def _split_proto_line(line, allowed):
     if not line:
         fields = [None]
     else:
-        fields = line.rstrip(b'\n').split(b' ', 1)
+        fields = line.rstrip(b"\n").split(b" ", 1)
     command = fields[0]
     if allowed is not None and command not in allowed:
         raise UnexpectedCommandError(command)
     if len(fields) == 1 and command in (COMMAND_DONE, None):
         return (command, None)
     elif len(fields) == 2:
-        if command in (COMMAND_WANT, COMMAND_HAVE, COMMAND_SHALLOW,
-                       COMMAND_UNSHALLOW):
+        if command in (
+            COMMAND_WANT,
+            COMMAND_HAVE,
+            COMMAND_SHALLOW,
+            COMMAND_UNSHALLOW,
+        ):
             if not valid_hexsha(fields[1]):
                 raise GitProtocolError("Invalid sha")
             return tuple(fields)
         elif command == COMMAND_DEEPEN:
             return command, int(fields[1])
-    raise GitProtocolError('Received invalid line from client: %r' % line)
+    raise GitProtocolError("Received invalid line from client: %r" % line)
 
 
 def _find_shallow(store, heads, depth):
@@ -533,6 +555,7 @@ class _ProtocolGraphWalker(object):
     call to set_ack_type() is required to set up the implementation, before
     any calls to next() or ack() are made.
     """
+
     def __init__(self, handler, object_store, get_peeled, get_symrefs):
         self.handler = handler
         self.store = object_store
@@ -550,7 +573,7 @@ class _ProtocolGraphWalker(object):
         self._cache_index = 0
         self._impl = None
 
-    def determine_wants(self, heads):
+    def determine_wants(self, heads, depth=None):
         """Determine the wants for a set of heads.
 
         The given heads are advertised to the client, who then specifies which
@@ -579,16 +602,17 @@ class _ProtocolGraphWalker(object):
                     # TODO(jelmer): Integrate with Repo.fetch_objects refs
                     # logic.
                     continue
-                line = sha + b' ' + ref
+                line = sha + b" " + ref
                 if not i:
-                    line += (b'\x00' +
-                             self.handler.capability_line(
-                                 self.handler.capabilities() +
-                                 symref_capabilities(symrefs.items())))
-                self.proto.write_pkt_line(line + b'\n')
+                    line += b"\x00" + self.handler.capability_line(
+                        self.handler.capabilities()
+                        + symref_capabilities(symrefs.items())
+                    )
+                self.proto.write_pkt_line(line + b"\n")
                 if peeled_sha != sha:
                     self.proto.write_pkt_line(
-                        peeled_sha + b' ' + ref + ANNOTATED_TAG_SUFFIX + b'\n')
+                        peeled_sha + b" " + ref + ANNOTATED_TAG_SUFFIX + b"\n"
+                    )
 
             # i'm done..
             self.proto.write_pkt_line(None)
@@ -609,8 +633,7 @@ class _ProtocolGraphWalker(object):
         want_revs = []
         while command == COMMAND_WANT:
             if sha not in values:
-                raise GitProtocolError(
-                  'Client wants invalid object %s' % sha)
+                raise GitProtocolError("Client wants invalid object %s" % sha)
             want_revs.append(sha)
             command, sha = self.read_proto_line(allowed)
 
@@ -630,8 +653,8 @@ class _ProtocolGraphWalker(object):
 
     def unread_proto_line(self, command, value):
         if isinstance(value, int):
-            value = str(value).encode('ascii')
-        self.proto.unread_pkt_line(command + b' ' + value)
+            value = str(value).encode("ascii")
+        self.proto.unread_pkt_line(command + b" " + value)
 
     def ack(self, have_ref):
         if len(have_ref) != 40:
@@ -667,8 +690,7 @@ class _ProtocolGraphWalker(object):
 
     def _handle_shallow_request(self, wants):
         while True:
-            command, val = self.read_proto_line(
-                    (COMMAND_DEEPEN, COMMAND_SHALLOW))
+            command, val = self.read_proto_line((COMMAND_DEEPEN, COMMAND_SHALLOW))
             if command == COMMAND_DEEPEN:
                 depth = val
                 break
@@ -684,9 +706,9 @@ class _ProtocolGraphWalker(object):
         unshallow = self.unshallow = not_shallow & self.client_shallow
 
         for sha in sorted(new_shallow):
-            self.proto.write_pkt_line(COMMAND_SHALLOW + b' ' + sha)
+            self.proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha)
         for sha in sorted(unshallow):
-            self.proto.write_pkt_line(COMMAND_UNSHALLOW + b' ' + sha)
+            self.proto.write_pkt_line(COMMAND_UNSHALLOW + b" " + sha)
 
         self.proto.write_pkt_line(None)
 
@@ -694,13 +716,13 @@ class _ProtocolGraphWalker(object):
         # relay the message down to the handler.
         self.handler.notify_done()
 
-    def send_ack(self, sha, ack_type=b''):
+    def send_ack(self, sha, ack_type=b""):
         if ack_type:
-            ack_type = b' ' + ack_type
-        self.proto.write_pkt_line(b'ACK ' + sha + ack_type + b'\n')
+            ack_type = b" " + ack_type
+        self.proto.write_pkt_line(b"ACK " + sha + ack_type + b"\n")
 
     def send_nak(self):
-        self.proto.write_pkt_line(b'NAK\n')
+        self.proto.write_pkt_line(b"NAK\n")
 
     def handle_done(self, done_required, done_received):
         # Delegate this to the implementation.
@@ -721,10 +743,10 @@ class _ProtocolGraphWalker(object):
 
     def set_ack_type(self, ack_type):
         impl_classes = {
-          MULTI_ACK: MultiAckGraphWalkerImpl,
-          MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
-          SINGLE_ACK: SingleAckGraphWalkerImpl,
-          }
+            MULTI_ACK: MultiAckGraphWalkerImpl,
+            MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
+            SINGLE_ACK: SingleAckGraphWalkerImpl,
+        }
         self._impl = impl_classes[ack_type](self)
 
 
@@ -786,7 +808,7 @@ class MultiAckGraphWalkerImpl(object):
     def ack(self, have_ref):
         self._common.append(have_ref)
         if not self._found_base:
-            self.walker.send_ack(have_ref, b'continue')
+            self.walker.send_ack(have_ref, b"continue")
             if self.walker.all_wants_satisfied(self._common):
                 self._found_base = True
         # else we blind ack within next
@@ -805,7 +827,7 @@ class MultiAckGraphWalkerImpl(object):
             elif command == COMMAND_HAVE:
                 if self._found_base:
                     # blind ack
-                    self.walker.send_ack(sha, b'continue')
+                    self.walker.send_ack(sha, b"continue")
                 return sha
 
     __next__ = next
@@ -844,14 +866,14 @@ class MultiAckDetailedGraphWalkerImpl(object):
     def ack(self, have_ref):
         # Should only be called iff have_ref is common
         self._common.append(have_ref)
-        self.walker.send_ack(have_ref, b'common')
+        self.walker.send_ack(have_ref, b"common")
 
     def next(self):
         while True:
             command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
             if command is None:
                 if self.walker.all_wants_satisfied(self._common):
-                    self.walker.send_ack(self._common[-1], b'ready')
+                    self.walker.send_ack(self._common[-1], b"ready")
                 self.walker.send_nak()
                 if self.walker.stateless_rpc:
                     # The HTTP version of this request a flush-pkt always
@@ -902,25 +924,37 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(PackHandler):
     """Protocol handler for downloading a pack from the client."""
 
-    def __init__(self, backend, args, proto, stateless_rpc=None,
-                 advertise_refs=False):
+    def __init__(self, backend, args, proto, stateless_rpc=None, advertise_refs=False):
         super(ReceivePackHandler, self).__init__(
-                backend, proto, stateless_rpc=stateless_rpc)
+            backend, proto, stateless_rpc=stateless_rpc
+        )
         self.repo = backend.open_repository(args[0])
         self.advertise_refs = advertise_refs
 
     @classmethod
     def capabilities(cls) -> Iterable[bytes]:
-        return [CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS,
-                CAPABILITY_QUIET, CAPABILITY_OFS_DELTA,
-                CAPABILITY_SIDE_BAND_64K, CAPABILITY_NO_DONE]
+        return [
+            CAPABILITY_REPORT_STATUS,
+            CAPABILITY_DELETE_REFS,
+            CAPABILITY_QUIET,
+            CAPABILITY_OFS_DELTA,
+            CAPABILITY_SIDE_BAND_64K,
+            CAPABILITY_NO_DONE,
+        ]
 
     def _apply_pack(
-            self, refs: List[Tuple[bytes, bytes, bytes]]
-            ) -> List[Tuple[bytes, bytes]]:
-        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
-                          AssertionError, socket.error, zlib.error,
-                          ObjectFormatException)
+        self, refs: List[Tuple[bytes, bytes, bytes]]
+    ) -> List[Tuple[bytes, bytes]]:
+        all_exceptions = (
+            IOError,
+            OSError,
+            ChecksumMismatch,
+            ApplyDeltaError,
+            AssertionError,
+            socket.error,
+            zlib.error,
+            ObjectFormatException,
+        )
         status = []
         will_send_pack = False
 
@@ -934,36 +968,36 @@ class ReceivePackHandler(PackHandler):
             try:
                 recv = getattr(self.proto, "recv", None)
                 self.repo.object_store.add_thin_pack(self.proto.read, recv)
-                status.append((b'unpack', b'ok'))
+                status.append((b"unpack", b"ok"))
             except all_exceptions as e:
-                status.append(
-                    (b'unpack', str(e).replace('\n', '').encode('utf-8')))
+                status.append((b"unpack", str(e).replace("\n", "").encode("utf-8")))
                 # The pack may still have been moved in, but it may contain
                 # broken objects. We trust a later GC to clean it up.
         else:
             # The git protocol want to find a status entry related to unpack
             # process even if no pack data has been sent.
-            status.append((b'unpack', b'ok'))
+            status.append((b"unpack", b"ok"))
 
         for oldsha, sha, ref in refs:
-            ref_status = b'ok'
+            ref_status = b"ok"
             try:
                 if sha == ZERO_SHA:
                     if CAPABILITY_DELETE_REFS not in self.capabilities():
                         raise GitProtocolError(
-                          'Attempted to delete refs without delete-refs '
-                          'capability.')
+                            "Attempted to delete refs without delete-refs "
+                            "capability."
+                        )
                     try:
                         self.repo.refs.remove_if_equals(ref, oldsha)
                     except all_exceptions:
-                        ref_status = b'failed to delete'
+                        ref_status = b"failed to delete"
                 else:
                     try:
                         self.repo.refs.set_if_equals(ref, oldsha, sha)
                     except all_exceptions:
-                        ref_status = b'failed to write'
+                        ref_status = b"failed to write"
             except KeyError:
-                ref_status = b'bad ref'
+                ref_status = b"bad ref"
             status.append((ref, ref_status))
 
         return status
@@ -971,12 +1005,14 @@ class ReceivePackHandler(PackHandler):
     def _report_status(self, status: List[Tuple[bytes, bytes]]) -> None:
         if self.has_capability(CAPABILITY_SIDE_BAND_64K):
             writer = BufferedPktLineWriter(
-              lambda d: self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, d))
+                lambda d: self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, d)
+            )
             write = writer.write
 
             def flush():
                 writer.flush()
                 self.proto.write_pkt_line(None)
+
         else:
             write = self.proto.write_pkt_line
 
@@ -984,17 +1020,17 @@ class ReceivePackHandler(PackHandler):
                 pass
 
         for name, msg in status:
-            if name == b'unpack':
-                write(b'unpack ' + msg + b'\n')
-            elif msg == b'ok':
-                write(b'ok ' + name + b'\n')
+            if name == b"unpack":
+                write(b"unpack " + msg + b"\n")
+            elif msg == b"ok":
+                write(b"ok " + name + b"\n")
             else:
-                write(b'ng ' + name + b' ' + msg + b'\n')
+                write(b"ng " + name + b" " + msg + b"\n")
         write(None)
         flush()
 
     def _on_post_receive(self, client_refs):
-        hook = self.repo.hooks.get('post-receive', None)
+        hook = self.repo.hooks.get("post-receive", None)
         if not hook:
             return
         try:
@@ -1002,7 +1038,7 @@ class ReceivePackHandler(PackHandler):
             if output:
                 self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, output)
         except HookError as err:
-            self.proto.write_sideband(SIDE_BAND_CHANNEL_FATAL, repr(err))
+            self.proto.write_sideband(SIDE_BAND_CHANNEL_FATAL, str(err).encode('utf-8'))
 
     def handle(self) -> None:
         if self.advertise_refs or not self.stateless_rpc:
@@ -1012,12 +1048,18 @@ class ReceivePackHandler(PackHandler):
             if not refs:
                 refs = [(CAPABILITIES_REF, ZERO_SHA)]
             self.proto.write_pkt_line(
-              refs[0][1] + b' ' + refs[0][0] + b'\0' +
-              self.capability_line(
-                  self.capabilities() + symref_capabilities(symrefs)) + b'\n')
+                refs[0][1]
+                + b" "
+                + refs[0][0]
+                + b"\0"
+                + self.capability_line(
+                    self.capabilities() + symref_capabilities(symrefs)
+                )
+                + b"\n"
+            )
             for i in range(1, len(refs)):
                 ref = refs[i]
-                self.proto.write_pkt_line(ref[1] + b' ' + ref[0] + b'\n')
+                self.proto.write_pkt_line(ref[1] + b" " + ref[0] + b"\n")
 
             self.proto.write_pkt_line(None)
             if self.advertise_refs:
@@ -1050,55 +1092,54 @@ class ReceivePackHandler(PackHandler):
 
 
 class UploadArchiveHandler(Handler):
-
     def __init__(self, backend, args, proto, stateless_rpc=None):
-        super(UploadArchiveHandler, self).__init__(
-            backend, proto, stateless_rpc)
+        super(UploadArchiveHandler, self).__init__(backend, proto, stateless_rpc)
         self.repo = backend.open_repository(args[0])
 
     def handle(self):
         def write(x):
             return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
+
         arguments = []
         for pkt in self.proto.read_pkt_seq():
-            (key, value) = pkt.split(b' ', 1)
-            if key != b'argument':
-                raise GitProtocolError('unknown command %s' % key)
-            arguments.append(value.rstrip(b'\n'))
-        prefix = b''
-        format = 'tar'
+            (key, value) = pkt.split(b" ", 1)
+            if key != b"argument":
+                raise GitProtocolError("unknown command %s" % key)
+            arguments.append(value.rstrip(b"\n"))
+        prefix = b""
+        format = "tar"
         i = 0
         store = self.repo.object_store
         while i < len(arguments):
             argument = arguments[i]
-            if argument == b'--prefix':
+            if argument == b"--prefix":
                 i += 1
                 prefix = arguments[i]
-            elif argument == b'--format':
+            elif argument == b"--format":
                 i += 1
-                format = arguments[i].decode('ascii')
+                format = arguments[i].decode("ascii")
             else:
                 commit_sha = self.repo.refs[argument]
                 tree = store[store[commit_sha].tree]
             i += 1
-        self.proto.write_pkt_line(b'ACK')
+        self.proto.write_pkt_line(b"ACK")
         self.proto.write_pkt_line(None)
         for chunk in tar_stream(
-                store, tree, mtime=time.time(), prefix=prefix, format=format):
+            store, tree, mtime=time.time(), prefix=prefix, format=format
+        ):
             write(chunk)
         self.proto.write_pkt_line(None)
 
 
 # Default handler classes for git services.
 DEFAULT_HANDLERS = {
-  b'git-upload-pack': UploadPackHandler,
-  b'git-receive-pack': ReceivePackHandler,
-  b'git-upload-archive': UploadArchiveHandler,
+    b"git-upload-pack": UploadPackHandler,
+    b"git-receive-pack": ReceivePackHandler,
+    b"git-upload-archive": UploadArchiveHandler,
 }
 
 
 class TCPGitRequestHandler(socketserver.StreamRequestHandler):
-
     def __init__(self, handlers, *args, **kwargs):
         self.handlers = handlers
         socketserver.StreamRequestHandler.__init__(self, *args, **kwargs)
@@ -1106,11 +1147,11 @@ class TCPGitRequestHandler(socketserver.StreamRequestHandler):
     def handle(self):
         proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
         command, args = proto.read_cmd()
-        logger.info('Handling %s request, args=%s', command, args)
+        logger.info("Handling %s request, args=%s", command, args)
 
         cls = self.handlers.get(command, None)
         if not callable(cls):
-            raise GitProtocolError('Invalid service %s' % command)
+            raise GitProtocolError("Invalid service %s" % command)
         h = cls(self.server.backend, args, proto)
         h.handle()
 
@@ -1128,45 +1169,56 @@ class TCPGitServer(socketserver.TCPServer):
         if handlers is not None:
             self.handlers.update(handlers)
         self.backend = backend
-        logger.info('Listening for TCP connections on %s:%d',
-                    listen_addr, port)
-        socketserver.TCPServer.__init__(self, (listen_addr, port),
-                                        self._make_handler)
+        logger.info("Listening for TCP connections on %s:%d", listen_addr, port)
+        socketserver.TCPServer.__init__(self, (listen_addr, port), self._make_handler)
 
     def verify_request(self, request, client_address):
-        logger.info('Handling request from %s', client_address)
+        logger.info("Handling request from %s", client_address)
         return True
 
     def handle_error(self, request, client_address):
-        logger.exception('Exception happened during processing of request '
-                         'from %s', client_address)
+        logger.exception(
+            "Exception happened during processing of request " "from %s",
+            client_address,
+        )
 
 
 def main(argv=sys.argv):
     """Entry point for starting a TCP git server."""
     import optparse
+
     parser = optparse.OptionParser()
-    parser.add_option("-l", "--listen_address", dest="listen_address",
-                      default="localhost",
-                      help="Binding IP address.")
-    parser.add_option("-p", "--port", dest="port", type=int,
-                      default=TCP_GIT_PORT,
-                      help="Binding TCP port.")
+    parser.add_option(
+        "-l",
+        "--listen_address",
+        dest="listen_address",
+        default="localhost",
+        help="Binding IP address.",
+    )
+    parser.add_option(
+        "-p",
+        "--port",
+        dest="port",
+        type=int,
+        default=TCP_GIT_PORT,
+        help="Binding TCP port.",
+    )
     options, args = parser.parse_args(argv)
 
     log_utils.default_logging_config()
     if len(args) > 1:
         gitdir = args[1]
     else:
-        gitdir = '.'
+        gitdir = "."
     # TODO(jelmer): Support git-daemon-export-ok and --export-all.
     backend = FileSystemBackend(gitdir)
     server = TCPGitServer(backend, options.listen_address, options.port)
     server.serve_forever()
 
 
-def serve_command(handler_cls, argv=sys.argv, backend=None, inf=sys.stdin,
-                  outf=sys.stdout):
+def serve_command(
+    handler_cls, argv=sys.argv, backend=None, inf=sys.stdin, outf=sys.stdout
+):
     """Serve a single command.
 
     This is mostly useful for the implementation of commands used by e.g.
@@ -1186,6 +1238,7 @@ def serve_command(handler_cls, argv=sys.argv, backend=None, inf=sys.stdin,
     def send_fn(data):
         outf.write(data)
         outf.flush()
+
     proto = Protocol(inf.read, send_fn)
     handler = handler_cls(backend, argv[1:], proto)
     # FIXME: Catch exceptions and write a single-line summary to outf.
@@ -1202,9 +1255,7 @@ def generate_info_refs(repo):
 def generate_objects_info_packs(repo):
     """Generate an index for for packs."""
     for pack in repo.object_store.packs:
-        yield (
-            b'P ' + os.fsencode(pack.data.filename) +
-            b'\n')
+        yield (b"P " + os.fsencode(pack.data.filename) + b"\n")
 
 
 def update_server_info(repo):
@@ -1214,13 +1265,14 @@ def update_server_info(repo):
     similar to "git update-server-info".
     """
     repo._put_named_file(
-        os.path.join('info', 'refs'),
-        b"".join(generate_info_refs(repo)))
+        os.path.join("info", "refs"), b"".join(generate_info_refs(repo))
+    )
 
     repo._put_named_file(
-        os.path.join('objects', 'info', 'packs'),
-        b"".join(generate_objects_info_packs(repo)))
+        os.path.join("objects", "info", "packs"),
+        b"".join(generate_objects_info_packs(repo)),
+    )
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()

+ 37 - 17
dulwich/stash.py

@@ -28,8 +28,8 @@ from dulwich.file import GitFile
 from dulwich.index import (
     commit_tree,
     iter_fresh_objects,
-    )
-from dulwich.reflog import read_reflog
+)
+from dulwich.reflog import drop_reflog_entry, read_reflog
 
 
 DEFAULT_STASH_REF = b"refs/stash"
@@ -45,11 +45,15 @@ class Stash(object):
         self._ref = ref
         self._repo = repo
 
+    @property
+    def _reflog_path(self):
+        return os.path.join(
+            self._repo.commondir(), "logs", os.fsdecode(self._ref)
+        )
+
     def stashes(self):
-        reflog_path = os.path.join(
-            self._repo.commondir(), 'logs', os.fsdecode(self._ref))
         try:
-            with GitFile(reflog_path, 'rb') as f:
+            with GitFile(self._reflog_path, "rb") as f:
                 return reversed(list(read_reflog(f)))
         except FileNotFoundError:
             return []
@@ -61,10 +65,17 @@ class Stash(object):
 
     def drop(self, index):
         """Drop entry with specified index."""
-        raise NotImplementedError(self.drop)
+        with open(self._reflog_path, "rb+") as f:
+            drop_reflog_entry(f, index, rewrite=True)
+        if len(self) == 0:
+            os.remove(self._reflog_path)
+            del self._repo.refs[self._ref]
+            return
+        if index == 0:
+            self._repo.refs[self._ref] = self[0].new_sha
 
     def pop(self, index):
-        raise NotImplementedError(self.drop)
+        raise NotImplementedError(self.pop)
 
     def push(self, committer=None, author=None, message=None):
         """Create a new stash.
@@ -77,24 +88,30 @@ class Stash(object):
         # First, create the index commit.
         commit_kwargs = {}
         if committer is not None:
-            commit_kwargs['committer'] = committer
+            commit_kwargs["committer"] = committer
         if author is not None:
-            commit_kwargs['author'] = author
+            commit_kwargs["author"] = author
 
         index = self._repo.open_index()
         index_tree_id = index.commit(self._repo.object_store)
         index_commit_id = self._repo.do_commit(
-            ref=None, tree=index_tree_id,
+            ref=None,
+            tree=index_tree_id,
             message=b"Index stash",
             merge_heads=[self._repo.head()],
-            **commit_kwargs)
+            no_verify=True,
+            **commit_kwargs
+        )
 
         # Then, the working tree one.
         stash_tree_id = commit_tree(
-                self._repo.object_store,
-                iter_fresh_objects(
-                    index, os.fsencode(self._repo.path),
-                    object_store=self._repo.object_store))
+            self._repo.object_store,
+            iter_fresh_objects(
+                index,
+                os.fsencode(self._repo.path),
+                object_store=self._repo.object_store,
+            ),
+        )
 
         if message is None:
             message = b"A stash on " + self._repo.head()
@@ -103,10 +120,13 @@ class Stash(object):
         self._repo.refs[self._ref] = self._repo.head()
 
         cid = self._repo.do_commit(
-            ref=self._ref, tree=stash_tree_id,
+            ref=self._ref,
+            tree=stash_tree_id,
             message=message,
             merge_heads=[index_commit_id],
-            **commit_kwargs)
+            no_verify=True,
+            **commit_kwargs
+        )
 
         return cid
 

+ 70 - 56
dulwich/tests/__init__.py

@@ -35,15 +35,15 @@ from unittest import (  # noqa: F401
     TestCase as _TestCase,
     skipIf,
     expectedFailure,
-    )
+)
 
 
 class TestCase(_TestCase):
-
     def setUp(self):
         super(TestCase, self).setUp()
         self._old_home = os.environ.get("HOME")
         os.environ["HOME"] = "/nonexistant"
+        os.environ["GIT_CONFIG_NOSYSTEM"] = "1"
 
     def tearDown(self):
         super(TestCase, self).tearDown()
@@ -57,9 +57,11 @@ class BlackboxTestCase(TestCase):
     """Blackbox testing."""
 
     # TODO(jelmer): Include more possible binary paths.
-    bin_directories = [os.path.abspath(os.path.join(
-            os.path.dirname(__file__), "..", "..", "bin")), '/usr/bin',
-            '/usr/local/bin']
+    bin_directories = [
+        os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "bin")),
+        "/usr/bin",
+        "/usr/local/bin",
+    ]
 
     def bin_path(self, name):
         """Determine the full path of a binary.
@@ -92,50 +94,52 @@ class BlackboxTestCase(TestCase):
         # Save us from all that headache and call python with the bin script.
         argv = [sys.executable, self.bin_path(name)] + args
         return subprocess.Popen(
-                argv,
-                stdout=subprocess.PIPE,
-                stdin=subprocess.PIPE, stderr=subprocess.PIPE,
-                env=env)
+            argv,
+            stdout=subprocess.PIPE,
+            stdin=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            env=env,
+        )
 
 
 def self_test_suite():
     names = [
-        'archive',
-        'blackbox',
-        'bundle',
-        'client',
-        'config',
-        'diff_tree',
-        'fastexport',
-        'file',
-        'grafts',
-        'graph',
-        'greenthreads',
-        'hooks',
-        'ignore',
-        'index',
-        'lfs',
-        'line_ending',
-        'lru_cache',
-        'mailmap',
-        'objects',
-        'objectspec',
-        'object_store',
-        'missing_obj_finder',
-        'pack',
-        'patch',
-        'porcelain',
-        'protocol',
-        'reflog',
-        'refs',
-        'repository',
-        'server',
-        'stash',
-        'utils',
-        'walk',
-        'web',
-        ]
-    module_names = ['dulwich.tests.test_' + name for name in names]
+        "archive",
+        "blackbox",
+        "bundle",
+        "client",
+        "config",
+        "diff_tree",
+        "fastexport",
+        "file",
+        "grafts",
+        "graph",
+        "greenthreads",
+        "hooks",
+        "ignore",
+        "index",
+        "lfs",
+        "line_ending",
+        "lru_cache",
+        "mailmap",
+        "objects",
+        "objectspec",
+        "object_store",
+        "missing_obj_finder",
+        "pack",
+        "patch",
+        "porcelain",
+        "protocol",
+        "reflog",
+        "refs",
+        "repository",
+        "server",
+        "stash",
+        "utils",
+        "walk",
+        "web",
+    ]
+    module_names = ["dulwich.tests.test_" + name for name in names]
     loader = unittest.TestLoader()
     return loader.loadTestsFromNames(module_names)
 
@@ -148,28 +152,34 @@ def tutorial_test_suite():
     import dulwich.repo  # noqa: F401
     import dulwich.server  # noqa: F401
     import dulwich.patch  # noqa: F401
+
     tutorial = [
-        'introduction',
-        'file-format',
-        'repo',
-        'object-store',
-        'remote',
-        'conclusion',
-        ]
+        "introduction",
+        "file-format",
+        "repo",
+        "object-store",
+        "remote",
+        "conclusion",
+    ]
     tutorial_files = ["../../docs/tutorial/%s.txt" % name for name in tutorial]
 
     def setup(test):
         test.__old_cwd = os.getcwd()
         test.tempdir = tempfile.mkdtemp()
-        test.globs.update({'tempdir': test.tempdir})
+        test.globs.update({"tempdir": test.tempdir})
         os.chdir(test.tempdir)
 
     def teardown(test):
         os.chdir(test.__old_cwd)
         shutil.rmtree(test.tempdir)
+
     return doctest.DocFileSuite(
-            module_relative=True, package='dulwich.tests',
-            setUp=setup, tearDown=teardown, *tutorial_files)
+        module_relative=True,
+        package="dulwich.tests",
+        setUp=setup,
+        tearDown=teardown,
+        *tutorial_files
+    )
 
 
 def nocompat_test_suite():
@@ -177,6 +187,7 @@ def nocompat_test_suite():
     result.addTests(self_test_suite())
     result.addTests(tutorial_test_suite())
     from dulwich.contrib import test_suite as contrib_test_suite
+
     result.addTests(contrib_test_suite())
     return result
 
@@ -184,6 +195,7 @@ def nocompat_test_suite():
 def compat_test_suite():
     result = unittest.TestSuite()
     from dulwich.tests.compat import test_suite as compat_test_suite
+
     result.addTests(compat_test_suite())
     return result
 
@@ -191,10 +203,12 @@ def compat_test_suite():
 def test_suite():
     result = unittest.TestSuite()
     result.addTests(self_test_suite())
-    if sys.platform != 'win32':
+    if sys.platform != "win32":
         result.addTests(tutorial_test_suite())
     from dulwich.tests.compat import test_suite as compat_test_suite
+
     result.addTests(compat_test_suite())
     from dulwich.contrib import test_suite as contrib_test_suite
+
     result.addTests(contrib_test_suite())
     return result

+ 10 - 9
dulwich/tests/compat/__init__.py

@@ -25,15 +25,16 @@ import unittest
 
 def test_suite():
     names = [
-        'client',
-        'pack',
-        'patch',
-        'repository',
-        'server',
-        'utils',
-        'web',
-        ]
-    module_names = ['dulwich.tests.compat.test_' + name for name in names]
+        "client",
+        "pack",
+        "patch",
+        "porcelain",
+        "repository",
+        "server",
+        "utils",
+        "web",
+    ]
+    module_names = ["dulwich.tests.compat.test_" + name for name in names]
     result = unittest.TestSuite()
     loader = unittest.TestLoader()
     suite = loader.loadTestsFromNames(module_names)

+ 148 - 94
dulwich/tests/compat/server_utils.py

@@ -30,16 +30,16 @@ from dulwich.repo import Repo
 from dulwich.objects import hex_to_sha
 from dulwich.protocol import (
     CAPABILITY_SIDE_BAND_64K,
-    )
+)
 from dulwich.server import (
     ReceivePackHandler,
-    )
+)
 from dulwich.tests.utils import (
     tear_down_repo,
-    )
+)
 from dulwich.tests.compat.utils import (
     run_git_or_fail,
-    )
+)
 from dulwich.tests.compat.utils import require_git_version
 
 
@@ -56,7 +56,7 @@ class _StubRepo(object):
 
 
 def _get_shallow(repo):
-    shallow_file = repo.get_named_file('shallow')
+    shallow_file = repo.get_named_file("shallow")
     if not shallow_file:
         return []
     shallows = []
@@ -76,70 +76,80 @@ class ServerTests(object):
     Does not inherit from TestCase so tests are not automatically run.
     """
 
-    min_single_branch_version = (1, 7, 10,)
+    min_single_branch_version = (
+        1,
+        7,
+        10,
+    )
 
     def import_repos(self):
-        self._old_repo = self.import_repo('server_old.export')
-        self._new_repo = self.import_repo('server_new.export')
+        self._old_repo = self.import_repo("server_old.export")
+        self._new_repo = self.import_repo("server_new.export")
 
     def url(self, port):
-        return '%s://localhost:%s/' % (self.protocol, port)
+        return "%s://localhost:%s/" % (self.protocol, port)
 
     def branch_args(self, branches=None):
         if branches is None:
-            branches = ['master', 'branch']
-        return ['%s:%s' % (b, b) for b in branches]
+            branches = ["master", "branch"]
+        return ["%s:%s" % (b, b) for b in branches]
 
     def test_push_to_dulwich(self):
         self.import_repos()
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
 
-        run_git_or_fail(['push', self.url(port)] + self.branch_args(),
-                        cwd=self._new_repo.path)
+        run_git_or_fail(
+            ["push", self.url(port)] + self.branch_args(),
+            cwd=self._new_repo.path,
+        )
         self.assertReposEqual(self._old_repo, self._new_repo)
 
     def test_push_to_dulwich_no_op(self):
-        self._old_repo = self.import_repo('server_old.export')
-        self._new_repo = self.import_repo('server_old.export')
+        self._old_repo = self.import_repo("server_old.export")
+        self._new_repo = self.import_repo("server_old.export")
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
 
-        run_git_or_fail(['push', self.url(port)] + self.branch_args(),
-                        cwd=self._new_repo.path)
+        run_git_or_fail(
+            ["push", self.url(port)] + self.branch_args(),
+            cwd=self._new_repo.path,
+        )
         self.assertReposEqual(self._old_repo, self._new_repo)
 
     def test_push_to_dulwich_remove_branch(self):
-        self._old_repo = self.import_repo('server_old.export')
-        self._new_repo = self.import_repo('server_old.export')
+        self._old_repo = self.import_repo("server_old.export")
+        self._new_repo = self.import_repo("server_old.export")
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
 
-        run_git_or_fail(['push', self.url(port), ":master"],
-                        cwd=self._new_repo.path)
+        run_git_or_fail(["push", self.url(port), ":master"], cwd=self._new_repo.path)
 
-        self.assertEqual(
-            list(self._old_repo.get_refs().keys()), [b"refs/heads/branch"])
+        self.assertEqual(list(self._old_repo.get_refs().keys()), [b"refs/heads/branch"])
 
     def test_fetch_from_dulwich(self):
         self.import_repos()
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._new_repo)
 
-        run_git_or_fail(['fetch', self.url(port)] + self.branch_args(),
-                        cwd=self._old_repo.path)
+        run_git_or_fail(
+            ["fetch", self.url(port)] + self.branch_args(),
+            cwd=self._old_repo.path,
+        )
         # flush the pack cache so any new packs are picked up
         self._old_repo.object_store._pack_cache_time = 0
         self.assertReposEqual(self._old_repo, self._new_repo)
 
     def test_fetch_from_dulwich_no_op(self):
-        self._old_repo = self.import_repo('server_old.export')
-        self._new_repo = self.import_repo('server_old.export')
+        self._old_repo = self.import_repo("server_old.export")
+        self._new_repo = self.import_repo("server_old.export")
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._new_repo)
 
-        run_git_or_fail(['fetch', self.url(port)] + self.branch_args(),
-                        cwd=self._old_repo.path)
+        run_git_or_fail(
+            ["fetch", self.url(port)] + self.branch_args(),
+            cwd=self._old_repo.path,
+        )
         # flush the pack cache so any new packs are picked up
         self._old_repo.object_store._pack_cache_time = 0
         self.assertReposEqual(self._old_repo, self._new_repo)
@@ -152,146 +162,185 @@ class ServerTests(object):
 
         new_repo_base_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, new_repo_base_dir)
-        new_repo_dir = os.path.join(new_repo_base_dir, 'empty_new')
-        run_git_or_fail(['clone', self.url(port), new_repo_dir],
-                        cwd=new_repo_base_dir)
+        new_repo_dir = os.path.join(new_repo_base_dir, "empty_new")
+        run_git_or_fail(["clone", self.url(port), new_repo_dir], cwd=new_repo_base_dir)
         new_repo = Repo(new_repo_dir)
         self.assertReposEqual(self._old_repo, new_repo)
 
     def test_lsremote_from_dulwich(self):
-        self._repo = self.import_repo('server_old.export')
+        self._repo = self.import_repo("server_old.export")
         port = self._start_server(self._repo)
-        o = run_git_or_fail(['ls-remote', self.url(port)])
-        self.assertEqual(len(o.split(b'\n')), 4)
+        o = run_git_or_fail(["ls-remote", self.url(port)])
+        self.assertEqual(len(o.split(b"\n")), 4)
 
     def test_new_shallow_clone_from_dulwich(self):
         require_git_version(self.min_single_branch_version)
-        self._source_repo = self.import_repo('server_new.export')
-        self._stub_repo = _StubRepo('shallow')
+        self._source_repo = self.import_repo("server_new.export")
+        self._stub_repo = _StubRepo("shallow")
         self.addCleanup(tear_down_repo, self._stub_repo)
         port = self._start_server(self._source_repo)
 
         # Fetch at depth 1
         run_git_or_fail(
-            ['clone', '--mirror', '--depth=1', '--no-single-branch',
-             self.url(port), self._stub_repo.path])
+            [
+                "clone",
+                "--mirror",
+                "--depth=1",
+                "--no-single-branch",
+                self.url(port),
+                self._stub_repo.path,
+            ]
+        )
         clone = self._stub_repo = Repo(self._stub_repo.path)
-        expected_shallow = [b'35e0b59e187dd72a0af294aedffc213eaa4d03ff',
-                            b'514dc6d3fbfe77361bcaef320c4d21b72bc10be9']
+        expected_shallow = [
+            b"35e0b59e187dd72a0af294aedffc213eaa4d03ff",
+            b"514dc6d3fbfe77361bcaef320c4d21b72bc10be9",
+        ]
         self.assertEqual(expected_shallow, _get_shallow(clone))
         self.assertReposNotEqual(clone, self._source_repo)
 
     def test_shallow_clone_from_git_is_identical(self):
         require_git_version(self.min_single_branch_version)
-        self._source_repo = self.import_repo('server_new.export')
-        self._stub_repo_git = _StubRepo('shallow-git')
+        self._source_repo = self.import_repo("server_new.export")
+        self._stub_repo_git = _StubRepo("shallow-git")
         self.addCleanup(tear_down_repo, self._stub_repo_git)
-        self._stub_repo_dw = _StubRepo('shallow-dw')
+        self._stub_repo_dw = _StubRepo("shallow-dw")
         self.addCleanup(tear_down_repo, self._stub_repo_dw)
 
         # shallow clone using stock git, then using dulwich
         run_git_or_fail(
-            ['clone', '--mirror', '--depth=1', '--no-single-branch',
-             'file://' + self._source_repo.path, self._stub_repo_git.path])
+            [
+                "clone",
+                "--mirror",
+                "--depth=1",
+                "--no-single-branch",
+                "file://" + self._source_repo.path,
+                self._stub_repo_git.path,
+            ]
+        )
 
         port = self._start_server(self._source_repo)
         run_git_or_fail(
-            ['clone', '--mirror', '--depth=1', '--no-single-branch',
-             self.url(port), self._stub_repo_dw.path])
+            [
+                "clone",
+                "--mirror",
+                "--depth=1",
+                "--no-single-branch",
+                self.url(port),
+                self._stub_repo_dw.path,
+            ]
+        )
 
         # compare the two clones; they should be equal
-        self.assertReposEqual(Repo(self._stub_repo_git.path),
-                              Repo(self._stub_repo_dw.path))
+        self.assertReposEqual(
+            Repo(self._stub_repo_git.path), Repo(self._stub_repo_dw.path)
+        )
 
     def test_fetch_same_depth_into_shallow_clone_from_dulwich(self):
         require_git_version(self.min_single_branch_version)
-        self._source_repo = self.import_repo('server_new.export')
-        self._stub_repo = _StubRepo('shallow')
+        self._source_repo = self.import_repo("server_new.export")
+        self._stub_repo = _StubRepo("shallow")
         self.addCleanup(tear_down_repo, self._stub_repo)
         port = self._start_server(self._source_repo)
 
         # Fetch at depth 2
         run_git_or_fail(
-            ['clone', '--mirror', '--depth=2', '--no-single-branch',
-             self.url(port), self._stub_repo.path])
+            [
+                "clone",
+                "--mirror",
+                "--depth=2",
+                "--no-single-branch",
+                self.url(port),
+                self._stub_repo.path,
+            ]
+        )
         clone = self._stub_repo = Repo(self._stub_repo.path)
 
         # Fetching at the same depth is a no-op.
         run_git_or_fail(
-          ['fetch', '--depth=2', self.url(port)] + self.branch_args(),
-          cwd=self._stub_repo.path)
-        expected_shallow = [b'94de09a530df27ac3bb613aaecdd539e0a0655e1',
-                            b'da5cd81e1883c62a25bb37c4d1f8ad965b29bf8d']
+            ["fetch", "--depth=2", self.url(port)] + self.branch_args(),
+            cwd=self._stub_repo.path,
+        )
+        expected_shallow = [
+            b"94de09a530df27ac3bb613aaecdd539e0a0655e1",
+            b"da5cd81e1883c62a25bb37c4d1f8ad965b29bf8d",
+        ]
         self.assertEqual(expected_shallow, _get_shallow(clone))
         self.assertReposNotEqual(clone, self._source_repo)
 
     def test_fetch_full_depth_into_shallow_clone_from_dulwich(self):
         require_git_version(self.min_single_branch_version)
-        self._source_repo = self.import_repo('server_new.export')
-        self._stub_repo = _StubRepo('shallow')
+        self._source_repo = self.import_repo("server_new.export")
+        self._stub_repo = _StubRepo("shallow")
         self.addCleanup(tear_down_repo, self._stub_repo)
         port = self._start_server(self._source_repo)
 
         # Fetch at depth 2
         run_git_or_fail(
-            ['clone', '--mirror', '--depth=2', '--no-single-branch',
-             self.url(port), self._stub_repo.path])
+            [
+                "clone",
+                "--mirror",
+                "--depth=2",
+                "--no-single-branch",
+                self.url(port),
+                self._stub_repo.path,
+            ]
+        )
         clone = self._stub_repo = Repo(self._stub_repo.path)
 
         # Fetching at the same depth is a no-op.
         run_git_or_fail(
-          ['fetch', '--depth=2', self.url(port)] + self.branch_args(),
-          cwd=self._stub_repo.path)
+            ["fetch", "--depth=2", self.url(port)] + self.branch_args(),
+            cwd=self._stub_repo.path,
+        )
 
         # The whole repo only has depth 4, so it should equal server_new.
         run_git_or_fail(
-          ['fetch', '--depth=4', self.url(port)] + self.branch_args(),
-          cwd=self._stub_repo.path)
+            ["fetch", "--depth=4", self.url(port)] + self.branch_args(),
+            cwd=self._stub_repo.path,
+        )
         self.assertEqual([], _get_shallow(clone))
         self.assertReposEqual(clone, self._source_repo)
 
     def test_fetch_from_dulwich_issue_88_standard(self):
         # Basically an integration test to see that the ACK/NAK
         # generation works on repos with common head.
-        self._source_repo = self.import_repo(
-            'issue88_expect_ack_nak_server.export')
-        self._client_repo = self.import_repo(
-            'issue88_expect_ack_nak_client.export')
+        self._source_repo = self.import_repo("issue88_expect_ack_nak_server.export")
+        self._client_repo = self.import_repo("issue88_expect_ack_nak_client.export")
         port = self._start_server(self._source_repo)
 
-        run_git_or_fail(['fetch', self.url(port), 'master'],
-                        cwd=self._client_repo.path)
+        run_git_or_fail(["fetch", self.url(port), "master"], cwd=self._client_repo.path)
         self.assertObjectStoreEqual(
-            self._source_repo.object_store,
-            self._client_repo.object_store)
+            self._source_repo.object_store, self._client_repo.object_store
+        )
 
     def test_fetch_from_dulwich_issue_88_alternative(self):
         # likewise, but the case where the two repos have no common parent
-        self._source_repo = self.import_repo(
-            'issue88_expect_ack_nak_other.export')
-        self._client_repo = self.import_repo(
-            'issue88_expect_ack_nak_client.export')
+        self._source_repo = self.import_repo("issue88_expect_ack_nak_other.export")
+        self._client_repo = self.import_repo("issue88_expect_ack_nak_client.export")
         port = self._start_server(self._source_repo)
 
         self.assertRaises(
-            KeyError, self._client_repo.get_object,
-            b'02a14da1fc1fc13389bbf32f0af7d8899f2b2323')
-        run_git_or_fail(['fetch', self.url(port), 'master'],
-                        cwd=self._client_repo.path)
-        self.assertEqual(b'commit', self._client_repo.get_object(
-            b'02a14da1fc1fc13389bbf32f0af7d8899f2b2323').type_name)
+            KeyError,
+            self._client_repo.get_object,
+            b"02a14da1fc1fc13389bbf32f0af7d8899f2b2323",
+        )
+        run_git_or_fail(["fetch", self.url(port), "master"], cwd=self._client_repo.path)
+        self.assertEqual(
+            b"commit",
+            self._client_repo.get_object(
+                b"02a14da1fc1fc13389bbf32f0af7d8899f2b2323"
+            ).type_name,
+        )
 
     def test_push_to_dulwich_issue_88_standard(self):
         # Same thing, but we reverse the role of the server/client
         # and do a push instead.
-        self._source_repo = self.import_repo(
-            'issue88_expect_ack_nak_client.export')
-        self._client_repo = self.import_repo(
-            'issue88_expect_ack_nak_server.export')
+        self._source_repo = self.import_repo("issue88_expect_ack_nak_client.export")
+        self._client_repo = self.import_repo("issue88_expect_ack_nak_server.export")
         port = self._start_server(self._source_repo)
 
-        run_git_or_fail(['push', self.url(port), 'master'],
-                        cwd=self._client_repo.path)
+        run_git_or_fail(["push", self.url(port), "master"], cwd=self._client_repo.path)
         self.assertReposEqual(self._source_repo, self._client_repo)
 
 
@@ -303,12 +352,17 @@ class NoSideBand64kReceivePackHandler(ReceivePackHandler):
 
     @classmethod
     def capabilities(cls):
-        return [c for c in ReceivePackHandler.capabilities()
-                if c != CAPABILITY_SIDE_BAND_64K]
+        return [
+            c
+            for c in ReceivePackHandler.capabilities()
+            if c != CAPABILITY_SIDE_BAND_64K
+        ]
 
 
 def ignore_error(error):
     """Check whether this error is safe to ignore."""
     (e_type, e_value, e_tb) = error
-    return (issubclass(e_type, socket.error) and
-            e_value[0] in (errno.ECONNRESET, errno.EPIPE))
+    return issubclass(e_type, socket.error) and e_value[0] in (
+        errno.ECONNRESET,
+        errno.EPIPE,
+    )

+ 220 - 176
dulwich/tests/compat/test_client.py

@@ -43,11 +43,11 @@ from dulwich import (
     protocol,
     objects,
     repo,
-    )
+)
 from dulwich.tests import (
     SkipTest,
     expectedFailure,
-    )
+)
 from dulwich.tests.compat.utils import (
     CompatTestCase,
     check_for_daemon,
@@ -55,10 +55,10 @@ from dulwich.tests.compat.utils import (
     rmtree_ro,
     run_git_or_fail,
     _DEFAULT_GIT,
-    )
+)
 
 
-if sys.platform == 'win32':
+if sys.platform == "win32":
     import ctypes
 
 
@@ -67,17 +67,18 @@ class DulwichClientTestBase(object):
 
     def setUp(self):
         self.gitroot = os.path.dirname(
-                import_repo_to_dir('server_new.export').rstrip(os.sep))
-        self.dest = os.path.join(self.gitroot, 'dest')
+            import_repo_to_dir("server_new.export").rstrip(os.sep)
+        )
+        self.dest = os.path.join(self.gitroot, "dest")
         file.ensure_dir_exists(self.dest)
-        run_git_or_fail(['init', '--quiet', '--bare'], cwd=self.dest)
+        run_git_or_fail(["init", "--quiet", "--bare"], cwd=self.dest)
 
     def tearDown(self):
         rmtree_ro(self.gitroot)
 
     def assertDestEqualsSrc(self):
-        repo_dir = os.path.join(self.gitroot, 'server_new.export')
-        dest_repo_dir = os.path.join(self.gitroot, 'dest')
+        repo_dir = os.path.join(self.gitroot, "server_new.export")
+        dest_repo_dir = os.path.join(self.gitroot, "dest")
         with repo.Repo(repo_dir) as src:
             with repo.Repo(dest_repo_dir) as dest:
                 self.assertReposEqual(src, dest)
@@ -90,12 +91,15 @@ class DulwichClientTestBase(object):
 
     def _do_send_pack(self):
         c = self._client()
-        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        srcpath = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(srcpath) as src:
             sendrefs = dict(src.get_refs())
-            del sendrefs[b'HEAD']
-            c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
-                        src.generate_pack_data)
+            del sendrefs[b"HEAD"]
+            c.send_pack(
+                self._build_path("/dest"),
+                lambda _: sendrefs,
+                src.generate_pack_data,
+            )
 
     def test_send_pack(self):
         self._do_send_pack()
@@ -111,157 +115,175 @@ class DulwichClientTestBase(object):
     def _add_file(repo, tree_id, filename, contents):
         tree = repo[tree_id]
         blob = objects.Blob()
-        blob.data = contents.encode('utf-8')
+        blob.data = contents.encode("utf-8")
         repo.object_store.add_object(blob)
-        tree.add(filename.encode('utf-8'), stat.S_IFREG | 0o644, blob.id)
+        tree.add(filename.encode("utf-8"), stat.S_IFREG | 0o644, blob.id)
         repo.object_store.add_object(tree)
         return tree.id
 
     def test_send_pack_from_shallow_clone(self):
         c = self._client()
-        server_new_path = os.path.join(self.gitroot, 'server_new.export')
-        run_git_or_fail(['config', 'http.uploadpack', 'true'],
-                        cwd=server_new_path)
-        run_git_or_fail(['config', 'http.receivepack', 'true'],
-                        cwd=server_new_path)
-        remote_path = self._build_path('/server_new.export')
+        server_new_path = os.path.join(self.gitroot, "server_new.export")
+        run_git_or_fail(["config", "http.uploadpack", "true"], cwd=server_new_path)
+        run_git_or_fail(["config", "http.receivepack", "true"], cwd=server_new_path)
+        remote_path = self._build_path("/server_new.export")
         with repo.Repo(self.dest) as local:
             result = c.fetch(remote_path, local, depth=1)
             for r in result.refs.items():
                 local.refs.set_if_equals(r[0], None, r[1])
             tree_id = local[local.head()].tree
-            for filename, contents in [('bar', 'bar contents'),
-                                       ('zop', 'zop contents')]:
+            for filename, contents in [
+                ("bar", "bar contents"),
+                ("zop", "zop contents"),
+            ]:
                 tree_id = self._add_file(local, tree_id, filename, contents)
                 commit_id = local.do_commit(
-                    message=b"add " + filename.encode('utf-8'),
+                    message=b"add " + filename.encode("utf-8"),
                     committer=b"Joe Example <joe@example.com>",
-                    tree=tree_id)
+                    tree=tree_id,
+                )
             sendrefs = dict(local.get_refs())
-            del sendrefs[b'HEAD']
-            c.send_pack(remote_path, lambda _: sendrefs,
-                        local.generate_pack_data)
+            del sendrefs[b"HEAD"]
+            c.send_pack(remote_path, lambda _: sendrefs, local.generate_pack_data)
         with repo.Repo(server_new_path) as remote:
             self.assertEqual(remote.head(), commit_id)
 
     def test_send_without_report_status(self):
         c = self._client()
-        c._send_capabilities.remove(b'report-status')
-        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        c._send_capabilities.remove(b"report-status")
+        srcpath = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(srcpath) as src:
             sendrefs = dict(src.get_refs())
-            del sendrefs[b'HEAD']
-            c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
-                        src.generate_pack_data)
+            del sendrefs[b"HEAD"]
+            c.send_pack(
+                self._build_path("/dest"),
+                lambda _: sendrefs,
+                src.generate_pack_data,
+            )
             self.assertDestEqualsSrc()
 
     def make_dummy_commit(self, dest):
-        b = objects.Blob.from_string(b'hi')
+        b = objects.Blob.from_string(b"hi")
         dest.object_store.add_object(b)
-        t = index.commit_tree(dest.object_store, [(b'hi', b.id, 0o100644)])
+        t = index.commit_tree(dest.object_store, [(b"hi", b.id, 0o100644)])
         c = objects.Commit()
-        c.author = c.committer = b'Foo Bar <foo@example.com>'
+        c.author = c.committer = b"Foo Bar <foo@example.com>"
         c.author_time = c.commit_time = 0
         c.author_timezone = c.commit_timezone = 0
-        c.message = b'hi'
+        c.message = b"hi"
         c.tree = t
         dest.object_store.add_object(c)
         return c.id
 
     def disable_ff_and_make_dummy_commit(self):
         # disable non-fast-forward pushes to the server
-        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
-        run_git_or_fail(['config', 'receive.denyNonFastForwards', 'true'],
-                        cwd=dest.path)
+        dest = repo.Repo(os.path.join(self.gitroot, "dest"))
+        run_git_or_fail(
+            ["config", "receive.denyNonFastForwards", "true"], cwd=dest.path
+        )
         commit_id = self.make_dummy_commit(dest)
         return dest, commit_id
 
     def compute_send(self, src):
         sendrefs = dict(src.get_refs())
-        del sendrefs[b'HEAD']
+        del sendrefs[b"HEAD"]
         return sendrefs, src.generate_pack_data
 
     def test_send_pack_one_error(self):
         dest, dummy_commit = self.disable_ff_and_make_dummy_commit()
-        dest.refs[b'refs/heads/master'] = dummy_commit
-        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        dest.refs[b"refs/heads/master"] = dummy_commit
+        repo_dir = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(repo_dir) as src:
             sendrefs, gen_pack = self.compute_send(src)
             c = self._client()
             result = c.send_pack(
-                self._build_path('/dest'), lambda _: sendrefs, gen_pack)
-            self.assertEqual({b'refs/heads/branch': None,
-                              b'refs/heads/master': 'non-fast-forward'},
-                             result.ref_status)
+                self._build_path("/dest"), lambda _: sendrefs, gen_pack
+            )
+            self.assertEqual(
+                {
+                    b"refs/heads/branch": None,
+                    b"refs/heads/master": "non-fast-forward",
+                },
+                result.ref_status,
+            )
 
     def test_send_pack_multiple_errors(self):
         dest, dummy = self.disable_ff_and_make_dummy_commit()
         # set up for two non-ff errors
-        branch, master = b'refs/heads/branch', b'refs/heads/master'
+        branch, master = b"refs/heads/branch", b"refs/heads/master"
         dest.refs[branch] = dest.refs[master] = dummy
-        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        repo_dir = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(repo_dir) as src:
             sendrefs, gen_pack = self.compute_send(src)
             c = self._client()
             result = c.send_pack(
-                self._build_path('/dest'), lambda _: sendrefs, gen_pack)
-            self.assertEqual({branch: 'non-fast-forward',
-                              master: 'non-fast-forward'},
-                             result.ref_status)
+                self._build_path("/dest"), lambda _: sendrefs, gen_pack
+            )
+            self.assertEqual(
+                {branch: "non-fast-forward", master: "non-fast-forward"},
+                result.ref_status,
+            )
 
     def test_archive(self):
         c = self._client()
         f = BytesIO()
-        c.archive(self._build_path('/server_new.export'), b'HEAD', f.write)
+        c.archive(self._build_path("/server_new.export"), b"HEAD", f.write)
         f.seek(0)
         tf = tarfile.open(fileobj=f)
-        self.assertEqual(['baz', 'foo'], tf.getnames())
+        self.assertEqual(["baz", "foo"], tf.getnames())
 
     def test_fetch_pack(self):
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            result = c.fetch(self._build_path('/server_new.export'), dest)
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
+            result = c.fetch(self._build_path("/server_new.export"), dest)
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
 
     def test_fetch_pack_depth(self):
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            result = c.fetch(self._build_path('/server_new.export'), dest,
-                             depth=1)
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
+            result = c.fetch(self._build_path("/server_new.export"), dest, depth=1)
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertEqual(
-                    dest.get_shallow(),
-                    set([b'35e0b59e187dd72a0af294aedffc213eaa4d03ff',
-                         b'514dc6d3fbfe77361bcaef320c4d21b72bc10be9']))
+                dest.get_shallow(),
+                set(
+                    [
+                        b"35e0b59e187dd72a0af294aedffc213eaa4d03ff",
+                        b"514dc6d3fbfe77361bcaef320c4d21b72bc10be9",
+                    ]
+                ),
+            )
 
     def test_repeat(self):
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            result = c.fetch(self._build_path('/server_new.export'), dest)
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
+            result = c.fetch(self._build_path("/server_new.export"), dest)
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
-            result = c.fetch(self._build_path('/server_new.export'), dest)
+            result = c.fetch(self._build_path("/server_new.export"), dest)
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
 
     def test_fetch_empty_pack(self):
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            result = c.fetch(self._build_path('/server_new.export'), dest)
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
+            result = c.fetch(self._build_path("/server_new.export"), dest)
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
 
-            def dw(refs):
+            def dw(refs, **kwargs):
                 return list(refs.values())
+
             result = c.fetch(
-                self._build_path('/server_new.export'), dest,
-                determine_wants=dw)
+                self._build_path("/server_new.export"),
+                dest,
+                determine_wants=dw,
+            )
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
@@ -269,20 +291,20 @@ class DulwichClientTestBase(object):
     def test_incremental_fetch_pack(self):
         self.test_fetch_pack()
         dest, dummy = self.disable_ff_and_make_dummy_commit()
-        dest.refs[b'refs/heads/master'] = dummy
+        dest.refs[b"refs/heads/master"] = dummy
         c = self._client()
-        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        repo_dir = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(repo_dir) as dest:
-            result = c.fetch(self._build_path('/dest'), dest)
+            result = c.fetch(self._build_path("/dest"), dest)
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
 
     def test_fetch_pack_no_side_band_64k(self):
         c = self._client()
-        c._fetch_capabilities.remove(b'side-band-64k')
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            result = c.fetch(self._build_path('/server_new.export'), dest)
+        c._fetch_capabilities.remove(b"side-band-64k")
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
+            result = c.fetch(self._build_path("/server_new.export"), dest)
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
@@ -291,84 +313,96 @@ class DulwichClientTestBase(object):
         # zero sha1s are already present on the client, and should
         # be ignored
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
             result = c.fetch(
-                self._build_path('/server_new.export'), dest,
-                lambda refs: [protocol.ZERO_SHA])
+                self._build_path("/server_new.export"),
+                dest,
+                lambda refs, **kwargs: [protocol.ZERO_SHA],
+            )
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
 
     def test_send_remove_branch(self):
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
             dummy_commit = self.make_dummy_commit(dest)
-            dest.refs[b'refs/heads/master'] = dummy_commit
-            dest.refs[b'refs/heads/abranch'] = dummy_commit
+            dest.refs[b"refs/heads/master"] = dummy_commit
+            dest.refs[b"refs/heads/abranch"] = dummy_commit
             sendrefs = dict(dest.refs)
-            sendrefs[b'refs/heads/abranch'] = b"00" * 20
-            del sendrefs[b'HEAD']
+            sendrefs[b"refs/heads/abranch"] = b"00" * 20
+            del sendrefs[b"HEAD"]
 
             def gen_pack(have, want, ofs_delta=False):
                 return 0, []
+
             c = self._client()
             self.assertEqual(dest.refs[b"refs/heads/abranch"], dummy_commit)
-            c.send_pack(
-                self._build_path('/dest'), lambda _: sendrefs, gen_pack)
+            c.send_pack(self._build_path("/dest"), lambda _: sendrefs, gen_pack)
             self.assertFalse(b"refs/heads/abranch" in dest.refs)
 
     def test_send_new_branch_empty_pack(self):
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
             dummy_commit = self.make_dummy_commit(dest)
-            dest.refs[b'refs/heads/master'] = dummy_commit
-            dest.refs[b'refs/heads/abranch'] = dummy_commit
-            sendrefs = {b'refs/heads/bbranch': dummy_commit}
+            dest.refs[b"refs/heads/master"] = dummy_commit
+            dest.refs[b"refs/heads/abranch"] = dummy_commit
+            sendrefs = {b"refs/heads/bbranch": dummy_commit}
 
             def gen_pack(have, want, ofs_delta=False):
                 return 0, []
+
             c = self._client()
             self.assertEqual(dest.refs[b"refs/heads/abranch"], dummy_commit)
-            c.send_pack(
-                self._build_path('/dest'), lambda _: sendrefs, gen_pack)
+            c.send_pack(self._build_path("/dest"), lambda _: sendrefs, gen_pack)
             self.assertEqual(dummy_commit, dest.refs[b"refs/heads/abranch"])
 
     def test_get_refs(self):
         c = self._client()
-        refs = c.get_refs(self._build_path('/server_new.export'))
+        refs = c.get_refs(self._build_path("/server_new.export"))
 
-        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        repo_dir = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(repo_dir) as dest:
             self.assertDictEqual(dest.refs.as_dict(), refs)
 
 
 class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
-
     def setUp(self):
         CompatTestCase.setUp(self)
         DulwichClientTestBase.setUp(self)
         if check_for_daemon(limit=1):
-            raise SkipTest('git-daemon was already running on port %s' %
-                           protocol.TCP_GIT_PORT)
-        fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
-                                            suffix=".pid")
+            raise SkipTest(
+                "git-daemon was already running on port %s" % protocol.TCP_GIT_PORT
+            )
+        fd, self.pidfile = tempfile.mkstemp(
+            prefix="dulwich-test-git-client", suffix=".pid"
+        )
         os.fdopen(fd).close()
-        args = [_DEFAULT_GIT, 'daemon', '--verbose', '--export-all',
-                '--pid-file=%s' % self.pidfile,
-                '--base-path=%s' % self.gitroot,
-                '--enable=receive-pack', '--enable=upload-archive',
-                '--listen=localhost', '--reuseaddr',
-                self.gitroot]
+        args = [
+            _DEFAULT_GIT,
+            "daemon",
+            "--verbose",
+            "--export-all",
+            "--pid-file=%s" % self.pidfile,
+            "--base-path=%s" % self.gitroot,
+            "--enable=receive-pack",
+            "--enable=upload-archive",
+            "--listen=localhost",
+            "--reuseaddr",
+            self.gitroot,
+        ]
         self.process = subprocess.Popen(
-            args, cwd=self.gitroot,
-            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+            args,
+            cwd=self.gitroot,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+        )
         if not check_for_daemon():
-            raise SkipTest('git-daemon failed to start')
+            raise SkipTest("git-daemon failed to start")
 
     def tearDown(self):
         with open(self.pidfile) as f:
             pid = int(f.read().strip())
-        if sys.platform == 'win32':
+        if sys.platform == "win32":
             PROCESS_TERMINATE = 1
-            handle = ctypes.windll.kernel32.OpenProcess(
-                PROCESS_TERMINATE, False, pid)
+            handle = ctypes.windll.kernel32.OpenProcess(PROCESS_TERMINATE, False, pid)
             ctypes.windll.kernel32.TerminateProcess(handle, -1)
             ctypes.windll.kernel32.CloseHandle(handle)
         else:
@@ -384,32 +418,42 @@ class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
         CompatTestCase.tearDown(self)
 
     def _client(self):
-        return client.TCPGitClient('localhost')
+        return client.TCPGitClient("localhost")
 
     def _build_path(self, path):
         return path
 
-    if sys.platform == 'win32':
+    if sys.platform == "win32":
+
         @expectedFailure
         def test_fetch_pack_no_side_band_64k(self):
             DulwichClientTestBase.test_fetch_pack_no_side_band_64k(self)
 
 
 class TestSSHVendor(object):
-
     @staticmethod
-    def run_command(host, command, username=None, port=None,
-                    password=None, key_filename=None):
-        cmd, path = command.split(' ')
-        cmd = cmd.split('-', 1)
+    def run_command(
+        host,
+        command,
+        username=None,
+        port=None,
+        password=None,
+        key_filename=None,
+    ):
+        cmd, path = command.split(" ")
+        cmd = cmd.split("-", 1)
         path = path.replace("'", "")
-        p = subprocess.Popen(cmd + [path], bufsize=0, stdin=subprocess.PIPE,
-                             stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        p = subprocess.Popen(
+            cmd + [path],
+            bufsize=0,
+            stdin=subprocess.PIPE,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+        )
         return client.SubprocessWrapper(p)
 
 
 class DulwichMockSSHClientTest(CompatTestCase, DulwichClientTestBase):
-
     def setUp(self):
         CompatTestCase.setUp(self)
         DulwichClientTestBase.setUp(self)
@@ -422,14 +466,13 @@ class DulwichMockSSHClientTest(CompatTestCase, DulwichClientTestBase):
         client.get_ssh_vendor = self.real_vendor
 
     def _client(self):
-        return client.SSHGitClient('localhost')
+        return client.SSHGitClient("localhost")
 
     def _build_path(self, path):
         return self.gitroot + path
 
 
 class DulwichSubprocessClientTest(CompatTestCase, DulwichClientTestBase):
-
     def setUp(self):
         CompatTestCase.setUp(self)
         DulwichClientTestBase.setUp(self)
@@ -461,11 +504,11 @@ class GitHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
     def send_head(self):
         return self.run_backend()
 
-    def log_request(self, code='-', size='-'):
+    def log_request(self, code="-", size="-"):
         # Let's be quiet, the test suite is noisy enough already
         pass
 
-    def run_backend(self):
+    def run_backend(self):  # noqa: C901
         """Call out to git http-backend."""
         # Based on CGIHTTPServer.CGIHTTPRequestHandler.run_cgi:
         # Copyright (c) 2001-2010 Python Software Foundation;
@@ -473,83 +516,88 @@ class GitHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
         # Licensed under the Python Software Foundation License.
         rest = self.path
         # find an explicit query string, if present.
-        i = rest.rfind('?')
+        i = rest.rfind("?")
         if i >= 0:
-            rest, query = rest[:i], rest[i+1:]
+            rest, query = rest[:i], rest[i + 1 :]
         else:
-            query = ''
+            query = ""
 
         env = copy.deepcopy(os.environ)
-        env['SERVER_SOFTWARE'] = self.version_string()
-        env['SERVER_NAME'] = self.server.server_name
-        env['GATEWAY_INTERFACE'] = 'CGI/1.1'
-        env['SERVER_PROTOCOL'] = self.protocol_version
-        env['SERVER_PORT'] = str(self.server.server_port)
-        env['GIT_PROJECT_ROOT'] = self.server.root_path
+        env["SERVER_SOFTWARE"] = self.version_string()
+        env["SERVER_NAME"] = self.server.server_name
+        env["GATEWAY_INTERFACE"] = "CGI/1.1"
+        env["SERVER_PROTOCOL"] = self.protocol_version
+        env["SERVER_PORT"] = str(self.server.server_port)
+        env["GIT_PROJECT_ROOT"] = self.server.root_path
         env["GIT_HTTP_EXPORT_ALL"] = "1"
-        env['REQUEST_METHOD'] = self.command
+        env["REQUEST_METHOD"] = self.command
         uqrest = unquote(rest)
-        env['PATH_INFO'] = uqrest
-        env['SCRIPT_NAME'] = "/"
+        env["PATH_INFO"] = uqrest
+        env["SCRIPT_NAME"] = "/"
         if query:
-            env['QUERY_STRING'] = query
+            env["QUERY_STRING"] = query
         host = self.address_string()
         if host != self.client_address[0]:
-            env['REMOTE_HOST'] = host
-        env['REMOTE_ADDR'] = self.client_address[0]
+            env["REMOTE_HOST"] = host
+        env["REMOTE_ADDR"] = self.client_address[0]
         authorization = self.headers.get("authorization")
         if authorization:
             authorization = authorization.split()
             if len(authorization) == 2:
                 import base64
                 import binascii
-                env['AUTH_TYPE'] = authorization[0]
+
+                env["AUTH_TYPE"] = authorization[0]
                 if authorization[0].lower() == "basic":
                     try:
                         authorization = base64.decodestring(authorization[1])
                     except binascii.Error:
                         pass
                     else:
-                        authorization = authorization.split(':')
+                        authorization = authorization.split(":")
                         if len(authorization) == 2:
-                            env['REMOTE_USER'] = authorization[0]
+                            env["REMOTE_USER"] = authorization[0]
         # XXX REMOTE_IDENT
-        content_type = self.headers.get('content-type')
+        content_type = self.headers.get("content-type")
         if content_type:
-            env['CONTENT_TYPE'] = content_type
-        length = self.headers.get('content-length')
+            env["CONTENT_TYPE"] = content_type
+        length = self.headers.get("content-length")
         if length:
-            env['CONTENT_LENGTH'] = length
-        referer = self.headers.get('referer')
+            env["CONTENT_LENGTH"] = length
+        referer = self.headers.get("referer")
         if referer:
-            env['HTTP_REFERER'] = referer
+            env["HTTP_REFERER"] = referer
         accept = []
-        for line in self.headers.getallmatchingheaders('accept'):
+        for line in self.headers.getallmatchingheaders("accept"):
             if line[:1] in "\t\n\r ":
                 accept.append(line.strip())
             else:
-                accept = accept + line[7:].split(',')
-        env['HTTP_ACCEPT'] = ','.join(accept)
-        ua = self.headers.get('user-agent')
+                accept = accept + line[7:].split(",")
+        env["HTTP_ACCEPT"] = ",".join(accept)
+        ua = self.headers.get("user-agent")
         if ua:
-            env['HTTP_USER_AGENT'] = ua
-        co = self.headers.get('cookie')
+            env["HTTP_USER_AGENT"] = ua
+        co = self.headers.get("cookie")
         if co:
-            env['HTTP_COOKIE'] = co
+            env["HTTP_COOKIE"] = co
         # XXX Other HTTP_* headers
         # Since we're setting the env in the parent, provide empty
         # values to override previously set values
-        for k in ('QUERY_STRING', 'REMOTE_HOST', 'CONTENT_LENGTH',
-                  'HTTP_USER_AGENT', 'HTTP_COOKIE', 'HTTP_REFERER'):
+        for k in (
+            "QUERY_STRING",
+            "REMOTE_HOST",
+            "CONTENT_LENGTH",
+            "HTTP_USER_AGENT",
+            "HTTP_COOKIE",
+            "HTTP_REFERER",
+        ):
             env.setdefault(k, "")
 
         self.wfile.write(b"HTTP/1.1 200 Script output follows\r\n")
-        self.wfile.write(
-            ("Server: %s\r\n" % self.server.server_name).encode('ascii'))
-        self.wfile.write(
-            ("Date: %s\r\n" % self.date_time_string()).encode('ascii'))
+        self.wfile.write(("Server: %s\r\n" % self.server.server_name).encode("ascii"))
+        self.wfile.write(("Date: %s\r\n" % self.date_time_string()).encode("ascii"))
 
-        decoded_query = query.replace('+', ' ')
+        decoded_query = query.replace("+", " ")
 
         try:
             nbytes = int(length)
@@ -559,16 +607,15 @@ class GitHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
             data = self.rfile.read(nbytes)
         else:
             data = None
-            env['CONTENT_LENGTH'] = '0'
+            env["CONTENT_LENGTH"] = "0"
         # throw away additional data [see bug #427345]
         while select.select([self.rfile._sock], [], [], 0)[0]:
             if not self.rfile._sock.recv(1):
                 break
-        args = ['http-backend']
-        if '=' not in decoded_query:
+        args = ["http-backend"]
+        if "=" not in decoded_query:
             args.append(decoded_query)
-        stdout = run_git_or_fail(
-            args, input=data, env=env, stderr=subprocess.PIPE)
+        stdout = run_git_or_fail(args, input=data, env=env, stderr=subprocess.PIPE)
         self.wfile.write(stdout)
 
 
@@ -577,13 +624,12 @@ class HTTPGitServer(http.server.HTTPServer):
     allow_reuse_address = True
 
     def __init__(self, server_address, root_path):
-        http.server.HTTPServer.__init__(
-            self, server_address, GitHTTPRequestHandler)
+        http.server.HTTPServer.__init__(self, server_address, GitHTTPRequestHandler)
         self.root_path = root_path
         self.server_name = "localhost"
 
     def get_url(self):
-        return 'http://%s:%s/' % (self.server_name, self.server_port)
+        return "http://%s:%s/" % (self.server_name, self.server_port)
 
 
 class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):
@@ -596,10 +642,8 @@ class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):
         self._httpd = HTTPGitServer(("localhost", 0), self.gitroot)
         self.addCleanup(self._httpd.shutdown)
         threading.Thread(target=self._httpd.serve_forever).start()
-        run_git_or_fail(['config', 'http.uploadpack', 'true'],
-                        cwd=self.dest)
-        run_git_or_fail(['config', 'http.receivepack', 'true'],
-                        cwd=self.dest)
+        run_git_or_fail(["config", "http.uploadpack", "true"], cwd=self.dest)
+        run_git_or_fail(["config", "http.receivepack", "true"], cwd=self.dest)
 
     def tearDown(self):
         DulwichClientTestBase.tearDown(self)

+ 55 - 39
dulwich/tests/compat/test_pack.py

@@ -29,24 +29,24 @@ import tempfile
 
 from dulwich.pack import (
     write_pack,
-    )
+)
 from dulwich.objects import (
     Blob,
-    )
+)
 from dulwich.tests import (
     SkipTest,
-    )
+)
 from dulwich.tests.test_pack import (
     a_sha,
     pack1_sha,
     PackTests,
-    )
+)
 from dulwich.tests.compat.utils import (
     require_git_version,
     run_git_or_fail,
-    )
+)
 
-_NON_DELTA_RE = re.compile(b'non delta: (?P<non_delta>\\d+) objects')
+_NON_DELTA_RE = re.compile(b"non delta: (?P<non_delta>\\d+) objects")
 
 
 def _git_verify_pack_object_list(output):
@@ -75,28 +75,32 @@ class TestPack(PackTests):
             self.assertSucceeds(origpack.index.check)
             pack_path = os.path.join(self._tempdir, "Elch")
             write_pack(pack_path, origpack.pack_tuples())
-            output = run_git_or_fail(['verify-pack', '-v', pack_path])
-            orig_shas = set(o.id for o in origpack.iterobjects())
+            output = run_git_or_fail(["verify-pack", "-v", pack_path])
+            orig_shas = {o.id for o in origpack.iterobjects()}
             self.assertEqual(orig_shas, _git_verify_pack_object_list(output))
 
     def test_deltas_work(self):
         with self.get_pack(pack1_sha) as orig_pack:
             orig_blob = orig_pack[a_sha]
             new_blob = Blob()
-            new_blob.data = orig_blob.data + b'x'
+            new_blob.data = orig_blob.data + b"x"
             all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None)]
-        pack_path = os.path.join(self._tempdir, 'pack_with_deltas')
+        pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         write_pack(pack_path, all_to_pack, deltify=True)
-        output = run_git_or_fail(['verify-pack', '-v', pack_path])
-        self.assertEqual(set(x[0].id for x in all_to_pack),
-                         _git_verify_pack_object_list(output))
+        output = run_git_or_fail(["verify-pack", "-v", pack_path])
+        self.assertEqual(
+            {x[0].id for x in all_to_pack},
+            _git_verify_pack_object_list(output),
+        )
         # We specifically made a new blob that should be a delta
         # against the blob a_sha, so make sure we really got only 3
         # non-delta objects:
-        got_non_delta = int(_NON_DELTA_RE.search(output).group('non_delta'))
+        got_non_delta = int(_NON_DELTA_RE.search(output).group("non_delta"))
         self.assertEqual(
-            3, got_non_delta,
-            'Expected 3 non-delta objects, got %d' % got_non_delta)
+            3,
+            got_non_delta,
+            "Expected 3 non-delta objects, got %d" % got_non_delta,
+        )
 
     def test_delta_medium_object(self):
         # This tests an object set that will have a copy operation
@@ -104,26 +108,32 @@ class TestPack(PackTests):
         with self.get_pack(pack1_sha) as orig_pack:
             orig_blob = orig_pack[a_sha]
             new_blob = Blob()
-            new_blob.data = orig_blob.data + (b'x' * 2 ** 20)
+            new_blob.data = orig_blob.data + (b"x" * 2 ** 20)
             new_blob_2 = Blob()
-            new_blob_2.data = new_blob.data + b'y'
-            all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None),
-                                                           (new_blob_2, None)]
-        pack_path = os.path.join(self._tempdir, 'pack_with_deltas')
+            new_blob_2.data = new_blob.data + b"y"
+            all_to_pack = list(orig_pack.pack_tuples()) + [
+                (new_blob, None),
+                (new_blob_2, None),
+            ]
+        pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         write_pack(pack_path, all_to_pack, deltify=True)
-        output = run_git_or_fail(['verify-pack', '-v', pack_path])
-        self.assertEqual(set(x[0].id for x in all_to_pack),
-                         _git_verify_pack_object_list(output))
+        output = run_git_or_fail(["verify-pack", "-v", pack_path])
+        self.assertEqual(
+            {x[0].id for x in all_to_pack},
+            _git_verify_pack_object_list(output),
+        )
         # We specifically made a new blob that should be a delta
         # against the blob a_sha, so make sure we really got only 3
         # non-delta objects:
-        got_non_delta = int(_NON_DELTA_RE.search(output).group('non_delta'))
+        got_non_delta = int(_NON_DELTA_RE.search(output).group("non_delta"))
         self.assertEqual(
-            3, got_non_delta,
-            'Expected 3 non-delta objects, got %d' % got_non_delta)
+            3,
+            got_non_delta,
+            "Expected 3 non-delta objects, got %d" % got_non_delta,
+        )
         # We expect one object to have a delta chain length of two
         # (new_blob_2), so let's verify that actually happens:
-        self.assertIn(b'chain length = 2', output)
+        self.assertIn(b"chain length = 2", output)
 
     # This test is SUPER slow: over 80 seconds on a 2012-era
     # laptop. This is because SequenceMatcher is worst-case quadratic
@@ -134,23 +144,29 @@ class TestPack(PackTests):
         # This tests an object set that will have a copy operation
         # 2**25 in size. This is a copy large enough that it requires
         # two copy operations in git's binary delta format.
-        raise SkipTest('skipping slow, large test')
+        raise SkipTest("skipping slow, large test")
         with self.get_pack(pack1_sha) as orig_pack:
             new_blob = Blob()
-            new_blob.data = 'big blob' + ('x' * 2 ** 25)
+            new_blob.data = "big blob" + ("x" * 2 ** 25)
             new_blob_2 = Blob()
-            new_blob_2.data = new_blob.data + 'y'
-            all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None),
-                                                           (new_blob_2, None)]
+            new_blob_2.data = new_blob.data + "y"
+            all_to_pack = list(orig_pack.pack_tuples()) + [
+                (new_blob, None),
+                (new_blob_2, None),
+            ]
         pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         write_pack(pack_path, all_to_pack, deltify=True)
-        output = run_git_or_fail(['verify-pack', '-v', pack_path])
-        self.assertEqual(set(x[0].id for x in all_to_pack),
-                         _git_verify_pack_object_list(output))
+        output = run_git_or_fail(["verify-pack", "-v", pack_path])
+        self.assertEqual(
+            {x[0].id for x in all_to_pack},
+            _git_verify_pack_object_list(output),
+        )
         # We specifically made a new blob that should be a delta
         # against the blob a_sha, so make sure we really got only 4
         # non-delta objects:
-        got_non_delta = int(_NON_DELTA_RE.search(output).group('non_delta'))
+        got_non_delta = int(_NON_DELTA_RE.search(output).group("non_delta"))
         self.assertEqual(
-            4, got_non_delta,
-            'Expected 4 non-delta objects, got %d' % got_non_delta)
+            4,
+            got_non_delta,
+            "Expected 4 non-delta objects, got %d" % got_non_delta,
+        )

+ 6 - 6
dulwich/tests/compat/test_patch.py

@@ -27,20 +27,19 @@ import tempfile
 from dulwich import porcelain
 from dulwich.repo import (
     Repo,
-    )
+)
 from dulwich.tests.compat.utils import (
     CompatTestCase,
     run_git_or_fail,
-    )
+)
 
 
 class CompatPatchTestCase(CompatTestCase):
-
     def setUp(self):
         super(CompatPatchTestCase, self).setUp()
         self.test_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self.test_dir)
-        self.repo_path = os.path.join(self.test_dir, 'repo')
+        self.repo_path = os.path.join(self.test_dir, "repo")
         self.repo = Repo.init(self.repo_path, mkdir=True)
         self.addCleanup(self.repo.close)
 
@@ -82,8 +81,9 @@ class CompatPatchTestCase(CompatTestCase):
         second_tree = self.repo[second_commit].tree
 
         outstream = BytesIO()
-        porcelain.diff_tree(self.repo.path, first_tree, second_tree,
-                            outstream=outstream)
+        porcelain.diff_tree(
+            self.repo.path, first_tree, second_tree, outstream=outstream
+        )
 
         # Save it on disk
         patch_path = os.path.join(self.test_dir, "patch.patch")

+ 101 - 0
dulwich/tests/compat/test_porcelain.py

@@ -0,0 +1,101 @@
+# test_porcelain .py -- Tests for dulwich.porcelain/CGit compatibility
+# Copyright (C) 2010 Google, Inc.
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Compatibility tests for dulwich.porcelain."""
+
+import os
+import platform
+import sys
+from unittest import skipIf
+
+from dulwich import porcelain
+from dulwich.tests.utils import (
+    build_commit_graph,
+)
+from dulwich.tests.compat.utils import (
+    run_git_or_fail,
+    CompatTestCase,
+)
+from dulwich.tests.test_porcelain import (
+    PorcelainGpgTestCase,
+)
+
+
+@skipIf(platform.python_implementation() == "PyPy" or sys.platform == "win32", "gpgme not easily available or supported on Windows and PyPy")
+class TagCreateSignTestCase(PorcelainGpgTestCase, CompatTestCase):
+    def setUp(self):
+        super(TagCreateSignTestCase, self).setUp()
+
+    def test_sign(self):
+        # Test that dulwich signatures can be verified by CGit
+        c1, c2, c3 = build_commit_graph(
+            self.repo.object_store, [[1], [2, 1], [3, 1, 2]]
+        )
+        self.repo.refs[b"HEAD"] = c3.id
+        cfg = self.repo.get_config()
+        cfg.set(("user",), "signingKey", PorcelainGpgTestCase.DEFAULT_KEY_ID)
+        self.import_default_key()
+
+        porcelain.tag_create(
+            self.repo.path,
+            b"tryme",
+            b"foo <foo@bar.com>",
+            b"bar",
+            annotated=True,
+            sign=True,
+        )
+
+        run_git_or_fail(
+            [
+                "--git-dir={}".format(self.repo.controldir()),
+                "tag",
+                "-v",
+                "tryme"
+            ],
+            env={'GNUPGHOME': os.environ['GNUPGHOME']},
+        )
+
+    def test_verify(self):
+        # Test that CGit signatures can be verified by dulwich
+        c1, c2, c3 = build_commit_graph(
+            self.repo.object_store, [[1], [2, 1], [3, 1, 2]]
+        )
+        self.repo.refs[b"HEAD"] = c3.id
+        self.import_default_key()
+
+        run_git_or_fail(
+            [
+                "--git-dir={}".format(self.repo.controldir()),
+                "tag",
+                "-u",
+                PorcelainGpgTestCase.DEFAULT_KEY_ID,
+                "-m",
+                "foo",
+                "verifyme",
+            ],
+            env={
+                'GNUPGHOME': os.environ['GNUPGHOME'],
+                'GIT_COMMITTER_NAME': 'Joe Example',
+                'GIT_COMMITTER_EMAIL': 'joe@example.com',
+                },
+        )
+        tag = self.repo[b"refs/tags/verifyme"]
+        self.assertNotEqual(tag.signature, None)
+        tag.verify()

+ 41 - 43
dulwich/tests/compat/test_repository.py

@@ -28,17 +28,17 @@ import tempfile
 
 from dulwich.objects import (
     hex_to_sha,
-    )
+)
 from dulwich.repo import (
     check_ref_format,
     Repo,
-    )
+)
 from dulwich.tests.compat.utils import (
     require_git_version,
     rmtree_ro,
     run_git_or_fail,
     CompatTestCase,
-    )
+)
 
 
 class ObjectStoreTestCase(CompatTestCase):
@@ -46,7 +46,7 @@ class ObjectStoreTestCase(CompatTestCase):
 
     def setUp(self):
         super(ObjectStoreTestCase, self).setUp()
-        self._repo = self.import_repo('server_new.export')
+        self._repo = self.import_repo("server_new.export")
 
     def _run_git(self, args):
         return run_git_or_fail(args, cwd=self._repo.path)
@@ -54,7 +54,7 @@ class ObjectStoreTestCase(CompatTestCase):
     def _parse_refs(self, output):
         refs = {}
         for line in BytesIO(output):
-            fields = line.rstrip(b'\n').split(b' ')
+            fields = line.rstrip(b"\n").split(b" ")
             self.assertEqual(3, len(fields))
             refname, type_name, sha = fields
             check_ref_format(refname[5:])
@@ -63,26 +63,27 @@ class ObjectStoreTestCase(CompatTestCase):
         return refs
 
     def _parse_objects(self, output):
-        return set(s.rstrip(b'\n').split(b' ')[0] for s in BytesIO(output))
+        return {s.rstrip(b"\n").split(b" ")[0] for s in BytesIO(output)}
 
     def test_bare(self):
         self.assertTrue(self._repo.bare)
-        self.assertFalse(os.path.exists(os.path.join(self._repo.path, '.git')))
+        self.assertFalse(os.path.exists(os.path.join(self._repo.path, ".git")))
 
     def test_head(self):
-        output = self._run_git(['rev-parse', 'HEAD'])
-        head_sha = output.rstrip(b'\n')
+        output = self._run_git(["rev-parse", "HEAD"])
+        head_sha = output.rstrip(b"\n")
         hex_to_sha(head_sha)
-        self.assertEqual(head_sha, self._repo.refs[b'HEAD'])
+        self.assertEqual(head_sha, self._repo.refs[b"HEAD"])
 
     def test_refs(self):
         output = self._run_git(
-          ['for-each-ref', '--format=%(refname) %(objecttype) %(objectname)'])
+            ["for-each-ref", "--format=%(refname) %(objecttype) %(objectname)"]
+        )
         expected_refs = self._parse_refs(output)
 
         actual_refs = {}
         for refname, sha in self._repo.refs.as_dict().items():
-            if refname == b'HEAD':
+            if refname == b"HEAD":
                 continue  # handled in test_head
             obj = self._repo[sha]
             self.assertEqual(sha, obj.id)
@@ -92,12 +93,11 @@ class ObjectStoreTestCase(CompatTestCase):
     # TODO(dborowitz): peeled ref tests
 
     def _get_loose_shas(self):
-        output = self._run_git(
-            ['rev-list', '--all', '--objects', '--unpacked'])
+        output = self._run_git(["rev-list", "--all", "--objects", "--unpacked"])
         return self._parse_objects(output)
 
     def _get_all_shas(self):
-        output = self._run_git(['rev-list', '--all', '--objects'])
+        output = self._run_git(["rev-list", "--all", "--objects"])
         return self._parse_objects(output)
 
     def assertShasMatch(self, expected_shas, actual_shas_iter):
@@ -112,14 +112,14 @@ class ObjectStoreTestCase(CompatTestCase):
         # TODO(dborowitz): This is currently not very useful since
         # fast-imported repos only contained packed objects.
         expected_shas = self._get_loose_shas()
-        self.assertShasMatch(expected_shas,
-                             self._repo.object_store._iter_loose_objects())
+        self.assertShasMatch(
+            expected_shas, self._repo.object_store._iter_loose_objects()
+        )
 
     def test_packed_objects(self):
         expected_shas = self._get_all_shas() - self._get_loose_shas()
         self.assertShasMatch(
-            expected_shas,
-            chain.from_iterable(self._repo.object_store.packs)
+            expected_shas, chain.from_iterable(self._repo.object_store.packs)
         )
 
     def test_all_objects(self):
@@ -142,15 +142,13 @@ class WorkingTreeTestCase(ObjectStoreTestCase):
         Returns: The path to the new working tree.
         """
         temp_dir = tempfile.mkdtemp()
-        run_git_or_fail(['worktree', 'add', temp_dir, branch],
-                        cwd=repo_dir)
+        run_git_or_fail(["worktree", "add", temp_dir, branch], cwd=repo_dir)
         self.addCleanup(rmtree_ro, temp_dir)
         return temp_dir
 
     def setUp(self):
         super(WorkingTreeTestCase, self).setUp()
-        self._worktree_path = self.create_new_worktree(
-            self._repo.path, 'branch')
+        self._worktree_path = self.create_new_worktree(self._repo.path, "branch")
         self._worktree_repo = Repo(self._worktree_path)
         self.addCleanup(self._worktree_repo.close)
         self._mainworktree_repo = self._repo
@@ -159,42 +157,40 @@ class WorkingTreeTestCase(ObjectStoreTestCase):
 
     def test_refs(self):
         super(WorkingTreeTestCase, self).test_refs()
-        self.assertEqual(self._mainworktree_repo.refs.allkeys(),
-                         self._repo.refs.allkeys())
+        self.assertEqual(
+            self._mainworktree_repo.refs.allkeys(), self._repo.refs.allkeys()
+        )
 
     def test_head_equality(self):
-        self.assertNotEqual(self._repo.refs[b'HEAD'],
-                            self._mainworktree_repo.refs[b'HEAD'])
+        self.assertNotEqual(
+            self._repo.refs[b"HEAD"], self._mainworktree_repo.refs[b"HEAD"]
+        )
 
     def test_bare(self):
         self.assertFalse(self._repo.bare)
-        self.assertTrue(os.path.isfile(os.path.join(self._repo.path, '.git')))
+        self.assertTrue(os.path.isfile(os.path.join(self._repo.path, ".git")))
 
     def _parse_worktree_list(self, output):
         worktrees = []
         for line in BytesIO(output):
-            fields = line.rstrip(b'\n').split()
+            fields = line.rstrip(b"\n").split()
             worktrees.append(tuple(f.decode() for f in fields))
         return worktrees
 
     def test_git_worktree_list(self):
         # 'git worktree list' was introduced in 2.7.0
         require_git_version((2, 7, 0))
-        output = run_git_or_fail(['worktree', 'list'], cwd=self._repo.path)
+        output = run_git_or_fail(["worktree", "list"], cwd=self._repo.path)
         worktrees = self._parse_worktree_list(output)
         self.assertEqual(len(worktrees), self._number_of_working_tree)
-        self.assertEqual(worktrees[0][1], '(bare)')
-        self.assertTrue(
-            os.path.samefile(worktrees[0][0], self._mainworktree_repo.path))
+        self.assertEqual(worktrees[0][1], "(bare)")
+        self.assertTrue(os.path.samefile(worktrees[0][0], self._mainworktree_repo.path))
 
-        output = run_git_or_fail(
-            ['worktree', 'list'], cwd=self._mainworktree_repo.path)
+        output = run_git_or_fail(["worktree", "list"], cwd=self._mainworktree_repo.path)
         worktrees = self._parse_worktree_list(output)
         self.assertEqual(len(worktrees), self._number_of_working_tree)
-        self.assertEqual(worktrees[0][1], '(bare)')
-        self.assertTrue(os.path.samefile(
-            worktrees[0][0],
-            self._mainworktree_repo.path))
+        self.assertEqual(worktrees[0][1], "(bare)")
+        self.assertTrue(os.path.samefile(worktrees[0][0], self._mainworktree_repo.path))
 
 
 class InitNewWorkingDirectoryTestCase(WorkingTreeTestCase):
@@ -208,14 +204,16 @@ class InitNewWorkingDirectoryTestCase(WorkingTreeTestCase):
         worktree_repo_path = tempfile.mkdtemp()
         self.addCleanup(rmtree_ro, worktree_repo_path)
         self._repo = Repo._init_new_working_directory(
-            worktree_repo_path, self._mainworktree_repo)
+            worktree_repo_path, self._mainworktree_repo
+        )
         self.addCleanup(self._repo.close)
         self._number_of_working_tree = 3
 
     def test_head_equality(self):
-        self.assertEqual(self._repo.refs[b'HEAD'],
-                         self._mainworktree_repo.refs[b'HEAD'])
+        self.assertEqual(
+            self._repo.refs[b"HEAD"], self._mainworktree_repo.refs[b"HEAD"]
+        )
 
     def test_bare(self):
         self.assertFalse(self._repo.bare)
-        self.assertTrue(os.path.isfile(os.path.join(self._repo.path, '.git')))
+        self.assertTrue(os.path.isfile(os.path.join(self._repo.path, ".git")))

+ 14 - 17
dulwich/tests/compat/test_server.py

@@ -32,40 +32,38 @@ import sys
 from dulwich.server import (
     DictBackend,
     TCPGitServer,
-    )
+)
 from dulwich.tests import skipIf
 from dulwich.tests.compat.server_utils import (
     ServerTests,
     NoSideBand64kReceivePackHandler,
-    )
+)
 from dulwich.tests.compat.utils import (
     CompatTestCase,
     require_git_version,
-    )
+)
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class GitServerTestCase(ServerTests, CompatTestCase):
     """Tests for client/server compatibility.
 
     This server test case does not use side-band-64k in git-receive-pack.
     """
 
-    protocol = 'git'
+    protocol = "git"
 
     def _handlers(self):
-        return {b'git-receive-pack': NoSideBand64kReceivePackHandler}
+        return {b"git-receive-pack": NoSideBand64kReceivePackHandler}
 
     def _check_server(self, dul_server):
-        receive_pack_handler_cls = dul_server.handlers[b'git-receive-pack']
+        receive_pack_handler_cls = dul_server.handlers[b"git-receive-pack"]
         caps = receive_pack_handler_cls.capabilities()
-        self.assertFalse(b'side-band-64k' in caps)
+        self.assertFalse(b"side-band-64k" in caps)
 
     def _start_server(self, repo):
-        backend = DictBackend({b'/': repo})
-        dul_server = TCPGitServer(backend, b'localhost', 0,
-                                  handlers=self._handlers())
+        backend = DictBackend({b"/": repo})
+        dul_server = TCPGitServer(backend, b"localhost", 0, handlers=self._handlers())
         self._check_server(dul_server)
         self.addCleanup(dul_server.shutdown)
         self.addCleanup(dul_server.server_close)
@@ -75,8 +73,7 @@ class GitServerTestCase(ServerTests, CompatTestCase):
         return port
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class GitServerSideBand64kTestCase(GitServerTestCase):
     """Tests for client/server compatibility with side-band-64k support."""
 
@@ -88,13 +85,13 @@ class GitServerSideBand64kTestCase(GitServerTestCase):
         # side-band-64k is broken in the windows client.
         # https://github.com/msysgit/git/issues/101
         # Fix has landed for the 1.9.3 release.
-        if os.name == 'nt':
+        if os.name == "nt":
             require_git_version((1, 9, 3))
 
     def _handlers(self):
         return None  # default handlers include side-band-64k
 
     def _check_server(self, server):
-        receive_pack_handler_cls = server.handlers[b'git-receive-pack']
+        receive_pack_handler_cls = server.handlers[b"git-receive-pack"]
         caps = receive_pack_handler_cls.capabilities()
-        self.assertTrue(b'side-band-64k' in caps)
+        self.assertTrue(b"side-band-64k" in caps)

+ 12 - 14
dulwich/tests/compat/test_utils.py

@@ -23,20 +23,20 @@
 from dulwich.tests import (
     SkipTest,
     TestCase,
-    )
+)
 from dulwich.tests.compat import utils
 
 
 class GitVersionTests(TestCase):
-
     def setUp(self):
         super(GitVersionTests, self).setUp()
         self._orig_run_git = utils.run_git
         self._version_str = None  # tests can override to set stub version
 
         def run_git(args, **unused_kwargs):
-            self.assertEqual(['--version'], args)
+            self.assertEqual(["--version"], args)
             return 0, self._version_str
+
         utils.run_git = run_git
 
     def tearDown(self):
@@ -44,19 +44,19 @@ class GitVersionTests(TestCase):
         utils.run_git = self._orig_run_git
 
     def test_git_version_none(self):
-        self._version_str = b'not a git version'
+        self._version_str = b"not a git version"
         self.assertEqual(None, utils.git_version())
 
     def test_git_version_3(self):
-        self._version_str = b'git version 1.6.6'
+        self._version_str = b"git version 1.6.6"
         self.assertEqual((1, 6, 6, 0), utils.git_version())
 
     def test_git_version_4(self):
-        self._version_str = b'git version 1.7.0.2'
+        self._version_str = b"git version 1.7.0.2"
         self.assertEqual((1, 7, 0, 2), utils.git_version())
 
     def test_git_version_extra(self):
-        self._version_str = b'git version 1.7.0.3.295.gd8fa2'
+        self._version_str = b"git version 1.7.0.3.295.gd8fa2"
         self.assertEqual((1, 7, 0, 3), utils.git_version())
 
     def assertRequireSucceeds(self, required_version):
@@ -66,22 +66,20 @@ class GitVersionTests(TestCase):
             self.fail()
 
     def assertRequireFails(self, required_version):
-        self.assertRaises(SkipTest, utils.require_git_version,
-                          required_version)
+        self.assertRaises(SkipTest, utils.require_git_version, required_version)
 
     def test_require_git_version(self):
         try:
-            self._version_str = b'git version 1.6.6'
+            self._version_str = b"git version 1.6.6"
             self.assertRequireSucceeds((1, 6, 6))
             self.assertRequireSucceeds((1, 6, 6, 0))
             self.assertRequireSucceeds((1, 6, 5))
             self.assertRequireSucceeds((1, 6, 5, 99))
             self.assertRequireFails((1, 7, 0))
             self.assertRequireFails((1, 7, 0, 2))
-            self.assertRaises(ValueError, utils.require_git_version,
-                              (1, 6, 6, 0, 0))
+            self.assertRaises(ValueError, utils.require_git_version, (1, 6, 6, 0, 0))
 
-            self._version_str = b'git version 1.7.0.2'
+            self._version_str = b"git version 1.7.0.2"
             self.assertRequireSucceeds((1, 6, 6))
             self.assertRequireSucceeds((1, 6, 6, 0))
             self.assertRequireSucceeds((1, 7, 0))
@@ -90,4 +88,4 @@ class GitVersionTests(TestCase):
             self.assertRequireFails((1, 7, 1))
         except SkipTest as e:
             # This test is designed to catch all SkipTest exceptions.
-            self.fail('Test unexpectedly skipped: %s' % e)
+            self.fail("Test unexpectedly skipped: %s" % e)

+ 36 - 34
dulwich/tests/compat/test_web.py

@@ -34,29 +34,28 @@ from dulwich.server import (
     DictBackend,
     UploadPackHandler,
     ReceivePackHandler,
-    )
+)
 from dulwich.tests import (
     SkipTest,
     skipIf,
-    )
+)
 from dulwich.web import (
     make_wsgi_chain,
     HTTPGitApplication,
     WSGIRequestHandlerLogger,
     WSGIServerLogger,
-    )
+)
 
 from dulwich.tests.compat.server_utils import (
     ServerTests,
     NoSideBand64kReceivePackHandler,
-    )
+)
 from dulwich.tests.compat.utils import (
     CompatTestCase,
-    )
+)
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class WebTests(ServerTests):
     """Base tests for web server tests.
 
@@ -64,14 +63,18 @@ class WebTests(ServerTests):
     TestCase so tests are not automatically run.
     """
 
-    protocol = 'http'
+    protocol = "http"
 
     def _start_server(self, repo):
-        backend = DictBackend({'/': repo})
+        backend = DictBackend({"/": repo})
         app = self._make_app(backend)
         dul_server = simple_server.make_server(
-          'localhost', 0, app, server_class=WSGIServerLogger,
-          handler_class=WSGIRequestHandlerLogger)
+            "localhost",
+            0,
+            app,
+            server_class=WSGIServerLogger,
+            handler_class=WSGIRequestHandlerLogger,
+        )
         self.addCleanup(dul_server.shutdown)
         self.addCleanup(dul_server.server_close)
         threading.Thread(target=dul_server.serve_forever).start()
@@ -80,8 +83,7 @@ class WebTests(ServerTests):
         return port
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class SmartWebTestCase(WebTests, CompatTestCase):
     """Test cases for smart HTTP server.
 
@@ -91,12 +93,12 @@ class SmartWebTestCase(WebTests, CompatTestCase):
     min_git_version = (1, 6, 6)  # type: Tuple[int, ...]
 
     def _handlers(self):
-        return {b'git-receive-pack': NoSideBand64kReceivePackHandler}
+        return {b"git-receive-pack": NoSideBand64kReceivePackHandler}
 
     def _check_app(self, app):
-        receive_pack_handler_cls = app.handlers[b'git-receive-pack']
+        receive_pack_handler_cls = app.handlers[b"git-receive-pack"]
         caps = receive_pack_handler_cls.capabilities()
-        self.assertNotIn(b'side-band-64k', caps)
+        self.assertNotIn(b"side-band-64k", caps)
 
     def _make_app(self, backend):
         app = make_wsgi_chain(backend, handlers=self._handlers())
@@ -113,16 +115,17 @@ def patch_capabilities(handler, caps_removed):
     # removed, and return the original classmethod for restoration.
     original_capabilities = handler.capabilities
     filtered_capabilities = [
-        i for i in original_capabilities() if i not in caps_removed]
+        i for i in original_capabilities() if i not in caps_removed
+    ]
 
     def capabilities(cls):
         return filtered_capabilities
+
     handler.capabilities = classmethod(capabilities)
     return original_capabilities
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class SmartWebSideBand64kTestCase(SmartWebTestCase):
     """Test cases for smart HTTP server with side-band-64k support."""
 
@@ -143,10 +146,10 @@ class SmartWebSideBand64kTestCase(SmartWebTestCase):
         return None  # default handlers include side-band-64k
 
     def _check_app(self, app):
-        receive_pack_handler_cls = app.handlers[b'git-receive-pack']
+        receive_pack_handler_cls = app.handlers[b"git-receive-pack"]
         caps = receive_pack_handler_cls.capabilities()
-        self.assertIn(b'side-band-64k', caps)
-        self.assertNotIn(b'no-done', caps)
+        self.assertIn(b"side-band-64k", caps)
+        self.assertNotIn(b"no-done", caps)
 
 
 class SmartWebSideBand64kNoDoneTestCase(SmartWebTestCase):
@@ -161,14 +164,13 @@ class SmartWebSideBand64kNoDoneTestCase(SmartWebTestCase):
         return None  # default handlers include side-band-64k
 
     def _check_app(self, app):
-        receive_pack_handler_cls = app.handlers[b'git-receive-pack']
+        receive_pack_handler_cls = app.handlers[b"git-receive-pack"]
         caps = receive_pack_handler_cls.capabilities()
-        self.assertIn(b'side-band-64k', caps)
-        self.assertIn(b'no-done', caps)
+        self.assertIn(b"side-band-64k", caps)
+        self.assertIn(b"no-done", caps)
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class DumbWebTestCase(WebTests, CompatTestCase):
     """Test cases for dumb HTTP server."""
 
@@ -177,31 +179,31 @@ class DumbWebTestCase(WebTests, CompatTestCase):
 
     def test_push_to_dulwich(self):
         # Note: remove this if dulwich implements dumb web pushing.
-        raise SkipTest('Dumb web pushing not supported.')
+        raise SkipTest("Dumb web pushing not supported.")
 
     def test_push_to_dulwich_remove_branch(self):
         # Note: remove this if dumb pushing is supported
-        raise SkipTest('Dumb web pushing not supported.')
+        raise SkipTest("Dumb web pushing not supported.")
 
     def test_new_shallow_clone_from_dulwich(self):
         # Note: remove this if C git and dulwich implement dumb web shallow
         # clones.
-        raise SkipTest('Dumb web shallow cloning not supported.')
+        raise SkipTest("Dumb web shallow cloning not supported.")
 
     def test_shallow_clone_from_git_is_identical(self):
         # Note: remove this if C git and dulwich implement dumb web shallow
         # clones.
-        raise SkipTest('Dumb web shallow cloning not supported.')
+        raise SkipTest("Dumb web shallow cloning not supported.")
 
     def test_fetch_same_depth_into_shallow_clone_from_dulwich(self):
         # Note: remove this if C git and dulwich implement dumb web shallow
         # clones.
-        raise SkipTest('Dumb web shallow cloning not supported.')
+        raise SkipTest("Dumb web shallow cloning not supported.")
 
     def test_fetch_full_depth_into_shallow_clone_from_dulwich(self):
         # Note: remove this if C git and dulwich implement dumb web shallow
         # clones.
-        raise SkipTest('Dumb web shallow cloning not supported.')
+        raise SkipTest("Dumb web shallow cloning not supported.")
 
     def test_push_to_dulwich_issue_88_standard(self):
-        raise SkipTest('Dumb web pushing not supported.')
+        raise SkipTest("Dumb web pushing not supported.")

+ 44 - 35
dulwich/tests/compat/utils.py

@@ -38,12 +38,13 @@ from dulwich.protocol import TCP_GIT_PORT
 from dulwich.tests import (
     SkipTest,
     TestCase,
-    )
+)
 
-_DEFAULT_GIT = 'git'
+_DEFAULT_GIT = "git"
 _VERSION_LEN = 4
-_REPOS_DATA_DIR = os.path.abspath(os.path.join(
-    os.path.dirname(__file__), os.pardir, 'data', 'repos'))
+_REPOS_DATA_DIR = os.path.abspath(
+    os.path.join(os.path.dirname(__file__), os.pardir, "data", "repos")
+)
 
 
 def git_version(git_path=_DEFAULT_GIT):
@@ -56,14 +57,14 @@ def git_version(git_path=_DEFAULT_GIT):
         None if no git installation was found.
     """
     try:
-        output = run_git_or_fail(['--version'], git_path=git_path)
+        output = run_git_or_fail(["--version"], git_path=git_path)
     except OSError:
         return None
-    version_prefix = b'git version '
+    version_prefix = b"git version "
     if not output.startswith(version_prefix):
         return None
 
-    parts = output[len(version_prefix):].split(b'.')
+    parts = output[len(version_prefix) :].split(b".")
     nums = []
     for part in parts:
         try:
@@ -90,12 +91,15 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
     """
     found_version = git_version(git_path=git_path)
     if found_version is None:
-        raise SkipTest('Test requires git >= %s, but c git not found' %
-                       (required_version, ))
+        raise SkipTest(
+            "Test requires git >= %s, but c git not found" % (required_version,)
+        )
 
     if len(required_version) > _VERSION_LEN:
-        raise ValueError('Invalid version tuple %s, expected %i parts' %
-                         (required_version, _VERSION_LEN))
+        raise ValueError(
+            "Invalid version tuple %s, expected %i parts"
+            % (required_version, _VERSION_LEN)
+        )
 
     required_version = list(required_version)
     while len(found_version) < len(required_version):
@@ -103,14 +107,16 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
     required_version = tuple(required_version)
 
     if found_version < required_version:
-        required_version = '.'.join(map(str, required_version))
-        found_version = '.'.join(map(str, found_version))
-        raise SkipTest('Test requires git >= %s, found %s' %
-                       (required_version, found_version))
+        required_version = ".".join(map(str, required_version))
+        found_version = ".".join(map(str, found_version))
+        raise SkipTest(
+            "Test requires git >= %s, found %s" % (required_version, found_version)
+        )
 
 
-def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
-            **popen_kwargs):
+def run_git(
+    args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False, **popen_kwargs
+):
     """Run a git command.
 
     Input is piped from the input parameter and output is sent to the standard
@@ -129,15 +135,15 @@ def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
       OSError: if the git executable was not found.
     """
 
-    env = popen_kwargs.pop('env', {})
-    env['LC_ALL'] = env['LANG'] = 'C'
+    env = popen_kwargs.pop("env", {})
+    env["LC_ALL"] = env["LANG"] = "C"
 
     args = [git_path] + args
-    popen_kwargs['stdin'] = subprocess.PIPE
+    popen_kwargs["stdin"] = subprocess.PIPE
     if capture_stdout:
-        popen_kwargs['stdout'] = subprocess.PIPE
+        popen_kwargs["stdout"] = subprocess.PIPE
     else:
-        popen_kwargs.pop('stdout', None)
+        popen_kwargs.pop("stdout", None)
     p = subprocess.Popen(args, env=env, **popen_kwargs)
     stdout, stderr = p.communicate(input=input)
     return (p.returncode, stdout)
@@ -145,13 +151,15 @@ def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
 
 def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
     """Run a git command, capture stdout/stderr, and fail if git fails."""
-    if 'stderr' not in popen_kwargs:
-        popen_kwargs['stderr'] = subprocess.STDOUT
-    returncode, stdout = run_git(args, git_path=git_path, input=input,
-                                 capture_stdout=True, **popen_kwargs)
+    if "stderr" not in popen_kwargs:
+        popen_kwargs["stderr"] = subprocess.STDOUT
+    returncode, stdout = run_git(
+        args, git_path=git_path, input=input, capture_stdout=True, **popen_kwargs
+    )
     if returncode != 0:
-        raise AssertionError("git with args %r failed with %d: %r" % (
-            args, returncode, stdout))
+        raise AssertionError(
+            "git with args %r failed with %d: %r" % (args, returncode, stdout)
+        )
     return stdout
 
 
@@ -169,10 +177,9 @@ def import_repo_to_dir(name):
     temp_dir = tempfile.mkdtemp()
     export_path = os.path.join(_REPOS_DATA_DIR, name)
     temp_repo_dir = os.path.join(temp_dir, name)
-    export_file = open(export_path, 'rb')
-    run_git_or_fail(['init', '--quiet', '--bare', temp_repo_dir])
-    run_git_or_fail(['fast-import'], input=export_file.read(),
-                    cwd=temp_repo_dir)
+    export_file = open(export_path, "rb")
+    run_git_or_fail(["init", "--quiet", "--bare", temp_repo_dir])
+    run_git_or_fail(["fast-import"], input=export_file.read(), cwd=temp_repo_dir)
     export_file.close()
     return temp_repo_dir
 
@@ -195,12 +202,12 @@ def check_for_daemon(limit=10, delay=0.1, timeout=0.1, port=TCP_GIT_PORT):
         s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         s.settimeout(delay)
         try:
-            s.connect(('localhost', port))
+            s.connect(("localhost", port))
             return True
         except socket.timeout:
             pass
         except socket.error as e:
-            if getattr(e, 'errno', False) and e.errno != errno.ECONNREFUSED:
+            if getattr(e, "errno", False) and e.errno != errno.ECONNREFUSED:
                 raise
             elif e.args[0] != errno.ECONNREFUSED:
                 raise
@@ -251,11 +258,13 @@ class CompatTestCase(TestCase):
         def cleanup():
             repo.close()
             rmtree_ro(os.path.dirname(path.rstrip(os.sep)))
+
         self.addCleanup(cleanup)
         return repo
 
 
-if sys.platform == 'win32':
+if sys.platform == "win32":
+
     def remove_ro(action, name, exc):
         os.chmod(name, stat.S_IWRITE)
         os.remove(name)

+ 15 - 18
dulwich/tests/test_archive.py

@@ -28,31 +28,30 @@ from unittest import skipUnless
 from dulwich.archive import tar_stream
 from dulwich.object_store import (
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
     Blob,
     Tree,
-    )
+)
 from dulwich.tests import (
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
     build_commit_graph,
-    )
+)
 
 try:
     from unittest.mock import patch
 except ImportError:
-    patch = None   # type: ignore
+    patch = None  # type: ignore
 
 
 class ArchiveTests(TestCase):
-
     def test_empty(self):
         store = MemoryObjectStore()
         c1, c2, c3 = build_commit_graph(store, [[1], [2, 1], [3, 1, 2]])
         tree = store[c3.tree]
-        stream = b''.join(tar_stream(store, tree, 10))
+        stream = b"".join(tar_stream(store, tree, 10))
         out = BytesIO(stream)
         tf = tarfile.TarFile(fileobj=out)
         self.addCleanup(tf.close)
@@ -65,8 +64,7 @@ class ArchiveTests(TestCase):
         t1 = Tree()
         t1.add(b"somename", 0o100644, b1.id)
         store.add_object(t1)
-        stream = b''.join(
-            tar_stream(store, t1, *tar_stream_args, **tar_stream_kwargs))
+        stream = b"".join(tar_stream(store, t1, *tar_stream_args, **tar_stream_kwargs))
         return BytesIO(stream)
 
     def test_simple(self):
@@ -76,27 +74,26 @@ class ArchiveTests(TestCase):
         self.assertEqual(["somename"], tf.getnames())
 
     def test_prefix(self):
-        stream = self._get_example_tar_stream(mtime=0, prefix=b'blah')
+        stream = self._get_example_tar_stream(mtime=0, prefix=b"blah")
         tf = tarfile.TarFile(fileobj=stream)
         self.addCleanup(tf.close)
         self.assertEqual(["blah/somename"], tf.getnames())
 
     def test_gzip_mtime(self):
-        stream = self._get_example_tar_stream(mtime=1234, format='gz')
-        expected_mtime = struct.pack('<L', 1234)
+        stream = self._get_example_tar_stream(mtime=1234, format="gz")
+        expected_mtime = struct.pack("<L", 1234)
         self.assertEqual(stream.getvalue()[4:8], expected_mtime)
 
     @skipUnless(patch, "Required mock.patch")
     def test_same_file(self):
         contents = [None, None]
-        for format in ['', 'gz', 'bz2']:
+        for format in ["", "gz", "bz2"]:
             for i in [0, 1]:
-                with patch('time.time', return_value=i):
-                    stream = self._get_example_tar_stream(
-                        mtime=0, format=format)
+                with patch("time.time", return_value=i):
+                    stream = self._get_example_tar_stream(mtime=0, format=format)
                     contents[i] = stream.getvalue()
             self.assertEqual(
                 contents[0],
                 contents[1],
-                "Different file contents for format %r" % format
-                )
+                "Different file contents for format %r" % format,
+            )

+ 9 - 9
dulwich/tests/test_blackbox.py

@@ -25,10 +25,10 @@ import shutil
 
 from dulwich.repo import (
     Repo,
-    )
+)
 from dulwich.tests import (
     BlackboxTestCase,
-    )
+)
 
 
 class GitReceivePackTests(BlackboxTestCase):
@@ -43,16 +43,16 @@ class GitReceivePackTests(BlackboxTestCase):
     def test_basic(self):
         process = self.run_command("dul-receive-pack", [self.path])
         (stdout, stderr) = process.communicate(b"0000")
-        self.assertEqual(b'0000', stdout[-4:])
+        self.assertEqual(b"0000", stdout[-4:])
         self.assertEqual(0, process.returncode)
 
     def test_missing_arg(self):
         process = self.run_command("dul-receive-pack", [])
         (stdout, stderr) = process.communicate()
         self.assertEqual(
-            [b'usage: dul-receive-pack <git-dir>'],
-            stderr.splitlines()[-1:])
-        self.assertEqual(b'', stdout)
+            [b"usage: dul-receive-pack <git-dir>"], stderr.splitlines()[-1:]
+        )
+        self.assertEqual(b"", stdout)
         self.assertEqual(1, process.returncode)
 
 
@@ -69,7 +69,7 @@ class GitUploadPackTests(BlackboxTestCase):
         process = self.run_command("dul-upload-pack", [])
         (stdout, stderr) = process.communicate()
         self.assertEqual(
-            [b'usage: dul-upload-pack <git-dir>'],
-            stderr.splitlines()[-1:])
-        self.assertEqual(b'', stdout)
+            [b"usage: dul-upload-pack <git-dir>"], stderr.splitlines()[-1:]
+        )
+        self.assertEqual(b"", stdout)
         self.assertEqual(1, process.returncode)

+ 7 - 8
dulwich/tests/test_bundle.py

@@ -25,28 +25,27 @@ import tempfile
 
 from dulwich.tests import (
     TestCase,
-    )
+)
 
 from dulwich.bundle import (
     Bundle,
     read_bundle,
     write_bundle,
-    )
+)
 
 
 class BundleTests(TestCase):
-
     def test_roundtrip_bundle(self):
         origbundle = Bundle()
         origbundle.version = 3
-        origbundle.capabilities = {'foo': None}
-        origbundle.references = {b'refs/heads/master': b'ab' * 20}
-        origbundle.prerequisites = [(b'cc' * 20, 'comment')]
+        origbundle.capabilities = {"foo": None}
+        origbundle.references = {b"refs/heads/master": b"ab" * 20}
+        origbundle.prerequisites = [(b"cc" * 20, "comment")]
         with tempfile.TemporaryDirectory() as td:
-            with open(os.path.join(td, 'foo'), 'wb') as f:
+            with open(os.path.join(td, "foo"), "wb") as f:
                 write_bundle(f, origbundle)
 
-            with open(os.path.join(td, 'foo'), 'rb') as f:
+            with open(os.path.join(td, "foo"), "rb") as f:
                 newbundle = read_bundle(f)
 
                 self.assertEqual(origbundle, newbundle)

文件差異過大導致無法顯示
+ 312 - 286
dulwich/tests/test_client.py


+ 172 - 101
dulwich/tests/test_config.py

@@ -20,7 +20,12 @@
 
 """Tests for reading and writing configuration files."""
 
+import os
+import sys
 from io import BytesIO
+from unittest import skipIf
+from unittest.mock import patch
+
 from dulwich.config import (
     ConfigDict,
     ConfigFile,
@@ -31,14 +36,13 @@ from dulwich.config import (
     _escape_value,
     _parse_string,
     parse_submodules,
-    )
+)
 from dulwich.tests import (
     TestCase,
-    )
+)
 
 
 class ConfigFileTests(TestCase):
-
     def from_file(self, text):
         return ConfigFile.from_file(BytesIO(text))
 
@@ -49,17 +53,27 @@ class ConfigFileTests(TestCase):
         self.assertEqual(ConfigFile(), ConfigFile())
 
     def test_default_config(self):
-        cf = self.from_file(b"""[core]
+        cf = self.from_file(
+            b"""[core]
 \trepositoryformatversion = 0
 \tfilemode = true
 \tbare = false
 \tlogallrefupdates = true
-""")
-        self.assertEqual(ConfigFile({(b"core", ): {
-            b"repositoryformatversion": b"0",
-            b"filemode": b"true",
-            b"bare": b"false",
-            b"logallrefupdates": b"true"}}), cf)
+"""
+        )
+        self.assertEqual(
+            ConfigFile(
+                {
+                    (b"core",): {
+                        b"repositoryformatversion": b"0",
+                        b"filemode": b"true",
+                        b"bare": b"false",
+                        b"logallrefupdates": b"true",
+                    }
+                }
+            ),
+            cf,
+        )
 
     def test_from_file_empty(self):
         cf = self.from_file(b"")
@@ -67,81 +81,71 @@ class ConfigFileTests(TestCase):
 
     def test_empty_line_before_section(self):
         cf = self.from_file(b"\n[section]\n")
-        self.assertEqual(ConfigFile({(b"section", ): {}}), cf)
+        self.assertEqual(ConfigFile({(b"section",): {}}), cf)
 
     def test_comment_before_section(self):
         cf = self.from_file(b"# foo\n[section]\n")
-        self.assertEqual(ConfigFile({(b"section", ): {}}), cf)
+        self.assertEqual(ConfigFile({(b"section",): {}}), cf)
 
     def test_comment_after_section(self):
         cf = self.from_file(b"[section] # foo\n")
-        self.assertEqual(ConfigFile({(b"section", ): {}}), cf)
+        self.assertEqual(ConfigFile({(b"section",): {}}), cf)
 
     def test_comment_after_variable(self):
         cf = self.from_file(b"[section]\nbar= foo # a comment\n")
-        self.assertEqual(ConfigFile({(b"section", ): {b"bar": b"foo"}}), cf)
+        self.assertEqual(ConfigFile({(b"section",): {b"bar": b"foo"}}), cf)
 
     def test_comment_character_within_value_string(self):
-        cf = self.from_file(b"[section]\nbar= \"foo#bar\"\n")
-        self.assertEqual(
-            ConfigFile({(b"section", ): {b"bar": b"foo#bar"}}), cf)
+        cf = self.from_file(b'[section]\nbar= "foo#bar"\n')
+        self.assertEqual(ConfigFile({(b"section",): {b"bar": b"foo#bar"}}), cf)
 
     def test_comment_character_within_section_string(self):
-        cf = self.from_file(b"[branch \"foo#bar\"] # a comment\nbar= foo\n")
-        self.assertEqual(
-            ConfigFile({(b"branch", b"foo#bar"): {b"bar": b"foo"}}), cf)
+        cf = self.from_file(b'[branch "foo#bar"] # a comment\nbar= foo\n')
+        self.assertEqual(ConfigFile({(b"branch", b"foo#bar"): {b"bar": b"foo"}}), cf)
 
     def test_from_file_section(self):
         cf = self.from_file(b"[core]\nfoo = bar\n")
-        self.assertEqual(b"bar", cf.get((b"core", ), b"foo"))
+        self.assertEqual(b"bar", cf.get((b"core",), b"foo"))
         self.assertEqual(b"bar", cf.get((b"core", b"foo"), b"foo"))
 
     def test_from_file_section_case_insensitive_lower(self):
         cf = self.from_file(b"[cOre]\nfOo = bar\n")
-        self.assertEqual(b"bar", cf.get((b"core", ), b"foo"))
+        self.assertEqual(b"bar", cf.get((b"core",), b"foo"))
         self.assertEqual(b"bar", cf.get((b"core", b"foo"), b"foo"))
 
     def test_from_file_section_case_insensitive_mixed(self):
         cf = self.from_file(b"[cOre]\nfOo = bar\n")
-        self.assertEqual(b"bar", cf.get((b"core", ), b"fOo"))
+        self.assertEqual(b"bar", cf.get((b"core",), b"fOo"))
         self.assertEqual(b"bar", cf.get((b"cOre", b"fOo"), b"fOo"))
 
     def test_from_file_with_mixed_quoted(self):
-        cf = self.from_file(b"[core]\nfoo = \"bar\"la\n")
-        self.assertEqual(b"barla", cf.get((b"core", ), b"foo"))
+        cf = self.from_file(b'[core]\nfoo = "bar"la\n')
+        self.assertEqual(b"barla", cf.get((b"core",), b"foo"))
 
     def test_from_file_section_with_open_brackets(self):
         self.assertRaises(ValueError, self.from_file, b"[core\nfoo = bar\n")
 
     def test_from_file_value_with_open_quoted(self):
-        self.assertRaises(ValueError, self.from_file, b"[core]\nfoo = \"bar\n")
+        self.assertRaises(ValueError, self.from_file, b'[core]\nfoo = "bar\n')
 
     def test_from_file_with_quotes(self):
-        cf = self.from_file(
-            b"[core]\n"
-            b'foo = " bar"\n')
-        self.assertEqual(b" bar", cf.get((b"core", ), b"foo"))
+        cf = self.from_file(b"[core]\n" b'foo = " bar"\n')
+        self.assertEqual(b" bar", cf.get((b"core",), b"foo"))
 
     def test_from_file_with_interrupted_line(self):
-        cf = self.from_file(
-            b"[core]\n"
-            b'foo = bar\\\n'
-            b' la\n')
-        self.assertEqual(b"barla", cf.get((b"core", ), b"foo"))
+        cf = self.from_file(b"[core]\n" b"foo = bar\\\n" b" la\n")
+        self.assertEqual(b"barla", cf.get((b"core",), b"foo"))
 
     def test_from_file_with_boolean_setting(self):
-        cf = self.from_file(
-            b"[core]\n"
-            b'foo\n')
-        self.assertEqual(b"true", cf.get((b"core", ), b"foo"))
+        cf = self.from_file(b"[core]\n" b"foo\n")
+        self.assertEqual(b"true", cf.get((b"core",), b"foo"))
 
     def test_from_file_subsection(self):
-        cf = self.from_file(b"[branch \"foo\"]\nfoo = bar\n")
+        cf = self.from_file(b'[branch "foo"]\nfoo = bar\n')
         self.assertEqual(b"bar", cf.get((b"branch", b"foo"), b"foo"))
 
     def test_from_file_subsection_invalid(self):
-        self.assertRaises(
-                ValueError, self.from_file, b"[branch \"foo]\nfoo = bar\n")
+        self.assertRaises(ValueError, self.from_file, b'[branch "foo]\nfoo = bar\n')
 
     def test_from_file_subsection_not_quoted(self):
         cf = self.from_file(b"[branch.foo]\nfoo = bar\n")
@@ -155,7 +159,7 @@ class ConfigFileTests(TestCase):
 
     def test_write_to_file_section(self):
         c = ConfigFile()
-        c.set((b"core", ), b"foo", b"bar")
+        c.set((b"core",), b"foo", b"bar")
         f = BytesIO()
         c.write_to_file(f)
         self.assertEqual(b"[core]\n\tfoo = bar\n", f.getvalue())
@@ -165,100 +169,161 @@ class ConfigFileTests(TestCase):
         c.set((b"branch", b"blie"), b"foo", b"bar")
         f = BytesIO()
         c.write_to_file(f)
-        self.assertEqual(b"[branch \"blie\"]\n\tfoo = bar\n", f.getvalue())
+        self.assertEqual(b'[branch "blie"]\n\tfoo = bar\n', f.getvalue())
 
     def test_same_line(self):
         cf = self.from_file(b"[branch.foo] foo = bar\n")
         self.assertEqual(b"bar", cf.get((b"branch", b"foo"), b"foo"))
 
     def test_quoted(self):
-        cf = self.from_file(b"""[gui]
+        cf = self.from_file(
+            b"""[gui]
 \tfontdiff = -family \\\"Ubuntu Mono\\\" -size 11 -overstrike 0
-""")
-        self.assertEqual(ConfigFile({(b'gui', ): {
-            b'fontdiff': b'-family "Ubuntu Mono" -size 11 -overstrike 0',
-        }}), cf)
+"""
+        )
+        self.assertEqual(
+            ConfigFile(
+                {
+                    (b"gui",): {
+                        b"fontdiff": b'-family "Ubuntu Mono" -size 11 -overstrike 0',
+                    }
+                }
+            ),
+            cf,
+        )
 
     def test_quoted_multiline(self):
-        cf = self.from_file(b"""[alias]
+        cf = self.from_file(
+            b"""[alias]
 who = \"!who() {\\
   git log --no-merges --pretty=format:'%an - %ae' $@ | uniq -c | sort -rn;\\
 };\\
 who\"
-""")
-        self.assertEqual(ConfigFile({(b'alias', ): {
-            b'who': (b"!who() {git log --no-merges --pretty=format:'%an - "
-                     b"%ae' $@ | uniq -c | sort -rn;};who")
-            }}), cf)
+"""
+        )
+        self.assertEqual(
+            ConfigFile(
+                {
+                    (b"alias",): {
+                        b"who": (
+                            b"!who() {git log --no-merges --pretty=format:'%an - "
+                            b"%ae' $@ | uniq -c | sort -rn;};who"
+                        )
+                    }
+                }
+            ),
+            cf,
+        )
 
     def test_set_hash_gets_quoted(self):
         c = ConfigFile()
         c.set(b"xandikos", b"color", b"#665544")
         f = BytesIO()
         c.write_to_file(f)
-        self.assertEqual(b"[xandikos]\n\tcolor = \"#665544\"\n", f.getvalue())
+        self.assertEqual(b'[xandikos]\n\tcolor = "#665544"\n', f.getvalue())
 
 
 class ConfigDictTests(TestCase):
-
     def test_get_set(self):
         cd = ConfigDict()
         self.assertRaises(KeyError, cd.get, b"foo", b"core")
-        cd.set((b"core", ), b"foo", b"bla")
-        self.assertEqual(b"bla", cd.get((b"core", ), b"foo"))
-        cd.set((b"core", ), b"foo", b"bloe")
-        self.assertEqual(b"bloe", cd.get((b"core", ), b"foo"))
+        cd.set((b"core",), b"foo", b"bla")
+        self.assertEqual(b"bla", cd.get((b"core",), b"foo"))
+        cd.set((b"core",), b"foo", b"bloe")
+        self.assertEqual(b"bloe", cd.get((b"core",), b"foo"))
 
     def test_get_boolean(self):
         cd = ConfigDict()
-        cd.set((b"core", ), b"foo", b"true")
-        self.assertTrue(cd.get_boolean((b"core", ), b"foo"))
-        cd.set((b"core", ), b"foo", b"false")
-        self.assertFalse(cd.get_boolean((b"core", ), b"foo"))
-        cd.set((b"core", ), b"foo", b"invalid")
-        self.assertRaises(ValueError, cd.get_boolean, (b"core", ), b"foo")
+        cd.set((b"core",), b"foo", b"true")
+        self.assertTrue(cd.get_boolean((b"core",), b"foo"))
+        cd.set((b"core",), b"foo", b"false")
+        self.assertFalse(cd.get_boolean((b"core",), b"foo"))
+        cd.set((b"core",), b"foo", b"invalid")
+        self.assertRaises(ValueError, cd.get_boolean, (b"core",), b"foo")
 
     def test_dict(self):
         cd = ConfigDict()
-        cd.set((b"core", ), b"foo", b"bla")
-        cd.set((b"core2", ), b"foo", b"bloe")
+        cd.set((b"core",), b"foo", b"bla")
+        cd.set((b"core2",), b"foo", b"bloe")
 
-        self.assertEqual([(b"core", ), (b"core2", )], list(cd.keys()))
-        self.assertEqual(cd[(b"core", )], {b'foo': b'bla'})
+        self.assertEqual([(b"core",), (b"core2",)], list(cd.keys()))
+        self.assertEqual(cd[(b"core",)], {b"foo": b"bla"})
 
-        cd[b'a'] = b'b'
-        self.assertEqual(cd[b'a'], b'b')
+        cd[b"a"] = b"b"
+        self.assertEqual(cd[b"a"], b"b")
 
     def test_iteritems(self):
         cd = ConfigDict()
-        cd.set((b"core", ), b"foo", b"bla")
-        cd.set((b"core2", ), b"foo", b"bloe")
+        cd.set((b"core",), b"foo", b"bla")
+        cd.set((b"core2",), b"foo", b"bloe")
 
-        self.assertEqual(
-            [(b'foo', b'bla')],
-            list(cd.iteritems((b"core", ))))
+        self.assertEqual([(b"foo", b"bla")], list(cd.iteritems((b"core",))))
 
     def test_iteritems_nonexistant(self):
         cd = ConfigDict()
-        cd.set((b"core2", ), b"foo", b"bloe")
+        cd.set((b"core2",), b"foo", b"bloe")
 
-        self.assertEqual([], list(cd.iteritems((b"core", ))))
+        self.assertEqual([], list(cd.iteritems((b"core",))))
 
     def test_itersections(self):
         cd = ConfigDict()
-        cd.set((b"core2", ), b"foo", b"bloe")
+        cd.set((b"core2",), b"foo", b"bloe")
 
-        self.assertEqual([(b"core2", )], list(cd.itersections()))
+        self.assertEqual([(b"core2",)], list(cd.itersections()))
 
 
 class StackedConfigTests(TestCase):
+    def setUp(self):
+        super(StackedConfigTests, self).setUp()
+        self._old_path = os.environ.get("PATH")
+
+    def tearDown(self):
+        super(StackedConfigTests, self).tearDown()
+        os.environ["PATH"] = self._old_path
 
     def test_default_backends(self):
         StackedConfig.default_backends()
 
+    @skipIf(sys.platform != "win32", "Windows specfic config location.")
+    def test_windows_config_from_path(self):
+        from dulwich.config import get_win_system_paths
+
+        install_dir = os.path.join("C:", "foo", "Git")
+        os.environ["PATH"] = os.path.join(install_dir, "cmd")
+        with patch("os.path.exists", return_value=True):
+            paths = set(get_win_system_paths())
+        self.assertEqual(
+            {
+                os.path.join(os.environ.get("PROGRAMDATA"), "Git", "config"),
+                os.path.join(install_dir, "etc", "gitconfig"),
+            },
+            paths,
+        )
+
+    @skipIf(sys.platform != "win32", "Windows specfic config location.")
+    def test_windows_config_from_reg(self):
+        import winreg
+
+        from dulwich.config import get_win_system_paths
+
+        del os.environ["PATH"]
+        install_dir = os.path.join("C:", "foo", "Git")
+        with patch("winreg.OpenKey"):
+            with patch(
+                "winreg.QueryValueEx",
+                return_value=(install_dir, winreg.REG_SZ),
+            ):
+                paths = set(get_win_system_paths())
+        self.assertEqual(
+            {
+                os.path.join(os.environ.get("PROGRAMDATA"), "Git", "config"),
+                os.path.join(install_dir, "etc", "gitconfig"),
+            },
+            paths,
+        )
 
-class EscapeValueTests(TestCase):
 
+class EscapeValueTests(TestCase):
     def test_nothing(self):
         self.assertEqual(b"foo", _escape_value(b"foo"))
 
@@ -270,28 +335,26 @@ class EscapeValueTests(TestCase):
 
 
 class FormatStringTests(TestCase):
-
     def test_quoted(self):
         self.assertEqual(b'" foo"', _format_string(b" foo"))
         self.assertEqual(b'"\\tfoo"', _format_string(b"\tfoo"))
 
     def test_not_quoted(self):
-        self.assertEqual(b'foo', _format_string(b"foo"))
-        self.assertEqual(b'foo bar', _format_string(b"foo bar"))
+        self.assertEqual(b"foo", _format_string(b"foo"))
+        self.assertEqual(b"foo bar", _format_string(b"foo bar"))
 
 
 class ParseStringTests(TestCase):
-
     def test_quoted(self):
-        self.assertEqual(b' foo', _parse_string(b'" foo"'))
-        self.assertEqual(b'\tfoo', _parse_string(b'"\\tfoo"'))
+        self.assertEqual(b" foo", _parse_string(b'" foo"'))
+        self.assertEqual(b"\tfoo", _parse_string(b'"\\tfoo"'))
 
     def test_not_quoted(self):
-        self.assertEqual(b'foo', _parse_string(b"foo"))
-        self.assertEqual(b'foo bar', _parse_string(b"foo bar"))
+        self.assertEqual(b"foo", _parse_string(b"foo"))
+        self.assertEqual(b"foo bar", _parse_string(b"foo bar"))
 
     def test_nothing(self):
-        self.assertEqual(b"", _parse_string(b''))
+        self.assertEqual(b"", _parse_string(b""))
 
     def test_tab(self):
         self.assertEqual(b"\tbar\t", _parse_string(b"\\tbar\\t"))
@@ -300,11 +363,10 @@ class ParseStringTests(TestCase):
         self.assertEqual(b"\nbar\t", _parse_string(b"\\nbar\\t\t"))
 
     def test_quote(self):
-        self.assertEqual(b"\"foo\"", _parse_string(b"\\\"foo\\\""))
+        self.assertEqual(b'"foo"', _parse_string(b'\\"foo\\"'))
 
 
 class CheckVariableNameTests(TestCase):
-
     def test_invalid(self):
         self.assertFalse(_check_variable_name(b"foo "))
         self.assertFalse(_check_variable_name(b"bar,bar"))
@@ -317,7 +379,6 @@ class CheckVariableNameTests(TestCase):
 
 
 class CheckSectionNameTests(TestCase):
-
     def test_invalid(self):
         self.assertFalse(_check_section_name(b"foo "))
         self.assertFalse(_check_section_name(b"bar,bar"))
@@ -330,14 +391,24 @@ class CheckSectionNameTests(TestCase):
 
 
 class SubmodulesTests(TestCase):
-
     def testSubmodules(self):
-        cf = ConfigFile.from_file(BytesIO(b"""\
+        cf = ConfigFile.from_file(
+            BytesIO(
+                b"""\
 [submodule "core/lib"]
 \tpath = core/lib
 \turl = https://github.com/phhusson/QuasselC.git
-"""))
+"""
+            )
+        )
         got = list(parse_submodules(cf))
-        self.assertEqual([
-            (b'core/lib', b'https://github.com/phhusson/QuasselC.git',
-             b'core/lib')], got)
+        self.assertEqual(
+            [
+                (
+                    b"core/lib",
+                    b"https://github.com/phhusson/QuasselC.git",
+                    b"core/lib",
+                )
+            ],
+            got,
+        )

+ 787 - 584
dulwich/tests/test_diff_tree.py

@@ -37,33 +37,32 @@ from dulwich.diff_tree import (
     _tree_change_key,
     RenameDetector,
     _is_tree,
-    _is_tree_py
-    )
+    _is_tree_py,
+)
 from dulwich.index import (
     commit_tree,
-    )
+)
 from dulwich.object_store import (
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
     ShaFile,
     Blob,
     TreeEntry,
     Tree,
-    )
+)
 from dulwich.tests import (
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
     F,
     make_object,
     functest_builder,
     ext_functest_builder,
-    )
+)
 
 
 class DiffTestCase(TestCase):
-
     def setUp(self):
         super(DiffTestCase, self).setUp()
         self.store = MemoryObjectStore()
@@ -87,7 +86,6 @@ class DiffTestCase(TestCase):
 
 
 class TreeChangesTest(DiffTestCase):
-
     def setUp(self):
         super(TreeChangesTest, self).setUp()
         self.detector = RenameDetector(self.store)
@@ -95,62 +93,74 @@ class TreeChangesTest(DiffTestCase):
     def assertMergeFails(self, merge_entries, name, mode, sha):
         t = Tree()
         t[name] = (mode, sha)
-        self.assertRaises((TypeError, ValueError), merge_entries, '', t, t)
+        self.assertRaises((TypeError, ValueError), merge_entries, "", t, t)
 
     def _do_test_merge_entries(self, merge_entries):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_b1 = make_object(Blob, data=b'b1')
-        blob_c2 = make_object(Blob, data=b'c2')
-        tree1 = self.commit_tree([(b'a', blob_a1, 0o100644),
-                                  (b'b', blob_b1, 0o100755)])
-        tree2 = self.commit_tree([(b'a', blob_a2, 0o100644),
-                                  (b'c', blob_c2, 0o100755)])
-
-        self.assertEqual([], merge_entries(b'', self.empty_tree,
-                                           self.empty_tree))
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_b1 = make_object(Blob, data=b"b1")
+        blob_c2 = make_object(Blob, data=b"c2")
+        tree1 = self.commit_tree([(b"a", blob_a1, 0o100644), (b"b", blob_b1, 0o100755)])
+        tree2 = self.commit_tree([(b"a", blob_a2, 0o100644), (b"c", blob_c2, 0o100755)])
+
+        self.assertEqual([], merge_entries(b"", self.empty_tree, self.empty_tree))
         self.assertEqual(
-            [((None, None, None), (b'a', 0o100644, blob_a1.id)),
-             ((None, None, None), (b'b', 0o100755, blob_b1.id)), ],
-            merge_entries(b'', self.empty_tree, tree1))
+            [
+                ((None, None, None), (b"a", 0o100644, blob_a1.id)),
+                ((None, None, None), (b"b", 0o100755, blob_b1.id)),
+            ],
+            merge_entries(b"", self.empty_tree, tree1),
+        )
         self.assertEqual(
-            [((None, None, None), (b'x/a', 0o100644, blob_a1.id)),
-             ((None, None, None), (b'x/b', 0o100755, blob_b1.id)), ],
-            merge_entries(b'x', self.empty_tree, tree1))
+            [
+                ((None, None, None), (b"x/a", 0o100644, blob_a1.id)),
+                ((None, None, None), (b"x/b", 0o100755, blob_b1.id)),
+            ],
+            merge_entries(b"x", self.empty_tree, tree1),
+        )
 
         self.assertEqual(
-            [((b'a', 0o100644, blob_a2.id), (None, None, None)),
-             ((b'c', 0o100755, blob_c2.id), (None, None, None)), ],
-            merge_entries(b'', tree2, self.empty_tree))
+            [
+                ((b"a", 0o100644, blob_a2.id), (None, None, None)),
+                ((b"c", 0o100755, blob_c2.id), (None, None, None)),
+            ],
+            merge_entries(b"", tree2, self.empty_tree),
+        )
 
         self.assertEqual(
-            [((b'a', 0o100644, blob_a1.id), (b'a', 0o100644, blob_a2.id)),
-             ((b'b', 0o100755, blob_b1.id), (None, None, None)),
-             ((None, None, None), (b'c', 0o100755, blob_c2.id)), ],
-            merge_entries(b'', tree1, tree2))
+            [
+                ((b"a", 0o100644, blob_a1.id), (b"a", 0o100644, blob_a2.id)),
+                ((b"b", 0o100755, blob_b1.id), (None, None, None)),
+                ((None, None, None), (b"c", 0o100755, blob_c2.id)),
+            ],
+            merge_entries(b"", tree1, tree2),
+        )
 
         self.assertEqual(
-            [((b'a', 0o100644, blob_a2.id), (b'a', 0o100644, blob_a1.id)),
-             ((None, None, None), (b'b', 0o100755, blob_b1.id)),
-             ((b'c', 0o100755, blob_c2.id), (None, None, None)), ],
-            merge_entries(b'', tree2, tree1))
-
-        self.assertMergeFails(merge_entries, 0xdeadbeef, 0o100644, '1' * 40)
-        self.assertMergeFails(merge_entries, b'a', b'deadbeef', '1' * 40)
-        self.assertMergeFails(merge_entries, b'a', 0o100644, 0xdeadbeef)
-
-    test_merge_entries = functest_builder(_do_test_merge_entries,
-                                          _merge_entries_py)
-    test_merge_entries_extension = ext_functest_builder(_do_test_merge_entries,
-                                                        _merge_entries)
+            [
+                ((b"a", 0o100644, blob_a2.id), (b"a", 0o100644, blob_a1.id)),
+                ((None, None, None), (b"b", 0o100755, blob_b1.id)),
+                ((b"c", 0o100755, blob_c2.id), (None, None, None)),
+            ],
+            merge_entries(b"", tree2, tree1),
+        )
+
+        self.assertMergeFails(merge_entries, 0xDEADBEEF, 0o100644, "1" * 40)
+        self.assertMergeFails(merge_entries, b"a", b"deadbeef", "1" * 40)
+        self.assertMergeFails(merge_entries, b"a", 0o100644, 0xDEADBEEF)
+
+    test_merge_entries = functest_builder(_do_test_merge_entries, _merge_entries_py)
+    test_merge_entries_extension = ext_functest_builder(
+        _do_test_merge_entries, _merge_entries
+    )
 
     def _do_test_is_tree(self, is_tree):
         self.assertFalse(is_tree(TreeEntry(None, None, None)))
-        self.assertFalse(is_tree(TreeEntry(b'a', 0o100644, b'a' * 40)))
-        self.assertFalse(is_tree(TreeEntry(b'a', 0o100755, b'a' * 40)))
-        self.assertFalse(is_tree(TreeEntry(b'a', 0o120000, b'a' * 40)))
-        self.assertTrue(is_tree(TreeEntry(b'a', 0o040000, b'a' * 40)))
-        self.assertRaises(TypeError, is_tree, TreeEntry(b'a', b'x', b'a' * 40))
+        self.assertFalse(is_tree(TreeEntry(b"a", 0o100644, b"a" * 40)))
+        self.assertFalse(is_tree(TreeEntry(b"a", 0o100755, b"a" * 40)))
+        self.assertFalse(is_tree(TreeEntry(b"a", 0o120000, b"a" * 40)))
+        self.assertTrue(is_tree(TreeEntry(b"a", 0o040000, b"a" * 40)))
+        self.assertRaises(TypeError, is_tree, TreeEntry(b"a", b"x", b"a" * 40))
         self.assertRaises(AttributeError, is_tree, 1234)
 
     test_is_tree = functest_builder(_do_test_is_tree, _is_tree_py)
@@ -166,243 +176,334 @@ class TreeChangesTest(DiffTestCase):
         self.assertChangesEqual([], self.empty_tree, self.empty_tree)
 
     def test_tree_changes_no_changes(self):
-        blob = make_object(Blob, data=b'blob')
-        tree = self.commit_tree([(b'a', blob), (b'b/c', blob)])
+        blob = make_object(Blob, data=b"blob")
+        tree = self.commit_tree([(b"a", blob), (b"b/c", blob)])
         self.assertChangesEqual([], self.empty_tree, self.empty_tree)
         self.assertChangesEqual([], tree, tree)
         self.assertChangesEqual(
-            [TreeChange(CHANGE_UNCHANGED, (b'a', F, blob.id),
-                        (b'a', F, blob.id)),
-             TreeChange(CHANGE_UNCHANGED, (b'b/c', F, blob.id),
-                        (b'b/c', F, blob.id))],
-            tree, tree, want_unchanged=True)
+            [
+                TreeChange(CHANGE_UNCHANGED, (b"a", F, blob.id), (b"a", F, blob.id)),
+                TreeChange(
+                    CHANGE_UNCHANGED,
+                    (b"b/c", F, blob.id),
+                    (b"b/c", F, blob.id),
+                ),
+            ],
+            tree,
+            tree,
+            want_unchanged=True,
+        )
 
     def test_tree_changes_add_delete(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
-        tree = self.commit_tree([(b'a', blob_a, 0o100644),
-                                 (b'x/b', blob_b, 0o100755)])
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
+        tree = self.commit_tree([(b"a", blob_a, 0o100644), (b"x/b", blob_b, 0o100755)])
         self.assertChangesEqual(
-            [TreeChange.add((b'a', 0o100644, blob_a.id)),
-             TreeChange.add((b'x/b', 0o100755, blob_b.id))],
-            self.empty_tree, tree)
+            [
+                TreeChange.add((b"a", 0o100644, blob_a.id)),
+                TreeChange.add((b"x/b", 0o100755, blob_b.id)),
+            ],
+            self.empty_tree,
+            tree,
+        )
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', 0o100644, blob_a.id)),
-             TreeChange.delete((b'x/b', 0o100755, blob_b.id))],
-            tree, self.empty_tree)
+            [
+                TreeChange.delete((b"a", 0o100644, blob_a.id)),
+                TreeChange.delete((b"x/b", 0o100755, blob_b.id)),
+            ],
+            tree,
+            self.empty_tree,
+        )
 
     def test_tree_changes_modify_contents(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        tree1 = self.commit_tree([(b'a', blob_a1)])
-        tree2 = self.commit_tree([(b'a', blob_a2)])
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        tree1 = self.commit_tree([(b"a", blob_a1)])
+        tree2 = self.commit_tree([(b"a", blob_a2)])
         self.assertChangesEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob_a1.id),
-                        (b'a', F, blob_a2.id))],
-            tree1, tree2)
+            [TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a2.id))],
+            tree1,
+            tree2,
+        )
 
     def test_tree_changes_modify_mode(self):
-        blob_a = make_object(Blob, data=b'a')
-        tree1 = self.commit_tree([(b'a', blob_a, 0o100644)])
-        tree2 = self.commit_tree([(b'a', blob_a, 0o100755)])
+        blob_a = make_object(Blob, data=b"a")
+        tree1 = self.commit_tree([(b"a", blob_a, 0o100644)])
+        tree2 = self.commit_tree([(b"a", blob_a, 0o100755)])
         self.assertChangesEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', 0o100644, blob_a.id),
-                        (b'a', 0o100755, blob_a.id))],
-            tree1, tree2)
+            [
+                TreeChange(
+                    CHANGE_MODIFY,
+                    (b"a", 0o100644, blob_a.id),
+                    (b"a", 0o100755, blob_a.id),
+                )
+            ],
+            tree1,
+            tree2,
+        )
 
     def test_tree_changes_change_type(self):
-        blob_a1 = make_object(Blob, data=b'a')
-        blob_a2 = make_object(Blob, data=b'/foo/bar')
-        tree1 = self.commit_tree([(b'a', blob_a1, 0o100644)])
-        tree2 = self.commit_tree([(b'a', blob_a2, 0o120000)])
+        blob_a1 = make_object(Blob, data=b"a")
+        blob_a2 = make_object(Blob, data=b"/foo/bar")
+        tree1 = self.commit_tree([(b"a", blob_a1, 0o100644)])
+        tree2 = self.commit_tree([(b"a", blob_a2, 0o120000)])
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', 0o100644, blob_a1.id)),
-             TreeChange.add((b'a', 0o120000, blob_a2.id))],
-            tree1, tree2)
+            [
+                TreeChange.delete((b"a", 0o100644, blob_a1.id)),
+                TreeChange.add((b"a", 0o120000, blob_a2.id)),
+            ],
+            tree1,
+            tree2,
+        )
 
     def test_tree_changes_change_type_same(self):
-        blob_a1 = make_object(Blob, data=b'a')
-        blob_a2 = make_object(Blob, data=b'/foo/bar')
-        tree1 = self.commit_tree([(b'a', blob_a1, 0o100644)])
-        tree2 = self.commit_tree([(b'a', blob_a2, 0o120000)])
+        blob_a1 = make_object(Blob, data=b"a")
+        blob_a2 = make_object(Blob, data=b"/foo/bar")
+        tree1 = self.commit_tree([(b"a", blob_a1, 0o100644)])
+        tree2 = self.commit_tree([(b"a", blob_a2, 0o120000)])
         self.assertChangesEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', 0o100644, blob_a1.id),
-                        (b'a', 0o120000, blob_a2.id))],
-            tree1, tree2, change_type_same=True)
+            [
+                TreeChange(
+                    CHANGE_MODIFY,
+                    (b"a", 0o100644, blob_a1.id),
+                    (b"a", 0o120000, blob_a2.id),
+                )
+            ],
+            tree1,
+            tree2,
+            change_type_same=True,
+        )
 
     def test_tree_changes_to_tree(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_x = make_object(Blob, data=b'x')
-        tree1 = self.commit_tree([(b'a', blob_a)])
-        tree2 = self.commit_tree([(b'a/x', blob_x)])
+        blob_a = make_object(Blob, data=b"a")
+        blob_x = make_object(Blob, data=b"x")
+        tree1 = self.commit_tree([(b"a", blob_a)])
+        tree2 = self.commit_tree([(b"a/x", blob_x)])
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', F, blob_a.id)),
-             TreeChange.add((b'a/x', F, blob_x.id))],
-            tree1, tree2)
+            [
+                TreeChange.delete((b"a", F, blob_a.id)),
+                TreeChange.add((b"a/x", F, blob_x.id)),
+            ],
+            tree1,
+            tree2,
+        )
 
     def test_tree_changes_complex(self):
-        blob_a_1 = make_object(Blob, data=b'a1_1')
-        blob_bx1_1 = make_object(Blob, data=b'bx1_1')
-        blob_bx2_1 = make_object(Blob, data=b'bx2_1')
-        blob_by1_1 = make_object(Blob, data=b'by1_1')
-        blob_by2_1 = make_object(Blob, data=b'by2_1')
-        tree1 = self.commit_tree([
-            (b'a', blob_a_1),
-            (b'b/x/1', blob_bx1_1),
-            (b'b/x/2', blob_bx2_1),
-            (b'b/y/1', blob_by1_1),
-            (b'b/y/2', blob_by2_1),
-        ])
-
-        blob_a_2 = make_object(Blob, data=b'a1_2')
+        blob_a_1 = make_object(Blob, data=b"a1_1")
+        blob_bx1_1 = make_object(Blob, data=b"bx1_1")
+        blob_bx2_1 = make_object(Blob, data=b"bx2_1")
+        blob_by1_1 = make_object(Blob, data=b"by1_1")
+        blob_by2_1 = make_object(Blob, data=b"by2_1")
+        tree1 = self.commit_tree(
+            [
+                (b"a", blob_a_1),
+                (b"b/x/1", blob_bx1_1),
+                (b"b/x/2", blob_bx2_1),
+                (b"b/y/1", blob_by1_1),
+                (b"b/y/2", blob_by2_1),
+            ]
+        )
+
+        blob_a_2 = make_object(Blob, data=b"a1_2")
         blob_bx1_2 = blob_bx1_1
-        blob_by_2 = make_object(Blob, data=b'by_2')
-        blob_c_2 = make_object(Blob, data=b'c_2')
-        tree2 = self.commit_tree([
-            (b'a', blob_a_2),
-            (b'b/x/1', blob_bx1_2),
-            (b'b/y', blob_by_2),
-            (b'c', blob_c_2),
-        ])
+        blob_by_2 = make_object(Blob, data=b"by_2")
+        blob_c_2 = make_object(Blob, data=b"c_2")
+        tree2 = self.commit_tree(
+            [
+                (b"a", blob_a_2),
+                (b"b/x/1", blob_bx1_2),
+                (b"b/y", blob_by_2),
+                (b"c", blob_c_2),
+            ]
+        )
 
         self.assertChangesEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob_a_1.id),
-                        (b'a', F, blob_a_2.id)),
-             TreeChange.delete((b'b/x/2', F, blob_bx2_1.id)),
-             TreeChange.add((b'b/y', F, blob_by_2.id)),
-             TreeChange.delete((b'b/y/1', F, blob_by1_1.id)),
-             TreeChange.delete((b'b/y/2', F, blob_by2_1.id)),
-             TreeChange.add((b'c', F, blob_c_2.id))],
-            tree1, tree2)
+            [
+                TreeChange(
+                    CHANGE_MODIFY,
+                    (b"a", F, blob_a_1.id),
+                    (b"a", F, blob_a_2.id),
+                ),
+                TreeChange.delete((b"b/x/2", F, blob_bx2_1.id)),
+                TreeChange.add((b"b/y", F, blob_by_2.id)),
+                TreeChange.delete((b"b/y/1", F, blob_by1_1.id)),
+                TreeChange.delete((b"b/y/2", F, blob_by2_1.id)),
+                TreeChange.add((b"c", F, blob_c_2.id)),
+            ],
+            tree1,
+            tree2,
+        )
 
     def test_tree_changes_name_order(self):
-        blob = make_object(Blob, data=b'a')
-        tree1 = self.commit_tree([(b'a', blob), (b'a.', blob), (b'a..', blob)])
+        blob = make_object(Blob, data=b"a")
+        tree1 = self.commit_tree([(b"a", blob), (b"a.", blob), (b"a..", blob)])
         # Tree order is the reverse of this, so if we used tree order, 'a..'
         # would not be merged.
-        tree2 = self.commit_tree(
-                [(b'a/x', blob), (b'a./x', blob), (b'a..', blob)])
+        tree2 = self.commit_tree([(b"a/x", blob), (b"a./x", blob), (b"a..", blob)])
 
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', F, blob.id)),
-             TreeChange.add((b'a/x', F, blob.id)),
-             TreeChange.delete((b'a.', F, blob.id)),
-             TreeChange.add((b'a./x', F, blob.id))],
-            tree1, tree2)
+            [
+                TreeChange.delete((b"a", F, blob.id)),
+                TreeChange.add((b"a/x", F, blob.id)),
+                TreeChange.delete((b"a.", F, blob.id)),
+                TreeChange.add((b"a./x", F, blob.id)),
+            ],
+            tree1,
+            tree2,
+        )
 
     def test_tree_changes_prune(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_x = make_object(Blob, data=b'x')
-        tree1 = self.commit_tree([(b'a', blob_a1), (b'b/x', blob_x)])
-        tree2 = self.commit_tree([(b'a', blob_a2), (b'b/x', blob_x)])
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_x = make_object(Blob, data=b"x")
+        tree1 = self.commit_tree([(b"a", blob_a1), (b"b/x", blob_x)])
+        tree2 = self.commit_tree([(b"a", blob_a2), (b"b/x", blob_x)])
         # Remove identical items so lookups will fail unless we prune.
-        subtree = self.store[tree1[b'b'][1]]
+        subtree = self.store[tree1[b"b"][1]]
         for entry in subtree.items():
             del self.store[entry.sha]
         del self.store[subtree.id]
 
         self.assertChangesEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob_a1.id),
-                        (b'a', F, blob_a2.id))],
-            tree1, tree2)
+            [TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a2.id))],
+            tree1,
+            tree2,
+        )
 
     def test_tree_changes_rename_detector(self):
-        blob_a1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob_a2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob_b = make_object(Blob, data=b'b')
-        tree1 = self.commit_tree([(b'a', blob_a1), (b'b', blob_b)])
-        tree2 = self.commit_tree([(b'c', blob_a2), (b'b', blob_b)])
+        blob_a1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob_a2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob_b = make_object(Blob, data=b"b")
+        tree1 = self.commit_tree([(b"a", blob_a1), (b"b", blob_b)])
+        tree2 = self.commit_tree([(b"c", blob_a2), (b"b", blob_b)])
         detector = RenameDetector(self.store)
 
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', F, blob_a1.id)),
-             TreeChange.add((b'c', F, blob_a2.id))],
-            tree1, tree2)
+            [
+                TreeChange.delete((b"a", F, blob_a1.id)),
+                TreeChange.add((b"c", F, blob_a2.id)),
+            ],
+            tree1,
+            tree2,
+        )
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', F, blob_a1.id)),
-             TreeChange(CHANGE_UNCHANGED, (b'b', F, blob_b.id),
-                        (b'b', F, blob_b.id)),
-             TreeChange.add((b'c', F, blob_a2.id))],
-            tree1, tree2, want_unchanged=True)
+            [
+                TreeChange.delete((b"a", F, blob_a1.id)),
+                TreeChange(
+                    CHANGE_UNCHANGED,
+                    (b"b", F, blob_b.id),
+                    (b"b", F, blob_b.id),
+                ),
+                TreeChange.add((b"c", F, blob_a2.id)),
+            ],
+            tree1,
+            tree2,
+            want_unchanged=True,
+        )
         self.assertChangesEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob_a1.id),
-                        (b'c', F, blob_a2.id))],
-            tree1, tree2, rename_detector=detector)
+            [TreeChange(CHANGE_RENAME, (b"a", F, blob_a1.id), (b"c", F, blob_a2.id))],
+            tree1,
+            tree2,
+            rename_detector=detector,
+        )
         self.assertChangesEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob_a1.id),
-                        (b'c', F, blob_a2.id)),
-             TreeChange(CHANGE_UNCHANGED, (b'b', F, blob_b.id),
-                        (b'b', F, blob_b.id))],
-            tree1, tree2, rename_detector=detector, want_unchanged=True)
-
-    def assertChangesForMergeEqual(self, expected, parent_trees, merge_tree,
-                                   **kwargs):
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob_a1.id), (b"c", F, blob_a2.id)),
+                TreeChange(
+                    CHANGE_UNCHANGED,
+                    (b"b", F, blob_b.id),
+                    (b"b", F, blob_b.id),
+                ),
+            ],
+            tree1,
+            tree2,
+            rename_detector=detector,
+            want_unchanged=True,
+        )
+
+    def assertChangesForMergeEqual(self, expected, parent_trees, merge_tree, **kwargs):
         parent_tree_ids = [t.id for t in parent_trees]
-        actual = list(tree_changes_for_merge(
-          self.store, parent_tree_ids, merge_tree.id, **kwargs))
+        actual = list(
+            tree_changes_for_merge(self.store, parent_tree_ids, merge_tree.id, **kwargs)
+        )
         self.assertEqual(expected, actual)
 
         parent_tree_ids.reverse()
         expected = [list(reversed(cs)) for cs in expected]
-        actual = list(tree_changes_for_merge(
-          self.store, parent_tree_ids, merge_tree.id, **kwargs))
+        actual = list(
+            tree_changes_for_merge(self.store, parent_tree_ids, merge_tree.id, **kwargs)
+        )
         self.assertEqual(expected, actual)
 
     def test_tree_changes_for_merge_add_no_conflict(self):
-        blob = make_object(Blob, data=b'blob')
+        blob = make_object(Blob, data=b"blob")
         parent1 = self.commit_tree([])
-        parent2 = merge = self.commit_tree([(b'a', blob)])
+        parent2 = merge = self.commit_tree([(b"a", blob)])
         self.assertChangesForMergeEqual([], [parent1, parent2], merge)
         self.assertChangesForMergeEqual([], [parent2, parent2], merge)
 
     def test_tree_changes_for_merge_add_modify_conflict(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
         parent1 = self.commit_tree([])
-        parent2 = self.commit_tree([(b'a', blob1)])
-        merge = self.commit_tree([(b'a', blob2)])
+        parent2 = self.commit_tree([(b"a", blob1)])
+        merge = self.commit_tree([(b"a", blob2)])
         self.assertChangesForMergeEqual(
-            [[TreeChange.add((b'a', F, blob2.id)),
-              TreeChange(CHANGE_MODIFY, (b'a', F, blob1.id),
-                         (b'a', F, blob2.id))]],
-            [parent1, parent2], merge)
+            [
+                [
+                    TreeChange.add((b"a", F, blob2.id)),
+                    TreeChange(CHANGE_MODIFY, (b"a", F, blob1.id), (b"a", F, blob2.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+        )
 
     def test_tree_changes_for_merge_modify_modify_conflict(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        blob3 = make_object(Blob, data=b'3')
-        parent1 = self.commit_tree([(b'a', blob1)])
-        parent2 = self.commit_tree([(b'a', blob2)])
-        merge = self.commit_tree([(b'a', blob3)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        blob3 = make_object(Blob, data=b"3")
+        parent1 = self.commit_tree([(b"a", blob1)])
+        parent2 = self.commit_tree([(b"a", blob2)])
+        merge = self.commit_tree([(b"a", blob3)])
         self.assertChangesForMergeEqual(
-            [[TreeChange(CHANGE_MODIFY, (b'a', F, blob1.id),
-                         (b'a', F, blob3.id)),
-              TreeChange(CHANGE_MODIFY, (b'a', F, blob2.id),
-                         (b'a', F, blob3.id))]],
-            [parent1, parent2], merge)
+            [
+                [
+                    TreeChange(CHANGE_MODIFY, (b"a", F, blob1.id), (b"a", F, blob3.id)),
+                    TreeChange(CHANGE_MODIFY, (b"a", F, blob2.id), (b"a", F, blob3.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+        )
 
     def test_tree_changes_for_merge_modify_no_conflict(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        parent1 = self.commit_tree([(b'a', blob1)])
-        parent2 = merge = self.commit_tree([(b'a', blob2)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        parent1 = self.commit_tree([(b"a", blob1)])
+        parent2 = merge = self.commit_tree([(b"a", blob2)])
         self.assertChangesForMergeEqual([], [parent1, parent2], merge)
 
     def test_tree_changes_for_merge_delete_delete_conflict(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        parent1 = self.commit_tree([(b'a', blob1)])
-        parent2 = self.commit_tree([(b'a', blob2)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        parent1 = self.commit_tree([(b"a", blob1)])
+        parent2 = self.commit_tree([(b"a", blob2)])
         merge = self.commit_tree([])
         self.assertChangesForMergeEqual(
-            [[TreeChange.delete((b'a', F, blob1.id)),
-              TreeChange.delete((b'a', F, blob2.id))]],
-            [parent1, parent2], merge)
+            [
+                [
+                    TreeChange.delete((b"a", F, blob1.id)),
+                    TreeChange.delete((b"a", F, blob2.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+        )
 
     def test_tree_changes_for_merge_delete_no_conflict(self):
-        blob = make_object(Blob, data=b'blob')
-        has = self.commit_tree([(b'a', blob)])
+        blob = make_object(Blob, data=b"blob")
+        has = self.commit_tree([(b"a", blob)])
         doesnt_have = self.commit_tree([])
         self.assertChangesForMergeEqual([], [has, has], doesnt_have)
         self.assertChangesForMergeEqual([], [has, doesnt_have], doesnt_have)
@@ -410,7 +511,7 @@ class TreeChangesTest(DiffTestCase):
     def test_tree_changes_for_merge_octopus_no_conflict(self):
         r = list(range(5))
         blobs = [make_object(Blob, data=bytes(i)) for i in r]
-        parents = [self.commit_tree([(b'a', blobs[i])]) for i in r]
+        parents = [self.commit_tree([(b"a", blobs[i])]) for i in r]
         for i in r:
             # Take the SHA from each of the parents.
             self.assertChangesForMergeEqual([], parents, parents[i])
@@ -421,134 +522,168 @@ class TreeChangesTest(DiffTestCase):
         # defined, so test it anyway.
         r = list(range(5))
         parent_blobs = [make_object(Blob, data=bytes(i)) for i in r]
-        merge_blob = make_object(Blob, data=b'merge')
-        parents = [self.commit_tree([(b'a', parent_blobs[i])]) for i in r]
-        merge = self.commit_tree([(b'a', merge_blob)])
-        expected = [[TreeChange(CHANGE_MODIFY, (b'a', F, parent_blobs[i].id),
-                                (b'a', F, merge_blob.id)) for i in r]]
+        merge_blob = make_object(Blob, data=b"merge")
+        parents = [self.commit_tree([(b"a", parent_blobs[i])]) for i in r]
+        merge = self.commit_tree([(b"a", merge_blob)])
+        expected = [
+            [
+                TreeChange(
+                    CHANGE_MODIFY,
+                    (b"a", F, parent_blobs[i].id),
+                    (b"a", F, merge_blob.id),
+                )
+                for i in r
+            ]
+        ]
         self.assertChangesForMergeEqual(expected, parents, merge)
 
     def test_tree_changes_for_merge_octopus_delete(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'3')
-        parent1 = self.commit_tree([(b'a', blob1)])
-        parent2 = self.commit_tree([(b'a', blob2)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"3")
+        parent1 = self.commit_tree([(b"a", blob1)])
+        parent2 = self.commit_tree([(b"a", blob2)])
         parent3 = merge = self.commit_tree([])
         self.assertChangesForMergeEqual([], [parent1, parent1, parent1], merge)
         self.assertChangesForMergeEqual([], [parent1, parent1, parent3], merge)
         self.assertChangesForMergeEqual([], [parent1, parent3, parent3], merge)
         self.assertChangesForMergeEqual(
-            [[TreeChange.delete((b'a', F, blob1.id)),
-              TreeChange.delete((b'a', F, blob2.id)),
-              None]],
-            [parent1, parent2, parent3], merge)
+            [
+                [
+                    TreeChange.delete((b"a", F, blob1.id)),
+                    TreeChange.delete((b"a", F, blob2.id)),
+                    None,
+                ]
+            ],
+            [parent1, parent2, parent3],
+            merge,
+        )
 
     def test_tree_changes_for_merge_add_add_same_conflict(self):
-        blob = make_object(Blob, data=b'a\nb\nc\nd\n')
-        parent1 = self.commit_tree([(b'a', blob)])
+        blob = make_object(Blob, data=b"a\nb\nc\nd\n")
+        parent1 = self.commit_tree([(b"a", blob)])
         parent2 = self.commit_tree([])
-        merge = self.commit_tree([(b'b', blob)])
-        add = TreeChange.add((b'b', F, blob.id))
-        self.assertChangesForMergeEqual(
-                [[add, add]], [parent1, parent2], merge)
+        merge = self.commit_tree([(b"b", blob)])
+        add = TreeChange.add((b"b", F, blob.id))
+        self.assertChangesForMergeEqual([[add, add]], [parent1, parent2], merge)
 
     def test_tree_changes_for_merge_add_exact_rename_conflict(self):
-        blob = make_object(Blob, data=b'a\nb\nc\nd\n')
-        parent1 = self.commit_tree([(b'a', blob)])
+        blob = make_object(Blob, data=b"a\nb\nc\nd\n")
+        parent1 = self.commit_tree([(b"a", blob)])
         parent2 = self.commit_tree([])
-        merge = self.commit_tree([(b'b', blob)])
+        merge = self.commit_tree([(b"b", blob)])
         self.assertChangesForMergeEqual(
-            [[TreeChange(CHANGE_RENAME, (b'a', F, blob.id),
-                         (b'b', F, blob.id)),
-              TreeChange.add((b'b', F, blob.id))]],
-            [parent1, parent2], merge, rename_detector=self.detector)
+            [
+                [
+                    TreeChange(CHANGE_RENAME, (b"a", F, blob.id), (b"b", F, blob.id)),
+                    TreeChange.add((b"b", F, blob.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+            rename_detector=self.detector,
+        )
 
     def test_tree_changes_for_merge_add_content_rename_conflict(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        parent1 = self.commit_tree([(b'a', blob1)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        parent1 = self.commit_tree([(b"a", blob1)])
         parent2 = self.commit_tree([])
-        merge = self.commit_tree([(b'b', blob2)])
+        merge = self.commit_tree([(b"b", blob2)])
         self.assertChangesForMergeEqual(
-            [[TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                         (b'b', F, blob2.id)),
-              TreeChange.add((b'b', F, blob2.id))]],
-            [parent1, parent2], merge, rename_detector=self.detector)
+            [
+                [
+                    TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob2.id)),
+                    TreeChange.add((b"b", F, blob2.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+            rename_detector=self.detector,
+        )
 
     def test_tree_changes_for_merge_modify_rename_conflict(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        parent1 = self.commit_tree([(b'a', blob1)])
-        parent2 = self.commit_tree([(b'b', blob1)])
-        merge = self.commit_tree([(b'b', blob2)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        parent1 = self.commit_tree([(b"a", blob1)])
+        parent2 = self.commit_tree([(b"b", blob1)])
+        merge = self.commit_tree([(b"b", blob2)])
         self.assertChangesForMergeEqual(
-            [[TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                         (b'b', F, blob2.id)),
-              TreeChange(CHANGE_MODIFY, (b'b', F, blob1.id),
-                         (b'b', F, blob2.id))]],
-            [parent1, parent2], merge, rename_detector=self.detector)
+            [
+                [
+                    TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob2.id)),
+                    TreeChange(CHANGE_MODIFY, (b"b", F, blob1.id), (b"b", F, blob2.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+            rename_detector=self.detector,
+        )
 
 
 class RenameDetectionTest(DiffTestCase):
-
     def _do_test_count_blocks(self, count_blocks):
-        blob = make_object(Blob, data=b'a\nb\na\n')
-        self.assertBlockCountEqual({b'a\n': 4, b'b\n': 2}, count_blocks(blob))
+        blob = make_object(Blob, data=b"a\nb\na\n")
+        self.assertBlockCountEqual({b"a\n": 4, b"b\n": 2}, count_blocks(blob))
 
-    test_count_blocks = functest_builder(_do_test_count_blocks,
-                                         _count_blocks_py)
-    test_count_blocks_extension = ext_functest_builder(_do_test_count_blocks,
-                                                       _count_blocks)
+    test_count_blocks = functest_builder(_do_test_count_blocks, _count_blocks_py)
+    test_count_blocks_extension = ext_functest_builder(
+        _do_test_count_blocks, _count_blocks
+    )
 
     def _do_test_count_blocks_no_newline(self, count_blocks):
-        blob = make_object(Blob, data=b'a\na')
-        self.assertBlockCountEqual({b'a\n': 2, b'a': 1}, _count_blocks(blob))
+        blob = make_object(Blob, data=b"a\na")
+        self.assertBlockCountEqual({b"a\n": 2, b"a": 1}, _count_blocks(blob))
 
     test_count_blocks_no_newline = functest_builder(
-        _do_test_count_blocks_no_newline, _count_blocks_py)
+        _do_test_count_blocks_no_newline, _count_blocks_py
+    )
     test_count_blocks_no_newline_extension = ext_functest_builder(
-        _do_test_count_blocks_no_newline, _count_blocks)
+        _do_test_count_blocks_no_newline, _count_blocks
+    )
 
     def assertBlockCountEqual(self, expected, got):
         self.assertEqual(
-            {(hash(l) & 0xffffffff): c for (l, c) in expected.items()},
-            {(h & 0xffffffff): c for (h, c) in got.items()})
+            {(hash(l) & 0xFFFFFFFF): c for (l, c) in expected.items()},
+            {(h & 0xFFFFFFFF): c for (h, c) in got.items()},
+        )
 
     def _do_test_count_blocks_chunks(self, count_blocks):
-        blob = ShaFile.from_raw_chunks(Blob.type_num, [b'a\nb', b'\na\n'])
-        self.assertBlockCountEqual({b'a\n': 4, b'b\n': 2}, _count_blocks(blob))
+        blob = ShaFile.from_raw_chunks(Blob.type_num, [b"a\nb", b"\na\n"])
+        self.assertBlockCountEqual({b"a\n": 4, b"b\n": 2}, _count_blocks(blob))
 
-    test_count_blocks_chunks = functest_builder(_do_test_count_blocks_chunks,
-                                                _count_blocks_py)
+    test_count_blocks_chunks = functest_builder(
+        _do_test_count_blocks_chunks, _count_blocks_py
+    )
     test_count_blocks_chunks_extension = ext_functest_builder(
-        _do_test_count_blocks_chunks, _count_blocks)
+        _do_test_count_blocks_chunks, _count_blocks
+    )
 
     def _do_test_count_blocks_long_lines(self, count_blocks):
-        a = b'a' * 64
-        data = a + b'xxx\ny\n' + a + b'zzz\n'
+        a = b"a" * 64
+        data = a + b"xxx\ny\n" + a + b"zzz\n"
         blob = make_object(Blob, data=data)
         self.assertBlockCountEqual(
-            {b'a' * 64: 128,
-             b'xxx\n': 4,
-             b'y\n': 2,
-             b'zzz\n': 4},
-            _count_blocks(blob))
+            {b"a" * 64: 128, b"xxx\n": 4, b"y\n": 2, b"zzz\n": 4},
+            _count_blocks(blob),
+        )
 
     test_count_blocks_long_lines = functest_builder(
-        _do_test_count_blocks_long_lines, _count_blocks_py)
+        _do_test_count_blocks_long_lines, _count_blocks_py
+    )
     test_count_blocks_long_lines_extension = ext_functest_builder(
-        _do_test_count_blocks_long_lines, _count_blocks)
+        _do_test_count_blocks_long_lines, _count_blocks
+    )
 
     def assertSimilar(self, expected_score, blob1, blob2):
         self.assertEqual(expected_score, _similarity_score(blob1, blob2))
         self.assertEqual(expected_score, _similarity_score(blob2, blob1))
 
     def test_similarity_score(self):
-        blob0 = make_object(Blob, data=b'')
-        blob1 = make_object(Blob, data=b'ab\ncd\ncd\n')
-        blob2 = make_object(Blob, data=b'ab\n')
-        blob3 = make_object(Blob, data=b'cd\n')
-        blob4 = make_object(Blob, data=b'cd\ncd\n')
+        blob0 = make_object(Blob, data=b"")
+        blob1 = make_object(Blob, data=b"ab\ncd\ncd\n")
+        blob2 = make_object(Blob, data=b"ab\n")
+        blob3 = make_object(Blob, data=b"cd\n")
+        blob4 = make_object(Blob, data=b"cd\ncd\n")
 
         self.assertSimilar(100, blob0, blob0)
         self.assertSimilar(0, blob0, blob1)
@@ -559,396 +694,464 @@ class RenameDetectionTest(DiffTestCase):
         self.assertSimilar(50, blob3, blob4)
 
     def test_similarity_score_cache(self):
-        blob1 = make_object(Blob, data=b'ab\ncd\n')
-        blob2 = make_object(Blob, data=b'ab\n')
+        blob1 = make_object(Blob, data=b"ab\ncd\n")
+        blob2 = make_object(Blob, data=b"ab\n")
 
         block_cache = {}
-        self.assertEqual(
-            50, _similarity_score(blob1, blob2, block_cache=block_cache))
+        self.assertEqual(50, _similarity_score(blob1, blob2, block_cache=block_cache))
         self.assertEqual(set([blob1.id, blob2.id]), set(block_cache))
 
         def fail_chunks():
-            self.fail('Unexpected call to as_raw_chunks()')
+            self.fail("Unexpected call to as_raw_chunks()")
 
         blob1.as_raw_chunks = blob2.as_raw_chunks = fail_chunks
         blob1.raw_length = lambda: 6
         blob2.raw_length = lambda: 3
-        self.assertEqual(
-            50, _similarity_score(blob1, blob2, block_cache=block_cache))
+        self.assertEqual(50, _similarity_score(blob1, blob2, block_cache=block_cache))
 
     def test_tree_entry_sort(self):
-        sha = 'abcd' * 10
+        sha = "abcd" * 10
         expected_entries = [
-            TreeChange.add(TreeEntry(b'aaa', F, sha)),
-            TreeChange(CHANGE_COPY, TreeEntry(b'bbb', F, sha),
-                       TreeEntry(b'aab', F, sha)),
-            TreeChange(CHANGE_MODIFY, TreeEntry(b'bbb', F, sha),
-                       TreeEntry(b'bbb', F, b'dabc' * 10)),
-            TreeChange(CHANGE_RENAME, TreeEntry(b'bbc', F, sha),
-                       TreeEntry(b'ddd', F, sha)),
-            TreeChange.delete(TreeEntry(b'ccc', F, sha)),
+            TreeChange.add(TreeEntry(b"aaa", F, sha)),
+            TreeChange(
+                CHANGE_COPY,
+                TreeEntry(b"bbb", F, sha),
+                TreeEntry(b"aab", F, sha),
+            ),
+            TreeChange(
+                CHANGE_MODIFY,
+                TreeEntry(b"bbb", F, sha),
+                TreeEntry(b"bbb", F, b"dabc" * 10),
+            ),
+            TreeChange(
+                CHANGE_RENAME,
+                TreeEntry(b"bbc", F, sha),
+                TreeEntry(b"ddd", F, sha),
+            ),
+            TreeChange.delete(TreeEntry(b"ccc", F, sha)),
         ]
 
         for perm in permutations(expected_entries):
-            self.assertEqual(expected_entries,
-                             sorted(perm, key=_tree_change_key))
+            self.assertEqual(expected_entries, sorted(perm, key=_tree_change_key))
 
     def detect_renames(self, tree1, tree2, want_unchanged=False, **kwargs):
         detector = RenameDetector(self.store, **kwargs)
-        return detector.changes_with_renames(tree1.id, tree2.id,
-                                             want_unchanged=want_unchanged)
+        return detector.changes_with_renames(
+            tree1.id, tree2.id, want_unchanged=want_unchanged
+        )
 
     def test_no_renames(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\ne\nf\n')
-        blob3 = make_object(Blob, data=b'a\nb\ng\nh\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'a', blob1), (b'b', blob3)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\ne\nf\n")
+        blob3 = make_object(Blob, data=b"a\nb\ng\nh\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"a", blob1), (b"b", blob3)])
         self.assertEqual(
-            [TreeChange(CHANGE_MODIFY, (b'b', F, blob2.id),
-                        (b'b', F, blob3.id))],
-            self.detect_renames(tree1, tree2))
+            [TreeChange(CHANGE_MODIFY, (b"b", F, blob2.id), (b"b", F, blob3.id))],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_exact_rename_one_to_one(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'c', blob1), (b'd', blob2)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"c", blob1), (b"d", blob2)])
         self.assertEqual(
-                [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                            (b'c', F, blob1.id)),
-                 TreeChange(CHANGE_RENAME, (b'b', F, blob2.id),
-                            (b'd', F, blob2.id))],
-                self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"c", F, blob1.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob2.id), (b"d", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_exact_rename_split_different_type(self):
-        blob = make_object(Blob, data=b'/foo')
-        tree1 = self.commit_tree([(b'a', blob, 0o100644)])
-        tree2 = self.commit_tree([(b'a', blob, 0o120000)])
+        blob = make_object(Blob, data=b"/foo")
+        tree1 = self.commit_tree([(b"a", blob, 0o100644)])
+        tree2 = self.commit_tree([(b"a", blob, 0o120000)])
         self.assertEqual(
-            [TreeChange.add((b'a', 0o120000, blob.id)),
-             TreeChange.delete((b'a', 0o100644, blob.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange.add((b"a", 0o120000, blob.id)),
+                TreeChange.delete((b"a", 0o100644, blob.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_exact_rename_and_different_type(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'a', blob2, 0o120000), (b'b', blob1)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"a", blob2, 0o120000), (b"b", blob1)])
         self.assertEqual(
-                [TreeChange.add((b'a', 0o120000, blob2.id)),
-                 TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                            (b'b', F, blob1.id))],
-                self.detect_renames(tree1, tree2))
+            [
+                TreeChange.add((b"a", 0o120000, blob2.id)),
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob1.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_exact_rename_one_to_many(self):
-        blob = make_object(Blob, data=b'1')
-        tree1 = self.commit_tree([(b'a', blob)])
-        tree2 = self.commit_tree([(b'b', blob), (b'c', blob)])
+        blob = make_object(Blob, data=b"1")
+        tree1 = self.commit_tree([(b"a", blob)])
+        tree2 = self.commit_tree([(b"b", blob), (b"c", blob)])
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob.id), (b'b', F, blob.id)),
-             TreeChange(CHANGE_COPY, (b'a', F, blob.id), (b'c', F, blob.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob.id), (b"b", F, blob.id)),
+                TreeChange(CHANGE_COPY, (b"a", F, blob.id), (b"c", F, blob.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_exact_rename_many_to_one(self):
-        blob = make_object(Blob, data=b'1')
-        tree1 = self.commit_tree([(b'a', blob), (b'b', blob)])
-        tree2 = self.commit_tree([(b'c', blob)])
+        blob = make_object(Blob, data=b"1")
+        tree1 = self.commit_tree([(b"a", blob), (b"b", blob)])
+        tree2 = self.commit_tree([(b"c", blob)])
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob.id), (b'c', F, blob.id)),
-             TreeChange.delete((b'b', F, blob.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob.id), (b"c", F, blob.id)),
+                TreeChange.delete((b"b", F, blob.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_exact_rename_many_to_many(self):
-        blob = make_object(Blob, data=b'1')
-        tree1 = self.commit_tree([(b'a', blob), (b'b', blob)])
-        tree2 = self.commit_tree([(b'c', blob), (b'd', blob), (b'e', blob)])
-        self.assertEqual(
-                [TreeChange(CHANGE_RENAME, (b'a', F, blob.id),
-                            (b'c', F, blob.id)),
-                 TreeChange(CHANGE_COPY, (b'a', F, blob.id),
-                            (b'e', F, blob.id)),
-                 TreeChange(CHANGE_RENAME, (b'b', F, blob.id),
-                            (b'd', F, blob.id))],
-                self.detect_renames(tree1, tree2))
+        blob = make_object(Blob, data=b"1")
+        tree1 = self.commit_tree([(b"a", blob), (b"b", blob)])
+        tree2 = self.commit_tree([(b"c", blob), (b"d", blob), (b"e", blob)])
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob.id), (b"c", F, blob.id)),
+                TreeChange(CHANGE_COPY, (b"a", F, blob.id), (b"e", F, blob.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob.id), (b"d", F, blob.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_exact_copy_modify(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'a', blob2), (b'b', blob1)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"a", blob2), (b"b", blob1)])
         self.assertEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob1.id),
-                        (b'a', F, blob2.id)),
-             TreeChange(CHANGE_COPY, (b'a', F, blob1.id),
-                        (b'b', F, blob1.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_MODIFY, (b"a", F, blob1.id), (b"a", F, blob2.id)),
+                TreeChange(CHANGE_COPY, (b"a", F, blob1.id), (b"b", F, blob1.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_exact_copy_change_mode(self):
-        blob = make_object(Blob, data=b'a\nb\nc\nd\n')
-        tree1 = self.commit_tree([(b'a', blob)])
-        tree2 = self.commit_tree([(b'a', blob, 0o100755), (b'b', blob)])
+        blob = make_object(Blob, data=b"a\nb\nc\nd\n")
+        tree1 = self.commit_tree([(b"a", blob)])
+        tree2 = self.commit_tree([(b"a", blob, 0o100755), (b"b", blob)])
         self.assertEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob.id),
-                        (b'a', 0o100755, blob.id)),
-             TreeChange(CHANGE_COPY, (b'a', F, blob.id), (b'b', F, blob.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(
+                    CHANGE_MODIFY,
+                    (b"a", F, blob.id),
+                    (b"a", 0o100755, blob.id),
+                ),
+                TreeChange(CHANGE_COPY, (b"a", F, blob.id), (b"b", F, blob.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_rename_threshold(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\n')
-        blob2 = make_object(Blob, data=b'a\nb\nd\n')
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'b', blob2)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\n")
+        blob2 = make_object(Blob, data=b"a\nb\nd\n")
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"b", blob2)])
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2, rename_threshold=50))
+            [TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob2.id))],
+            self.detect_renames(tree1, tree2, rename_threshold=50),
+        )
         self.assertEqual(
-            [TreeChange.delete((b'a', F, blob1.id)),
-             TreeChange.add((b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2, rename_threshold=75))
+            [
+                TreeChange.delete((b"a", F, blob1.id)),
+                TreeChange.add((b"b", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2, rename_threshold=75),
+        )
 
     def test_content_rename_max_files(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd')
-        blob4 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob2 = make_object(Blob, data=b'e\nf\ng\nh\n')
-        blob3 = make_object(Blob, data=b'e\nf\ng\ni\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'c', blob3), (b'd', blob4)])
-        self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'd', F, blob4.id)),
-             TreeChange(CHANGE_RENAME, (b'b', F, blob2.id),
-                        (b'c', F, blob3.id))],
-            self.detect_renames(tree1, tree2))
-        self.assertEqual(
-            [TreeChange.delete((b'a', F, blob1.id)),
-             TreeChange.delete((b'b', F, blob2.id)),
-             TreeChange.add((b'c', F, blob3.id)),
-             TreeChange.add((b'd', F, blob4.id))],
-            self.detect_renames(tree1, tree2, max_files=1))
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd")
+        blob4 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob2 = make_object(Blob, data=b"e\nf\ng\nh\n")
+        blob3 = make_object(Blob, data=b"e\nf\ng\ni\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"c", blob3), (b"d", blob4)])
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"d", F, blob4.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob2.id), (b"c", F, blob3.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
+        self.assertEqual(
+            [
+                TreeChange.delete((b"a", F, blob1.id)),
+                TreeChange.delete((b"b", F, blob2.id)),
+                TreeChange.add((b"c", F, blob3.id)),
+                TreeChange.add((b"d", F, blob4.id)),
+            ],
+            self.detect_renames(tree1, tree2, max_files=1),
+        )
 
     def test_content_rename_one_to_one(self):
-        b11 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        b12 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        b21 = make_object(Blob, data=b'e\nf\ng\n\nh')
-        b22 = make_object(Blob, data=b'e\nf\ng\n\ni')
-        tree1 = self.commit_tree([(b'a', b11), (b'b', b21)])
-        tree2 = self.commit_tree([(b'c', b12), (b'd', b22)])
+        b11 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        b12 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        b21 = make_object(Blob, data=b"e\nf\ng\n\nh")
+        b22 = make_object(Blob, data=b"e\nf\ng\n\ni")
+        tree1 = self.commit_tree([(b"a", b11), (b"b", b21)])
+        tree2 = self.commit_tree([(b"c", b12), (b"d", b22)])
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, b11.id), (b'c', F, b12.id)),
-             TreeChange(CHANGE_RENAME, (b'b', F, b21.id), (b'd', F, b22.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, b11.id), (b"c", F, b12.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, b21.id), (b"d", F, b22.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_content_rename_one_to_one_ordering(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\ne\nf\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\nd\ng\nh\n')
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\ne\nf\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\nd\ng\nh\n")
         # 6/10 match to blob1, 8/10 match to blob2
-        blob3 = make_object(Blob, data=b'a\nb\nc\nd\ng\ni\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'c', blob3)])
+        blob3 = make_object(Blob, data=b"a\nb\nc\nd\ng\ni\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"c", blob3)])
         self.assertEqual(
-            [TreeChange.delete((b'a', F, blob1.id)),
-             TreeChange(CHANGE_RENAME, (b'b', F, blob2.id),
-                        (b'c', F, blob3.id))],
-            self.detect_renames(tree1, tree2))
-
-        tree3 = self.commit_tree([(b'a', blob2), (b'b', blob1)])
-        tree4 = self.commit_tree([(b'c', blob3)])
+            [
+                TreeChange.delete((b"a", F, blob1.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob2.id), (b"c", F, blob3.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
+
+        tree3 = self.commit_tree([(b"a", blob2), (b"b", blob1)])
+        tree4 = self.commit_tree([(b"c", blob3)])
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob2.id),
-                        (b'c', F, blob3.id)),
-             TreeChange.delete((b'b', F, blob1.id))],
-            self.detect_renames(tree3, tree4))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob2.id), (b"c", F, blob3.id)),
+                TreeChange.delete((b"b", F, blob1.id)),
+            ],
+            self.detect_renames(tree3, tree4),
+        )
 
     def test_content_rename_one_to_many(self):
-        blob1 = make_object(Blob, data=b'aa\nb\nc\nd\ne\n')
-        blob2 = make_object(Blob, data=b'ab\nb\nc\nd\ne\n')  # 8/11 match
-        blob3 = make_object(Blob, data=b'aa\nb\nc\nd\nf\n')  # 9/11 match
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'b', blob2), (b'c', blob3)])
+        blob1 = make_object(Blob, data=b"aa\nb\nc\nd\ne\n")
+        blob2 = make_object(Blob, data=b"ab\nb\nc\nd\ne\n")  # 8/11 match
+        blob3 = make_object(Blob, data=b"aa\nb\nc\nd\nf\n")  # 9/11 match
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"b", blob2), (b"c", blob3)])
         self.assertEqual(
-            [TreeChange(CHANGE_COPY, (b'a', F, blob1.id), (b'b', F, blob2.id)),
-             TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'c', F, blob3.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_COPY, (b"a", F, blob1.id), (b"b", F, blob2.id)),
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"c", F, blob3.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_content_rename_many_to_one(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob3 = make_object(Blob, data=b'a\nb\nc\nf\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'c', blob3)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob3 = make_object(Blob, data=b"a\nb\nc\nf\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"c", blob3)])
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'c', F, blob3.id)),
-             TreeChange.delete((b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"c", F, blob3.id)),
+                TreeChange.delete((b"b", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_content_rename_many_to_many(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob3 = make_object(Blob, data=b'a\nb\nc\nf\n')
-        blob4 = make_object(Blob, data=b'a\nb\nc\ng\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'c', blob3), (b'd', blob4)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob3 = make_object(Blob, data=b"a\nb\nc\nf\n")
+        blob4 = make_object(Blob, data=b"a\nb\nc\ng\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"c", blob3), (b"d", blob4)])
         # TODO(dborowitz): Distribute renames rather than greedily choosing
         # copies.
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'c', F, blob3.id)),
-             TreeChange(CHANGE_COPY, (b'a', F, blob1.id), (b'd', F, blob4.id)),
-             TreeChange.delete((b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"c", F, blob3.id)),
+                TreeChange(CHANGE_COPY, (b"a", F, blob1.id), (b"d", F, blob4.id)),
+                TreeChange.delete((b"b", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_content_rename_with_more_deletions(self):
-        blob1 = make_object(Blob, data=b'')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob1), (b'c', blob1),
-                                  (b'd', blob1)])
-        tree2 = self.commit_tree([(b'e', blob1), (b'f', blob1), (b'g', blob1)])
+        blob1 = make_object(Blob, data=b"")
+        tree1 = self.commit_tree(
+            [(b"a", blob1), (b"b", blob1), (b"c", blob1), (b"d", blob1)]
+        )
+        tree2 = self.commit_tree([(b"e", blob1), (b"f", blob1), (b"g", blob1)])
         self.maxDiff = None
         self.assertEqual(
-          [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id), (b'e', F, blob1.id)),
-           TreeChange(CHANGE_RENAME, (b'b', F, blob1.id), (b'f', F, blob1.id)),
-           TreeChange(CHANGE_RENAME, (b'c', F, blob1.id), (b'g', F, blob1.id)),
-           TreeChange.delete((b'd', F, blob1.id))],
-          self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"e", F, blob1.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob1.id), (b"f", F, blob1.id)),
+                TreeChange(CHANGE_RENAME, (b"c", F, blob1.id), (b"g", F, blob1.id)),
+                TreeChange.delete((b"d", F, blob1.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_content_rename_gitlink(self):
-        blob1 = make_object(Blob, data=b'blob1')
-        blob2 = make_object(Blob, data=b'blob2')
-        link1 = b'1' * 40
-        link2 = b'2' * 40
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', link1, 0o160000)])
-        tree2 = self.commit_tree([(b'c', blob2), (b'd', link2, 0o160000)])
-        self.assertEqual(
-            [TreeChange.delete((b'a', 0o100644, blob1.id)),
-             TreeChange.delete((b'b', 0o160000, link1)),
-             TreeChange.add((b'c', 0o100644, blob2.id)),
-             TreeChange.add((b'd', 0o160000, link2))],
-            self.detect_renames(tree1, tree2))
+        blob1 = make_object(Blob, data=b"blob1")
+        blob2 = make_object(Blob, data=b"blob2")
+        link1 = b"1" * 40
+        link2 = b"2" * 40
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", link1, 0o160000)])
+        tree2 = self.commit_tree([(b"c", blob2), (b"d", link2, 0o160000)])
+        self.assertEqual(
+            [
+                TreeChange.delete((b"a", 0o100644, blob1.id)),
+                TreeChange.delete((b"b", 0o160000, link1)),
+                TreeChange.add((b"c", 0o100644, blob2.id)),
+                TreeChange.add((b"d", 0o160000, link2)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
     def test_exact_rename_swap(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'a', blob2), (b'b', blob1)])
-        self.assertEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob1.id),
-                        (b'a', F, blob2.id)),
-             TreeChange(CHANGE_MODIFY, (b'b', F, blob2.id),
-                        (b'b', F, blob1.id))],
-            self.detect_renames(tree1, tree2))
-        self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'b', F, blob1.id)),
-             TreeChange(CHANGE_RENAME, (b'b', F, blob2.id),
-                        (b'a', F, blob2.id))],
-            self.detect_renames(tree1, tree2, rewrite_threshold=50))
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"a", blob2), (b"b", blob1)])
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_MODIFY, (b"a", F, blob1.id), (b"a", F, blob2.id)),
+                TreeChange(CHANGE_MODIFY, (b"b", F, blob2.id), (b"b", F, blob1.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob1.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob2.id), (b"a", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2, rewrite_threshold=50),
+        )
 
     def test_content_rename_swap(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'e\nf\ng\nh\n')
-        blob3 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob4 = make_object(Blob, data=b'e\nf\ng\ni\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'a', blob4), (b'b', blob3)])
-        self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'b', F, blob3.id)),
-             TreeChange(CHANGE_RENAME, (b'b', F, blob2.id),
-                        (b'a', F, blob4.id))],
-            self.detect_renames(tree1, tree2, rewrite_threshold=60))
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"e\nf\ng\nh\n")
+        blob3 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob4 = make_object(Blob, data=b"e\nf\ng\ni\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"a", blob4), (b"b", blob3)])
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob3.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob2.id), (b"a", F, blob4.id)),
+            ],
+            self.detect_renames(tree1, tree2, rewrite_threshold=60),
+        )
 
     def test_rewrite_threshold(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob3 = make_object(Blob, data=b'a\nb\nf\ng\n')
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob3 = make_object(Blob, data=b"a\nb\nf\ng\n")
 
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'a', blob3), (b'b', blob2)])
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"a", blob3), (b"b", blob2)])
 
         no_renames = [
-            TreeChange(CHANGE_MODIFY, (b'a', F, blob1.id),
-                       (b'a', F, blob3.id)),
-            TreeChange(CHANGE_COPY, (b'a', F, blob1.id), (b'b', F, blob2.id))]
-        self.assertEqual(
-            no_renames, self.detect_renames(tree1, tree2))
+            TreeChange(CHANGE_MODIFY, (b"a", F, blob1.id), (b"a", F, blob3.id)),
+            TreeChange(CHANGE_COPY, (b"a", F, blob1.id), (b"b", F, blob2.id)),
+        ]
+        self.assertEqual(no_renames, self.detect_renames(tree1, tree2))
         self.assertEqual(
-            no_renames, self.detect_renames(
-                tree1, tree2, rewrite_threshold=40))
+            no_renames, self.detect_renames(tree1, tree2, rewrite_threshold=40)
+        )
         self.assertEqual(
-            [TreeChange.add((b'a', F, blob3.id)),
-             TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2, rewrite_threshold=80))
+            [
+                TreeChange.add((b"a", F, blob3.id)),
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2, rewrite_threshold=80),
+        )
 
     def test_find_copies_harder_exact(self):
-        blob = make_object(Blob, data=b'blob')
-        tree1 = self.commit_tree([(b'a', blob)])
-        tree2 = self.commit_tree([(b'a', blob), (b'b', blob)])
-        self.assertEqual([TreeChange.add((b'b', F, blob.id))],
-                         self.detect_renames(tree1, tree2))
+        blob = make_object(Blob, data=b"blob")
+        tree1 = self.commit_tree([(b"a", blob)])
+        tree2 = self.commit_tree([(b"a", blob), (b"b", blob)])
+        self.assertEqual(
+            [TreeChange.add((b"b", F, blob.id))],
+            self.detect_renames(tree1, tree2),
+        )
         self.assertEqual(
-            [TreeChange(CHANGE_COPY, (b'a', F, blob.id), (b'b', F, blob.id))],
-            self.detect_renames(tree1, tree2, find_copies_harder=True))
+            [TreeChange(CHANGE_COPY, (b"a", F, blob.id), (b"b", F, blob.id))],
+            self.detect_renames(tree1, tree2, find_copies_harder=True),
+        )
 
     def test_find_copies_harder_content(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        self.assertEqual([TreeChange.add((b'b', F, blob2.id))],
-                         self.detect_renames(tree1, tree2))
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
         self.assertEqual(
-            [TreeChange(CHANGE_COPY, (b'a', F, blob1.id),
-                        (b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2, find_copies_harder=True))
+            [TreeChange.add((b"b", F, blob2.id))],
+            self.detect_renames(tree1, tree2),
+        )
+        self.assertEqual(
+            [TreeChange(CHANGE_COPY, (b"a", F, blob1.id), (b"b", F, blob2.id))],
+            self.detect_renames(tree1, tree2, find_copies_harder=True),
+        )
 
     def test_find_copies_harder_with_rewrites(self):
-        blob_a1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob_a2 = make_object(Blob, data=b'f\ng\nh\ni\n')
-        blob_b2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        tree1 = self.commit_tree([(b'a', blob_a1)])
-        tree2 = self.commit_tree([(b'a', blob_a2), (b'b', blob_b2)])
-        self.assertEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob_a1.id),
-                        (b'a', F, blob_a2.id)),
-             TreeChange(CHANGE_COPY, (b'a', F, blob_a1.id),
-                        (b'b', F, blob_b2.id))],
-            self.detect_renames(tree1, tree2, find_copies_harder=True))
-        self.assertEqual(
-            [TreeChange.add((b'a', F, blob_a2.id)),
-             TreeChange(CHANGE_RENAME, (b'a', F, blob_a1.id),
-                        (b'b', F, blob_b2.id))],
-            self.detect_renames(tree1, tree2, rewrite_threshold=50,
-                                find_copies_harder=True))
+        blob_a1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob_a2 = make_object(Blob, data=b"f\ng\nh\ni\n")
+        blob_b2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        tree1 = self.commit_tree([(b"a", blob_a1)])
+        tree2 = self.commit_tree([(b"a", blob_a2), (b"b", blob_b2)])
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a2.id)),
+                TreeChange(CHANGE_COPY, (b"a", F, blob_a1.id), (b"b", F, blob_b2.id)),
+            ],
+            self.detect_renames(tree1, tree2, find_copies_harder=True),
+        )
+        self.assertEqual(
+            [
+                TreeChange.add((b"a", F, blob_a2.id)),
+                TreeChange(CHANGE_RENAME, (b"a", F, blob_a1.id), (b"b", F, blob_b2.id)),
+            ],
+            self.detect_renames(
+                tree1, tree2, rewrite_threshold=50, find_copies_harder=True
+            ),
+        )
 
     def test_reuse_detector(self):
-        blob = make_object(Blob, data=b'blob')
-        tree1 = self.commit_tree([(b'a', blob)])
-        tree2 = self.commit_tree([(b'b', blob)])
+        blob = make_object(Blob, data=b"blob")
+        tree1 = self.commit_tree([(b"a", blob)])
+        tree2 = self.commit_tree([(b"b", blob)])
         detector = RenameDetector(self.store)
-        changes = [TreeChange(CHANGE_RENAME, (b'a', F, blob.id),
-                              (b'b', F, blob.id))]
-        self.assertEqual(changes,
-                         detector.changes_with_renames(tree1.id, tree2.id))
-        self.assertEqual(changes,
-                         detector.changes_with_renames(tree1.id, tree2.id))
+        changes = [TreeChange(CHANGE_RENAME, (b"a", F, blob.id), (b"b", F, blob.id))]
+        self.assertEqual(changes, detector.changes_with_renames(tree1.id, tree2.id))
+        self.assertEqual(changes, detector.changes_with_renames(tree1.id, tree2.id))
 
     def test_want_unchanged(self):
-        blob_a1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob_b = make_object(Blob, data=b'b')
-        blob_c2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        tree1 = self.commit_tree([(b'a', blob_a1), (b'b', blob_b)])
-        tree2 = self.commit_tree([(b'c', blob_c2), (b'b', blob_b)])
-        self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob_a1.id),
-                        (b'c', F, blob_c2.id))],
-            self.detect_renames(tree1, tree2))
-        self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob_a1.id),
-                        (b'c', F, blob_c2.id)),
-             TreeChange(CHANGE_UNCHANGED, (b'b', F, blob_b.id),
-                        (b'b', F, blob_b.id))],
-            self.detect_renames(tree1, tree2, want_unchanged=True))
+        blob_a1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob_b = make_object(Blob, data=b"b")
+        blob_c2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        tree1 = self.commit_tree([(b"a", blob_a1), (b"b", blob_b)])
+        tree2 = self.commit_tree([(b"c", blob_c2), (b"b", blob_b)])
+        self.assertEqual(
+            [TreeChange(CHANGE_RENAME, (b"a", F, blob_a1.id), (b"c", F, blob_c2.id))],
+            self.detect_renames(tree1, tree2),
+        )
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob_a1.id), (b"c", F, blob_c2.id)),
+                TreeChange(
+                    CHANGE_UNCHANGED,
+                    (b"b", F, blob_b.id),
+                    (b"b", F, blob_b.id),
+                ),
+            ],
+            self.detect_renames(tree1, tree2, want_unchanged=True),
+        )

+ 112 - 56
dulwich/tests/test_fastexport.py

@@ -24,23 +24,23 @@ import stat
 
 from dulwich.object_store import (
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
     Blob,
     Commit,
     Tree,
     ZERO_SHA,
-    )
+)
 from dulwich.repo import (
     MemoryRepo,
-    )
+)
 from dulwich.tests import (
     SkipTest,
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
     build_commit_graph,
-    )
+)
 
 
 class GitFastExporterTests(TestCase):
@@ -60,8 +60,7 @@ class GitFastExporterTests(TestCase):
         b = Blob()
         b.data = b"fooBAR"
         self.fastexporter.emit_blob(b)
-        self.assertEqual(b'blob\nmark :1\ndata 6\nfooBAR\n',
-                         self.stream.getvalue())
+        self.assertEqual(b"blob\nmark :1\ndata 6\nfooBAR\n", self.stream.getvalue())
 
     def test_emit_commit(self):
         b = Blob()
@@ -76,7 +75,8 @@ class GitFastExporterTests(TestCase):
         c.tree = t.id
         self.store.add_objects([(b, None), (t, None), (c, None)])
         self.fastexporter.emit_commit(c, b"refs/heads/master")
-        self.assertEqual(b"""blob
+        self.assertEqual(
+            b"""blob
 mark :1
 data 3
 FOO
@@ -87,7 +87,9 @@ committer Jelmer <jelmer@host> 1271345553 +0000
 data 3
 msg
 M 644 :1 foo
-""", self.stream.getvalue())
+""",
+            self.stream.getvalue(),
+        )
 
 
 class GitImportProcessorTests(TestCase):
@@ -104,6 +106,7 @@ class GitImportProcessorTests(TestCase):
 
     def test_reset_handler(self):
         from fastimport import commands
+
         [c1] = build_commit_graph(self.repo.object_store, [[1]])
         cmd = commands.ResetCommand(b"refs/heads/foo", c1.id)
         self.processor.reset_handler(cmd)
@@ -112,14 +115,16 @@ class GitImportProcessorTests(TestCase):
 
     def test_reset_handler_marker(self):
         from fastimport import commands
+
         [c1, c2] = build_commit_graph(self.repo.object_store, [[1], [2]])
-        self.processor.markers[b'10'] = c1.id
-        cmd = commands.ResetCommand(b"refs/heads/foo", b':10')
+        self.processor.markers[b"10"] = c1.id
+        cmd = commands.ResetCommand(b"refs/heads/foo", b":10")
         self.processor.reset_handler(cmd)
         self.assertEqual(c1.id, self.repo.get_refs()[b"refs/heads/foo"])
 
     def test_reset_handler_default(self):
         from fastimport import commands
+
         [c1, c2] = build_commit_graph(self.repo.object_store, [[1], [2]])
         cmd = commands.ResetCommand(b"refs/heads/foo", None)
         self.processor.reset_handler(cmd)
@@ -127,11 +132,17 @@ class GitImportProcessorTests(TestCase):
 
     def test_commit_handler(self):
         from fastimport import commands
+
         cmd = commands.CommitCommand(
-                b"refs/heads/foo",  b"mrkr",
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                b"FOO", None, [], [])
+            b"refs/heads/foo",
+            b"mrkr",
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            b"FOO",
+            None,
+            [],
+            [],
+        )
         self.processor.commit_handler(cmd)
         commit = self.repo[self.processor.last_commit]
         self.assertEqual(b"Jelmer <jelmer@samba.org>", commit.author)
@@ -146,16 +157,21 @@ class GitImportProcessorTests(TestCase):
 
     def test_commit_handler_markers(self):
         from fastimport import commands
-        [c1, c2, c3] = build_commit_graph(self.repo.object_store,
-                                          [[1], [2], [3]])
-        self.processor.markers[b'10'] = c1.id
-        self.processor.markers[b'42'] = c2.id
-        self.processor.markers[b'98'] = c3.id
+
+        [c1, c2, c3] = build_commit_graph(self.repo.object_store, [[1], [2], [3]])
+        self.processor.markers[b"10"] = c1.id
+        self.processor.markers[b"42"] = c2.id
+        self.processor.markers[b"98"] = c3.id
         cmd = commands.CommitCommand(
-                b"refs/heads/foo",  b"mrkr",
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                b"FOO", b':10', [b':42', b':98'], [])
+            b"refs/heads/foo",
+            b"mrkr",
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            b"FOO",
+            b":10",
+            [b":42", b":98"],
+            [],
+        )
         self.processor.commit_handler(cmd)
         commit = self.repo[self.processor.last_commit]
         self.assertEqual(c1.id, commit.parents[0])
@@ -163,7 +179,9 @@ class GitImportProcessorTests(TestCase):
         self.assertEqual(c3.id, commit.parents[2])
 
     def test_import_stream(self):
-        markers = self.processor.import_stream(BytesIO(b"""blob
+        markers = self.processor.import_stream(
+            BytesIO(
+                b"""blob
 mark :1
 data 11
 text for a
@@ -175,37 +193,50 @@ data 20
 <The commit message>
 M 100644 :1 a
 
-"""))
+"""
+            )
+        )
         self.assertEqual(2, len(markers))
         self.assertTrue(isinstance(self.repo[markers[b"1"]], Blob))
         self.assertTrue(isinstance(self.repo[markers[b"2"]], Commit))
 
     def test_file_add(self):
         from fastimport import commands
+
         cmd = commands.BlobCommand(b"23", b"data")
         self.processor.blob_handler(cmd)
         cmd = commands.CommitCommand(
-                b"refs/heads/foo", b"mrkr",
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                b"FOO", None, [],
-                [commands.FileModifyCommand(b"path", 0o100644, b":23", None)])
+            b"refs/heads/foo",
+            b"mrkr",
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            b"FOO",
+            None,
+            [],
+            [commands.FileModifyCommand(b"path", 0o100644, b":23", None)],
+        )
         self.processor.commit_handler(cmd)
         commit = self.repo[self.processor.last_commit]
-        self.assertEqual([
-            (b'path', 0o100644, b'6320cd248dd8aeaab759d5871f8781b5c0505172')],
-            self.repo[commit.tree].items())
+        self.assertEqual(
+            [(b"path", 0o100644, b"6320cd248dd8aeaab759d5871f8781b5c0505172")],
+            self.repo[commit.tree].items(),
+        )
 
     def simple_commit(self):
         from fastimport import commands
+
         cmd = commands.BlobCommand(b"23", b"data")
         self.processor.blob_handler(cmd)
         cmd = commands.CommitCommand(
-                b"refs/heads/foo", b"mrkr",
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                b"FOO", None, [],
-                [commands.FileModifyCommand(b"path", 0o100644, b":23", None)])
+            b"refs/heads/foo",
+            b"mrkr",
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            b"FOO",
+            None,
+            [],
+            [commands.FileModifyCommand(b"path", 0o100644, b":23", None)],
+        )
         self.processor.commit_handler(cmd)
         commit = self.repo[self.processor.last_commit]
         return commit
@@ -218,44 +249,69 @@ M 100644 :1 a
         Returns: The created commit object
         """
         from fastimport import commands
+
         cmd = commands.CommitCommand(
-                b"refs/heads/foo", b"mrkr",
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                b"FOO", None, [], file_cmds)
+            b"refs/heads/foo",
+            b"mrkr",
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            b"FOO",
+            None,
+            [],
+            file_cmds,
+        )
         self.processor.commit_handler(cmd)
         return self.repo[self.processor.last_commit]
 
     def test_file_copy(self):
         from fastimport import commands
+
         self.simple_commit()
-        commit = self.make_file_commit(
-                [commands.FileCopyCommand(b"path", b"new_path")])
-        self.assertEqual([
-                (b'new_path', 0o100644,
-                 b'6320cd248dd8aeaab759d5871f8781b5c0505172'),
-                (b'path', 0o100644,
-                 b'6320cd248dd8aeaab759d5871f8781b5c0505172'),
-                ], self.repo[commit.tree].items())
+        commit = self.make_file_commit([commands.FileCopyCommand(b"path", b"new_path")])
+        self.assertEqual(
+            [
+                (
+                    b"new_path",
+                    0o100644,
+                    b"6320cd248dd8aeaab759d5871f8781b5c0505172",
+                ),
+                (
+                    b"path",
+                    0o100644,
+                    b"6320cd248dd8aeaab759d5871f8781b5c0505172",
+                ),
+            ],
+            self.repo[commit.tree].items(),
+        )
 
     def test_file_move(self):
         from fastimport import commands
+
         self.simple_commit()
         commit = self.make_file_commit(
-                [commands.FileRenameCommand(b"path", b"new_path")])
-        self.assertEqual([
-                (b'new_path', 0o100644,
-                 b'6320cd248dd8aeaab759d5871f8781b5c0505172'),
-                ], self.repo[commit.tree].items())
+            [commands.FileRenameCommand(b"path", b"new_path")]
+        )
+        self.assertEqual(
+            [
+                (
+                    b"new_path",
+                    0o100644,
+                    b"6320cd248dd8aeaab759d5871f8781b5c0505172",
+                ),
+            ],
+            self.repo[commit.tree].items(),
+        )
 
     def test_file_delete(self):
         from fastimport import commands
+
         self.simple_commit()
         commit = self.make_file_commit([commands.FileDeleteCommand(b"path")])
         self.assertEqual([], self.repo[commit.tree].items())
 
     def test_file_deleteall(self):
         from fastimport import commands
+
         self.simple_commit()
         commit = self.make_file_commit([commands.FileDeleteAllCommand()])
         self.assertEqual([], self.repo[commit.tree].items())

+ 62 - 64
dulwich/tests/test_file.py

@@ -28,17 +28,16 @@ from dulwich.file import FileLocked, GitFile, _fancy_rename
 from dulwich.tests import (
     SkipTest,
     TestCase,
-    )
+)
 
 
 class FancyRenameTests(TestCase):
-
     def setUp(self):
         super(FancyRenameTests, self).setUp()
         self._tempdir = tempfile.mkdtemp()
-        self.foo = self.path('foo')
-        self.bar = self.path('bar')
-        self.create(self.foo, b'foo contents')
+        self.foo = self.path("foo")
+        self.bar = self.path("bar")
+        self.create(self.foo, b"foo contents")
 
     def tearDown(self):
         shutil.rmtree(self._tempdir)
@@ -48,7 +47,7 @@ class FancyRenameTests(TestCase):
         return os.path.join(self._tempdir, filename)
 
     def create(self, path, contents):
-        f = open(path, 'wb')
+        f = open(path, "wb")
         f.write(contents)
         f.close()
 
@@ -57,44 +56,43 @@ class FancyRenameTests(TestCase):
         _fancy_rename(self.foo, self.bar)
         self.assertFalse(os.path.exists(self.foo))
 
-        new_f = open(self.bar, 'rb')
-        self.assertEqual(b'foo contents', new_f.read())
+        new_f = open(self.bar, "rb")
+        self.assertEqual(b"foo contents", new_f.read())
         new_f.close()
 
     def test_dest_exists(self):
-        self.create(self.bar, b'bar contents')
+        self.create(self.bar, b"bar contents")
         _fancy_rename(self.foo, self.bar)
         self.assertFalse(os.path.exists(self.foo))
 
-        new_f = open(self.bar, 'rb')
-        self.assertEqual(b'foo contents', new_f.read())
+        new_f = open(self.bar, "rb")
+        self.assertEqual(b"foo contents", new_f.read())
         new_f.close()
 
     def test_dest_opened(self):
         if sys.platform != "win32":
             raise SkipTest("platform allows overwriting open files")
-        self.create(self.bar, b'bar contents')
-        dest_f = open(self.bar, 'rb')
+        self.create(self.bar, b"bar contents")
+        dest_f = open(self.bar, "rb")
         self.assertRaises(OSError, _fancy_rename, self.foo, self.bar)
         dest_f.close()
-        self.assertTrue(os.path.exists(self.path('foo')))
+        self.assertTrue(os.path.exists(self.path("foo")))
 
-        new_f = open(self.foo, 'rb')
-        self.assertEqual(b'foo contents', new_f.read())
+        new_f = open(self.foo, "rb")
+        self.assertEqual(b"foo contents", new_f.read())
         new_f.close()
 
-        new_f = open(self.bar, 'rb')
-        self.assertEqual(b'bar contents', new_f.read())
+        new_f = open(self.bar, "rb")
+        self.assertEqual(b"bar contents", new_f.read())
         new_f.close()
 
 
 class GitFileTests(TestCase):
-
     def setUp(self):
         super(GitFileTests, self).setUp()
         self._tempdir = tempfile.mkdtemp()
-        f = open(self.path('foo'), 'wb')
-        f.write(b'foo contents')
+        f = open(self.path("foo"), "wb")
+        f.write(b"foo contents")
         f.close()
 
     def tearDown(self):
@@ -105,98 +103,98 @@ class GitFileTests(TestCase):
         return os.path.join(self._tempdir, filename)
 
     def test_invalid(self):
-        foo = self.path('foo')
-        self.assertRaises(IOError, GitFile, foo, mode='r')
-        self.assertRaises(IOError, GitFile, foo, mode='ab')
-        self.assertRaises(IOError, GitFile, foo, mode='r+b')
-        self.assertRaises(IOError, GitFile, foo, mode='w+b')
-        self.assertRaises(IOError, GitFile, foo, mode='a+bU')
+        foo = self.path("foo")
+        self.assertRaises(IOError, GitFile, foo, mode="r")
+        self.assertRaises(IOError, GitFile, foo, mode="ab")
+        self.assertRaises(IOError, GitFile, foo, mode="r+b")
+        self.assertRaises(IOError, GitFile, foo, mode="w+b")
+        self.assertRaises(IOError, GitFile, foo, mode="a+bU")
 
     def test_readonly(self):
-        f = GitFile(self.path('foo'), 'rb')
+        f = GitFile(self.path("foo"), "rb")
         self.assertTrue(isinstance(f, io.IOBase))
-        self.assertEqual(b'foo contents', f.read())
-        self.assertEqual(b'', f.read())
+        self.assertEqual(b"foo contents", f.read())
+        self.assertEqual(b"", f.read())
         f.seek(4)
-        self.assertEqual(b'contents', f.read())
+        self.assertEqual(b"contents", f.read())
         f.close()
 
     def test_default_mode(self):
-        f = GitFile(self.path('foo'))
-        self.assertEqual(b'foo contents', f.read())
+        f = GitFile(self.path("foo"))
+        self.assertEqual(b"foo contents", f.read())
         f.close()
 
     def test_write(self):
-        foo = self.path('foo')
-        foo_lock = '%s.lock' % foo
+        foo = self.path("foo")
+        foo_lock = "%s.lock" % foo
 
-        orig_f = open(foo, 'rb')
-        self.assertEqual(orig_f.read(), b'foo contents')
+        orig_f = open(foo, "rb")
+        self.assertEqual(orig_f.read(), b"foo contents")
         orig_f.close()
 
         self.assertFalse(os.path.exists(foo_lock))
-        f = GitFile(foo, 'wb')
+        f = GitFile(foo, "wb")
         self.assertFalse(f.closed)
-        self.assertRaises(AttributeError, getattr, f, 'not_a_file_property')
+        self.assertRaises(AttributeError, getattr, f, "not_a_file_property")
 
         self.assertTrue(os.path.exists(foo_lock))
-        f.write(b'new stuff')
+        f.write(b"new stuff")
         f.seek(4)
-        f.write(b'contents')
+        f.write(b"contents")
         f.close()
         self.assertFalse(os.path.exists(foo_lock))
 
-        new_f = open(foo, 'rb')
-        self.assertEqual(b'new contents', new_f.read())
+        new_f = open(foo, "rb")
+        self.assertEqual(b"new contents", new_f.read())
         new_f.close()
 
     def test_open_twice(self):
-        foo = self.path('foo')
-        f1 = GitFile(foo, 'wb')
-        f1.write(b'new')
+        foo = self.path("foo")
+        f1 = GitFile(foo, "wb")
+        f1.write(b"new")
         try:
-            f2 = GitFile(foo, 'wb')
+            f2 = GitFile(foo, "wb")
             self.fail()
         except FileLocked:
             pass
         else:
             f2.close()
-        f1.write(b' contents')
+        f1.write(b" contents")
         f1.close()
 
         # Ensure trying to open twice doesn't affect original.
-        f = open(foo, 'rb')
-        self.assertEqual(b'new contents', f.read())
+        f = open(foo, "rb")
+        self.assertEqual(b"new contents", f.read())
         f.close()
 
     def test_abort(self):
-        foo = self.path('foo')
-        foo_lock = '%s.lock' % foo
+        foo = self.path("foo")
+        foo_lock = "%s.lock" % foo
 
-        orig_f = open(foo, 'rb')
-        self.assertEqual(orig_f.read(), b'foo contents')
+        orig_f = open(foo, "rb")
+        self.assertEqual(orig_f.read(), b"foo contents")
         orig_f.close()
 
-        f = GitFile(foo, 'wb')
-        f.write(b'new contents')
+        f = GitFile(foo, "wb")
+        f.write(b"new contents")
         f.abort()
         self.assertTrue(f.closed)
         self.assertFalse(os.path.exists(foo_lock))
 
-        new_orig_f = open(foo, 'rb')
-        self.assertEqual(new_orig_f.read(), b'foo contents')
+        new_orig_f = open(foo, "rb")
+        self.assertEqual(new_orig_f.read(), b"foo contents")
         new_orig_f.close()
 
     def test_abort_close(self):
-        foo = self.path('foo')
-        f = GitFile(foo, 'wb')
+        foo = self.path("foo")
+        f = GitFile(foo, "wb")
         f.abort()
         try:
             f.close()
         except (IOError, OSError):
             self.fail()
 
-        f = GitFile(foo, 'wb')
+        f = GitFile(foo, "wb")
         f.close()
         try:
             f.abort()
@@ -204,11 +202,11 @@ class GitFileTests(TestCase):
             self.fail()
 
     def test_abort_close_removed(self):
-        foo = self.path('foo')
-        f = GitFile(foo, 'wb')
+        foo = self.path("foo")
+        f = GitFile(foo, "wb")
 
         f._file.close()
-        os.remove(foo+".lock")
+        os.remove(foo + ".lock")
 
         f.abort()
         self.assertTrue(f._closed)

+ 67 - 65
dulwich/tests/test_grafts.py

@@ -27,7 +27,7 @@ from dulwich.errors import ObjectFormatException
 from dulwich.tests import TestCase
 from dulwich.objects import (
     Tree,
-    )
+)
 from dulwich.repo import (
     parse_graftpoints,
     serialize_graftpoints,
@@ -37,11 +37,10 @@ from dulwich.repo import (
 
 
 def makesha(digit):
-    return (str(digit).encode('ascii') * 40)[:40]
+    return (str(digit).encode("ascii") * 40)[:40]
 
 
 class GraftParserTests(TestCase):
-
     def assertParse(self, expected, graftpoints):
         self.assertEqual(expected, parse_graftpoints(iter(graftpoints)))
 
@@ -52,49 +51,60 @@ class GraftParserTests(TestCase):
         self.assertParse({makesha(0): []}, [makesha(0)])
 
     def test_parents(self):
-        self.assertParse({makesha(0): [makesha(1), makesha(2)]},
-                         [b' '.join([makesha(0), makesha(1), makesha(2)])])
+        self.assertParse(
+            {makesha(0): [makesha(1), makesha(2)]},
+            [b" ".join([makesha(0), makesha(1), makesha(2)])],
+        )
 
     def test_multiple_hybrid(self):
         self.assertParse(
-            {makesha(0): [],
-             makesha(1): [makesha(2)],
-             makesha(3): [makesha(4), makesha(5)]},
-            [makesha(0),
-             b' '.join([makesha(1), makesha(2)]),
-             b' '.join([makesha(3), makesha(4), makesha(5)])])
+            {
+                makesha(0): [],
+                makesha(1): [makesha(2)],
+                makesha(3): [makesha(4), makesha(5)],
+            },
+            [
+                makesha(0),
+                b" ".join([makesha(1), makesha(2)]),
+                b" ".join([makesha(3), makesha(4), makesha(5)]),
+            ],
+        )
 
 
 class GraftSerializerTests(TestCase):
-
     def assertSerialize(self, expected, graftpoints):
-        self.assertEqual(
-            sorted(expected),
-            sorted(serialize_graftpoints(graftpoints)))
+        self.assertEqual(sorted(expected), sorted(serialize_graftpoints(graftpoints)))
 
     def test_no_grafts(self):
-        self.assertSerialize(b'', {})
+        self.assertSerialize(b"", {})
 
     def test_no_parents(self):
         self.assertSerialize(makesha(0), {makesha(0): []})
 
     def test_parents(self):
-        self.assertSerialize(b' '.join([makesha(0), makesha(1), makesha(2)]),
-                             {makesha(0): [makesha(1), makesha(2)]})
+        self.assertSerialize(
+            b" ".join([makesha(0), makesha(1), makesha(2)]),
+            {makesha(0): [makesha(1), makesha(2)]},
+        )
 
     def test_multiple_hybrid(self):
         self.assertSerialize(
-            b'\n'.join([
-                makesha(0),
-                b' '.join([makesha(1), makesha(2)]),
-                b' '.join([makesha(3), makesha(4), makesha(5)])]),
-            {makesha(0): [],
-             makesha(1): [makesha(2)],
-             makesha(3): [makesha(4), makesha(5)]})
+            b"\n".join(
+                [
+                    makesha(0),
+                    b" ".join([makesha(1), makesha(2)]),
+                    b" ".join([makesha(3), makesha(4), makesha(5)]),
+                ]
+            ),
+            {
+                makesha(0): [],
+                makesha(1): [makesha(2)],
+                makesha(3): [makesha(4), makesha(5)],
+            },
+        )
 
 
 class GraftsInRepositoryBase(object):
-
     def tearDown(self):
         super(GraftsInRepositoryBase, self).tearDown()
 
@@ -112,33 +122,31 @@ class GraftsInRepositoryBase(object):
     def test_no_parents_graft(self):
         r = self.get_repo_with_grafts({self._repo.head(): []})
 
-        self.assertEqual([e.commit.id for e in r.get_walker()],
-                         [r.head()])
+        self.assertEqual([e.commit.id for e in r.get_walker()], [r.head()])
 
     def test_existing_parent_graft(self):
         r = self.get_repo_with_grafts({self._shas[-1]: [self._shas[0]]})
 
-        self.assertEqual([e.commit.id for e in r.get_walker()],
-                         [self._shas[-1], self._shas[0]])
+        self.assertEqual(
+            [e.commit.id for e in r.get_walker()],
+            [self._shas[-1], self._shas[0]],
+        )
 
     def test_remove_graft(self):
         r = self.get_repo_with_grafts({self._repo.head(): []})
         r._remove_graftpoints([self._repo.head()])
 
-        self.assertEqual([e.commit.id for e in r.get_walker()],
-                         self._shas[::-1])
+        self.assertEqual([e.commit.id for e in r.get_walker()], self._shas[::-1])
 
     def test_object_store_fail_invalid_parents(self):
         r = self._repo
 
         self.assertRaises(
-            ObjectFormatException,
-            r._add_graftpoints,
-            {self._shas[-1]: ['1']})
+            ObjectFormatException, r._add_graftpoints, {self._shas[-1]: ["1"]}
+        )
 
 
 class GraftsInRepoTests(GraftsInRepositoryBase, TestCase):
-
     def setUp(self):
         super(GraftsInRepoTests, self).setUp()
         self._repo_dir = os.path.join(tempfile.mkdtemp())
@@ -148,24 +156,21 @@ class GraftsInRepoTests(GraftsInRepositoryBase, TestCase):
         self._shas = []
 
         commit_kwargs = {
-            'committer': b'Test Committer <test@nodomain.com>',
-            'author': b'Test Author <test@nodomain.com>',
-            'commit_timestamp': 12395,
-            'commit_timezone': 0,
-            'author_timestamp': 12395,
-            'author_timezone': 0,
+            "committer": b"Test Committer <test@nodomain.com>",
+            "author": b"Test Author <test@nodomain.com>",
+            "commit_timestamp": 12395,
+            "commit_timezone": 0,
+            "author_timestamp": 12395,
+            "author_timezone": 0,
         }
 
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))
 
     def test_init_with_empty_info_grafts(self):
         r = self._repo
-        r._put_named_file(os.path.join('info', 'grafts'), b'')
+        r._put_named_file(os.path.join("info", "grafts"), b"")
 
         r = Repo(self._repo_dir)
         self.assertEqual({}, r._graftpoints)
@@ -173,15 +178,15 @@ class GraftsInRepoTests(GraftsInRepositoryBase, TestCase):
     def test_init_with_info_grafts(self):
         r = self._repo
         r._put_named_file(
-            os.path.join('info', 'grafts'),
-            self._shas[-1] + b' ' + self._shas[0])
+            os.path.join("info", "grafts"),
+            self._shas[-1] + b" " + self._shas[0],
+        )
 
         r = Repo(self._repo_dir)
         self.assertEqual({self._shas[-1]: [self._shas[0]]}, r._graftpoints)
 
 
 class GraftsInMemoryRepoTests(GraftsInRepositoryBase, TestCase):
-
     def setUp(self):
         super(GraftsInMemoryRepoTests, self).setUp()
         r = self._repo = MemoryRepo()
@@ -191,18 +196,15 @@ class GraftsInMemoryRepoTests(GraftsInRepositoryBase, TestCase):
         tree = Tree()
 
         commit_kwargs = {
-            'committer': b'Test Committer <test@nodomain.com>',
-            'author': b'Test Author <test@nodomain.com>',
-            'commit_timestamp': 12395,
-            'commit_timezone': 0,
-            'author_timestamp': 12395,
-            'author_timezone': 0,
-            'tree': tree.id
+            "committer": b"Test Committer <test@nodomain.com>",
+            "author": b"Test Author <test@nodomain.com>",
+            "commit_timestamp": 12395,
+            "commit_timezone": 0,
+            "author_timestamp": 12395,
+            "author_timezone": 0,
+            "tree": tree.id,
         }
 
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))

+ 76 - 77
dulwich/tests/test_graph.py

@@ -29,11 +29,11 @@ from dulwich.graph import _find_lcas, can_fast_forward
 
 
 class FindMergeBaseTests(TestCase):
-
     @staticmethod
     def run_test(dag, inputs):
         def lookup_parents(commit_id):
             return dag[commit_id]
+
         c1 = inputs[0]
         c2s = inputs[1:]
         return set(_find_lcas(lookup_parents, c1, c2s))
@@ -41,125 +41,125 @@ class FindMergeBaseTests(TestCase):
     def test_multiple_lca(self):
         # two lowest common ancestors
         graph = {
-            '5': ['1', '2'],
-            '4': ['3', '1'],
-            '3': ['2'],
-            '2': ['0'],
-            '1': [],
-            '0': []
+            "5": ["1", "2"],
+            "4": ["3", "1"],
+            "3": ["2"],
+            "2": ["0"],
+            "1": [],
+            "0": [],
         }
-        self.assertEqual(self.run_test(graph, ['4', '5']), set(['1', '2']))
+        self.assertEqual(self.run_test(graph, ["4", "5"]), set(["1", "2"]))
 
     def test_no_common_ancestor(self):
         # no common ancestor
         graph = {
-            '4': ['2'],
-            '3': ['1'],
-            '2': [],
-            '1': ['0'],
-            '0': [],
+            "4": ["2"],
+            "3": ["1"],
+            "2": [],
+            "1": ["0"],
+            "0": [],
         }
-        self.assertEqual(self.run_test(graph, ['4', '3']), set([]))
+        self.assertEqual(self.run_test(graph, ["4", "3"]), set([]))
 
     def test_ancestor(self):
         # ancestor
         graph = {
-            'G': ['D', 'F'],
-            'F': ['E'],
-            'D': ['C'],
-            'C': ['B'],
-            'E': ['B'],
-            'B': ['A'],
-            'A': []
+            "G": ["D", "F"],
+            "F": ["E"],
+            "D": ["C"],
+            "C": ["B"],
+            "E": ["B"],
+            "B": ["A"],
+            "A": [],
         }
-        self.assertEqual(self.run_test(graph, ['D', 'C']), set(['C']))
+        self.assertEqual(self.run_test(graph, ["D", "C"]), set(["C"]))
 
     def test_direct_parent(self):
         # parent
         graph = {
-            'G': ['D', 'F'],
-            'F': ['E'],
-            'D': ['C'],
-            'C': ['B'],
-            'E': ['B'],
-            'B': ['A'],
-            'A': []
+            "G": ["D", "F"],
+            "F": ["E"],
+            "D": ["C"],
+            "C": ["B"],
+            "E": ["B"],
+            "B": ["A"],
+            "A": [],
         }
-        self.assertEqual(self.run_test(graph, ['G', 'D']), set(['D']))
+        self.assertEqual(self.run_test(graph, ["G", "D"]), set(["D"]))
 
     def test_another_crossover(self):
         # Another cross over
         graph = {
-            'G': ['D', 'F'],
-            'F': ['E', 'C'],
-            'D': ['C', 'E'],
-            'C': ['B'],
-            'E': ['B'],
-            'B': ['A'],
-            'A': []
+            "G": ["D", "F"],
+            "F": ["E", "C"],
+            "D": ["C", "E"],
+            "C": ["B"],
+            "E": ["B"],
+            "B": ["A"],
+            "A": [],
         }
-        self.assertEqual(self.run_test(graph, ['D', 'F']), set(['E', 'C']))
+        self.assertEqual(self.run_test(graph, ["D", "F"]), set(["E", "C"]))
 
     def test_three_way_merge_lca(self):
         # three way merge commit straight from git docs
         graph = {
-            'C': ['C1'],
-            'C1': ['C2'],
-            'C2': ['C3'],
-            'C3': ['C4'],
-            'C4': ['2'],
-            'B': ['B1'],
-            'B1': ['B2'],
-            'B2': ['B3'],
-            'B3': ['1'],
-            'A': ['A1'],
-            'A1': ['A2'],
-            'A2': ['A3'],
-            'A3': ['1'],
-            '1': ['2'],
-            '2': [],
+            "C": ["C1"],
+            "C1": ["C2"],
+            "C2": ["C3"],
+            "C3": ["C4"],
+            "C4": ["2"],
+            "B": ["B1"],
+            "B1": ["B2"],
+            "B2": ["B3"],
+            "B3": ["1"],
+            "A": ["A1"],
+            "A1": ["A2"],
+            "A2": ["A3"],
+            "A3": ["1"],
+            "1": ["2"],
+            "2": [],
         }
         # assumes a theoretical merge M exists that merges B and C first
         # which actually means find the first LCA from either of B OR C with A
-        self.assertEqual(self.run_test(graph, ['A', 'B', 'C']), set(['1']))
+        self.assertEqual(self.run_test(graph, ["A", "B", "C"]), set(["1"]))
 
     def test_octopus(self):
         # octopus algorithm test
         # test straight from git docs of A, B, and C
         # but this time use octopus to find lcas of A, B, and C simultaneously
         graph = {
-            'C': ['C1'],
-            'C1': ['C2'],
-            'C2': ['C3'],
-            'C3': ['C4'],
-            'C4': ['2'],
-            'B': ['B1'],
-            'B1': ['B2'],
-            'B2': ['B3'],
-            'B3': ['1'],
-            'A': ['A1'],
-            'A1': ['A2'],
-            'A2': ['A3'],
-            'A3': ['1'],
-            '1': ['2'],
-            '2': [],
+            "C": ["C1"],
+            "C1": ["C2"],
+            "C2": ["C3"],
+            "C3": ["C4"],
+            "C4": ["2"],
+            "B": ["B1"],
+            "B1": ["B2"],
+            "B2": ["B3"],
+            "B3": ["1"],
+            "A": ["A1"],
+            "A1": ["A2"],
+            "A2": ["A3"],
+            "A3": ["1"],
+            "1": ["2"],
+            "2": [],
         }
 
         def lookup_parents(cid):
             return graph[cid]
-        lcas = ['A']
-        others = ['B', 'C']
+
+        lcas = ["A"]
+        others = ["B", "C"]
         for cmt in others:
             next_lcas = []
             for ca in lcas:
                 res = _find_lcas(lookup_parents, cmt, [ca])
                 next_lcas.extend(res)
             lcas = next_lcas[:]
-        self.assertEqual(set(lcas), set(['2']))
+        self.assertEqual(set(lcas), set(["2"]))
 
 
 class CanFastForwardTests(TestCase):
-
     def test_ff(self):
         r = MemoryRepo()
         base = make_commit()
@@ -175,10 +175,9 @@ class CanFastForwardTests(TestCase):
         r = MemoryRepo()
         base = make_commit()
         c1 = make_commit(parents=[base.id])
-        c2a = make_commit(parents=[c1.id], message=b'2a')
-        c2b = make_commit(parents=[c1.id], message=b'2b')
-        r.object_store.add_objects(
-            [(base, None), (c1, None), (c2a, None), (c2b, None)])
+        c2a = make_commit(parents=[c1.id], message=b"2a")
+        c2b = make_commit(parents=[c1.id], message=b"2b")
+        r.object_store.add_objects([(base, None), (c1, None), (c2a, None), (c2b, None)])
         self.assertTrue(can_fast_forward(r, c1.id, c2a.id))
         self.assertTrue(can_fast_forward(r, c1.id, c2b.id))
         self.assertFalse(can_fast_forward(r, c2a.id, c2b.id))

+ 18 - 15
dulwich/tests/test_greenthreads.py

@@ -25,20 +25,21 @@ import time
 from dulwich.tests import (
     skipIf,
     TestCase,
-    )
+)
 from dulwich.object_store import (
     MemoryObjectStore,
     MissingObjectFinder,
-    )
+)
 from dulwich.objects import (
     Commit,
     Blob,
     Tree,
     parse_timezone,
-    )
+)
 
 try:
     import gevent  # noqa: F401
+
     gevent_support = True
 except ImportError:
     gevent_support = False
@@ -53,14 +54,14 @@ skipmsg = "Gevent library is not installed"
 
 
 def create_commit(marker=None):
-    blob = Blob.from_string(b'The blob content ' + marker)
+    blob = Blob.from_string(b"The blob content " + marker)
     tree = Tree()
     tree.add(b"thefile " + marker, 0o100644, blob.id)
     cmt = Commit()
     cmt.tree = tree.id
     cmt.author = cmt.committer = b"John Doe <john@doe.net>"
     cmt.message = marker
-    tz = parse_timezone(b'-0200')[0]
+    tz = parse_timezone(b"-0200")[0]
     cmt.commit_time = cmt.author_time = int(time.time())
     cmt.commit_timezone = cmt.author_timezone = tz
     return cmt, tree, blob
@@ -69,7 +70,7 @@ def create_commit(marker=None):
 def init_store(store, count=1):
     ret = []
     for i in range(0, count):
-        objs = create_commit(marker=("%d" % i).encode('ascii'))
+        objs = create_commit(marker=("%d" % i).encode("ascii"))
         for obj in objs:
             ret.append(obj)
             store.add_object(obj)
@@ -78,7 +79,6 @@ def init_store(store, count=1):
 
 @skipIf(not gevent_support, skipmsg)
 class TestGreenThreadsObjectStoreIterator(TestCase):
-
     def setUp(self):
         super(TestGreenThreadsObjectStoreIterator, self).setUp()
         self.store = MemoryObjectStore()
@@ -89,20 +89,23 @@ class TestGreenThreadsObjectStoreIterator(TestCase):
         wants = [sha.id for sha in self.objs if isinstance(sha, Commit)]
         finder = MissingObjectFinder(self.store, (), wants)
         iterator = GreenThreadsObjectStoreIterator(
-                self.store, iter(finder.next, None), finder)
+            self.store, iter(finder.next, None), finder
+        )
         # One commit refers one tree and one blob
         self.assertEqual(len(iterator), self.cmt_amount * 3)
-        haves = wants[0:self.cmt_amount-1]
+        haves = wants[0 : self.cmt_amount - 1]
         finder = MissingObjectFinder(self.store, haves, wants)
         iterator = GreenThreadsObjectStoreIterator(
-            self.store, iter(finder.next, None), finder)
+            self.store, iter(finder.next, None), finder
+        )
         self.assertEqual(len(iterator), 3)
 
     def test_iter(self):
         wants = [sha.id for sha in self.objs if isinstance(sha, Commit)]
         finder = MissingObjectFinder(self.store, (), wants)
         iterator = GreenThreadsObjectStoreIterator(
-            self.store, iter(finder.next, None), finder)
+            self.store, iter(finder.next, None), finder
+        )
         objs = []
         for sha, path in iterator:
             self.assertIn(sha, self.objs)
@@ -112,7 +115,6 @@ class TestGreenThreadsObjectStoreIterator(TestCase):
 
 @skipIf(not gevent_support, skipmsg)
 class TestGreenThreadsMissingObjectFinder(TestCase):
-
     def setUp(self):
         super(TestGreenThreadsMissingObjectFinder, self).setUp()
         self.store = MemoryObjectStore()
@@ -126,7 +128,8 @@ class TestGreenThreadsMissingObjectFinder(TestCase):
         self.assertEqual(len(finder.objects_to_send), self.cmt_amount)
 
         finder = GreenThreadsMissingObjectFinder(
-            self.store, wants[0:int(self.cmt_amount/2)], wants)
+            self.store, wants[0 : int(self.cmt_amount / 2)], wants
+        )
         # sha_done will contains commit id and sha of blob refered in tree
-        self.assertEqual(len(finder.sha_done), (self.cmt_amount/2)*2)
-        self.assertEqual(len(finder.objects_to_send), self.cmt_amount/2)
+        self.assertEqual(len(finder.sha_done), (self.cmt_amount / 2) * 2)
+        self.assertEqual(len(finder.objects_to_send), self.cmt_amount / 2)

+ 51 - 34
dulwich/tests/test_hooks.py

@@ -37,16 +37,15 @@ from dulwich.tests import TestCase
 
 
 class ShellHookTests(TestCase):
-
     def setUp(self):
         super(ShellHookTests, self).setUp()
-        if os.name != 'posix':
-            self.skipTest('shell hook tests requires POSIX shell')
-        self.assertTrue(os.path.exists('/bin/sh'))
+        if os.name != "posix":
+            self.skipTest("shell hook tests requires POSIX shell")
+        self.assertTrue(os.path.exists("/bin/sh"))
 
     def test_hook_pre_commit(self):
         repo_dir = os.path.join(tempfile.mkdtemp())
-        os.mkdir(os.path.join(repo_dir, 'hooks'))
+        os.mkdir(os.path.join(repo_dir, "hooks"))
         self.addCleanup(shutil.rmtree, repo_dir)
 
         pre_commit_fail = """#!/bin/sh
@@ -56,34 +55,40 @@ exit 1
         pre_commit_success = """#!/bin/sh
 exit 0
 """
-        pre_commit_cwd = """#!/bin/sh
-if [ "$(pwd)" != '""" + repo_dir + """' ]; then
-    echo "Expected path '""" + repo_dir + """', got '$(pwd)'"
+        pre_commit_cwd = (
+            """#!/bin/sh
+if [ "$(pwd)" != '"""
+            + repo_dir
+            + """' ]; then
+    echo "Expected path '"""
+            + repo_dir
+            + """', got '$(pwd)'"
     exit 1
 fi
 
 exit 0
 """
+        )
 
-        pre_commit = os.path.join(repo_dir, 'hooks', 'pre-commit')
+        pre_commit = os.path.join(repo_dir, "hooks", "pre-commit")
         hook = PreCommitShellHook(repo_dir)
 
-        with open(pre_commit, 'w') as f:
+        with open(pre_commit, "w") as f:
             f.write(pre_commit_fail)
         os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
         self.assertRaises(errors.HookError, hook.execute)
 
-        if sys.platform != 'darwin':
+        if sys.platform != "darwin":
             # Don't bother running this test on darwin since path
             # canonicalization messages with our simple string comparison.
-            with open(pre_commit, 'w') as f:
+            with open(pre_commit, "w") as f:
                 f.write(pre_commit_cwd)
             os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
             hook.execute()
 
-        with open(pre_commit, 'w') as f:
+        with open(pre_commit, "w") as f:
             f.write(pre_commit_success)
         os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
@@ -92,7 +97,7 @@ exit 0
     def test_hook_commit_msg(self):
 
         repo_dir = os.path.join(tempfile.mkdtemp())
-        os.mkdir(os.path.join(repo_dir, 'hooks'))
+        os.mkdir(os.path.join(repo_dir, "hooks"))
         self.addCleanup(shutil.rmtree, repo_dir)
 
         commit_msg_fail = """#!/bin/sh
@@ -103,32 +108,36 @@ exit 1
 exit 0
 """
 
-        commit_msg_cwd = """#!/bin/sh
-if [ "$(pwd)" = '""" + repo_dir + "' ]; then exit 0; else exit 1; fi\n"
+        commit_msg_cwd = (
+            """#!/bin/sh
+if [ "$(pwd)" = '"""
+            + repo_dir
+            + "' ]; then exit 0; else exit 1; fi\n"
+        )
 
-        commit_msg = os.path.join(repo_dir, 'hooks', 'commit-msg')
+        commit_msg = os.path.join(repo_dir, "hooks", "commit-msg")
         hook = CommitMsgShellHook(repo_dir)
 
-        with open(commit_msg, 'w') as f:
+        with open(commit_msg, "w") as f:
             f.write(commit_msg_fail)
         os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
-        self.assertRaises(errors.HookError, hook.execute, b'failed commit')
+        self.assertRaises(errors.HookError, hook.execute, b"failed commit")
 
-        if sys.platform != 'darwin':
+        if sys.platform != "darwin":
             # Don't bother running this test on darwin since path
             # canonicalization messages with our simple string comparison.
-            with open(commit_msg, 'w') as f:
+            with open(commit_msg, "w") as f:
                 f.write(commit_msg_cwd)
             os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
-            hook.execute(b'cwd test commit')
+            hook.execute(b"cwd test commit")
 
-        with open(commit_msg, 'w') as f:
+        with open(commit_msg, "w") as f:
             f.write(commit_msg_success)
         os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
-        hook.execute(b'empty commit')
+        hook.execute(b"empty commit")
 
     def test_hook_post_commit(self):
 
@@ -136,38 +145,46 @@ if [ "$(pwd)" = '""" + repo_dir + "' ]; then exit 0; else exit 1; fi\n"
         os.close(fd)
 
         repo_dir = os.path.join(tempfile.mkdtemp())
-        os.mkdir(os.path.join(repo_dir, 'hooks'))
+        os.mkdir(os.path.join(repo_dir, "hooks"))
         self.addCleanup(shutil.rmtree, repo_dir)
 
-        post_commit_success = """#!/bin/sh
-rm """ + path + "\n"
+        post_commit_success = (
+            """#!/bin/sh
+rm """
+            + path
+            + "\n"
+        )
 
         post_commit_fail = """#!/bin/sh
 exit 1
 """
 
-        post_commit_cwd = """#!/bin/sh
-if [ "$(pwd)" = '""" + repo_dir + "' ]; then exit 0; else exit 1; fi\n"
+        post_commit_cwd = (
+            """#!/bin/sh
+if [ "$(pwd)" = '"""
+            + repo_dir
+            + "' ]; then exit 0; else exit 1; fi\n"
+        )
 
-        post_commit = os.path.join(repo_dir, 'hooks', 'post-commit')
+        post_commit = os.path.join(repo_dir, "hooks", "post-commit")
         hook = PostCommitShellHook(repo_dir)
 
-        with open(post_commit, 'w') as f:
+        with open(post_commit, "w") as f:
             f.write(post_commit_fail)
         os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
         self.assertRaises(errors.HookError, hook.execute)
 
-        if sys.platform != 'darwin':
+        if sys.platform != "darwin":
             # Don't bother running this test on darwin since path
             # canonicalization messages with our simple string comparison.
-            with open(post_commit, 'w') as f:
+            with open(post_commit, "w") as f:
                 f.write(post_commit_cwd)
             os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
             hook.execute()
 
-        with open(post_commit, 'w') as f:
+        with open(post_commit, "w") as f:
             f.write(post_commit_success)
         os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 

+ 126 - 115
dulwich/tests/test_ignore.py

@@ -35,7 +35,7 @@ from dulwich.ignore import (
     match_pattern,
     read_ignore_patterns,
     translate,
-    )
+)
 from dulwich.repo import Repo
 
 
@@ -65,44 +65,45 @@ NEGATIVE_MATCH_TESTS = [
     (b"foo/foo.c", b"/*.c"),
     (b"foo/bar/", b"/bar/"),
     (b"foo/bar/", b"foo/bar/*"),
-    (b"foo/bar", b"foo?bar")
+    (b"foo/bar", b"foo?bar"),
 ]
 
 
 TRANSLATE_TESTS = [
-    (b"*.c", b'(?ms)(.*/)?[^/]*\\.c/?\\Z'),
-    (b"foo.c", b'(?ms)(.*/)?foo\\.c/?\\Z'),
-    (b"/*.c", b'(?ms)[^/]*\\.c/?\\Z'),
-    (b"/foo.c", b'(?ms)foo\\.c/?\\Z'),
-    (b"foo.c", b'(?ms)(.*/)?foo\\.c/?\\Z'),
-    (b"foo.[ch]", b'(?ms)(.*/)?foo\\.[ch]/?\\Z'),
-    (b"bar/", b'(?ms)(.*/)?bar\\/\\Z'),
-    (b"foo/**", b'(?ms)foo(/.*)?/?\\Z'),
-    (b"foo/**/blie.c", b'(?ms)foo(/.*)?\\/blie\\.c/?\\Z'),
-    (b"**/bla.c", b'(?ms)(.*/)?bla\\.c/?\\Z'),
-    (b"foo/**/bar", b'(?ms)foo(/.*)?\\/bar/?\\Z'),
-    (b"foo/bar/*", b'(?ms)foo\\/bar\\/[^/]+/?\\Z'),
+    (b"*.c", b"(?ms)(.*/)?[^/]*\\.c/?\\Z"),
+    (b"foo.c", b"(?ms)(.*/)?foo\\.c/?\\Z"),
+    (b"/*.c", b"(?ms)[^/]*\\.c/?\\Z"),
+    (b"/foo.c", b"(?ms)foo\\.c/?\\Z"),
+    (b"foo.c", b"(?ms)(.*/)?foo\\.c/?\\Z"),
+    (b"foo.[ch]", b"(?ms)(.*/)?foo\\.[ch]/?\\Z"),
+    (b"bar/", b"(?ms)(.*/)?bar\\/\\Z"),
+    (b"foo/**", b"(?ms)foo(/.*)?/?\\Z"),
+    (b"foo/**/blie.c", b"(?ms)foo(/.*)?\\/blie\\.c/?\\Z"),
+    (b"**/bla.c", b"(?ms)(.*/)?bla\\.c/?\\Z"),
+    (b"foo/**/bar", b"(?ms)foo(/.*)?\\/bar/?\\Z"),
+    (b"foo/bar/*", b"(?ms)foo\\/bar\\/[^/]+/?\\Z"),
 ]
 
 
 class TranslateTests(TestCase):
-
     def test_translate(self):
         for (pattern, regex) in TRANSLATE_TESTS:
-            if re.escape(b'/') == b'/':
+            if re.escape(b"/") == b"/":
                 # Slash is no longer escaped in Python3.7, so undo the escaping
                 # in the expected return value..
-                regex = regex.replace(b'\\/', b'/')
+                regex = regex.replace(b"\\/", b"/")
             self.assertEqual(
-                regex, translate(pattern),
-                "orig pattern: %r, regex: %r, expected: %r" %
-                (pattern, translate(pattern), regex))
+                regex,
+                translate(pattern),
+                "orig pattern: %r, regex: %r, expected: %r"
+                % (pattern, translate(pattern), regex),
+            )
 
 
 class ReadIgnorePatterns(TestCase):
-
     def test_read_file(self):
-        f = BytesIO(b"""
+        f = BytesIO(
+            b"""
 # a comment
 
 # and an empty line:
@@ -111,151 +112,161 @@ class ReadIgnorePatterns(TestCase):
 !negative
 with trailing whitespace 
 with escaped trailing whitespace\\ 
-""")  # noqa: W291
-        self.assertEqual(list(read_ignore_patterns(f)), [
-            b'\\#not a comment',
-            b'!negative',
-            b'with trailing whitespace',
-            b'with escaped trailing whitespace '
-        ])
+"""
+        )  # noqa: W291
+        self.assertEqual(
+            list(read_ignore_patterns(f)),
+            [
+                b"\\#not a comment",
+                b"!negative",
+                b"with trailing whitespace",
+                b"with escaped trailing whitespace ",
+            ],
+        )
 
 
 class MatchPatternTests(TestCase):
-
     def test_matches(self):
         for (path, pattern) in POSITIVE_MATCH_TESTS:
             self.assertTrue(
                 match_pattern(path, pattern),
-                "path: %r, pattern: %r" % (path, pattern))
+                "path: %r, pattern: %r" % (path, pattern),
+            )
 
     def test_no_matches(self):
         for (path, pattern) in NEGATIVE_MATCH_TESTS:
             self.assertFalse(
                 match_pattern(path, pattern),
-                "path: %r, pattern: %r" % (path, pattern))
+                "path: %r, pattern: %r" % (path, pattern),
+            )
 
 
 class IgnoreFilterTests(TestCase):
-
     def test_included(self):
-        filter = IgnoreFilter([b'a.c', b'b.c'])
-        self.assertTrue(filter.is_ignored(b'a.c'))
-        self.assertIs(None, filter.is_ignored(b'c.c'))
-        self.assertEqual(
-            [Pattern(b'a.c')],
-            list(filter.find_matching(b'a.c')))
-        self.assertEqual(
-            [],
-            list(filter.find_matching(b'c.c')))
+        filter = IgnoreFilter([b"a.c", b"b.c"])
+        self.assertTrue(filter.is_ignored(b"a.c"))
+        self.assertIs(None, filter.is_ignored(b"c.c"))
+        self.assertEqual([Pattern(b"a.c")], list(filter.find_matching(b"a.c")))
+        self.assertEqual([], list(filter.find_matching(b"c.c")))
 
     def test_included_ignorecase(self):
-        filter = IgnoreFilter([b'a.c', b'b.c'], ignorecase=False)
-        self.assertTrue(filter.is_ignored(b'a.c'))
-        self.assertFalse(filter.is_ignored(b'A.c'))
-        filter = IgnoreFilter([b'a.c', b'b.c'], ignorecase=True)
-        self.assertTrue(filter.is_ignored(b'a.c'))
-        self.assertTrue(filter.is_ignored(b'A.c'))
-        self.assertTrue(filter.is_ignored(b'A.C'))
+        filter = IgnoreFilter([b"a.c", b"b.c"], ignorecase=False)
+        self.assertTrue(filter.is_ignored(b"a.c"))
+        self.assertFalse(filter.is_ignored(b"A.c"))
+        filter = IgnoreFilter([b"a.c", b"b.c"], ignorecase=True)
+        self.assertTrue(filter.is_ignored(b"a.c"))
+        self.assertTrue(filter.is_ignored(b"A.c"))
+        self.assertTrue(filter.is_ignored(b"A.C"))
 
     def test_excluded(self):
-        filter = IgnoreFilter([b'a.c', b'b.c', b'!c.c'])
-        self.assertFalse(filter.is_ignored(b'c.c'))
-        self.assertIs(None, filter.is_ignored(b'd.c'))
-        self.assertEqual(
-            [Pattern(b'!c.c')],
-            list(filter.find_matching(b'c.c')))
-        self.assertEqual([], list(filter.find_matching(b'd.c')))
+        filter = IgnoreFilter([b"a.c", b"b.c", b"!c.c"])
+        self.assertFalse(filter.is_ignored(b"c.c"))
+        self.assertIs(None, filter.is_ignored(b"d.c"))
+        self.assertEqual([Pattern(b"!c.c")], list(filter.find_matching(b"c.c")))
+        self.assertEqual([], list(filter.find_matching(b"d.c")))
 
     def test_include_exclude_include(self):
-        filter = IgnoreFilter([b'a.c', b'!a.c', b'a.c'])
-        self.assertTrue(filter.is_ignored(b'a.c'))
+        filter = IgnoreFilter([b"a.c", b"!a.c", b"a.c"])
+        self.assertTrue(filter.is_ignored(b"a.c"))
         self.assertEqual(
-            [Pattern(b'a.c'), Pattern(b'!a.c'), Pattern(b'a.c')],
-            list(filter.find_matching(b'a.c')))
+            [Pattern(b"a.c"), Pattern(b"!a.c"), Pattern(b"a.c")],
+            list(filter.find_matching(b"a.c")),
+        )
 
     def test_manpage(self):
         # A specific example from the gitignore manpage
-        filter = IgnoreFilter([
-            b'/*',
-            b'!/foo',
-            b'/foo/*',
-            b'!/foo/bar'])
-        self.assertTrue(filter.is_ignored(b'a.c'))
-        self.assertTrue(filter.is_ignored(b'foo/blie'))
-        self.assertFalse(filter.is_ignored(b'foo'))
-        self.assertFalse(filter.is_ignored(b'foo/bar'))
-        self.assertFalse(filter.is_ignored(b'foo/bar/'))
-        self.assertFalse(filter.is_ignored(b'foo/bar/bloe'))
+        filter = IgnoreFilter([b"/*", b"!/foo", b"/foo/*", b"!/foo/bar"])
+        self.assertTrue(filter.is_ignored(b"a.c"))
+        self.assertTrue(filter.is_ignored(b"foo/blie"))
+        self.assertFalse(filter.is_ignored(b"foo"))
+        self.assertFalse(filter.is_ignored(b"foo/bar"))
+        self.assertFalse(filter.is_ignored(b"foo/bar/"))
+        self.assertFalse(filter.is_ignored(b"foo/bar/bloe"))
 
 
 class IgnoreFilterStackTests(TestCase):
-
     def test_stack_first(self):
-        filter1 = IgnoreFilter([b'[a].c', b'[b].c', b'![d].c'])
-        filter2 = IgnoreFilter([b'[a].c', b'![b],c', b'[c].c', b'[d].c'])
+        filter1 = IgnoreFilter([b"[a].c", b"[b].c", b"![d].c"])
+        filter2 = IgnoreFilter([b"[a].c", b"![b],c", b"[c].c", b"[d].c"])
         stack = IgnoreFilterStack([filter1, filter2])
-        self.assertIs(True, stack.is_ignored(b'a.c'))
-        self.assertIs(True, stack.is_ignored(b'b.c'))
-        self.assertIs(True, stack.is_ignored(b'c.c'))
-        self.assertIs(False, stack.is_ignored(b'd.c'))
-        self.assertIs(None, stack.is_ignored(b'e.c'))
+        self.assertIs(True, stack.is_ignored(b"a.c"))
+        self.assertIs(True, stack.is_ignored(b"b.c"))
+        self.assertIs(True, stack.is_ignored(b"c.c"))
+        self.assertIs(False, stack.is_ignored(b"d.c"))
+        self.assertIs(None, stack.is_ignored(b"e.c"))
 
 
 class IgnoreFilterManagerTests(TestCase):
-
     def test_load_ignore(self):
         tmp_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, tmp_dir)
         repo = Repo.init(tmp_dir)
+        with open(os.path.join(repo.path, ".gitignore"), "wb") as f:
+            f.write(b"/foo/bar\n")
+            f.write(b"/dir2\n")
+            f.write(b"/dir3/\n")
+        os.mkdir(os.path.join(repo.path, "dir"))
+        with open(os.path.join(repo.path, "dir", ".gitignore"), "wb") as f:
+            f.write(b"/blie\n")
+        with open(os.path.join(repo.path, "dir", "blie"), "wb") as f:
+            f.write(b"IGNORED")
+        p = os.path.join(repo.controldir(), "info", "exclude")
+        with open(p, "wb") as f:
+            f.write(b"/excluded\n")
+        m = IgnoreFilterManager.from_repo(repo)
+        self.assertTrue(m.is_ignored("dir/blie"))
+        self.assertIs(None, m.is_ignored(os.path.join("dir", "bloe")))
+        self.assertIs(None, m.is_ignored("dir"))
+        self.assertTrue(m.is_ignored(os.path.join("foo", "bar")))
+        self.assertTrue(m.is_ignored(os.path.join("excluded")))
+        self.assertTrue(m.is_ignored(os.path.join("dir2", "fileinignoreddir")))
+        self.assertFalse(m.is_ignored("dir3"))
+        self.assertTrue(m.is_ignored("dir3/"))
+        self.assertTrue(m.is_ignored("dir3/bla"))
+
+    def test_nested_gitignores(self):
+        tmp_dir = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, tmp_dir)
+        repo = Repo.init(tmp_dir)
+
         with open(os.path.join(repo.path, '.gitignore'), 'wb') as f:
-            f.write(b'/foo/bar\n')
-            f.write(b'/dir2\n')
-            f.write(b'/dir3/\n')
-        os.mkdir(os.path.join(repo.path, 'dir'))
-        with open(os.path.join(repo.path, 'dir', '.gitignore'), 'wb') as f:
-            f.write(b'/blie\n')
-        with open(os.path.join(repo.path, 'dir', 'blie'), 'wb') as f:
+            f.write(b'/*\n')
+            f.write(b'!/foo\n')
+
+        os.mkdir(os.path.join(repo.path, 'foo'))
+        with open(os.path.join(repo.path, 'foo', '.gitignore'), 'wb') as f:
+            f.write(b'/bar\n')
+
+        with open(os.path.join(repo.path, 'foo', 'bar'), 'wb') as f:
             f.write(b'IGNORED')
-        p = os.path.join(repo.controldir(), 'info', 'exclude')
-        with open(p, 'wb') as f:
-            f.write(b'/excluded\n')
+        
         m = IgnoreFilterManager.from_repo(repo)
-        self.assertTrue(m.is_ignored('dir/blie'))
-        self.assertIs(None,
-                      m.is_ignored(os.path.join('dir', 'bloe')))
-        self.assertIs(None, m.is_ignored('dir'))
-        self.assertTrue(m.is_ignored(os.path.join('foo', 'bar')))
-        self.assertTrue(m.is_ignored(os.path.join('excluded')))
-        self.assertTrue(m.is_ignored(os.path.join(
-            'dir2', 'fileinignoreddir')))
-        self.assertFalse(m.is_ignored('dir3'))
-        self.assertTrue(m.is_ignored('dir3/'))
-        self.assertTrue(m.is_ignored('dir3/bla'))
+        self.assertTrue(m.is_ignored('foo/bar'))
 
     def test_load_ignore_ignorecase(self):
         tmp_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, tmp_dir)
         repo = Repo.init(tmp_dir)
         config = repo.get_config()
-        config.set(b'core', b'ignorecase', True)
+        config.set(b"core", b"ignorecase", True)
         config.write_to_path()
-        with open(os.path.join(repo.path, '.gitignore'), 'wb') as f:
-            f.write(b'/foo/bar\n')
-            f.write(b'/dir\n')
+        with open(os.path.join(repo.path, ".gitignore"), "wb") as f:
+            f.write(b"/foo/bar\n")
+            f.write(b"/dir\n")
         m = IgnoreFilterManager.from_repo(repo)
-        self.assertTrue(m.is_ignored(os.path.join('dir', 'blie')))
-        self.assertTrue(m.is_ignored(os.path.join('DIR', 'blie')))
+        self.assertTrue(m.is_ignored(os.path.join("dir", "blie")))
+        self.assertTrue(m.is_ignored(os.path.join("DIR", "blie")))
 
     def test_ignored_contents(self):
         tmp_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, tmp_dir)
         repo = Repo.init(tmp_dir)
-        with open(os.path.join(repo.path, '.gitignore'), 'wb') as f:
-            f.write(b'a/*\n')
-            f.write(b'!a/*.txt\n')
+        with open(os.path.join(repo.path, ".gitignore"), "wb") as f:
+            f.write(b"a/*\n")
+            f.write(b"!a/*.txt\n")
         m = IgnoreFilterManager.from_repo(repo)
-        os.mkdir(os.path.join(repo.path, 'a'))
-        self.assertIs(None, m.is_ignored('a'))
-        self.assertIs(None, m.is_ignored('a/'))
-        self.assertFalse(m.is_ignored('a/b.txt'))
-        self.assertTrue(m.is_ignored('a/c.dat'))
+        os.mkdir(os.path.join(repo.path, "a"))
+        self.assertIs(None, m.is_ignored("a"))
+        self.assertIs(None, m.is_ignored("a/"))
+        self.assertFalse(m.is_ignored("a/b.txt"))
+        self.assertTrue(m.is_ignored("a/c.dat"))

+ 296 - 221
dulwich/tests/test_index.py

@@ -48,38 +48,39 @@ from dulwich.index import (
     write_index_dict,
     _tree_to_fs_path,
     _fs_to_tree_path,
-    )
+    IndexEntry,
+)
 from dulwich.object_store import (
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
     Blob,
     Commit,
     Tree,
     S_IFGITLINK,
-    )
+)
 from dulwich.repo import Repo
 from dulwich.tests import (
     TestCase,
     skipIf,
-    )
+)
 from dulwich.tests.utils import (
     setup_warning_catcher,
-    )
+)
 
 
 def can_symlink():
     """Return whether running process can create symlinks."""
-    if sys.platform != 'win32':
+    if sys.platform != "win32":
         # Platforms other than Windows should allow symlinks without issues.
         return True
 
-    if not hasattr(os, 'symlink'):
+    if not hasattr(os, "symlink"):
         # Older Python versions do not have `os.symlink` on Windows.
         return False
 
     test_source = tempfile.mkdtemp()
-    test_target = test_source + 'can_symlink'
+    test_target = test_source + "can_symlink"
     try:
         os.symlink(test_source, test_target)
     except (NotImplementedError, OSError):
@@ -89,24 +90,24 @@ def can_symlink():
 
 class IndexTestCase(TestCase):
 
-    datadir = os.path.join(os.path.dirname(__file__), 'data/indexes')
+    datadir = os.path.join(os.path.dirname(__file__), "data/indexes")
 
     def get_simple_index(self, name):
         return Index(os.path.join(self.datadir, name))
 
 
 class SimpleIndexTestCase(IndexTestCase):
-
     def test_len(self):
         self.assertEqual(1, len(self.get_simple_index("index")))
 
     def test_iter(self):
-        self.assertEqual([b'bla'], list(self.get_simple_index("index")))
+        self.assertEqual([b"bla"], list(self.get_simple_index("index")))
 
     def test_iterobjects(self):
         self.assertEqual(
-                [(b'bla', b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 33188)],
-                list(self.get_simple_index("index").iterobjects()))
+            [(b"bla", b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391", 33188)],
+            list(self.get_simple_index("index").iterobjects()),
+        )
 
     def test_iterblobs(self):
         warnings.simplefilter("always", UserWarning)
@@ -115,26 +116,36 @@ class SimpleIndexTestCase(IndexTestCase):
         self.addCleanup(restore_warnings)
 
         self.assertEqual(
-                [(b'bla', b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 33188)],
-                list(self.get_simple_index("index").iterblobs()))
+            [(b"bla", b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391", 33188)],
+            list(self.get_simple_index("index").iterblobs()),
+        )
 
-        expected_warning = PendingDeprecationWarning(
-            'Use iterobjects() instead.')
+        expected_warning = PendingDeprecationWarning("Use iterobjects() instead.")
         for w in warnings_list:
-            if (type(w) == type(expected_warning) and
-                    w.args == expected_warning.args):
+            if type(w) == type(expected_warning) and w.args == expected_warning.args:
                 break
         else:
             raise AssertionError(
-                'Expected warning %r not in %r' %
-                (expected_warning, warnings_list))
+                "Expected warning %r not in %r" % (expected_warning, warnings_list)
+            )
 
     def test_getitem(self):
         self.assertEqual(
-                ((1230680220, 0), (1230680220, 0), 2050, 3761020,
-                 33188, 1000, 1000, 0,
-                 b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0),
-                self.get_simple_index("index")[b"bla"])
+            (
+                (1230680220, 0),
+                (1230680220, 0),
+                2050,
+                3761020,
+                33188,
+                1000,
+                1000,
+                0,
+                b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
+                0,
+                0,
+            ),
+            self.get_simple_index("index")[b"bla"],
+        )
 
     def test_empty(self):
         i = self.get_simple_index("notanindex")
@@ -146,12 +157,11 @@ class SimpleIndexTestCase(IndexTestCase):
         changes = list(i.changes_from_tree(MemoryObjectStore(), None))
         self.assertEqual(1, len(changes))
         (oldname, newname), (oldmode, newmode), (oldsha, newsha) = changes[0]
-        self.assertEqual(b'bla', newname)
-        self.assertEqual(b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', newsha)
+        self.assertEqual(b"bla", newname)
+        self.assertEqual(b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391", newsha)
 
 
 class SimpleIndexWriterTestCase(IndexTestCase):
-
     def setUp(self):
         IndexTestCase.setUp(self)
         self.tempdir = tempfile.mkdtemp()
@@ -161,19 +171,32 @@ class SimpleIndexWriterTestCase(IndexTestCase):
         shutil.rmtree(self.tempdir)
 
     def test_simple_write(self):
-        entries = [(b'barbla', (1230680220, 0), (1230680220, 0), 2050, 3761020,
-                    33188, 1000, 1000, 0,
-                    b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0)]
-        filename = os.path.join(self.tempdir, 'test-simple-write-index')
-        with open(filename, 'wb+') as x:
+        entries = [
+            (
+                b"barbla",
+                IndexEntry(
+                    (1230680220, 0),
+                    (1230680220, 0),
+                    2050,
+                    3761020,
+                    33188,
+                    1000,
+                    1000,
+                    0,
+                    b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
+                    0,
+                    0)
+            )
+        ]
+        filename = os.path.join(self.tempdir, "test-simple-write-index")
+        with open(filename, "wb+") as x:
             write_index(x, entries)
 
-        with open(filename, 'rb') as x:
+        with open(filename, "rb") as x:
             self.assertEqual(entries, list(read_index(x)))
 
 
 class ReadIndexDictTests(IndexTestCase):
-
     def setUp(self):
         IndexTestCase.setUp(self)
         self.tempdir = tempfile.mkdtemp()
@@ -184,20 +207,29 @@ class ReadIndexDictTests(IndexTestCase):
 
     def test_simple_write(self):
         entries = {
-                b'barbla':
-                ((1230680220, 0), (1230680220, 0), 2050, 3761020, 33188,
-                 1000, 1000, 0,
-                 b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0)}
-        filename = os.path.join(self.tempdir, 'test-simple-write-index')
-        with open(filename, 'wb+') as x:
+            b"barbla": IndexEntry(
+                (1230680220, 0),
+                (1230680220, 0),
+                2050,
+                3761020,
+                33188,
+                1000,
+                1000,
+                0,
+                b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
+                0,
+                0,
+            )
+        }
+        filename = os.path.join(self.tempdir, "test-simple-write-index")
+        with open(filename, "wb+") as x:
             write_index_dict(x, entries)
 
-        with open(filename, 'rb') as x:
+        with open(filename, "rb") as x:
             self.assertEqual(entries, read_index_dict(x))
 
 
 class CommitTreeTests(TestCase):
-
     def setUp(self):
         super(CommitTreeTests, self).setUp()
         self.store = MemoryObjectStore()
@@ -223,14 +255,12 @@ class CommitTreeTests(TestCase):
         self.assertEqual(dirid, b"c1a1deb9788150829579a8b4efa6311e7b638650")
         self.assertEqual((stat.S_IFDIR, dirid), self.store[rootid][b"bla"])
         self.assertEqual((stat.S_IFREG, blob.id), self.store[dirid][b"bar"])
-        self.assertEqual(set([rootid, dirid, blob.id]),
-                         set(self.store._data.keys()))
+        self.assertEqual(set([rootid, dirid, blob.id]), set(self.store._data.keys()))
 
 
 class CleanupModeTests(TestCase):
-
     def assertModeEqual(self, expected, got):
-        self.assertEqual(expected, got, '%o != %o' % (expected, got))
+        self.assertEqual(expected, got, "%o != %o" % (expected, got))
 
     def test_file(self):
         self.assertModeEqual(0o100644, cleanup_mode(0o100000))
@@ -250,7 +280,6 @@ class CleanupModeTests(TestCase):
 
 
 class WriteCacheTimeTests(TestCase):
-
     def test_write_string(self):
         f = BytesIO()
         self.assertRaises(TypeError, write_cache_time, f, "foo")
@@ -272,46 +301,74 @@ class WriteCacheTimeTests(TestCase):
 
 
 class IndexEntryFromStatTests(TestCase):
-
     def test_simple(self):
         st = os.stat_result(
-                (16877, 131078, 64769, 154, 1000, 1000, 12288,
-                 1323629595, 1324180496, 1324180496))
+            (
+                16877,
+                131078,
+                64769,
+                154,
+                1000,
+                1000,
+                12288,
+                1323629595,
+                1324180496,
+                1324180496,
+            )
+        )
         entry = index_entry_from_stat(st, "22" * 20, 0)
-        self.assertEqual(entry, (
-            1324180496,
-            1324180496,
-            64769,
-            131078,
-            16384,
-            1000,
-            1000,
-            12288,
-            '2222222222222222222222222222222222222222',
-            0))
+        self.assertEqual(
+            entry,
+            IndexEntry(
+                1324180496,
+                1324180496,
+                64769,
+                131078,
+                16384,
+                1000,
+                1000,
+                12288,
+                "2222222222222222222222222222222222222222",
+                0,
+                None,
+            ),
+        )
 
     def test_override_mode(self):
         st = os.stat_result(
-                (stat.S_IFREG + 0o644, 131078, 64769,
-                 154, 1000, 1000, 12288,
-                 1323629595, 1324180496, 1324180496))
-        entry = index_entry_from_stat(
-            st, "22" * 20, 0, mode=stat.S_IFREG + 0o755)
-        self.assertEqual(entry, (
-            1324180496,
-            1324180496,
-            64769,
-            131078,
-            33261,
-            1000,
-            1000,
-            12288,
-            '2222222222222222222222222222222222222222',
-            0))
+            (
+                stat.S_IFREG + 0o644,
+                131078,
+                64769,
+                154,
+                1000,
+                1000,
+                12288,
+                1323629595,
+                1324180496,
+                1324180496,
+            )
+        )
+        entry = index_entry_from_stat(st, "22" * 20, 0, mode=stat.S_IFREG + 0o755)
+        self.assertEqual(
+            entry,
+            IndexEntry(
+                1324180496,
+                1324180496,
+                64769,
+                131078,
+                33261,
+                1000,
+                1000,
+                12288,
+                "2222222222222222222222222222222222222222",
+                0,
+                None,
+            ),
+        )
 
 
 class BuildIndexTests(TestCase):
-
     def assertReasonableIndexEntry(self, index_entry, mode, filesize, sha):
         self.assertEqual(index_entry[4], mode)  # mode
         self.assertEqual(index_entry[7], filesize)  # filesize
@@ -321,7 +378,7 @@ class BuildIndexTests(TestCase):
         if symlink:
             self.assertEqual(os.readlink(path), contents)
         else:
-            with open(path, 'rb') as f:
+            with open(path, "rb") as f:
                 self.assertEqual(f.read(), contents)
 
     def test_empty(self):
@@ -332,15 +389,15 @@ class BuildIndexTests(TestCase):
             repo.object_store.add_object(tree)
 
             build_index_from_tree(
-                    repo.path, repo.index_path(),
-                    repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
             # Verify index entries
             index = repo.open_index()
             self.assertEqual(len(index), 0)
 
             # Verify no files
-            self.assertEqual(['.git'], os.listdir(repo.path))
+            self.assertEqual([".git"], os.listdir(repo.path))
 
     def test_git_dir(self):
         repo_dir = tempfile.mkdtemp()
@@ -348,33 +405,34 @@ class BuildIndexTests(TestCase):
         with Repo.init(repo_dir) as repo:
 
             # Populate repo
-            filea = Blob.from_string(b'file a')
-            filee = Blob.from_string(b'd')
+            filea = Blob.from_string(b"file a")
+            filee = Blob.from_string(b"d")
 
             tree = Tree()
-            tree[b'.git/a'] = (stat.S_IFREG | 0o644, filea.id)
-            tree[b'c/e'] = (stat.S_IFREG | 0o644, filee.id)
+            tree[b".git/a"] = (stat.S_IFREG | 0o644, filea.id)
+            tree[b"c/e"] = (stat.S_IFREG | 0o644, filee.id)
 
-            repo.object_store.add_objects(
-                    [(o, None) for o in [filea, filee, tree]])
+            repo.object_store.add_objects([(o, None) for o in [filea, filee, tree]])
 
             build_index_from_tree(
-                repo.path, repo.index_path(), repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
             # Verify index entries
             index = repo.open_index()
             self.assertEqual(len(index), 1)
 
             # filea
-            apath = os.path.join(repo.path, '.git', 'a')
+            apath = os.path.join(repo.path, ".git", "a")
             self.assertFalse(os.path.exists(apath))
 
             # filee
-            epath = os.path.join(repo.path, 'c', 'e')
+            epath = os.path.join(repo.path, "c", "e")
             self.assertTrue(os.path.exists(epath))
             self.assertReasonableIndexEntry(
-                index[b'c/e'], stat.S_IFREG | 0o644, 1, filee.id)
-            self.assertFileContents(epath, b'd')
+                index[b"c/e"], stat.S_IFREG | 0o644, 1, filee.id
+            )
+            self.assertFileContents(epath, b"d")
 
     def test_nonempty(self):
         repo_dir = tempfile.mkdtemp()
@@ -382,122 +440,130 @@ class BuildIndexTests(TestCase):
         with Repo.init(repo_dir) as repo:
 
             # Populate repo
-            filea = Blob.from_string(b'file a')
-            fileb = Blob.from_string(b'file b')
-            filed = Blob.from_string(b'file d')
+            filea = Blob.from_string(b"file a")
+            fileb = Blob.from_string(b"file b")
+            filed = Blob.from_string(b"file d")
 
             tree = Tree()
-            tree[b'a'] = (stat.S_IFREG | 0o644, filea.id)
-            tree[b'b'] = (stat.S_IFREG | 0o644, fileb.id)
-            tree[b'c/d'] = (stat.S_IFREG | 0o644, filed.id)
+            tree[b"a"] = (stat.S_IFREG | 0o644, filea.id)
+            tree[b"b"] = (stat.S_IFREG | 0o644, fileb.id)
+            tree[b"c/d"] = (stat.S_IFREG | 0o644, filed.id)
 
             repo.object_store.add_objects(
-                [(o, None) for o in [filea, fileb, filed, tree]])
+                [(o, None) for o in [filea, fileb, filed, tree]]
+            )
 
             build_index_from_tree(
-                repo.path, repo.index_path(), repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
             # Verify index entries
             index = repo.open_index()
             self.assertEqual(len(index), 3)
 
             # filea
-            apath = os.path.join(repo.path, 'a')
+            apath = os.path.join(repo.path, "a")
             self.assertTrue(os.path.exists(apath))
             self.assertReasonableIndexEntry(
-                    index[b'a'], stat.S_IFREG | 0o644, 6, filea.id)
-            self.assertFileContents(apath, b'file a')
+                index[b"a"], stat.S_IFREG | 0o644, 6, filea.id
+            )
+            self.assertFileContents(apath, b"file a")
 
             # fileb
-            bpath = os.path.join(repo.path, 'b')
+            bpath = os.path.join(repo.path, "b")
             self.assertTrue(os.path.exists(bpath))
             self.assertReasonableIndexEntry(
-                    index[b'b'], stat.S_IFREG | 0o644, 6, fileb.id)
-            self.assertFileContents(bpath, b'file b')
+                index[b"b"], stat.S_IFREG | 0o644, 6, fileb.id
+            )
+            self.assertFileContents(bpath, b"file b")
 
             # filed
-            dpath = os.path.join(repo.path, 'c', 'd')
+            dpath = os.path.join(repo.path, "c", "d")
             self.assertTrue(os.path.exists(dpath))
             self.assertReasonableIndexEntry(
-                    index[b'c/d'], stat.S_IFREG | 0o644, 6, filed.id)
-            self.assertFileContents(dpath, b'file d')
+                index[b"c/d"], stat.S_IFREG | 0o644, 6, filed.id
+            )
+            self.assertFileContents(dpath, b"file d")
 
             # Verify no extra files
-            self.assertEqual(
-                    ['.git', 'a', 'b', 'c'], sorted(os.listdir(repo.path)))
-            self.assertEqual(
-                    ['d'], sorted(os.listdir(os.path.join(repo.path, 'c'))))
+            self.assertEqual([".git", "a", "b", "c"], sorted(os.listdir(repo.path)))
+            self.assertEqual(["d"], sorted(os.listdir(os.path.join(repo.path, "c"))))
 
-    @skipIf(not getattr(os, 'sync', None), 'Requires sync support')
+    @skipIf(not getattr(os, "sync", None), "Requires sync support")
     def test_norewrite(self):
         repo_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, repo_dir)
         with Repo.init(repo_dir) as repo:
             # Populate repo
-            filea = Blob.from_string(b'file a')
-            filea_path = os.path.join(repo_dir, 'a')
+            filea = Blob.from_string(b"file a")
+            filea_path = os.path.join(repo_dir, "a")
             tree = Tree()
-            tree[b'a'] = (stat.S_IFREG | 0o644, filea.id)
+            tree[b"a"] = (stat.S_IFREG | 0o644, filea.id)
 
             repo.object_store.add_objects([(o, None) for o in [filea, tree]])
 
             # First Write
-            build_index_from_tree(repo.path, repo.index_path(),
-                                  repo.object_store, tree.id)
+            build_index_from_tree(
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
             # Use sync as metadata can be cached on some FS
             os.sync()
             mtime = os.stat(filea_path).st_mtime
 
             # Test Rewrite
-            build_index_from_tree(repo.path, repo.index_path(),
-                                  repo.object_store, tree.id)
+            build_index_from_tree(
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
             os.sync()
             self.assertEqual(mtime, os.stat(filea_path).st_mtime)
 
             # Modify content
-            with open(filea_path, 'wb') as fh:
-                fh.write(b'test a')
+            with open(filea_path, "wb") as fh:
+                fh.write(b"test a")
             os.sync()
             mtime = os.stat(filea_path).st_mtime
 
             # Test rewrite
-            build_index_from_tree(repo.path, repo.index_path(),
-                                  repo.object_store, tree.id)
+            build_index_from_tree(
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
             os.sync()
-            with open(filea_path, 'rb') as fh:
-                self.assertEqual(b'file a', fh.read())
+            with open(filea_path, "rb") as fh:
+                self.assertEqual(b"file a", fh.read())
 
-    @skipIf(not can_symlink(), 'Requires symlink support')
+    @skipIf(not can_symlink(), "Requires symlink support")
     def test_symlink(self):
         repo_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, repo_dir)
         with Repo.init(repo_dir) as repo:
 
             # Populate repo
-            filed = Blob.from_string(b'file d')
-            filee = Blob.from_string(b'd')
+            filed = Blob.from_string(b"file d")
+            filee = Blob.from_string(b"d")
 
             tree = Tree()
-            tree[b'c/d'] = (stat.S_IFREG | 0o644, filed.id)
-            tree[b'c/e'] = (stat.S_IFLNK, filee.id)  # symlink
+            tree[b"c/d"] = (stat.S_IFREG | 0o644, filed.id)
+            tree[b"c/e"] = (stat.S_IFLNK, filee.id)  # symlink
 
-            repo.object_store.add_objects(
-                    [(o, None) for o in [filed, filee, tree]])
+            repo.object_store.add_objects([(o, None) for o in [filed, filee, tree]])
 
             build_index_from_tree(
-                    repo.path, repo.index_path(), repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
             # Verify index entries
             index = repo.open_index()
 
             # symlink to d
-            epath = os.path.join(repo.path, 'c', 'e')
+            epath = os.path.join(repo.path, "c", "e")
             self.assertTrue(os.path.exists(epath))
             self.assertReasonableIndexEntry(
-                index[b'c/e'], stat.S_IFLNK,
-                0 if sys.platform == 'win32' else 1,
-                filee.id)
-            self.assertFileContents(epath, 'd', symlink=True)
+                index[b"c/e"],
+                stat.S_IFLNK,
+                0 if sys.platform == "win32" else 1,
+                filee.id,
+            )
+            self.assertFileContents(epath, "d", symlink=True)
 
     def test_no_decode_encode(self):
         repo_dir = tempfile.mkdtemp()
@@ -506,33 +572,32 @@ class BuildIndexTests(TestCase):
         with Repo.init(repo_dir) as repo:
 
             # Populate repo
-            file = Blob.from_string(b'foo')
+            file = Blob.from_string(b"foo")
 
             tree = Tree()
-            latin1_name = u'À'.encode('latin1')
+            latin1_name = u"À".encode("latin1")
             latin1_path = os.path.join(repo_dir_bytes, latin1_name)
-            utf8_name = u'À'.encode('utf8')
+            utf8_name = u"À".encode("utf8")
             utf8_path = os.path.join(repo_dir_bytes, utf8_name)
             tree[latin1_name] = (stat.S_IFREG | 0o644, file.id)
             tree[utf8_name] = (stat.S_IFREG | 0o644, file.id)
 
-            repo.object_store.add_objects(
-                [(o, None) for o in [file, tree]])
+            repo.object_store.add_objects([(o, None) for o in [file, tree]])
 
             try:
                 build_index_from_tree(
-                    repo.path, repo.index_path(),
-                    repo.object_store, tree.id)
+                    repo.path, repo.index_path(), repo.object_store, tree.id
+                )
             except OSError as e:
-                if e.errno == 92 and sys.platform == 'darwin':
+                if e.errno == 92 and sys.platform == "darwin":
                     # Our filename isn't supported by the platform :(
-                    self.skipTest('can not write filename %r' % e.filename)
+                    self.skipTest("can not write filename %r" % e.filename)
                 else:
                     raise
             except UnicodeDecodeError:
                 # This happens e.g. with python3.6 on Windows.
                 # It implicitly decodes using utf8, which doesn't work.
-                self.skipTest('can not implicitly convert as utf8')
+                self.skipTest("can not implicitly convert as utf8")
 
             # Verify index entries
             index = repo.open_index()
@@ -547,86 +612,85 @@ class BuildIndexTests(TestCase):
         repo_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, repo_dir)
         with Repo.init(repo_dir) as repo:
-            filea = Blob.from_string(b'file alalala')
+            filea = Blob.from_string(b"file alalala")
 
             subtree = Tree()
-            subtree[b'a'] = (stat.S_IFREG | 0o644, filea.id)
+            subtree[b"a"] = (stat.S_IFREG | 0o644, filea.id)
 
             c = Commit()
             c.tree = subtree.id
-            c.committer = c.author = b'Somebody <somebody@example.com>'
+            c.committer = c.author = b"Somebody <somebody@example.com>"
             c.commit_time = c.author_time = 42342
             c.commit_timezone = c.author_timezone = 0
             c.parents = []
-            c.message = b'Subcommit'
+            c.message = b"Subcommit"
 
             tree = Tree()
-            tree[b'c'] = (S_IFGITLINK, c.id)
+            tree[b"c"] = (S_IFGITLINK, c.id)
 
-            repo.object_store.add_objects(
-                [(o, None) for o in [tree]])
+            repo.object_store.add_objects([(o, None) for o in [tree]])
 
             build_index_from_tree(
-                    repo.path, repo.index_path(), repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
             # Verify index entries
             index = repo.open_index()
             self.assertEqual(len(index), 1)
 
             # filea
-            apath = os.path.join(repo.path, 'c/a')
+            apath = os.path.join(repo.path, "c/a")
             self.assertFalse(os.path.exists(apath))
 
             # dir c
-            cpath = os.path.join(repo.path, 'c')
+            cpath = os.path.join(repo.path, "c")
             self.assertTrue(os.path.isdir(cpath))
-            self.assertEqual(index[b'c'][4], S_IFGITLINK)  # mode
-            self.assertEqual(index[b'c'][8], c.id)  # sha
+            self.assertEqual(index[b"c"][4], S_IFGITLINK)  # mode
+            self.assertEqual(index[b"c"][8], c.id)  # sha
 
     def test_git_submodule_exists(self):
         repo_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, repo_dir)
         with Repo.init(repo_dir) as repo:
-            filea = Blob.from_string(b'file alalala')
+            filea = Blob.from_string(b"file alalala")
 
             subtree = Tree()
-            subtree[b'a'] = (stat.S_IFREG | 0o644, filea.id)
+            subtree[b"a"] = (stat.S_IFREG | 0o644, filea.id)
 
             c = Commit()
             c.tree = subtree.id
-            c.committer = c.author = b'Somebody <somebody@example.com>'
+            c.committer = c.author = b"Somebody <somebody@example.com>"
             c.commit_time = c.author_time = 42342
             c.commit_timezone = c.author_timezone = 0
             c.parents = []
-            c.message = b'Subcommit'
+            c.message = b"Subcommit"
 
             tree = Tree()
-            tree[b'c'] = (S_IFGITLINK, c.id)
+            tree[b"c"] = (S_IFGITLINK, c.id)
 
-            os.mkdir(os.path.join(repo_dir, 'c'))
-            repo.object_store.add_objects(
-                [(o, None) for o in [tree]])
+            os.mkdir(os.path.join(repo_dir, "c"))
+            repo.object_store.add_objects([(o, None) for o in [tree]])
 
             build_index_from_tree(
-                    repo.path, repo.index_path(), repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
             # Verify index entries
             index = repo.open_index()
             self.assertEqual(len(index), 1)
 
             # filea
-            apath = os.path.join(repo.path, 'c/a')
+            apath = os.path.join(repo.path, "c/a")
             self.assertFalse(os.path.exists(apath))
 
             # dir c
-            cpath = os.path.join(repo.path, 'c')
+            cpath = os.path.join(repo.path, "c")
             self.assertTrue(os.path.isdir(cpath))
-            self.assertEqual(index[b'c'][4], S_IFGITLINK)  # mode
-            self.assertEqual(index[b'c'][8], c.id)  # sha
+            self.assertEqual(index[b"c"][4], S_IFGITLINK)  # mode
+            self.assertEqual(index[b"c"][8], c.id)  # sha
 
 
 class GetUnstagedChangesTests(TestCase):
-
     def test_get_unstaged_changes(self):
         """Unit test for get_unstaged_changes."""
 
@@ -635,27 +699,30 @@ class GetUnstagedChangesTests(TestCase):
         with Repo.init(repo_dir) as repo:
 
             # Commit a dummy file then modify it
-            foo1_fullpath = os.path.join(repo_dir, 'foo1')
-            with open(foo1_fullpath, 'wb') as f:
-                f.write(b'origstuff')
+            foo1_fullpath = os.path.join(repo_dir, "foo1")
+            with open(foo1_fullpath, "wb") as f:
+                f.write(b"origstuff")
 
-            foo2_fullpath = os.path.join(repo_dir, 'foo2')
-            with open(foo2_fullpath, 'wb') as f:
-                f.write(b'origstuff')
+            foo2_fullpath = os.path.join(repo_dir, "foo2")
+            with open(foo2_fullpath, "wb") as f:
+                f.write(b"origstuff")
 
-            repo.stage(['foo1', 'foo2'])
-            repo.do_commit(b'test status', author=b'author <email>',
-                           committer=b'committer <email>')
+            repo.stage(["foo1", "foo2"])
+            repo.do_commit(
+                b"test status",
+                author=b"author <email>",
+                committer=b"committer <email>",
+            )
 
-            with open(foo1_fullpath, 'wb') as f:
-                f.write(b'newstuff')
+            with open(foo1_fullpath, "wb") as f:
+                f.write(b"newstuff")
 
             # modify access and modify time of path
             os.utime(foo1_fullpath, (0, 0))
 
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
 
-            self.assertEqual(list(changes), [b'foo1'])
+            self.assertEqual(list(changes), [b"foo1"])
 
     def test_get_unstaged_deleted_changes(self):
         """Unit test for get_unstaged_changes."""
@@ -665,19 +732,22 @@ class GetUnstagedChangesTests(TestCase):
         with Repo.init(repo_dir) as repo:
 
             # Commit a dummy file then remove it
-            foo1_fullpath = os.path.join(repo_dir, 'foo1')
-            with open(foo1_fullpath, 'wb') as f:
-                f.write(b'origstuff')
+            foo1_fullpath = os.path.join(repo_dir, "foo1")
+            with open(foo1_fullpath, "wb") as f:
+                f.write(b"origstuff")
 
-            repo.stage(['foo1'])
-            repo.do_commit(b'test status', author=b'author <email>',
-                           committer=b'committer <email>')
+            repo.stage(["foo1"])
+            repo.do_commit(
+                b"test status",
+                author=b"author <email>",
+                committer=b"committer <email>",
+            )
 
             os.unlink(foo1_fullpath)
 
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
 
-            self.assertEqual(list(changes), [b'foo1'])
+            self.assertEqual(list(changes), [b"foo1"])
 
     def test_get_unstaged_changes_removed_replaced_by_directory(self):
         """Unit test for get_unstaged_changes."""
@@ -687,22 +757,25 @@ class GetUnstagedChangesTests(TestCase):
         with Repo.init(repo_dir) as repo:
 
             # Commit a dummy file then modify it
-            foo1_fullpath = os.path.join(repo_dir, 'foo1')
-            with open(foo1_fullpath, 'wb') as f:
-                f.write(b'origstuff')
+            foo1_fullpath = os.path.join(repo_dir, "foo1")
+            with open(foo1_fullpath, "wb") as f:
+                f.write(b"origstuff")
 
-            repo.stage(['foo1'])
-            repo.do_commit(b'test status', author=b'author <email>',
-                           committer=b'committer <email>')
+            repo.stage(["foo1"])
+            repo.do_commit(
+                b"test status",
+                author=b"author <email>",
+                committer=b"committer <email>",
+            )
 
             os.remove(foo1_fullpath)
             os.mkdir(foo1_fullpath)
 
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
 
-            self.assertEqual(list(changes), [b'foo1'])
+            self.assertEqual(list(changes), [b"foo1"])
 
-    @skipIf(not can_symlink(), 'Requires symlink support')
+    @skipIf(not can_symlink(), "Requires symlink support")
     def test_get_unstaged_changes_removed_replaced_by_link(self):
         """Unit test for get_unstaged_changes."""
 
@@ -711,24 +784,26 @@ class GetUnstagedChangesTests(TestCase):
         with Repo.init(repo_dir) as repo:
 
             # Commit a dummy file then modify it
-            foo1_fullpath = os.path.join(repo_dir, 'foo1')
-            with open(foo1_fullpath, 'wb') as f:
-                f.write(b'origstuff')
+            foo1_fullpath = os.path.join(repo_dir, "foo1")
+            with open(foo1_fullpath, "wb") as f:
+                f.write(b"origstuff")
 
-            repo.stage(['foo1'])
-            repo.do_commit(b'test status', author=b'author <email>',
-                           committer=b'committer <email>')
+            repo.stage(["foo1"])
+            repo.do_commit(
+                b"test status",
+                author=b"author <email>",
+                committer=b"committer <email>",
+            )
 
             os.remove(foo1_fullpath)
             os.symlink(os.path.dirname(foo1_fullpath), foo1_fullpath)
 
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
 
-            self.assertEqual(list(changes), [b'foo1'])
+            self.assertEqual(list(changes), [b"foo1"])
 
 
 class TestValidatePathElement(TestCase):
-
     def test_default(self):
         self.assertTrue(validate_path_element_default(b"bla"))
         self.assertTrue(validate_path_element_default(b".bla"))
@@ -747,20 +822,20 @@ class TestValidatePathElement(TestCase):
 
 
 class TestTreeFSPathConversion(TestCase):
-
     def test_tree_to_fs_path(self):
-        tree_path = u'délwíçh/foo'.encode('utf8')
-        fs_path = _tree_to_fs_path(b'/prefix/path', tree_path)
+        tree_path = u"délwíçh/foo".encode("utf8")
+        fs_path = _tree_to_fs_path(b"/prefix/path", tree_path)
         self.assertEqual(
             fs_path,
-            os.fsencode(os.path.join(u'/prefix/path', u'délwíçh', u'foo')))
+            os.fsencode(os.path.join(u"/prefix/path", u"délwíçh", u"foo")),
+        )
 
     def test_fs_to_tree_path_str(self):
-        fs_path = os.path.join(os.path.join(u'délwíçh', u'foo'))
+        fs_path = os.path.join(os.path.join(u"délwíçh", u"foo"))
         tree_path = _fs_to_tree_path(fs_path)
-        self.assertEqual(tree_path, u'délwíçh/foo'.encode('utf-8'))
+        self.assertEqual(tree_path, u"délwíçh/foo".encode("utf-8"))
 
     def test_fs_to_tree_path_bytes(self):
-        fs_path = os.path.join(os.fsencode(os.path.join(u'délwíçh', u'foo')))
+        fs_path = os.path.join(os.fsencode(os.path.join(u"délwíçh", u"foo")))
         tree_path = _fs_to_tree_path(fs_path)
-        self.assertEqual(tree_path, u'délwíçh/foo'.encode('utf-8'))
+        self.assertEqual(tree_path, u"délwíçh/foo".encode("utf-8"))

+ 3 - 5
dulwich/tests/test_lfs.py

@@ -27,7 +27,6 @@ import tempfile
 
 
 class LFSTests(TestCase):
-
     def setUp(self):
         super(LFSTests, self).setUp()
         self.test_dir = tempfile.mkdtemp()
@@ -35,10 +34,9 @@ class LFSTests(TestCase):
         self.lfs = LFSStore.create(self.test_dir)
 
     def test_create(self):
-        sha = self.lfs.write_object([b'a', b'b'])
+        sha = self.lfs.write_object([b"a", b"b"])
         with self.lfs.open_object(sha) as f:
-            self.assertEqual(b'ab', f.read())
+            self.assertEqual(b"ab", f.read())
 
     def test_missing(self):
-        self.assertRaises(
-            KeyError, self.lfs.open_object, 'abcdeabcdeabcdeabcde')
+        self.assertRaises(KeyError, self.lfs.open_object, "abcdeabcdeabcdeabcde")

+ 4 - 13
dulwich/tests/test_line_ending.py

@@ -40,31 +40,22 @@ class LineEndingConversion(TestCase):
         self.assertEqual(convert_crlf_to_lf(b"foobar"), b"foobar")
 
     def test_convert_crlf_to_lf(self):
-        self.assertEqual(
-            convert_crlf_to_lf(b"line1\r\nline2"), b"line1\nline2"
-        )
+        self.assertEqual(convert_crlf_to_lf(b"line1\r\nline2"), b"line1\nline2")
 
     def test_convert_crlf_to_lf_mixed(self):
-        self.assertEqual(
-            convert_crlf_to_lf(b"line1\r\n\nline2"), b"line1\n\nline2"
-        )
+        self.assertEqual(convert_crlf_to_lf(b"line1\r\n\nline2"), b"line1\n\nline2")
 
     def test_convert_lf_to_crlf_no_op(self):
         self.assertEqual(convert_lf_to_crlf(b"foobar"), b"foobar")
 
     def test_convert_lf_to_crlf(self):
-        self.assertEqual(
-            convert_lf_to_crlf(b"line1\nline2"), b"line1\r\nline2"
-        )
+        self.assertEqual(convert_lf_to_crlf(b"line1\nline2"), b"line1\r\nline2")
 
     def test_convert_lf_to_crlf_mixed(self):
-        self.assertEqual(
-            convert_lf_to_crlf(b"line1\r\n\nline2"), b"line1\r\n\r\nline2"
-        )
+        self.assertEqual(convert_lf_to_crlf(b"line1\r\n\nline2"), b"line1\r\n\r\nline2")
 
 
 class GetLineEndingAutocrlfFilters(TestCase):
-
     def test_get_checkin_filter_autocrlf_default(self):
         checkin_filter = get_checkin_filter_autocrlf(b"false")
 

+ 89 - 88
dulwich/tests/test_lru_cache.py

@@ -21,10 +21,10 @@
 
 from dulwich import (
     lru_cache,
-    )
+)
 from dulwich.tests import (
     TestCase,
-    )
+)
 
 
 class TestLRUCache(TestCase):
@@ -43,13 +43,13 @@ class TestLRUCache(TestCase):
     def test_missing(self):
         cache = lru_cache.LRUCache(max_cache=10)
 
-        self.assertFalse('foo' in cache)
-        self.assertRaises(KeyError, cache.__getitem__, 'foo')
+        self.assertFalse("foo" in cache)
+        self.assertRaises(KeyError, cache.__getitem__, "foo")
 
-        cache['foo'] = 'bar'
-        self.assertEqual('bar', cache['foo'])
-        self.assertTrue('foo' in cache)
-        self.assertFalse('bar' in cache)
+        cache["foo"] = "bar"
+        self.assertEqual("bar", cache["foo"])
+        self.assertTrue("foo" in cache)
+        self.assertFalse("bar" in cache)
 
     def test_map_None(self):
         # Make sure that we can properly map None as a key.
@@ -76,28 +76,28 @@ class TestLRUCache(TestCase):
         """Adding extra entries will pop out old ones."""
         cache = lru_cache.LRUCache(max_cache=1, after_cleanup_count=1)
 
-        cache['foo'] = 'bar'
+        cache["foo"] = "bar"
         # With a max cache of 1, adding 'baz' should pop out 'foo'
-        cache['baz'] = 'biz'
+        cache["baz"] = "biz"
 
-        self.assertFalse('foo' in cache)
-        self.assertTrue('baz' in cache)
+        self.assertFalse("foo" in cache)
+        self.assertTrue("baz" in cache)
 
-        self.assertEqual('biz', cache['baz'])
+        self.assertEqual("biz", cache["baz"])
 
     def test_by_usage(self):
         """Accessing entries bumps them up in priority."""
         cache = lru_cache.LRUCache(max_cache=2)
 
-        cache['baz'] = 'biz'
-        cache['foo'] = 'bar'
+        cache["baz"] = "biz"
+        cache["foo"] = "bar"
 
-        self.assertEqual('biz', cache['baz'])
+        self.assertEqual("biz", cache["baz"])
 
         # This must kick out 'foo' because it was the last accessed
-        cache['nub'] = 'in'
+        cache["nub"] = "in"
 
-        self.assertFalse('foo' in cache)
+        self.assertFalse("foo" in cache)
 
     def test_cleanup(self):
         """Test that we can use a cleanup function."""
@@ -108,17 +108,16 @@ class TestLRUCache(TestCase):
 
         cache = lru_cache.LRUCache(max_cache=2, after_cleanup_count=2)
 
-        cache.add('baz', '1', cleanup=cleanup_func)
-        cache.add('foo', '2', cleanup=cleanup_func)
-        cache.add('biz', '3', cleanup=cleanup_func)
+        cache.add("baz", "1", cleanup=cleanup_func)
+        cache.add("foo", "2", cleanup=cleanup_func)
+        cache.add("biz", "3", cleanup=cleanup_func)
 
-        self.assertEqual([('baz', '1')], cleanup_called)
+        self.assertEqual([("baz", "1")], cleanup_called)
 
         # 'foo' is now most recent, so final cleanup will call it last
-        cache['foo']
+        cache["foo"]
         cache.clear()
-        self.assertEqual([('baz', '1'), ('biz', '3'), ('foo', '2')],
-                         cleanup_called)
+        self.assertEqual([("baz", "1"), ("biz", "3"), ("foo", "2")], cleanup_called)
 
     def test_cleanup_on_replace(self):
         """Replacing an object should cleanup the old value."""
@@ -166,8 +165,10 @@ class TestLRUCache(TestCase):
 
         # We hit the max
         self.assertEqual(10, len(cache))
-        self.assertEqual([11, 10, 9, 1, 8, 7, 6, 5, 4, 3],
-                         [n.key for n in cache._walk_lru()])
+        self.assertEqual(
+            [11, 10, 9, 1, 8, 7, 6, 5, 4, 3],
+            [n.key for n in cache._walk_lru()],
+        )
 
     def test_cleanup_shrinks_to_after_clean_count(self):
         cache = lru_cache.LRUCache(max_cache=5, after_cleanup_count=3)
@@ -293,11 +294,10 @@ class TestLRUCache(TestCase):
 
 
 class TestLRUSizeCache(TestCase):
-
     def test_basic_init(self):
         cache = lru_cache.LRUSizeCache()
         self.assertEqual(2048, cache._max_cache)
-        self.assertEqual(int(cache._max_size*0.8), cache._after_cleanup_size)
+        self.assertEqual(int(cache._max_size * 0.8), cache._after_cleanup_size)
         self.assertEqual(0, cache._value_size)
 
     def test_add__null_key(self):
@@ -307,15 +307,15 @@ class TestLRUSizeCache(TestCase):
     def test_add_tracks_size(self):
         cache = lru_cache.LRUSizeCache()
         self.assertEqual(0, cache._value_size)
-        cache.add('my key', 'my value text')
+        cache.add("my key", "my value text")
         self.assertEqual(13, cache._value_size)
 
     def test_remove_tracks_size(self):
         cache = lru_cache.LRUSizeCache()
         self.assertEqual(0, cache._value_size)
-        cache.add('my key', 'my value text')
+        cache.add("my key", "my value text")
         self.assertEqual(13, cache._value_size)
-        node = cache._cache['my key']
+        node = cache._cache["my key"]
         cache._remove_node(node)
         self.assertEqual(0, cache._value_size)
 
@@ -324,21 +324,21 @@ class TestLRUSizeCache(TestCase):
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=5)
         self.assertEqual(0, cache._value_size)
         self.assertEqual({}, cache.items())
-        cache.add('test', 'key')
+        cache.add("test", "key")
         self.assertEqual(3, cache._value_size)
-        self.assertEqual({'test': 'key'}, cache.items())
-        cache.add('test2', 'key that is too big')
+        self.assertEqual({"test": "key"}, cache.items())
+        cache.add("test2", "key that is too big")
         self.assertEqual(3, cache._value_size)
-        self.assertEqual({'test': 'key'}, cache.items())
+        self.assertEqual({"test": "key"}, cache.items())
         # If we would add a key, only to cleanup and remove all cached entries,
         # then obviously that value should not be stored
-        cache.add('test3', 'bigkey')
+        cache.add("test3", "bigkey")
         self.assertEqual(3, cache._value_size)
-        self.assertEqual({'test': 'key'}, cache.items())
+        self.assertEqual({"test": "key"}, cache.items())
 
-        cache.add('test4', 'bikey')
+        cache.add("test4", "bikey")
         self.assertEqual(3, cache._value_size)
-        self.assertEqual({'test': 'key'}, cache.items())
+        self.assertEqual({"test": "key"}, cache.items())
 
     def test_no_add_over_size_cleanup(self):
         """If a large value is not cached, we will call cleanup right away."""
@@ -350,63 +350,64 @@ class TestLRUSizeCache(TestCase):
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=5)
         self.assertEqual(0, cache._value_size)
         self.assertEqual({}, cache.items())
-        cache.add('test', 'key that is too big', cleanup=cleanup)
+        cache.add("test", "key that is too big", cleanup=cleanup)
         # key was not added
         self.assertEqual(0, cache._value_size)
         self.assertEqual({}, cache.items())
         # and cleanup was called
-        self.assertEqual([('test', 'key that is too big')], cleanup_calls)
+        self.assertEqual([("test", "key that is too big")], cleanup_calls)
 
     def test_adding_clears_cache_based_on_size(self):
         """The cache is cleared in LRU order until small enough"""
         cache = lru_cache.LRUSizeCache(max_size=20)
-        cache.add('key1', 'value')  # 5 chars
-        cache.add('key2', 'value2')  # 6 chars
-        cache.add('key3', 'value23')  # 7 chars
-        self.assertEqual(5+6+7, cache._value_size)
-        cache['key2']  # reference key2 so it gets a newer reference time
-        cache.add('key4', 'value234')  # 8 chars, over limit
+        cache.add("key1", "value")  # 5 chars
+        cache.add("key2", "value2")  # 6 chars
+        cache.add("key3", "value23")  # 7 chars
+        self.assertEqual(5 + 6 + 7, cache._value_size)
+        cache["key2"]  # reference key2 so it gets a newer reference time
+        cache.add("key4", "value234")  # 8 chars, over limit
         # We have to remove 2 keys to get back under limit
-        self.assertEqual(6+8, cache._value_size)
-        self.assertEqual({'key2': 'value2', 'key4': 'value234'},
-                         cache.items())
+        self.assertEqual(6 + 8, cache._value_size)
+        self.assertEqual({"key2": "value2", "key4": "value234"}, cache.items())
 
     def test_adding_clears_to_after_cleanup_size(self):
         cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10)
-        cache.add('key1', 'value')  # 5 chars
-        cache.add('key2', 'value2')  # 6 chars
-        cache.add('key3', 'value23')  # 7 chars
-        self.assertEqual(5+6+7, cache._value_size)
-        cache['key2']  # reference key2 so it gets a newer reference time
-        cache.add('key4', 'value234')  # 8 chars, over limit
+        cache.add("key1", "value")  # 5 chars
+        cache.add("key2", "value2")  # 6 chars
+        cache.add("key3", "value23")  # 7 chars
+        self.assertEqual(5 + 6 + 7, cache._value_size)
+        cache["key2"]  # reference key2 so it gets a newer reference time
+        cache.add("key4", "value234")  # 8 chars, over limit
         # We have to remove 3 keys to get back under limit
         self.assertEqual(8, cache._value_size)
-        self.assertEqual({'key4': 'value234'}, cache.items())
+        self.assertEqual({"key4": "value234"}, cache.items())
 
     def test_custom_sizes(self):
         def size_of_list(lst):
             return sum(len(x) for x in lst)
-        cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10,
-                                       compute_size=size_of_list)
-
-        cache.add('key1', ['val', 'ue'])  # 5 chars
-        cache.add('key2', ['val', 'ue2'])  # 6 chars
-        cache.add('key3', ['val', 'ue23'])  # 7 chars
-        self.assertEqual(5+6+7, cache._value_size)
-        cache['key2']  # reference key2 so it gets a newer reference time
-        cache.add('key4', ['value', '234'])  # 8 chars, over limit
+
+        cache = lru_cache.LRUSizeCache(
+            max_size=20, after_cleanup_size=10, compute_size=size_of_list
+        )
+
+        cache.add("key1", ["val", "ue"])  # 5 chars
+        cache.add("key2", ["val", "ue2"])  # 6 chars
+        cache.add("key3", ["val", "ue23"])  # 7 chars
+        self.assertEqual(5 + 6 + 7, cache._value_size)
+        cache["key2"]  # reference key2 so it gets a newer reference time
+        cache.add("key4", ["value", "234"])  # 8 chars, over limit
         # We have to remove 3 keys to get back under limit
         self.assertEqual(8, cache._value_size)
-        self.assertEqual({'key4': ['value', '234']}, cache.items())
+        self.assertEqual({"key4": ["value", "234"]}, cache.items())
 
     def test_cleanup(self):
         cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10)
 
         # Add these in order
-        cache.add('key1', 'value')  # 5 chars
-        cache.add('key2', 'value2')  # 6 chars
-        cache.add('key3', 'value23')  # 7 chars
-        self.assertEqual(5+6+7, cache._value_size)
+        cache.add("key1", "value")  # 5 chars
+        cache.add("key2", "value2")  # 6 chars
+        cache.add("key3", "value23")  # 7 chars
+        self.assertEqual(5 + 6 + 7, cache._value_size)
 
         cache.cleanup()
         # Only the most recent fits after cleaning up
@@ -415,40 +416,40 @@ class TestLRUSizeCache(TestCase):
     def test_keys(self):
         cache = lru_cache.LRUSizeCache(max_size=10)
 
-        cache[1] = 'a'
-        cache[2] = 'b'
-        cache[3] = 'cdef'
+        cache[1] = "a"
+        cache[2] = "b"
+        cache[3] = "cdef"
         self.assertEqual([1, 2, 3], sorted(cache.keys()))
 
     def test_resize_smaller(self):
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=9)
-        cache[1] = 'abc'
-        cache[2] = 'def'
-        cache[3] = 'ghi'
-        cache[4] = 'jkl'
+        cache[1] = "abc"
+        cache[2] = "def"
+        cache[3] = "ghi"
+        cache[4] = "jkl"
         # Triggers a cleanup
         self.assertEqual([2, 3, 4], sorted(cache.keys()))
         # Resize should also cleanup again
         cache.resize(max_size=6, after_cleanup_size=4)
         self.assertEqual([4], sorted(cache.keys()))
         # Adding should use the new max size
-        cache[5] = 'mno'
+        cache[5] = "mno"
         self.assertEqual([4, 5], sorted(cache.keys()))
-        cache[6] = 'pqr'
+        cache[6] = "pqr"
         self.assertEqual([6], sorted(cache.keys()))
 
     def test_resize_larger(self):
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=9)
-        cache[1] = 'abc'
-        cache[2] = 'def'
-        cache[3] = 'ghi'
-        cache[4] = 'jkl'
+        cache[1] = "abc"
+        cache[2] = "def"
+        cache[3] = "ghi"
+        cache[4] = "jkl"
         # Triggers a cleanup
         self.assertEqual([2, 3, 4], sorted(cache.keys()))
         cache.resize(max_size=15, after_cleanup_size=12)
         self.assertEqual([2, 3, 4], sorted(cache.keys()))
-        cache[5] = 'mno'
-        cache[6] = 'pqr'
+        cache[5] = "mno"
+        cache[6] = "pqr"
         self.assertEqual([2, 3, 4, 5, 6], sorted(cache.keys()))
-        cache[7] = 'stu'
+        cache[7] = "stu"
         self.assertEqual([4, 5, 6, 7], sorted(cache.keys()))

+ 51 - 40
dulwich/tests/test_mailmap.py

@@ -28,9 +28,9 @@ from dulwich.mailmap import Mailmap, read_mailmap
 
 
 class ReadMailmapTests(TestCase):
-
     def test_read(self):
-        b = BytesIO(b"""\
+        b = BytesIO(
+            b"""\
 Jane Doe         <jane@desktop.(none)>
 Joe R. Developer <joe@example.com>
 # A comment
@@ -39,52 +39,63 @@ Some Dude <some@dude.xx>         nick1 <bugs@company.xx>
 Other Author <other@author.xx>   nick2 <bugs@company.xx>
 Other Author <other@author.xx>         <nick2@company.xx>
 Santa Claus <santa.claus@northpole.xx> <me@company.xx>
-""")
-        self.assertEqual([
-            ((b'Jane Doe', b'jane@desktop.(none)'), None),
-            ((b'Joe R. Developer', b'joe@example.com'), None),
-            ((None, b'cto@company.xx'), (None, b'cto@coompany.xx')),
-            ((b'Some Dude', b'some@dude.xx'), (b'nick1', b'bugs@company.xx')),
-            ((b'Other Author', b'other@author.xx'),
-                (b'nick2', b'bugs@company.xx')),
-            ((b'Other Author', b'other@author.xx'),
-                (None, b'nick2@company.xx')),
-            ((b'Santa Claus', b'santa.claus@northpole.xx'),
-                (None, b'me@company.xx'))],
-            list(read_mailmap(b)))
+"""
+        )
+        self.assertEqual(
+            [
+                ((b"Jane Doe", b"jane@desktop.(none)"), None),
+                ((b"Joe R. Developer", b"joe@example.com"), None),
+                ((None, b"cto@company.xx"), (None, b"cto@coompany.xx")),
+                (
+                    (b"Some Dude", b"some@dude.xx"),
+                    (b"nick1", b"bugs@company.xx"),
+                ),
+                (
+                    (b"Other Author", b"other@author.xx"),
+                    (b"nick2", b"bugs@company.xx"),
+                ),
+                (
+                    (b"Other Author", b"other@author.xx"),
+                    (None, b"nick2@company.xx"),
+                ),
+                (
+                    (b"Santa Claus", b"santa.claus@northpole.xx"),
+                    (None, b"me@company.xx"),
+                ),
+            ],
+            list(read_mailmap(b)),
+        )
 
 
 class MailmapTests(TestCase):
-
     def test_lookup(self):
         m = Mailmap()
-        m.add_entry((b'Jane Doe', b'jane@desktop.(none)'), (None, None))
-        m.add_entry((b'Joe R. Developer', b'joe@example.com'), None)
-        m.add_entry((None, b'cto@company.xx'), (None, b'cto@coompany.xx'))
+        m.add_entry((b"Jane Doe", b"jane@desktop.(none)"), (None, None))
+        m.add_entry((b"Joe R. Developer", b"joe@example.com"), None)
+        m.add_entry((None, b"cto@company.xx"), (None, b"cto@coompany.xx"))
+        m.add_entry((b"Some Dude", b"some@dude.xx"), (b"nick1", b"bugs@company.xx"))
         m.add_entry(
-                (b'Some Dude', b'some@dude.xx'),
-                (b'nick1', b'bugs@company.xx'))
+            (b"Other Author", b"other@author.xx"),
+            (b"nick2", b"bugs@company.xx"),
+        )
+        m.add_entry((b"Other Author", b"other@author.xx"), (None, b"nick2@company.xx"))
         m.add_entry(
-                (b'Other Author', b'other@author.xx'),
-                (b'nick2', b'bugs@company.xx'))
-        m.add_entry(
-                (b'Other Author', b'other@author.xx'),
-                (None, b'nick2@company.xx'))
-        m.add_entry(
-                (b'Santa Claus', b'santa.claus@northpole.xx'),
-                (None, b'me@company.xx'))
-        self.assertEqual(
-            b'Jane Doe <jane@desktop.(none)>',
-            m.lookup(b'Jane Doe <jane@desktop.(none)>'))
+            (b"Santa Claus", b"santa.claus@northpole.xx"),
+            (None, b"me@company.xx"),
+        )
         self.assertEqual(
-            b'Jane Doe <jane@desktop.(none)>',
-            m.lookup(b'Jane Doe <jane@example.com>'))
+            b"Jane Doe <jane@desktop.(none)>",
+            m.lookup(b"Jane Doe <jane@desktop.(none)>"),
+        )
         self.assertEqual(
-            b'Jane Doe <jane@desktop.(none)>',
-            m.lookup(b'Jane D. <jane@desktop.(none)>'))
+            b"Jane Doe <jane@desktop.(none)>",
+            m.lookup(b"Jane Doe <jane@example.com>"),
+        )
         self.assertEqual(
-            b'Some Dude <some@dude.xx>',
-            m.lookup(b'nick1 <bugs@company.xx>'))
+            b"Jane Doe <jane@desktop.(none)>",
+            m.lookup(b"Jane D. <jane@desktop.(none)>"),
+        )
         self.assertEqual(
-            b'CTO <cto@company.xx>',
-            m.lookup(b'CTO <cto@coompany.xx>'))
+            b"Some Dude <some@dude.xx>", m.lookup(b"nick1 <bugs@company.xx>")
+        )
+        self.assertEqual(b"CTO <cto@company.xx>", m.lookup(b"CTO <cto@coompany.xx>"))

+ 142 - 85
dulwich/tests/test_missing_obj_finder.py

@@ -20,57 +20,60 @@
 
 from dulwich.object_store import (
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
     Blob,
-    )
+)
 from dulwich.tests import TestCase
 from dulwich.tests.utils import (
     make_object,
     make_tag,
     build_commit_graph,
-    )
+)
 
 
 class MissingObjectFinderTest(TestCase):
-
     def setUp(self):
         super(MissingObjectFinderTest, self).setUp()
         self.store = MemoryObjectStore()
         self.commits = []
 
     def cmt(self, n):
-        return self.commits[n-1]
+        return self.commits[n - 1]
 
     def assertMissingMatch(self, haves, wants, expected):
         for sha, path in self.store.find_missing_objects(haves, wants, set()):
             self.assertTrue(
-                    sha in expected,
-                    "(%s,%s) erroneously reported as missing" % (sha, path))
+                sha in expected,
+                "(%s,%s) erroneously reported as missing" % (sha, path),
+            )
             expected.remove(sha)
 
         self.assertEqual(
-                len(expected), 0,
-                "some objects are not reported as missing: %s" % (expected, ))
+            len(expected),
+            0,
+            "some objects are not reported as missing: %s" % (expected,),
+        )
 
 
 class MOFLinearRepoTest(MissingObjectFinderTest):
-
     def setUp(self):
         super(MOFLinearRepoTest, self).setUp()
         # present in 1, removed in 3
-        f1_1 = make_object(Blob, data=b'f1')
+        f1_1 = make_object(Blob, data=b"f1")
         # present in all revisions, changed in 2 and 3
-        f2_1 = make_object(Blob, data=b'f2')
-        f2_2 = make_object(Blob, data=b'f2-changed')
-        f2_3 = make_object(Blob, data=b'f2-changed-again')
+        f2_1 = make_object(Blob, data=b"f2")
+        f2_2 = make_object(Blob, data=b"f2-changed")
+        f2_3 = make_object(Blob, data=b"f2-changed-again")
         # added in 2, left unmodified in 3
-        f3_2 = make_object(Blob, data=b'f3')
+        f3_2 = make_object(Blob, data=b"f3")
 
         commit_spec = [[1], [2, 1], [3, 2]]
-        trees = {1: [(b'f1', f1_1), (b'f2', f2_1)],
-                 2: [(b'f1', f1_1), (b'f2', f2_2), (b'f3', f3_2)],
-                 3: [(b'f2', f2_3), (b'f3', f3_2)]}
+        trees = {
+            1: [(b"f1", f1_1), (b"f2", f2_1)],
+            2: [(b"f1", f1_1), (b"f2", f2_2), (b"f3", f3_2)],
+            3: [(b"f2", f2_3), (b"f3", f3_2)],
+        }
         # commit 1: f1 and f2
         # commit 2: f3 added, f2 changed. Missing shall report commit id and a
         # tree referenced by commit
@@ -80,24 +83,23 @@ class MOFLinearRepoTest(MissingObjectFinderTest):
         self.missing_1_2 = [self.cmt(2).id, self.cmt(2).tree, f2_2.id, f3_2.id]
         self.missing_2_3 = [self.cmt(3).id, self.cmt(3).tree, f2_3.id]
         self.missing_1_3 = [
-            self.cmt(2).id, self.cmt(3).id,
-            self.cmt(2).tree, self.cmt(3).tree,
-            f2_2.id, f3_2.id, f2_3.id]
+            self.cmt(2).id,
+            self.cmt(3).id,
+            self.cmt(2).tree,
+            self.cmt(3).tree,
+            f2_2.id,
+            f3_2.id,
+            f2_3.id,
+        ]
 
     def test_1_to_2(self):
-        self.assertMissingMatch(
-                [self.cmt(1).id], [self.cmt(2).id],
-                self.missing_1_2)
+        self.assertMissingMatch([self.cmt(1).id], [self.cmt(2).id], self.missing_1_2)
 
     def test_2_to_3(self):
-        self.assertMissingMatch(
-                [self.cmt(2).id], [self.cmt(3).id],
-                self.missing_2_3)
+        self.assertMissingMatch([self.cmt(2).id], [self.cmt(3).id], self.missing_2_3)
 
     def test_1_to_3(self):
-        self.assertMissingMatch(
-                [self.cmt(1).id], [self.cmt(3).id],
-                self.missing_1_3)
+        self.assertMissingMatch([self.cmt(1).id], [self.cmt(3).id], self.missing_1_3)
 
     def test_bogus_haves(self):
         """Ensure non-existent SHA in haves are tolerated"""
@@ -112,7 +114,8 @@ class MOFLinearRepoTest(MissingObjectFinderTest):
         haves = [self.cmt(1).id]
         wants = [self.cmt(3).id, bogus_sha]
         self.assertRaises(
-                KeyError, self.store.find_missing_objects, haves, wants, set())
+            KeyError, self.store.find_missing_objects, haves, wants, set()
+        )
 
     def test_no_changes(self):
         self.assertMissingMatch([self.cmt(3).id], [self.cmt(3).id], [])
@@ -127,25 +130,27 @@ class MOFMergeForkRepoTest(MissingObjectFinderTest):
 
     def setUp(self):
         super(MOFMergeForkRepoTest, self).setUp()
-        f1_1 = make_object(Blob, data=b'f1')
-        f1_2 = make_object(Blob, data=b'f1-2')
-        f1_4 = make_object(Blob, data=b'f1-4')
-        f1_7 = make_object(Blob, data=b'f1-2')  # same data as in rev 2
-        f2_1 = make_object(Blob, data=b'f2')
-        f2_3 = make_object(Blob, data=b'f2-3')
-        f3_3 = make_object(Blob, data=b'f3')
-        f3_5 = make_object(Blob, data=b'f3-5')
+        f1_1 = make_object(Blob, data=b"f1")
+        f1_2 = make_object(Blob, data=b"f1-2")
+        f1_4 = make_object(Blob, data=b"f1-4")
+        f1_7 = make_object(Blob, data=b"f1-2")  # same data as in rev 2
+        f2_1 = make_object(Blob, data=b"f2")
+        f2_3 = make_object(Blob, data=b"f2-3")
+        f3_3 = make_object(Blob, data=b"f3")
+        f3_5 = make_object(Blob, data=b"f3-5")
         commit_spec = [[1], [2, 1], [3, 2], [4, 2], [5, 3], [6, 3, 4], [7, 6]]
-        trees = {1: [(b'f1', f1_1), (b'f2', f2_1)],
-                 2: [(b'f1', f1_2), (b'f2', f2_1)],  # f1 changed
-                 # f3 added, f2 changed
-                 3: [(b'f1', f1_2), (b'f2', f2_3), (b'f3', f3_3)],
-                 4: [(b'f1', f1_4), (b'f2', f2_1)],  # f1 changed
-                 5: [(b'f1', f1_2), (b'f3', f3_5)],  # f2 removed, f3 changed
-                 # merged 3 and 4
-                 6: [(b'f1', f1_4), (b'f2', f2_3), (b'f3', f3_3)],
-                 # f1 changed to match rev2. f3 removed
-                 7: [(b'f1', f1_7), (b'f2', f2_3)]}
+        trees = {
+            1: [(b"f1", f1_1), (b"f2", f2_1)],
+            2: [(b"f1", f1_2), (b"f2", f2_1)],  # f1 changed
+            # f3 added, f2 changed
+            3: [(b"f1", f1_2), (b"f2", f2_3), (b"f3", f3_3)],
+            4: [(b"f1", f1_4), (b"f2", f2_1)],  # f1 changed
+            5: [(b"f1", f1_2), (b"f3", f3_5)],  # f2 removed, f3 changed
+            # merged 3 and 4
+            6: [(b"f1", f1_4), (b"f2", f2_3), (b"f3", f3_3)],
+            # f1 changed to match rev2. f3 removed
+            7: [(b"f1", f1_7), (b"f2", f2_3)],
+        }
         self.commits = build_commit_graph(self.store, commit_spec, trees)
 
         self.f1_2_id = f1_2.id
@@ -164,54 +169,96 @@ class MOFMergeForkRepoTest(MissingObjectFinderTest):
         # doesn't record f1_2 was known prior to that, hence can't detect f1_7
         # is in fact f1_2 and shall not be reported)
         self.assertMissingMatch(
-                [self.cmt(6).id], [self.cmt(7).id],
-                [self.cmt(7).id, self.cmt(7).tree, self.f1_7_id])
+            [self.cmt(6).id],
+            [self.cmt(7).id],
+            [self.cmt(7).id, self.cmt(7).tree, self.f1_7_id],
+        )
 
     def test_have4_want7(self):
         # have 4, want 7. Shall not include rev5 as it is not in the tree
         # between 4 and 7 (well, it is, but its SHA's are irrelevant for 4..7
         # commit hierarchy)
-        self.assertMissingMatch([self.cmt(4).id], [self.cmt(7).id], [
-            self.cmt(7).id, self.cmt(6).id, self.cmt(3).id,
-            self.cmt(7).tree, self.cmt(6).tree, self.cmt(3).tree,
-            self.f2_3_id, self.f3_3_id])
+        self.assertMissingMatch(
+            [self.cmt(4).id],
+            [self.cmt(7).id],
+            [
+                self.cmt(7).id,
+                self.cmt(6).id,
+                self.cmt(3).id,
+                self.cmt(7).tree,
+                self.cmt(6).tree,
+                self.cmt(3).tree,
+                self.f2_3_id,
+                self.f3_3_id,
+            ],
+        )
 
     def test_have1_want6(self):
         # have 1, want 6. Shall not include rev5
-        self.assertMissingMatch([self.cmt(1).id], [self.cmt(6).id], [
-            self.cmt(6).id, self.cmt(4).id, self.cmt(3).id, self.cmt(2).id,
-            self.cmt(6).tree, self.cmt(4).tree, self.cmt(3).tree,
-            self.cmt(2).tree, self.f1_2_id, self.f1_4_id, self.f2_3_id,
-            self.f3_3_id])
+        self.assertMissingMatch(
+            [self.cmt(1).id],
+            [self.cmt(6).id],
+            [
+                self.cmt(6).id,
+                self.cmt(4).id,
+                self.cmt(3).id,
+                self.cmt(2).id,
+                self.cmt(6).tree,
+                self.cmt(4).tree,
+                self.cmt(3).tree,
+                self.cmt(2).tree,
+                self.f1_2_id,
+                self.f1_4_id,
+                self.f2_3_id,
+                self.f3_3_id,
+            ],
+        )
 
     def test_have3_want6(self):
         # have 3, want 7. Shall not report rev2 and its tree, because
         # haves(3) means has parents, i.e. rev2, too
         # BUT shall report any changes descending rev2 (excluding rev3)
-        # Shall NOT report f1_7 as it's techically == f1_2
-        self.assertMissingMatch([self.cmt(3).id], [self.cmt(7).id], [
-              self.cmt(7).id, self.cmt(6).id, self.cmt(4).id,
-              self.cmt(7).tree, self.cmt(6).tree, self.cmt(4).tree,
-              self.f1_4_id])
+        # Shall NOT report f1_7 as it's technically == f1_2
+        self.assertMissingMatch(
+            [self.cmt(3).id],
+            [self.cmt(7).id],
+            [
+                self.cmt(7).id,
+                self.cmt(6).id,
+                self.cmt(4).id,
+                self.cmt(7).tree,
+                self.cmt(6).tree,
+                self.cmt(4).tree,
+                self.f1_4_id,
+            ],
+        )
 
     def test_have5_want7(self):
         # have 5, want 7. Common parent is rev2, hence children of rev2 from
         # a descent line other than rev5 shall be reported
         # expects f1_4 from rev6. f3_5 is known in rev5;
         # f1_7 shall be the same as f1_2 (known, too)
-        self.assertMissingMatch([self.cmt(5).id], [self.cmt(7).id], [
-              self.cmt(7).id, self.cmt(6).id, self.cmt(4).id,
-              self.cmt(7).tree, self.cmt(6).tree, self.cmt(4).tree,
-              self.f1_4_id])
+        self.assertMissingMatch(
+            [self.cmt(5).id],
+            [self.cmt(7).id],
+            [
+                self.cmt(7).id,
+                self.cmt(6).id,
+                self.cmt(4).id,
+                self.cmt(7).tree,
+                self.cmt(6).tree,
+                self.cmt(4).tree,
+                self.f1_4_id,
+            ],
+        )
 
 
 class MOFTagsTest(MissingObjectFinderTest):
-
     def setUp(self):
         super(MOFTagsTest, self).setUp()
-        f1_1 = make_object(Blob, data=b'f1')
+        f1_1 = make_object(Blob, data=b"f1")
         commit_spec = [[1]]
-        trees = {1: [(b'f1', f1_1)]}
+        trees = {1: [(b"f1", f1_1)]}
         self.commits = build_commit_graph(self.store, commit_spec, trees)
 
         self._normal_tag = make_tag(self.cmt(1))
@@ -234,28 +281,38 @@ class MOFTagsTest(MissingObjectFinderTest):
     def test_tagged_commit(self):
         # The user already has the tagged commit, all they want is the tag,
         # so send them only the tag object.
-        self.assertMissingMatch([self.cmt(1).id], [self._normal_tag.id],
-                                [self._normal_tag.id])
+        self.assertMissingMatch(
+            [self.cmt(1).id], [self._normal_tag.id], [self._normal_tag.id]
+        )
 
     # The remaining cases are unusual, but do happen in the wild.
     def test_tagged_tag(self):
         # User already has tagged tag, send only tag of tag
-        self.assertMissingMatch([self._normal_tag.id], [self._tag_of_tag.id],
-                                [self._tag_of_tag.id])
+        self.assertMissingMatch(
+            [self._normal_tag.id], [self._tag_of_tag.id], [self._tag_of_tag.id]
+        )
         # User needs both tags, but already has commit
-        self.assertMissingMatch([self.cmt(1).id], [self._tag_of_tag.id],
-                                [self._normal_tag.id, self._tag_of_tag.id])
+        self.assertMissingMatch(
+            [self.cmt(1).id],
+            [self._tag_of_tag.id],
+            [self._normal_tag.id, self._tag_of_tag.id],
+        )
 
     def test_tagged_tree(self):
         self.assertMissingMatch(
-            [], [self._tag_of_tree.id],
-            [self._tag_of_tree.id, self.cmt(1).tree, self.f1_1_id])
+            [],
+            [self._tag_of_tree.id],
+            [self._tag_of_tree.id, self.cmt(1).tree, self.f1_1_id],
+        )
 
     def test_tagged_blob(self):
-        self.assertMissingMatch([], [self._tag_of_blob.id],
-                                [self._tag_of_blob.id, self.f1_1_id])
+        self.assertMissingMatch(
+            [], [self._tag_of_blob.id], [self._tag_of_blob.id, self.f1_1_id]
+        )
 
     def test_tagged_tagged_blob(self):
-        self.assertMissingMatch([], [self._tag_of_tag_of_blob.id],
-                                [self._tag_of_tag_of_blob.id,
-                                 self._tag_of_blob.id, self.f1_1_id])
+        self.assertMissingMatch(
+            [],
+            [self._tag_of_tag_of_blob.id],
+            [self._tag_of_tag_of_blob.id, self._tag_of_blob.id, self.f1_1_id],
+        )

+ 257 - 170
dulwich/tests/test_object_store.py

@@ -23,6 +23,7 @@
 
 from contextlib import closing
 from io import BytesIO
+from unittest import skipUnless
 import os
 import shutil
 import stat
@@ -30,17 +31,17 @@ import tempfile
 
 from dulwich.index import (
     commit_tree,
-    )
+)
 from dulwich.errors import (
     NotTreeError,
-    )
+)
 from dulwich.objects import (
     sha_to_hex,
     Blob,
     Tree,
     TreeEntry,
     EmptyFileException,
-    )
+)
 from dulwich.object_store import (
     DiskObjectStore,
     MemoryObjectStore,
@@ -49,34 +50,83 @@ from dulwich.object_store import (
     commit_tree_changes,
     read_packs_file,
     tree_lookup_path,
-    )
+)
 from dulwich.pack import (
     REF_DELTA,
     write_pack_objects,
-    )
+)
+from dulwich.protocol import DEPTH_INFINITE
 from dulwich.tests import (
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
     make_object,
     make_tag,
     build_pack,
-    )
+)
+
+try:
+    from unittest.mock import patch
+except ImportError:
+    patch = None  # type: ignore
 
 
 testobject = make_object(Blob, data=b"yummy data")
 
 
 class ObjectStoreTests(object):
-
     def test_determine_wants_all(self):
         self.assertEqual(
             [b"1" * 40],
-            self.store.determine_wants_all({b"refs/heads/foo": b"1" * 40}))
+            self.store.determine_wants_all({b"refs/heads/foo": b"1" * 40}),
+        )
 
     def test_determine_wants_all_zero(self):
         self.assertEqual(
-            [], self.store.determine_wants_all({b"refs/heads/foo": b"0" * 40}))
+            [], self.store.determine_wants_all({b"refs/heads/foo": b"0" * 40})
+        )
+
+    @skipUnless(patch, "Required mock.patch")
+    def test_determine_wants_all_depth(self):
+        self.store.add_object(testobject)
+        refs = {b"refs/heads/foo": testobject.id}
+        with patch.object(self.store, "_get_depth", return_value=1) as m:
+            self.assertEqual(
+                [], self.store.determine_wants_all(refs, depth=0)
+            )
+            self.assertEqual(
+                [testobject.id],
+                self.store.determine_wants_all(refs, depth=DEPTH_INFINITE),
+            )
+            m.assert_not_called()
+
+            self.assertEqual(
+                [], self.store.determine_wants_all(refs, depth=1)
+            )
+            m.assert_called_with(testobject.id)
+            self.assertEqual(
+                [testobject.id], self.store.determine_wants_all(refs, depth=2)
+            )
+
+    def test_get_depth(self):
+        self.assertEqual(
+            0, self.store._get_depth(testobject.id)
+        )
+
+        self.store.add_object(testobject)
+        self.assertEqual(
+            1, self.store._get_depth(testobject.id, get_parents=lambda x: [])
+        )
+
+        parent = make_object(Blob, data=b"parent data")
+        self.store.add_object(parent)
+        self.assertEqual(
+            2,
+            self.store._get_depth(
+                testobject.id,
+                get_parents=lambda x: [parent.id] if x == testobject else [],
+            ),
+        )
 
     def test_iter(self):
         self.assertEqual([], list(self.store))
@@ -99,11 +149,11 @@ class ObjectStoreTests(object):
         """Test if updating an existing stored object doesn't erase the
         object from the store.
         """
-        test_object = make_object(Blob, data=b'data')
+        test_object = make_object(Blob, data=b"data")
 
         self.store.add_object(test_object)
         test_object_id = test_object.id
-        test_object.data = test_object.data + b'update'
+        test_object.data = test_object.data + b"update"
         stored_test_object = self.store[test_object_id]
 
         self.assertNotEqual(test_object.id, stored_test_object.id)
@@ -125,69 +175,75 @@ class ObjectStoreTests(object):
         self.assertEqual(r, testobject)
 
     def test_tree_changes(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_b = make_object(Blob, data=b'b')
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_b = make_object(Blob, data=b"b")
         for blob in [blob_a1, blob_a2, blob_b]:
             self.store.add_object(blob)
 
-        blobs_1 = [(b'a', blob_a1.id, 0o100644), (b'b', blob_b.id, 0o100644)]
+        blobs_1 = [(b"a", blob_a1.id, 0o100644), (b"b", blob_b.id, 0o100644)]
         tree1_id = commit_tree(self.store, blobs_1)
-        blobs_2 = [(b'a', blob_a2.id, 0o100644), (b'b', blob_b.id, 0o100644)]
+        blobs_2 = [(b"a", blob_a2.id, 0o100644), (b"b", blob_b.id, 0o100644)]
         tree2_id = commit_tree(self.store, blobs_2)
-        change_a = ((b'a', b'a'), (0o100644, 0o100644),
-                    (blob_a1.id, blob_a2.id))
-        self.assertEqual([change_a],
-                         list(self.store.tree_changes(tree1_id, tree2_id)))
+        change_a = (
+            (b"a", b"a"),
+            (0o100644, 0o100644),
+            (blob_a1.id, blob_a2.id),
+        )
+        self.assertEqual([change_a], list(self.store.tree_changes(tree1_id, tree2_id)))
         self.assertEqual(
-            [change_a, ((b'b', b'b'), (0o100644, 0o100644),
-             (blob_b.id, blob_b.id))],
-            list(self.store.tree_changes(tree1_id, tree2_id,
-                 want_unchanged=True)))
+            [
+                change_a,
+                ((b"b", b"b"), (0o100644, 0o100644), (blob_b.id, blob_b.id)),
+            ],
+            list(self.store.tree_changes(tree1_id, tree2_id, want_unchanged=True)),
+        )
 
     def test_iter_tree_contents(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
-        blob_c = make_object(Blob, data=b'c')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
+        blob_c = make_object(Blob, data=b"c")
         for blob in [blob_a, blob_b, blob_c]:
             self.store.add_object(blob)
 
         blobs = [
-            (b'a', blob_a.id, 0o100644),
-            (b'ad/b', blob_b.id, 0o100644),
-            (b'ad/bd/c', blob_c.id, 0o100755),
-            (b'ad/c', blob_c.id, 0o100644),
-            (b'c', blob_c.id, 0o100644),
+            (b"a", blob_a.id, 0o100644),
+            (b"ad/b", blob_b.id, 0o100644),
+            (b"ad/bd/c", blob_c.id, 0o100755),
+            (b"ad/c", blob_c.id, 0o100644),
+            (b"c", blob_c.id, 0o100644),
         ]
         tree_id = commit_tree(self.store, blobs)
-        self.assertEqual([TreeEntry(p, m, h) for (p, h, m) in blobs],
-                         list(self.store.iter_tree_contents(tree_id)))
+        self.assertEqual(
+            [TreeEntry(p, m, h) for (p, h, m) in blobs],
+            list(self.store.iter_tree_contents(tree_id)),
+        )
 
     def test_iter_tree_contents_include_trees(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
-        blob_c = make_object(Blob, data=b'c')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
+        blob_c = make_object(Blob, data=b"c")
         for blob in [blob_a, blob_b, blob_c]:
             self.store.add_object(blob)
 
         blobs = [
-          (b'a', blob_a.id, 0o100644),
-          (b'ad/b', blob_b.id, 0o100644),
-          (b'ad/bd/c', blob_c.id, 0o100755),
-          ]
+            (b"a", blob_a.id, 0o100644),
+            (b"ad/b", blob_b.id, 0o100644),
+            (b"ad/bd/c", blob_c.id, 0o100755),
+        ]
         tree_id = commit_tree(self.store, blobs)
         tree = self.store[tree_id]
-        tree_ad = self.store[tree[b'ad'][1]]
-        tree_bd = self.store[tree_ad[b'bd'][1]]
+        tree_ad = self.store[tree[b"ad"][1]]
+        tree_bd = self.store[tree_ad[b"bd"][1]]
 
         expected = [
-          TreeEntry(b'', 0o040000, tree_id),
-          TreeEntry(b'a', 0o100644, blob_a.id),
-          TreeEntry(b'ad', 0o040000, tree_ad.id),
-          TreeEntry(b'ad/b', 0o100644, blob_b.id),
-          TreeEntry(b'ad/bd', 0o040000, tree_bd.id),
-          TreeEntry(b'ad/bd/c', 0o100755, blob_c.id),
-          ]
+            TreeEntry(b"", 0o040000, tree_id),
+            TreeEntry(b"a", 0o100644, blob_a.id),
+            TreeEntry(b"ad", 0o040000, tree_ad.id),
+            TreeEntry(b"ad/b", 0o100644, blob_b.id),
+            TreeEntry(b"ad/bd", 0o040000, tree_bd.id),
+            TreeEntry(b"ad/bd/c", 0o100755, blob_c.id),
+        ]
         actual = self.store.iter_tree_contents(tree_id, include_trees=True)
         self.assertEqual(expected, list(actual))
 
@@ -198,16 +254,17 @@ class ObjectStoreTests(object):
 
     def test_peel_sha(self):
         self.store.add_object(testobject)
-        tag1 = self.make_tag(b'1', testobject)
-        tag2 = self.make_tag(b'2', testobject)
-        tag3 = self.make_tag(b'3', testobject)
+        tag1 = self.make_tag(b"1", testobject)
+        tag2 = self.make_tag(b"2", testobject)
+        tag3 = self.make_tag(b"3", testobject)
         for obj in [testobject, tag1, tag2, tag3]:
             self.assertEqual(testobject, self.store.peel_sha(obj.id))
 
     def test_get_raw(self):
         self.store.add_object(testobject)
-        self.assertEqual((Blob.type_num, b'yummy data'),
-                         self.store.get_raw(testobject.id))
+        self.assertEqual(
+            (Blob.type_num, b"yummy data"), self.store.get_raw(testobject.id)
+        )
 
     def test_close(self):
         # For now, just check that close doesn't barf.
@@ -216,7 +273,6 @@ class ObjectStoreTests(object):
 
 
 class OverlayObjectStoreTests(ObjectStoreTests, TestCase):
-
     def setUp(self):
         TestCase.setUp(self)
         self.bases = [MemoryObjectStore(), MemoryObjectStore()]
@@ -224,7 +280,6 @@ class OverlayObjectStoreTests(ObjectStoreTests, TestCase):
 
 
 class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
-
     def setUp(self):
         TestCase.setUp(self)
         self.store = MemoryObjectStore()
@@ -248,17 +303,22 @@ class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
 
     def test_add_thin_pack(self):
         o = MemoryObjectStore()
-        blob = make_object(Blob, data=b'yummy data')
+        blob = make_object(Blob, data=b"yummy data")
         o.add_object(blob)
 
         f = BytesIO()
-        entries = build_pack(f, [
-            (REF_DELTA, (blob.id, b'more yummy data')),
-            ], store=o)
+        entries = build_pack(
+            f,
+            [
+                (REF_DELTA, (blob.id, b"more yummy data")),
+            ],
+            store=o,
+        )
         o.add_thin_pack(f.read, None)
         packed_blob_sha = sha_to_hex(entries[0][3])
-        self.assertEqual((Blob.type_num, b'more yummy data'),
-                         o.get_raw(packed_blob_sha))
+        self.assertEqual(
+            (Blob.type_num, b"more yummy data"), o.get_raw(packed_blob_sha)
+        )
 
     def test_add_thin_pack_empty(self):
         o = MemoryObjectStore()
@@ -270,7 +330,6 @@ class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
 
 
 class PackBasedObjectStoreTests(ObjectStoreTests):
-
     def tearDown(self):
         for pack in self.store.packs:
             pack.close()
@@ -303,8 +362,7 @@ class PackBasedObjectStoreTests(ObjectStoreTests):
         b5 = make_object(Blob, data=b"and more data")
         b6 = make_object(Blob, data=b"and some more data")
         self.store.add_objects([(b5, None), (b6, None)])
-        self.assertEqual({b1.id, b2.id, b3.id, b4.id, b5.id, b6.id},
-                         set(self.store))
+        self.assertEqual({b1.id, b2.id, b3.id, b4.id, b5.id, b6.id}, set(self.store))
         self.assertEqual(2, len(self.store.packs))
         self.assertEqual(6, self.store.repack())
         self.assertEqual(1, len(self.store.packs))
@@ -331,7 +389,6 @@ class PackBasedObjectStoreTests(ObjectStoreTests):
 
 
 class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
-
     def setUp(self):
         TestCase.setUp(self)
         self.store_dir = tempfile.mkdtemp()
@@ -345,8 +402,7 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
     def test_loose_compression_level(self):
         alternate_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, alternate_dir)
-        alternate_store = DiskObjectStore(
-            alternate_dir, loose_compression_level=6)
+        alternate_store = DiskObjectStore(alternate_dir, loose_compression_level=6)
         b2 = make_object(Blob, data=b"yummy data")
         alternate_store.add_object(b2)
 
@@ -365,14 +421,16 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
     def test_read_alternate_paths(self):
         store = DiskObjectStore(self.store_dir)
 
-        abs_path = os.path.abspath(os.path.normpath('/abspath'))
+        abs_path = os.path.abspath(os.path.normpath("/abspath"))
         # ensures in particular existence of the alternates file
         store.add_alternate_path(abs_path)
         self.assertEqual(set(store._read_alternate_paths()), {abs_path})
 
         store.add_alternate_path("relative-path")
-        self.assertIn(os.path.join(store.path, "relative-path"),
-                      set(store._read_alternate_paths()))
+        self.assertIn(
+            os.path.join(store.path, "relative-path"),
+            set(store._read_alternate_paths()),
+        )
 
         # arguably, add_alternate_path() could strip comments.
         # Meanwhile it's more convenient to use it than to import INFODIR
@@ -383,16 +441,17 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
     def test_corrupted_object_raise_exception(self):
         """Corrupted sha1 disk file should raise specific exception"""
         self.store.add_object(testobject)
-        self.assertEqual((Blob.type_num, b'yummy data'),
-                         self.store.get_raw(testobject.id))
+        self.assertEqual(
+            (Blob.type_num, b"yummy data"), self.store.get_raw(testobject.id)
+        )
         self.assertTrue(self.store.contains_loose(testobject.id))
         self.assertIsNotNone(self.store._get_loose_object(testobject.id))
 
         path = self.store._get_shafile_path(testobject.id)
-        with open(path, 'wb') as f:  # corrupt the file
-            f.write(b'')
+        with open(path, "wb") as f:  # corrupt the file
+            f.write(b"")
 
-        expected_error_msg = 'Corrupted empty file detected'
+        expected_error_msg = "Corrupted empty file detected"
         try:
             self.store.contains_loose(testobject.id)
         except EmptyFileException as e:
@@ -404,13 +463,11 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
             self.assertEqual(str(e), expected_error_msg)
 
         # this does not change iteration on loose objects though
-        self.assertEqual([testobject.id],
-                         list(self.store._iter_loose_objects()))
+        self.assertEqual([testobject.id], list(self.store._iter_loose_objects()))
 
     def test_tempfile_in_loose_store(self):
         self.store.add_object(testobject)
-        self.assertEqual([testobject.id],
-                         list(self.store._iter_loose_objects()))
+        self.assertEqual([testobject.id], list(self.store._iter_loose_objects()))
 
         # add temporary files to the loose store
         for i in range(256):
@@ -420,8 +477,7 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
             fd, n = tempfile.mkstemp(prefix="tmp_obj_", dir=dirname)
             os.close(fd)
 
-        self.assertEqual([testobject.id],
-                         list(self.store._iter_loose_objects()))
+        self.assertEqual([testobject.id], list(self.store._iter_loose_objects()))
 
     def test_add_alternate_path(self):
         store = DiskObjectStore(self.store_dir)
@@ -430,8 +486,8 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
         self.assertEqual(["/foo/path"], list(store._read_alternate_paths()))
         store.add_alternate_path("/bar/path")
         self.assertEqual(
-            ["/foo/path", "/bar/path"],
-            list(store._read_alternate_paths()))
+            ["/foo/path", "/bar/path"], list(store._read_alternate_paths())
+        )
 
     def test_rel_alternative_path(self):
         alternate_dir = tempfile.mkdtemp()
@@ -441,8 +497,7 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
         alternate_store.add_object(b2)
         store = DiskObjectStore(self.store_dir)
         self.assertRaises(KeyError, store.__getitem__, b2.id)
-        store.add_alternate_path(
-            os.path.relpath(alternate_dir, self.store_dir))
+        store.add_alternate_path(os.path.relpath(alternate_dir, self.store_dir))
         self.assertEqual(list(alternate_store), list(store.alternates[0]))
         self.assertIn(b2.id, store)
         self.assertEqual(b2, store[b2.id])
@@ -466,23 +521,28 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
     def test_add_thin_pack(self):
         o = DiskObjectStore(self.store_dir)
         try:
-            blob = make_object(Blob, data=b'yummy data')
+            blob = make_object(Blob, data=b"yummy data")
             o.add_object(blob)
 
             f = BytesIO()
-            entries = build_pack(f, [
-              (REF_DELTA, (blob.id, b'more yummy data')),
-              ], store=o)
+            entries = build_pack(
+                f,
+                [
+                    (REF_DELTA, (blob.id, b"more yummy data")),
+                ],
+                store=o,
+            )
 
             with o.add_thin_pack(f.read, None) as pack:
                 packed_blob_sha = sha_to_hex(entries[0][3])
                 pack.check_length_and_checksum()
-                self.assertEqual(
-                    sorted([blob.id, packed_blob_sha]), list(pack))
+                self.assertEqual(sorted([blob.id, packed_blob_sha]), list(pack))
                 self.assertTrue(o.contains_packed(packed_blob_sha))
                 self.assertTrue(o.contains_packed(blob.id))
-                self.assertEqual((Blob.type_num, b'more yummy data'),
-                                 o.get_raw(packed_blob_sha))
+                self.assertEqual(
+                    (Blob.type_num, b"more yummy data"),
+                    o.get_raw(packed_blob_sha),
+                )
         finally:
             o.close()
 
@@ -495,58 +555,62 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
 
 
 class TreeLookupPathTests(TestCase):
-
     def setUp(self):
         TestCase.setUp(self)
         self.store = MemoryObjectStore()
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
-        blob_c = make_object(Blob, data=b'c')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
+        blob_c = make_object(Blob, data=b"c")
         for blob in [blob_a, blob_b, blob_c]:
             self.store.add_object(blob)
 
         blobs = [
-          (b'a', blob_a.id, 0o100644),
-          (b'ad/b', blob_b.id, 0o100644),
-          (b'ad/bd/c', blob_c.id, 0o100755),
-          (b'ad/c', blob_c.id, 0o100644),
-          (b'c', blob_c.id, 0o100644),
-          ]
+            (b"a", blob_a.id, 0o100644),
+            (b"ad/b", blob_b.id, 0o100644),
+            (b"ad/bd/c", blob_c.id, 0o100755),
+            (b"ad/c", blob_c.id, 0o100644),
+            (b"c", blob_c.id, 0o100644),
+        ]
         self.tree_id = commit_tree(self.store, blobs)
 
     def get_object(self, sha):
         return self.store[sha]
 
     def test_lookup_blob(self):
-        o_id = tree_lookup_path(self.get_object, self.tree_id, b'a')[1]
+        o_id = tree_lookup_path(self.get_object, self.tree_id, b"a")[1]
         self.assertTrue(isinstance(self.store[o_id], Blob))
 
     def test_lookup_tree(self):
-        o_id = tree_lookup_path(self.get_object, self.tree_id, b'ad')[1]
+        o_id = tree_lookup_path(self.get_object, self.tree_id, b"ad")[1]
         self.assertTrue(isinstance(self.store[o_id], Tree))
-        o_id = tree_lookup_path(self.get_object, self.tree_id, b'ad/bd')[1]
+        o_id = tree_lookup_path(self.get_object, self.tree_id, b"ad/bd")[1]
         self.assertTrue(isinstance(self.store[o_id], Tree))
-        o_id = tree_lookup_path(self.get_object, self.tree_id, b'ad/bd/')[1]
+        o_id = tree_lookup_path(self.get_object, self.tree_id, b"ad/bd/")[1]
         self.assertTrue(isinstance(self.store[o_id], Tree))
 
     def test_lookup_nonexistent(self):
         self.assertRaises(
-            KeyError, tree_lookup_path, self.get_object, self.tree_id, b'j')
+            KeyError, tree_lookup_path, self.get_object, self.tree_id, b"j"
+        )
 
     def test_lookup_not_tree(self):
         self.assertRaises(
-            NotTreeError, tree_lookup_path, self.get_object, self.tree_id,
-            b'ad/b/j')
+            NotTreeError,
+            tree_lookup_path,
+            self.get_object,
+            self.tree_id,
+            b"ad/b/j",
+        )
 
 
 class ObjectStoreGraphWalkerTests(TestCase):
-
     def get_walker(self, heads, parent_map):
         new_parent_map = dict(
-                [(k * 40, [(p * 40) for p in ps])
-                 for (k, ps) in parent_map.items()])
-        return ObjectStoreGraphWalker([x * 40 for x in heads],
-                                      new_parent_map.__getitem__)
+            [(k * 40, [(p * 40) for p in ps]) for (k, ps) in parent_map.items()]
+        )
+        return ObjectStoreGraphWalker(
+            [x * 40 for x in heads], new_parent_map.__getitem__
+        )
 
     def test_ack_invalid_value(self):
         gw = self.get_walker([], {})
@@ -587,13 +651,16 @@ class ObjectStoreGraphWalkerTests(TestCase):
         # c  d
         # \ /
         #  e
-        gw = self.get_walker([b"a", b"b"], {
+        gw = self.get_walker(
+            [b"a", b"b"],
+            {
                 b"a": [b"c"],
                 b"b": [b"d"],
                 b"c": [b"e"],
                 b"d": [b"e"],
                 b"e": [],
-                })
+            },
+        )
         walk = []
         acked = False
         walk.append(next(gw))
@@ -612,86 +679,106 @@ class ObjectStoreGraphWalkerTests(TestCase):
         walk.append(next(gw))
         self.assertIs(None, next(gw))
 
-        self.assertEqual([b"a" * 40, b"b" * 40, b"c" * 40, b"d" * 40],
-                         sorted(walk))
+        self.assertEqual([b"a" * 40, b"b" * 40, b"c" * 40, b"d" * 40], sorted(walk))
         self.assertLess(walk.index(b"a" * 40), walk.index(b"c" * 40))
         self.assertLess(walk.index(b"b" * 40), walk.index(b"d" * 40))
 
 
 class CommitTreeChangesTests(TestCase):
-
     def setUp(self):
         super(CommitTreeChangesTests, self).setUp()
         self.store = MemoryObjectStore()
-        self.blob_a = make_object(Blob, data=b'a')
-        self.blob_b = make_object(Blob, data=b'b')
-        self.blob_c = make_object(Blob, data=b'c')
+        self.blob_a = make_object(Blob, data=b"a")
+        self.blob_b = make_object(Blob, data=b"b")
+        self.blob_c = make_object(Blob, data=b"c")
         for blob in [self.blob_a, self.blob_b, self.blob_c]:
             self.store.add_object(blob)
 
         blobs = [
-          (b'a', self.blob_a.id, 0o100644),
-          (b'ad/b', self.blob_b.id, 0o100644),
-          (b'ad/bd/c', self.blob_c.id, 0o100755),
-          (b'ad/c', self.blob_c.id, 0o100644),
-          (b'c', self.blob_c.id, 0o100644),
-          ]
+            (b"a", self.blob_a.id, 0o100644),
+            (b"ad/b", self.blob_b.id, 0o100644),
+            (b"ad/bd/c", self.blob_c.id, 0o100755),
+            (b"ad/c", self.blob_c.id, 0o100644),
+            (b"c", self.blob_c.id, 0o100644),
+        ]
         self.tree_id = commit_tree(self.store, blobs)
 
     def test_no_changes(self):
         self.assertEqual(
-                self.store[self.tree_id],
-                commit_tree_changes(self.store, self.store[self.tree_id], []))
+            self.store[self.tree_id],
+            commit_tree_changes(self.store, self.store[self.tree_id], []),
+        )
 
     def test_add_blob(self):
-        blob_d = make_object(Blob, data=b'd')
+        blob_d = make_object(Blob, data=b"d")
         new_tree = commit_tree_changes(
-                self.store, self.store[self.tree_id], [
-                    (b'd', 0o100644, blob_d.id)])
+            self.store, self.store[self.tree_id], [(b"d", 0o100644, blob_d.id)]
+        )
         self.assertEqual(
-            new_tree[b'd'],
-            (33188, b'c59d9b6344f1af00e504ba698129f07a34bbed8d'))
+            new_tree[b"d"],
+            (33188, b"c59d9b6344f1af00e504ba698129f07a34bbed8d"),
+        )
 
     def test_add_blob_in_dir(self):
-        blob_d = make_object(Blob, data=b'd')
+        blob_d = make_object(Blob, data=b"d")
         new_tree = commit_tree_changes(
-                self.store, self.store[self.tree_id], [
-                    (b'e/f/d', 0o100644, blob_d.id)])
+            self.store,
+            self.store[self.tree_id],
+            [(b"e/f/d", 0o100644, blob_d.id)],
+        )
         self.assertEqual(
-            new_tree.items(), [
-                TreeEntry(path=b'a', mode=stat.S_IFREG | 0o100644,
-                          sha=self.blob_a.id),
-                TreeEntry(path=b'ad', mode=stat.S_IFDIR,
-                          sha=b'0e2ce2cd7725ff4817791be31ccd6e627e801f4a'),
-                TreeEntry(path=b'c', mode=stat.S_IFREG | 0o100644,
-                          sha=self.blob_c.id),
-                TreeEntry(path=b'e', mode=stat.S_IFDIR,
-                          sha=b'6ab344e288724ac2fb38704728b8896e367ed108')
-                ])
-        e_tree = self.store[new_tree[b'e'][1]]
+            new_tree.items(),
+            [
+                TreeEntry(path=b"a", mode=stat.S_IFREG | 0o100644, sha=self.blob_a.id),
+                TreeEntry(
+                    path=b"ad",
+                    mode=stat.S_IFDIR,
+                    sha=b"0e2ce2cd7725ff4817791be31ccd6e627e801f4a",
+                ),
+                TreeEntry(path=b"c", mode=stat.S_IFREG | 0o100644, sha=self.blob_c.id),
+                TreeEntry(
+                    path=b"e",
+                    mode=stat.S_IFDIR,
+                    sha=b"6ab344e288724ac2fb38704728b8896e367ed108",
+                ),
+            ],
+        )
+        e_tree = self.store[new_tree[b"e"][1]]
         self.assertEqual(
-            e_tree.items(), [
-                TreeEntry(path=b'f', mode=stat.S_IFDIR,
-                          sha=b'24d2c94d8af232b15a0978c006bf61ef4479a0a5')
-                ])
-        f_tree = self.store[e_tree[b'f'][1]]
+            e_tree.items(),
+            [
+                TreeEntry(
+                    path=b"f",
+                    mode=stat.S_IFDIR,
+                    sha=b"24d2c94d8af232b15a0978c006bf61ef4479a0a5",
+                )
+            ],
+        )
+        f_tree = self.store[e_tree[b"f"][1]]
         self.assertEqual(
-            f_tree.items(), [
-                TreeEntry(path=b'd', mode=stat.S_IFREG | 0o100644,
-                          sha=blob_d.id)
-                ])
+            f_tree.items(),
+            [TreeEntry(path=b"d", mode=stat.S_IFREG | 0o100644, sha=blob_d.id)],
+        )
 
     def test_delete_blob(self):
         new_tree = commit_tree_changes(
-                self.store, self.store[self.tree_id], [
-                    (b'ad/bd/c', None, None)])
-        self.assertEqual(set(new_tree), {b'a', b'ad', b'c'})
-        ad_tree = self.store[new_tree[b'ad'][1]]
-        self.assertEqual(set(ad_tree), {b'b', b'c'})
+            self.store, self.store[self.tree_id], [(b"ad/bd/c", None, None)]
+        )
+        self.assertEqual(set(new_tree), {b"a", b"ad", b"c"})
+        ad_tree = self.store[new_tree[b"ad"][1]]
+        self.assertEqual(set(ad_tree), {b"b", b"c"})
 
 
 class TestReadPacksFile(TestCase):
-
     def test_read_packs(self):
-        self.assertEqual(["pack-1.pack"], list(read_packs_file(BytesIO(b"""P pack-1.pack
-"""))))
+        self.assertEqual(
+            ["pack-1.pack"],
+            list(
+                read_packs_file(
+                    BytesIO(
+                        b"""P pack-1.pack
+"""
+                    )
+                )
+            ),
+        )

文件差異過大導致無法顯示
+ 316 - 284
dulwich/tests/test_objects.py


+ 78 - 56
dulwich/tests/test_objectspec.py

@@ -25,7 +25,7 @@
 
 from dulwich.objects import (
     Blob,
-    )
+)
 from dulwich.objectspec import (
     parse_object,
     parse_commit,
@@ -35,14 +35,14 @@ from dulwich.objectspec import (
     parse_reftuple,
     parse_reftuples,
     parse_tree,
-    )
+)
 from dulwich.repo import MemoryRepo
 from dulwich.tests import (
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
     build_commit_graph,
-    )
+)
 
 
 class ParseObjectTests(TestCase):
@@ -68,8 +68,7 @@ class ParseCommitRangeTests(TestCase):
 
     def test_commit_by_sha(self):
         r = MemoryRepo()
-        c1, c2, c3 = build_commit_graph(
-                r.object_store, [[1], [2, 1], [3, 1, 2]])
+        c1, c2, c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         self.assertEqual([c1], list(parse_commit_range(r, c1.id)))
 
 
@@ -92,44 +91,50 @@ class ParseCommitTests(TestCase):
 
 
 class ParseRefTests(TestCase):
-
     def test_nonexistent(self):
         r = {}
         self.assertRaises(KeyError, parse_ref, r, b"thisdoesnotexist")
 
     def test_ambiguous_ref(self):
-        r = {b"ambig1": 'bla',
-             b"refs/ambig1": 'bla',
-             b"refs/tags/ambig1": 'bla',
-             b"refs/heads/ambig1": 'bla',
-             b"refs/remotes/ambig1": 'bla',
-             b"refs/remotes/ambig1/HEAD": "bla"}
+        r = {
+            b"ambig1": "bla",
+            b"refs/ambig1": "bla",
+            b"refs/tags/ambig1": "bla",
+            b"refs/heads/ambig1": "bla",
+            b"refs/remotes/ambig1": "bla",
+            b"refs/remotes/ambig1/HEAD": "bla",
+        }
         self.assertEqual(b"ambig1", parse_ref(r, b"ambig1"))
 
     def test_ambiguous_ref2(self):
-        r = {b"refs/ambig2": 'bla',
-             b"refs/tags/ambig2": 'bla',
-             b"refs/heads/ambig2": 'bla',
-             b"refs/remotes/ambig2": 'bla',
-             b"refs/remotes/ambig2/HEAD": "bla"}
+        r = {
+            b"refs/ambig2": "bla",
+            b"refs/tags/ambig2": "bla",
+            b"refs/heads/ambig2": "bla",
+            b"refs/remotes/ambig2": "bla",
+            b"refs/remotes/ambig2/HEAD": "bla",
+        }
         self.assertEqual(b"refs/ambig2", parse_ref(r, b"ambig2"))
 
     def test_ambiguous_tag(self):
-        r = {b"refs/tags/ambig3": 'bla',
-             b"refs/heads/ambig3": 'bla',
-             b"refs/remotes/ambig3": 'bla',
-             b"refs/remotes/ambig3/HEAD": "bla"}
+        r = {
+            b"refs/tags/ambig3": "bla",
+            b"refs/heads/ambig3": "bla",
+            b"refs/remotes/ambig3": "bla",
+            b"refs/remotes/ambig3/HEAD": "bla",
+        }
         self.assertEqual(b"refs/tags/ambig3", parse_ref(r, b"ambig3"))
 
     def test_ambiguous_head(self):
-        r = {b"refs/heads/ambig4": 'bla',
-             b"refs/remotes/ambig4": 'bla',
-             b"refs/remotes/ambig4/HEAD": "bla"}
+        r = {
+            b"refs/heads/ambig4": "bla",
+            b"refs/remotes/ambig4": "bla",
+            b"refs/remotes/ambig4/HEAD": "bla",
+        }
         self.assertEqual(b"refs/heads/ambig4", parse_ref(r, b"ambig4"))
 
     def test_ambiguous_remote(self):
-        r = {b"refs/remotes/ambig5": 'bla',
-             b"refs/remotes/ambig5/HEAD": "bla"}
+        r = {b"refs/remotes/ambig5": "bla", b"refs/remotes/ambig5/HEAD": "bla"}
         self.assertEqual(b"refs/remotes/ambig5", parse_ref(r, b"ambig5"))
 
     def test_ambiguous_remote_head(self):
@@ -150,7 +155,6 @@ class ParseRefTests(TestCase):
 
 
 class ParseRefsTests(TestCase):
-
     def test_nonexistent(self):
         r = {}
         self.assertRaises(KeyError, parse_refs, r, [b"thisdoesnotexist"])
@@ -165,62 +169,81 @@ class ParseRefsTests(TestCase):
 
 
 class ParseReftupleTests(TestCase):
-
     def test_nonexistent(self):
         r = {}
         self.assertRaises(KeyError, parse_reftuple, r, r, b"thisdoesnotexist")
 
     def test_head(self):
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", False),
-                         parse_reftuple(r, r, b"foo"))
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", True),
-                         parse_reftuple(r, r, b"+foo"))
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", True),
-                         parse_reftuple(r, {}, b"+foo"))
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", True),
-                         parse_reftuple(r, {}, b"foo", True))
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", False),
+            parse_reftuple(r, r, b"foo"),
+        )
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", True),
+            parse_reftuple(r, r, b"+foo"),
+        )
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", True),
+            parse_reftuple(r, {}, b"+foo"),
+        )
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", True),
+            parse_reftuple(r, {}, b"foo", True),
+        )
 
     def test_full(self):
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", False),
-                         parse_reftuple(r, r, b"refs/heads/foo"))
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", False),
+            parse_reftuple(r, r, b"refs/heads/foo"),
+        )
 
     def test_no_left_ref(self):
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual((None, b"refs/heads/foo", False),
-                         parse_reftuple(r, r, b":refs/heads/foo"))
+        self.assertEqual(
+            (None, b"refs/heads/foo", False),
+            parse_reftuple(r, r, b":refs/heads/foo"),
+        )
 
     def test_no_right_ref(self):
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual((b"refs/heads/foo", None, False),
-                         parse_reftuple(r, r, b"refs/heads/foo:"))
+        self.assertEqual(
+            (b"refs/heads/foo", None, False),
+            parse_reftuple(r, r, b"refs/heads/foo:"),
+        )
 
     def test_default_with_string(self):
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", False),
-                         parse_reftuple(r, r, "foo"))
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", False),
+            parse_reftuple(r, r, "foo"),
+        )
 
 
 class ParseReftuplesTests(TestCase):
-
     def test_nonexistent(self):
         r = {}
-        self.assertRaises(KeyError, parse_reftuples, r, r,
-                          [b"thisdoesnotexist"])
+        self.assertRaises(KeyError, parse_reftuples, r, r, [b"thisdoesnotexist"])
 
     def test_head(self):
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual([(b"refs/heads/foo", b"refs/heads/foo", False)],
-                         parse_reftuples(r, r, [b"foo"]))
+        self.assertEqual(
+            [(b"refs/heads/foo", b"refs/heads/foo", False)],
+            parse_reftuples(r, r, [b"foo"]),
+        )
 
     def test_full(self):
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual([(b"refs/heads/foo", b"refs/heads/foo", False)],
-                         parse_reftuples(r, r, b"refs/heads/foo"))
+        self.assertEqual(
+            [(b"refs/heads/foo", b"refs/heads/foo", False)],
+            parse_reftuples(r, r, b"refs/heads/foo"),
+        )
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual([(b"refs/heads/foo", b"refs/heads/foo", True)],
-                         parse_reftuples(r, r, b"refs/heads/foo", True))
+        self.assertEqual(
+            [(b"refs/heads/foo", b"refs/heads/foo", True)],
+            parse_reftuples(r, r, b"refs/heads/foo", True),
+        )
 
 
 class ParseTreeTests(TestCase):
@@ -232,7 +255,6 @@ class ParseTreeTests(TestCase):
 
     def test_from_commit(self):
         r = MemoryRepo()
-        c1, c2, c3 = build_commit_graph(
-                r.object_store, [[1], [2, 1], [3, 1, 2]])
+        c1, c2, c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         self.assertEqual(r[c1.tree], parse_tree(r, c1.id))
         self.assertEqual(r[c1.tree], parse_tree(r, c1.tree))

文件差異過大導致無法顯示
+ 313 - 250
dulwich/tests/test_pack.py


+ 311 - 216
dulwich/tests/test_patch.py

@@ -27,10 +27,10 @@ from dulwich.objects import (
     Commit,
     S_IFGITLINK,
     Tree,
-    )
+)
 from dulwich.object_store import (
     MemoryObjectStore,
-    )
+)
 from dulwich.patch import (
     get_summary,
     git_am_patch_split,
@@ -38,15 +38,14 @@ from dulwich.patch import (
     write_commit_patch,
     write_object_diff,
     write_tree_diff,
-    )
+)
 from dulwich.tests import (
     SkipTest,
     TestCase,
-    )
+)
 
 
 class WriteCommitPatchTests(TestCase):
-
     def test_simple_bytesio(self):
         f = BytesIO()
         c = Commit()
@@ -58,26 +57,28 @@ class WriteCommitPatchTests(TestCase):
         write_commit_patch(f, c, b"CONTENTS", (1, 1), version="custom")
         f.seek(0)
         lines = f.readlines()
-        self.assertTrue(lines[0].startswith(
-                    b"From 0b0d34d1b5b596c928adc9a727a4b9e03d025298"))
+        self.assertTrue(
+            lines[0].startswith(b"From 0b0d34d1b5b596c928adc9a727a4b9e03d025298")
+        )
         self.assertEqual(lines[1], b"From: Jelmer <jelmer@samba.org>\n")
         self.assertTrue(lines[2].startswith(b"Date: "))
-        self.assertEqual([
-            b"Subject: [PATCH 1/1] This is the first line\n",
-            b"And this is the second line.\n",
-            b"\n",
-            b"\n",
-            b"---\n"], lines[3:8])
-        self.assertEqual([
-            b"CONTENTS-- \n",
-            b"custom\n"], lines[-2:])
+        self.assertEqual(
+            [
+                b"Subject: [PATCH 1/1] This is the first line\n",
+                b"And this is the second line.\n",
+                b"\n",
+                b"\n",
+                b"---\n",
+            ],
+            lines[3:8],
+        )
+        self.assertEqual([b"CONTENTS-- \n", b"custom\n"], lines[-2:])
         if len(lines) >= 12:
             # diffstat may not be present
             self.assertEqual(lines[8], b" 0 files changed\n")
 
 
 class ReadGitAmPatch(TestCase):
-
     def test_extract_string(self):
         text = b"""\
 From ff643aae102d8870cac88e8f007e70f58f3a7363 Mon Sep 17 00:00:00 2001
@@ -93,17 +94,21 @@ Subject: [PATCH 1/2] Remove executable bit from prey.ico (triggers a warning).
 -- 
 1.7.0.4
 """  # noqa: W291
-        c, diff, version = git_am_patch_split(
-                StringIO(text.decode("utf-8")), "utf-8")
+        c, diff, version = git_am_patch_split(StringIO(text.decode("utf-8")), "utf-8")
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.committer)
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.author)
-        self.assertEqual(b"Remove executable bit from prey.ico "
-                         b"(triggers a warning).\n", c.message)
-        self.assertEqual(b""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
+        self.assertEqual(
+            b"Remove executable bit from prey.ico " b"(triggers a warning).\n",
+            c.message,
+        )
+        self.assertEqual(
+            b""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
  1 files changed, 0 insertions(+), 0 deletions(-)
  mode change 100755 => 100644 pixmaps/prey.ico
 
-""", diff)
+""",
+            diff,
+        )
         self.assertEqual(b"1.7.0.4", version)
 
     def test_extract_bytes(self):
@@ -124,13 +129,18 @@ Subject: [PATCH 1/2] Remove executable bit from prey.ico (triggers a warning).
         c, diff, version = git_am_patch_split(BytesIO(text))
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.committer)
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.author)
-        self.assertEqual(b"Remove executable bit from prey.ico "
-                         b"(triggers a warning).\n", c.message)
-        self.assertEqual(b""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
+        self.assertEqual(
+            b"Remove executable bit from prey.ico " b"(triggers a warning).\n",
+            c.message,
+        )
+        self.assertEqual(
+            b""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
  1 files changed, 0 insertions(+), 0 deletions(-)
  mode change 100755 => 100644 pixmaps/prey.ico
 
-""", diff)
+""",
+            diff,
+        )
         self.assertEqual(b"1.7.0.4", version)
 
     def test_extract_spaces(self):
@@ -152,13 +162,16 @@ Subject:  [Dulwich-users] [PATCH] Added unit tests for
 1.7.0.4
 """  # noqa: W291
         c, diff, version = git_am_patch_split(BytesIO(text), "utf-8")
-        self.assertEqual(b'''\
+        self.assertEqual(
+            b"""\
 Added unit tests for dulwich.object_store.tree_lookup_path.
 
 * dulwich/tests/test_object_store.py
   (TreeLookupPathTests): This test case contains a few tests that ensure the
    tree_lookup_path function works as expected.
-''', c.message)
+""",
+            c.message,
+        )
 
     def test_extract_pseudo_from_header(self):
         text = b"""From ff643aae102d8870cac88e8f007e70f58f3a7363 Mon Sep 17 00:00:00 2001
@@ -182,13 +195,16 @@ From: Jelmer Vernooij <jelmer@debian.org>
 """  # noqa: W291
         c, diff, version = git_am_patch_split(BytesIO(text), "utf-8")
         self.assertEqual(b"Jelmer Vernooij <jelmer@debian.org>", c.author)
-        self.assertEqual(b'''\
+        self.assertEqual(
+            b"""\
 Added unit tests for dulwich.object_store.tree_lookup_path.
 
 * dulwich/tests/test_object_store.py
   (TreeLookupPathTests): This test case contains a few tests that ensure the
    tree_lookup_path function works as expected.
-''', c.message)
+""",
+            c.message,
+        )
 
     def test_extract_no_version_tail(self):
         text = b"""\
@@ -211,8 +227,8 @@ From: Jelmer Vernooij <jelmer@debian.org>
 
     def test_extract_mercurial(self):
         raise SkipTest(
-                "git_am_patch_split doesn't handle Mercurial patches "
-                "properly yet")
+            "git_am_patch_split doesn't handle Mercurial patches " "properly yet"
+        )
         expected_diff = """\
 diff --git a/dulwich/tests/test_patch.py b/dulwich/tests/test_patch.py
 --- a/dulwich/tests/test_patch.py
@@ -227,7 +243,8 @@ diff --git a/dulwich/tests/test_patch.py b/dulwich/tests/test_patch.py
  
  class DiffTests(TestCase):
 """  # noqa: W291,W293
-        text = """\
+        text = (
+            """\
 From dulwich-users-bounces+jelmer=samba.org@lists.launchpad.net \
 Mon Nov 29 00:58:18 2010
 Date: Sun, 28 Nov 2010 17:57:27 -0600
@@ -246,7 +263,9 @@ Post to     : dulwich-users@lists.launchpad.net
 Unsubscribe : https://launchpad.net/~dulwich-users
 More help   : https://help.launchpad.net/ListHelp
 
-""" % expected_diff  # noqa: W291
+"""
+            % expected_diff
+        )  # noqa: W291
         c, diff, version = git_am_patch_split(BytesIO(text))
         self.assertEqual(expected_diff, diff)
         self.assertEqual(None, version)
@@ -258,50 +277,65 @@ class DiffTests(TestCase):
     def test_blob_diff(self):
         f = BytesIO()
         write_blob_diff(
-            f, (b"foo.txt", 0o644, Blob.from_string(b"old\nsame\n")),
-            (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")))
-        self.assertEqual([
-            b"diff --git a/foo.txt b/bar.txt",
-            b"index 3b0f961..a116b51 644",
-            b"--- a/foo.txt",
-            b"+++ b/bar.txt",
-            b"@@ -1,2 +1,2 @@",
-            b"-old",
-            b"+new",
-            b" same"
-            ], f.getvalue().splitlines())
+            f,
+            (b"foo.txt", 0o644, Blob.from_string(b"old\nsame\n")),
+            (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")),
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/foo.txt b/bar.txt",
+                b"index 3b0f961..a116b51 644",
+                b"--- a/foo.txt",
+                b"+++ b/bar.txt",
+                b"@@ -1,2 +1,2 @@",
+                b"-old",
+                b"+new",
+                b" same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_blob_add(self):
         f = BytesIO()
         write_blob_diff(
-            f, (None, None, None),
-            (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")))
-        self.assertEqual([
-             b'diff --git a/bar.txt b/bar.txt',
-             b'new file mode 644',
-             b'index 0000000..a116b51',
-             b'--- /dev/null',
-             b'+++ b/bar.txt',
-             b'@@ -0,0 +1,2 @@',
-             b'+new',
-             b'+same'
-            ], f.getvalue().splitlines())
+            f,
+            (None, None, None),
+            (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")),
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/bar.txt b/bar.txt",
+                b"new file mode 644",
+                b"index 0000000..a116b51",
+                b"--- /dev/null",
+                b"+++ b/bar.txt",
+                b"@@ -0,0 +1,2 @@",
+                b"+new",
+                b"+same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_blob_remove(self):
         f = BytesIO()
         write_blob_diff(
-            f, (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")),
-            (None, None, None))
-        self.assertEqual([
-            b'diff --git a/bar.txt b/bar.txt',
-            b'deleted file mode 644',
-            b'index a116b51..0000000',
-            b'--- a/bar.txt',
-            b'+++ /dev/null',
-            b'@@ -1,2 +0,0 @@',
-            b'-new',
-            b'-same'
-            ], f.getvalue().splitlines())
+            f,
+            (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")),
+            (None, None, None),
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/bar.txt b/bar.txt",
+                b"deleted file mode 644",
+                b"index a116b51..0000000",
+                b"--- a/bar.txt",
+                b"+++ /dev/null",
+                b"@@ -1,2 +0,0 @@",
+                b"-new",
+                b"-same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_tree_diff(self):
         f = BytesIO()
@@ -319,54 +353,78 @@ class DiffTests(TestCase):
         tree2.add(b"added.txt", 0o644, added.id)
         tree2.add(b"changed.txt", 0o644, changed2.id)
         tree2.add(b"unchanged.txt", 0o644, changed1.id)
-        store.add_objects([(o, None) for o in [
-            tree1, tree2, added, removed, changed1, changed2, unchanged]])
+        store.add_objects(
+            [
+                (o, None)
+                for o in [
+                    tree1,
+                    tree2,
+                    added,
+                    removed,
+                    changed1,
+                    changed2,
+                    unchanged,
+                ]
+            ]
+        )
         write_tree_diff(f, store, tree1.id, tree2.id)
-        self.assertEqual([
-            b'diff --git a/added.txt b/added.txt',
-            b'new file mode 644',
-            b'index 0000000..76d4bb8',
-            b'--- /dev/null',
-            b'+++ b/added.txt',
-            b'@@ -0,0 +1 @@',
-            b'+add',
-            b'diff --git a/changed.txt b/changed.txt',
-            b'index bf84e48..1be2436 644',
-            b'--- a/changed.txt',
-            b'+++ b/changed.txt',
-            b'@@ -1,2 +1,2 @@',
-            b' unchanged',
-            b'-removed',
-            b'+added',
-            b'diff --git a/removed.txt b/removed.txt',
-            b'deleted file mode 644',
-            b'index 2c3f0b3..0000000',
-            b'--- a/removed.txt',
-            b'+++ /dev/null',
-            b'@@ -1 +0,0 @@',
-            b'-removed',
-            ], f.getvalue().splitlines())
+        self.assertEqual(
+            [
+                b"diff --git a/added.txt b/added.txt",
+                b"new file mode 644",
+                b"index 0000000..76d4bb8",
+                b"--- /dev/null",
+                b"+++ b/added.txt",
+                b"@@ -0,0 +1 @@",
+                b"+add",
+                b"diff --git a/changed.txt b/changed.txt",
+                b"index bf84e48..1be2436 644",
+                b"--- a/changed.txt",
+                b"+++ b/changed.txt",
+                b"@@ -1,2 +1,2 @@",
+                b" unchanged",
+                b"-removed",
+                b"+added",
+                b"diff --git a/removed.txt b/removed.txt",
+                b"deleted file mode 644",
+                b"index 2c3f0b3..0000000",
+                b"--- a/removed.txt",
+                b"+++ /dev/null",
+                b"@@ -1 +0,0 @@",
+                b"-removed",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_tree_diff_submodule(self):
         f = BytesIO()
         store = MemoryObjectStore()
         tree1 = Tree()
-        tree1.add(b"asubmodule", S_IFGITLINK,
-                  b"06d0bdd9e2e20377b3180e4986b14c8549b393e4")
+        tree1.add(
+            b"asubmodule",
+            S_IFGITLINK,
+            b"06d0bdd9e2e20377b3180e4986b14c8549b393e4",
+        )
         tree2 = Tree()
-        tree2.add(b"asubmodule", S_IFGITLINK,
-                  b"cc975646af69f279396d4d5e1379ac6af80ee637")
+        tree2.add(
+            b"asubmodule",
+            S_IFGITLINK,
+            b"cc975646af69f279396d4d5e1379ac6af80ee637",
+        )
         store.add_objects([(o, None) for o in [tree1, tree2]])
         write_tree_diff(f, store, tree1.id, tree2.id)
-        self.assertEqual([
-            b'diff --git a/asubmodule b/asubmodule',
-            b'index 06d0bdd..cc97564 160000',
-            b'--- a/asubmodule',
-            b'+++ b/asubmodule',
-            b'@@ -1 +1 @@',
-            b'-Subproject commit 06d0bdd9e2e20377b3180e4986b14c8549b393e4',
-            b'+Subproject commit cc975646af69f279396d4d5e1379ac6af80ee637',
-            ], f.getvalue().splitlines())
+        self.assertEqual(
+            [
+                b"diff --git a/asubmodule b/asubmodule",
+                b"index 06d0bdd..cc97564 160000",
+                b"--- a/asubmodule",
+                b"+++ b/asubmodule",
+                b"@@ -1 +1 @@",
+                b"-Subproject commit 06d0bdd9e2e20377b3180e4986b14c8549b393e4",
+                b"+Subproject commit cc975646af69f279396d4d5e1379ac6af80ee637",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_object_diff_blob(self):
         f = BytesIO()
@@ -374,54 +432,62 @@ class DiffTests(TestCase):
         b2 = Blob.from_string(b"new\nsame\n")
         store = MemoryObjectStore()
         store.add_objects([(b1, None), (b2, None)])
-        write_object_diff(f, store, (b"foo.txt", 0o644, b1.id),
-                                    (b"bar.txt", 0o644, b2.id))
-        self.assertEqual([
-            b"diff --git a/foo.txt b/bar.txt",
-            b"index 3b0f961..a116b51 644",
-            b"--- a/foo.txt",
-            b"+++ b/bar.txt",
-            b"@@ -1,2 +1,2 @@",
-            b"-old",
-            b"+new",
-            b" same"
-            ], f.getvalue().splitlines())
+        write_object_diff(
+            f, store, (b"foo.txt", 0o644, b1.id), (b"bar.txt", 0o644, b2.id)
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/foo.txt b/bar.txt",
+                b"index 3b0f961..a116b51 644",
+                b"--- a/foo.txt",
+                b"+++ b/bar.txt",
+                b"@@ -1,2 +1,2 @@",
+                b"-old",
+                b"+new",
+                b" same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_object_diff_add_blob(self):
         f = BytesIO()
         store = MemoryObjectStore()
         b2 = Blob.from_string(b"new\nsame\n")
         store.add_object(b2)
-        write_object_diff(f, store, (None, None, None),
-                                    (b"bar.txt", 0o644, b2.id))
-        self.assertEqual([
-             b'diff --git a/bar.txt b/bar.txt',
-             b'new file mode 644',
-             b'index 0000000..a116b51',
-             b'--- /dev/null',
-             b'+++ b/bar.txt',
-             b'@@ -0,0 +1,2 @@',
-             b'+new',
-             b'+same'
-            ], f.getvalue().splitlines())
+        write_object_diff(f, store, (None, None, None), (b"bar.txt", 0o644, b2.id))
+        self.assertEqual(
+            [
+                b"diff --git a/bar.txt b/bar.txt",
+                b"new file mode 644",
+                b"index 0000000..a116b51",
+                b"--- /dev/null",
+                b"+++ b/bar.txt",
+                b"@@ -0,0 +1,2 @@",
+                b"+new",
+                b"+same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_object_diff_remove_blob(self):
         f = BytesIO()
         b1 = Blob.from_string(b"new\nsame\n")
         store = MemoryObjectStore()
         store.add_object(b1)
-        write_object_diff(f, store, (b"bar.txt", 0o644, b1.id),
-                                    (None, None, None))
-        self.assertEqual([
-            b'diff --git a/bar.txt b/bar.txt',
-            b'deleted file mode 644',
-            b'index a116b51..0000000',
-            b'--- a/bar.txt',
-            b'+++ /dev/null',
-            b'@@ -1,2 +0,0 @@',
-            b'-new',
-            b'-same'
-            ], f.getvalue().splitlines())
+        write_object_diff(f, store, (b"bar.txt", 0o644, b1.id), (None, None, None))
+        self.assertEqual(
+            [
+                b"diff --git a/bar.txt b/bar.txt",
+                b"deleted file mode 644",
+                b"index a116b51..0000000",
+                b"--- a/bar.txt",
+                b"+++ /dev/null",
+                b"@@ -1,2 +0,0 @@",
+                b"-new",
+                b"-same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_object_diff_bin_blob_force(self):
         f = BytesIO()
@@ -430,33 +496,42 @@ class DiffTests(TestCase):
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
-            b"\x08\x04\x00\x00\x00\x05\x04\x8b")
+            b"\x08\x04\x00\x00\x00\x05\x04\x8b"
+        )
         b2 = Blob.from_string(
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
-            b"\x08\x03\x00\x00\x00\x98\xd3\xb3")
+            b"\x08\x03\x00\x00\x00\x98\xd3\xb3"
+        )
         store = MemoryObjectStore()
         store.add_objects([(b1, None), (b2, None)])
         write_object_diff(
-            f, store, (b'foo.png', 0o644, b1.id),
-            (b'bar.png', 0o644, b2.id), diff_binary=True)
-        self.assertEqual([
-            b'diff --git a/foo.png b/bar.png',
-            b'index f73e47d..06364b7 644',
-            b'--- a/foo.png',
-            b'+++ b/bar.png',
-            b'@@ -1,4 +1,4 @@',
-            b' \x89PNG',
-            b' \x1a',
-            b' \x00\x00\x00',
-            b'-IHDR\x00\x00\x01\xd5\x00\x00\x00'
-            b'\x9f\x08\x04\x00\x00\x00\x05\x04\x8b',
-            b'\\ No newline at end of file',
-            b'+IHDR\x00\x00\x01\xd5\x00\x00\x00\x9f'
-            b'\x08\x03\x00\x00\x00\x98\xd3\xb3',
-            b'\\ No newline at end of file'
-            ], f.getvalue().splitlines())
+            f,
+            store,
+            (b"foo.png", 0o644, b1.id),
+            (b"bar.png", 0o644, b2.id),
+            diff_binary=True,
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/foo.png b/bar.png",
+                b"index f73e47d..06364b7 644",
+                b"--- a/foo.png",
+                b"+++ b/bar.png",
+                b"@@ -1,4 +1,4 @@",
+                b" \x89PNG",
+                b" \x1a",
+                b" \x00\x00\x00",
+                b"-IHDR\x00\x00\x01\xd5\x00\x00\x00"
+                b"\x9f\x08\x04\x00\x00\x00\x05\x04\x8b",
+                b"\\ No newline at end of file",
+                b"+IHDR\x00\x00\x01\xd5\x00\x00\x00\x9f"
+                b"\x08\x03\x00\x00\x00\x98\xd3\xb3",
+                b"\\ No newline at end of file",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_object_diff_bin_blob(self):
         f = BytesIO()
@@ -465,57 +540,69 @@ class DiffTests(TestCase):
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
-            b"\x08\x04\x00\x00\x00\x05\x04\x8b")
+            b"\x08\x04\x00\x00\x00\x05\x04\x8b"
+        )
         b2 = Blob.from_string(
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
-            b"\x08\x03\x00\x00\x00\x98\xd3\xb3")
+            b"\x08\x03\x00\x00\x00\x98\xd3\xb3"
+        )
         store = MemoryObjectStore()
         store.add_objects([(b1, None), (b2, None)])
-        write_object_diff(f, store, (b'foo.png', 0o644, b1.id),
-                                    (b'bar.png', 0o644, b2.id))
-        self.assertEqual([
-            b'diff --git a/foo.png b/bar.png',
-            b'index f73e47d..06364b7 644',
-            b'Binary files a/foo.png and b/bar.png differ'
-            ], f.getvalue().splitlines())
+        write_object_diff(
+            f, store, (b"foo.png", 0o644, b1.id), (b"bar.png", 0o644, b2.id)
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/foo.png b/bar.png",
+                b"index f73e47d..06364b7 644",
+                b"Binary files a/foo.png and b/bar.png differ",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_object_diff_add_bin_blob(self):
         f = BytesIO()
         b2 = Blob.from_string(
-            b'\x89\x50\x4e\x47\x0d\x0a\x1a\x0a'
-            b'\x00\x00\x00\x0d\x49\x48\x44\x52'
-            b'\x00\x00\x01\xd5\x00\x00\x00\x9f'
-            b'\x08\x03\x00\x00\x00\x98\xd3\xb3')
+            b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
+            b"\x00\x00\x00\x0d\x49\x48\x44\x52"
+            b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
+            b"\x08\x03\x00\x00\x00\x98\xd3\xb3"
+        )
         store = MemoryObjectStore()
         store.add_object(b2)
-        write_object_diff(f, store, (None, None, None),
-                                    (b'bar.png', 0o644, b2.id))
-        self.assertEqual([
-            b'diff --git a/bar.png b/bar.png',
-            b'new file mode 644',
-            b'index 0000000..06364b7',
-            b'Binary files /dev/null and b/bar.png differ'
-            ], f.getvalue().splitlines())
+        write_object_diff(f, store, (None, None, None), (b"bar.png", 0o644, b2.id))
+        self.assertEqual(
+            [
+                b"diff --git a/bar.png b/bar.png",
+                b"new file mode 644",
+                b"index 0000000..06364b7",
+                b"Binary files /dev/null and b/bar.png differ",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_object_diff_remove_bin_blob(self):
         f = BytesIO()
         b1 = Blob.from_string(
-            b'\x89\x50\x4e\x47\x0d\x0a\x1a\x0a'
-            b'\x00\x00\x00\x0d\x49\x48\x44\x52'
-            b'\x00\x00\x01\xd5\x00\x00\x00\x9f'
-            b'\x08\x04\x00\x00\x00\x05\x04\x8b')
+            b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
+            b"\x00\x00\x00\x0d\x49\x48\x44\x52"
+            b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
+            b"\x08\x04\x00\x00\x00\x05\x04\x8b"
+        )
         store = MemoryObjectStore()
         store.add_object(b1)
-        write_object_diff(f, store, (b'foo.png', 0o644, b1.id),
-                                    (None, None, None))
-        self.assertEqual([
-            b'diff --git a/foo.png b/foo.png',
-            b'deleted file mode 644',
-            b'index f73e47d..0000000',
-            b'Binary files a/foo.png and /dev/null differ'
-            ], f.getvalue().splitlines())
+        write_object_diff(f, store, (b"foo.png", 0o644, b1.id), (None, None, None))
+        self.assertEqual(
+            [
+                b"diff --git a/foo.png b/foo.png",
+                b"deleted file mode 644",
+                b"index f73e47d..0000000",
+                b"Binary files a/foo.png and /dev/null differ",
+            ],
+            f.getvalue().splitlines(),
+        )
 
     def test_object_diff_kind_change(self):
         f = BytesIO()
@@ -523,25 +610,33 @@ class DiffTests(TestCase):
         store = MemoryObjectStore()
         store.add_object(b1)
         write_object_diff(
-            f, store, (b"bar.txt", 0o644, b1.id),
-            (b"bar.txt", 0o160000,
-                b"06d0bdd9e2e20377b3180e4986b14c8549b393e4"))
-        self.assertEqual([
-            b'diff --git a/bar.txt b/bar.txt',
-            b'old file mode 644',
-            b'new file mode 160000',
-            b'index a116b51..06d0bdd 160000',
-            b'--- a/bar.txt',
-            b'+++ b/bar.txt',
-            b'@@ -1,2 +1 @@',
-            b'-new',
-            b'-same',
-            b'+Subproject commit 06d0bdd9e2e20377b3180e4986b14c8549b393e4',
-            ], f.getvalue().splitlines())
+            f,
+            store,
+            (b"bar.txt", 0o644, b1.id),
+            (
+                b"bar.txt",
+                0o160000,
+                b"06d0bdd9e2e20377b3180e4986b14c8549b393e4",
+            ),
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/bar.txt b/bar.txt",
+                b"old file mode 644",
+                b"new file mode 160000",
+                b"index a116b51..06d0bdd 160000",
+                b"--- a/bar.txt",
+                b"+++ b/bar.txt",
+                b"@@ -1,2 +1 @@",
+                b"-new",
+                b"-same",
+                b"+Subproject commit 06d0bdd9e2e20377b3180e4986b14c8549b393e4",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
 class GetSummaryTests(TestCase):
-
     def test_simple(self):
         c = Commit()
         c.committer = c.author = b"Jelmer <jelmer@samba.org>"
@@ -549,4 +644,4 @@ class GetSummaryTests(TestCase):
         c.commit_timezone = c.author_timezone = 0
         c.message = b"This is the first line\nAnd this is the second line.\n"
         c.tree = Tree().id
-        self.assertEqual('This-is-the-first-line', get_summary(c))
+        self.assertEqual("This-is-the-first-line", get_summary(c))

文件差異過大導致無法顯示
+ 470 - 201
dulwich/tests/test_porcelain.py


+ 84 - 84
dulwich/tests/test_protocol.py

@@ -25,7 +25,7 @@ from io import BytesIO
 
 from dulwich.errors import (
     HangupException,
-    )
+)
 from dulwich.protocol import (
     GitProtocolError,
     PktLineParser,
@@ -38,27 +38,26 @@ from dulwich.protocol import (
     MULTI_ACK,
     MULTI_ACK_DETAILED,
     BufferedPktLineWriter,
-    )
+)
 from dulwich.tests import TestCase
 
 
 class BaseProtocolTests(object):
-
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)
-        self.assertEqual(self.rout.getvalue(), b'0000')
+        self.assertEqual(self.rout.getvalue(), b"0000")
 
     def test_write_pkt_line(self):
-        self.proto.write_pkt_line(b'bla')
-        self.assertEqual(self.rout.getvalue(), b'0007bla')
+        self.proto.write_pkt_line(b"bla")
+        self.assertEqual(self.rout.getvalue(), b"0007bla")
 
     def test_read_pkt_line(self):
-        self.rin.write(b'0008cmd ')
+        self.rin.write(b"0008cmd ")
         self.rin.seek(0)
-        self.assertEqual(b'cmd ', self.proto.read_pkt_line())
+        self.assertEqual(b"cmd ", self.proto.read_pkt_line())
 
     def test_eof(self):
-        self.rin.write(b'0000')
+        self.rin.write(b"0000")
         self.rin.seek(0)
         self.assertFalse(self.proto.eof())
         self.assertEqual(None, self.proto.read_pkt_line())
@@ -66,51 +65,50 @@ class BaseProtocolTests(object):
         self.assertRaises(HangupException, self.proto.read_pkt_line)
 
     def test_unread_pkt_line(self):
-        self.rin.write(b'0007foo0000')
+        self.rin.write(b"0007foo0000")
         self.rin.seek(0)
-        self.assertEqual(b'foo', self.proto.read_pkt_line())
-        self.proto.unread_pkt_line(b'bar')
-        self.assertEqual(b'bar', self.proto.read_pkt_line())
+        self.assertEqual(b"foo", self.proto.read_pkt_line())
+        self.proto.unread_pkt_line(b"bar")
+        self.assertEqual(b"bar", self.proto.read_pkt_line())
         self.assertEqual(None, self.proto.read_pkt_line())
-        self.proto.unread_pkt_line(b'baz1')
-        self.assertRaises(ValueError, self.proto.unread_pkt_line, b'baz2')
+        self.proto.unread_pkt_line(b"baz1")
+        self.assertRaises(ValueError, self.proto.unread_pkt_line, b"baz2")
 
     def test_read_pkt_seq(self):
-        self.rin.write(b'0008cmd 0005l0000')
+        self.rin.write(b"0008cmd 0005l0000")
         self.rin.seek(0)
-        self.assertEqual([b'cmd ', b'l'], list(self.proto.read_pkt_seq()))
+        self.assertEqual([b"cmd ", b"l"], list(self.proto.read_pkt_seq()))
 
     def test_read_pkt_line_none(self):
-        self.rin.write(b'0000')
+        self.rin.write(b"0000")
         self.rin.seek(0)
         self.assertEqual(None, self.proto.read_pkt_line())
 
     def test_read_pkt_line_wrong_size(self):
-        self.rin.write(b'0100too short')
+        self.rin.write(b"0100too short")
         self.rin.seek(0)
         self.assertRaises(GitProtocolError, self.proto.read_pkt_line)
 
     def test_write_sideband(self):
-        self.proto.write_sideband(3, b'bloe')
-        self.assertEqual(self.rout.getvalue(), b'0009\x03bloe')
+        self.proto.write_sideband(3, b"bloe")
+        self.assertEqual(self.rout.getvalue(), b"0009\x03bloe")
 
     def test_send_cmd(self):
-        self.proto.send_cmd(b'fetch', b'a', b'b')
-        self.assertEqual(self.rout.getvalue(), b'000efetch a\x00b\x00')
+        self.proto.send_cmd(b"fetch", b"a", b"b")
+        self.assertEqual(self.rout.getvalue(), b"000efetch a\x00b\x00")
 
     def test_read_cmd(self):
-        self.rin.write(b'0012cmd arg1\x00arg2\x00')
+        self.rin.write(b"0012cmd arg1\x00arg2\x00")
         self.rin.seek(0)
-        self.assertEqual((b'cmd', [b'arg1', b'arg2']), self.proto.read_cmd())
+        self.assertEqual((b"cmd", [b"arg1", b"arg2"]), self.proto.read_cmd())
 
     def test_read_cmd_noend0(self):
-        self.rin.write(b'0011cmd arg1\x00arg2')
+        self.rin.write(b"0011cmd arg1\x00arg2")
         self.rin.seek(0)
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
 
 class ProtocolTests(BaseProtocolTests, TestCase):
-
     def setUp(self):
         TestCase.setUp(self)
         self.rout = BytesIO()
@@ -128,9 +126,8 @@ class ReceivableBytesIO(BytesIO):
     def recv(self, size):
         # fail fast if no bytes are available; in a real socket, this would
         # block forever
-        if (self.tell() == len(self.getvalue())
-                and not self.allow_read_past_eof):
-            raise GitProtocolError('Blocking read past end of socket')
+        if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof:
+            raise GitProtocolError("Blocking read past end of socket")
         if size == 1:
             return self.read(1)
         # calls shouldn't return quite as much as asked for
@@ -138,7 +135,6 @@ class ReceivableBytesIO(BytesIO):
 
 
 class ReceivableProtocolTests(BaseProtocolTests, TestCase):
-
     def setUp(self):
         TestCase.setUp(self)
         self.rout = BytesIO()
@@ -154,10 +150,10 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         BaseProtocolTests.test_eof(self)
 
     def test_recv(self):
-        all_data = b'1234567' * 10  # not a multiple of bufsize
+        all_data = b"1234567" * 10  # not a multiple of bufsize
         self.rin.write(all_data)
         self.rin.seek(0)
-        data = b''
+        data = b""
         # We ask for 8 bytes each time and actually read 7, so it should take
         # exactly 10 iterations.
         for _ in range(10):
@@ -167,28 +163,28 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.assertEqual(all_data, data)
 
     def test_recv_read(self):
-        all_data = b'1234567'  # recv exactly in one call
+        all_data = b"1234567"  # recv exactly in one call
         self.rin.write(all_data)
         self.rin.seek(0)
-        self.assertEqual(b'1234', self.proto.recv(4))
-        self.assertEqual(b'567', self.proto.read(3))
+        self.assertEqual(b"1234", self.proto.recv(4))
+        self.assertEqual(b"567", self.proto.read(3))
         self.assertRaises(GitProtocolError, self.proto.recv, 10)
 
     def test_read_recv(self):
-        all_data = b'12345678abcdefg'
+        all_data = b"12345678abcdefg"
         self.rin.write(all_data)
         self.rin.seek(0)
-        self.assertEqual(b'1234', self.proto.read(4))
-        self.assertEqual(b'5678abc', self.proto.recv(8))
-        self.assertEqual(b'defg', self.proto.read(4))
+        self.assertEqual(b"1234", self.proto.read(4))
+        self.assertEqual(b"5678abc", self.proto.recv(8))
+        self.assertEqual(b"defg", self.proto.read(4))
         self.assertRaises(GitProtocolError, self.proto.recv, 10)
 
     def test_mixed(self):
         # arbitrary non-repeating string
-        all_data = b','.join(str(i).encode('ascii') for i in range(100))
+        all_data = b",".join(str(i).encode("ascii") for i in range(100))
         self.rin.write(all_data)
         self.rin.seek(0)
-        data = b''
+        data = b""
 
         for i in range(1, 100):
             data += self.proto.recv(i)
@@ -209,41 +205,46 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
 
 
 class CapabilitiesTestCase(TestCase):
-
     def test_plain(self):
-        self.assertEqual((b'bla', []), extract_capabilities(b'bla'))
+        self.assertEqual((b"bla", []), extract_capabilities(b"bla"))
 
     def test_caps(self):
-        self.assertEqual((b'bla', [b'la']), extract_capabilities(b'bla\0la'))
-        self.assertEqual((b'bla', [b'la']), extract_capabilities(b'bla\0la\n'))
-        self.assertEqual((b'bla', [b'la', b'la']),
-                         extract_capabilities(b'bla\0la la'))
+        self.assertEqual((b"bla", [b"la"]), extract_capabilities(b"bla\0la"))
+        self.assertEqual((b"bla", [b"la"]), extract_capabilities(b"bla\0la\n"))
+        self.assertEqual((b"bla", [b"la", b"la"]), extract_capabilities(b"bla\0la la"))
 
     def test_plain_want_line(self):
-        self.assertEqual((b'want bla', []),
-                         extract_want_line_capabilities(b'want bla'))
+        self.assertEqual((b"want bla", []), extract_want_line_capabilities(b"want bla"))
 
     def test_caps_want_line(self):
-        self.assertEqual((b'want bla', [b'la']),
-                         extract_want_line_capabilities(b'want bla la'))
-        self.assertEqual((b'want bla', [b'la']),
-                         extract_want_line_capabilities(b'want bla la\n'))
-        self.assertEqual((b'want bla', [b'la', b'la']),
-                         extract_want_line_capabilities(b'want bla la la'))
+        self.assertEqual(
+            (b"want bla", [b"la"]),
+            extract_want_line_capabilities(b"want bla la"),
+        )
+        self.assertEqual(
+            (b"want bla", [b"la"]),
+            extract_want_line_capabilities(b"want bla la\n"),
+        )
+        self.assertEqual(
+            (b"want bla", [b"la", b"la"]),
+            extract_want_line_capabilities(b"want bla la la"),
+        )
 
     def test_ack_type(self):
-        self.assertEqual(SINGLE_ACK, ack_type([b'foo', b'bar']))
-        self.assertEqual(MULTI_ACK, ack_type([b'foo', b'bar', b'multi_ack']))
-        self.assertEqual(MULTI_ACK_DETAILED,
-                         ack_type([b'foo', b'bar', b'multi_ack_detailed']))
+        self.assertEqual(SINGLE_ACK, ack_type([b"foo", b"bar"]))
+        self.assertEqual(MULTI_ACK, ack_type([b"foo", b"bar", b"multi_ack"]))
+        self.assertEqual(
+            MULTI_ACK_DETAILED,
+            ack_type([b"foo", b"bar", b"multi_ack_detailed"]),
+        )
         # choose detailed when both present
-        self.assertEqual(MULTI_ACK_DETAILED,
-                         ack_type([b'foo', b'bar', b'multi_ack',
-                                   b'multi_ack_detailed']))
+        self.assertEqual(
+            MULTI_ACK_DETAILED,
+            ack_type([b"foo", b"bar", b"multi_ack", b"multi_ack_detailed"]),
+        )
 
 
 class BufferedPktLineWriterTests(TestCase):
-
     def setUp(self):
         TestCase.setUp(self)
         self._output = BytesIO()
@@ -257,48 +258,47 @@ class BufferedPktLineWriterTests(TestCase):
         self._output.truncate()
 
     def test_write(self):
-        self._writer.write(b'foo')
-        self.assertOutputEquals(b'')
+        self._writer.write(b"foo")
+        self.assertOutputEquals(b"")
         self._writer.flush()
-        self.assertOutputEquals(b'0007foo')
+        self.assertOutputEquals(b"0007foo")
 
     def test_write_none(self):
         self._writer.write(None)
-        self.assertOutputEquals(b'')
+        self.assertOutputEquals(b"")
         self._writer.flush()
-        self.assertOutputEquals(b'0000')
+        self.assertOutputEquals(b"0000")
 
     def test_flush_empty(self):
         self._writer.flush()
-        self.assertOutputEquals(b'')
+        self.assertOutputEquals(b"")
 
     def test_write_multiple(self):
-        self._writer.write(b'foo')
-        self._writer.write(b'bar')
-        self.assertOutputEquals(b'')
+        self._writer.write(b"foo")
+        self._writer.write(b"bar")
+        self.assertOutputEquals(b"")
         self._writer.flush()
-        self.assertOutputEquals(b'0007foo0007bar')
+        self.assertOutputEquals(b"0007foo0007bar")
 
     def test_write_across_boundary(self):
-        self._writer.write(b'foo')
-        self._writer.write(b'barbaz')
-        self.assertOutputEquals(b'0007foo000abarba')
+        self._writer.write(b"foo")
+        self._writer.write(b"barbaz")
+        self.assertOutputEquals(b"0007foo000abarba")
         self._truncate()
         self._writer.flush()
-        self.assertOutputEquals(b'z')
+        self.assertOutputEquals(b"z")
 
     def test_write_to_boundary(self):
-        self._writer.write(b'foo')
-        self._writer.write(b'barba')
-        self.assertOutputEquals(b'0007foo0009barba')
+        self._writer.write(b"foo")
+        self._writer.write(b"barba")
+        self.assertOutputEquals(b"0007foo0009barba")
         self._truncate()
-        self._writer.write(b'z')
+        self._writer.write(b"z")
         self._writer.flush()
-        self.assertOutputEquals(b'0005z')
+        self.assertOutputEquals(b"0005z")
 
 
 class PktLineParserTests(TestCase):
-
     def test_none(self):
         pktlines = []
         parser = PktLineParser(pktlines.append)

+ 102 - 28
dulwich/tests/test_reflog.py

@@ -21,52 +21,126 @@
 
 """Tests for dulwich.reflog."""
 
+from io import BytesIO
 
+from dulwich.objects import ZERO_SHA
 from dulwich.reflog import (
+    drop_reflog_entry,
     format_reflog_line,
     parse_reflog_line,
-    )
+    read_reflog,
+)
 
 from dulwich.tests import (
     TestCase,
-    )
+)
 
 
 class ReflogLineTests(TestCase):
-
     def test_format(self):
         self.assertEqual(
-            b'0000000000000000000000000000000000000000 '
-            b'49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij '
-            b'<jelmer@jelmer.uk> 1446552482 +0000	'
-            b'clone: from git://jelmer.uk/samba',
+            b"0000000000000000000000000000000000000000 "
+            b"49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij "
+            b"<jelmer@jelmer.uk> 1446552482 +0000	"
+            b"clone: from git://jelmer.uk/samba",
             format_reflog_line(
-                b'0000000000000000000000000000000000000000',
-                b'49030649db3dfec5a9bc03e5dde4255a14499f16',
-                b'Jelmer Vernooij <jelmer@jelmer.uk>',
-                1446552482, 0, b'clone: from git://jelmer.uk/samba'))
+                b"0000000000000000000000000000000000000000",
+                b"49030649db3dfec5a9bc03e5dde4255a14499f16",
+                b"Jelmer Vernooij <jelmer@jelmer.uk>",
+                1446552482,
+                0,
+                b"clone: from git://jelmer.uk/samba",
+            ),
+        )
 
         self.assertEqual(
-            b'0000000000000000000000000000000000000000 '
-            b'49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij '
-            b'<jelmer@jelmer.uk> 1446552482 +0000	'
-            b'clone: from git://jelmer.uk/samba',
+            b"0000000000000000000000000000000000000000 "
+            b"49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij "
+            b"<jelmer@jelmer.uk> 1446552482 +0000	"
+            b"clone: from git://jelmer.uk/samba",
             format_reflog_line(
                 None,
-                b'49030649db3dfec5a9bc03e5dde4255a14499f16',
-                b'Jelmer Vernooij <jelmer@jelmer.uk>',
-                1446552482, 0, b'clone: from git://jelmer.uk/samba'))
+                b"49030649db3dfec5a9bc03e5dde4255a14499f16",
+                b"Jelmer Vernooij <jelmer@jelmer.uk>",
+                1446552482,
+                0,
+                b"clone: from git://jelmer.uk/samba",
+            ),
+        )
 
     def test_parse(self):
         reflog_line = (
-                 b'0000000000000000000000000000000000000000 '
-                 b'49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij '
-                 b'<jelmer@jelmer.uk> 1446552482 +0000	'
-                 b'clone: from git://jelmer.uk/samba'
-                 )
+            b"0000000000000000000000000000000000000000 "
+            b"49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij "
+            b"<jelmer@jelmer.uk> 1446552482 +0000	"
+            b"clone: from git://jelmer.uk/samba"
+        )
         self.assertEqual(
-                (b'0000000000000000000000000000000000000000',
-                 b'49030649db3dfec5a9bc03e5dde4255a14499f16',
-                 b'Jelmer Vernooij <jelmer@jelmer.uk>',
-                 1446552482, 0, b'clone: from git://jelmer.uk/samba'),
-                parse_reflog_line(reflog_line))
+            (
+                b"0000000000000000000000000000000000000000",
+                b"49030649db3dfec5a9bc03e5dde4255a14499f16",
+                b"Jelmer Vernooij <jelmer@jelmer.uk>",
+                1446552482,
+                0,
+                b"clone: from git://jelmer.uk/samba",
+            ),
+            parse_reflog_line(reflog_line),
+        )
+
+
+_TEST_REFLOG = (
+    b"0000000000000000000000000000000000000000 "
+    b"49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij "
+    b"<jelmer@jelmer.uk> 1446552482 +0000	"
+    b"clone: from git://jelmer.uk/samba\n"
+    b"49030649db3dfec5a9bc03e5dde4255a14499f16 "
+    b"42d06bd4b77fed026b154d16493e5deab78f02ec Jelmer Vernooij "
+    b"<jelmer@jelmer.uk> 1446552483 +0000	"
+    b"clone: from git://jelmer.uk/samba\n"
+    b"42d06bd4b77fed026b154d16493e5deab78f02ec "
+    b"df6800012397fb85c56e7418dd4eb9405dee075c Jelmer Vernooij "
+    b"<jelmer@jelmer.uk> 1446552484 +0000	"
+    b"clone: from git://jelmer.uk/samba\n"
+)
+
+
+class ReflogDropTests(TestCase):
+    def setUp(self):
+        TestCase.setUp(self)
+        self.f = BytesIO(_TEST_REFLOG)
+        self.original_log = list(read_reflog(self.f))
+        self.f.seek(0)
+
+    def _read_log(self):
+        self.f.seek(0)
+        return list(read_reflog(self.f))
+
+    def test_invalid(self):
+        self.assertRaises(ValueError, drop_reflog_entry, self.f, -1)
+
+    def test_drop_entry(self):
+        drop_reflog_entry(self.f, 0)
+        log = self._read_log()
+        self.assertEqual(len(log), 2)
+        self.assertEqual(self.original_log[0:2], log)
+
+        self.f.seek(0)
+        drop_reflog_entry(self.f, 1)
+        log = self._read_log()
+        self.assertEqual(len(log), 1)
+        self.assertEqual(self.original_log[1], log[0])
+
+    def test_drop_entry_with_rewrite(self):
+        drop_reflog_entry(self.f, 1, True)
+        log = self._read_log()
+        self.assertEqual(len(log), 2)
+        self.assertEqual(self.original_log[0], log[0])
+        self.assertEqual(self.original_log[0].new_sha, log[1].old_sha)
+        self.assertEqual(self.original_log[2].new_sha, log[1].new_sha)
+
+        self.f.seek(0)
+        drop_reflog_entry(self.f, 1, True)
+        log = self._read_log()
+        self.assertEqual(len(log), 1)
+        self.assertEqual(ZERO_SHA, log[0].old_sha)
+        self.assertEqual(self.original_log[2].new_sha, log[0].new_sha)

文件差異過大導致無法顯示
+ 432 - 354
dulwich/tests/test_refs.py


文件差異過大導致無法顯示
+ 416 - 322
dulwich/tests/test_repository.py


文件差異過大導致無法顯示
+ 297 - 256
dulwich/tests/test_server.py


+ 29 - 23
dulwich/tests/test_utils.py

@@ -22,21 +22,20 @@
 
 from dulwich.object_store import (
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
     Blob,
-    )
+)
 from dulwich.tests import (
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
     make_object,
     build_commit_graph,
-    )
+)
 
 
 class BuildCommitGraphTest(TestCase):
-
     def setUp(self):
         super(BuildCommitGraphTest, self).setUp()
         self.store = MemoryObjectStore()
@@ -52,35 +51,42 @@ class BuildCommitGraphTest(TestCase):
         self.assertTrue(c2.commit_time > c1.commit_time)
 
     def test_merge(self):
-        c1, c2, c3, c4 = build_commit_graph(self.store,
-                                            [[1], [2, 1], [3, 1], [4, 2, 3]])
+        c1, c2, c3, c4 = build_commit_graph(
+            self.store, [[1], [2, 1], [3, 1], [4, 2, 3]]
+        )
         self.assertEqual([c2.id, c3.id], c4.parents)
         self.assertTrue(c4.commit_time > c2.commit_time)
         self.assertTrue(c4.commit_time > c3.commit_time)
 
     def test_missing_parent(self):
-        self.assertRaises(ValueError, build_commit_graph, self.store,
-                          [[1], [3, 2], [2, 1]])
+        self.assertRaises(
+            ValueError, build_commit_graph, self.store, [[1], [3, 2], [2, 1]]
+        )
 
     def test_trees(self):
-        a1 = make_object(Blob, data=b'aaa1')
-        a2 = make_object(Blob, data=b'aaa2')
-        c1, c2 = build_commit_graph(self.store, [[1], [2, 1]],
-                                    trees={1: [(b'a', a1)],
-                                           2: [(b'a', a2, 0o100644)]})
-        self.assertEqual((0o100644, a1.id), self.store[c1.tree][b'a'])
-        self.assertEqual((0o100644, a2.id), self.store[c2.tree][b'a'])
+        a1 = make_object(Blob, data=b"aaa1")
+        a2 = make_object(Blob, data=b"aaa2")
+        c1, c2 = build_commit_graph(
+            self.store,
+            [[1], [2, 1]],
+            trees={1: [(b"a", a1)], 2: [(b"a", a2, 0o100644)]},
+        )
+        self.assertEqual((0o100644, a1.id), self.store[c1.tree][b"a"])
+        self.assertEqual((0o100644, a2.id), self.store[c2.tree][b"a"])
 
     def test_attrs(self):
-        c1, c2 = build_commit_graph(self.store, [[1], [2, 1]],
-                                    attrs={1: {'message': b'Hooray!'}})
-        self.assertEqual(b'Hooray!', c1.message)
-        self.assertEqual(b'Commit 2', c2.message)
+        c1, c2 = build_commit_graph(
+            self.store, [[1], [2, 1]], attrs={1: {"message": b"Hooray!"}}
+        )
+        self.assertEqual(b"Hooray!", c1.message)
+        self.assertEqual(b"Commit 2", c2.message)
 
     def test_commit_time(self):
-        c1, c2, c3 = build_commit_graph(self.store, [[1], [2, 1], [3, 2]],
-                                        attrs={1: {'commit_time': 124},
-                                               2: {'commit_time': 123}})
+        c1, c2, c3 = build_commit_graph(
+            self.store,
+            [[1], [2, 1], [3, 2]],
+            attrs={1: {"commit_time": 124}, 2: {"commit_time": 123}},
+        )
         self.assertEqual(124, c1.commit_time)
         self.assertEqual(123, c2.commit_time)
         self.assertTrue(c2.commit_time < c1.commit_time < c3.commit_time)

+ 220 - 171
dulwich/tests/test_walk.py

@@ -22,7 +22,7 @@
 
 from itertools import (
     permutations,
-    )
+)
 from unittest import expectedFailure
 
 from dulwich.diff_tree import (
@@ -30,41 +30,37 @@ from dulwich.diff_tree import (
     CHANGE_RENAME,
     TreeChange,
     RenameDetector,
-    )
+)
 from dulwich.errors import (
     MissingCommitError,
-    )
+)
 from dulwich.object_store import (
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
     Commit,
     Blob,
-    )
-from dulwich.walk import (
-    ORDER_TOPO,
-    WalkEntry,
-    Walker,
-    _topo_reorder
-    )
+)
+from dulwich.walk import ORDER_TOPO, WalkEntry, Walker, _topo_reorder
 from dulwich.tests import TestCase
 from dulwich.tests.utils import (
     F,
     make_object,
     make_tag,
     build_commit_graph,
-    )
+)
 
 
 class TestWalkEntry(object):
-
     def __init__(self, commit, changes):
         self.commit = commit
         self.changes = changes
 
     def __repr__(self):
-        return '<TestWalkEntry commit=%s, changes=%r>' % (
-          self.commit.id, self.changes)
+        return "<TestWalkEntry commit=%s, changes=%r>" % (
+            self.commit.id,
+            self.changes,
+        )
 
     def __eq__(self, other):
         if not isinstance(other, WalkEntry) or self.commit != other.commit:
@@ -75,18 +71,16 @@ class TestWalkEntry(object):
 
 
 class WalkerTest(TestCase):
-
     def setUp(self):
         super(WalkerTest, self).setUp()
         self.store = MemoryObjectStore()
 
     def make_commits(self, commit_spec, **kwargs):
-        times = kwargs.pop('times', [])
-        attrs = kwargs.pop('attrs', {})
+        times = kwargs.pop("times", [])
+        attrs = kwargs.pop("attrs", {})
         for i, t in enumerate(times):
-            attrs.setdefault(i + 1, {})['commit_time'] = t
-        return build_commit_graph(self.store, commit_spec, attrs=attrs,
-                                  **kwargs)
+            attrs.setdefault(i + 1, {})["commit_time"] = t
+        return build_commit_graph(self.store, commit_spec, attrs=attrs, **kwargs)
 
     def make_linear_commits(self, num_commits, **kwargs):
         commit_spec = []
@@ -192,164 +186,210 @@ class WalkerTest(TestCase):
 
     def test_reverse_after_max_entries(self):
         c1, c2, c3 = self.make_linear_commits(3)
-        self.assertWalkYields([c1, c2, c3], [c3.id], max_entries=3,
-                              reverse=True)
+        self.assertWalkYields([c1, c2, c3], [c3.id], max_entries=3, reverse=True)
         self.assertWalkYields([c2, c3], [c3.id], max_entries=2, reverse=True)
         self.assertWalkYields([c3], [c3.id], max_entries=1, reverse=True)
 
     def test_changes_one_parent(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_b2 = make_object(Blob, data=b'b2')
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_b2 = make_object(Blob, data=b"b2")
         c1, c2 = self.make_linear_commits(
-            2, trees={1: [(b'a', blob_a1)],
-                      2: [(b'a', blob_a2), (b'b', blob_b2)]})
-        e1 = TestWalkEntry(c1, [TreeChange.add((b'a', F, blob_a1.id))])
+            2,
+            trees={
+                1: [(b"a", blob_a1)],
+                2: [(b"a", blob_a2), (b"b", blob_b2)],
+            },
+        )
+        e1 = TestWalkEntry(c1, [TreeChange.add((b"a", F, blob_a1.id))])
         e2 = TestWalkEntry(
-                c2,
-                [TreeChange(CHANGE_MODIFY, (b'a', F, blob_a1.id),
-                                           (b'a', F, blob_a2.id)),
-                 TreeChange.add((b'b', F, blob_b2.id))])
+            c2,
+            [
+                TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a2.id)),
+                TreeChange.add((b"b", F, blob_b2.id)),
+            ],
+        )
         self.assertWalkYields([e2, e1], [c2.id])
 
     def test_changes_multiple_parents(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_b2 = make_object(Blob, data=b'b2')
-        blob_a3 = make_object(Blob, data=b'a3')
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_b2 = make_object(Blob, data=b"b2")
+        blob_a3 = make_object(Blob, data=b"a3")
         c1, c2, c3 = self.make_commits(
             [[1], [2], [3, 1, 2]],
-            trees={1: [(b'a', blob_a1)], 2: [(b'b', blob_b2)],
-                   3: [(b'a', blob_a3), (b'b', blob_b2)]})
+            trees={
+                1: [(b"a", blob_a1)],
+                2: [(b"b", blob_b2)],
+                3: [(b"a", blob_a3), (b"b", blob_b2)],
+            },
+        )
         # a is a modify/add conflict and b is not conflicted.
-        changes = [[
-                TreeChange(CHANGE_MODIFY,
-                           (b'a', F, blob_a1.id), (b'a', F, blob_a3.id)),
-                TreeChange.add((b'a', F, blob_a3.id)),
-        ]]
-        self.assertWalkYields([TestWalkEntry(c3, changes)], [c3.id],
-                              exclude=[c1.id, c2.id])
+        changes = [
+            [
+                TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a3.id)),
+                TreeChange.add((b"a", F, blob_a3.id)),
+            ]
+        ]
+        self.assertWalkYields(
+            [TestWalkEntry(c3, changes)], [c3.id], exclude=[c1.id, c2.id]
+        )
 
     def test_path_matches(self):
-        walker = Walker(None, [], paths=[b'foo', b'bar', b'baz/quux'])
-        self.assertTrue(walker._path_matches(b'foo'))
-        self.assertTrue(walker._path_matches(b'foo/a'))
-        self.assertTrue(walker._path_matches(b'foo/a/b'))
-        self.assertTrue(walker._path_matches(b'bar'))
-        self.assertTrue(walker._path_matches(b'baz/quux'))
-        self.assertTrue(walker._path_matches(b'baz/quux/a'))
+        walker = Walker(None, [], paths=[b"foo", b"bar", b"baz/quux"])
+        self.assertTrue(walker._path_matches(b"foo"))
+        self.assertTrue(walker._path_matches(b"foo/a"))
+        self.assertTrue(walker._path_matches(b"foo/a/b"))
+        self.assertTrue(walker._path_matches(b"bar"))
+        self.assertTrue(walker._path_matches(b"baz/quux"))
+        self.assertTrue(walker._path_matches(b"baz/quux/a"))
 
         self.assertFalse(walker._path_matches(None))
-        self.assertFalse(walker._path_matches(b'oops'))
-        self.assertFalse(walker._path_matches(b'fool'))
-        self.assertFalse(walker._path_matches(b'baz'))
-        self.assertFalse(walker._path_matches(b'baz/quu'))
+        self.assertFalse(walker._path_matches(b"oops"))
+        self.assertFalse(walker._path_matches(b"fool"))
+        self.assertFalse(walker._path_matches(b"baz"))
+        self.assertFalse(walker._path_matches(b"baz/quu"))
 
     def test_paths(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_b2 = make_object(Blob, data=b'b2')
-        blob_a3 = make_object(Blob, data=b'a3')
-        blob_b3 = make_object(Blob, data=b'b3')
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_b2 = make_object(Blob, data=b"b2")
+        blob_a3 = make_object(Blob, data=b"a3")
+        blob_b3 = make_object(Blob, data=b"b3")
         c1, c2, c3 = self.make_linear_commits(
-            3, trees={1: [(b'a', blob_a1)],
-                      2: [(b'a', blob_a1), (b'x/b', blob_b2)],
-                      3: [(b'a', blob_a3), (b'x/b', blob_b3)]})
+            3,
+            trees={
+                1: [(b"a", blob_a1)],
+                2: [(b"a", blob_a1), (b"x/b", blob_b2)],
+                3: [(b"a", blob_a3), (b"x/b", blob_b3)],
+            },
+        )
 
         self.assertWalkYields([c3, c2, c1], [c3.id])
-        self.assertWalkYields([c3, c1], [c3.id], paths=[b'a'])
-        self.assertWalkYields([c3, c2], [c3.id], paths=[b'x/b'])
+        self.assertWalkYields([c3, c1], [c3.id], paths=[b"a"])
+        self.assertWalkYields([c3, c2], [c3.id], paths=[b"x/b"])
 
         # All changes are included, not just for requested paths.
         changes = [
-            TreeChange(CHANGE_MODIFY, (b'a', F, blob_a1.id),
-                       (b'a', F, blob_a3.id)),
-            TreeChange(CHANGE_MODIFY, (b'x/b', F, blob_b2.id),
-                       (b'x/b', F, blob_b3.id)),
+            TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a3.id)),
+            TreeChange(CHANGE_MODIFY, (b"x/b", F, blob_b2.id), (b"x/b", F, blob_b3.id)),
         ]
-        self.assertWalkYields([TestWalkEntry(c3, changes)], [c3.id],
-                              max_entries=1, paths=[b'a'])
+        self.assertWalkYields(
+            [TestWalkEntry(c3, changes)], [c3.id], max_entries=1, paths=[b"a"]
+        )
 
     def test_paths_subtree(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
         c1, c2, c3 = self.make_linear_commits(
-            3, trees={1: [(b'x/a', blob_a)],
-                      2: [(b'b', blob_b), (b'x/a', blob_a)],
-                      3: [(b'b', blob_b), (b'x/a', blob_a), (b'x/b', blob_b)]})
-        self.assertWalkYields([c2], [c3.id], paths=[b'b'])
-        self.assertWalkYields([c3, c1], [c3.id], paths=[b'x'])
+            3,
+            trees={
+                1: [(b"x/a", blob_a)],
+                2: [(b"b", blob_b), (b"x/a", blob_a)],
+                3: [(b"b", blob_b), (b"x/a", blob_a), (b"x/b", blob_b)],
+            },
+        )
+        self.assertWalkYields([c2], [c3.id], paths=[b"b"])
+        self.assertWalkYields([c3, c1], [c3.id], paths=[b"x"])
 
     def test_paths_max_entries(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
         c1, c2 = self.make_linear_commits(
-            2, trees={1: [(b'a', blob_a)],
-                      2: [(b'a', blob_a), (b'b', blob_b)]})
-        self.assertWalkYields([c2], [c2.id], paths=[b'b'], max_entries=1)
-        self.assertWalkYields([c1], [c1.id], paths=[b'a'], max_entries=1)
+            2, trees={1: [(b"a", blob_a)], 2: [(b"a", blob_a), (b"b", blob_b)]}
+        )
+        self.assertWalkYields([c2], [c2.id], paths=[b"b"], max_entries=1)
+        self.assertWalkYields([c1], [c1.id], paths=[b"a"], max_entries=1)
 
     def test_paths_merge(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_a3 = make_object(Blob, data=b'a3')
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_a3 = make_object(Blob, data=b"a3")
         x1, y2, m3, m4 = self.make_commits(
             [[1], [2], [3, 1, 2], [4, 1, 2]],
-            trees={1: [(b'a', blob_a1)],
-                   2: [(b'a', blob_a2)],
-                   3: [(b'a', blob_a3)],
-                   4: [(b'a', blob_a1)]})  # Non-conflicting
-        self.assertWalkYields([m3, y2, x1], [m3.id], paths=[b'a'])
-        self.assertWalkYields([y2, x1], [m4.id], paths=[b'a'])
+            trees={
+                1: [(b"a", blob_a1)],
+                2: [(b"a", blob_a2)],
+                3: [(b"a", blob_a3)],
+                4: [(b"a", blob_a1)],
+            },
+        )  # Non-conflicting
+        self.assertWalkYields([m3, y2, x1], [m3.id], paths=[b"a"])
+        self.assertWalkYields([y2, x1], [m4.id], paths=[b"a"])
 
     def test_changes_with_renames(self):
-        blob = make_object(Blob, data=b'blob')
+        blob = make_object(Blob, data=b"blob")
         c1, c2 = self.make_linear_commits(
-            2, trees={1: [(b'a', blob)], 2: [(b'b', blob)]})
-        entry_a = (b'a', F, blob.id)
-        entry_b = (b'b', F, blob.id)
-        changes_without_renames = [TreeChange.delete(entry_a),
-                                   TreeChange.add(entry_b)]
+            2, trees={1: [(b"a", blob)], 2: [(b"b", blob)]}
+        )
+        entry_a = (b"a", F, blob.id)
+        entry_b = (b"b", F, blob.id)
+        changes_without_renames = [
+            TreeChange.delete(entry_a),
+            TreeChange.add(entry_b),
+        ]
         changes_with_renames = [TreeChange(CHANGE_RENAME, entry_a, entry_b)]
         self.assertWalkYields(
-          [TestWalkEntry(c2, changes_without_renames)], [c2.id], max_entries=1)
+            [TestWalkEntry(c2, changes_without_renames)],
+            [c2.id],
+            max_entries=1,
+        )
         detector = RenameDetector(self.store)
         self.assertWalkYields(
-          [TestWalkEntry(c2, changes_with_renames)], [c2.id], max_entries=1,
-          rename_detector=detector)
+            [TestWalkEntry(c2, changes_with_renames)],
+            [c2.id],
+            max_entries=1,
+            rename_detector=detector,
+        )
 
     def test_follow_rename(self):
-        blob = make_object(Blob, data=b'blob')
-        names = [b'a', b'a', b'b', b'b', b'c', b'c']
+        blob = make_object(Blob, data=b"blob")
+        names = [b"a", b"a", b"b", b"b", b"c", b"c"]
 
-        trees = dict((i + 1, [(n, blob, F)]) for i, n in enumerate(names))
+        trees = {i + 1: [(n, blob, F)] for i, n in enumerate(names)}
         c1, c2, c3, c4, c5, c6 = self.make_linear_commits(6, trees=trees)
-        self.assertWalkYields([c5], [c6.id], paths=[b'c'])
+        self.assertWalkYields([c5], [c6.id], paths=[b"c"])
 
         def e(n):
             return (n, F, blob.id)
+
         self.assertWalkYields(
-            [TestWalkEntry(c5, [TreeChange(CHANGE_RENAME, e(b'b'), e(b'c'))]),
-             TestWalkEntry(c3, [TreeChange(CHANGE_RENAME, e(b'a'), e(b'b'))]),
-             TestWalkEntry(c1, [TreeChange.add(e(b'a'))])],
-            [c6.id], paths=[b'c'], follow=True)
+            [
+                TestWalkEntry(c5, [TreeChange(CHANGE_RENAME, e(b"b"), e(b"c"))]),
+                TestWalkEntry(c3, [TreeChange(CHANGE_RENAME, e(b"a"), e(b"b"))]),
+                TestWalkEntry(c1, [TreeChange.add(e(b"a"))]),
+            ],
+            [c6.id],
+            paths=[b"c"],
+            follow=True,
+        )
 
     def test_follow_rename_remove_path(self):
-        blob = make_object(Blob, data=b'blob')
+        blob = make_object(Blob, data=b"blob")
         _, _, _, c4, c5, c6 = self.make_linear_commits(
-            6, trees={1: [(b'a', blob), (b'c', blob)],
-                      2: [],
-                      3: [],
-                      4: [(b'b', blob)],
-                      5: [(b'a', blob)],
-                      6: [(b'c', blob)]})
+            6,
+            trees={
+                1: [(b"a", blob), (b"c", blob)],
+                2: [],
+                3: [],
+                4: [(b"b", blob)],
+                5: [(b"a", blob)],
+                6: [(b"c", blob)],
+            },
+        )
 
         def e(n):
             return (n, F, blob.id)
+
         # Once the path changes to b, we aren't interested in a or c anymore.
         self.assertWalkYields(
-            [TestWalkEntry(c6, [TreeChange(CHANGE_RENAME, e(b'a'), e(b'c'))]),
-             TestWalkEntry(c5, [TreeChange(CHANGE_RENAME, e(b'b'), e(b'a'))]),
-             TestWalkEntry(c4, [TreeChange.add(e(b'b'))])],
-            [c6.id], paths=[b'c'], follow=True)
+            [
+                TestWalkEntry(c6, [TreeChange(CHANGE_RENAME, e(b"a"), e(b"c"))]),
+                TestWalkEntry(c5, [TreeChange(CHANGE_RENAME, e(b"b"), e(b"a"))]),
+                TestWalkEntry(c4, [TreeChange.add(e(b"b"))]),
+            ],
+            [c6.id],
+            paths=[b"c"],
+            follow=True,
+        )
 
     def test_since(self):
         c1, c2, c3 = self.make_linear_commits(3)
@@ -385,8 +425,7 @@ class WalkerTest(TestCase):
         self.assertWalkYields([c2], [c3.id], since=50, until=150)
 
     def test_since_over_scan(self):
-        commits = self.make_linear_commits(
-          11, times=[9, 0, 1, 2, 3, 4, 5, 8, 6, 7, 9])
+        commits = self.make_linear_commits(11, times=[9, 0, 1, 2, 3, 4, 5, 8, 6, 7, 9])
         c8, _, c10, c11 = commits[-4:]
         del self.store[commits[0].id]
         # c9 is older than we want to walk, but is out of order with its
@@ -434,8 +473,8 @@ class WalkerTest(TestCase):
 
     def test_out_of_order_children(self):
         c1, c2, c3, c4, c5 = self.make_commits(
-          [[1], [2, 1], [3, 2], [4, 1], [5, 3, 4]],
-          times=[2, 1, 3, 4, 5])
+            [[1], [2, 1], [3, 2], [4, 1], [5, 3, 4]], times=[2, 1, 3, 4, 5]
+        )
         self.assertWalkYields([c5, c4, c3, c1, c2], [c5.id])
         self.assertWalkYields([c5, c4, c3, c2, c1], [c5.id], order=ORDER_TOPO)
 
@@ -446,8 +485,9 @@ class WalkerTest(TestCase):
         #    \-y3--y4-/--y5
         # Due to skew, y5 is the oldest commit.
         c1, x2, y3, y4, y5, m6 = self.make_commits(
-          [[1], [2, 1], [3, 1], [4, 3], [5, 4], [6, 2, 4]],
-          times=[2, 3, 4, 5, 1, 6])
+            [[1], [2, 1], [3, 1], [4, 3], [5, 4], [6, 2, 4]],
+            times=[2, 3, 4, 5, 1, 6],
+        )
         self.assertWalkYields([m6, y4, y3, x2, c1], [m6.id])
         # Ensure that c1..y4 get excluded even though they're popped from the
         # priority queue long before y5.
@@ -459,18 +499,16 @@ class WalkerTest(TestCase):
 
 
 class WalkEntryTest(TestCase):
-
     def setUp(self):
         super(WalkEntryTest, self).setUp()
         self.store = MemoryObjectStore()
 
     def make_commits(self, commit_spec, **kwargs):
-        times = kwargs.pop('times', [])
-        attrs = kwargs.pop('attrs', {})
+        times = kwargs.pop("times", [])
+        attrs = kwargs.pop("attrs", {})
         for i, t in enumerate(times):
-            attrs.setdefault(i + 1, {})['commit_time'] = t
-        return build_commit_graph(self.store, commit_spec, attrs=attrs,
-                                  **kwargs)
+            attrs.setdefault(i + 1, {})["commit_time"] = t
+        return build_commit_graph(self.store, commit_spec, attrs=attrs, **kwargs)
 
     def make_linear_commits(self, num_commits, **kwargs):
         commit_spec = []
@@ -483,11 +521,11 @@ class WalkEntryTest(TestCase):
 
     def test_all_changes(self):
         # Construct a commit with 2 files in different subdirectories.
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
         c1 = self.make_linear_commits(
             1,
-            trees={1: [(b'x/a', blob_a), (b'y/b', blob_b)]},
+            trees={1: [(b"x/a", blob_a), (b"y/b", blob_b)]},
         )[0]
 
         # Get the WalkEntry for the commit.
@@ -496,24 +534,26 @@ class WalkEntryTest(TestCase):
         changes = walker_entry.changes()
 
         # Compare the changes with the expected values.
-        entry_a = (b'x/a', F, blob_a.id)
-        entry_b = (b'y/b', F, blob_b.id)
+        entry_a = (b"x/a", F, blob_a.id)
+        entry_b = (b"y/b", F, blob_b.id)
         self.assertEqual(
-            [TreeChange.add(entry_a),
-             TreeChange.add(entry_b)],
+            [TreeChange.add(entry_a), TreeChange.add(entry_b)],
             changes,
         )
 
     def test_all_with_merge(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_b = make_object(Blob, data=b'b')
-        blob_b2 = make_object(Blob, data=b'b2')
+        blob_a = make_object(Blob, data=b"a")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_b = make_object(Blob, data=b"b")
+        blob_b2 = make_object(Blob, data=b"b2")
         x1, y2, m3 = self.make_commits(
             [[1], [2], [3, 1, 2]],
-            trees={1: [(b'x/a', blob_a)],
-                   2: [(b'y/b', blob_b)],
-                   3: [(b'x/a', blob_a2), (b'y/b', blob_b2)]})
+            trees={
+                1: [(b"x/a", blob_a)],
+                2: [(b"y/b", blob_b)],
+                3: [(b"x/a", blob_a2), (b"y/b", blob_b2)],
+            },
+        )
 
         # Get the WalkEntry for the merge commit.
         walker = Walker(self.store, m3.id)
@@ -523,60 +563,69 @@ class WalkEntryTest(TestCase):
         changes = walker_entry.changes()
         self.assertEqual(2, len(changes))
 
-        entry_a = (b'x/a', F, blob_a.id)
-        entry_a2 = (b'x/a', F, blob_a2.id)
-        entry_b = (b'y/b', F, blob_b.id)
-        entry_b2 = (b'y/b', F, blob_b2.id)
+        entry_a = (b"x/a", F, blob_a.id)
+        entry_a2 = (b"x/a", F, blob_a2.id)
+        entry_b = (b"y/b", F, blob_b.id)
+        entry_b2 = (b"y/b", F, blob_b2.id)
         self.assertEqual(
-                [[TreeChange(CHANGE_MODIFY, entry_a, entry_a2),
-                  TreeChange.add(entry_a2)],
-                 [TreeChange.add(entry_b2),
-                  TreeChange(CHANGE_MODIFY, entry_b, entry_b2)]],
-                changes,
+            [
+                [
+                    TreeChange(CHANGE_MODIFY, entry_a, entry_a2),
+                    TreeChange.add(entry_a2),
+                ],
+                [
+                    TreeChange.add(entry_b2),
+                    TreeChange(CHANGE_MODIFY, entry_b, entry_b2),
+                ],
+            ],
+            changes,
         )
 
     def test_filter_changes(self):
         # Construct a commit with 2 files in different subdirectories.
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
         c1 = self.make_linear_commits(
             1,
-            trees={1: [(b'x/a', blob_a), (b'y/b', blob_b)]},
+            trees={1: [(b"x/a", blob_a), (b"y/b", blob_b)]},
         )[0]
 
         # Get the WalkEntry for the commit.
         walker = Walker(self.store, c1.id)
         walker_entry = list(walker)[0]
-        changes = walker_entry.changes(path_prefix=b'x')
+        changes = walker_entry.changes(path_prefix=b"x")
 
         # Compare the changes with the expected values.
-        entry_a = (b'a', F, blob_a.id)
+        entry_a = (b"a", F, blob_a.id)
         self.assertEqual(
             [TreeChange.add(entry_a)],
             changes,
         )
 
     def test_filter_with_merge(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_b = make_object(Blob, data=b'b')
-        blob_b2 = make_object(Blob, data=b'b2')
+        blob_a = make_object(Blob, data=b"a")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_b = make_object(Blob, data=b"b")
+        blob_b2 = make_object(Blob, data=b"b2")
         x1, y2, m3 = self.make_commits(
             [[1], [2], [3, 1, 2]],
-            trees={1: [(b'x/a', blob_a)],
-                   2: [(b'y/b', blob_b)],
-                   3: [(b'x/a', blob_a2), (b'y/b', blob_b2)]})
+            trees={
+                1: [(b"x/a", blob_a)],
+                2: [(b"y/b", blob_b)],
+                3: [(b"x/a", blob_a2), (b"y/b", blob_b2)],
+            },
+        )
 
         # Get the WalkEntry for the merge commit.
         walker = Walker(self.store, m3.id)
         entries = list(walker)
         walker_entry = entries[0]
         self.assertEqual(walker_entry.commit.id, m3.id)
-        changes = walker_entry.changes(b'x')
+        changes = walker_entry.changes(b"x")
         self.assertEqual(1, len(changes))
 
-        entry_a = (b'a', F, blob_a.id)
-        entry_a2 = (b'a', F, blob_a2.id)
+        entry_a = (b"a", F, blob_a.id)
+        entry_a2 = (b"a", F, blob_a2.id)
         self.assertEqual(
             [[TreeChange(CHANGE_MODIFY, entry_a, entry_a2)]],
             changes,

+ 187 - 170
dulwich/tests/test_web.py

@@ -28,20 +28,20 @@ from typing import Type
 
 from dulwich.object_store import (
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
     Blob,
-    )
+)
 from dulwich.repo import (
     BaseRepo,
     MemoryRepo,
-    )
+)
 from dulwich.server import (
     DictBackend,
-    )
+)
 from dulwich.tests import (
     TestCase,
-    )
+)
 from dulwich.web import (
     HTTP_OK,
     HTTP_NOT_FOUND,
@@ -59,16 +59,17 @@ from dulwich.web import (
     _LengthLimitedFile,
     HTTPGitRequest,
     HTTPGitApplication,
-    )
+)
 
 from dulwich.tests.utils import (
     make_object,
     make_tag,
-    )
+)
 
 
 class MinimalistWSGIInputStream(object):
     """WSGI input stream with no 'seek()' and 'tell()' methods."""
+
     def __init__(self, data):
         self.data = data
         self.pos = 0
@@ -77,13 +78,14 @@ class MinimalistWSGIInputStream(object):
         start = self.pos
         end = self.pos + howmuch
         if start >= len(self.data):
-            return ''
+            return ""
         self.pos = end
         return self.data[start:end]
 
 
 class MinimalistWSGIInputStream2(MinimalistWSGIInputStream):
     """WSGI input stream with no *working* 'seek()' and 'tell()' methods."""
+
     def seek(self, pos):
         raise NotImplementedError
 
@@ -113,8 +115,9 @@ class WebTestCase(TestCase):
     def setUp(self):
         super(WebTestCase, self).setUp()
         self._environ = {}
-        self._req = self._req_class(self._environ, self._start_response,
-                                    handlers=self._handlers())
+        self._req = self._req_class(
+            self._environ, self._start_response, handlers=self._handlers()
+        )
         self._status = None
         self._headers = []
         self._output = BytesIO()
@@ -128,7 +131,7 @@ class WebTestCase(TestCase):
         return None
 
     def assertContentTypeEquals(self, expected):
-        self.assertTrue(('Content-Type', expected) in self._headers)
+        self.assertTrue(("Content-Type", expected) in self._headers)
 
 
 def _test_backend(objects, refs=None, named_files=None):
@@ -139,31 +142,29 @@ def _test_backend(objects, refs=None, named_files=None):
     repo = MemoryRepo.init_bare(objects, refs)
     for path, contents in named_files.items():
         repo._put_named_file(path, contents)
-    return DictBackend({'/': repo})
+    return DictBackend({"/": repo})
 
 
 class DumbHandlersTestCase(WebTestCase):
-
     def test_send_file_not_found(self):
-        list(send_file(self._req, None, 'text/plain'))
+        list(send_file(self._req, None, "text/plain"))
         self.assertEqual(HTTP_NOT_FOUND, self._status)
 
     def test_send_file(self):
-        f = BytesIO(b'foobar')
-        output = b''.join(send_file(self._req, f, 'some/thing'))
-        self.assertEqual(b'foobar', output)
+        f = BytesIO(b"foobar")
+        output = b"".join(send_file(self._req, f, "some/thing"))
+        self.assertEqual(b"foobar", output)
         self.assertEqual(HTTP_OK, self._status)
-        self.assertContentTypeEquals('some/thing')
+        self.assertContentTypeEquals("some/thing")
         self.assertTrue(f.closed)
 
     def test_send_file_buffered(self):
         bufsize = 10240
-        xs = b'x' * bufsize
+        xs = b"x" * bufsize
         f = BytesIO(2 * xs)
-        self.assertEqual([xs, xs],
-                         list(send_file(self._req, f, 'some/thing')))
+        self.assertEqual([xs, xs], list(send_file(self._req, f, "some/thing")))
         self.assertEqual(HTTP_OK, self._status)
-        self.assertContentTypeEquals('some/thing')
+        self.assertContentTypeEquals("some/thing")
         self.assertTrue(f.closed)
 
     def test_send_file_error(self):
@@ -179,122 +180,123 @@ class DumbHandlersTestCase(WebTestCase):
                 self.closed = True
 
         f = TestFile(IOError)
-        list(send_file(self._req, f, 'some/thing'))
+        list(send_file(self._req, f, "some/thing"))
         self.assertEqual(HTTP_ERROR, self._status)
         self.assertTrue(f.closed)
         self.assertFalse(self._req.cached)
 
         # non-IOErrors are reraised
         f = TestFile(AttributeError)
-        self.assertRaises(AttributeError, list,
-                          send_file(self._req, f, 'some/thing'))
+        self.assertRaises(AttributeError, list, send_file(self._req, f, "some/thing"))
         self.assertTrue(f.closed)
         self.assertFalse(self._req.cached)
 
     def test_get_text_file(self):
-        backend = _test_backend([], named_files={'description': b'foo'})
-        mat = re.search('.*', 'description')
-        output = b''.join(get_text_file(self._req, backend, mat))
-        self.assertEqual(b'foo', output)
+        backend = _test_backend([], named_files={"description": b"foo"})
+        mat = re.search(".*", "description")
+        output = b"".join(get_text_file(self._req, backend, mat))
+        self.assertEqual(b"foo", output)
         self.assertEqual(HTTP_OK, self._status)
-        self.assertContentTypeEquals('text/plain')
+        self.assertContentTypeEquals("text/plain")
         self.assertFalse(self._req.cached)
 
     def test_get_loose_object(self):
-        blob = make_object(Blob, data=b'foo')
+        blob = make_object(Blob, data=b"foo")
         backend = _test_backend([blob])
-        mat = re.search('^(..)(.{38})$', blob.id.decode('ascii'))
-        output = b''.join(get_loose_object(self._req, backend, mat))
+        mat = re.search("^(..)(.{38})$", blob.id.decode("ascii"))
+        output = b"".join(get_loose_object(self._req, backend, mat))
         self.assertEqual(blob.as_legacy_object(), output)
         self.assertEqual(HTTP_OK, self._status)
-        self.assertContentTypeEquals('application/x-git-loose-object')
+        self.assertContentTypeEquals("application/x-git-loose-object")
         self.assertTrue(self._req.cached)
 
     def test_get_loose_object_missing(self):
-        mat = re.search('^(..)(.{38})$', '1' * 40)
+        mat = re.search("^(..)(.{38})$", "1" * 40)
         list(get_loose_object(self._req, _test_backend([]), mat))
         self.assertEqual(HTTP_NOT_FOUND, self._status)
 
     def test_get_loose_object_error(self):
-        blob = make_object(Blob, data=b'foo')
+        blob = make_object(Blob, data=b"foo")
         backend = _test_backend([blob])
-        mat = re.search('^(..)(.{38})$', blob.id.decode('ascii'))
+        mat = re.search("^(..)(.{38})$", blob.id.decode("ascii"))
 
         def as_legacy_object_error(self):
             raise IOError
 
-        self.addCleanup(
-            setattr, Blob, 'as_legacy_object', Blob.as_legacy_object)
+        self.addCleanup(setattr, Blob, "as_legacy_object", Blob.as_legacy_object)
         Blob.as_legacy_object = as_legacy_object_error
         list(get_loose_object(self._req, backend, mat))
         self.assertEqual(HTTP_ERROR, self._status)
 
     def test_get_pack_file(self):
-        pack_name = os.path.join(
-            'objects', 'pack', 'pack-%s.pack' % ('1' * 40))
-        backend = _test_backend([], named_files={pack_name: b'pack contents'})
-        mat = re.search('.*', pack_name)
-        output = b''.join(get_pack_file(self._req, backend, mat))
-        self.assertEqual(b'pack contents', output)
+        pack_name = os.path.join("objects", "pack", "pack-%s.pack" % ("1" * 40))
+        backend = _test_backend([], named_files={pack_name: b"pack contents"})
+        mat = re.search(".*", pack_name)
+        output = b"".join(get_pack_file(self._req, backend, mat))
+        self.assertEqual(b"pack contents", output)
         self.assertEqual(HTTP_OK, self._status)
-        self.assertContentTypeEquals('application/x-git-packed-objects')
+        self.assertContentTypeEquals("application/x-git-packed-objects")
         self.assertTrue(self._req.cached)
 
     def test_get_idx_file(self):
-        idx_name = os.path.join('objects', 'pack', 'pack-%s.idx' % ('1' * 40))
-        backend = _test_backend([], named_files={idx_name: b'idx contents'})
-        mat = re.search('.*', idx_name)
-        output = b''.join(get_idx_file(self._req, backend, mat))
-        self.assertEqual(b'idx contents', output)
+        idx_name = os.path.join("objects", "pack", "pack-%s.idx" % ("1" * 40))
+        backend = _test_backend([], named_files={idx_name: b"idx contents"})
+        mat = re.search(".*", idx_name)
+        output = b"".join(get_idx_file(self._req, backend, mat))
+        self.assertEqual(b"idx contents", output)
         self.assertEqual(HTTP_OK, self._status)
-        self.assertContentTypeEquals('application/x-git-packed-objects-toc')
+        self.assertContentTypeEquals("application/x-git-packed-objects-toc")
         self.assertTrue(self._req.cached)
 
     def test_get_info_refs(self):
-        self._environ['QUERY_STRING'] = ''
+        self._environ["QUERY_STRING"] = ""
 
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        blob3 = make_object(Blob, data=b'3')
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        blob3 = make_object(Blob, data=b"3")
 
-        tag1 = make_tag(blob2, name=b'tag-tag')
+        tag1 = make_tag(blob2, name=b"tag-tag")
 
         objects = [blob1, blob2, blob3, tag1]
         refs = {
-          b'HEAD': b'000',
-          b'refs/heads/master': blob1.id,
-          b'refs/tags/tag-tag': tag1.id,
-          b'refs/tags/blob-tag': blob3.id,
-          }
+            b"HEAD": b"000",
+            b"refs/heads/master": blob1.id,
+            b"refs/tags/tag-tag": tag1.id,
+            b"refs/tags/blob-tag": blob3.id,
+        }
         backend = _test_backend(objects, refs=refs)
 
-        mat = re.search('.*', '//info/refs')
-        self.assertEqual([blob1.id + b'\trefs/heads/master\n',
-                          blob3.id + b'\trefs/tags/blob-tag\n',
-                          tag1.id + b'\trefs/tags/tag-tag\n',
-                          blob2.id + b'\trefs/tags/tag-tag^{}\n'],
-                         list(get_info_refs(self._req, backend, mat)))
+        mat = re.search(".*", "//info/refs")
+        self.assertEqual(
+            [
+                blob1.id + b"\trefs/heads/master\n",
+                blob3.id + b"\trefs/tags/blob-tag\n",
+                tag1.id + b"\trefs/tags/tag-tag\n",
+                blob2.id + b"\trefs/tags/tag-tag^{}\n",
+            ],
+            list(get_info_refs(self._req, backend, mat)),
+        )
         self.assertEqual(HTTP_OK, self._status)
-        self.assertContentTypeEquals('text/plain')
+        self.assertContentTypeEquals("text/plain")
         self.assertFalse(self._req.cached)
 
     def test_get_info_refs_not_found(self):
-        self._environ['QUERY_STRING'] = ''
+        self._environ["QUERY_STRING"] = ""
 
         objects = []
         refs = {}
         backend = _test_backend(objects, refs=refs)
 
-        mat = re.search('info/refs', '/foo/info/refs')
+        mat = re.search("info/refs", "/foo/info/refs")
         self.assertEqual(
-            [b'No git repository was found at /foo'],
-            list(get_info_refs(self._req, backend, mat)))
+            [b"No git repository was found at /foo"],
+            list(get_info_refs(self._req, backend, mat)),
+        )
         self.assertEqual(HTTP_NOT_FOUND, self._status)
-        self.assertContentTypeEquals('text/plain')
+        self.assertContentTypeEquals("text/plain")
 
     def test_get_info_packs(self):
         class TestPackData(object):
-
             def __init__(self, sha):
                 self.filename = "pack-%s.pack" % sha
 
@@ -312,61 +314,66 @@ class DumbHandlersTestCase(WebTestCase):
 
         store = TestObjectStore()
         repo = BaseRepo(store, None)
-        backend = DictBackend({'/': repo})
-        mat = re.search('.*', '//info/packs')
-        output = b''.join(get_info_packs(self._req, backend, mat))
-        expected = b''.join(
-            [(b'P pack-' + s + b'.pack\n')
-             for s in [b'1' * 40, b'2' * 40, b'3' * 40]])
+        backend = DictBackend({"/": repo})
+        mat = re.search(".*", "//info/packs")
+        output = b"".join(get_info_packs(self._req, backend, mat))
+        expected = b"".join(
+            [(b"P pack-" + s + b".pack\n") for s in [b"1" * 40, b"2" * 40, b"3" * 40]]
+        )
         self.assertEqual(expected, output)
         self.assertEqual(HTTP_OK, self._status)
-        self.assertContentTypeEquals('text/plain')
+        self.assertContentTypeEquals("text/plain")
         self.assertFalse(self._req.cached)
 
 
 class SmartHandlersTestCase(WebTestCase):
-
     class _TestUploadPackHandler(object):
-        def __init__(self, backend, args, proto, stateless_rpc=None,
-                     advertise_refs=False):
+        def __init__(
+            self,
+            backend,
+            args,
+            proto,
+            stateless_rpc=None,
+            advertise_refs=False,
+        ):
             self.args = args
             self.proto = proto
             self.stateless_rpc = stateless_rpc
             self.advertise_refs = advertise_refs
 
         def handle(self):
-            self.proto.write(b'handled input: ' + self.proto.recv(1024))
+            self.proto.write(b"handled input: " + self.proto.recv(1024))
 
     def _make_handler(self, *args, **kwargs):
         self._handler = self._TestUploadPackHandler(*args, **kwargs)
         return self._handler
 
     def _handlers(self):
-        return {b'git-upload-pack': self._make_handler}
+        return {b"git-upload-pack": self._make_handler}
 
     def test_handle_service_request_unknown(self):
-        mat = re.search('.*', '/git-evil-handler')
-        content = list(handle_service_request(self._req, 'backend', mat))
+        mat = re.search(".*", "/git-evil-handler")
+        content = list(handle_service_request(self._req, "backend", mat))
         self.assertEqual(HTTP_FORBIDDEN, self._status)
-        self.assertFalse(b'git-evil-handler' in b"".join(content))
+        self.assertFalse(b"git-evil-handler" in b"".join(content))
         self.assertFalse(self._req.cached)
 
     def _run_handle_service_request(self, content_length=None):
-        self._environ['wsgi.input'] = BytesIO(b'foo')
+        self._environ["wsgi.input"] = BytesIO(b"foo")
         if content_length is not None:
-            self._environ['CONTENT_LENGTH'] = content_length
-        mat = re.search('.*', '/git-upload-pack')
+            self._environ["CONTENT_LENGTH"] = content_length
+        mat = re.search(".*", "/git-upload-pack")
 
         class Backend(object):
             def open_repository(self, path):
                 return None
-        handler_output = b''.join(
-          handle_service_request(self._req, Backend(), mat))
+
+        handler_output = b"".join(handle_service_request(self._req, Backend(), mat))
         write_output = self._output.getvalue()
         # Ensure all output was written via the write callback.
-        self.assertEqual(b'', handler_output)
-        self.assertEqual(b'handled input: foo', write_output)
-        self.assertContentTypeEquals('application/x-git-upload-pack-result')
+        self.assertEqual(b"", handler_output)
+        self.assertEqual(b"handled input: foo", write_output)
+        self.assertContentTypeEquals("application/x-git-upload-pack-result")
         self.assertFalse(self._handler.advertise_refs)
         self.assertTrue(self._handler.stateless_rpc)
         self.assertFalse(self._req.cached)
@@ -375,42 +382,46 @@ class SmartHandlersTestCase(WebTestCase):
         self._run_handle_service_request()
 
     def test_handle_service_request_with_length(self):
-        self._run_handle_service_request(content_length='3')
+        self._run_handle_service_request(content_length="3")
 
     def test_handle_service_request_empty_length(self):
-        self._run_handle_service_request(content_length='')
+        self._run_handle_service_request(content_length="")
 
     def test_get_info_refs_unknown(self):
-        self._environ['QUERY_STRING'] = 'service=git-evil-handler'
+        self._environ["QUERY_STRING"] = "service=git-evil-handler"
 
         class Backend(object):
             def open_repository(self, url):
                 return None
 
-        mat = re.search('.*', '/git-evil-pack')
+        mat = re.search(".*", "/git-evil-pack")
         content = list(get_info_refs(self._req, Backend(), mat))
-        self.assertFalse(b'git-evil-handler' in b"".join(content))
+        self.assertFalse(b"git-evil-handler" in b"".join(content))
         self.assertEqual(HTTP_FORBIDDEN, self._status)
         self.assertFalse(self._req.cached)
 
     def test_get_info_refs(self):
-        self._environ['wsgi.input'] = BytesIO(b'foo')
-        self._environ['QUERY_STRING'] = 'service=git-upload-pack'
+        self._environ["wsgi.input"] = BytesIO(b"foo")
+        self._environ["QUERY_STRING"] = "service=git-upload-pack"
 
         class Backend(object):
-
             def open_repository(self, url):
                 return None
 
-        mat = re.search('.*', '/git-upload-pack')
-        handler_output = b''.join(get_info_refs(self._req, Backend(), mat))
+        mat = re.search(".*", "/git-upload-pack")
+        handler_output = b"".join(get_info_refs(self._req, Backend(), mat))
         write_output = self._output.getvalue()
-        self.assertEqual((b'001e# service=git-upload-pack\n'
-                          b'0000'
-                          # input is ignored by the handler
-                          b'handled input: '), write_output)
+        self.assertEqual(
+            (
+                b"001e# service=git-upload-pack\n"
+                b"0000"
+                # input is ignored by the handler
+                b"handled input: "
+            ),
+            write_output,
+        )
         # Ensure all output was written via the write callback.
-        self.assertEqual(b'', handler_output)
+        self.assertEqual(b"", handler_output)
         self.assertTrue(self._handler.advertise_refs)
         self.assertTrue(self._handler.stateless_rpc)
         self.assertFalse(self._req.cached)
@@ -418,19 +429,19 @@ class SmartHandlersTestCase(WebTestCase):
 
 class LengthLimitedFileTestCase(TestCase):
     def test_no_cutoff(self):
-        f = _LengthLimitedFile(BytesIO(b'foobar'), 1024)
-        self.assertEqual(b'foobar', f.read())
+        f = _LengthLimitedFile(BytesIO(b"foobar"), 1024)
+        self.assertEqual(b"foobar", f.read())
 
     def test_cutoff(self):
-        f = _LengthLimitedFile(BytesIO(b'foobar'), 3)
-        self.assertEqual(b'foo', f.read())
-        self.assertEqual(b'', f.read())
+        f = _LengthLimitedFile(BytesIO(b"foobar"), 3)
+        self.assertEqual(b"foo", f.read())
+        self.assertEqual(b"", f.read())
 
     def test_multiple_reads(self):
-        f = _LengthLimitedFile(BytesIO(b'foobar'), 3)
-        self.assertEqual(b'fo', f.read(2))
-        self.assertEqual(b'o', f.read(2))
-        self.assertEqual(b'', f.read())
+        f = _LengthLimitedFile(BytesIO(b"foobar"), 3)
+        self.assertEqual(b"fo", f.read(2))
+        self.assertEqual(b"o", f.read(2))
+        self.assertEqual(b"", f.read())
 
 
 class HTTPGitRequestTestCase(WebTestCase):
@@ -440,19 +451,17 @@ class HTTPGitRequestTestCase(WebTestCase):
 
     def test_not_found(self):
         self._req.cache_forever()  # cache headers should be discarded
-        message = 'Something not found'
-        self.assertEqual(message.encode('ascii'), self._req.not_found(message))
+        message = "Something not found"
+        self.assertEqual(message.encode("ascii"), self._req.not_found(message))
         self.assertEqual(HTTP_NOT_FOUND, self._status)
-        self.assertEqual(set([('Content-Type', 'text/plain')]),
-                         set(self._headers))
+        self.assertEqual(set([("Content-Type", "text/plain")]), set(self._headers))
 
     def test_forbidden(self):
         self._req.cache_forever()  # cache headers should be discarded
-        message = 'Something not found'
-        self.assertEqual(message.encode('ascii'), self._req.forbidden(message))
+        message = "Something not found"
+        self.assertEqual(message.encode("ascii"), self._req.forbidden(message))
         self.assertEqual(HTTP_FORBIDDEN, self._status)
-        self.assertEqual(set([('Content-Type', 'text/plain')]),
-                         set(self._headers))
+        self.assertEqual(set([("Content-Type", "text/plain")]), set(self._headers))
 
     def test_respond_ok(self):
         self._req.respond()
@@ -461,70 +470,77 @@ class HTTPGitRequestTestCase(WebTestCase):
 
     def test_respond(self):
         self._req.nocache()
-        self._req.respond(status=402, content_type='some/type',
-                          headers=[('X-Foo', 'foo'), ('X-Bar', 'bar')])
-        self.assertEqual(set([
-          ('X-Foo', 'foo'),
-          ('X-Bar', 'bar'),
-          ('Content-Type', 'some/type'),
-          ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
-          ('Pragma', 'no-cache'),
-          ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
-          ]), set(self._headers))
+        self._req.respond(
+            status=402,
+            content_type="some/type",
+            headers=[("X-Foo", "foo"), ("X-Bar", "bar")],
+        )
+        self.assertEqual(
+            set(
+                [
+                    ("X-Foo", "foo"),
+                    ("X-Bar", "bar"),
+                    ("Content-Type", "some/type"),
+                    ("Expires", "Fri, 01 Jan 1980 00:00:00 GMT"),
+                    ("Pragma", "no-cache"),
+                    ("Cache-Control", "no-cache, max-age=0, must-revalidate"),
+                ]
+            ),
+            set(self._headers),
+        )
         self.assertEqual(402, self._status)
 
 
 class HTTPGitApplicationTestCase(TestCase):
-
     def setUp(self):
         super(HTTPGitApplicationTestCase, self).setUp()
-        self._app = HTTPGitApplication('backend')
+        self._app = HTTPGitApplication("backend")
 
         self._environ = {
-            'PATH_INFO': '/foo',
-            'REQUEST_METHOD': 'GET',
+            "PATH_INFO": "/foo",
+            "REQUEST_METHOD": "GET",
         }
 
     def _test_handler(self, req, backend, mat):
         # tests interface used by all handlers
         self.assertEqual(self._environ, req.environ)
-        self.assertEqual('backend', backend)
-        self.assertEqual('/foo', mat.group(0))
-        return 'output'
+        self.assertEqual("backend", backend)
+        self.assertEqual("/foo", mat.group(0))
+        return "output"
 
     def _add_handler(self, app):
-        req = self._environ['REQUEST_METHOD']
+        req = self._environ["REQUEST_METHOD"]
         app.services = {
-          (req, re.compile('/foo$')): self._test_handler,
+            (req, re.compile("/foo$")): self._test_handler,
         }
 
     def test_call(self):
         self._add_handler(self._app)
-        self.assertEqual('output', self._app(self._environ, None))
+        self.assertEqual("output", self._app(self._environ, None))
 
     def test_fallback_app(self):
         def test_app(environ, start_response):
-            return 'output'
+            return "output"
 
-        app = HTTPGitApplication('backend', fallback_app=test_app)
-        self.assertEqual('output', app(self._environ, None))
+        app = HTTPGitApplication("backend", fallback_app=test_app)
+        self.assertEqual("output", app(self._environ, None))
 
 
 class GunzipTestCase(HTTPGitApplicationTestCase):
     __doc__ = """TestCase for testing the GunzipFilter, ensuring the wsgi.input
     is correctly decompressed and headers are corrected.
     """
-    example_text = __doc__.encode('ascii')
+    example_text = __doc__.encode("ascii")
 
     def setUp(self):
         super(GunzipTestCase, self).setUp()
         self._app = GunzipFilter(self._app)
-        self._environ['HTTP_CONTENT_ENCODING'] = 'gzip'
-        self._environ['REQUEST_METHOD'] = 'POST'
+        self._environ["HTTP_CONTENT_ENCODING"] = "gzip"
+        self._environ["REQUEST_METHOD"] = "POST"
 
     def _get_zstream(self, text):
         zstream = BytesIO()
-        zfile = gzip.GzipFile(fileobj=zstream, mode='w')
+        zfile = gzip.GzipFile(fileobj=zstream, mode="w")
         zfile.write(text)
         zfile.close()
         zlength = zstream.tell()
@@ -534,22 +550,19 @@ class GunzipTestCase(HTTPGitApplicationTestCase):
     def _test_call(self, orig, zstream, zlength):
         self._add_handler(self._app.app)
         self.assertLess(zlength, len(orig))
-        self.assertEqual(self._environ['HTTP_CONTENT_ENCODING'], 'gzip')
-        self._environ['CONTENT_LENGTH'] = zlength
-        self._environ['wsgi.input'] = zstream
+        self.assertEqual(self._environ["HTTP_CONTENT_ENCODING"], "gzip")
+        self._environ["CONTENT_LENGTH"] = zlength
+        self._environ["wsgi.input"] = zstream
         self._app(self._environ, None)
-        buf = self._environ['wsgi.input']
+        buf = self._environ["wsgi.input"]
         self.assertIsNot(buf, zstream)
         buf.seek(0)
         self.assertEqual(orig, buf.read())
-        self.assertIs(None, self._environ.get('CONTENT_LENGTH'))
-        self.assertNotIn('HTTP_CONTENT_ENCODING', self._environ)
+        self.assertIs(None, self._environ.get("CONTENT_LENGTH"))
+        self.assertNotIn("HTTP_CONTENT_ENCODING", self._environ)
 
     def test_call(self):
-        self._test_call(
-            self.example_text,
-            *self._get_zstream(self.example_text)
-        )
+        self._test_call(self.example_text, *self._get_zstream(self.example_text))
 
     def test_call_no_seek(self):
         """
@@ -560,7 +573,9 @@ class GunzipTestCase(HTTPGitApplicationTestCase):
         zstream, zlength = self._get_zstream(self.example_text)
         self._test_call(
             self.example_text,
-            MinimalistWSGIInputStream(zstream.read()), zlength)
+            MinimalistWSGIInputStream(zstream.read()),
+            zlength,
+        )
 
     def test_call_no_working_seek(self):
         """
@@ -569,5 +584,7 @@ class GunzipTestCase(HTTPGitApplicationTestCase):
         """
         zstream, zlength = self._get_zstream(self.example_text)
         self._test_call(
-                self.example_text,
-                MinimalistWSGIInputStream2(zstream.read()), zlength)
+            self.example_text,
+            MinimalistWSGIInputStream2(zstream.read()),
+            zlength,
+        )

部分文件因文件數量過多而無法顯示