2
0
Jelmer Vernooij 3 жил өмнө
parent
commit
1c94849ee7
100 өөрчлөгдсөн 11115 нэмэгдсэн , 7620 устгасан
  1. 12 0
      .deepsource.toml
  2. 5 0
      .flake8
  3. 9 0
      .github/workflows/pythonpackage.yml
  4. 17 0
      .github/workflows/pythonpublish.yml
  5. 1 0
      AUTHORS
  6. 86 0
      NEWS
  7. 6 3
      PKG-INFO
  8. 4 1
      README.rst
  9. 3 2
      debian/changelog
  10. 3 2
      docs/tutorial/file-format.txt
  11. 4 1
      docs/tutorial/remote.txt
  12. 6 3
      dulwich.egg-info/PKG-INFO
  13. 10 1
      dulwich.egg-info/SOURCES.txt
  14. 1 1
      dulwich/__init__.py
  15. 10 9
      dulwich/archive.py
  16. 22 23
      dulwich/bundle.py
  17. 181 159
      dulwich/cli.py
  18. 322 183
      dulwich/client.py
  19. 0 0
      dulwich/cloud/__init__.py
  20. 82 0
      dulwich/cloud/gcs.py
  21. 115 64
      dulwich/config.py
  22. 5 4
      dulwich/contrib/__init__.py
  23. 45 36
      dulwich/contrib/diffstat.py
  24. 18 12
      dulwich/contrib/paramiko_vendor.py
  25. 11 11
      dulwich/contrib/release_robot.py
  26. 221 188
      dulwich/contrib/swift.py
  27. 189 0
      dulwich/contrib/test_paramiko_vendor.py
  28. 28 21
      dulwich/contrib/test_release_robot.py
  29. 215 185
      dulwich/contrib/test_swift.py
  30. 102 99
      dulwich/contrib/test_swift_smoke.py
  31. 74 53
      dulwich/diff_tree.py
  32. 22 19
      dulwich/errors.py
  33. 39 24
      dulwich/fastexport.py
  34. 41 20
      dulwich/file.py
  35. 2 2
      dulwich/graph.py
  36. 23 19
      dulwich/greenthreads.py
  37. 38 28
      dulwich/hooks.py
  38. 86 85
      dulwich/ignore.py
  39. 216 115
      dulwich/index.py
  40. 7 8
      dulwich/lfs.py
  41. 40 12
      dulwich/line_ending.py
  42. 6 3
      dulwich/log_utils.py
  43. 35 26
      dulwich/lru_cache.py
  44. 6 5
      dulwich/mailmap.py
  45. 292 115
      dulwich/object_store.py
  46. 330 213
      dulwich/objects.py
  47. 3 4
      dulwich/objectspec.py
  48. 230 189
      dulwich/pack.py
  49. 100 69
      dulwich/patch.py
  50. 320 186
      dulwich/porcelain.py
  51. 84 75
      dulwich/protocol.py
  52. 88 13
      dulwich/reflog.py
  53. 295 142
      dulwich/refs.py
  54. 278 191
      dulwich/repo.py
  55. 209 157
      dulwich/server.py
  56. 37 17
      dulwich/stash.py
  57. 70 56
      dulwich/tests/__init__.py
  58. 10 9
      dulwich/tests/compat/__init__.py
  59. 148 94
      dulwich/tests/compat/server_utils.py
  60. 220 176
      dulwich/tests/compat/test_client.py
  61. 55 39
      dulwich/tests/compat/test_pack.py
  62. 6 6
      dulwich/tests/compat/test_patch.py
  63. 101 0
      dulwich/tests/compat/test_porcelain.py
  64. 41 43
      dulwich/tests/compat/test_repository.py
  65. 14 17
      dulwich/tests/compat/test_server.py
  66. 12 14
      dulwich/tests/compat/test_utils.py
  67. 36 34
      dulwich/tests/compat/test_web.py
  68. 44 35
      dulwich/tests/compat/utils.py
  69. 15 18
      dulwich/tests/test_archive.py
  70. 9 9
      dulwich/tests/test_blackbox.py
  71. 7 8
      dulwich/tests/test_bundle.py
  72. 312 286
      dulwich/tests/test_client.py
  73. 172 101
      dulwich/tests/test_config.py
  74. 787 584
      dulwich/tests/test_diff_tree.py
  75. 112 56
      dulwich/tests/test_fastexport.py
  76. 62 64
      dulwich/tests/test_file.py
  77. 67 65
      dulwich/tests/test_grafts.py
  78. 76 77
      dulwich/tests/test_graph.py
  79. 18 15
      dulwich/tests/test_greenthreads.py
  80. 51 34
      dulwich/tests/test_hooks.py
  81. 126 115
      dulwich/tests/test_ignore.py
  82. 296 221
      dulwich/tests/test_index.py
  83. 3 5
      dulwich/tests/test_lfs.py
  84. 4 13
      dulwich/tests/test_line_ending.py
  85. 89 88
      dulwich/tests/test_lru_cache.py
  86. 51 40
      dulwich/tests/test_mailmap.py
  87. 142 85
      dulwich/tests/test_missing_obj_finder.py
  88. 257 170
      dulwich/tests/test_object_store.py
  89. 316 284
      dulwich/tests/test_objects.py
  90. 78 56
      dulwich/tests/test_objectspec.py
  91. 313 250
      dulwich/tests/test_pack.py
  92. 311 216
      dulwich/tests/test_patch.py
  93. 470 201
      dulwich/tests/test_porcelain.py
  94. 84 84
      dulwich/tests/test_protocol.py
  95. 102 28
      dulwich/tests/test_reflog.py
  96. 432 354
      dulwich/tests/test_refs.py
  97. 416 322
      dulwich/tests/test_repository.py
  98. 297 256
      dulwich/tests/test_server.py
  99. 29 23
      dulwich/tests/test_utils.py
  100. 220 171
      dulwich/tests/test_walk.py

+ 12 - 0
.deepsource.toml

@@ -0,0 +1,12 @@
+version = 1
+
+test_patterns = ["dulwich/**test_*.py"]
+
+exclude_patterns = ["examples/**"]
+
+[[analyzers]]
+name = "python"
+enabled = true
+
+  [analyzers.meta]
+  runtime_version = "3.x.x"

+ 5 - 0
.flake8

@@ -0,0 +1,5 @@
+[flake8]
+extend-ignore = E203, E266, E501, W293, W291
+max-line-length = 88
+max-complexity = 18
+select = B,C,E,F,W,T4,B9

+ 9 - 0
.github/workflows/pythonpackage.yml

@@ -31,10 +31,19 @@ jobs:
       uses: actions/setup-python@v2
       uses: actions/setup-python@v2
       with:
       with:
         python-version: ${{ matrix.python-version }}
         python-version: ${{ matrix.python-version }}
+    - name: Install native dependencies (Ubuntu)
+      run: sudo apt-get update && sudo apt-get install -y libgpgme-dev libgpg-error-dev
+      if: "matrix.os == 'ubuntu-latest'"
+    - name: Install native dependencies (MacOS)
+      run: brew install swig gpgme
+      if: "matrix.os == 'macos-latest'"
     - name: Install dependencies
     - name: Install dependencies
       run: |
       run: |
         python -m pip install --upgrade pip
         python -m pip install --upgrade pip
         pip install -U pip coverage codecov flake8 fastimport
         pip install -U pip coverage codecov flake8 fastimport
+    - name: Install gpg on supported platforms
+      run: pip install -U gpg
+      if: "matrix.os != 'windows-latest' && matrix.python-version != 'pypy3'"
     - name: Install mypy
     - name: Install mypy
       run: |
       run: |
         pip install -U mypy
         pip install -U mypy

+ 17 - 0
.github/workflows/pythonpublish.yml

@@ -30,10 +30,19 @@ jobs:
       uses: actions/setup-python@v2
       uses: actions/setup-python@v2
       with:
       with:
         python-version: ${{ matrix.python-version }}
         python-version: ${{ matrix.python-version }}
+    - name: Install native dependencies (Ubuntu)
+      run: sudo apt-get update && sudo apt-get install -y libgpgme-dev libgpg-error-dev
+      if: "matrix.os == 'ubuntu-latest'"
+    - name: Install native dependencies (MacOS)
+      run: brew install swig gpgme
+      if: "matrix.os == 'macos-latest'"
     - name: Install dependencies
     - name: Install dependencies
       run: |
       run: |
         python -m pip install --upgrade pip
         python -m pip install --upgrade pip
         pip install setuptools wheel twine fastimport
         pip install setuptools wheel twine fastimport
+    - name: Install gpg on supported platforms
+      run: pip install -U gpg
+      if: "matrix.os != 'windows-latest' && matrix.python-version != 'pypy3'"
     - name: Run test suite
     - name: Run test suite
       run: |
       run: |
         python -m unittest dulwich.tests.test_suite
         python -m unittest dulwich.tests.test_suite
@@ -41,6 +50,14 @@ jobs:
       run: |
       run: |
         python setup.py sdist bdist_wheel
         python setup.py sdist bdist_wheel
       if: "matrix.os != 'ubuntu-latest'"
       if: "matrix.os != 'ubuntu-latest'"
+    - uses: docker/setup-qemu-action@v1
+      name: Set up QEMU
+      if: "matrix.os == 'ubuntu-latest'"
+    - name: Build and publish (Linux aarch64)
+      uses: RalfG/python-wheels-manylinux-build@v0.3.3-manylinux2014_aarch64
+      with:
+        python-versions: 'cp36-cp36m cp37-cp37m cp38-cp38 cp39-cp39'
+      if: "matrix.os == 'ubuntu-latest'"
     - name: Build and publish (Linux)
     - name: Build and publish (Linux)
       uses: RalfG/python-wheels-manylinux-build@v0.3.1
       uses: RalfG/python-wheels-manylinux-build@v0.3.1
       with:
       with:

+ 1 - 0
AUTHORS

@@ -150,5 +150,6 @@ Antoine Lambert <anlambert@softwareheritage.org>
 Lane Barlow <lane.barlow@gmail.com>
 Lane Barlow <lane.barlow@gmail.com>
 Manuel Jacob <me@manueljacob.de>
 Manuel Jacob <me@manueljacob.de>
 Brecht Machiels <brecht@mos6581.org>
 Brecht Machiels <brecht@mos6581.org>
+Peter Rowlands <peter@pmrowla.com>
 
 
 If you contributed but are missing from this list, please send me an e-mail.
 If you contributed but are missing from this list, please send me an e-mail.

+ 86 - 0
NEWS

@@ -1,3 +1,88 @@
+0.20.23	2021-05-24
+
+ * Fix installation of GPG during package publishing.
+   (Ruslan Kuprieiev)
+
+0.20.22	2021-05-24
+
+ * Prevent removal of refs directory when the last ref is
+   deleted. (Jelmer Vernooij)
+
+ * Fix filename: MERGE_HEADS => MERGE_HEAD.
+   (Jelmer Vernooij, #861)
+
+ * For ignored directories, porcelain.add and porcelain.status now only return
+   the path to directory itself in the list of ignored paths. Previously, paths
+   for all files within the directory would also be included in the list.
+   (Peter Rowlands, #853)
+
+ * Provide depth argument to ``determine_wants``.
+   (Peter Rowlands)
+
+ * Various tag signature handling improvements.
+   (Daniel Murphy)
+
+ * Add separate Tag.verify().  (Peter Rowlands)
+
+ * Add support for version 3 index files. (Jelmer Vernooij)
+
+ * Fix autocrlf=input handling. (Peter Rowlands, Boris Feld)
+
+ * Attempt to find C Git global config on Windows.
+   (Peter Rowlands)
+
+ API CHANGES
+
+ * The APIs for writing and reading individual index entries have changed
+   to handle lists of (name, entry) tuples rather than tuples.
+
+0.20.21	2021-03-20
+
+ * Add basic support for a GcsObjectStore that stores
+   pack files in gcs. (Jelmer Vernooij)
+
+ * In porcelain.push, default to local active branch.
+   (Jelmer Vernooij, #846)
+
+ * Support fetching symrefs.
+   (Jelmer Vernooij, #485, #847)
+
+ * Add aarch64 wheel building.
+   (odidev, Jelmer Vernooij)
+
+0.20.20	2021-03-03
+
+ * Implement ``Stash.drop``. (Peter Rowlands)
+
+ * Support untracked symlinks to paths outside the
+   repository. (Peter Rowlands, #842)
+
+0.20.19	2021-02-11
+
+ * Fix handling of negative matches in nested gitignores.
+   (Corentin Hembise, #836)
+
+0.20.18	2021-02-04
+
+ * Fix formatting in setup.py. (Jelmer Vernooij)
+
+ * Add release configuration. (Jelmer Vernooij)
+
+0.20.17	2021-02-04
+
+ * credentials: ignore end-of-line character. (Georges Racinet)
+
+ * Fix failure in get_untracked_paths when the repository contains symlinks.
+   (#830, #793, mattseddon)
+
+ * docs: Clarify that Git objects are created on `git add`.
+   (Utku Gultopu)
+
+0.20.16	2021-01-16
+
+ * Add flag to only attempt to fetch ignored untracked files when specifically requested.
+   (Matt Seddon)
+
 0.20.15	2020-12-23
 0.20.15	2020-12-23
 
 
  * Add some functions for parsing and writing bundles.
  * Add some functions for parsing and writing bundles.
@@ -994,6 +1079,7 @@
 
 
   * In dulwich.index.build_index_from_tree, by default
   * In dulwich.index.build_index_from_tree, by default
     refuse to create entries that start with .git/.
     refuse to create entries that start with .git/.
+    (Jelmer Vernooij, CVE-2014-9706)
 
 
   * Fix running of testsuite when installed.
   * Fix running of testsuite when installed.
     (Jelmer Vernooij, #223)
     (Jelmer Vernooij, #223)

+ 6 - 3
PKG-INFO

@@ -1,6 +1,6 @@
 Metadata-Version: 2.1
 Metadata-Version: 2.1
 Name: dulwich
 Name: dulwich
-Version: 0.20.15
+Version: 0.20.23
 Summary: Python Git Library
 Summary: Python Git Library
 Home-page: https://www.dulwich.io/
 Home-page: https://www.dulwich.io/
 Author: Jelmer Vernooij
 Author: Jelmer Vernooij
@@ -9,7 +9,10 @@ License: Apachev2 or later or GPLv2
 Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
 Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
 Project-URL: Repository, https://www.dulwich.io/code/
 Project-URL: Repository, https://www.dulwich.io/code/
 Project-URL: GitHub, https://github.com/dulwich/dulwich
 Project-URL: GitHub, https://github.com/dulwich/dulwich
-Description: This is the Dulwich project.
+Description: Dulwich
+        =======
+        
+        This is the Dulwich project.
         
         
         It aims to provide an interface to git repos (both local and remote) that
         It aims to provide an interface to git repos (both local and remote) that
         doesn't call out to git directly but instead uses pure Python.
         doesn't call out to git directly but instead uses pure Python.
@@ -80,7 +83,7 @@ Description: This is the Dulwich project.
         Help
         Help
         ----
         ----
         
         
-        There is a *#dulwich* IRC channel on the `Freenode <https://www.freenode.net/>`_, and
+        There is a *#dulwich* IRC channel on the `OFTC <https://www.oftc.net/>`_, and
         `dulwich-announce <https://groups.google.com/forum/#!forum/dulwich-announce>`_
         `dulwich-announce <https://groups.google.com/forum/#!forum/dulwich-announce>`_
         and `dulwich-discuss <https://groups.google.com/forum/#!forum/dulwich-discuss>`_
         and `dulwich-discuss <https://groups.google.com/forum/#!forum/dulwich-discuss>`_
         mailing lists.
         mailing lists.

+ 4 - 1
README.rst

@@ -1,3 +1,6 @@
+Dulwich
+=======
+
 This is the Dulwich project.
 This is the Dulwich project.
 
 
 It aims to provide an interface to git repos (both local and remote) that
 It aims to provide an interface to git repos (both local and remote) that
@@ -69,7 +72,7 @@ doc``. It can also be found `on the web <https://www.dulwich.io/docs/>`_.
 Help
 Help
 ----
 ----
 
 
-There is a *#dulwich* IRC channel on the `Freenode <https://www.freenode.net/>`_, and
+There is a *#dulwich* IRC channel on the `OFTC <https://www.oftc.net/>`_, and
 `dulwich-announce <https://groups.google.com/forum/#!forum/dulwich-announce>`_
 `dulwich-announce <https://groups.google.com/forum/#!forum/dulwich-announce>`_
 and `dulwich-discuss <https://groups.google.com/forum/#!forum/dulwich-discuss>`_
 and `dulwich-discuss <https://groups.google.com/forum/#!forum/dulwich-discuss>`_
 mailing lists.
 mailing lists.

+ 3 - 2
debian/changelog

@@ -1,9 +1,10 @@
-dulwich (0.20.15-2) UNRELEASED; urgency=medium
+dulwich (0.20.23-1) UNRELEASED; urgency=medium
 
 
   * Update watch file format version to 4.
   * Update watch file format version to 4.
   * Fix watch file.
   * Fix watch file.
+  * New upstream release.
 
 
- -- Jelmer Vernooij <jelmer@debian.org>  Wed, 03 Mar 2021 17:14:44 -0000
+ -- Jelmer Vernooij <jelmer@debian.org>  Sun, 30 May 2021 17:32:21 -0000
 
 
 dulwich (0.20.15-1) unstable; urgency=medium
 dulwich (0.20.15-1) unstable; urgency=medium
 
 

+ 3 - 2
docs/tutorial/file-format.txt

@@ -69,8 +69,9 @@ A blob file looks like this::
 
 
   blob <content length><NUL><content>
   blob <content length><NUL><content>
 
 
-If you change a single line, another blob will be generated by Git at commit
-time. This is how Git can fastly checkout any version in time.
+If you change a single line, another blob will be generated by Git each time you
+successfully run ``git add``. This is how Git can fastly checkout any version in
+time.
 
 
 On the opposite, several identical files with different filenames generate
 On the opposite, several identical files with different filenames generate
 only one blob. That's mostly how renames are so cheap and efficient in Git.
 only one blob. That's mostly how renames are so cheap and efficient in Git.

+ 4 - 1
docs/tutorial/remote.txt

@@ -41,10 +41,13 @@ The client object can then be used to retrieve a pack. The ``fetch_pack``
 method takes a ``determine_wants`` callback argument, which allows the
 method takes a ``determine_wants`` callback argument, which allows the
 client to determine which objects it wants to end up with::
 client to determine which objects it wants to end up with::
 
 
-   >>> def determine_wants(refs):
+   >>> def determine_wants(refs, depth=None):
    ...    # retrieve all objects
    ...    # retrieve all objects
    ...    return refs.values()
    ...    return refs.values()
 
 
+Note that the ``depth`` keyword argument will contain an optional requested
+shallow fetch depth.
+
 Another required object is a "graph walker", which is used to determine
 Another required object is a "graph walker", which is used to determine
 which objects that the client already has should not be sent again
 which objects that the client already has should not be sent again
 by the server. Here in the tutorial we'll just use a dummy graph walker
 by the server. Here in the tutorial we'll just use a dummy graph walker

+ 6 - 3
dulwich.egg-info/PKG-INFO

@@ -1,6 +1,6 @@
 Metadata-Version: 2.1
 Metadata-Version: 2.1
 Name: dulwich
 Name: dulwich
-Version: 0.20.15
+Version: 0.20.23
 Summary: Python Git Library
 Summary: Python Git Library
 Home-page: https://www.dulwich.io/
 Home-page: https://www.dulwich.io/
 Author: Jelmer Vernooij
 Author: Jelmer Vernooij
@@ -9,7 +9,10 @@ License: Apachev2 or later or GPLv2
 Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
 Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
 Project-URL: Repository, https://www.dulwich.io/code/
 Project-URL: Repository, https://www.dulwich.io/code/
 Project-URL: GitHub, https://github.com/dulwich/dulwich
 Project-URL: GitHub, https://github.com/dulwich/dulwich
-Description: This is the Dulwich project.
+Description: Dulwich
+        =======
+        
+        This is the Dulwich project.
         
         
         It aims to provide an interface to git repos (both local and remote) that
         It aims to provide an interface to git repos (both local and remote) that
         doesn't call out to git directly but instead uses pure Python.
         doesn't call out to git directly but instead uses pure Python.
@@ -80,7 +83,7 @@ Description: This is the Dulwich project.
         Help
         Help
         ----
         ----
         
         
-        There is a *#dulwich* IRC channel on the `Freenode <https://www.freenode.net/>`_, and
+        There is a *#dulwich* IRC channel on the `OFTC <https://www.oftc.net/>`_, and
         `dulwich-announce <https://groups.google.com/forum/#!forum/dulwich-announce>`_
         `dulwich-announce <https://groups.google.com/forum/#!forum/dulwich-announce>`_
         and `dulwich-discuss <https://groups.google.com/forum/#!forum/dulwich-discuss>`_
         and `dulwich-discuss <https://groups.google.com/forum/#!forum/dulwich-discuss>`_
         mailing lists.
         mailing lists.

+ 10 - 1
dulwich.egg-info/SOURCES.txt

@@ -1,4 +1,6 @@
 .coveragerc
 .coveragerc
+.deepsource.toml
+.flake8
 .gitignore
 .gitignore
 .mailmap
 .mailmap
 .testr.conf
 .testr.conf
@@ -15,6 +17,7 @@ SECURITY.md
 TODO
 TODO
 build.cmd
 build.cmd
 dulwich.cfg
 dulwich.cfg
+releaser.conf
 requirements.txt
 requirements.txt
 setup.cfg
 setup.cfg
 setup.py
 setup.py
@@ -92,12 +95,15 @@ dulwich.egg-info/dependency_links.txt
 dulwich.egg-info/entry_points.txt
 dulwich.egg-info/entry_points.txt
 dulwich.egg-info/requires.txt
 dulwich.egg-info/requires.txt
 dulwich.egg-info/top_level.txt
 dulwich.egg-info/top_level.txt
+dulwich/cloud/__init__.py
+dulwich/cloud/gcs.py
 dulwich/contrib/README.md
 dulwich/contrib/README.md
 dulwich/contrib/__init__.py
 dulwich/contrib/__init__.py
 dulwich/contrib/diffstat.py
 dulwich/contrib/diffstat.py
 dulwich/contrib/paramiko_vendor.py
 dulwich/contrib/paramiko_vendor.py
 dulwich/contrib/release_robot.py
 dulwich/contrib/release_robot.py
 dulwich/contrib/swift.py
 dulwich/contrib/swift.py
+dulwich/contrib/test_paramiko_vendor.py
 dulwich/contrib/test_release_robot.py
 dulwich/contrib/test_release_robot.py
 dulwich/contrib/test_swift.py
 dulwich/contrib/test_swift.py
 dulwich/contrib/test_swift_smoke.py
 dulwich/contrib/test_swift_smoke.py
@@ -142,6 +148,7 @@ dulwich/tests/compat/server_utils.py
 dulwich/tests/compat/test_client.py
 dulwich/tests/compat/test_client.py
 dulwich/tests/compat/test_pack.py
 dulwich/tests/compat/test_pack.py
 dulwich/tests/compat/test_patch.py
 dulwich/tests/compat/test_patch.py
+dulwich/tests/compat/test_porcelain.py
 dulwich/tests/compat/test_repository.py
 dulwich/tests/compat/test_repository.py
 dulwich/tests/compat/test_server.py
 dulwich/tests/compat/test_server.py
 dulwich/tests/compat/test_utils.py
 dulwich/tests/compat/test_utils.py
@@ -229,5 +236,7 @@ dulwich/tests/data/trees/70/c190eb48fa8bbb50ddc692a17b44cb781af7f6
 examples/clone.py
 examples/clone.py
 examples/config.py
 examples/config.py
 examples/diff.py
 examples/diff.py
+examples/gcs.py
 examples/latest_change.py
 examples/latest_change.py
-examples/memoryrepo.py
+examples/memoryrepo.py
+examples/rename-branch.py

+ 1 - 1
dulwich/__init__.py

@@ -22,4 +22,4 @@
 
 
 """Python implementation of the Git file formats and protocols."""
 """Python implementation of the Git file formats and protocols."""
 
 
-__version__ = (0, 20, 15)
+__version__ = (0, 20, 23)

+ 10 - 9
dulwich/archive.py

@@ -42,20 +42,21 @@ class ChunkedBytesIO(object):
         BytesIO(b''.join(list_of_bytestrings)) =~= ChunkedBytesIO(
         BytesIO(b''.join(list_of_bytestrings)) =~= ChunkedBytesIO(
             list_of_bytestrings)
             list_of_bytestrings)
     """
     """
+
     def __init__(self, contents):
     def __init__(self, contents):
         self.contents = contents
         self.contents = contents
         self.pos = (0, 0)
         self.pos = (0, 0)
 
 
     def read(self, maxbytes=None):
     def read(self, maxbytes=None):
         if maxbytes < 0:
         if maxbytes < 0:
-            maxbytes = float('inf')
+            maxbytes = float("inf")
 
 
         buf = []
         buf = []
         chunk, cursor = self.pos
         chunk, cursor = self.pos
 
 
         while chunk < len(self.contents):
         while chunk < len(self.contents):
             if maxbytes < len(self.contents[chunk]) - cursor:
             if maxbytes < len(self.contents[chunk]) - cursor:
-                buf.append(self.contents[chunk][cursor:cursor+maxbytes])
+                buf.append(self.contents[chunk][cursor : cursor + maxbytes])
                 cursor += maxbytes
                 cursor += maxbytes
                 self.pos = (chunk, cursor)
                 self.pos = (chunk, cursor)
                 break
                 break
@@ -65,10 +66,10 @@ class ChunkedBytesIO(object):
                 chunk += 1
                 chunk += 1
                 cursor = 0
                 cursor = 0
                 self.pos = (chunk, cursor)
                 self.pos = (chunk, cursor)
-        return b''.join(buf)
+        return b"".join(buf)
 
 
 
 
-def tar_stream(store, tree, mtime, prefix=b'', format=''):
+def tar_stream(store, tree, mtime, prefix=b"", format=""):
     """Generate a tar stream for the contents of a Git tree.
     """Generate a tar stream for the contents of a Git tree.
 
 
     Returns a generator that lazily assembles a .tar.gz archive, yielding it in
     Returns a generator that lazily assembles a .tar.gz archive, yielding it in
@@ -86,16 +87,16 @@ def tar_stream(store, tree, mtime, prefix=b'', format=''):
     """
     """
     buf = BytesIO()
     buf = BytesIO()
     with closing(tarfile.open(None, "w:%s" % format, buf)) as tar:
     with closing(tarfile.open(None, "w:%s" % format, buf)) as tar:
-        if format == 'gz':
+        if format == "gz":
             # Manually correct the gzip header file modification time so that
             # Manually correct the gzip header file modification time so that
             # archives created from the same Git tree are always identical.
             # archives created from the same Git tree are always identical.
             # The gzip header file modification time is not currenctly
             # The gzip header file modification time is not currenctly
             # accessible from the tarfile API, see:
             # accessible from the tarfile API, see:
             # https://bugs.python.org/issue31526
             # https://bugs.python.org/issue31526
             buf.seek(0)
             buf.seek(0)
-            assert buf.read(2) == b'\x1f\x8b', 'Invalid gzip header'
+            assert buf.read(2) == b"\x1f\x8b", "Invalid gzip header"
             buf.seek(4)
             buf.seek(4)
-            buf.write(struct.pack('<L', mtime))
+            buf.write(struct.pack("<L", mtime))
             buf.seek(0, SEEK_END)
             buf.seek(0, SEEK_END)
 
 
         for entry_abspath, entry in _walk_tree(store, tree, prefix):
         for entry_abspath, entry in _walk_tree(store, tree, prefix):
@@ -109,7 +110,7 @@ def tar_stream(store, tree, mtime, prefix=b'', format=''):
 
 
             info = tarfile.TarInfo()
             info = tarfile.TarInfo()
             # tarfile only works with ascii.
             # tarfile only works with ascii.
-            info.name = entry_abspath.decode('ascii')
+            info.name = entry_abspath.decode("ascii")
             info.size = blob.raw_length()
             info.size = blob.raw_length()
             info.mode = entry.mode
             info.mode = entry.mode
             info.mtime = mtime
             info.mtime = mtime
@@ -121,7 +122,7 @@ def tar_stream(store, tree, mtime, prefix=b'', format=''):
     yield buf.getvalue()
     yield buf.getvalue()
 
 
 
 
-def _walk_tree(store, tree, root=b''):
+def _walk_tree(store, tree, root=b""):
     """Recursively walk a dulwich Tree, yielding tuples of
     """Recursively walk a dulwich Tree, yielding tuples of
     (absolute path, TreeEntry) along the way.
     (absolute path, TreeEntry) along the way.
     """
     """

+ 22 - 23
dulwich/bundle.py

@@ -56,23 +56,23 @@ def _read_bundle(f, version):
     references = {}
     references = {}
     line = f.readline()
     line = f.readline()
     if version >= 3:
     if version >= 3:
-        while line.startswith(b'@'):
-            line = line[1:].rstrip(b'\n')
+        while line.startswith(b"@"):
+            line = line[1:].rstrip(b"\n")
             try:
             try:
-                key, value = line.split(b'=', 1)
+                key, value = line.split(b"=", 1)
             except ValueError:
             except ValueError:
                 key = line
                 key = line
                 value = None
                 value = None
             else:
             else:
-                value = value.decode('utf-8')
-            capabilities[key.decode('utf-8')] = value
+                value = value.decode("utf-8")
+            capabilities[key.decode("utf-8")] = value
             line = f.readline()
             line = f.readline()
-    while line.startswith(b'-'):
-        (obj_id, comment) = line[1:].rstrip(b'\n').split(b' ', 1)
-        prerequisites.append((obj_id, comment.decode('utf-8')))
+    while line.startswith(b"-"):
+        (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
+        prerequisites.append((obj_id, comment.decode("utf-8")))
         line = f.readline()
         line = f.readline()
-    while line != b'\n':
-        (obj_id, ref) = line.rstrip(b'\n').split(b' ', 1)
+    while line != b"\n":
+        (obj_id, ref) = line.rstrip(b"\n").split(b" ", 1)
         references[ref] = obj_id
         references[ref] = obj_id
         line = f.readline()
         line = f.readline()
     pack_data = PackData.from_file(f)
     pack_data = PackData.from_file(f)
@@ -88,12 +88,11 @@ def _read_bundle(f, version):
 def read_bundle(f):
 def read_bundle(f):
     """Read a bundle file."""
     """Read a bundle file."""
     firstline = f.readline()
     firstline = f.readline()
-    if firstline == b'# v2 git bundle\n':
+    if firstline == b"# v2 git bundle\n":
         return _read_bundle(f, 2)
         return _read_bundle(f, 2)
-    if firstline == b'# v3 git bundle\n':
+    if firstline == b"# v3 git bundle\n":
         return _read_bundle(f, 3)
         return _read_bundle(f, 3)
-    raise AssertionError(
-        'unsupported bundle format header: %r' % firstline)
+    raise AssertionError("unsupported bundle format header: %r" % firstline)
 
 
 
 
 def write_bundle(f, bundle):
 def write_bundle(f, bundle):
@@ -104,20 +103,20 @@ def write_bundle(f, bundle):
         else:
         else:
             version = 2
             version = 2
     if version == 2:
     if version == 2:
-        f.write(b'# v2 git bundle\n')
+        f.write(b"# v2 git bundle\n")
     elif version == 3:
     elif version == 3:
-        f.write(b'# v3 git bundle\n')
+        f.write(b"# v3 git bundle\n")
     else:
     else:
-        raise AssertionError('unknown version %d' % version)
+        raise AssertionError("unknown version %d" % version)
     if version == 3:
     if version == 3:
         for key, value in bundle.capabilities.items():
         for key, value in bundle.capabilities.items():
-            f.write(b'@' + key.encode('utf-8'))
+            f.write(b"@" + key.encode("utf-8"))
             if value is not None:
             if value is not None:
-                f.write(b'=' + value.encode('utf-8'))
-            f.write(b'\n')
+                f.write(b"=" + value.encode("utf-8"))
+            f.write(b"\n")
     for (obj_id, comment) in bundle.prerequisites:
     for (obj_id, comment) in bundle.prerequisites:
-        f.write(b'-%s %s\n' % (obj_id, comment.encode('utf-8')))
+        f.write(b"-%s %s\n" % (obj_id, comment.encode("utf-8")))
     for ref, obj_id in bundle.references.items():
     for ref, obj_id in bundle.references.items():
-        f.write(b'%s %s\n' % (obj_id, ref))
-    f.write(b'\n')
+        f.write(b"%s %s\n" % (obj_id, ref))
+    f.write(b"\n")
     write_pack_data(f, len(bundle.pack_data), iter(bundle.pack_data))
     write_pack_data(f, len(bundle.pack_data), iter(bundle.pack_data))

+ 181 - 159
dulwich/cli.py

@@ -51,6 +51,7 @@ def signal_int(signal, frame):
 
 
 def signal_quit(signal, frame):
 def signal_quit(signal, frame):
     import pdb
     import pdb
+
     pdb.set_trace()
     pdb.set_trace()
 
 
 
 
@@ -63,57 +64,65 @@ class Command(object):
 
 
 
 
 class cmd_archive(Command):
 class cmd_archive(Command):
-
     def run(self, args):
     def run(self, args):
-        parser = optparse.OptionParser()
-        parser.add_option("--remote", type=str,
-                          help="Retrieve archive from specified remote repo")
-        options, args = parser.parse_args(args)
-        committish = args.pop(0)
-        if options.remote:
-            client, path = get_transport_and_path(options.remote)
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "--remote",
+            type=str,
+            help="Retrieve archive from specified remote repo",
+        )
+        parser.add_argument('committish', type=str, nargs='?')
+        args = parser.parse_args(args)
+        if args.remote:
+            client, path = get_transport_and_path(args.remote)
             client.archive(
             client.archive(
-                path, committish, sys.stdout.write,
-                write_error=sys.stderr.write)
+                path,
+                args.committish,
+                sys.stdout.write,
+                write_error=sys.stderr.write,
+            )
         else:
         else:
             porcelain.archive(
             porcelain.archive(
-                '.', committish, outstream=sys.stdout,
-                errstream=sys.stderr)
+                ".", args.committish, outstream=sys.stdout.buffer,
+                errstream=sys.stderr
+            )
 
 
 
 
 class cmd_add(Command):
 class cmd_add(Command):
-
-    def run(self, args):
-        opts, args = getopt(args, "", [])
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        args = parser.parse_args(argv)
 
 
         porcelain.add(".", paths=args)
         porcelain.add(".", paths=args)
 
 
 
 
 class cmd_rm(Command):
 class cmd_rm(Command):
-
-    def run(self, args):
-        opts, args = getopt(args, "", [])
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        args = parser.parse_args(argv)
 
 
         porcelain.rm(".", paths=args)
         porcelain.rm(".", paths=args)
 
 
 
 
 class cmd_fetch_pack(Command):
 class cmd_fetch_pack(Command):
-
-    def run(self, args):
-        opts, args = getopt(args, "", ["all"])
-        opts = dict(opts)
-        client, path = get_transport_and_path(args.pop(0))
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        parser.add_argument('--all', action='store_true')
+        parser.add_argument('location', nargs='?', type=str)
+        args = parser.parse_args(argv)
+        client, path = get_transport_and_path(args.location)
         r = Repo(".")
         r = Repo(".")
-        if "--all" in opts:
+        if args.all:
             determine_wants = r.object_store.determine_wants_all
             determine_wants = r.object_store.determine_wants_all
         else:
         else:
-            def determine_wants(x):
+
+            def determine_wants(x, **kwargs):
                 return [y for y in args if y not in r.object_store]
                 return [y for y in args if y not in r.object_store]
+
         client.fetch(path, r, determine_wants)
         client.fetch(path, r, determine_wants)
 
 
 
 
 class cmd_fetch(Command):
 class cmd_fetch(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts, args = getopt(args, "", [])
         opts = dict(opts)
         opts = dict(opts)
@@ -126,32 +135,40 @@ class cmd_fetch(Command):
 
 
 
 
 class cmd_fsck(Command):
 class cmd_fsck(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts, args = getopt(args, "", [])
         opts = dict(opts)
         opts = dict(opts)
-        for (obj, msg) in porcelain.fsck('.'):
+        for (obj, msg) in porcelain.fsck("."):
             print("%s: %s" % (obj, msg))
             print("%s: %s" % (obj, msg))
 
 
 
 
 class cmd_log(Command):
 class cmd_log(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
-        parser.add_option("--reverse", dest="reverse", action="store_true",
-                          help="Reverse order in which entries are printed")
-        parser.add_option("--name-status", dest="name_status",
-                          action="store_true",
-                          help="Print name/status for each changed file")
+        parser.add_option(
+            "--reverse",
+            dest="reverse",
+            action="store_true",
+            help="Reverse order in which entries are printed",
+        )
+        parser.add_option(
+            "--name-status",
+            dest="name_status",
+            action="store_true",
+            help="Print name/status for each changed file",
+        )
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
 
 
-        porcelain.log(".", paths=args, reverse=options.reverse,
-                      name_status=options.name_status,
-                      outstream=sys.stdout)
+        porcelain.log(
+            ".",
+            paths=args,
+            reverse=options.reverse,
+            name_status=options.name_status,
+            outstream=sys.stdout,
+        )
 
 
 
 
 class cmd_diff(Command):
 class cmd_diff(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts, args = getopt(args, "", [])
 
 
@@ -163,12 +180,10 @@ class cmd_diff(Command):
         commit_id = args[0]
         commit_id = args[0]
         commit = r[commit_id]
         commit = r[commit_id]
         parent_commit = r[commit.parents[0]]
         parent_commit = r[commit.parents[0]]
-        write_tree_diff(
-            sys.stdout, r.object_store, parent_commit.tree, commit.tree)
+        write_tree_diff(sys.stdout, r.object_store, parent_commit.tree, commit.tree)
 
 
 
 
 class cmd_dump_pack(Command):
 class cmd_dump_pack(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts, args = getopt(args, "", [])
 
 
@@ -193,7 +208,6 @@ class cmd_dump_pack(Command):
 
 
 
 
 class cmd_dump_index(Command):
 class cmd_dump_index(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts, args = getopt(args, "", [])
 
 
@@ -209,7 +223,6 @@ class cmd_dump_index(Command):
 
 
 
 
 class cmd_init(Command):
 class cmd_init(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", ["bare"])
         opts, args = getopt(args, "", ["bare"])
         opts = dict(opts)
         opts = dict(opts)
@@ -223,14 +236,17 @@ class cmd_init(Command):
 
 
 
 
 class cmd_clone(Command):
 class cmd_clone(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
-        parser.add_option("--bare", dest="bare",
-                          help="Whether to create a bare repository.",
-                          action="store_true")
-        parser.add_option("--depth", dest="depth",
-                          type=int, help="Depth at which to fetch")
+        parser.add_option(
+            "--bare",
+            dest="bare",
+            help="Whether to create a bare repository.",
+            action="store_true",
+        )
+        parser.add_option(
+            "--depth", dest="depth", type=int, help="Depth at which to fetch"
+        )
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
 
 
         if args == []:
         if args == []:
@@ -247,7 +263,6 @@ class cmd_clone(Command):
 
 
 
 
 class cmd_commit(Command):
 class cmd_commit(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", ["message"])
         opts, args = getopt(args, "", ["message"])
         opts = dict(opts)
         opts = dict(opts)
@@ -255,7 +270,6 @@ class cmd_commit(Command):
 
 
 
 
 class cmd_commit_tree(Command):
 class cmd_commit_tree(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", ["message"])
         opts, args = getopt(args, "", ["message"])
         if args == []:
         if args == []:
@@ -266,13 +280,11 @@ class cmd_commit_tree(Command):
 
 
 
 
 class cmd_update_server_info(Command):
 class cmd_update_server_info(Command):
-
     def run(self, args):
     def run(self, args):
         porcelain.update_server_info(".")
         porcelain.update_server_info(".")
 
 
 
 
 class cmd_symbolic_ref(Command):
 class cmd_symbolic_ref(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", ["ref-name", "force"])
         opts, args = getopt(args, "", ["ref-name", "force"])
         if not args:
         if not args:
@@ -280,18 +292,18 @@ class cmd_symbolic_ref(Command):
             sys.exit(1)
             sys.exit(1)
 
 
         ref_name = args.pop(0)
         ref_name = args.pop(0)
-        porcelain.symbolic_ref(".", ref_name=ref_name, force='--force' in args)
+        porcelain.symbolic_ref(".", ref_name=ref_name, force="--force" in args)
 
 
 
 
 class cmd_show(Command):
 class cmd_show(Command):
-
-    def run(self, args):
-        opts, args = getopt(args, "", [])
-        porcelain.show(".", args)
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        parser.add_argument('objectish', type=str, nargs='*')
+        args = parser.parse_args(argv)
+        porcelain.show(".", args.objectish or None)
 
 
 
 
 class cmd_diff_tree(Command):
 class cmd_diff_tree(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts, args = getopt(args, "", [])
         if len(args) < 2:
         if len(args) < 2:
@@ -301,41 +313,40 @@ class cmd_diff_tree(Command):
 
 
 
 
 class cmd_rev_list(Command):
 class cmd_rev_list(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts, args = getopt(args, "", [])
         if len(args) < 1:
         if len(args) < 1:
-            print('Usage: dulwich rev-list COMMITID...')
+            print("Usage: dulwich rev-list COMMITID...")
             sys.exit(1)
             sys.exit(1)
-        porcelain.rev_list('.', args)
+        porcelain.rev_list(".", args)
 
 
 
 
 class cmd_tag(Command):
 class cmd_tag(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         parser.add_option(
         parser.add_option(
-            "-a", "--annotated", help="Create an annotated tag.",
-            action="store_true")
+            "-a",
+            "--annotated",
+            help="Create an annotated tag.",
+            action="store_true",
+        )
         parser.add_option(
         parser.add_option(
-            "-s", "--sign", help="Sign the annotated tag.",
-            action="store_true")
+            "-s", "--sign", help="Sign the annotated tag.", action="store_true"
+        )
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
         porcelain.tag_create(
         porcelain.tag_create(
-            '.', args[0], annotated=options.annotated,
-            sign=options.sign)
+            ".", args[0], annotated=options.annotated, sign=options.sign
+        )
 
 
 
 
 class cmd_repack(Command):
 class cmd_repack(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", [])
         opts, args = getopt(args, "", [])
         opts = dict(opts)
         opts = dict(opts)
-        porcelain.repack('.')
+        porcelain.repack(".")
 
 
 
 
 class cmd_reset(Command):
 class cmd_reset(Command):
-
     def run(self, args):
     def run(self, args):
         opts, args = getopt(args, "", ["hard", "soft", "mixed"])
         opts, args = getopt(args, "", ["hard", "soft", "mixed"])
         opts = dict(opts)
         opts = dict(opts)
@@ -346,110 +357,122 @@ class cmd_reset(Command):
             mode = "soft"
             mode = "soft"
         elif "--mixed" in opts:
         elif "--mixed" in opts:
             mode = "mixed"
             mode = "mixed"
-        porcelain.reset('.', mode=mode, *args)
+        porcelain.reset(".", mode=mode, *args)
 
 
 
 
 class cmd_daemon(Command):
 class cmd_daemon(Command):
-
     def run(self, args):
     def run(self, args):
         from dulwich import log_utils
         from dulwich import log_utils
         from dulwich.protocol import TCP_GIT_PORT
         from dulwich.protocol import TCP_GIT_PORT
+
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
-        parser.add_option("-l", "--listen_address", dest="listen_address",
-                          default="localhost",
-                          help="Binding IP address.")
-        parser.add_option("-p", "--port", dest="port", type=int,
-                          default=TCP_GIT_PORT,
-                          help="Binding TCP port.")
+        parser.add_option(
+            "-l",
+            "--listen_address",
+            dest="listen_address",
+            default="localhost",
+            help="Binding IP address.",
+        )
+        parser.add_option(
+            "-p",
+            "--port",
+            dest="port",
+            type=int,
+            default=TCP_GIT_PORT,
+            help="Binding TCP port.",
+        )
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
 
 
         log_utils.default_logging_config()
         log_utils.default_logging_config()
         if len(args) >= 1:
         if len(args) >= 1:
             gitdir = args[0]
             gitdir = args[0]
         else:
         else:
-            gitdir = '.'
-        from dulwich import porcelain
-        porcelain.daemon(gitdir, address=options.listen_address,
-                         port=options.port)
+            gitdir = "."
 
 
+        porcelain.daemon(gitdir, address=options.listen_address, port=options.port)
 
 
-class cmd_web_daemon(Command):
 
 
+class cmd_web_daemon(Command):
     def run(self, args):
     def run(self, args):
         from dulwich import log_utils
         from dulwich import log_utils
+
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
-        parser.add_option("-l", "--listen_address", dest="listen_address",
-                          default="",
-                          help="Binding IP address.")
-        parser.add_option("-p", "--port", dest="port", type=int,
-                          default=8000,
-                          help="Binding TCP port.")
+        parser.add_option(
+            "-l",
+            "--listen_address",
+            dest="listen_address",
+            default="",
+            help="Binding IP address.",
+        )
+        parser.add_option(
+            "-p",
+            "--port",
+            dest="port",
+            type=int,
+            default=8000,
+            help="Binding TCP port.",
+        )
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
 
 
         log_utils.default_logging_config()
         log_utils.default_logging_config()
         if len(args) >= 1:
         if len(args) >= 1:
             gitdir = args[0]
             gitdir = args[0]
         else:
         else:
-            gitdir = '.'
-        from dulwich import porcelain
-        porcelain.web_daemon(gitdir, address=options.listen_address,
-                             port=options.port)
+            gitdir = "."
 
 
+        porcelain.web_daemon(gitdir, address=options.listen_address, port=options.port)
 
 
-class cmd_write_tree(Command):
 
 
+class cmd_write_tree(Command):
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
-        sys.stdout.write('%s\n' % porcelain.write_tree('.'))
+        sys.stdout.write("%s\n" % porcelain.write_tree("."))
 
 
 
 
 class cmd_receive_pack(Command):
 class cmd_receive_pack(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
         if len(args) >= 1:
         if len(args) >= 1:
             gitdir = args[0]
             gitdir = args[0]
         else:
         else:
-            gitdir = '.'
+            gitdir = "."
         porcelain.receive_pack(gitdir)
         porcelain.receive_pack(gitdir)
 
 
 
 
 class cmd_upload_pack(Command):
 class cmd_upload_pack(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
         if len(args) >= 1:
         if len(args) >= 1:
             gitdir = args[0]
             gitdir = args[0]
         else:
         else:
-            gitdir = '.'
+            gitdir = "."
         porcelain.upload_pack(gitdir)
         porcelain.upload_pack(gitdir)
 
 
 
 
 class cmd_status(Command):
 class cmd_status(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
         if len(args) >= 1:
         if len(args) >= 1:
             gitdir = args[0]
             gitdir = args[0]
         else:
         else:
-            gitdir = '.'
+            gitdir = "."
         status = porcelain.status(gitdir)
         status = porcelain.status(gitdir)
         if any(names for (kind, names) in status.staged.items()):
         if any(names for (kind, names) in status.staged.items()):
             sys.stdout.write("Changes to be committed:\n\n")
             sys.stdout.write("Changes to be committed:\n\n")
             for kind, names in status.staged.items():
             for kind, names in status.staged.items():
                 for name in names:
                 for name in names:
-                    sys.stdout.write("\t%s: %s\n" % (
-                        kind, name.decode(sys.getfilesystemencoding())))
+                    sys.stdout.write(
+                        "\t%s: %s\n" % (kind, name.decode(sys.getfilesystemencoding()))
+                    )
             sys.stdout.write("\n")
             sys.stdout.write("\n")
         if status.unstaged:
         if status.unstaged:
             sys.stdout.write("Changes not staged for commit:\n\n")
             sys.stdout.write("Changes not staged for commit:\n\n")
             for name in status.unstaged:
             for name in status.unstaged:
-                sys.stdout.write(
-                    "\t%s\n" % name.decode(sys.getfilesystemencoding()))
+                sys.stdout.write("\t%s\n" % name.decode(sys.getfilesystemencoding()))
             sys.stdout.write("\n")
             sys.stdout.write("\n")
         if status.untracked:
         if status.untracked:
             sys.stdout.write("Untracked files:\n\n")
             sys.stdout.write("Untracked files:\n\n")
@@ -459,11 +482,10 @@ class cmd_status(Command):
 
 
 
 
 class cmd_ls_remote(Command):
 class cmd_ls_remote(Command):
-
     def run(self, args):
     def run(self, args):
-        opts, args = getopt(args, '', [])
+        opts, args = getopt(args, "", [])
         if len(args) < 1:
         if len(args) < 1:
-            print('Usage: dulwich ls-remote URL')
+            print("Usage: dulwich ls-remote URL")
             sys.exit(1)
             sys.exit(1)
         refs = porcelain.ls_remote(args[0])
         refs = porcelain.ls_remote(args[0])
         for ref in sorted(refs):
         for ref in sorted(refs):
@@ -471,48 +493,52 @@ class cmd_ls_remote(Command):
 
 
 
 
 class cmd_ls_tree(Command):
 class cmd_ls_tree(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
-        parser.add_option("-r", "--recursive", action="store_true",
-                          help="Recusively list tree contents.")
-        parser.add_option("--name-only", action="store_true",
-                          help="Only display name.")
+        parser.add_option(
+            "-r",
+            "--recursive",
+            action="store_true",
+            help="Recusively list tree contents.",
+        )
+        parser.add_option("--name-only", action="store_true", help="Only display name.")
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
         try:
         try:
             treeish = args.pop(0)
             treeish = args.pop(0)
         except IndexError:
         except IndexError:
             treeish = None
             treeish = None
         porcelain.ls_tree(
         porcelain.ls_tree(
-            '.', treeish, outstream=sys.stdout, recursive=options.recursive,
-            name_only=options.name_only)
+            ".",
+            treeish,
+            outstream=sys.stdout,
+            recursive=options.recursive,
+            name_only=options.name_only,
+        )
 
 
 
 
 class cmd_pack_objects(Command):
 class cmd_pack_objects(Command):
-
     def run(self, args):
     def run(self, args):
-        opts, args = getopt(args, '', ['stdout'])
+        opts, args = getopt(args, "", ["stdout"])
         opts = dict(opts)
         opts = dict(opts)
-        if len(args) < 1 and '--stdout' not in args:
-            print('Usage: dulwich pack-objects basename')
+        if len(args) < 1 and "--stdout" not in args:
+            print("Usage: dulwich pack-objects basename")
             sys.exit(1)
             sys.exit(1)
         object_ids = [line.strip() for line in sys.stdin.readlines()]
         object_ids = [line.strip() for line in sys.stdin.readlines()]
         basename = args[0]
         basename = args[0]
-        if '--stdout' in opts:
-            packf = getattr(sys.stdout, 'buffer', sys.stdout)
+        if "--stdout" in opts:
+            packf = getattr(sys.stdout, "buffer", sys.stdout)
             idxf = None
             idxf = None
             close = []
             close = []
         else:
         else:
-            packf = open(basename + '.pack', 'w')
-            idxf = open(basename + '.idx', 'w')
+            packf = open(basename + ".pack", "w")
+            idxf = open(basename + ".idx", "w")
             close = [packf, idxf]
             close = [packf, idxf]
-        porcelain.pack_objects('.', object_ids, packf, idxf)
+        porcelain.pack_objects(".", object_ids, packf, idxf)
         for f in close:
         for f in close:
             f.close()
             f.close()
 
 
 
 
 class cmd_pull(Command):
 class cmd_pull(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
@@ -520,7 +546,7 @@ class cmd_pull(Command):
             from_location = args[0]
             from_location = args[0]
         except IndexError:
         except IndexError:
             from_location = None
             from_location = None
-        porcelain.pull('.', from_location)
+        porcelain.pull(".", from_location)
 
 
 
 
 class cmd_push(Command):
 class cmd_push(Command):
@@ -534,11 +560,10 @@ class cmd_push(Command):
 
 
 
 
 class cmd_remote_add(Command):
 class cmd_remote_add(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
-        porcelain.remote_add('.', args[0], args[1])
+        porcelain.remote_add(".", args[0], args[1])
 
 
 
 
 class SuperCommand(Command):
 class SuperCommand(Command):
@@ -547,14 +572,13 @@ class SuperCommand(Command):
 
 
     def run(self, args):
     def run(self, args):
         if not args:
         if not args:
-            print("Supported subcommands: %s" %
-                  ', '.join(self.subcommands.keys()))
+            print("Supported subcommands: %s" % ", ".join(self.subcommands.keys()))
             return False
             return False
         cmd = args[0]
         cmd = args[0]
         try:
         try:
             cmd_kls = self.subcommands[cmd]
             cmd_kls = self.subcommands[cmd]
         except KeyError:
         except KeyError:
-            print('No such subcommand: %s' % args[0])
+            print("No such subcommand: %s" % args[0])
             return False
             return False
         return cmd_kls().run(args[1:])
         return cmd_kls().run(args[1:])
 
 
@@ -567,51 +591,46 @@ class cmd_remote(SuperCommand):
 
 
 
 
 class cmd_check_ignore(Command):
 class cmd_check_ignore(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
         ret = 1
         ret = 1
-        for path in porcelain.check_ignore('.', args):
+        for path in porcelain.check_ignore(".", args):
             print(path)
             print(path)
             ret = 0
             ret = 0
         return ret
         return ret
 
 
 
 
 class cmd_check_mailmap(Command):
 class cmd_check_mailmap(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
         for arg in args:
         for arg in args:
-            canonical_identity = porcelain.check_mailmap('.', arg)
+            canonical_identity = porcelain.check_mailmap(".", arg)
             print(canonical_identity)
             print(canonical_identity)
 
 
 
 
 class cmd_stash_list(Command):
 class cmd_stash_list(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
-        for i, entry in porcelain.stash_list('.'):
-            print("stash@{%d}: %s" % (i, entry.message.rstrip('\n')))
+        for i, entry in porcelain.stash_list("."):
+            print("stash@{%d}: %s" % (i, entry.message.rstrip("\n")))
 
 
 
 
 class cmd_stash_push(Command):
 class cmd_stash_push(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
-        porcelain.stash_push('.')
+        porcelain.stash_push(".")
         print("Saved working directory and index state")
         print("Saved working directory and index state")
 
 
 
 
 class cmd_stash_pop(Command):
 class cmd_stash_pop(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
-        porcelain.stash_pop('.')
+        porcelain.stash_pop(".")
         print("Restrored working directory and index state")
         print("Restrored working directory and index state")
 
 
 
 
@@ -625,42 +644,45 @@ class cmd_stash(SuperCommand):
 
 
 
 
 class cmd_ls_files(Command):
 class cmd_ls_files(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
-        for name in porcelain.ls_files('.'):
+        for name in porcelain.ls_files("."):
             print(name)
             print(name)
 
 
 
 
 class cmd_describe(Command):
 class cmd_describe(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
-        print(porcelain.describe('.'))
+        print(porcelain.describe("."))
 
 
 
 
 class cmd_help(Command):
 class cmd_help(Command):
-
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
-        parser.add_option("-a", "--all", dest="all",
-                          action="store_true",
-                          help="List all commands.")
+        parser.add_option(
+            "-a",
+            "--all",
+            dest="all",
+            action="store_true",
+            help="List all commands.",
+        )
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
 
 
         if options.all:
         if options.all:
-            print('Available commands:')
+            print("Available commands:")
             for cmd in sorted(commands):
             for cmd in sorted(commands):
-                print('  %s' % cmd)
+                print("  %s" % cmd)
         else:
         else:
-            print("""\
+            print(
+                """\
 The dulwich command line tool is currently a very basic frontend for the
 The dulwich command line tool is currently a very basic frontend for the
 Dulwich python module. For full functionality, please see the API reference.
 Dulwich python module. For full functionality, please see the API reference.
 
 
 For a list of supported commands, see 'dulwich help -a'.
 For a list of supported commands, see 'dulwich help -a'.
-""")
+"""
+            )
 
 
 
 
 commands = {
 commands = {
@@ -704,7 +726,7 @@ commands = {
     "upload-pack": cmd_upload_pack,
     "upload-pack": cmd_upload_pack,
     "web-daemon": cmd_web_daemon,
     "web-daemon": cmd_web_daemon,
     "write-tree": cmd_write_tree,
     "write-tree": cmd_write_tree,
-    }
+}
 
 
 
 
 def main(argv=None):
 def main(argv=None):
@@ -725,8 +747,8 @@ def main(argv=None):
     return cmd_kls().run(argv[1:])
     return cmd_kls().run(argv[1:])
 
 
 
 
-if __name__ == '__main__':
-    if 'DULWICH_PDB' in os.environ and getattr(signal, 'SIGQUIT', None):
+if __name__ == "__main__":
+    if "DULWICH_PDB" in os.environ and getattr(signal, "SIGQUIT", None):
         signal.signal(signal.SIGQUIT, signal_quit)  # type: ignore
         signal.signal(signal.SIGQUIT, signal_quit)  # type: ignore
     signal.signal(signal.SIGINT, signal_int)
     signal.signal(signal.SIGINT, signal_int)
 
 

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 322 - 183
dulwich/client.py


+ 0 - 0
dulwich/cloud/__init__.py


+ 82 - 0
dulwich/cloud/gcs.py

@@ -0,0 +1,82 @@
+# object_store.py -- Object store for git objects
+# Copyright (C) 2021 Jelmer Vernooij <jelmer@jelmer.uk>
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+
+"""Storage of repositories on GCS."""
+
+import posixpath
+import tempfile
+
+from ..object_store import BucketBasedObjectStore
+from ..pack import PackData, Pack, load_pack_index_file
+
+
+# TODO(jelmer): For performance, read ranges?
+
+
+class GcsObjectStore(BucketBasedObjectStore):
+
+    def __init__(self, bucket, subpath=''):
+        super(GcsObjectStore, self).__init__()
+        self.bucket = bucket
+        self.subpath = subpath
+
+    def __repr__(self):
+        return "%s(%r, subpath=%r)" % (
+            type(self).__name__, self.bucket, self.subpath)
+
+    def _remove_pack(self, name):
+        self.bucket.delete_blobs([
+            posixpath.join(self.subpath, name) + '.' + ext
+            for ext in ['pack', 'idx']])
+
+    def _iter_pack_names(self):
+        packs = {}
+        for blob in self.bucket.list_blobs(prefix=self.subpath):
+            name, ext = posixpath.splitext(posixpath.basename(blob.name))
+            packs.setdefault(name, set()).add(ext)
+        for name, exts in packs.items():
+            if exts == set(['.pack', '.idx']):
+                yield name
+
+    def _load_pack_data(self, name):
+        b = self.bucket.blob(posixpath.join(self.subpath, name + '.pack'))
+        f = tempfile.SpooledTemporaryFile()
+        b.download_to_file(f)
+        f.seek(0)
+        return PackData(name + '.pack', f)
+
+    def _load_pack_index(self, name):
+        b = self.bucket.blob(posixpath.join(self.subpath, name + '.idx'))
+        f = tempfile.SpooledTemporaryFile()
+        b.download_to_file(f)
+        f.seek(0)
+        return load_pack_index_file(name + '.idx', f)
+
+    def _get_pack(self, name):
+        return Pack.from_lazy_objects(
+            lambda: self._load_pack_data(name),
+            lambda: self._load_pack_index(name))
+
+    def _upload_pack(self, basename, pack_file, index_file):
+        idxblob = self.bucket.blob(posixpath.join(self.subpath, basename + '.idx'))
+        datablob = self.bucket.blob(posixpath.join(self.subpath, basename + '.pack'))
+        idxblob.upload_from_file(index_file)
+        datablob.upload_from_file(pack_file)

+ 115 - 64
dulwich/config.py

@@ -33,17 +33,18 @@ from typing import BinaryIO, Tuple, Optional
 
 
 from collections import (
 from collections import (
     OrderedDict,
     OrderedDict,
-    )
+)
+
 try:
 try:
     from collections.abc import (
     from collections.abc import (
         Iterable,
         Iterable,
         MutableMapping,
         MutableMapping,
-        )
+    )
 except ImportError:  # python < 3.7
 except ImportError:  # python < 3.7
     from collections import (
     from collections import (
         Iterable,
         Iterable,
         MutableMapping,
         MutableMapping,
-        )
+    )
 
 
 from dulwich.file import GitFile
 from dulwich.file import GitFile
 
 
@@ -56,15 +57,12 @@ def lower_key(key):
         return key.lower()
         return key.lower()
 
 
     if isinstance(key, Iterable):
     if isinstance(key, Iterable):
-        return type(key)(
-            map(lower_key, key)
-        )
+        return type(key)(map(lower_key, key))
 
 
     return key
     return key
 
 
 
 
 class CaseInsensitiveDict(OrderedDict):
 class CaseInsensitiveDict(OrderedDict):
-
     @classmethod
     @classmethod
     def make(cls, dict_in=None):
     def make(cls, dict_in=None):
 
 
@@ -87,7 +85,7 @@ class CaseInsensitiveDict(OrderedDict):
     def __setitem__(self, key, value, **kwargs):
     def __setitem__(self, key, value, **kwargs):
         key = lower_key(key)
         key = lower_key(key)
 
 
-        super(CaseInsensitiveDict, self).__setitem__(key, value,  **kwargs)
+        super(CaseInsensitiveDict, self).__setitem__(key, value, **kwargs)
 
 
     def __getitem__(self, item):
     def __getitem__(self, item):
         key = lower_key(item)
         key = lower_key(item)
@@ -188,7 +186,7 @@ class Config(object):
         Returns:
         Returns:
           boolean indicating whether the section exists
           boolean indicating whether the section exists
         """
         """
-        return (name in self.itersections())
+        return name in self.itersections()
 
 
 
 
 class ConfigDict(Config, MutableMapping):
 class ConfigDict(Config, MutableMapping):
@@ -205,9 +203,7 @@ class ConfigDict(Config, MutableMapping):
         return "%s(%r)" % (self.__class__.__name__, self._values)
         return "%s(%r)" % (self.__class__.__name__, self._values)
 
 
     def __eq__(self, other):
     def __eq__(self, other):
-        return (
-            isinstance(other, self.__class__) and
-            other._values == self._values)
+        return isinstance(other, self.__class__) and other._values == self._values
 
 
     def __getitem__(self, key):
     def __getitem__(self, key):
         return self._values.__getitem__(key)
         return self._values.__getitem__(key)
@@ -234,13 +230,16 @@ class ConfigDict(Config, MutableMapping):
 
 
     def _check_section_and_name(self, section, name):
     def _check_section_and_name(self, section, name):
         if not isinstance(section, tuple):
         if not isinstance(section, tuple):
-            section = (section, )
-
-        section = tuple([
-            subsection.encode(self.encoding)
-            if not isinstance(subsection, bytes) else subsection
-            for subsection in section
-            ])
+            section = (section,)
+
+        section = tuple(
+            [
+                subsection.encode(self.encoding)
+                if not isinstance(subsection, bytes)
+                else subsection
+                for subsection in section
+            ]
+        )
 
 
         if not isinstance(name, bytes):
         if not isinstance(name, bytes):
             name = name.encode(self.encoding)
             name = name.encode(self.encoding)
@@ -274,11 +273,13 @@ class ConfigDict(Config, MutableMapping):
 
 
 
 
 def _format_string(value):
 def _format_string(value):
-    if (value.startswith(b" ") or
-            value.startswith(b"\t") or
-            value.endswith(b" ") or
-            b'#' in value or
-            value.endswith(b"\t")):
+    if (
+        value.startswith(b" ")
+        or value.startswith(b"\t")
+        or value.endswith(b" ")
+        or b"#" in value
+        or value.endswith(b"\t")
+    ):
         return b'"' + _escape_value(value) + b'"'
         return b'"' + _escape_value(value) + b'"'
     else:
     else:
         return _escape_value(value)
         return _escape_value(value)
@@ -286,11 +287,11 @@ def _format_string(value):
 
 
 _ESCAPE_TABLE = {
 _ESCAPE_TABLE = {
     ord(b"\\"): ord(b"\\"),
     ord(b"\\"): ord(b"\\"),
-    ord(b"\""): ord(b"\""),
+    ord(b'"'): ord(b'"'),
     ord(b"n"): ord(b"\n"),
     ord(b"n"): ord(b"\n"),
     ord(b"t"): ord(b"\t"),
     ord(b"t"): ord(b"\t"),
     ord(b"b"): ord(b"\b"),
     ord(b"b"): ord(b"\b"),
-    }
+}
 _COMMENT_CHARS = [ord(b"#"), ord(b";")]
 _COMMENT_CHARS = [ord(b"#"), ord(b";")]
 _WHITESPACE_CHARS = [ord(b"\t"), ord(b" ")]
 _WHITESPACE_CHARS = [ord(b"\t"), ord(b" ")]
 
 
@@ -309,18 +310,19 @@ def _parse_string(value):
                 v = _ESCAPE_TABLE[value[i]]
                 v = _ESCAPE_TABLE[value[i]]
             except IndexError:
             except IndexError:
                 raise ValueError(
                 raise ValueError(
-                    "escape character in %r at %d before end of string" %
-                    (value, i))
+                    "escape character in %r at %d before end of string" % (value, i)
+                )
             except KeyError:
             except KeyError:
                 raise ValueError(
                 raise ValueError(
                     "escape character followed by unknown character "
                     "escape character followed by unknown character "
-                    "%s at %d in %r" % (value[i], i, value))
+                    "%s at %d in %r" % (value[i], i, value)
+                )
             if whitespace:
             if whitespace:
                 ret.extend(whitespace)
                 ret.extend(whitespace)
                 whitespace = bytearray()
                 whitespace = bytearray()
             ret.append(v)
             ret.append(v)
-        elif c == ord(b"\""):
-            in_quotes = (not in_quotes)
+        elif c == ord(b'"'):
+            in_quotes = not in_quotes
         elif c in _COMMENT_CHARS and not in_quotes:
         elif c in _COMMENT_CHARS and not in_quotes:
             # the rest of the line is a comment
             # the rest of the line is a comment
             break
             break
@@ -344,22 +346,22 @@ def _escape_value(value):
     value = value.replace(b"\\", b"\\\\")
     value = value.replace(b"\\", b"\\\\")
     value = value.replace(b"\n", b"\\n")
     value = value.replace(b"\n", b"\\n")
     value = value.replace(b"\t", b"\\t")
     value = value.replace(b"\t", b"\\t")
-    value = value.replace(b"\"", b"\\\"")
+    value = value.replace(b'"', b'\\"')
     return value
     return value
 
 
 
 
 def _check_variable_name(name):
 def _check_variable_name(name):
     for i in range(len(name)):
     for i in range(len(name)):
-        c = name[i:i+1]
-        if not c.isalnum() and c != b'-':
+        c = name[i : i + 1]
+        if not c.isalnum() and c != b"-":
             return False
             return False
     return True
     return True
 
 
 
 
 def _check_section_name(name):
 def _check_section_name(name):
     for i in range(len(name)):
     for i in range(len(name)):
-        c = name[i:i+1]
-        if not c.isalnum() and c not in (b'-', b'.'):
+        c = name[i : i + 1]
+        if not c.isalnum() and c not in (b"-", b"."):
             return False
             return False
     return True
     return True
 
 
@@ -379,15 +381,14 @@ def _strip_comments(line):
 
 
 
 
 class ConfigFile(ConfigDict):
 class ConfigFile(ConfigDict):
-    """A Git configuration file, like .git/config or ~/.gitconfig.
-    """
+    """A Git configuration file, like .git/config or ~/.gitconfig."""
 
 
     def __init__(self, values=None, encoding=None):
     def __init__(self, values=None, encoding=None):
         super(ConfigFile, self).__init__(values=values, encoding=encoding)
         super(ConfigFile, self).__init__(values=values, encoding=encoding)
         self.path = None
         self.path = None
 
 
     @classmethod
     @classmethod
-    def from_file(cls, f: BinaryIO) -> 'ConfigFile':
+    def from_file(cls, f: BinaryIO) -> "ConfigFile":
         """Read configuration from a file-like object."""
         """Read configuration from a file-like object."""
         ret = cls()
         ret = cls()
         section = None  # type: Optional[Tuple[bytes, ...]]
         section = None  # type: Optional[Tuple[bytes, ...]]
@@ -404,26 +405,23 @@ class ConfigFile(ConfigDict):
                     except ValueError:
                     except ValueError:
                         raise ValueError("expected trailing ]")
                         raise ValueError("expected trailing ]")
                     pts = line[1:last].split(b" ", 1)
                     pts = line[1:last].split(b" ", 1)
-                    line = line[last+1:]
+                    line = line[last + 1 :]
                     if len(pts) == 2:
                     if len(pts) == 2:
-                        if pts[1][:1] != b"\"" or pts[1][-1:] != b"\"":
-                            raise ValueError(
-                                "Invalid subsection %r" % pts[1])
+                        if pts[1][:1] != b'"' or pts[1][-1:] != b'"':
+                            raise ValueError("Invalid subsection %r" % pts[1])
                         else:
                         else:
                             pts[1] = pts[1][1:-1]
                             pts[1] = pts[1][1:-1]
                         if not _check_section_name(pts[0]):
                         if not _check_section_name(pts[0]):
-                            raise ValueError("invalid section name %r" %
-                                             pts[0])
+                            raise ValueError("invalid section name %r" % pts[0])
                         section = (pts[0], pts[1])
                         section = (pts[0], pts[1])
                     else:
                     else:
                         if not _check_section_name(pts[0]):
                         if not _check_section_name(pts[0]):
-                            raise ValueError(
-                                "invalid section name %r" % pts[0])
+                            raise ValueError("invalid section name %r" % pts[0])
                         pts = pts[0].split(b".", 1)
                         pts = pts[0].split(b".", 1)
                         if len(pts) == 2:
                         if len(pts) == 2:
                             section = (pts[0], pts[1])
                             section = (pts[0], pts[1])
                         else:
                         else:
-                            section = (pts[0], )
+                            section = (pts[0],)
                     ret._values.setdefault(section)
                     ret._values.setdefault(section)
                 if _strip_comments(line).strip() == b"":
                 if _strip_comments(line).strip() == b"":
                     continue
                     continue
@@ -456,9 +454,9 @@ class ConfigFile(ConfigDict):
         return ret
         return ret
 
 
     @classmethod
     @classmethod
-    def from_path(cls, path) -> 'ConfigFile':
+    def from_path(cls, path) -> "ConfigFile":
         """Read configuration from a file on disk."""
         """Read configuration from a file on disk."""
-        with GitFile(path, 'rb') as f:
+        with GitFile(path, "rb") as f:
             ret = cls.from_file(f)
             ret = cls.from_file(f)
             ret.path = path
             ret.path = path
             return ret
             return ret
@@ -467,7 +465,7 @@ class ConfigFile(ConfigDict):
         """Write configuration to a file on disk."""
         """Write configuration to a file on disk."""
         if path is None:
         if path is None:
             path = self.path
             path = self.path
-        with GitFile(path, 'wb') as f:
+        with GitFile(path, "wb") as f:
             self.write_to_file(f)
             self.write_to_file(f)
 
 
     def write_to_file(self, f: BinaryIO) -> None:
     def write_to_file(self, f: BinaryIO) -> None:
@@ -476,13 +474,12 @@ class ConfigFile(ConfigDict):
             try:
             try:
                 section_name, subsection_name = section
                 section_name, subsection_name = section
             except ValueError:
             except ValueError:
-                (section_name, ) = section
+                (section_name,) = section
                 subsection_name = None
                 subsection_name = None
             if subsection_name is None:
             if subsection_name is None:
                 f.write(b"[" + section_name + b"]\n")
                 f.write(b"[" + section_name + b"]\n")
             else:
             else:
-                f.write(b"[" + section_name +
-                        b" \"" + subsection_name + b"\"]\n")
+                f.write(b"[" + section_name + b' "' + subsection_name + b'"]\n')
             for key, value in values.items():
             for key, value in values.items():
                 if value is True:
                 if value is True:
                     value = b"true"
                     value = b"true"
@@ -495,11 +492,63 @@ class ConfigFile(ConfigDict):
 
 
 def get_xdg_config_home_path(*path_segments):
 def get_xdg_config_home_path(*path_segments):
     xdg_config_home = os.environ.get(
     xdg_config_home = os.environ.get(
-        "XDG_CONFIG_HOME", os.path.expanduser("~/.config/"),
+        "XDG_CONFIG_HOME",
+        os.path.expanduser("~/.config/"),
     )
     )
     return os.path.join(xdg_config_home, *path_segments)
     return os.path.join(xdg_config_home, *path_segments)
 
 
 
 
+def _find_git_in_win_path():
+    for exe in ("git.exe", "git.cmd"):
+        for path in os.environ.get("PATH", "").split(";"):
+            if os.path.exists(os.path.join(path, exe)):
+                # exe path is .../Git/bin/git.exe or .../Git/cmd/git.exe
+                git_dir, _bin_dir = os.path.split(path)
+                yield git_dir
+                break
+
+
+def _find_git_in_win_reg():
+    import platform
+    import winreg
+
+    if platform.machine() == "AMD64":
+        subkey = (
+            "SOFTWARE\\Wow6432Node\\Microsoft\\Windows\\"
+            "CurrentVersion\\Uninstall\\Git_is1"
+        )
+    else:
+        subkey = (
+            "SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\"
+            "Uninstall\\Git_is1"
+        )
+
+    for key in (winreg.HKEY_CURRENT_USER, winreg.HKEY_LOCAL_MACHINE):
+        try:
+            with winreg.OpenKey(key, subkey) as k:
+                val, typ = winreg.QueryValueEx(k, "InstallLocation")
+                if typ == winreg.REG_SZ:
+                    yield val
+        except OSError:
+            pass
+
+
+# There is no set standard for system config dirs on windows. We try the
+# following:
+#   - %PROGRAMDATA%/Git/config - (deprecated) Windows config dir per CGit docs
+#   - %PROGRAMFILES%/Git/etc/gitconfig - Git for Windows (msysgit) config dir
+#     Used if CGit installation (Git/bin/git.exe) is found in PATH in the
+#     system registry
+def get_win_system_paths():
+    if "PROGRAMDATA" in os.environ:
+        yield os.path.join(os.environ["PROGRAMDATA"], "Git", "config")
+
+    for git_dir in _find_git_in_win_path():
+        yield os.path.join(git_dir, "etc", "gitconfig")
+    for git_dir in _find_git_in_win_reg():
+        yield os.path.join(git_dir, "etc", "gitconfig")
+
+
 class StackedConfig(Config):
 class StackedConfig(Config):
     """Configuration which reads from multiple config files.."""
     """Configuration which reads from multiple config files.."""
 
 
@@ -526,6 +575,8 @@ class StackedConfig(Config):
 
 
         if "GIT_CONFIG_NOSYSTEM" not in os.environ:
         if "GIT_CONFIG_NOSYSTEM" not in os.environ:
             paths.append("/etc/gitconfig")
             paths.append("/etc/gitconfig")
+            if sys.platform == "win32":
+                paths.extend(get_win_system_paths())
 
 
         backends = []
         backends = []
         for path in paths:
         for path in paths:
@@ -538,7 +589,7 @@ class StackedConfig(Config):
 
 
     def get(self, section, name):
     def get(self, section, name):
         if not isinstance(section, tuple):
         if not isinstance(section, tuple):
-            section = (section, )
+            section = (section,)
         for backend in self.backends:
         for backend in self.backends:
             try:
             try:
                 return backend.get(section, name)
                 return backend.get(section, name)
@@ -555,15 +606,15 @@ class StackedConfig(Config):
 def parse_submodules(config):
 def parse_submodules(config):
     """Parse a gitmodules GitConfig file, returning submodules.
     """Parse a gitmodules GitConfig file, returning submodules.
 
 
-   Args:
-     config: A `ConfigFile`
-   Returns:
-     list of tuples (submodule path, url, name),
-       where name is quoted part of the section's name.
+    Args:
+      config: A `ConfigFile`
+    Returns:
+      list of tuples (submodule path, url, name),
+        where name is quoted part of the section's name.
     """
     """
     for section in config.keys():
     for section in config.keys():
         section_kind, section_name = section
         section_kind, section_name = section
-        if section_kind == b'submodule':
-            sm_path = config.get(section, b'path')
-            sm_url = config.get(section, b'url')
+        if section_kind == b"submodule":
+            sm_path = config.get(section, b"path")
+            sm_url = config.get(section, b"url")
             yield (sm_path, sm_url, section_name)
             yield (sm_path, sm_url, section_name)

+ 5 - 4
dulwich/contrib/__init__.py

@@ -21,10 +21,11 @@
 
 
 def test_suite():
 def test_suite():
     import unittest
     import unittest
+
     names = [
     names = [
-        'release_robot',
-        'swift',
-        ]
-    module_names = ['dulwich.contrib.test_' + name for name in names]
+        "release_robot",
+        "swift",
+    ]
+    module_names = ["dulwich.contrib.test_" + name for name in names]
     loader = unittest.TestLoader()
     loader = unittest.TestLoader()
     return loader.loadTestsFromNames(module_names)
     return loader.loadTestsFromNames(module_names)

+ 45 - 36
dulwich/contrib/diffstat.py

@@ -39,16 +39,16 @@ import re
 # only needs to detect git style diffs as this is for
 # only needs to detect git style diffs as this is for
 # use with dulwich
 # use with dulwich
 
 
-_git_header_name = re.compile(br'diff --git a/(.*) b/(.*)')
+_git_header_name = re.compile(br"diff --git a/(.*) b/(.*)")
 
 
-_GIT_HEADER_START = b'diff --git a/'
-_GIT_BINARY_START = b'Binary file'
-_GIT_RENAMEFROM_START = b'rename from'
-_GIT_RENAMETO_START = b'rename to'
-_GIT_CHUNK_START = b'@@'
-_GIT_ADDED_START = b'+'
-_GIT_DELETED_START = b'-'
-_GIT_UNCHANGED_START = b' '
+_GIT_HEADER_START = b"diff --git a/"
+_GIT_BINARY_START = b"Binary file"
+_GIT_RENAMEFROM_START = b"rename from"
+_GIT_RENAMETO_START = b"rename to"
+_GIT_CHUNK_START = b"@@"
+_GIT_ADDED_START = b"+"
+_GIT_DELETED_START = b"-"
+_GIT_UNCHANGED_START = b" "
 
 
 # emulate original full Patch class by just extracting
 # emulate original full Patch class by just extracting
 # filename and minimal chunk added/deleted information to
 # filename and minimal chunk added/deleted information to
@@ -89,9 +89,8 @@ def _parse_patch(lines):
         elif line.startswith(_GIT_RENAMEFROM_START) and in_git_header:
         elif line.startswith(_GIT_RENAMEFROM_START) and in_git_header:
             currentfile = line[12:]
             currentfile = line[12:]
         elif line.startswith(_GIT_RENAMETO_START) and in_git_header:
         elif line.startswith(_GIT_RENAMETO_START) and in_git_header:
-            currentfile += b' => %s' % line[10:]
-        elif line.startswith(_GIT_CHUNK_START) and \
-                (in_patch_chunk or in_git_header):
+            currentfile += b" => %s" % line[10:]
+        elif line.startswith(_GIT_CHUNK_START) and (in_patch_chunk or in_git_header):
             in_patch_chunk = True
             in_patch_chunk = True
             in_git_header = False
             in_git_header = False
         elif line.startswith(_GIT_ADDED_START) and in_patch_chunk:
         elif line.startswith(_GIT_ADDED_START) and in_patch_chunk:
@@ -130,8 +129,8 @@ def diffstat(lines, max_width=80):
         insert.append(i)
         insert.append(i)
         delete.append(d)
         delete.append(d)
         namelen = max(namelen, len(filename))
         namelen = max(namelen, len(filename))
-        maxdiff = max(maxdiff, i+d)
-    output = b''
+        maxdiff = max(maxdiff, i + d)
+    output = b""
     statlen = len(str(maxdiff))  # stats column width
     statlen = len(str(maxdiff))  # stats column width
     for i, n in enumerate(names):
     for i, n in enumerate(names):
         binaryfile = nametypes[i]
         binaryfile = nametypes[i]
@@ -139,16 +138,21 @@ def diffstat(lines, max_width=80):
         # note b'%d' % namelen is not supported until Python 3.5
         # note b'%d' % namelen is not supported until Python 3.5
         # To convert an int to a format width specifier for byte
         # To convert an int to a format width specifier for byte
         # strings use str(namelen).encode('ascii')
         # strings use str(namelen).encode('ascii')
-        format = b' %-' + str(namelen).encode('ascii') + \
-            b's | %' + str(statlen).encode('ascii') + b's %s\n'
-        binformat = b' %-' + str(namelen).encode('ascii') + b's | %s\n'
+        format = (
+            b" %-"
+            + str(namelen).encode("ascii")
+            + b"s | %"
+            + str(statlen).encode("ascii")
+            + b"s %s\n"
+        )
+        binformat = b" %-" + str(namelen).encode("ascii") + b"s | %s\n"
         if not binaryfile:
         if not binaryfile:
-            hist = b''
+            hist = b""
             # -- calculating histogram --
             # -- calculating histogram --
-            width = len(format % (b'', b'', b''))
+            width = len(format % (b"", b"", b""))
             histwidth = max(2, max_width - width)
             histwidth = max(2, max_width - width)
             if maxdiff < histwidth:
             if maxdiff < histwidth:
-                hist = b'+'*insert[i] + b'-'*delete[i]
+                hist = b"+" * insert[i] + b"-" * delete[i]
             else:
             else:
                 iratio = (float(insert[i]) / maxdiff) * histwidth
                 iratio = (float(insert[i]) / maxdiff) * histwidth
                 dratio = (float(delete[i]) / maxdiff) * histwidth
                 dratio = (float(delete[i]) / maxdiff) * histwidth
@@ -165,15 +169,20 @@ def diffstat(lines, max_width=80):
                     dwidth = int(dratio)
                     dwidth = int(dratio)
                     if dwidth == 0 and 0 < dratio < 1:
                     if dwidth == 0 and 0 < dratio < 1:
                         dwidth = 1
                         dwidth = 1
-                hist = b'+'*int(iwidth) + b'-'*int(dwidth)
-            output += (format % (bytes(names[i]),
-                                 str(insert[i] + delete[i]).encode('ascii'),
-                                 hist))
+                hist = b"+" * int(iwidth) + b"-" * int(dwidth)
+            output += format % (
+                bytes(names[i]),
+                str(insert[i] + delete[i]).encode("ascii"),
+                hist,
+            )
         else:
         else:
-            output += (binformat % (bytes(names[i]), b'Bin'))
+            output += binformat % (bytes(names[i]), b"Bin")
 
 
-    output += (b' %d files changed, %d insertions(+), %d deletions(-)'
-               % (len(names), sum(insert), sum(delete)))
+    output += b" %d files changed, %d insertions(+), %d deletions(-)" % (
+        len(names),
+        sum(insert),
+        sum(delete),
+    )
     return output
     return output
 
 
 
 
@@ -182,12 +191,12 @@ def main():
     # allow diffstat.py to also be used from the comand line
     # allow diffstat.py to also be used from the comand line
     if len(sys.argv) > 1:
     if len(sys.argv) > 1:
         diffpath = argv[1]
         diffpath = argv[1]
-        data = b''
-        with open(diffpath, 'rb') as f:
+        data = b""
+        with open(diffpath, "rb") as f:
             data = f.read()
             data = f.read()
-        lines = data.split(b'\n')
+        lines = data.split(b"\n")
         result = diffstat(lines)
         result = diffstat(lines)
-        print(result.decode('utf-8'))
+        print(result.decode("utf-8"))
         return 0
         return 0
 
 
     # if no path argument to a diff file is passed in, run
     # if no path argument to a diff file is passed in, run
@@ -314,7 +323,7 @@ index 3b41fd80..64914c78 100644
  2. open Sigil.app to the normal nearly blank template epub it generates when opened
  2. open Sigil.app to the normal nearly blank template epub it generates when opened
  3. use Plugins->Manage Plugins menu and make sure the "Use Bundled Python" checkbox is checked
  3. use Plugins->Manage Plugins menu and make sure the "Use Bundled Python" checkbox is checked
  4. use the "Add Plugin" button to navigate to and add testplugin.zip and then hit "Okay" to exit the Manage Plugins Dialog
  4. use the "Add Plugin" button to navigate to and add testplugin.zip and then hit "Okay" to exit the Manage Plugins Dialog
-"""     # noqa: E501 W293
+"""  # noqa: E501 W293
 
 
     testoutput = b""" docs/qt512.7_remove_bad_workaround.patch            | 15 ++++++++++++
     testoutput = b""" docs/qt512.7_remove_bad_workaround.patch            | 15 ++++++++++++
  docs/testplugin_v017.zip                            | Bin
  docs/testplugin_v017.zip                            | Bin
@@ -324,17 +333,17 @@ index 3b41fd80..64914c78 100644
  5 files changed, 16 insertions(+), 27 deletions(-)"""  # noqa: W291
  5 files changed, 16 insertions(+), 27 deletions(-)"""  # noqa: W291
 
 
     # return 0 on success otherwise return -1
     # return 0 on success otherwise return -1
-    result = diffstat(selftest.split(b'\n'))
+    result = diffstat(selftest.split(b"\n"))
     if result == testoutput:
     if result == testoutput:
         print("self test passed")
         print("self test passed")
         return 0
         return 0
     print("self test failed")
     print("self test failed")
     print("Received:")
     print("Received:")
-    print(result.decode('utf-8'))
+    print(result.decode("utf-8"))
     print("Expected:")
     print("Expected:")
-    print(testoutput.decode('utf-8'))
+    print(testoutput.decode("utf-8"))
     return -1
     return -1
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     sys.exit(main())
     sys.exit(main())

+ 18 - 12
dulwich/contrib/paramiko_vendor.py

@@ -35,7 +35,6 @@ import paramiko.client
 
 
 
 
 class _ParamikoWrapper(object):
 class _ParamikoWrapper(object):
-
     def __init__(self, client, channel):
     def __init__(self, client, channel):
         self.client = client
         self.client = client
         self.channel = channel
         self.channel = channel
@@ -59,7 +58,7 @@ class _ParamikoWrapper(object):
 
 
         # Closed socket
         # Closed socket
         if not data:
         if not data:
-            return b''
+            return b""
 
 
         # Read more if needed
         # Read more if needed
         if n and data_len < n:
         if n and data_len < n:
@@ -77,25 +76,32 @@ class ParamikoSSHVendor(object):
     def __init__(self, **kwargs):
     def __init__(self, **kwargs):
         self.kwargs = kwargs
         self.kwargs = kwargs
 
 
-    def run_command(self, host, command,
-                    username=None, port=None,
-                    password=None, pkey=None,
-                    key_filename=None, **kwargs):
+    def run_command(
+        self,
+        host,
+        command,
+        username=None,
+        port=None,
+        password=None,
+        pkey=None,
+        key_filename=None,
+        **kwargs
+    ):
 
 
         client = paramiko.SSHClient()
         client = paramiko.SSHClient()
 
 
-        connection_kwargs = {'hostname': host}
+        connection_kwargs = {"hostname": host}
         connection_kwargs.update(self.kwargs)
         connection_kwargs.update(self.kwargs)
         if username:
         if username:
-            connection_kwargs['username'] = username
+            connection_kwargs["username"] = username
         if port:
         if port:
-            connection_kwargs['port'] = port
+            connection_kwargs["port"] = port
         if password:
         if password:
-            connection_kwargs['password'] = password
+            connection_kwargs["password"] = password
         if pkey:
         if pkey:
-            connection_kwargs['pkey'] = pkey
+            connection_kwargs["pkey"] = pkey
         if key_filename:
         if key_filename:
-            connection_kwargs['key_filename'] = key_filename
+            connection_kwargs["key_filename"] = key_filename
         connection_kwargs.update(kwargs)
         connection_kwargs.update(kwargs)
 
 
         policy = paramiko.client.MissingHostKeyPolicy()
         policy = paramiko.client.MissingHostKeyPolicy()

+ 11 - 11
dulwich/contrib/release_robot.py

@@ -52,8 +52,8 @@ import time
 from dulwich.repo import Repo
 from dulwich.repo import Repo
 
 
 # CONSTANTS
 # CONSTANTS
-PROJDIR = '.'
-PATTERN = r'[ a-zA-Z_\-]*([\d\.]+[\-\w\.]*)'
+PROJDIR = "."
+PATTERN = r"[ a-zA-Z_\-]*([\d\.]+[\-\w\.]*)"
 
 
 
 
 def get_recent_tags(projdir=PROJDIR):
 def get_recent_tags(projdir=PROJDIR):
@@ -74,15 +74,15 @@ def get_recent_tags(projdir=PROJDIR):
         tags = {}  # empty dictionary to hold tags, commits and datetimes
         tags = {}  # empty dictionary to hold tags, commits and datetimes
         # iterate over refs in repository
         # iterate over refs in repository
         for key, value in refs.items():
         for key, value in refs.items():
-            key = key.decode('utf-8')  # compatible with Python-3
+            key = key.decode("utf-8")  # compatible with Python-3
             obj = project.get_object(value)  # dulwich object from SHA-1
             obj = project.get_object(value)  # dulwich object from SHA-1
             # don't just check if object is "tag" b/c it could be a "commit"
             # don't just check if object is "tag" b/c it could be a "commit"
             # instead check if "tags" is in the ref-name
             # instead check if "tags" is in the ref-name
-            if u'tags' not in key:
+            if u"tags" not in key:
                 # skip ref if not a tag
                 # skip ref if not a tag
                 continue
                 continue
             # strip the leading text from refs to get "tag name"
             # strip the leading text from refs to get "tag name"
-            _, tag = key.rsplit(u'/', 1)
+            _, tag = key.rsplit(u"/", 1)
             # check if tag object is "commit" or "tag" pointing to a "commit"
             # check if tag object is "commit" or "tag" pointing to a "commit"
             try:
             try:
                 commit = obj.object  # a tuple (commit class, commit id)
                 commit = obj.object  # a tuple (commit class, commit id)
@@ -92,8 +92,8 @@ def get_recent_tags(projdir=PROJDIR):
             else:
             else:
                 tag_meta = (
                 tag_meta = (
                     datetime.datetime(*time.gmtime(obj.tag_time)[:6]),
                     datetime.datetime(*time.gmtime(obj.tag_time)[:6]),
-                    obj.id.decode('utf-8'),
-                    obj.name.decode('utf-8')
+                    obj.id.decode("utf-8"),
+                    obj.name.decode("utf-8"),
                 )  # compatible with Python-3
                 )  # compatible with Python-3
                 commit = project.get_object(commit[1])  # commit object
                 commit = project.get_object(commit[1])  # commit object
             # get tag commit datetime, but dulwich returns seconds since
             # get tag commit datetime, but dulwich returns seconds since
@@ -101,9 +101,9 @@ def get_recent_tags(projdir=PROJDIR):
             # timetuple then convert to datetime
             # timetuple then convert to datetime
             tags[tag] = [
             tags[tag] = [
                 datetime.datetime(*time.gmtime(commit.commit_time)[:6]),
                 datetime.datetime(*time.gmtime(commit.commit_time)[:6]),
-                commit.id.decode('utf-8'),
-                commit.author.decode('utf-8'),
-                tag_meta
+                commit.id.decode("utf-8"),
+                commit.author.decode("utf-8"),
+                tag_meta,
             ]  # compatible with Python-3
             ]  # compatible with Python-3
 
 
     # return list of tags sorted by their datetimes from newest to oldest
     # return list of tags sorted by their datetimes from newest to oldest
@@ -139,7 +139,7 @@ def get_current_version(projdir=PROJDIR, pattern=PATTERN, logger=None):
     return current_version
     return current_version
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     if len(sys.argv) > 1:
     if len(sys.argv) > 1:
         _PROJDIR = sys.argv[1]
         _PROJDIR = sys.argv[1]
     else:
     else:

+ 221 - 188
dulwich/contrib/swift.py

@@ -41,7 +41,7 @@ from geventhttpclient import HTTPClient
 from dulwich.greenthreads import (
 from dulwich.greenthreads import (
     GreenThreadsMissingObjectFinder,
     GreenThreadsMissingObjectFinder,
     GreenThreadsObjectStoreIterator,
     GreenThreadsObjectStoreIterator,
-    )
+)
 
 
 from dulwich.lru_cache import LRUSizeCache
 from dulwich.lru_cache import LRUSizeCache
 from dulwich.objects import (
 from dulwich.objects import (
@@ -50,12 +50,12 @@ from dulwich.objects import (
     Tree,
     Tree,
     Tag,
     Tag,
     S_ISGITLINK,
     S_ISGITLINK,
-    )
+)
 from dulwich.object_store import (
 from dulwich.object_store import (
     PackBasedObjectStore,
     PackBasedObjectStore,
     PACKDIR,
     PACKDIR,
     INFODIR,
     INFODIR,
-    )
+)
 from dulwich.pack import (
 from dulwich.pack import (
     PackData,
     PackData,
     Pack,
     Pack,
@@ -70,21 +70,21 @@ from dulwich.pack import (
     _compute_object_size,
     _compute_object_size,
     unpack_object,
     unpack_object,
     write_pack_object,
     write_pack_object,
-    )
+)
 from dulwich.protocol import TCP_GIT_PORT
 from dulwich.protocol import TCP_GIT_PORT
 from dulwich.refs import (
 from dulwich.refs import (
     InfoRefsContainer,
     InfoRefsContainer,
     read_info_refs,
     read_info_refs,
     write_info_refs,
     write_info_refs,
-    )
+)
 from dulwich.repo import (
 from dulwich.repo import (
     BaseRepo,
     BaseRepo,
     OBJECTDIR,
     OBJECTDIR,
-    )
+)
 from dulwich.server import (
 from dulwich.server import (
     Backend,
     Backend,
     TCPGitServer,
     TCPGitServer,
-    )
+)
 
 
 import json
 import json
 
 
@@ -120,9 +120,8 @@ cache_length = 20
 
 
 
 
 class PackInfoObjectStoreIterator(GreenThreadsObjectStoreIterator):
 class PackInfoObjectStoreIterator(GreenThreadsObjectStoreIterator):
-
     def __len__(self):
     def __len__(self):
-        while len(self.finder.objects_to_send):
+        while self.finder.objects_to_send:
             for _ in range(0, len(self.finder.objects_to_send)):
             for _ in range(0, len(self.finder.objects_to_send)):
                 sha = self.finder.next()
                 sha = self.finder.next()
                 self._shas.append(sha)
                 self._shas.append(sha)
@@ -130,7 +129,6 @@ class PackInfoObjectStoreIterator(GreenThreadsObjectStoreIterator):
 
 
 
 
 class PackInfoMissingObjectFinder(GreenThreadsMissingObjectFinder):
 class PackInfoMissingObjectFinder(GreenThreadsMissingObjectFinder):
-
     def next(self):
     def next(self):
         while True:
         while True:
             if not self.objects_to_send:
             if not self.objects_to_send:
@@ -171,7 +169,7 @@ def load_conf(path=None, file=None):
     confpath = None
     confpath = None
     if not path:
     if not path:
         try:
         try:
-            confpath = os.environ['DULWICH_SWIFT_CFG']
+            confpath = os.environ["DULWICH_SWIFT_CFG"]
         except KeyError:
         except KeyError:
             raise Exception("You need to specify a configuration file")
             raise Exception("You need to specify a configuration file")
     else:
     else:
@@ -203,8 +201,11 @@ def pack_info_create(pack_data, pack_index):
             info[obj.id] = (obj.type_num, obj.parents, obj.tree)
             info[obj.id] = (obj.type_num, obj.parents, obj.tree)
         # Tree
         # Tree
         elif obj.type_num == Tree.type_num:
         elif obj.type_num == Tree.type_num:
-            shas = [(s, n, not stat.S_ISDIR(m)) for
-                    n, m, s in obj.items() if not S_ISGITLINK(m)]
+            shas = [
+                (s, n, not stat.S_ISDIR(m))
+                for n, m, s in obj.items()
+                if not S_ISGITLINK(m)
+            ]
             info[obj.id] = (obj.type_num, shas)
             info[obj.id] = (obj.type_num, shas)
         # Blob
         # Blob
         elif obj.type_num == Blob.type_num:
         elif obj.type_num == Blob.type_num:
@@ -233,11 +234,10 @@ class SwiftException(Exception):
 
 
 
 
 class SwiftConnector(object):
 class SwiftConnector(object):
-    """A Connector to swift that manage authentication and errors catching
-    """
+    """A Connector to swift that manage authentication and errors catching"""
 
 
     def __init__(self, root, conf):
     def __init__(self, root, conf):
-        """ Initialize a SwiftConnector
+        """Initialize a SwiftConnector
 
 
         Args:
         Args:
           root: The swift container that will act as Git bare repository
           root: The swift container that will act as Git bare repository
@@ -246,18 +246,15 @@ class SwiftConnector(object):
         self.conf = conf
         self.conf = conf
         self.auth_ver = self.conf.get("swift", "auth_ver")
         self.auth_ver = self.conf.get("swift", "auth_ver")
         if self.auth_ver not in ["1", "2"]:
         if self.auth_ver not in ["1", "2"]:
-            raise NotImplementedError(
-                "Wrong authentication version use either 1 or 2")
+            raise NotImplementedError("Wrong authentication version use either 1 or 2")
         self.auth_url = self.conf.get("swift", "auth_url")
         self.auth_url = self.conf.get("swift", "auth_url")
         self.user = self.conf.get("swift", "username")
         self.user = self.conf.get("swift", "username")
         self.password = self.conf.get("swift", "password")
         self.password = self.conf.get("swift", "password")
-        self.concurrency = self.conf.getint('swift', 'concurrency') or 10
-        self.http_timeout = self.conf.getint('swift', 'http_timeout') or 20
-        self.http_pool_length = \
-            self.conf.getint('swift', 'http_pool_length') or 10
+        self.concurrency = self.conf.getint("swift", "concurrency") or 10
+        self.http_timeout = self.conf.getint("swift", "http_timeout") or 20
+        self.http_pool_length = self.conf.getint("swift", "http_pool_length") or 10
         self.region_name = self.conf.get("swift", "region_name") or "RegionOne"
         self.region_name = self.conf.get("swift", "region_name") or "RegionOne"
-        self.endpoint_type = \
-            self.conf.get("swift", "endpoint_type") or "internalURL"
+        self.endpoint_type = self.conf.get("swift", "endpoint_type") or "internalURL"
         self.cache_length = self.conf.getint("swift", "cache_length") or 20
         self.cache_length = self.conf.getint("swift", "cache_length") or 20
         self.chunk_length = self.conf.getint("swift", "chunk_length") or 12228
         self.chunk_length = self.conf.getint("swift", "chunk_length") or 12228
         self.root = root
         self.root = root
@@ -267,16 +264,18 @@ class SwiftConnector(object):
         else:
         else:
             self.storage_url, self.token = self.swift_auth_v2()
             self.storage_url, self.token = self.swift_auth_v2()
 
 
-        token_header = {'X-Auth-Token': str(self.token)}
-        self.httpclient = \
-            HTTPClient.from_url(str(self.storage_url),
-                                concurrency=self.http_pool_length,
-                                block_size=block_size,
-                                connection_timeout=self.http_timeout,
-                                network_timeout=self.http_timeout,
-                                headers=token_header)
-        self.base_path = str(posixpath.join(
-                urlparse.urlparse(self.storage_url).path, self.root))
+        token_header = {"X-Auth-Token": str(self.token)}
+        self.httpclient = HTTPClient.from_url(
+            str(self.storage_url),
+            concurrency=self.http_pool_length,
+            block_size=block_size,
+            connection_timeout=self.http_timeout,
+            network_timeout=self.http_timeout,
+            headers=token_header,
+        )
+        self.base_path = str(
+            posixpath.join(urlparse.urlparse(self.storage_url).path, self.root)
+        )
 
 
     def swift_auth_v1(self):
     def swift_auth_v1(self):
         self.user = self.user.replace(";", ":")
         self.user = self.user.replace(";", ":")
@@ -284,62 +283,68 @@ class SwiftConnector(object):
             self.auth_url,
             self.auth_url,
             connection_timeout=self.http_timeout,
             connection_timeout=self.http_timeout,
             network_timeout=self.http_timeout,
             network_timeout=self.http_timeout,
-            )
-        headers = {'X-Auth-User': self.user,
-                   'X-Auth-Key': self.password}
+        )
+        headers = {"X-Auth-User": self.user, "X-Auth-Key": self.password}
         path = urlparse.urlparse(self.auth_url).path
         path = urlparse.urlparse(self.auth_url).path
 
 
-        ret = auth_httpclient.request('GET', path, headers=headers)
+        ret = auth_httpclient.request("GET", path, headers=headers)
 
 
         # Should do something with redirections (301 in my case)
         # Should do something with redirections (301 in my case)
 
 
         if ret.status_code < 200 or ret.status_code >= 300:
         if ret.status_code < 200 or ret.status_code >= 300:
-            raise SwiftException('AUTH v1.0 request failed on ' +
-                                 '%s with error code %s (%s)'
-                                 % (str(auth_httpclient.get_base_url()) +
-                                    path, ret.status_code,
-                                    str(ret.items())))
-        storage_url = ret['X-Storage-Url']
-        token = ret['X-Auth-Token']
+            raise SwiftException(
+                "AUTH v1.0 request failed on "
+                + "%s with error code %s (%s)"
+                % (
+                    str(auth_httpclient.get_base_url()) + path,
+                    ret.status_code,
+                    str(ret.items()),
+                )
+            )
+        storage_url = ret["X-Storage-Url"]
+        token = ret["X-Auth-Token"]
         return storage_url, token
         return storage_url, token
 
 
     def swift_auth_v2(self):
     def swift_auth_v2(self):
-        self.tenant, self.user = self.user.split(';')
+        self.tenant, self.user = self.user.split(";")
         auth_dict = {}
         auth_dict = {}
-        auth_dict['auth'] = {'passwordCredentials':
-                             {
-                                 'username': self.user,
-                                 'password': self.password,
-                             },
-                             'tenantName': self.tenant}
+        auth_dict["auth"] = {
+            "passwordCredentials": {
+                "username": self.user,
+                "password": self.password,
+            },
+            "tenantName": self.tenant,
+        }
         auth_json = json.dumps(auth_dict)
         auth_json = json.dumps(auth_dict)
-        headers = {'Content-Type': 'application/json'}
+        headers = {"Content-Type": "application/json"}
         auth_httpclient = HTTPClient.from_url(
         auth_httpclient = HTTPClient.from_url(
             self.auth_url,
             self.auth_url,
             connection_timeout=self.http_timeout,
             connection_timeout=self.http_timeout,
             network_timeout=self.http_timeout,
             network_timeout=self.http_timeout,
-            )
+        )
         path = urlparse.urlparse(self.auth_url).path
         path = urlparse.urlparse(self.auth_url).path
-        if not path.endswith('tokens'):
-            path = posixpath.join(path, 'tokens')
-        ret = auth_httpclient.request('POST', path,
-                                      body=auth_json,
-                                      headers=headers)
+        if not path.endswith("tokens"):
+            path = posixpath.join(path, "tokens")
+        ret = auth_httpclient.request("POST", path, body=auth_json, headers=headers)
 
 
         if ret.status_code < 200 or ret.status_code >= 300:
         if ret.status_code < 200 or ret.status_code >= 300:
-            raise SwiftException('AUTH v2.0 request failed on ' +
-                                 '%s with error code %s (%s)'
-                                 % (str(auth_httpclient.get_base_url()) +
-                                    path, ret.status_code,
-                                    str(ret.items())))
+            raise SwiftException(
+                "AUTH v2.0 request failed on "
+                + "%s with error code %s (%s)"
+                % (
+                    str(auth_httpclient.get_base_url()) + path,
+                    ret.status_code,
+                    str(ret.items()),
+                )
+            )
         auth_ret_json = json.loads(ret.read())
         auth_ret_json = json.loads(ret.read())
-        token = auth_ret_json['access']['token']['id']
-        catalogs = auth_ret_json['access']['serviceCatalog']
-        object_store = [o_store for o_store in catalogs if
-                        o_store['type'] == 'object-store'][0]
-        endpoints = object_store['endpoints']
-        endpoint = [endp for endp in endpoints if
-                    endp["region"] == self.region_name][0]
+        token = auth_ret_json["access"]["token"]["id"]
+        catalogs = auth_ret_json["access"]["serviceCatalog"]
+        object_store = [
+            o_store for o_store in catalogs if o_store["type"] == "object-store"
+        ][0]
+        endpoints = object_store["endpoints"]
+        endpoint = [endp for endp in endpoints if endp["region"] == self.region_name][0]
         return endpoint[self.endpoint_type], token
         return endpoint[self.endpoint_type], token
 
 
     def test_root_exists(self):
     def test_root_exists(self):
@@ -347,12 +352,13 @@ class SwiftConnector(object):
 
 
         Returns: True if exist or None it not
         Returns: True if exist or None it not
         """
         """
-        ret = self.httpclient.request('HEAD', self.base_path)
+        ret = self.httpclient.request("HEAD", self.base_path)
         if ret.status_code == 404:
         if ret.status_code == 404:
             return None
             return None
         if ret.status_code < 200 or ret.status_code > 300:
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('HEAD request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "HEAD request failed with error code %s" % ret.status_code
+            )
         return True
         return True
 
 
     def create_root(self):
     def create_root(self):
@@ -362,10 +368,11 @@ class SwiftConnector(object):
           SwiftException: if unable to create
           SwiftException: if unable to create
         """
         """
         if not self.test_root_exists():
         if not self.test_root_exists():
-            ret = self.httpclient.request('PUT', self.base_path)
+            ret = self.httpclient.request("PUT", self.base_path)
             if ret.status_code < 200 or ret.status_code > 300:
             if ret.status_code < 200 or ret.status_code > 300:
-                raise SwiftException('PUT request failed with error code %s'
-                                     % ret.status_code)
+                raise SwiftException(
+                    "PUT request failed with error code %s" % ret.status_code
+                )
 
 
     def get_container_objects(self):
     def get_container_objects(self):
         """Retrieve objects list in a container
         """Retrieve objects list in a container
@@ -373,14 +380,15 @@ class SwiftConnector(object):
         Returns: A list of dict that describe objects
         Returns: A list of dict that describe objects
                  or None if container does not exist
                  or None if container does not exist
         """
         """
-        qs = '?format=json'
+        qs = "?format=json"
         path = self.base_path + qs
         path = self.base_path + qs
-        ret = self.httpclient.request('GET', path)
+        ret = self.httpclient.request("GET", path)
         if ret.status_code == 404:
         if ret.status_code == 404:
             return None
             return None
         if ret.status_code < 200 or ret.status_code > 300:
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('GET request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "GET request failed with error code %s" % ret.status_code
+            )
         content = ret.read()
         content = ret.read()
         return json.loads(content)
         return json.loads(content)
 
 
@@ -392,13 +400,14 @@ class SwiftConnector(object):
         Returns:
         Returns:
           A dict that describe the object or None if object does not exist
           A dict that describe the object or None if object does not exist
         """
         """
-        path = self.base_path + '/' + name
-        ret = self.httpclient.request('HEAD', path)
+        path = self.base_path + "/" + name
+        ret = self.httpclient.request("HEAD", path)
         if ret.status_code == 404:
         if ret.status_code == 404:
             return None
             return None
         if ret.status_code < 200 or ret.status_code > 300:
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('HEAD request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "HEAD request failed with error code %s" % ret.status_code
+            )
         resp_headers = {}
         resp_headers = {}
         for header, value in ret.items():
         for header, value in ret.items():
             resp_headers[header.lower()] = value
             resp_headers[header.lower()] = value
@@ -415,13 +424,11 @@ class SwiftConnector(object):
         """
         """
         content.seek(0)
         content.seek(0)
         data = content.read()
         data = content.read()
-        path = self.base_path + '/' + name
-        headers = {'Content-Length': str(len(data))}
+        path = self.base_path + "/" + name
+        headers = {"Content-Length": str(len(data))}
 
 
         def _send():
         def _send():
-            ret = self.httpclient.request('PUT', path,
-                                          body=data,
-                                          headers=headers)
+            ret = self.httpclient.request("PUT", path, body=data, headers=headers)
             return ret
             return ret
 
 
         try:
         try:
@@ -432,8 +439,9 @@ class SwiftConnector(object):
             ret = _send()
             ret = _send()
 
 
         if ret.status_code < 200 or ret.status_code > 300:
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('PUT request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "PUT request failed with error code %s" % ret.status_code
+            )
 
 
     def get_object(self, name, range=None):
     def get_object(self, name, range=None):
         """Retrieve an object
         """Retrieve an object
@@ -447,14 +455,15 @@ class SwiftConnector(object):
         """
         """
         headers = {}
         headers = {}
         if range:
         if range:
-            headers['Range'] = 'bytes=%s' % range
-        path = self.base_path + '/' + name
-        ret = self.httpclient.request('GET', path, headers=headers)
+            headers["Range"] = "bytes=%s" % range
+        path = self.base_path + "/" + name
+        ret = self.httpclient.request("GET", path, headers=headers)
         if ret.status_code == 404:
         if ret.status_code == 404:
             return None
             return None
         if ret.status_code < 200 or ret.status_code > 300:
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('GET request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "GET request failed with error code %s" % ret.status_code
+            )
         content = ret.read()
         content = ret.read()
 
 
         if range:
         if range:
@@ -469,11 +478,12 @@ class SwiftConnector(object):
         Raises:
         Raises:
           SwiftException: if unable to delete
           SwiftException: if unable to delete
         """
         """
-        path = self.base_path + '/' + name
-        ret = self.httpclient.request('DELETE', path)
+        path = self.base_path + "/" + name
+        ret = self.httpclient.request("DELETE", path)
         if ret.status_code < 200 or ret.status_code > 300:
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('DELETE request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "DELETE request failed with error code %s" % ret.status_code
+            )
 
 
     def del_root(self):
     def del_root(self):
         """Delete the root container by removing container content
         """Delete the root container by removing container content
@@ -482,11 +492,12 @@ class SwiftConnector(object):
           SwiftException: if unable to delete
           SwiftException: if unable to delete
         """
         """
         for obj in self.get_container_objects():
         for obj in self.get_container_objects():
-            self.del_object(obj['name'])
-        ret = self.httpclient.request('DELETE', self.base_path)
+            self.del_object(obj["name"])
+        ret = self.httpclient.request("DELETE", self.base_path)
         if ret.status_code < 200 or ret.status_code > 300:
         if ret.status_code < 200 or ret.status_code > 300:
-            raise SwiftException('DELETE request failed with error code %s'
-                                 % ret.status_code)
+            raise SwiftException(
+                "DELETE request failed with error code %s" % ret.status_code
+            )
 
 
 
 
 class SwiftPackReader(object):
 class SwiftPackReader(object):
@@ -512,7 +523,7 @@ class SwiftPackReader(object):
         self.pack_length = pack_length
         self.pack_length = pack_length
         self.offset = 0
         self.offset = 0
         self.base_offset = 0
         self.base_offset = 0
-        self.buff = b''
+        self.buff = b""
         self.buff_length = self.scon.chunk_length
         self.buff_length = self.scon.chunk_length
 
 
     def _read(self, more=False):
     def _read(self, more=False):
@@ -531,16 +542,16 @@ class SwiftPackReader(object):
         Returns:
         Returns:
           a bytestring
           a bytestring
         """
         """
-        end = self.offset+length
+        end = self.offset + length
         if self.base_offset + end > self.pack_length:
         if self.base_offset + end > self.pack_length:
-            data = self.buff[self.offset:]
+            data = self.buff[self.offset :]
             self.offset = end
             self.offset = end
             return data
             return data
         if end > len(self.buff):
         if end > len(self.buff):
             # Need to read more from swift
             # Need to read more from swift
             self._read(more=True)
             self._read(more=True)
             return self.read(length)
             return self.read(length)
-        data = self.buff[self.offset:end]
+        data = self.buff[self.offset : end]
         self.offset = end
         self.offset = end
         return data
         return data
 
 
@@ -570,7 +581,7 @@ class SwiftPackData(PackData):
     """
     """
 
 
     def __init__(self, scon, filename):
     def __init__(self, scon, filename):
-        """ Initialize a SwiftPackReader
+        """Initialize a SwiftPackReader
 
 
         Args:
         Args:
           scon: a `SwiftConnector` instance
           scon: a `SwiftConnector` instance
@@ -580,27 +591,26 @@ class SwiftPackData(PackData):
         self._filename = filename
         self._filename = filename
         self._header_size = 12
         self._header_size = 12
         headers = self.scon.get_object_stat(self._filename)
         headers = self.scon.get_object_stat(self._filename)
-        self.pack_length = int(headers['content-length'])
-        pack_reader = SwiftPackReader(self.scon, self._filename,
-                                      self.pack_length)
+        self.pack_length = int(headers["content-length"])
+        pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length)
         (version, self._num_objects) = read_pack_header(pack_reader.read)
         (version, self._num_objects) = read_pack_header(pack_reader.read)
-        self._offset_cache = LRUSizeCache(1024*1024*self.scon.cache_length,
-                                          compute_size=_compute_object_size)
+        self._offset_cache = LRUSizeCache(
+            1024 * 1024 * self.scon.cache_length,
+            compute_size=_compute_object_size,
+        )
         self.pack = None
         self.pack = None
 
 
     def get_object_at(self, offset):
     def get_object_at(self, offset):
         if offset in self._offset_cache:
         if offset in self._offset_cache:
             return self._offset_cache[offset]
             return self._offset_cache[offset]
         assert offset >= self._header_size
         assert offset >= self._header_size
-        pack_reader = SwiftPackReader(self.scon, self._filename,
-                                      self.pack_length)
+        pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length)
         pack_reader.seek(offset)
         pack_reader.seek(offset)
         unpacked, _ = unpack_object(pack_reader.read)
         unpacked, _ = unpack_object(pack_reader.read)
         return (unpacked.pack_type_num, unpacked._obj())
         return (unpacked.pack_type_num, unpacked._obj())
 
 
     def get_stored_checksum(self):
     def get_stored_checksum(self):
-        pack_reader = SwiftPackReader(self.scon, self._filename,
-                                      self.pack_length)
+        pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length)
         return pack_reader.read_checksum()
         return pack_reader.read_checksum()
 
 
     def close(self):
     def close(self):
@@ -616,15 +626,13 @@ class SwiftPack(Pack):
     """
     """
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
-        self.scon = kwargs['scon']
-        del kwargs['scon']
+        self.scon = kwargs["scon"]
+        del kwargs["scon"]
         super(SwiftPack, self).__init__(*args, **kwargs)
         super(SwiftPack, self).__init__(*args, **kwargs)
-        self._pack_info_path = self._basename + '.info'
+        self._pack_info_path = self._basename + ".info"
         self._pack_info = None
         self._pack_info = None
-        self._pack_info_load = lambda: load_pack_info(self._pack_info_path,
-                                                      self.scon)
-        self._idx_load = lambda: swift_load_pack_index(self.scon,
-                                                       self._idx_path)
+        self._pack_info_load = lambda: load_pack_info(self._pack_info_path, self.scon)
+        self._idx_load = lambda: swift_load_pack_index(self.scon, self._idx_path)
         self._data_load = lambda: SwiftPackData(self.scon, self._data_path)
         self._data_load = lambda: SwiftPackData(self.scon, self._data_path)
 
 
     @property
     @property
@@ -641,6 +649,7 @@ class SwiftObjectStore(PackBasedObjectStore):
     Allow to manage a bare Git repository from Openstack Swift.
     Allow to manage a bare Git repository from Openstack Swift.
     This object store only supports pack files and not loose objects.
     This object store only supports pack files and not loose objects.
     """
     """
+
     def __init__(self, scon):
     def __init__(self, scon):
         """Open a Swift object store.
         """Open a Swift object store.
 
 
@@ -655,8 +664,11 @@ class SwiftObjectStore(PackBasedObjectStore):
 
 
     def _update_pack_cache(self):
     def _update_pack_cache(self):
         objects = self.scon.get_container_objects()
         objects = self.scon.get_container_objects()
-        pack_files = [o['name'].replace(".pack", "")
-                      for o in objects if o['name'].endswith(".pack")]
+        pack_files = [
+            o["name"].replace(".pack", "")
+            for o in objects
+            if o["name"].endswith(".pack")
+        ]
         ret = []
         ret = []
         for basename in pack_files:
         for basename in pack_files:
             pack = SwiftPack(basename, scon=self.scon)
             pack = SwiftPack(basename, scon=self.scon)
@@ -665,8 +677,7 @@ class SwiftObjectStore(PackBasedObjectStore):
         return ret
         return ret
 
 
     def _iter_loose_objects(self):
     def _iter_loose_objects(self):
-        """Loose objects are not supported by this repository
-        """
+        """Loose objects are not supported by this repository"""
         return []
         return []
 
 
     def iter_shas(self, finder):
     def iter_shas(self, finder):
@@ -676,11 +687,10 @@ class SwiftObjectStore(PackBasedObjectStore):
                  instance if gevent is enabled
                  instance if gevent is enabled
         """
         """
         shas = iter(finder.next, None)
         shas = iter(finder.next, None)
-        return PackInfoObjectStoreIterator(
-            self, shas, finder, self.scon.concurrency)
+        return PackInfoObjectStoreIterator(self, shas, finder, self.scon.concurrency)
 
 
     def find_missing_objects(self, *args, **kwargs):
     def find_missing_objects(self, *args, **kwargs):
-        kwargs['concurrency'] = self.scon.concurrency
+        kwargs["concurrency"] = self.scon.concurrency
         return PackInfoMissingObjectFinder(self, *args, **kwargs)
         return PackInfoMissingObjectFinder(self, *args, **kwargs)
 
 
     def pack_info_get(self, sha):
     def pack_info_get(self, sha):
@@ -725,11 +735,11 @@ class SwiftObjectStore(PackBasedObjectStore):
             f.seek(0)
             f.seek(0)
             pack = PackData(file=f, filename="")
             pack = PackData(file=f, filename="")
             entries = pack.sorted_entries()
             entries = pack.sorted_entries()
-            if len(entries):
-                basename = posixpath.join(self.pack_dir,
-                                          "pack-%s" %
-                                          iter_sha1(entry[0] for
-                                                    entry in entries))
+            if entries:
+                basename = posixpath.join(
+                    self.pack_dir,
+                    "pack-%s" % iter_sha1(entry[0] for entry in entries),
+                )
                 index = BytesIO()
                 index = BytesIO()
                 write_pack_index_v2(index, entries, pack.get_stored_checksum())
                 write_pack_index_v2(index, entries, pack.get_stored_checksum())
                 self.scon.put_object(basename + ".pack", f)
                 self.scon.put_object(basename + ".pack", f)
@@ -745,10 +755,15 @@ class SwiftObjectStore(PackBasedObjectStore):
 
 
         def abort():
         def abort():
             pass
             pass
+
         return f, commit, abort
         return f, commit, abort
 
 
     def add_object(self, obj):
     def add_object(self, obj):
-        self.add_objects([(obj, None), ])
+        self.add_objects(
+            [
+                (obj, None),
+            ]
+        )
 
 
     def _pack_cache_stale(self):
     def _pack_cache_stale(self):
         return False
         return False
@@ -762,12 +777,11 @@ class SwiftObjectStore(PackBasedObjectStore):
         Read it from a stream and complete it in a temporary file.
         Read it from a stream and complete it in a temporary file.
         Then the pack and the corresponding index file are uploaded to Swift.
         Then the pack and the corresponding index file are uploaded to Swift.
         """
         """
-        fd, path = tempfile.mkstemp(prefix='tmp_pack_')
-        f = os.fdopen(fd, 'w+b')
+        fd, path = tempfile.mkstemp(prefix="tmp_pack_")
+        f = os.fdopen(fd, "w+b")
         try:
         try:
             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
-            copier = PackStreamCopier(read_all, read_some, f,
-                                      delta_iter=indexer)
+            copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             copier.verify()
             return self._complete_thin_pack(f, path, copier, indexer)
             return self._complete_thin_pack(f, path, copier, indexer)
         finally:
         finally:
@@ -805,11 +819,12 @@ class SwiftObjectStore(PackBasedObjectStore):
         entries.sort()
         entries.sort()
         pack_base_name = posixpath.join(
         pack_base_name = posixpath.join(
             self.pack_dir,
             self.pack_dir,
-            'pack-' + os.fsdecode(iter_sha1(e[0] for e in entries)))
-        self.scon.put_object(pack_base_name + '.pack', f)
+            "pack-" + os.fsdecode(iter_sha1(e[0] for e in entries)),
+        )
+        self.scon.put_object(pack_base_name + ".pack", f)
 
 
         # Write the index.
         # Write the index.
-        filename = pack_base_name + '.idx'
+        filename = pack_base_name + ".idx"
         index_file = BytesIO()
         index_file = BytesIO()
         write_pack_index_v2(index_file, entries, pack_sha)
         write_pack_index_v2(index_file, entries, pack_sha)
         self.scon.put_object(filename, index_file)
         self.scon.put_object(filename, index_file)
@@ -818,12 +833,12 @@ class SwiftObjectStore(PackBasedObjectStore):
         f.seek(0)
         f.seek(0)
         pack_data = PackData(filename="", file=f)
         pack_data = PackData(filename="", file=f)
         index_file.seek(0)
         index_file.seek(0)
-        pack_index = load_pack_index_file('', index_file)
+        pack_index = load_pack_index_file("", index_file)
         serialized_pack_info = pack_info_create(pack_data, pack_index)
         serialized_pack_info = pack_info_create(pack_data, pack_index)
         f.close()
         f.close()
         index_file.close()
         index_file.close()
         pack_info_file = BytesIO(serialized_pack_info)
         pack_info_file = BytesIO(serialized_pack_info)
-        filename = pack_base_name + '.info'
+        filename = pack_base_name + ".info"
         self.scon.put_object(filename, pack_info_file)
         self.scon.put_object(filename, pack_info_file)
         pack_info_file.close()
         pack_info_file.close()
 
 
@@ -835,16 +850,15 @@ class SwiftObjectStore(PackBasedObjectStore):
 
 
 
 
 class SwiftInfoRefsContainer(InfoRefsContainer):
 class SwiftInfoRefsContainer(InfoRefsContainer):
-    """Manage references in info/refs object.
-    """
+    """Manage references in info/refs object."""
 
 
     def __init__(self, scon, store):
     def __init__(self, scon, store):
         self.scon = scon
         self.scon = scon
-        self.filename = 'info/refs'
+        self.filename = "info/refs"
         self.store = store
         self.store = store
         f = self.scon.get_object(self.filename)
         f = self.scon.get_object(self.filename)
         if not f:
         if not f:
-            f = BytesIO(b'')
+            f = BytesIO(b"")
         super(SwiftInfoRefsContainer, self).__init__(f)
         super(SwiftInfoRefsContainer, self).__init__(f)
 
 
     def _load_check_ref(self, name, old_ref):
     def _load_check_ref(self, name, old_ref):
@@ -864,9 +878,8 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         self.scon.put_object(self.filename, f)
         self.scon.put_object(self.filename, f)
 
 
     def set_if_equals(self, name, old_ref, new_ref):
     def set_if_equals(self, name, old_ref, new_ref):
-        """Set a refname to new_ref only if it currently equals old_ref.
-        """
-        if name == 'HEAD':
+        """Set a refname to new_ref only if it currently equals old_ref."""
+        if name == "HEAD":
             return True
             return True
         refs = self._load_check_ref(name, old_ref)
         refs = self._load_check_ref(name, old_ref)
         if not isinstance(refs, dict):
         if not isinstance(refs, dict):
@@ -877,9 +890,8 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         return True
         return True
 
 
     def remove_if_equals(self, name, old_ref):
     def remove_if_equals(self, name, old_ref):
-        """Remove a refname only if it currently equals old_ref.
-        """
-        if name == 'HEAD':
+        """Remove a refname only if it currently equals old_ref."""
+        if name == "HEAD":
             return True
             return True
         refs = self._load_check_ref(name, old_ref)
         refs = self._load_check_ref(name, old_ref)
         if not isinstance(refs, dict):
         if not isinstance(refs, dict):
@@ -891,14 +903,13 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
 
 
     def allkeys(self):
     def allkeys(self):
         try:
         try:
-            self._refs['HEAD'] = self._refs['refs/heads/master']
+            self._refs["HEAD"] = self._refs["refs/heads/master"]
         except KeyError:
         except KeyError:
             pass
             pass
         return self._refs.keys()
         return self._refs.keys()
 
 
 
 
 class SwiftRepo(BaseRepo):
 class SwiftRepo(BaseRepo):
-
     def __init__(self, root, conf):
     def __init__(self, root, conf):
         """Init a Git bare Repository on top of a Swift container.
         """Init a Git bare Repository on top of a Swift container.
 
 
@@ -910,15 +921,15 @@ class SwiftRepo(BaseRepo):
           root: The container which contains the bare repo
           root: The container which contains the bare repo
           conf: A ConfigParser object
           conf: A ConfigParser object
         """
         """
-        self.root = root.lstrip('/')
+        self.root = root.lstrip("/")
         self.conf = conf
         self.conf = conf
         self.scon = SwiftConnector(self.root, self.conf)
         self.scon = SwiftConnector(self.root, self.conf)
         objects = self.scon.get_container_objects()
         objects = self.scon.get_container_objects()
         if not objects:
         if not objects:
-            raise Exception('There is not any GIT repo here : %s' % self.root)
-        objects = [o['name'].split('/')[0] for o in objects]
+            raise Exception("There is not any GIT repo here : %s" % self.root)
+        objects = [o["name"].split("/")[0] for o in objects]
         if OBJECTDIR not in objects:
         if OBJECTDIR not in objects:
-            raise Exception('This repository (%s) is not bare.' % self.root)
+            raise Exception("This repository (%s) is not bare." % self.root)
         self.bare = True
         self.bare = True
         self._controldir = self.root
         self._controldir = self.root
         object_store = SwiftObjectStore(self.scon)
         object_store = SwiftObjectStore(self.scon)
@@ -954,66 +965,89 @@ class SwiftRepo(BaseRepo):
           a `SwiftRepo` instance
           a `SwiftRepo` instance
         """
         """
         scon.create_root()
         scon.create_root()
-        for obj in [posixpath.join(OBJECTDIR, PACKDIR),
-                    posixpath.join(INFODIR, 'refs')]:
-            scon.put_object(obj, BytesIO(b''))
+        for obj in [
+            posixpath.join(OBJECTDIR, PACKDIR),
+            posixpath.join(INFODIR, "refs"),
+        ]:
+            scon.put_object(obj, BytesIO(b""))
         ret = cls(scon.root, conf)
         ret = cls(scon.root, conf)
         ret._init_files(True)
         ret._init_files(True)
         return ret
         return ret
 
 
 
 
 class SwiftSystemBackend(Backend):
 class SwiftSystemBackend(Backend):
-
     def __init__(self, logger, conf):
     def __init__(self, logger, conf):
         self.conf = conf
         self.conf = conf
         self.logger = logger
         self.logger = logger
 
 
     def open_repository(self, path):
     def open_repository(self, path):
-        self.logger.info('opening repository at %s', path)
+        self.logger.info("opening repository at %s", path)
         return SwiftRepo(path, self.conf)
         return SwiftRepo(path, self.conf)
 
 
 
 
 def cmd_daemon(args):
 def cmd_daemon(args):
     """Entry point for starting a TCP git server."""
     """Entry point for starting a TCP git server."""
     import optparse
     import optparse
+
     parser = optparse.OptionParser()
     parser = optparse.OptionParser()
-    parser.add_option("-l", "--listen_address", dest="listen_address",
-                      default="127.0.0.1",
-                      help="Binding IP address.")
-    parser.add_option("-p", "--port", dest="port", type=int,
-                      default=TCP_GIT_PORT,
-                      help="Binding TCP port.")
-    parser.add_option("-c", "--swift_config", dest="swift_config",
-                      default="",
-                      help="Path to the configuration file for Swift backend.")
+    parser.add_option(
+        "-l",
+        "--listen_address",
+        dest="listen_address",
+        default="127.0.0.1",
+        help="Binding IP address.",
+    )
+    parser.add_option(
+        "-p",
+        "--port",
+        dest="port",
+        type=int,
+        default=TCP_GIT_PORT,
+        help="Binding TCP port.",
+    )
+    parser.add_option(
+        "-c",
+        "--swift_config",
+        dest="swift_config",
+        default="",
+        help="Path to the configuration file for Swift backend.",
+    )
     options, args = parser.parse_args(args)
     options, args = parser.parse_args(args)
 
 
     try:
     try:
         import gevent
         import gevent
         import geventhttpclient  # noqa: F401
         import geventhttpclient  # noqa: F401
     except ImportError:
     except ImportError:
-        print("gevent and geventhttpclient libraries are mandatory "
-              " for use the Swift backend.")
+        print(
+            "gevent and geventhttpclient libraries are mandatory "
+            " for use the Swift backend."
+        )
         sys.exit(1)
         sys.exit(1)
     import gevent.monkey
     import gevent.monkey
+
     gevent.monkey.patch_socket()
     gevent.monkey.patch_socket()
     from dulwich import log_utils
     from dulwich import log_utils
+
     logger = log_utils.getLogger(__name__)
     logger = log_utils.getLogger(__name__)
     conf = load_conf(options.swift_config)
     conf = load_conf(options.swift_config)
     backend = SwiftSystemBackend(logger, conf)
     backend = SwiftSystemBackend(logger, conf)
 
 
     log_utils.default_logging_config()
     log_utils.default_logging_config()
-    server = TCPGitServer(backend, options.listen_address,
-                          port=options.port)
+    server = TCPGitServer(backend, options.listen_address, port=options.port)
     server.serve_forever()
     server.serve_forever()
 
 
 
 
 def cmd_init(args):
 def cmd_init(args):
     import optparse
     import optparse
+
     parser = optparse.OptionParser()
     parser = optparse.OptionParser()
-    parser.add_option("-c", "--swift_config", dest="swift_config",
-                      default="",
-                      help="Path to the configuration file for Swift backend.")
+    parser.add_option(
+        "-c",
+        "--swift_config",
+        dest="swift_config",
+        default="",
+        help="Path to the configuration file for Swift backend.",
+    )
     options, args = parser.parse_args(args)
     options, args = parser.parse_args(args)
 
 
     conf = load_conf(options.swift_config)
     conf = load_conf(options.swift_config)
@@ -1031,8 +1065,7 @@ def main(argv=sys.argv):
     }
     }
 
 
     if len(sys.argv) < 2:
     if len(sys.argv) < 2:
-        print("Usage: %s <%s> [OPTIONS...]" % (
-                sys.argv[0], "|".join(commands.keys())))
+        print("Usage: %s <%s> [OPTIONS...]" % (sys.argv[0], "|".join(commands.keys())))
         sys.exit(1)
         sys.exit(1)
 
 
     cmd = sys.argv[1]
     cmd = sys.argv[1]
@@ -1042,5 +1075,5 @@ def main(argv=sys.argv):
     commands[cmd](sys.argv[2:])
     commands[cmd](sys.argv[2:])
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
     main()

+ 189 - 0
dulwich/contrib/test_paramiko_vendor.py

@@ -0,0 +1,189 @@
+# test_paramiko_vendor.py
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Tests for paramiko_vendor."""
+
+import socket
+import paramiko
+import threading
+
+from dulwich.tests import TestCase
+from dulwich.contrib.paramiko_vendor import ParamikoSSHVendor
+
+try:
+    from StringIO import StringIO
+except ImportError:
+    from io import StringIO
+
+
+USER = 'testuser'
+PASSWORD = 'test'
+SERVER_KEY = """\
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEAy/L1sSYAzxsMprtNXW4u/1jGXXkQmQ2xtmKVlR+RlIL3a1BH
+bzTpPlZyjltAAwzIP8XRh0iJFKz5y3zSQChhX47ZGN0NvQsVct8R+YwsUonwfAJ+
+JN0KBKKvC8fPHlzqBr3gX+ZxqsFH934tQ6wdQPH5eQWtdM8L826lMsH1737uyTGk
++mCSDjL3c6EzY83g7qhkJU2R4qbi6ne01FaWADzG8sOzXnHT+xpxtk8TTT8yCVUY
+MmBNsSoA/ka3iWz70ghB+6Xb0WpFJZXWq1oYovviPAfZGZSrxBZMxsWMye70SdLl
+TqsBEt0+miIcm9s0fvjWvQuhaHX6mZs5VO4r5QIDAQABAoIBAGYqeYWaYgFdrYLA
+hUrubUCg+g3NHdFuGL4iuIgRXl4lFUh+2KoOuWDu8Uf60iA1AQNhV0sLvQ/Mbv3O
+s4xMLisuZfaclctDiCUZNenqnDFkxEF7BjH1QJV94W5nU4wEQ3/JEmM4D2zYkfKb
+FJW33JeyH6TOgUvohDYYEU1R+J9V8qA243p+ui1uVtNI6Pb0TXJnG5y9Ny4vkSWH
+Fi0QoMPR1r9xJ4SEearGzA/crb4SmmDTKhGSoMsT3d5ATieLmwcS66xWz8w4oFGJ
+yzDq24s4Fp9ccNjMf/xR8XRiekJv835gjEqwF9IXyvgOaq6XJ1iCqGPFDKa25nui
+JnEstOkCgYEA/ZXk7aIanvdeJlTqpX578sJfCnrXLydzE8emk1b7+5mrzGxQ4/pM
+PBQs2f8glT3t0O0mRX9NoRqnwrid88/b+cY4NCOICFZeasX336/gYQxyVeRLJS6Z
+hnGEQqry8qS7PdKAyeHMNmZFrUh4EiHiObymEfQS+mkRUObn0cGBTw8CgYEAzeQU
+D2baec1DawjppKaRynAvWjp+9ry1lZx9unryKVRwjRjkEpw+b3/+hdaF1IvsVSce
+cNj+6W2guZ2tyHuPhZ64/4SJVyE2hKDSKD4xTb2nVjsMeN0bLD2UWXC9mwbx8nWa
+2tmtUZ7a/okQb2cSdosJinRewLNqXIsBXamT1csCgYEA0cXb2RCOQQ6U3dTFPx4A
+3vMXuA2iUKmrsqMoEx6T2LBow/Sefdkik1iFOdipVYwjXP+w9zC2QR1Rxez/DR/X
+8ymceNUjxPHdrSoTQQG29dFcC92MpDeGXQcuyA+uZjcLhbrLOzYEvsOfxBb87NMG
+14hNQPDNekTMREafYo9WrtUCgYAREK54+FVzcwf7fymedA/xb4r9N4v+d3W1iNsC
+8d3Qfyc1CrMct8aVB07ZWQaOr2pPRIbJY7L9NhD0UZVt4I/sy1MaGqonhqE2LP4+
+R6legDG2e/50ph7yc8gwAaA1kUXMiuLi8Nfkw/3yyvmJwklNegi4aRzRbA2Mzhi2
+4q9WMQKBgQCb0JNyxHG4pvLWCF/j0Sm1FfvrpnqSv5678n1j4GX7Ka/TubOK1Y4K
+U+Oib7dKa/zQMWehVFNTayrsq6bKVZ6q7zG+IHiRLw4wjeAxREFH6WUjDrn9vl2l
+D48DKbBuBwuVOJWyq3qbfgJXojscgNQklrsPdXVhDwOF0dYxP89HnA=="""
+CLIENT_KEY = """\
+-----BEGIN RSA PRIVATE KEY-----
+MIIEpAIBAAKCAQEAxvREKSElPOm/0z/nPO+j5rk2tjdgGcGc7We1QZ6TRXYLu7nN
+GeEFIL4p8N1i6dmB+Eydt7xqCU79MWD6Yy4prFe1+/K1wCDUxIbFMxqQcX5zjJzd
+i8j8PbcaUlVhP/OkjtkSxrXaGDO1BzfdV4iEBtTV/2l3zmLKJlt3jnOHLczP24CB
+DTQKp3rKshbRefzot9Y+wnaK692RsYgsyo9YEP0GyWKG9topCHk13r46J6vGLeuj
+ryUKqmbLJkzbJbIcEqwTDo5iHaCVqaMr5Hrb8BdMucSseqZQJsXSd+9tdRcIblUQ
+38kZjmFMm4SFbruJcpZCNM2wNSZPIRX+3eiwNwIDAQABAoIBAHSacOBSJsr+jIi5
+KUOTh9IPtzswVUiDKwARCjB9Sf8p4lKR4N1L/n9kNJyQhApeikgGT2GCMftmqgoo
+tlculQoHFgemBlOmak0MV8NNzF5YKEy/GzF0CDH7gJfEpoyetVFrdA+2QS5yD6U9
+XqKQxiBi2VEqdScmyyeT8AwzNYTnPeH/DOEcnbdRjqiy/CD79F49CQ1lX1Fuqm0K
+I7BivBH1xo/rVnUP4F+IzocDqoga+Pjdj0LTXIgJlHQDSbhsQqWujWQDDuKb+MAw
+sNK4Zf8ErV3j1PyA7f/M5LLq6zgstkW4qikDHo4SpZX8kFOO8tjqb7kujj7XqeaB
+CxqrOTECgYEA73uWkrohcmDJ4KqbuL3tbExSCOUiaIV+sT1eGPNi7GCmXD4eW5Z4
+75v2IHymW83lORSu/DrQ6sKr1nkuRpqr2iBzRmQpl/H+wahIhBXlnJ25uUjDsuPO
+1Pq2LcmyD+jTxVnmbSe/q7O09gZQw3I6H4+BMHmpbf8tC97lqimzpJ0CgYEA1K0W
+ZL70Xtn9quyHvbtae/BW07NZnxvUg4UaVIAL9Zu34JyplJzyzbIjrmlDbv6aRogH
+/KtuG9tfbf55K/jjqNORiuRtzt1hUN1ye4dyW7tHx2/7lXdlqtyK40rQl8P0kqf8
+zaS6BqjnobgSdSpg32rWoL/pcBHPdJCJEgQ8zeMCgYEA0/PK8TOhNIzrP1dgGSKn
+hkkJ9etuB5nW5mEM7gJDFDf6JPupfJ/xiwe6z0fjKK9S57EhqgUYMB55XYnE5iIw
+ZQ6BV9SAZ4V7VsRs4dJLdNC3tn/rDGHJBgCaym2PlbsX6rvFT+h1IC8dwv0V79Ui
+Ehq9WTzkMoE8yhvNokvkPZUCgYEAgBAFxv5xGdh79ftdtXLmhnDvZ6S8l6Fjcxqo
+Ay/jg66Tp43OU226iv/0mmZKM8Dd1xC8dnon4GBVc19jSYYiWBulrRPlx0Xo/o+K
+CzZBN1lrXH1i6dqufpc0jq8TMf/N+q1q/c1uMupsKCY1/xVYpc+ok71b7J7c49zQ
+nOeuUW8CgYA9Infooy65FTgbzca0c9kbCUBmcAPQ2ItH3JcPKWPQTDuV62HcT00o
+fZdIV47Nez1W5Clk191RMy8TXuqI54kocciUWpThc6j44hz49oUueb8U4bLcEHzA
+WxtWBWHwxfSmqgTXilEA3ALJp0kNolLnEttnhENwJpZHlqtes0ZA4w==
+-----END RSA PRIVATE KEY-----"""
+
+
+class Server(paramiko.ServerInterface):
+    """http://docs.paramiko.org/en/2.4/api/server.html"""
+    def __init__(self, commands, *args, **kwargs):
+        super(Server, self).__init__(*args, **kwargs)
+        self.commands = commands
+
+    def check_channel_exec_request(self, channel, command):
+        self.commands.append(command)
+        return True
+
+    def check_auth_password(self, username, password):
+        if username == USER and password == PASSWORD:
+            return paramiko.AUTH_SUCCESSFUL
+        return paramiko.AUTH_FAILED
+
+    def check_auth_publickey(self, username, key):
+        pubkey = paramiko.RSAKey.from_private_key(StringIO(CLIENT_KEY))
+        if username == USER and key == pubkey:
+            return paramiko.AUTH_SUCCESSFUL
+        return paramiko.AUTH_FAILED
+
+    def check_channel_request(self, kind, chanid):
+        if kind == "session":
+            return paramiko.OPEN_SUCCEEDED
+        return paramiko.OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED
+
+    def get_allowed_auths(self, username):
+        return "password,publickey"
+
+
+class ParamikoSSHVendorTests(TestCase):
+    def setUp(self):
+        self.commands = []
+        socket.setdefaulttimeout(10)
+        self.addCleanup(socket.setdefaulttimeout, None)
+        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        self.socket.bind(('127.0.0.1', 0))
+        self.socket.listen(5)
+        self.addCleanup(self.socket.close)
+        self.port = self.socket.getsockname()[1]
+        self.thread = threading.Thread(target=self._run)
+        self.thread.start()
+
+    def tearDown(self):
+        pass
+
+    def _run(self):
+        try:
+            conn, addr = self.socket.accept()
+        except socket.error:
+            return False
+        self.transport = paramiko.Transport(conn)
+        self.addCleanup(self.transport.close)
+        host_key = paramiko.RSAKey.from_private_key(StringIO(SERVER_KEY))
+        self.transport.add_server_key(host_key)
+        server = Server(self.commands)
+        self.transport.start_server(server=server)
+
+    def test_run_command_password(self):
+        vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False,)
+        vendor.run_command(
+            '127.0.0.1', 'test_run_command_password',
+            username=USER, port=self.port, password=PASSWORD)
+
+        self.assertIn(b'test_run_command_password', self.commands)
+
+    def test_run_command_with_privkey(self):
+        key = paramiko.RSAKey.from_private_key(StringIO(CLIENT_KEY))
+
+        vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False,)
+        vendor.run_command(
+            '127.0.0.1', 'test_run_command_with_privkey',
+            username=USER, port=self.port, pkey=key)
+
+        self.assertIn(b'test_run_command_with_privkey', self.commands)
+
+    def test_run_command_data_transfer(self):
+        vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False,)
+        con = vendor.run_command(
+            '127.0.0.1', 'test_run_command_data_transfer',
+            username=USER, port=self.port, password=PASSWORD)
+
+        self.assertIn(b'test_run_command_data_transfer', self.commands)
+
+        channel = self.transport.accept(5)
+        channel.send(b'stdout\n')
+        channel.send_stderr(b'stderr\n')
+        channel.close()
+
+        # Fixme: it's return false
+        # self.assertTrue(con.can_read())
+
+        self.assertEqual(b'stdout\n', con.read(4096))
+
+        # Fixme: it's return empty string
+        # self.assertEqual(b'stderr\n', con.read_stderr(4096))

+ 28 - 21
dulwich/contrib/test_release_robot.py

@@ -44,10 +44,17 @@ class TagPatternTests(unittest.TestCase):
     def test_tag_pattern(self):
     def test_tag_pattern(self):
         """test tag patterns"""
         """test tag patterns"""
         test_cases = {
         test_cases = {
-            '0.3': '0.3', 'v0.3': '0.3', 'release0.3': '0.3',
-            'Release-0.3': '0.3', 'v0.3rc1': '0.3rc1', 'v0.3-rc1': '0.3-rc1',
-            'v0.3-rc.1': '0.3-rc.1', 'version 0.3': '0.3',
-            'version_0.3_rc_1': '0.3_rc_1', 'v1': '1', '0.3rc1': '0.3rc1'
+            "0.3": "0.3",
+            "v0.3": "0.3",
+            "release0.3": "0.3",
+            "Release-0.3": "0.3",
+            "v0.3rc1": "0.3rc1",
+            "v0.3-rc1": "0.3-rc1",
+            "v0.3-rc.1": "0.3-rc.1",
+            "version 0.3": "0.3",
+            "version_0.3_rc_1": "0.3_rc_1",
+            "v1": "1",
+            "0.3rc1": "0.3rc1",
         }
         }
         for testcase, version in test_cases.items():
         for testcase, version in test_cases.items():
             matches = re.match(release_robot.PATTERN, testcase)
             matches = re.match(release_robot.PATTERN, testcase)
@@ -58,12 +65,12 @@ class GetRecentTagsTest(unittest.TestCase):
     """test get recent tags"""
     """test get recent tags"""
 
 
     # Git repo for dulwich project
     # Git repo for dulwich project
-    test_repo = os.path.join(BASEDIR, 'dulwich_test_repo.zip')
+    test_repo = os.path.join(BASEDIR, "dulwich_test_repo.zip")
     committer = b"Mark Mikofski <mark.mikofski@sunpowercorp.com>"
     committer = b"Mark Mikofski <mark.mikofski@sunpowercorp.com>"
-    test_tags = [b'v0.1a', b'v0.1']
+    test_tags = [b"v0.1a", b"v0.1"]
     tag_test_data = {
     tag_test_data = {
-        test_tags[0]: [1484788003, b'3' * 40, None],
-        test_tags[1]: [1484788314, b'1' * 40, (1484788401, b'2' * 40)]
+        test_tags[0]: [1484788003, b"3" * 40, None],
+        test_tags[1]: [1484788314, b"1" * 40, (1484788401, b"2" * 40)],
     }
     }
 
 
     @classmethod
     @classmethod
@@ -75,20 +82,20 @@ class GetRecentTagsTest(unittest.TestCase):
         cls.c1 = make_commit(
         cls.c1 = make_commit(
             id=cls.tag_test_data[cls.test_tags[0]][1],
             id=cls.tag_test_data[cls.test_tags[0]][1],
             commit_time=cls.tag_test_data[cls.test_tags[0]][0],
             commit_time=cls.tag_test_data[cls.test_tags[0]][0],
-            message=b'unannotated tag',
-            author=cls.committer
+            message=b"unannotated tag",
+            author=cls.committer,
         )
         )
         obj_store.add_object(cls.c1)
         obj_store.add_object(cls.c1)
         # tag 1: unannotated
         # tag 1: unannotated
         cls.t1 = cls.test_tags[0]
         cls.t1 = cls.test_tags[0]
-        cls.repo[b'refs/tags/' + cls.t1] = cls.c1.id  # add unannotated tag
+        cls.repo[b"refs/tags/" + cls.t1] = cls.c1.id  # add unannotated tag
         # commit 2 ('2017-01-19T01:11:54')
         # commit 2 ('2017-01-19T01:11:54')
         cls.c2 = make_commit(
         cls.c2 = make_commit(
             id=cls.tag_test_data[cls.test_tags[1]][1],
             id=cls.tag_test_data[cls.test_tags[1]][1],
             commit_time=cls.tag_test_data[cls.test_tags[1]][0],
             commit_time=cls.tag_test_data[cls.test_tags[1]][0],
-            message=b'annotated tag',
+            message=b"annotated tag",
             parents=[cls.c1.id],
             parents=[cls.c1.id],
-            author=cls.committer
+            author=cls.committer,
         )
         )
         obj_store.add_object(cls.c2)
         obj_store.add_object(cls.c2)
         # tag 2: annotated ('2017-01-19T01:13:21')
         # tag 2: annotated ('2017-01-19T01:13:21')
@@ -96,11 +103,11 @@ class GetRecentTagsTest(unittest.TestCase):
             cls.c2,
             cls.c2,
             id=cls.tag_test_data[cls.test_tags[1]][2][1],
             id=cls.tag_test_data[cls.test_tags[1]][2][1],
             name=cls.test_tags[1],
             name=cls.test_tags[1],
-            tag_time=cls.tag_test_data[cls.test_tags[1]][2][0]
+            tag_time=cls.tag_test_data[cls.test_tags[1]][2][0],
         )
         )
         obj_store.add_object(cls.t2)
         obj_store.add_object(cls.t2)
-        cls.repo[b'refs/heads/master'] = cls.c2.id
-        cls.repo[b'refs/tags/' + cls.t2.name] = cls.t2.id  # add annotated tag
+        cls.repo[b"refs/heads/master"] = cls.c2.id
+        cls.repo[b"refs/tags/" + cls.t2.name] = cls.t2.id  # add annotated tag
 
 
     @classmethod
     @classmethod
     def tearDownClass(cls):
     def tearDownClass(cls):
@@ -111,17 +118,17 @@ class GetRecentTagsTest(unittest.TestCase):
         """test get recent tags"""
         """test get recent tags"""
         tags = release_robot.get_recent_tags(self.projdir)  # get test tags
         tags = release_robot.get_recent_tags(self.projdir)  # get test tags
         for tag, metadata in tags:
         for tag, metadata in tags:
-            tag = tag.encode('utf-8')
+            tag = tag.encode("utf-8")
             test_data = self.tag_test_data[tag]  # test data tag
             test_data = self.tag_test_data[tag]  # test data tag
             # test commit date, id and author name
             # test commit date, id and author name
             self.assertEqual(metadata[0], gmtime_to_datetime(test_data[0]))
             self.assertEqual(metadata[0], gmtime_to_datetime(test_data[0]))
-            self.assertEqual(metadata[1].encode('utf-8'), test_data[1])
-            self.assertEqual(metadata[2].encode('utf-8'), self.committer)
+            self.assertEqual(metadata[1].encode("utf-8"), test_data[1])
+            self.assertEqual(metadata[2].encode("utf-8"), self.committer)
             # skip unannotated tags
             # skip unannotated tags
             tag_obj = test_data[2]
             tag_obj = test_data[2]
             if not tag_obj:
             if not tag_obj:
                 continue
                 continue
             # tag date, id and name
             # tag date, id and name
             self.assertEqual(metadata[3][0], gmtime_to_datetime(tag_obj[0]))
             self.assertEqual(metadata[3][0], gmtime_to_datetime(tag_obj[0]))
-            self.assertEqual(metadata[3][1].encode('utf-8'), tag_obj[1])
-            self.assertEqual(metadata[3][2].encode('utf-8'), tag)
+            self.assertEqual(metadata[3][1].encode("utf-8"), tag_obj[1])
+            self.assertEqual(metadata[3][2].encode("utf-8"), tag)

+ 215 - 185
dulwich/contrib/test_swift.py

@@ -31,17 +31,17 @@ from unittest import skipIf
 
 
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
-    )
+)
 from dulwich.tests.test_object_store import (
 from dulwich.tests.test_object_store import (
     ObjectStoreTests,
     ObjectStoreTests,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
     Commit,
     Commit,
     Tree,
     Tree,
     Tag,
     Tag,
     parse_timezone,
     parse_timezone,
-    )
+)
 
 
 import json
 import json
 
 
@@ -82,25 +82,24 @@ http_pool_length = %(http_pool_length)s
 http_timeout = %(http_timeout)s
 http_timeout = %(http_timeout)s
 """
 """
 
 
-def_config_file = {'version_str': 'v1.0',
-                   'version_int': 1,
-                   'concurrency': 1,
-                   'chunk_length': 12228,
-                   'cache_length': 1,
-                   'region_name': 'test',
-                   'endpoint_type': 'internalURL',
-                   'http_pool_length': 1,
-                   'http_timeout': 1}
+def_config_file = {
+    "version_str": "v1.0",
+    "version_int": 1,
+    "concurrency": 1,
+    "chunk_length": 12228,
+    "cache_length": 1,
+    "region_name": "test",
+    "endpoint_type": "internalURL",
+    "http_pool_length": 1,
+    "http_timeout": 1,
+}
 
 
 
 
 def create_swift_connector(store={}):
 def create_swift_connector(store={}):
-    return lambda root, conf: FakeSwiftConnector(root,
-                                                 conf=conf,
-                                                 store=store)
+    return lambda root, conf: FakeSwiftConnector(root, conf=conf, store=store)
 
 
 
 
 class Response(object):
 class Response(object):
-
     def __init__(self, headers={}, status=200, content=None):
     def __init__(self, headers={}, status=200, content=None):
         self.headers = headers
         self.headers = headers
         self.status_code = status
         self.status_code = status
@@ -117,40 +116,46 @@ class Response(object):
 
 
 
 
 def fake_auth_request_v1(*args, **kwargs):
 def fake_auth_request_v1(*args, **kwargs):
-    ret = Response({'X-Storage-Url':
-                    'http://127.0.0.1:8080/v1.0/AUTH_fakeuser',
-                    'X-Auth-Token': '12' * 10},
-                   200)
+    ret = Response(
+        {
+            "X-Storage-Url": "http://127.0.0.1:8080/v1.0/AUTH_fakeuser",
+            "X-Auth-Token": "12" * 10,
+        },
+        200,
+    )
     return ret
     return ret
 
 
 
 
 def fake_auth_request_v1_error(*args, **kwargs):
 def fake_auth_request_v1_error(*args, **kwargs):
-    ret = Response({},
-                   401)
+    ret = Response({}, 401)
     return ret
     return ret
 
 
 
 
 def fake_auth_request_v2(*args, **kwargs):
 def fake_auth_request_v2(*args, **kwargs):
-    s_url = 'http://127.0.0.1:8080/v1.0/AUTH_fakeuser'
-    resp = {'access': {'token': {'id': '12' * 10},
-                       'serviceCatalog':
-                       [
-                           {'type': 'object-store',
-                            'endpoints': [{'region': 'test',
-                                          'internalURL': s_url,
-                                           },
-                                          ]
-                            },
-                       ]
-                       }
-            }
+    s_url = "http://127.0.0.1:8080/v1.0/AUTH_fakeuser"
+    resp = {
+        "access": {
+            "token": {"id": "12" * 10},
+            "serviceCatalog": [
+                {
+                    "type": "object-store",
+                    "endpoints": [
+                        {
+                            "region": "test",
+                            "internalURL": s_url,
+                        },
+                    ],
+                },
+            ],
+        }
+    }
     ret = Response(status=200, content=json.dumps(resp))
     ret = Response(status=200, content=json.dumps(resp))
     return ret
     return ret
 
 
 
 
-def create_commit(data, marker=b'Default', blob=None):
+def create_commit(data, marker=b"Default", blob=None):
     if not blob:
     if not blob:
-        blob = Blob.from_string(b'The blob content ' + marker)
+        blob = Blob.from_string(b"The blob content " + marker)
     tree = Tree()
     tree = Tree()
     tree.add(b"thefile_" + marker, 0o100644, blob.id)
     tree.add(b"thefile_" + marker, 0o100644, blob.id)
     cmt = Commit()
     cmt = Commit()
@@ -160,7 +165,7 @@ def create_commit(data, marker=b'Default', blob=None):
     cmt.tree = tree.id
     cmt.tree = tree.id
     author = b"John Doe " + marker + b" <john@doe.net>"
     author = b"John Doe " + marker + b" <john@doe.net>"
     cmt.author = cmt.committer = author
     cmt.author = cmt.committer = author
-    tz = parse_timezone(b'-0200')[0]
+    tz = parse_timezone(b"-0200")[0]
     cmt.commit_time = cmt.author_time = int(time())
     cmt.commit_time = cmt.author_time = int(time())
     cmt.commit_timezone = cmt.author_timezone = tz
     cmt.commit_timezone = cmt.author_timezone = tz
     cmt.encoding = b"UTF-8"
     cmt.encoding = b"UTF-8"
@@ -168,14 +173,14 @@ def create_commit(data, marker=b'Default', blob=None):
     tag = Tag()
     tag = Tag()
     tag.tagger = b"john@doe.net"
     tag.tagger = b"john@doe.net"
     tag.message = b"Annotated tag"
     tag.message = b"Annotated tag"
-    tag.tag_timezone = parse_timezone(b'-0200')[0]
+    tag.tag_timezone = parse_timezone(b"-0200")[0]
     tag.tag_time = cmt.author_time
     tag.tag_time = cmt.author_time
     tag.object = (Commit, cmt.id)
     tag.object = (Commit, cmt.id)
     tag.name = b"v_" + marker + b"_0.1"
     tag.name = b"v_" + marker + b"_0.1"
     return blob, tree, tag, cmt
     return blob, tree, tag, cmt
 
 
 
 
-def create_commits(length=1, marker=b'Default'):
+def create_commits(length=1, marker=b"Default"):
     data = []
     data = []
     for i in range(0, length):
     for i in range(0, length):
         _marker = ("%s_%s" % (marker, i)).encode()
         _marker = ("%s_%s" % (marker, i)).encode()
@@ -186,7 +191,6 @@ def create_commits(length=1, marker=b'Default'):
 
 
 @skipIf(missing_libs, skipmsg)
 @skipIf(missing_libs, skipmsg)
 class FakeSwiftConnector(object):
 class FakeSwiftConnector(object):
-
     def __init__(self, root, conf, store=None):
     def __init__(self, root, conf, store=None):
         if store:
         if store:
             self.store = store
             self.store = store
@@ -200,7 +204,7 @@ class FakeSwiftConnector(object):
 
 
     def put_object(self, name, content):
     def put_object(self, name, content):
         name = posixpath.join(self.root, name)
         name = posixpath.join(self.root, name)
-        if hasattr(content, 'seek'):
+        if hasattr(content, "seek"):
             content.seek(0)
             content.seek(0)
             content = content.read()
             content = content.read()
         self.store[name] = content
         self.store[name] = content
@@ -213,96 +217,99 @@ class FakeSwiftConnector(object):
             except KeyError:
             except KeyError:
                 return None
                 return None
         else:
         else:
-            l, r = range.split('-')
+            l, r = range.split("-")
             try:
             try:
                 if not l:
                 if not l:
                     r = -int(r)
                     r = -int(r)
                     return self.store[name][r:]
                     return self.store[name][r:]
                 else:
                 else:
-                    return self.store[name][int(l):int(r)]
+                    return self.store[name][int(l) : int(r)]
             except KeyError:
             except KeyError:
                 return None
                 return None
 
 
     def get_container_objects(self):
     def get_container_objects(self):
-        return [{'name': k.replace(self.root + '/', '')}
-                for k in self.store]
+        return [{"name": k.replace(self.root + "/", "")} for k in self.store]
 
 
     def create_root(self):
     def create_root(self):
         if self.root in self.store.keys():
         if self.root in self.store.keys():
             pass
             pass
         else:
         else:
-            self.store[self.root] = ''
+            self.store[self.root] = ""
 
 
     def get_object_stat(self, name):
     def get_object_stat(self, name):
         name = posixpath.join(self.root, name)
         name = posixpath.join(self.root, name)
         if name not in self.store:
         if name not in self.store:
             return None
             return None
-        return {'content-length': len(self.store[name])}
+        return {"content-length": len(self.store[name])}
 
 
 
 
 @skipIf(missing_libs, skipmsg)
 @skipIf(missing_libs, skipmsg)
 class TestSwiftRepo(TestCase):
 class TestSwiftRepo(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(TestSwiftRepo, self).setUp()
         super(TestSwiftRepo, self).setUp()
-        self.conf = swift.load_conf(file=StringIO(config_file %
-                                                  def_config_file))
+        self.conf = swift.load_conf(file=StringIO(config_file % def_config_file))
 
 
     def test_init(self):
     def test_init(self):
-        store = {'fakerepo/objects/pack': ''}
-        with patch('dulwich.contrib.swift.SwiftConnector',
-                   new_callable=create_swift_connector,
-                   store=store):
-            swift.SwiftRepo('fakerepo', conf=self.conf)
+        store = {"fakerepo/objects/pack": ""}
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector",
+            new_callable=create_swift_connector,
+            store=store,
+        ):
+            swift.SwiftRepo("fakerepo", conf=self.conf)
 
 
     def test_init_no_data(self):
     def test_init_no_data(self):
-        with patch('dulwich.contrib.swift.SwiftConnector',
-                   new_callable=create_swift_connector):
-            self.assertRaises(Exception, swift.SwiftRepo,
-                              'fakerepo', self.conf)
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector",
+            new_callable=create_swift_connector,
+        ):
+            self.assertRaises(Exception, swift.SwiftRepo, "fakerepo", self.conf)
 
 
     def test_init_bad_data(self):
     def test_init_bad_data(self):
-        store = {'fakerepo/.git/objects/pack': ''}
-        with patch('dulwich.contrib.swift.SwiftConnector',
-                   new_callable=create_swift_connector,
-                   store=store):
-            self.assertRaises(Exception, swift.SwiftRepo,
-                              'fakerepo', self.conf)
+        store = {"fakerepo/.git/objects/pack": ""}
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector",
+            new_callable=create_swift_connector,
+            store=store,
+        ):
+            self.assertRaises(Exception, swift.SwiftRepo, "fakerepo", self.conf)
 
 
     def test_put_named_file(self):
     def test_put_named_file(self):
-        store = {'fakerepo/objects/pack': ''}
-        with patch('dulwich.contrib.swift.SwiftConnector',
-                   new_callable=create_swift_connector,
-                   store=store):
-            repo = swift.SwiftRepo('fakerepo', conf=self.conf)
-            desc = b'Fake repo'
-            repo._put_named_file('description', desc)
-        self.assertEqual(repo.scon.store['fakerepo/description'],
-                         desc)
+        store = {"fakerepo/objects/pack": ""}
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector",
+            new_callable=create_swift_connector,
+            store=store,
+        ):
+            repo = swift.SwiftRepo("fakerepo", conf=self.conf)
+            desc = b"Fake repo"
+            repo._put_named_file("description", desc)
+        self.assertEqual(repo.scon.store["fakerepo/description"], desc)
 
 
     def test_init_bare(self):
     def test_init_bare(self):
-        fsc = FakeSwiftConnector('fakeroot', conf=self.conf)
-        with patch('dulwich.contrib.swift.SwiftConnector',
-                   new_callable=create_swift_connector,
-                   store=fsc.store):
+        fsc = FakeSwiftConnector("fakeroot", conf=self.conf)
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector",
+            new_callable=create_swift_connector,
+            store=fsc.store,
+        ):
             swift.SwiftRepo.init_bare(fsc, conf=self.conf)
             swift.SwiftRepo.init_bare(fsc, conf=self.conf)
-        self.assertIn('fakeroot/objects/pack', fsc.store)
-        self.assertIn('fakeroot/info/refs', fsc.store)
-        self.assertIn('fakeroot/description', fsc.store)
+        self.assertIn("fakeroot/objects/pack", fsc.store)
+        self.assertIn("fakeroot/info/refs", fsc.store)
+        self.assertIn("fakeroot/description", fsc.store)
 
 
 
 
 @skipIf(missing_libs, skipmsg)
 @skipIf(missing_libs, skipmsg)
 class TestSwiftInfoRefsContainer(TestCase):
 class TestSwiftInfoRefsContainer(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(TestSwiftInfoRefsContainer, self).setUp()
         super(TestSwiftInfoRefsContainer, self).setUp()
         content = (
         content = (
             b"22effb216e3a82f97da599b8885a6cadb488b4c5\trefs/heads/master\n"
             b"22effb216e3a82f97da599b8885a6cadb488b4c5\trefs/heads/master\n"
-            b"cca703b0e1399008b53a1a236d6b4584737649e4\trefs/heads/dev")
-        self.store = {'fakerepo/info/refs': content}
-        self.conf = swift.load_conf(file=StringIO(config_file %
-                                                  def_config_file))
-        self.fsc = FakeSwiftConnector('fakerepo', conf=self.conf)
+            b"cca703b0e1399008b53a1a236d6b4584737649e4\trefs/heads/dev"
+        )
+        self.store = {"fakerepo/info/refs": content}
+        self.conf = swift.load_conf(file=StringIO(config_file % def_config_file))
+        self.fsc = FakeSwiftConnector("fakerepo", conf=self.conf)
         self.object_store = {}
         self.object_store = {}
 
 
     def test_init(self):
     def test_init(self):
@@ -311,160 +318,183 @@ class TestSwiftInfoRefsContainer(TestCase):
         self.assertEqual(len(irc._refs), 0)
         self.assertEqual(len(irc._refs), 0)
         self.fsc.store = self.store
         self.fsc.store = self.store
         irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store)
         irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store)
-        self.assertIn(b'refs/heads/dev', irc.allkeys())
-        self.assertIn(b'refs/heads/master', irc.allkeys())
+        self.assertIn(b"refs/heads/dev", irc.allkeys())
+        self.assertIn(b"refs/heads/master", irc.allkeys())
 
 
     def test_set_if_equals(self):
     def test_set_if_equals(self):
         self.fsc.store = self.store
         self.fsc.store = self.store
         irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store)
         irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store)
-        irc.set_if_equals(b'refs/heads/dev',
-                          b"cca703b0e1399008b53a1a236d6b4584737649e4", b'1'*40)
-        self.assertEqual(irc[b'refs/heads/dev'], b'1'*40)
+        irc.set_if_equals(
+            b"refs/heads/dev",
+            b"cca703b0e1399008b53a1a236d6b4584737649e4",
+            b"1" * 40,
+        )
+        self.assertEqual(irc[b"refs/heads/dev"], b"1" * 40)
 
 
     def test_remove_if_equals(self):
     def test_remove_if_equals(self):
         self.fsc.store = self.store
         self.fsc.store = self.store
         irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store)
         irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store)
-        irc.remove_if_equals(b'refs/heads/dev',
-                             b"cca703b0e1399008b53a1a236d6b4584737649e4")
-        self.assertNotIn(b'refs/heads/dev', irc.allkeys())
+        irc.remove_if_equals(
+            b"refs/heads/dev", b"cca703b0e1399008b53a1a236d6b4584737649e4"
+        )
+        self.assertNotIn(b"refs/heads/dev", irc.allkeys())
 
 
 
 
 @skipIf(missing_libs, skipmsg)
 @skipIf(missing_libs, skipmsg)
 class TestSwiftConnector(TestCase):
 class TestSwiftConnector(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(TestSwiftConnector, self).setUp()
         super(TestSwiftConnector, self).setUp()
-        self.conf = swift.load_conf(file=StringIO(config_file %
-                                                  def_config_file))
-        with patch('geventhttpclient.HTTPClient.request',
-                   fake_auth_request_v1):
-            self.conn = swift.SwiftConnector('fakerepo', conf=self.conf)
+        self.conf = swift.load_conf(file=StringIO(config_file % def_config_file))
+        with patch("geventhttpclient.HTTPClient.request", fake_auth_request_v1):
+            self.conn = swift.SwiftConnector("fakerepo", conf=self.conf)
 
 
     def test_init_connector(self):
     def test_init_connector(self):
-        self.assertEqual(self.conn.auth_ver, '1')
-        self.assertEqual(self.conn.auth_url,
-                         'http://127.0.0.1:8080/auth/v1.0')
-        self.assertEqual(self.conn.user, 'test:tester')
-        self.assertEqual(self.conn.password, 'testing')
-        self.assertEqual(self.conn.root, 'fakerepo')
-        self.assertEqual(self.conn.storage_url,
-                         'http://127.0.0.1:8080/v1.0/AUTH_fakeuser')
-        self.assertEqual(self.conn.token, '12' * 10)
+        self.assertEqual(self.conn.auth_ver, "1")
+        self.assertEqual(self.conn.auth_url, "http://127.0.0.1:8080/auth/v1.0")
+        self.assertEqual(self.conn.user, "test:tester")
+        self.assertEqual(self.conn.password, "testing")
+        self.assertEqual(self.conn.root, "fakerepo")
+        self.assertEqual(
+            self.conn.storage_url, "http://127.0.0.1:8080/v1.0/AUTH_fakeuser"
+        )
+        self.assertEqual(self.conn.token, "12" * 10)
         self.assertEqual(self.conn.http_timeout, 1)
         self.assertEqual(self.conn.http_timeout, 1)
         self.assertEqual(self.conn.http_pool_length, 1)
         self.assertEqual(self.conn.http_pool_length, 1)
         self.assertEqual(self.conn.concurrency, 1)
         self.assertEqual(self.conn.concurrency, 1)
-        self.conf.set('swift', 'auth_ver', '2')
-        self.conf.set('swift', 'auth_url', 'http://127.0.0.1:8080/auth/v2.0')
-        with patch('geventhttpclient.HTTPClient.request',
-                   fake_auth_request_v2):
-            conn = swift.SwiftConnector('fakerepo', conf=self.conf)
-        self.assertEqual(conn.user, 'tester')
-        self.assertEqual(conn.tenant, 'test')
-        self.conf.set('swift', 'auth_ver', '1')
-        self.conf.set('swift', 'auth_url', 'http://127.0.0.1:8080/auth/v1.0')
-        with patch('geventhttpclient.HTTPClient.request',
-                   fake_auth_request_v1_error):
-            self.assertRaises(swift.SwiftException,
-                              lambda: swift.SwiftConnector('fakerepo',
-                                                           conf=self.conf))
+        self.conf.set("swift", "auth_ver", "2")
+        self.conf.set("swift", "auth_url", "http://127.0.0.1:8080/auth/v2.0")
+        with patch("geventhttpclient.HTTPClient.request", fake_auth_request_v2):
+            conn = swift.SwiftConnector("fakerepo", conf=self.conf)
+        self.assertEqual(conn.user, "tester")
+        self.assertEqual(conn.tenant, "test")
+        self.conf.set("swift", "auth_ver", "1")
+        self.conf.set("swift", "auth_url", "http://127.0.0.1:8080/auth/v1.0")
+        with patch("geventhttpclient.HTTPClient.request", fake_auth_request_v1_error):
+            self.assertRaises(
+                swift.SwiftException,
+                lambda: swift.SwiftConnector("fakerepo", conf=self.conf),
+            )
 
 
     def test_root_exists(self):
     def test_root_exists(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response()):
+        with patch("geventhttpclient.HTTPClient.request", lambda *args: Response()):
             self.assertEqual(self.conn.test_root_exists(), True)
             self.assertEqual(self.conn.test_root_exists(), True)
 
 
     def test_root_not_exists(self):
     def test_root_not_exists(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response(status=404)):
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args: Response(status=404),
+        ):
             self.assertEqual(self.conn.test_root_exists(), None)
             self.assertEqual(self.conn.test_root_exists(), None)
 
 
     def test_create_root(self):
     def test_create_root(self):
-        with patch('dulwich.contrib.swift.SwiftConnector.test_root_exists',
-                   lambda *args: None):
-            with patch('geventhttpclient.HTTPClient.request',
-                       lambda *args: Response()):
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector.test_root_exists",
+            lambda *args: None,
+        ):
+            with patch("geventhttpclient.HTTPClient.request", lambda *args: Response()):
                 self.assertEqual(self.conn.create_root(), None)
                 self.assertEqual(self.conn.create_root(), None)
 
 
     def test_create_root_fails(self):
     def test_create_root_fails(self):
-        with patch('dulwich.contrib.swift.SwiftConnector.test_root_exists',
-                   lambda *args: None):
-            with patch('geventhttpclient.HTTPClient.request',
-                       lambda *args: Response(status=404)):
-                self.assertRaises(swift.SwiftException,
-                                  lambda: self.conn.create_root())
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector.test_root_exists",
+            lambda *args: None,
+        ):
+            with patch(
+                "geventhttpclient.HTTPClient.request",
+                lambda *args: Response(status=404),
+            ):
+                self.assertRaises(swift.SwiftException, self.conn.create_root)
 
 
     def test_get_container_objects(self):
     def test_get_container_objects(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response(content=json.dumps(
-                       (({'name': 'a'}, {'name': 'b'}))))):
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args: Response(
+                content=json.dumps((({"name": "a"}, {"name": "b"})))
+            ),
+        ):
             self.assertEqual(len(self.conn.get_container_objects()), 2)
             self.assertEqual(len(self.conn.get_container_objects()), 2)
 
 
     def test_get_container_objects_fails(self):
     def test_get_container_objects_fails(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response(status=404)):
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args: Response(status=404),
+        ):
             self.assertEqual(self.conn.get_container_objects(), None)
             self.assertEqual(self.conn.get_container_objects(), None)
 
 
     def test_get_object_stat(self):
     def test_get_object_stat(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response(headers={'content-length': '10'})):
-            self.assertEqual(self.conn.get_object_stat('a')['content-length'],
-                             '10')
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args: Response(headers={"content-length": "10"}),
+        ):
+            self.assertEqual(self.conn.get_object_stat("a")["content-length"], "10")
 
 
     def test_get_object_stat_fails(self):
     def test_get_object_stat_fails(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response(status=404)):
-            self.assertEqual(self.conn.get_object_stat('a'), None)
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args: Response(status=404),
+        ):
+            self.assertEqual(self.conn.get_object_stat("a"), None)
 
 
     def test_put_object(self):
     def test_put_object(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args, **kwargs: Response()):
-            self.assertEqual(self.conn.put_object('a', BytesIO(b'content')),
-                             None)
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args, **kwargs: Response(),
+        ):
+            self.assertEqual(self.conn.put_object("a", BytesIO(b"content")), None)
 
 
     def test_put_object_fails(self):
     def test_put_object_fails(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args, **kwargs: Response(status=400)):
-            self.assertRaises(swift.SwiftException,
-                              lambda: self.conn.put_object(
-                                  'a', BytesIO(b'content')))
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args, **kwargs: Response(status=400),
+        ):
+            self.assertRaises(
+                swift.SwiftException,
+                lambda: self.conn.put_object("a", BytesIO(b"content")),
+            )
 
 
     def test_get_object(self):
     def test_get_object(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args, **kwargs: Response(content=b'content')):
-            self.assertEqual(self.conn.get_object('a').read(), b'content')
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args, **kwargs: Response(content=b'content')):
-            self.assertEqual(
-                    self.conn.get_object('a', range='0-6'),
-                    b'content')
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args, **kwargs: Response(content=b"content"),
+        ):
+            self.assertEqual(self.conn.get_object("a").read(), b"content")
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args, **kwargs: Response(content=b"content"),
+        ):
+            self.assertEqual(self.conn.get_object("a", range="0-6"), b"content")
 
 
     def test_get_object_fails(self):
     def test_get_object_fails(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args, **kwargs: Response(status=404)):
-            self.assertEqual(self.conn.get_object('a'), None)
+        with patch(
+            "geventhttpclient.HTTPClient.request",
+            lambda *args, **kwargs: Response(status=404),
+        ):
+            self.assertEqual(self.conn.get_object("a"), None)
 
 
     def test_del_object(self):
     def test_del_object(self):
-        with patch('geventhttpclient.HTTPClient.request',
-                   lambda *args: Response()):
-            self.assertEqual(self.conn.del_object('a'), None)
+        with patch("geventhttpclient.HTTPClient.request", lambda *args: Response()):
+            self.assertEqual(self.conn.del_object("a"), None)
 
 
     def test_del_root(self):
     def test_del_root(self):
-        with patch('dulwich.contrib.swift.SwiftConnector.del_object',
-                   lambda *args: None):
-            with patch('dulwich.contrib.swift.SwiftConnector.'
-                       'get_container_objects',
-                       lambda *args: ({'name': 'a'}, {'name': 'b'})):
-                with patch('geventhttpclient.HTTPClient.request',
-                           lambda *args: Response()):
+        with patch(
+            "dulwich.contrib.swift.SwiftConnector.del_object",
+            lambda *args: None,
+        ):
+            with patch(
+                "dulwich.contrib.swift.SwiftConnector." "get_container_objects",
+                lambda *args: ({"name": "a"}, {"name": "b"}),
+            ):
+                with patch(
+                    "geventhttpclient.HTTPClient.request",
+                    lambda *args: Response(),
+                ):
                     self.assertEqual(self.conn.del_root(), None)
                     self.assertEqual(self.conn.del_root(), None)
 
 
 
 
 @skipIf(missing_libs, skipmsg)
 @skipIf(missing_libs, skipmsg)
 class SwiftObjectStoreTests(ObjectStoreTests, TestCase):
 class SwiftObjectStoreTests(ObjectStoreTests, TestCase):
-
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
-        conf = swift.load_conf(file=StringIO(config_file %
-                               def_config_file))
-        fsc = FakeSwiftConnector('fakerepo', conf=conf)
+        conf = swift.load_conf(file=StringIO(config_file % def_config_file))
+        fsc = FakeSwiftConnector("fakerepo", conf=conf)
         self.store = swift.SwiftObjectStore(fsc)
         self.store = swift.SwiftObjectStore(fsc)

+ 102 - 99
dulwich/contrib/test_swift_smoke.py

@@ -40,6 +40,7 @@ import shutil
 
 
 import gevent
 import gevent
 from gevent import monkey
 from gevent import monkey
+
 monkey.patch_all()
 monkey.patch_all()
 
 
 from dulwich import (  # noqa:E402
 from dulwich import (  # noqa:E402
@@ -48,21 +49,19 @@ from dulwich import (  # noqa:E402
     index,
     index,
     client,
     client,
     objects,
     objects,
-    )
+)
 from dulwich.contrib import swift  # noqa:E402
 from dulwich.contrib import swift  # noqa:E402
 
 
 
 
-class DulwichServer():
-    """Start the TCPGitServer with Swift backend
-    """
+class DulwichServer:
+    """Start the TCPGitServer with Swift backend"""
+
     def __init__(self, backend, port):
     def __init__(self, backend, port):
         self.port = port
         self.port = port
         self.backend = backend
         self.backend = backend
 
 
     def run(self):
     def run(self):
-        self.server = server.TCPGitServer(self.backend,
-                                          'localhost',
-                                          port=self.port)
+        self.server = server.TCPGitServer(self.backend, "localhost", port=self.port)
         self.job = gevent.spawn(self.server.serve_forever)
         self.job = gevent.spawn(self.server.serve_forever)
 
 
     def stop(self):
     def stop(self):
@@ -71,19 +70,17 @@ class DulwichServer():
 
 
 
 
 class SwiftSystemBackend(server.Backend):
 class SwiftSystemBackend(server.Backend):
-
     def open_repository(self, path):
     def open_repository(self, path):
         return swift.SwiftRepo(path, conf=swift.load_conf())
         return swift.SwiftRepo(path, conf=swift.load_conf())
 
 
 
 
 class SwiftRepoSmokeTest(unittest.TestCase):
 class SwiftRepoSmokeTest(unittest.TestCase):
-
     @classmethod
     @classmethod
     def setUpClass(cls):
     def setUpClass(cls):
         cls.backend = SwiftSystemBackend()
         cls.backend = SwiftSystemBackend()
         cls.port = 9148
         cls.port = 9148
-        cls.server_address = 'localhost'
-        cls.fakerepo = 'fakerepo'
+        cls.server_address = "localhost"
+        cls.fakerepo = "fakerepo"
         cls.th_server = DulwichServer(cls.backend, cls.port)
         cls.th_server = DulwichServer(cls.backend, cls.port)
         cls.th_server.run()
         cls.th_server.run()
         cls.conf = swift.load_conf()
         cls.conf = swift.load_conf()
@@ -116,103 +113,103 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         self.assertTrue(self.scon.test_root_exists())
         self.assertTrue(self.scon.test_root_exists())
         obj = self.scon.get_container_objects()
         obj = self.scon.get_container_objects()
-        filtered = [o for o in obj if o['name'] == 'info/refs'
-                    or o['name'] == 'objects/pack']
+        filtered = [
+            o for o in obj if o["name"] == "info/refs" or o["name"] == "objects/pack"
+        ]
         self.assertEqual(len(filtered), 2)
         self.assertEqual(len(filtered), 2)
 
 
     def test_clone_bare(self):
     def test_clone_bare(self):
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
         remote_refs = tcp_client.fetch(self.fakerepo, local_repo)
         remote_refs = tcp_client.fetch(self.fakerepo, local_repo)
         # The remote repo is empty (no refs retreived)
         # The remote repo is empty (no refs retreived)
         self.assertEqual(remote_refs, None)
         self.assertEqual(remote_refs, None)
 
 
     def test_push_commit(self):
     def test_push_commit(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
             return {"refs/heads/master": local_repo.refs["HEAD"]}
 
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         # Nothing in the staging area
         # Nothing in the staging area
-        local_repo.do_commit('Test commit', 'fbo@localhost')
-        sha = local_repo.refs.read_loose_ref('refs/heads/master')
+        local_repo.do_commit("Test commit", "fbo@localhost")
+        sha = local_repo.refs.read_loose_ref("refs/heads/master")
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack(self.fakerepo, determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            self.fakerepo, determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
-        remote_sha = swift_repo.refs.read_loose_ref('refs/heads/master')
+        remote_sha = swift_repo.refs.read_loose_ref("refs/heads/master")
         self.assertEqual(sha, remote_sha)
         self.assertEqual(sha, remote_sha)
 
 
     def test_push_branch(self):
     def test_push_branch(self):
-        def determine_wants(*args):
-            return {"refs/heads/mybranch":
-                    local_repo.refs["refs/heads/mybranch"]}
+        def determine_wants(*args, **kwargs):
+            return {"refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"]}
 
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         # Nothing in the staging area
         # Nothing in the staging area
-        local_repo.do_commit('Test commit', 'fbo@localhost',
-                             ref='refs/heads/mybranch')
-        sha = local_repo.refs.read_loose_ref('refs/heads/mybranch')
+        local_repo.do_commit("Test commit", "fbo@localhost", ref="refs/heads/mybranch")
+        sha = local_repo.refs.read_loose_ref("refs/heads/mybranch")
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack("/fakerepo", determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            "/fakerepo", determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo(self.fakerepo, self.conf)
         swift_repo = swift.SwiftRepo(self.fakerepo, self.conf)
-        remote_sha = swift_repo.refs.read_loose_ref('refs/heads/mybranch')
+        remote_sha = swift_repo.refs.read_loose_ref("refs/heads/mybranch")
         self.assertEqual(sha, remote_sha)
         self.assertEqual(sha, remote_sha)
 
 
     def test_push_multiple_branch(self):
     def test_push_multiple_branch(self):
-        def determine_wants(*args):
-            return {"refs/heads/mybranch":
-                    local_repo.refs["refs/heads/mybranch"],
-                    "refs/heads/master":
-                    local_repo.refs["refs/heads/master"],
-                    "refs/heads/pullr-108":
-                    local_repo.refs["refs/heads/pullr-108"]}
+        def determine_wants(*args, **kwargs):
+            return {
+                "refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"],
+                "refs/heads/master": local_repo.refs["refs/heads/master"],
+                "refs/heads/pullr-108": local_repo.refs["refs/heads/pullr-108"],
+            }
 
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         # Nothing in the staging area
         # Nothing in the staging area
         local_shas = {}
         local_shas = {}
         remote_shas = {}
         remote_shas = {}
-        for branch in ('master', 'mybranch', 'pullr-108'):
+        for branch in ("master", "mybranch", "pullr-108"):
             local_shas[branch] = local_repo.do_commit(
             local_shas[branch] = local_repo.do_commit(
-                'Test commit %s' % branch, 'fbo@localhost',
-                ref='refs/heads/%s' % branch)
+                "Test commit %s" % branch,
+                "fbo@localhost",
+                ref="refs/heads/%s" % branch,
+            )
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack(self.fakerepo, determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            self.fakerepo, determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
-        for branch in ('master', 'mybranch', 'pullr-108'):
+        for branch in ("master", "mybranch", "pullr-108"):
             remote_shas[branch] = swift_repo.refs.read_loose_ref(
             remote_shas[branch] = swift_repo.refs.read_loose_ref(
-                'refs/heads/%s' % branch)
+                "refs/heads/%s" % branch
+            )
         self.assertDictEqual(local_shas, remote_shas)
         self.assertDictEqual(local_shas, remote_shas)
 
 
     def test_push_data_branch(self):
     def test_push_data_branch(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
             return {"refs/heads/master": local_repo.refs["HEAD"]}
+
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         os.mkdir(os.path.join(self.temp_d, "dir"))
         os.mkdir(os.path.join(self.temp_d, "dir"))
-        files = ('testfile', 'testfile2', 'dir/testfile3')
+        files = ("testfile", "testfile2", "dir/testfile3")
         i = 0
         i = 0
         for f in files:
         for f in files:
-            open(os.path.join(self.temp_d, f), 'w').write("DATA %s" % i)
+            open(os.path.join(self.temp_d, f), "w").write("DATA %s" % i)
             i += 1
             i += 1
         local_repo.stage(files)
         local_repo.stage(files)
-        local_repo.do_commit('Test commit', 'fbo@localhost',
-                             ref='refs/heads/master')
+        local_repo.do_commit("Test commit", "fbo@localhost", ref="refs/heads/master")
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack(self.fakerepo, determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            self.fakerepo, determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
-        commit_sha = swift_repo.refs.read_loose_ref('refs/heads/master')
+        commit_sha = swift_repo.refs.read_loose_ref("refs/heads/master")
         otype, data = swift_repo.object_store.get_raw(commit_sha)
         otype, data = swift_repo.object_store.get_raw(commit_sha)
         commit = objects.ShaFile.from_raw_string(otype, data)
         commit = objects.ShaFile.from_raw_string(otype, data)
         otype, data = swift_repo.object_store.get_raw(commit._tree)
         otype, data = swift_repo.object_store.get_raw(commit._tree)
@@ -222,8 +219,8 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         for tree_entry in objs:
         for tree_entry in objs:
             objs_.append(swift_repo.object_store.get_raw(tree_entry.sha))
             objs_.append(swift_repo.object_store.get_raw(tree_entry.sha))
         # Blob
         # Blob
-        self.assertEqual(objs_[1][1], 'DATA 0')
-        self.assertEqual(objs_[2][1], 'DATA 1')
+        self.assertEqual(objs_[1][1], "DATA 0")
+        self.assertEqual(objs_[2][1], "DATA 1")
         # Tree
         # Tree
         self.assertEqual(objs_[0][0], 2)
         self.assertEqual(objs_[0][0], 2)
 
 
@@ -231,80 +228,86 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.test_push_data_branch()
         self.test_push_data_branch()
         shutil.rmtree(self.temp_d)
         shutil.rmtree(self.temp_d)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
         remote_refs = tcp_client.fetch(self.fakerepo, local_repo)
         remote_refs = tcp_client.fetch(self.fakerepo, local_repo)
-        files = (os.path.join(self.temp_d, 'testfile'),
-                 os.path.join(self.temp_d, 'testfile2'))
+        files = (
+            os.path.join(self.temp_d, "testfile"),
+            os.path.join(self.temp_d, "testfile2"),
+        )
         local_repo["HEAD"] = remote_refs["refs/heads/master"]
         local_repo["HEAD"] = remote_refs["refs/heads/master"]
         indexfile = local_repo.index_path()
         indexfile = local_repo.index_path()
         tree = local_repo["HEAD"].tree
         tree = local_repo["HEAD"].tree
-        index.build_index_from_tree(local_repo.path, indexfile,
-                                    local_repo.object_store, tree)
+        index.build_index_from_tree(
+            local_repo.path, indexfile, local_repo.object_store, tree
+        )
         for f in files:
         for f in files:
             self.assertEqual(os.path.isfile(f), True)
             self.assertEqual(os.path.isfile(f), True)
 
 
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
             return {"refs/heads/master": local_repo.refs["HEAD"]}
+
         os.mkdir(os.path.join(self.temp_d, "test"))
         os.mkdir(os.path.join(self.temp_d, "test"))
-        files = ('testfile11', 'testfile22', 'test/testfile33')
+        files = ("testfile11", "testfile22", "test/testfile33")
         i = 0
         i = 0
         for f in files:
         for f in files:
-            open(os.path.join(self.temp_d, f), 'w').write("DATA %s" % i)
+            open(os.path.join(self.temp_d, f), "w").write("DATA %s" % i)
             i += 1
             i += 1
         local_repo.stage(files)
         local_repo.stage(files)
-        local_repo.do_commit('Test commit', 'fbo@localhost',
-                             ref='refs/heads/master')
-        tcp_client.send_pack("/fakerepo", determine_wants,
-                             local_repo.generate_pack_data)
+        local_repo.do_commit("Test commit", "fbo@localhost", ref="refs/heads/master")
+        tcp_client.send_pack(
+            "/fakerepo", determine_wants, local_repo.generate_pack_data
+        )
 
 
     def test_push_remove_branch(self):
     def test_push_remove_branch(self):
-        def determine_wants(*args):
-            return {"refs/heads/pullr-108": objects.ZERO_SHA,
-                    "refs/heads/master":
-                    local_repo.refs['refs/heads/master'],
-                    "refs/heads/mybranch":
-                    local_repo.refs['refs/heads/mybranch'],
-                    }
+        def determine_wants(*args, **kwargs):
+            return {
+                "refs/heads/pullr-108": objects.ZERO_SHA,
+                "refs/heads/master": local_repo.refs["refs/heads/master"],
+                "refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"],
+            }
+
         self.test_push_multiple_branch()
         self.test_push_multiple_branch()
         local_repo = repo.Repo(self.temp_d)
         local_repo = repo.Repo(self.temp_d)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack(self.fakerepo, determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            self.fakerepo, determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
-        self.assertNotIn('refs/heads/pullr-108', swift_repo.refs.allkeys())
+        self.assertNotIn("refs/heads/pullr-108", swift_repo.refs.allkeys())
 
 
     def test_push_annotated_tag(self):
     def test_push_annotated_tag(self):
-        def determine_wants(*args):
-            return {"refs/heads/master": local_repo.refs["HEAD"],
-                    "refs/tags/v1.0": local_repo.refs["refs/tags/v1.0"]}
+        def determine_wants(*args, **kwargs):
+            return {
+                "refs/heads/master": local_repo.refs["HEAD"],
+                "refs/tags/v1.0": local_repo.refs["refs/tags/v1.0"],
+            }
+
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         # Nothing in the staging area
         # Nothing in the staging area
-        sha = local_repo.do_commit('Test commit', 'fbo@localhost')
+        sha = local_repo.do_commit("Test commit", "fbo@localhost")
         otype, data = local_repo.object_store.get_raw(sha)
         otype, data = local_repo.object_store.get_raw(sha)
         commit = objects.ShaFile.from_raw_string(otype, data)
         commit = objects.ShaFile.from_raw_string(otype, data)
         tag = objects.Tag()
         tag = objects.Tag()
         tag.tagger = "fbo@localhost"
         tag.tagger = "fbo@localhost"
         tag.message = "Annotated tag"
         tag.message = "Annotated tag"
-        tag.tag_timezone = objects.parse_timezone('-0200')[0]
+        tag.tag_timezone = objects.parse_timezone("-0200")[0]
         tag.tag_time = commit.author_time
         tag.tag_time = commit.author_time
         tag.object = (objects.Commit, commit.id)
         tag.object = (objects.Commit, commit.id)
         tag.name = "v0.1"
         tag.name = "v0.1"
         local_repo.object_store.add_object(tag)
         local_repo.object_store.add_object(tag)
-        local_repo.refs['refs/tags/v1.0'] = tag.id
+        local_repo.refs["refs/tags/v1.0"] = tag.id
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         swift.SwiftRepo.init_bare(self.scon, self.conf)
-        tcp_client = client.TCPGitClient(self.server_address,
-                                         port=self.port)
-        tcp_client.send_pack(self.fakerepo, determine_wants,
-                             local_repo.generate_pack_data)
+        tcp_client = client.TCPGitClient(self.server_address, port=self.port)
+        tcp_client.send_pack(
+            self.fakerepo, determine_wants, local_repo.generate_pack_data
+        )
         swift_repo = swift.SwiftRepo(self.fakerepo, self.conf)
         swift_repo = swift.SwiftRepo(self.fakerepo, self.conf)
-        tag_sha = swift_repo.refs.read_loose_ref('refs/tags/v1.0')
+        tag_sha = swift_repo.refs.read_loose_ref("refs/tags/v1.0")
         otype, data = swift_repo.object_store.get_raw(tag_sha)
         otype, data = swift_repo.object_store.get_raw(tag_sha)
         rtag = objects.ShaFile.from_raw_string(otype, data)
         rtag = objects.ShaFile.from_raw_string(otype, data)
         self.assertEqual(rtag.object[1], commit.id)
         self.assertEqual(rtag.object[1], commit.id)
         self.assertEqual(rtag.id, tag.id)
         self.assertEqual(rtag.id, tag.id)
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     unittest.main()
     unittest.main()

+ 74 - 53
dulwich/diff_tree.py

@@ -23,7 +23,7 @@
 from collections import (
 from collections import (
     defaultdict,
     defaultdict,
     namedtuple,
     namedtuple,
-    )
+)
 
 
 from io import BytesIO
 from io import BytesIO
 from itertools import chain
 from itertools import chain
@@ -32,16 +32,16 @@ import stat
 from dulwich.objects import (
 from dulwich.objects import (
     S_ISGITLINK,
     S_ISGITLINK,
     TreeEntry,
     TreeEntry,
-    )
+)
 
 
 
 
 # TreeChange type constants.
 # TreeChange type constants.
-CHANGE_ADD = 'add'
-CHANGE_MODIFY = 'modify'
-CHANGE_DELETE = 'delete'
-CHANGE_RENAME = 'rename'
-CHANGE_COPY = 'copy'
-CHANGE_UNCHANGED = 'unchanged'
+CHANGE_ADD = "add"
+CHANGE_MODIFY = "modify"
+CHANGE_DELETE = "delete"
+CHANGE_RENAME = "rename"
+CHANGE_COPY = "copy"
+CHANGE_UNCHANGED = "unchanged"
 
 
 RENAME_CHANGE_TYPES = (CHANGE_RENAME, CHANGE_COPY)
 RENAME_CHANGE_TYPES = (CHANGE_RENAME, CHANGE_COPY)
 
 
@@ -53,7 +53,7 @@ MAX_FILES = 200
 REWRITE_THRESHOLD = None
 REWRITE_THRESHOLD = None
 
 
 
 
-class TreeChange(namedtuple('TreeChange', ['type', 'old', 'new'])):
+class TreeChange(namedtuple("TreeChange", ["type", "old", "new"])):
     """Named tuple a single change between two trees."""
     """Named tuple a single change between two trees."""
 
 
     @classmethod
     @classmethod
@@ -142,7 +142,7 @@ def walk_trees(store, tree1_id, tree2_id, prune_identical=False):
     # case.
     # case.
     mode1 = tree1_id and stat.S_IFDIR or None
     mode1 = tree1_id and stat.S_IFDIR or None
     mode2 = tree2_id and stat.S_IFDIR or None
     mode2 = tree2_id and stat.S_IFDIR or None
-    todo = [(TreeEntry(b'', mode1, tree1_id), TreeEntry(b'', mode2, tree2_id))]
+    todo = [(TreeEntry(b"", mode1, tree1_id), TreeEntry(b"", mode2, tree2_id))]
     while todo:
     while todo:
         entry1, entry2 = todo.pop()
         entry1, entry2 = todo.pop()
         is_tree1 = _is_tree(entry1)
         is_tree1 = _is_tree(entry1)
@@ -163,9 +163,15 @@ def _skip_tree(entry, include_trees):
     return entry
     return entry
 
 
 
 
-def tree_changes(store, tree1_id, tree2_id, want_unchanged=False,
-                 rename_detector=None, include_trees=False,
-                 change_type_same=False):
+def tree_changes(
+    store,
+    tree1_id,
+    tree2_id,
+    want_unchanged=False,
+    rename_detector=None,
+    include_trees=False,
+    change_type_same=False,
+):
     """Find the differences between the contents of two trees.
     """Find the differences between the contents of two trees.
 
 
     Args:
     Args:
@@ -182,16 +188,19 @@ def tree_changes(store, tree1_id, tree2_id, want_unchanged=False,
       Iterator over TreeChange instances for each change between the
       Iterator over TreeChange instances for each change between the
         source and target tree.
         source and target tree.
     """
     """
-    if (rename_detector is not None and tree1_id is not None and
-            tree2_id is not None):
+    if rename_detector is not None and tree1_id is not None and tree2_id is not None:
         for change in rename_detector.changes_with_renames(
         for change in rename_detector.changes_with_renames(
-                tree1_id, tree2_id, want_unchanged=want_unchanged,
-                include_trees=include_trees):
+            tree1_id,
+            tree2_id,
+            want_unchanged=want_unchanged,
+            include_trees=include_trees,
+        ):
             yield change
             yield change
         return
         return
 
 
-    entries = walk_trees(store, tree1_id, tree2_id,
-                         prune_identical=(not want_unchanged))
+    entries = walk_trees(
+        store, tree1_id, tree2_id, prune_identical=(not want_unchanged)
+    )
     for entry1, entry2 in entries:
     for entry1, entry2 in entries:
         if entry1 == entry2 and not want_unchanged:
         if entry1 == entry2 and not want_unchanged:
             continue
             continue
@@ -201,8 +210,10 @@ def tree_changes(store, tree1_id, tree2_id, want_unchanged=False,
         entry2 = _skip_tree(entry2, include_trees)
         entry2 = _skip_tree(entry2, include_trees)
 
 
         if entry1 != _NULL_ENTRY and entry2 != _NULL_ENTRY:
         if entry1 != _NULL_ENTRY and entry2 != _NULL_ENTRY:
-            if (stat.S_IFMT(entry1.mode) != stat.S_IFMT(entry2.mode)
-                    and not change_type_same):
+            if (
+                stat.S_IFMT(entry1.mode) != stat.S_IFMT(entry2.mode)
+                and not change_type_same
+            ):
                 # File type changed: report as delete/add.
                 # File type changed: report as delete/add.
                 yield TreeChange.delete(entry1)
                 yield TreeChange.delete(entry1)
                 entry1 = _NULL_ENTRY
                 entry1 = _NULL_ENTRY
@@ -232,8 +243,7 @@ def _all_same(seq, key):
     return _all_eq(seq[1:], key, key(seq[0]))
     return _all_eq(seq[1:], key, key(seq[0]))
 
 
 
 
-def tree_changes_for_merge(store, parent_tree_ids, tree_id,
-                           rename_detector=None):
+def tree_changes_for_merge(store, parent_tree_ids, tree_id, rename_detector=None):
     """Get the tree changes for a merge tree relative to all its parents.
     """Get the tree changes for a merge tree relative to all its parents.
 
 
     Args:
     Args:
@@ -254,9 +264,10 @@ def tree_changes_for_merge(store, parent_tree_ids, tree_id,
       in the merge tree is not found in any of the parents, or in the case of
       in the merge tree is not found in any of the parents, or in the case of
       deletes, if not all of the old SHAs match.
       deletes, if not all of the old SHAs match.
     """
     """
-    all_parent_changes = [tree_changes(store, t, tree_id,
-                                       rename_detector=rename_detector)
-                          for t in parent_tree_ids]
+    all_parent_changes = [
+        tree_changes(store, t, tree_id, rename_detector=rename_detector)
+        for t in parent_tree_ids
+    ]
     num_parents = len(parent_tree_ids)
     num_parents = len(parent_tree_ids)
     changes_by_path = defaultdict(lambda: [None] * num_parents)
     changes_by_path = defaultdict(lambda: [None] * num_parents)
 
 
@@ -315,10 +326,10 @@ def _count_blocks(obj):
     block_getvalue = block.getvalue
     block_getvalue = block.getvalue
 
 
     for c in chain.from_iterable(obj.as_raw_chunks()):
     for c in chain.from_iterable(obj.as_raw_chunks()):
-        c = c.to_bytes(1, 'big')
+        c = c.to_bytes(1, "big")
         block_write(c)
         block_write(c)
         n += 1
         n += 1
-        if c == b'\n' or n == _BLOCK_SIZE:
+        if c == b"\n" or n == _BLOCK_SIZE:
             value = block_getvalue()
             value = block_getvalue()
             block_counts[hash(value)] += len(value)
             block_counts[hash(value)] += len(value)
             block_seek(0)
             block_seek(0)
@@ -392,10 +403,14 @@ def _tree_change_key(entry):
 class RenameDetector(object):
 class RenameDetector(object):
     """Object for handling rename detection between two trees."""
     """Object for handling rename detection between two trees."""
 
 
-    def __init__(self, store, rename_threshold=RENAME_THRESHOLD,
-                 max_files=MAX_FILES,
-                 rewrite_threshold=REWRITE_THRESHOLD,
-                 find_copies_harder=False):
+    def __init__(
+        self,
+        store,
+        rename_threshold=RENAME_THRESHOLD,
+        max_files=MAX_FILES,
+        rewrite_threshold=REWRITE_THRESHOLD,
+        find_copies_harder=False,
+    ):
         """Initialize the rename detector.
         """Initialize the rename detector.
 
 
         Args:
         Args:
@@ -426,8 +441,11 @@ class RenameDetector(object):
         self._changes = []
         self._changes = []
 
 
     def _should_split(self, change):
     def _should_split(self, change):
-        if (self._rewrite_threshold is None or change.type != CHANGE_MODIFY or
-                change.old.sha == change.new.sha):
+        if (
+            self._rewrite_threshold is None
+            or change.type != CHANGE_MODIFY
+            or change.old.sha == change.new.sha
+        ):
             return False
             return False
         old_obj = self._store[change.old.sha]
         old_obj = self._store[change.old.sha]
         new_obj = self._store[change.new.sha]
         new_obj = self._store[change.new.sha]
@@ -441,8 +459,9 @@ class RenameDetector(object):
         elif self._should_split(change):
         elif self._should_split(change):
             self._deletes.append(TreeChange.delete(change.old))
             self._deletes.append(TreeChange.delete(change.old))
             self._adds.append(TreeChange.add(change.new))
             self._adds.append(TreeChange.add(change.new))
-        elif ((self._find_copies_harder and change.type == CHANGE_UNCHANGED)
-              or change.type == CHANGE_MODIFY):
+        elif (
+            self._find_copies_harder and change.type == CHANGE_UNCHANGED
+        ) or change.type == CHANGE_MODIFY:
             # Treat all modifies as potential deletes for rename detection,
             # Treat all modifies as potential deletes for rename detection,
             # but don't split them (to avoid spurious renames). Setting
             # but don't split them (to avoid spurious renames). Setting
             # find_copies_harder means we treat unchanged the same as
             # find_copies_harder means we treat unchanged the same as
@@ -453,15 +472,18 @@ class RenameDetector(object):
 
 
     def _collect_changes(self, tree1_id, tree2_id):
     def _collect_changes(self, tree1_id, tree2_id):
         want_unchanged = self._find_copies_harder or self._want_unchanged
         want_unchanged = self._find_copies_harder or self._want_unchanged
-        for change in tree_changes(self._store, tree1_id, tree2_id,
-                                   want_unchanged=want_unchanged,
-                                   include_trees=self._include_trees):
+        for change in tree_changes(
+            self._store,
+            tree1_id,
+            tree2_id,
+            want_unchanged=want_unchanged,
+            include_trees=self._include_trees,
+        ):
             self._add_change(change)
             self._add_change(change)
 
 
     def _prune(self, add_paths, delete_paths):
     def _prune(self, add_paths, delete_paths):
         self._adds = [a for a in self._adds if a.new.path not in add_paths]
         self._adds = [a for a in self._adds if a.new.path not in add_paths]
-        self._deletes = [d for d in self._deletes
-                         if d.old.path not in delete_paths]
+        self._deletes = [d for d in self._deletes if d.old.path not in delete_paths]
 
 
     def _find_exact_renames(self):
     def _find_exact_renames(self):
         add_map = defaultdict(list)
         add_map = defaultdict(list)
@@ -534,8 +556,7 @@ class RenameDetector(object):
                 if stat.S_IFMT(delete.old.mode) != stat.S_IFMT(add.new.mode):
                 if stat.S_IFMT(delete.old.mode) != stat.S_IFMT(add.new.mode):
                     continue
                     continue
                 new_obj = self._store[add.new.sha]
                 new_obj = self._store[add.new.sha]
-                score = _similarity_score(old_obj, new_obj,
-                                          block_cache=block_cache)
+                score = _similarity_score(old_obj, new_obj, block_cache=block_cache)
                 if score > self._rename_threshold:
                 if score > self._rename_threshold:
                     new_type = self._rename_type(check_paths, delete, add)
                     new_type = self._rename_type(check_paths, delete, add)
                     rename = TreeChange(new_type, delete.old, add.new)
                     rename = TreeChange(new_type, delete.old, add.new)
@@ -570,17 +591,17 @@ class RenameDetector(object):
             return
             return
 
 
         modifies = {}
         modifies = {}
-        delete_map = dict((d.old.path, d) for d in self._deletes)
+        delete_map = {d.old.path: d for d in self._deletes}
         for add in self._adds:
         for add in self._adds:
             path = add.new.path
             path = add.new.path
             delete = delete_map.get(path)
             delete = delete_map.get(path)
-            if (delete is not None and
-                    stat.S_IFMT(delete.old.mode) == stat.S_IFMT(add.new.mode)):
+            if delete is not None and stat.S_IFMT(delete.old.mode) == stat.S_IFMT(
+                add.new.mode
+            ):
                 modifies[path] = TreeChange(CHANGE_MODIFY, delete.old, add.new)
                 modifies[path] = TreeChange(CHANGE_MODIFY, delete.old, add.new)
 
 
         self._adds = [a for a in self._adds if a.new.path not in modifies]
         self._adds = [a for a in self._adds if a.new.path not in modifies]
-        self._deletes = [a for a in self._deletes if a.new.path not in
-                         modifies]
+        self._deletes = [a for a in self._deletes if a.new.path not in modifies]
         self._changes += modifies.values()
         self._changes += modifies.values()
 
 
     def _sorted_changes(self):
     def _sorted_changes(self):
@@ -594,11 +615,11 @@ class RenameDetector(object):
     def _prune_unchanged(self):
     def _prune_unchanged(self):
         if self._want_unchanged:
         if self._want_unchanged:
             return
             return
-        self._deletes = [
-            d for d in self._deletes if d.type != CHANGE_UNCHANGED]
+        self._deletes = [d for d in self._deletes if d.type != CHANGE_UNCHANGED]
 
 
-    def changes_with_renames(self, tree1_id, tree2_id, want_unchanged=False,
-                             include_trees=False):
+    def changes_with_renames(
+        self, tree1_id, tree2_id, want_unchanged=False, include_trees=False
+    ):
         """Iterate TreeChanges between two tree SHAs, with rename detection."""
         """Iterate TreeChanges between two tree SHAs, with rename detection."""
         self._reset()
         self._reset()
         self._want_unchanged = want_unchanged
         self._want_unchanged = want_unchanged
@@ -622,6 +643,6 @@ try:
         _is_tree,
         _is_tree,
         _merge_entries,
         _merge_entries,
         _count_blocks,
         _count_blocks,
-        )
+    )
 except ImportError:
 except ImportError:
     pass
     pass

+ 22 - 19
dulwich/errors.py

@@ -37,12 +37,14 @@ class ChecksumMismatch(Exception):
         self.extra = extra
         self.extra = extra
         if self.extra is None:
         if self.extra is None:
             Exception.__init__(
             Exception.__init__(
-                self, "Checksum mismatch: Expected %s, got %s" %
-                (expected, got))
+                self,
+                "Checksum mismatch: Expected %s, got %s" % (expected, got),
+            )
         else:
         else:
             Exception.__init__(
             Exception.__init__(
-                self, "Checksum mismatch: Expected %s, got %s; %s" %
-                (expected, got, extra))
+                self,
+                "Checksum mismatch: Expected %s, got %s; %s" % (expected, got, extra),
+            )
 
 
 
 
 class WrongObjectException(Exception):
 class WrongObjectException(Exception):
@@ -61,25 +63,25 @@ class WrongObjectException(Exception):
 class NotCommitError(WrongObjectException):
 class NotCommitError(WrongObjectException):
     """Indicates that the sha requested does not point to a commit."""
     """Indicates that the sha requested does not point to a commit."""
 
 
-    type_name = 'commit'
+    type_name = "commit"
 
 
 
 
 class NotTreeError(WrongObjectException):
 class NotTreeError(WrongObjectException):
     """Indicates that the sha requested does not point to a tree."""
     """Indicates that the sha requested does not point to a tree."""
 
 
-    type_name = 'tree'
+    type_name = "tree"
 
 
 
 
 class NotTagError(WrongObjectException):
 class NotTagError(WrongObjectException):
     """Indicates that the sha requested does not point to a tag."""
     """Indicates that the sha requested does not point to a tag."""
 
 
-    type_name = 'tag'
+    type_name = "tag"
 
 
 
 
 class NotBlobError(WrongObjectException):
 class NotBlobError(WrongObjectException):
     """Indicates that the sha requested does not point to a blob."""
     """Indicates that the sha requested does not point to a blob."""
 
 
-    type_name = 'blob'
+    type_name = "blob"
 
 
 
 
 class MissingCommitError(Exception):
 class MissingCommitError(Exception):
@@ -132,7 +134,7 @@ class UpdateRefsError(GitProtocolError):
     """The server reported errors updating refs."""
     """The server reported errors updating refs."""
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
-        self.ref_status = kwargs.pop('ref_status')
+        self.ref_status = kwargs.pop("ref_status")
         super(UpdateRefsError, self).__init__(*args, **kwargs)
         super(UpdateRefsError, self).__init__(*args, **kwargs)
 
 
 
 
@@ -142,18 +144,18 @@ class HangupException(GitProtocolError):
     def __init__(self, stderr_lines=None):
     def __init__(self, stderr_lines=None):
         if stderr_lines:
         if stderr_lines:
             super(HangupException, self).__init__(
             super(HangupException, self).__init__(
-                '\n'.join(
-                    [line.decode('utf-8', 'surrogateescape')
-                     for line in stderr_lines]))
+                "\n".join(
+                    [line.decode("utf-8", "surrogateescape") for line in stderr_lines]
+                )
+            )
         else:
         else:
             super(HangupException, self).__init__(
             super(HangupException, self).__init__(
-                "The remote server unexpectedly closed the connection.")
+                "The remote server unexpectedly closed the connection."
+            )
         self.stderr_lines = stderr_lines
         self.stderr_lines = stderr_lines
 
 
     def __eq__(self, other):
     def __eq__(self, other):
-        return (
-            isinstance(self, type(other)) and
-            self.stderr_lines == other.stderr_lines)
+        return isinstance(self, type(other)) and self.stderr_lines == other.stderr_lines
 
 
 
 
 class UnexpectedCommandError(GitProtocolError):
 class UnexpectedCommandError(GitProtocolError):
@@ -161,11 +163,12 @@ class UnexpectedCommandError(GitProtocolError):
 
 
     def __init__(self, command):
     def __init__(self, command):
         if command is None:
         if command is None:
-            command = 'flush-pkt'
+            command = "flush-pkt"
         else:
         else:
-            command = 'command %s' % command
+            command = "command %s" % command
         super(UnexpectedCommandError, self).__init__(
         super(UnexpectedCommandError, self).__init__(
-            'Protocol got unexpected %s' % command)
+            "Protocol got unexpected %s" % command
+        )
 
 
 
 
 class FileFormatException(Exception):
 class FileFormatException(Exception):

+ 39 - 24
dulwich/fastexport.py

@@ -23,19 +23,19 @@
 
 
 from dulwich.index import (
 from dulwich.index import (
     commit_tree,
     commit_tree,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
     Commit,
     Commit,
     Tag,
     Tag,
     ZERO_SHA,
     ZERO_SHA,
-    )
+)
 from fastimport import (  # noqa: E402
 from fastimport import (  # noqa: E402
     commands,
     commands,
     errors as fastimport_errors,
     errors as fastimport_errors,
     parser,
     parser,
     processor,
     processor,
-    )
+)
 
 
 import stat  # noqa: E402
 import stat  # noqa: E402
 
 
@@ -59,7 +59,7 @@ class GitFastExporter(object):
 
 
     def _allocate_marker(self):
     def _allocate_marker(self):
         self._marker_idx += 1
         self._marker_idx += 1
-        return ("%d" % (self._marker_idx,)).encode('ascii')
+        return ("%d" % (self._marker_idx,)).encode("ascii")
 
 
     def _export_blob(self, blob):
     def _export_blob(self, blob):
         marker = self._allocate_marker()
         marker = self._allocate_marker()
@@ -72,9 +72,11 @@ class GitFastExporter(object):
         return marker
         return marker
 
 
     def _iter_files(self, base_tree, new_tree):
     def _iter_files(self, base_tree, new_tree):
-        for ((old_path, new_path), (old_mode, new_mode),
-             (old_hexsha, new_hexsha)) in \
-                self.store.tree_changes(base_tree, new_tree):
+        for (
+            (old_path, new_path),
+            (old_mode, new_mode),
+            (old_hexsha, new_hexsha),
+        ) in self.store.tree_changes(base_tree, new_tree):
             if new_path is None:
             if new_path is None:
                 yield commands.FileDeleteCommand(old_path)
                 yield commands.FileDeleteCommand(old_path)
                 continue
                 continue
@@ -84,7 +86,7 @@ class GitFastExporter(object):
             if old_path != new_path and old_path is not None:
             if old_path != new_path and old_path is not None:
                 yield commands.FileRenameCommand(old_path, new_path)
                 yield commands.FileRenameCommand(old_path, new_path)
             if old_mode != new_mode or old_hexsha != new_hexsha:
             if old_mode != new_mode or old_hexsha != new_hexsha:
-                prefixed_marker = b':' + marker
+                prefixed_marker = b":" + marker
                 yield commands.FileModifyCommand(
                 yield commands.FileModifyCommand(
                     new_path, new_mode, prefixed_marker, None
                     new_path, new_mode, prefixed_marker, None
                 )
                 )
@@ -101,11 +103,20 @@ class GitFastExporter(object):
         author, author_email = split_email(commit.author)
         author, author_email = split_email(commit.author)
         committer, committer_email = split_email(commit.committer)
         committer, committer_email = split_email(commit.committer)
         cmd = commands.CommitCommand(
         cmd = commands.CommitCommand(
-            ref, marker,
+            ref,
+            marker,
             (author, author_email, commit.author_time, commit.author_timezone),
             (author, author_email, commit.author_time, commit.author_timezone),
-            (committer, committer_email, commit.commit_time,
-                commit.commit_timezone),
-            commit.message, from_, merges, file_cmds)
+            (
+                committer,
+                committer_email,
+                commit.commit_time,
+                commit.commit_timezone,
+            ),
+            commit.message,
+            from_,
+            merges,
+            file_cmds,
+        )
         return (cmd, marker)
         return (cmd, marker)
 
 
     def emit_commit(self, commit, ref, base_tree=None):
     def emit_commit(self, commit, ref, base_tree=None):
@@ -115,9 +126,8 @@ class GitFastExporter(object):
 
 
 
 
 class GitImportProcessor(processor.ImportProcessor):
 class GitImportProcessor(processor.ImportProcessor):
-    """An import processor that imports into a Git repository using Dulwich.
+    """An import processor that imports into a Git repository using Dulwich."""
 
 
-    """
     # FIXME: Batch creation of objects?
     # FIXME: Batch creation of objects?
 
 
     def __init__(self, repo, params=None, verbose=False, outf=None):
     def __init__(self, repo, params=None, verbose=False, outf=None):
@@ -156,8 +166,12 @@ class GitImportProcessor(processor.ImportProcessor):
         else:
         else:
             author = cmd.committer
             author = cmd.committer
         (author_name, author_email, author_timestamp, author_timezone) = author
         (author_name, author_email, author_timestamp, author_timezone) = author
-        (committer_name, committer_email, commit_timestamp,
-            commit_timezone) = cmd.committer
+        (
+            committer_name,
+            committer_email,
+            commit_timestamp,
+            commit_timezone,
+        ) = cmd.committer
         commit.author = author_name + b" <" + author_email + b">"
         commit.author = author_name + b" <" + author_email + b">"
         commit.author_timezone = author_timezone
         commit.author_timezone = author_timezone
         commit.author_time = int(author_timestamp)
         commit.author_time = int(author_timestamp)
@@ -181,11 +195,9 @@ class GitImportProcessor(processor.ImportProcessor):
             elif filecmd.name == b"filedelete":
             elif filecmd.name == b"filedelete":
                 del self._contents[filecmd.path]
                 del self._contents[filecmd.path]
             elif filecmd.name == b"filecopy":
             elif filecmd.name == b"filecopy":
-                self._contents[filecmd.dest_path] = self._contents[
-                    filecmd.src_path]
+                self._contents[filecmd.dest_path] = self._contents[filecmd.src_path]
             elif filecmd.name == b"filerename":
             elif filecmd.name == b"filerename":
-                self._contents[filecmd.new_path] = self._contents[
-                    filecmd.old_path]
+                self._contents[filecmd.new_path] = self._contents[filecmd.old_path]
                 del self._contents[filecmd.old_path]
                 del self._contents[filecmd.old_path]
             elif filecmd.name == b"filedeleteall":
             elif filecmd.name == b"filedeleteall":
                 self._contents = {}
                 self._contents = {}
@@ -193,8 +205,8 @@ class GitImportProcessor(processor.ImportProcessor):
                 raise Exception("Command %s not supported" % filecmd.name)
                 raise Exception("Command %s not supported" % filecmd.name)
         commit.tree = commit_tree(
         commit.tree = commit_tree(
             self.repo.object_store,
             self.repo.object_store,
-            ((path, hexsha, mode) for (path, (mode, hexsha)) in
-                self._contents.items()))
+            ((path, hexsha, mode) for (path, (mode, hexsha)) in self._contents.items()),
+        )
         if self.last_commit != ZERO_SHA:
         if self.last_commit != ZERO_SHA:
             commit.parents.append(self.last_commit)
             commit.parents.append(self.last_commit)
         for merge in cmd.merges:
         for merge in cmd.merges:
@@ -216,8 +228,11 @@ class GitImportProcessor(processor.ImportProcessor):
         self.last_commit = commit_id
         self.last_commit = commit_id
         if commit_id != ZERO_SHA:
         if commit_id != ZERO_SHA:
             tree_id = self.repo[commit_id].tree
             tree_id = self.repo[commit_id].tree
-            for (path, mode, hexsha) in (
-                    self.repo.object_store.iter_tree_contents(tree_id)):
+            for (
+                path,
+                mode,
+                hexsha,
+            ) in self.repo.object_store.iter_tree_contents(tree_id):
                 self._contents[path] = (mode, hexsha)
                 self._contents[path] = (mode, hexsha)
 
 
     def reset_handler(self, cmd):
     def reset_handler(self, cmd):

+ 41 - 20
dulwich/file.py

@@ -44,6 +44,7 @@ def _fancy_rename(oldname, newname):
 
 
     # Defer the tempfile import since it pulls in a lot of other things.
     # Defer the tempfile import since it pulls in a lot of other things.
     import tempfile
     import tempfile
+
     # destination file exists
     # destination file exists
     try:
     try:
         (fd, tmpfile) = tempfile.mkstemp(".tmp", prefix=oldname, dir=".")
         (fd, tmpfile) = tempfile.mkstemp(".tmp", prefix=oldname, dir=".")
@@ -56,7 +57,7 @@ def _fancy_rename(oldname, newname):
     try:
     try:
         os.rename(newname, tmpfile)
         os.rename(newname, tmpfile)
     except OSError:
     except OSError:
-        raise   # no rename occurred
+        raise  # no rename occurred
     try:
     try:
         os.rename(oldname, newname)
         os.rename(oldname, newname)
     except OSError:
     except OSError:
@@ -65,7 +66,7 @@ def _fancy_rename(oldname, newname):
     os.remove(tmpfile)
     os.remove(tmpfile)
 
 
 
 
-def GitFile(filename, mode='rb', bufsize=-1):
+def GitFile(filename, mode="rb", bufsize=-1):
     """Create a file object that obeys the git file locking protocol.
     """Create a file object that obeys the git file locking protocol.
 
 
     Returns: a builtin file object or a _GitFile object
     Returns: a builtin file object or a _GitFile object
@@ -77,13 +78,13 @@ def GitFile(filename, mode='rb', bufsize=-1):
     the fact that opening a file for write does not actually open the file you
     the fact that opening a file for write does not actually open the file you
     request.
     request.
     """
     """
-    if 'a' in mode:
-        raise IOError('append mode not supported for Git files')
-    if '+' in mode:
-        raise IOError('read/write mode not supported for Git files')
-    if 'b' not in mode:
-        raise IOError('text mode not supported for Git files')
-    if 'w' in mode:
+    if "a" in mode:
+        raise IOError("append mode not supported for Git files")
+    if "+" in mode:
+        raise IOError("read/write mode not supported for Git files")
+    if "b" not in mode:
+        raise IOError("text mode not supported for Git files")
+    if "w" in mode:
         return _GitFile(filename, mode, bufsize)
         return _GitFile(filename, mode, bufsize)
     else:
     else:
         return io.open(filename, mode, bufsize)
         return io.open(filename, mode, bufsize)
@@ -109,23 +110,43 @@ class _GitFile(object):
         released. Typically this will happen in a finally block.
         released. Typically this will happen in a finally block.
     """
     """
 
 
-    PROXY_PROPERTIES = set(['closed', 'encoding', 'errors', 'mode', 'name',
-                            'newlines', 'softspace'])
-    PROXY_METHODS = ('__iter__', 'flush', 'fileno', 'isatty', 'read',
-                     'readline', 'readlines', 'seek', 'tell',
-                     'truncate', 'write', 'writelines')
+    PROXY_PROPERTIES = set(
+        [
+            "closed",
+            "encoding",
+            "errors",
+            "mode",
+            "name",
+            "newlines",
+            "softspace",
+        ]
+    )
+    PROXY_METHODS = (
+        "__iter__",
+        "flush",
+        "fileno",
+        "isatty",
+        "read",
+        "readline",
+        "readlines",
+        "seek",
+        "tell",
+        "truncate",
+        "write",
+        "writelines",
+    )
 
 
     def __init__(self, filename, mode, bufsize):
     def __init__(self, filename, mode, bufsize):
         self._filename = filename
         self._filename = filename
         if isinstance(self._filename, bytes):
         if isinstance(self._filename, bytes):
-            self._lockfilename = self._filename + b'.lock'
+            self._lockfilename = self._filename + b".lock"
         else:
         else:
-            self._lockfilename = self._filename + '.lock'
+            self._lockfilename = self._filename + ".lock"
         try:
         try:
             fd = os.open(
             fd = os.open(
                 self._lockfilename,
                 self._lockfilename,
-                os.O_RDWR | os.O_CREAT | os.O_EXCL |
-                getattr(os, "O_BINARY", 0))
+                os.O_RDWR | os.O_CREAT | os.O_EXCL | getattr(os, "O_BINARY", 0),
+            )
         except FileExistsError:
         except FileExistsError:
             raise FileLocked(filename, self._lockfilename)
             raise FileLocked(filename, self._lockfilename)
         self._file = os.fdopen(fd, mode, bufsize)
         self._file = os.fdopen(fd, mode, bufsize)
@@ -166,10 +187,10 @@ class _GitFile(object):
         os.fsync(self._file.fileno())
         os.fsync(self._file.fileno())
         self._file.close()
         self._file.close()
         try:
         try:
-            if getattr(os, 'replace', None) is not None:
+            if getattr(os, "replace", None) is not None:
                 os.replace(self._lockfilename, self._filename)
                 os.replace(self._lockfilename, self._filename)
             else:
             else:
-                if sys.platform != 'win32':
+                if sys.platform != "win32":
                     os.rename(self._lockfilename, self._filename)
                     os.rename(self._lockfilename, self._filename)
                 else:
                 else:
                     # Windows versions prior to Vista don't support atomic
                     # Windows versions prior to Vista don't support atomic

+ 2 - 2
dulwich/graph.py

@@ -33,8 +33,8 @@ def _find_lcas(lookup_parents, c1, c2s):
     # Flags to Record State
     # Flags to Record State
     _ANC_OF_1 = 1  # ancestor of commit 1
     _ANC_OF_1 = 1  # ancestor of commit 1
     _ANC_OF_2 = 2  # ancestor of commit 2
     _ANC_OF_2 = 2  # ancestor of commit 2
-    _DNC = 4       # Do Not Consider
-    _LCA = 8       # potential LCA
+    _DNC = 4  # Do Not Consider
+    _LCA = 8  # potential LCA
 
 
     def _has_candidates(wlst, cstates):
     def _has_candidates(wlst, cstates):
         for cmt in wlst:
         for cmt in wlst:

+ 23 - 19
dulwich/greenthreads.py

@@ -28,16 +28,15 @@ from gevent import pool
 from dulwich.objects import (
 from dulwich.objects import (
     Commit,
     Commit,
     Tag,
     Tag,
-    )
+)
 from dulwich.object_store import (
 from dulwich.object_store import (
     MissingObjectFinder,
     MissingObjectFinder,
     _collect_filetree_revs,
     _collect_filetree_revs,
     ObjectStoreIterator,
     ObjectStoreIterator,
-    )
+)
 
 
 
 
-def _split_commits_and_tags(obj_store, lst,
-                            ignore_unknown=False, pool=None):
+def _split_commits_and_tags(obj_store, lst, ignore_unknown=False, pool=None):
     """Split object id list into two list with commit SHA1s and tag SHA1s.
     """Split object id list into two list with commit SHA1s and tag SHA1s.
 
 
     Same implementation as object_store._split_commits_and_tags
     Same implementation as object_store._split_commits_and_tags
@@ -59,7 +58,8 @@ def _split_commits_and_tags(obj_store, lst,
                 tags.add(sha)
                 tags.add(sha)
                 commits.add(o.object[1])
                 commits.add(o.object[1])
             else:
             else:
-                raise KeyError('Not a commit or a tag: %s' % sha)
+                raise KeyError("Not a commit or a tag: %s" % sha)
+
     jobs = [pool.spawn(find_commit_type, s) for s in lst]
     jobs = [pool.spawn(find_commit_type, s) for s in lst]
     gevent.joinall(jobs)
     gevent.joinall(jobs)
     return (commits, tags)
     return (commits, tags)
@@ -71,10 +71,17 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
     Same implementation as object_store.MissingObjectFinder
     Same implementation as object_store.MissingObjectFinder
     except we use gevent to parallelize object retrieval.
     except we use gevent to parallelize object retrieval.
     """
     """
-    def __init__(self, object_store, haves, wants,
-                 progress=None, get_tagged=None,
-                 concurrency=1, get_parents=None):
 
 
+    def __init__(
+        self,
+        object_store,
+        haves,
+        wants,
+        progress=None,
+        get_tagged=None,
+        concurrency=1,
+        get_parents=None,
+    ):
         def collect_tree_sha(sha):
         def collect_tree_sha(sha):
             self.sha_done.add(sha)
             self.sha_done.add(sha)
             cmt = object_store[sha]
             cmt = object_store[sha]
@@ -83,15 +90,12 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
         self.object_store = object_store
         self.object_store = object_store
         p = pool.Pool(size=concurrency)
         p = pool.Pool(size=concurrency)
 
 
-        have_commits, have_tags = \
-            _split_commits_and_tags(object_store, haves,
-                                    True, p)
-        want_commits, want_tags = \
-            _split_commits_and_tags(object_store, wants,
-                                    False, p)
+        have_commits, have_tags = _split_commits_and_tags(object_store, haves, True, p)
+        want_commits, want_tags = _split_commits_and_tags(object_store, wants, False, p)
         all_ancestors = object_store._collect_ancestors(have_commits)[0]
         all_ancestors = object_store._collect_ancestors(have_commits)[0]
-        missing_commits, common_commits = \
-            object_store._collect_ancestors(want_commits, all_ancestors)
+        missing_commits, common_commits = object_store._collect_ancestors(
+            want_commits, all_ancestors
+        )
 
 
         self.sha_done = set()
         self.sha_done = set()
         jobs = [p.spawn(collect_tree_sha, c) for c in common_commits]
         jobs = [p.spawn(collect_tree_sha, c) for c in common_commits]
@@ -114,6 +118,7 @@ class GreenThreadsObjectStoreIterator(ObjectStoreIterator):
     Same implementation as object_store.ObjectStoreIterator
     Same implementation as object_store.ObjectStoreIterator
     except we use gevent to parallelize object retrieval.
     except we use gevent to parallelize object retrieval.
     """
     """
+
     def __init__(self, store, shas, finder, concurrency=1):
     def __init__(self, store, shas, finder, concurrency=1):
         self.finder = finder
         self.finder = finder
         self.p = pool.Pool(size=concurrency)
         self.p = pool.Pool(size=concurrency)
@@ -124,14 +129,13 @@ class GreenThreadsObjectStoreIterator(ObjectStoreIterator):
         return self.store[sha], path
         return self.store[sha], path
 
 
     def __iter__(self):
     def __iter__(self):
-        for sha, path in self.p.imap_unordered(self.retrieve,
-                                               self.itershas()):
+        for sha, path in self.p.imap_unordered(self.retrieve, self.itershas()):
             yield sha, path
             yield sha, path
 
 
     def __len__(self):
     def __len__(self):
         if len(self._shas) > 0:
         if len(self._shas) > 0:
             return len(self._shas)
             return len(self._shas)
-        while len(self.finder.objects_to_send):
+        while self.finder.objects_to_send:
             jobs = []
             jobs = []
             for _ in range(0, len(self.finder.objects_to_send)):
             for _ in range(0, len(self.finder.objects_to_send)):
                 jobs.append(self.p.spawn(self.finder.next))
                 jobs.append(self.p.spawn(self.finder.next))

+ 38 - 28
dulwich/hooks.py

@@ -52,9 +52,15 @@ class ShellHook(Hook):
     [0] http://www.kernel.org/pub/software/scm/git/docs/githooks.html
     [0] http://www.kernel.org/pub/software/scm/git/docs/githooks.html
     """
     """
 
 
-    def __init__(self, name, path, numparam,
-                 pre_exec_callback=None, post_exec_callback=None,
-                 cwd=None):
+    def __init__(
+        self,
+        name,
+        path,
+        numparam,
+        pre_exec_callback=None,
+        post_exec_callback=None,
+        cwd=None,
+    ):
         """Setup shell hook definition
         """Setup shell hook definition
 
 
         Args:
         Args:
@@ -84,24 +90,27 @@ class ShellHook(Hook):
         """Execute the hook with given args"""
         """Execute the hook with given args"""
 
 
         if len(args) != self.numparam:
         if len(args) != self.numparam:
-            raise HookError("Hook %s executed with wrong number of args. \
+            raise HookError(
+                "Hook %s executed with wrong number of args. \
                             Expected %d. Saw %d. args: %s"
                             Expected %d. Saw %d. args: %s"
-                            % (self.name, self.numparam, len(args), args))
+                % (self.name, self.numparam, len(args), args)
+            )
 
 
-        if (self.pre_exec_callback is not None):
+        if self.pre_exec_callback is not None:
             args = self.pre_exec_callback(*args)
             args = self.pre_exec_callback(*args)
 
 
         try:
         try:
             ret = subprocess.call([self.filepath] + list(args), cwd=self.cwd)
             ret = subprocess.call([self.filepath] + list(args), cwd=self.cwd)
             if ret != 0:
             if ret != 0:
-                if (self.post_exec_callback is not None):
+                if self.post_exec_callback is not None:
                     self.post_exec_callback(0, *args)
                     self.post_exec_callback(0, *args)
-                raise HookError("Hook %s exited with non-zero status %d"
-                                % (self.name, ret))
-            if (self.post_exec_callback is not None):
+                raise HookError(
+                    "Hook %s exited with non-zero status %d" % (self.name, ret)
+                )
+            if self.post_exec_callback is not None:
                 return self.post_exec_callback(1, *args)
                 return self.post_exec_callback(1, *args)
         except OSError:  # no file. silent failure.
         except OSError:  # no file. silent failure.
-            if (self.post_exec_callback is not None):
+            if self.post_exec_callback is not None:
                 self.post_exec_callback(0, *args)
                 self.post_exec_callback(0, *args)
 
 
 
 
@@ -109,18 +118,18 @@ class PreCommitShellHook(ShellHook):
     """pre-commit shell hook"""
     """pre-commit shell hook"""
 
 
     def __init__(self, controldir):
     def __init__(self, controldir):
-        filepath = os.path.join(controldir, 'hooks', 'pre-commit')
+        filepath = os.path.join(controldir, "hooks", "pre-commit")
 
 
-        ShellHook.__init__(self, 'pre-commit', filepath, 0, cwd=controldir)
+        ShellHook.__init__(self, "pre-commit", filepath, 0, cwd=controldir)
 
 
 
 
 class PostCommitShellHook(ShellHook):
 class PostCommitShellHook(ShellHook):
     """post-commit shell hook"""
     """post-commit shell hook"""
 
 
     def __init__(self, controldir):
     def __init__(self, controldir):
-        filepath = os.path.join(controldir, 'hooks', 'post-commit')
+        filepath = os.path.join(controldir, "hooks", "post-commit")
 
 
-        ShellHook.__init__(self, 'post-commit', filepath, 0, cwd=controldir)
+        ShellHook.__init__(self, "post-commit", filepath, 0, cwd=controldir)
 
 
 
 
 class CommitMsgShellHook(ShellHook):
 class CommitMsgShellHook(ShellHook):
@@ -133,27 +142,29 @@ class CommitMsgShellHook(ShellHook):
     """
     """
 
 
     def __init__(self, controldir):
     def __init__(self, controldir):
-        filepath = os.path.join(controldir, 'hooks', 'commit-msg')
+        filepath = os.path.join(controldir, "hooks", "commit-msg")
 
 
         def prepare_msg(*args):
         def prepare_msg(*args):
             import tempfile
             import tempfile
+
             (fd, path) = tempfile.mkstemp()
             (fd, path) = tempfile.mkstemp()
 
 
-            with os.fdopen(fd, 'wb') as f:
+            with os.fdopen(fd, "wb") as f:
                 f.write(args[0])
                 f.write(args[0])
 
 
             return (path,)
             return (path,)
 
 
         def clean_msg(success, *args):
         def clean_msg(success, *args):
             if success:
             if success:
-                with open(args[0], 'rb') as f:
+                with open(args[0], "rb") as f:
                     new_msg = f.read()
                     new_msg = f.read()
                 os.unlink(args[0])
                 os.unlink(args[0])
                 return new_msg
                 return new_msg
             os.unlink(args[0])
             os.unlink(args[0])
 
 
-        ShellHook.__init__(self, 'commit-msg', filepath, 1,
-                           prepare_msg, clean_msg, controldir)
+        ShellHook.__init__(
+            self, "commit-msg", filepath, 1, prepare_msg, clean_msg, controldir
+        )
 
 
 
 
 class PostReceiveShellHook(ShellHook):
 class PostReceiveShellHook(ShellHook):
@@ -161,8 +172,8 @@ class PostReceiveShellHook(ShellHook):
 
 
     def __init__(self, controldir):
     def __init__(self, controldir):
         self.controldir = controldir
         self.controldir = controldir
-        filepath = os.path.join(controldir, 'hooks', 'post-receive')
-        ShellHook.__init__(self, 'post-receive', filepath, 0)
+        filepath = os.path.join(controldir, "hooks", "post-receive")
+        ShellHook.__init__(self, "post-receive", filepath, 0)
 
 
     def execute(self, client_refs):
     def execute(self, client_refs):
         # do nothing if the script doesn't exist
         # do nothing if the script doesn't exist
@@ -171,26 +182,25 @@ class PostReceiveShellHook(ShellHook):
 
 
         try:
         try:
             env = os.environ.copy()
             env = os.environ.copy()
-            env['GIT_DIR'] = self.controldir
+            env["GIT_DIR"] = self.controldir
 
 
             p = subprocess.Popen(
             p = subprocess.Popen(
                 self.filepath,
                 self.filepath,
                 stdin=subprocess.PIPE,
                 stdin=subprocess.PIPE,
                 stdout=subprocess.PIPE,
                 stdout=subprocess.PIPE,
                 stderr=subprocess.PIPE,
                 stderr=subprocess.PIPE,
-                env=env
+                env=env,
             )
             )
 
 
             # client_refs is a list of (oldsha, newsha, ref)
             # client_refs is a list of (oldsha, newsha, ref)
-            in_data = '\n'.join([' '.join(ref) for ref in client_refs])
+            in_data = "\n".join([" ".join(ref) for ref in client_refs])
 
 
             out_data, err_data = p.communicate(in_data)
             out_data, err_data = p.communicate(in_data)
 
 
             if (p.returncode != 0) or err_data:
             if (p.returncode != 0) or err_data:
-                err_fmt = "post-receive exit code: %d\n" \
-                    + "stdout:\n%s\nstderr:\n%s"
+                err_fmt = "post-receive exit code: %d\n" + "stdout:\n%s\nstderr:\n%s"
                 err_msg = err_fmt % (p.returncode, out_data, err_data)
                 err_msg = err_fmt % (p.returncode, out_data, err_data)
-                raise HookError(err_msg)
+                raise HookError(err_msg.decode('utf-8', 'backslashreplace'))
             return out_data
             return out_data
         except OSError as err:
         except OSError as err:
             raise HookError(repr(err))
             raise HookError(repr(err))

+ 86 - 85
dulwich/ignore.py

@@ -32,7 +32,7 @@ from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
     Dict,
     Dict,
     Union,
     Union,
-    )
+)
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from dulwich.repo import Repo
     from dulwich.repo import Repo
@@ -42,34 +42,34 @@ from dulwich.config import get_xdg_config_home_path, Config
 
 
 def _translate_segment(segment: bytes) -> bytes:
 def _translate_segment(segment: bytes) -> bytes:
     if segment == b"*":
     if segment == b"*":
-        return b'[^/]+'
+        return b"[^/]+"
     res = b""
     res = b""
     i, n = 0, len(segment)
     i, n = 0, len(segment)
     while i < n:
     while i < n:
-        c = segment[i:i+1]
-        i = i+1
-        if c == b'*':
-            res += b'[^/]*'
-        elif c == b'?':
-            res += b'[^/]'
-        elif c == b'[':
+        c = segment[i : i + 1]
+        i = i + 1
+        if c == b"*":
+            res += b"[^/]*"
+        elif c == b"?":
+            res += b"[^/]"
+        elif c == b"[":
             j = i
             j = i
-            if j < n and segment[j:j+1] == b'!':
-                j = j+1
-            if j < n and segment[j:j+1] == b']':
-                j = j+1
-            while j < n and segment[j:j+1] != b']':
-                j = j+1
+            if j < n and segment[j : j + 1] == b"!":
+                j = j + 1
+            if j < n and segment[j : j + 1] == b"]":
+                j = j + 1
+            while j < n and segment[j : j + 1] != b"]":
+                j = j + 1
             if j >= n:
             if j >= n:
-                res += b'\\['
+                res += b"\\["
             else:
             else:
-                stuff = segment[i:j].replace(b'\\', b'\\\\')
-                i = j+1
-                if stuff.startswith(b'!'):
-                    stuff = b'^' + stuff[1:]
-                elif stuff.startswith(b'^'):
-                    stuff = b'\\' + stuff
-                res += b'[' + stuff + b']'
+                stuff = segment[i:j].replace(b"\\", b"\\\\")
+                i = j + 1
+                if stuff.startswith(b"!"):
+                    stuff = b"^" + stuff[1:]
+                elif stuff.startswith(b"^"):
+                    stuff = b"\\" + stuff
+                res += b"[" + stuff + b"]"
         else:
         else:
             res += re.escape(c)
             res += re.escape(c)
     return res
     return res
@@ -84,32 +84,31 @@ def translate(pat: bytes) -> bytes:
     to cope with features in Git ignore patterns.
     to cope with features in Git ignore patterns.
     """
     """
 
 
-    res = b'(?ms)'
+    res = b"(?ms)"
 
 
-    if b'/' not in pat[:-1]:
+    if b"/" not in pat[:-1]:
         # If there's no slash, this is a filename-based match
         # If there's no slash, this is a filename-based match
-        res += b'(.*/)?'
+        res += b"(.*/)?"
 
 
-    if pat.startswith(b'**/'):
+    if pat.startswith(b"**/"):
         # Leading **/
         # Leading **/
         pat = pat[2:]
         pat = pat[2:]
-        res += b'(.*/)?'
+        res += b"(.*/)?"
 
 
-    if pat.startswith(b'/'):
+    if pat.startswith(b"/"):
         pat = pat[1:]
         pat = pat[1:]
 
 
-    for i, segment in enumerate(pat.split(b'/')):
-        if segment == b'**':
-            res += b'(/.*)?'
+    for i, segment in enumerate(pat.split(b"/")):
+        if segment == b"**":
+            res += b"(/.*)?"
             continue
             continue
         else:
         else:
-            res += ((re.escape(b'/') if i > 0 else b'') +
-                    _translate_segment(segment))
+            res += (re.escape(b"/") if i > 0 else b"") + _translate_segment(segment)
 
 
-    if not pat.endswith(b'/'):
-        res += b'/?'
+    if not pat.endswith(b"/"):
+        res += b"/?"
 
 
-    return res + b'\\Z'
+    return res + b"\\Z"
 
 
 
 
 def read_ignore_patterns(f: BinaryIO) -> Iterable[bytes]:
 def read_ignore_patterns(f: BinaryIO) -> Iterable[bytes]:
@@ -127,20 +126,19 @@ def read_ignore_patterns(f: BinaryIO) -> Iterable[bytes]:
         if not line:
         if not line:
             continue
             continue
 
 
-        if line.startswith(b'#'):
+        if line.startswith(b"#"):
             # Comment
             # Comment
             continue
             continue
 
 
         # Trailing spaces are ignored unless they are quoted with a backslash.
         # Trailing spaces are ignored unless they are quoted with a backslash.
-        while line.endswith(b' ') and not line.endswith(b'\\ '):
+        while line.endswith(b" ") and not line.endswith(b"\\ "):
             line = line[:-1]
             line = line[:-1]
-        line = line.replace(b'\\ ', b' ')
+        line = line.replace(b"\\ ", b" ")
 
 
         yield line
         yield line
 
 
 
 
-def match_pattern(
-        path: bytes, pattern: bytes, ignorecase: bool = False) -> bool:
+def match_pattern(path: bytes, pattern: bytes, ignorecase: bool = False) -> bool:
     """Match a gitignore-style pattern against a path.
     """Match a gitignore-style pattern against a path.
 
 
     Args:
     Args:
@@ -159,11 +157,11 @@ class Pattern(object):
     def __init__(self, pattern: bytes, ignorecase: bool = False):
     def __init__(self, pattern: bytes, ignorecase: bool = False):
         self.pattern = pattern
         self.pattern = pattern
         self.ignorecase = ignorecase
         self.ignorecase = ignorecase
-        if pattern[0:1] == b'!':
+        if pattern[0:1] == b"!":
             self.is_exclude = False
             self.is_exclude = False
             pattern = pattern[1:]
             pattern = pattern[1:]
         else:
         else:
-            if pattern[0:1] == b'\\':
+            if pattern[0:1] == b"\\":
                 pattern = pattern[1:]
                 pattern = pattern[1:]
             self.is_exclude = True
             self.is_exclude = True
         flags = 0
         flags = 0
@@ -178,13 +176,18 @@ class Pattern(object):
         return os.fsdecode(self.pattern)
         return os.fsdecode(self.pattern)
 
 
     def __eq__(self, other: object) -> bool:
     def __eq__(self, other: object) -> bool:
-        return (isinstance(other, type(self)) and
-                self.pattern == other.pattern and
-                self.ignorecase == other.ignorecase)
+        return (
+            isinstance(other, type(self))
+            and self.pattern == other.pattern
+            and self.ignorecase == other.ignorecase
+        )
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return "%s(%r, %r)" % (
         return "%s(%r, %r)" % (
-            type(self).__name__, self.pattern, self.ignorecase)
+            type(self).__name__,
+            self.pattern,
+            self.ignorecase,
+        )
 
 
     def match(self, path: bytes) -> bool:
     def match(self, path: bytes) -> bool:
         """Try to match a path against this ignore pattern.
         """Try to match a path against this ignore pattern.
@@ -197,9 +200,7 @@ class Pattern(object):
 
 
 
 
 class IgnoreFilter(object):
 class IgnoreFilter(object):
-
-    def __init__(self, patterns: Iterable[bytes], ignorecase: bool = False,
-                 path=None):
+    def __init__(self, patterns: Iterable[bytes], ignorecase: bool = False, path=None):
         self._patterns = []  # type: List[Pattern]
         self._patterns = []  # type: List[Pattern]
         self._ignorecase = ignorecase
         self._ignorecase = ignorecase
         self._path = path
         self._path = path
@@ -238,15 +239,14 @@ class IgnoreFilter(object):
         return status
         return status
 
 
     @classmethod
     @classmethod
-    def from_path(cls, path, ignorecase: bool = False) -> 'IgnoreFilter':
-        with open(path, 'rb') as f:
+    def from_path(cls, path, ignorecase: bool = False) -> "IgnoreFilter":
+        with open(path, "rb") as f:
             return cls(read_ignore_patterns(f), ignorecase, path=path)
             return cls(read_ignore_patterns(f), ignorecase, path=path)
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
-        path = getattr(self, '_path', None)
+        path = getattr(self, "_path", None)
         if path is not None:
         if path is not None:
-            return "%s.from_path(%r)" % (
-                type(self).__name__, path)
+            return "%s.from_path(%r)" % (type(self).__name__, path)
         else:
         else:
             return "<%s>" % (type(self).__name__)
             return "<%s>" % (type(self).__name__)
 
 
@@ -283,19 +283,22 @@ def default_user_ignore_filter_path(config: Config) -> str:
       Path to a global ignore file
       Path to a global ignore file
     """
     """
     try:
     try:
-        return config.get((b'core', ), b'excludesFile')
+        return config.get((b"core",), b"excludesFile")
     except KeyError:
     except KeyError:
         pass
         pass
 
 
-    return get_xdg_config_home_path('git', 'ignore')
+    return get_xdg_config_home_path("git", "ignore")
 
 
 
 
 class IgnoreFilterManager(object):
 class IgnoreFilterManager(object):
     """Ignore file manager."""
     """Ignore file manager."""
 
 
     def __init__(
     def __init__(
-            self, top_path: str, global_filters: List[IgnoreFilter],
-            ignorecase: bool):
+        self,
+        top_path: str,
+        global_filters: List[IgnoreFilter],
+        ignorecase: bool,
+    ):
         self._path_filters = {}  # type: Dict[str, Optional[IgnoreFilter]]
         self._path_filters = {}  # type: Dict[str, Optional[IgnoreFilter]]
         self._top_path = top_path
         self._top_path = top_path
         self._global_filters = global_filters
         self._global_filters = global_filters
@@ -303,9 +306,11 @@ class IgnoreFilterManager(object):
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return "%s(%s, %r, %r)" % (
         return "%s(%s, %r, %r)" % (
-            type(self).__name__, self._top_path,
+            type(self).__name__,
+            self._top_path,
             self._global_filters,
             self._global_filters,
-            self._ignorecase)
+            self._ignorecase,
+        )
 
 
     def _load_path(self, path: str) -> Optional[IgnoreFilter]:
     def _load_path(self, path: str) -> Optional[IgnoreFilter]:
         try:
         try:
@@ -313,10 +318,9 @@ class IgnoreFilterManager(object):
         except KeyError:
         except KeyError:
             pass
             pass
 
 
-        p = os.path.join(self._top_path, path, '.gitignore')
+        p = os.path.join(self._top_path, path, ".gitignore")
         try:
         try:
-            self._path_filters[path] = IgnoreFilter.from_path(
-                p, self._ignorecase)
+            self._path_filters[path] = IgnoreFilter.from_path(p, self._ignorecase)
         except IOError:
         except IOError:
             self._path_filters[path] = None
             self._path_filters[path] = None
         return self._path_filters[path]
         return self._path_filters[path]
@@ -324,34 +328,31 @@ class IgnoreFilterManager(object):
     def find_matching(self, path: str) -> Iterable[Pattern]:
     def find_matching(self, path: str) -> Iterable[Pattern]:
         """Find matching patterns for path.
         """Find matching patterns for path.
 
 
-        Stops after the first ignore file with matches.
-
         Args:
         Args:
           path: Path to check
           path: Path to check
         Returns:
         Returns:
           Iterator over Pattern instances
           Iterator over Pattern instances
         """
         """
         if os.path.isabs(path):
         if os.path.isabs(path):
-            raise ValueError('%s is an absolute path' % path)
+            raise ValueError("%s is an absolute path" % path)
         filters = [(0, f) for f in self._global_filters]
         filters = [(0, f) for f in self._global_filters]
-        if os.path.sep != '/':
-            path = path.replace(os.path.sep, '/')
-        parts = path.split('/')
-        for i in range(len(parts)+1):
-            dirname = '/'.join(parts[:i])
+        if os.path.sep != "/":
+            path = path.replace(os.path.sep, "/")
+        parts = path.split("/")
+        matches = []
+        for i in range(len(parts) + 1):
+            dirname = "/".join(parts[:i])
             for s, f in filters:
             for s, f in filters:
-                relpath = '/'.join(parts[s:i])
+                relpath = "/".join(parts[s:i])
                 if i < len(parts):
                 if i < len(parts):
                     # Paths leading up to the final part are all directories,
                     # Paths leading up to the final part are all directories,
                     # so need a trailing slash.
                     # so need a trailing slash.
-                    relpath += '/'
-                matches = list(f.find_matching(relpath))
-                if matches:
-                    return iter(matches)
+                    relpath += "/"
+                matches += list(f.find_matching(relpath))
             ignore_filter = self._load_path(dirname)
             ignore_filter = self._load_path(dirname)
             if ignore_filter is not None:
             if ignore_filter is not None:
                 filters.insert(0, (i, ignore_filter))
                 filters.insert(0, (i, ignore_filter))
-        return iter([])
+        return iter(matches)
 
 
     def is_ignored(self, path: str) -> Optional[bool]:
     def is_ignored(self, path: str) -> Optional[bool]:
         """Check whether a path is explicitly included or excluded in ignores.
         """Check whether a path is explicitly included or excluded in ignores.
@@ -368,7 +369,7 @@ class IgnoreFilterManager(object):
         return None
         return None
 
 
     @classmethod
     @classmethod
-    def from_repo(cls, repo: 'Repo') -> 'IgnoreFilterManager':
+    def from_repo(cls, repo: "Repo") -> "IgnoreFilterManager":
         """Create a IgnoreFilterManager from a repository.
         """Create a IgnoreFilterManager from a repository.
 
 
         Args:
         Args:
@@ -378,13 +379,13 @@ class IgnoreFilterManager(object):
         """
         """
         global_filters = []
         global_filters = []
         for p in [
         for p in [
-                os.path.join(repo.controldir(), 'info', 'exclude'),
-                default_user_ignore_filter_path(repo.get_config_stack())]:
+            os.path.join(repo.controldir(), "info", "exclude"),
+            default_user_ignore_filter_path(repo.get_config_stack()),
+        ]:
             try:
             try:
-                global_filters.append(
-                    IgnoreFilter.from_path(os.path.expanduser(p)))
+                global_filters.append(IgnoreFilter.from_path(os.path.expanduser(p)))
             except IOError:
             except IOError:
                 pass
                 pass
         config = repo.get_config_stack()
         config = repo.get_config_stack()
-        ignorecase = config.get_boolean((b'core'), (b'ignorecase'), False)
+        ignorecase = config.get_boolean((b"core"), (b"ignorecase"), False)
         return cls(repo.path, global_filters, ignorecase)
         return cls(repo.path, global_filters, ignorecase)

+ 216 - 115
dulwich/index.py

@@ -36,7 +36,7 @@ from typing import (
     Iterable,
     Iterable,
     Iterator,
     Iterator,
     Tuple,
     Tuple,
-    )
+)
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from dulwich.object_store import BaseObjectStore
     from dulwich.object_store import BaseObjectStore
@@ -49,24 +49,49 @@ from dulwich.objects import (
     Tree,
     Tree,
     hex_to_sha,
     hex_to_sha,
     sha_to_hex,
     sha_to_hex,
-    )
+)
 from dulwich.pack import (
 from dulwich.pack import (
     SHA1Reader,
     SHA1Reader,
     SHA1Writer,
     SHA1Writer,
-    )
+)
 
 
 
 
+# TODO(jelmer): Switch to dataclass?
 IndexEntry = collections.namedtuple(
 IndexEntry = collections.namedtuple(
-    'IndexEntry', [
-        'ctime', 'mtime', 'dev', 'ino', 'mode', 'uid', 'gid', 'size', 'sha',
-        'flags'])
-
-
+    "IndexEntry",
+    [
+        "ctime",
+        "mtime",
+        "dev",
+        "ino",
+        "mode",
+        "uid",
+        "gid",
+        "size",
+        "sha",
+        "flags",
+        "extended_flags",
+    ],
+)
+
+
+# 2-bit stage (during merge)
 FLAG_STAGEMASK = 0x3000
 FLAG_STAGEMASK = 0x3000
+
+# assume-valid
 FLAG_VALID = 0x8000
 FLAG_VALID = 0x8000
+
+# extended flag (must be zero in version 2)
 FLAG_EXTENDED = 0x4000
 FLAG_EXTENDED = 0x4000
 
 
 
 
+# used by sparse checkout
+EXTENDED_FLAG_SKIP_WORKTREE = 0x4000
+
+# used by "git add -N"
+EXTENDED_FLAG_INTEND_TO_ADD = 0x2000
+
+
 DEFAULT_VERSION = 2
 DEFAULT_VERSION = 2
 
 
 
 
@@ -87,9 +112,7 @@ def pathsplit(path):
 
 
 
 
 def pathjoin(*args):
 def pathjoin(*args):
-    """Join a /-delimited path.
-
-    """
+    """Join a /-delimited path."""
     return b"/".join([p for p in args if p])
     return b"/".join([p for p in args if p])
 
 
 
 
@@ -121,57 +144,101 @@ def write_cache_time(f, t):
     f.write(struct.pack(">LL", *t))
     f.write(struct.pack(">LL", *t))
 
 
 
 
-def read_cache_entry(f):
+def read_cache_entry(f, version: int) -> Tuple[str, IndexEntry]:
     """Read an entry from a cache file.
     """Read an entry from a cache file.
 
 
     Args:
     Args:
       f: File-like object to read from
       f: File-like object to read from
     Returns:
     Returns:
-      tuple with: device, inode, mode, uid, gid, size, sha, flags
+      tuple with: name, IndexEntry
     """
     """
     beginoffset = f.tell()
     beginoffset = f.tell()
     ctime = read_cache_time(f)
     ctime = read_cache_time(f)
     mtime = read_cache_time(f)
     mtime = read_cache_time(f)
-    (dev, ino, mode, uid, gid, size, sha, flags, ) = \
-        struct.unpack(">LLLLLL20sH", f.read(20 + 4 * 6 + 2))
-    name = f.read((flags & 0x0fff))
+    (
+        dev,
+        ino,
+        mode,
+        uid,
+        gid,
+        size,
+        sha,
+        flags,
+    ) = struct.unpack(">LLLLLL20sH", f.read(20 + 4 * 6 + 2))
+    if flags & FLAG_EXTENDED:
+        if version < 3:
+            raise AssertionError(
+                'extended flag set in index with version < 3')
+        (extended_flags, ) = struct.unpack(">H", f.read(2))
+    else:
+        extended_flags = 0
+    name = f.read((flags & 0x0FFF))
     # Padding:
     # Padding:
-    real_size = ((f.tell() - beginoffset + 8) & ~7)
-    f.read((beginoffset + real_size) - f.tell())
-    return (name, ctime, mtime, dev, ino, mode, uid, gid, size,
-            sha_to_hex(sha), flags & ~0x0fff)
-
-
-def write_cache_entry(f, entry):
+    if version < 4:
+        real_size = (f.tell() - beginoffset + 8) & ~7
+        f.read((beginoffset + real_size) - f.tell())
+    return (
+        name,
+        IndexEntry(
+            ctime,
+            mtime,
+            dev,
+            ino,
+            mode,
+            uid,
+            gid,
+            size,
+            sha_to_hex(sha),
+            flags & ~0x0FFF,
+            extended_flags,
+        ))
+
+
+def write_cache_entry(f, name, entry, version):
     """Write an index entry to a file.
     """Write an index entry to a file.
 
 
     Args:
     Args:
       f: File object
       f: File object
-      entry: Entry to write, tuple with:
-        (name, ctime, mtime, dev, ino, mode, uid, gid, size, sha, flags)
+      entry: IndexEntry to write, tuple with:
     """
     """
     beginoffset = f.tell()
     beginoffset = f.tell()
-    (name, ctime, mtime, dev, ino, mode, uid, gid, size, sha, flags) = entry
-    write_cache_time(f, ctime)
-    write_cache_time(f, mtime)
-    flags = len(name) | (flags & ~0x0fff)
-    f.write(struct.pack(
-            b'>LLLLLL20sH', dev & 0xFFFFFFFF, ino & 0xFFFFFFFF,
-            mode, uid, gid, size, hex_to_sha(sha), flags))
+    write_cache_time(f, entry.ctime)
+    write_cache_time(f, entry.mtime)
+    flags = len(name) | (entry.flags & ~0x0FFF)
+    if entry.extended_flags:
+        flags |= FLAG_EXTENDED
+    if flags & FLAG_EXTENDED and version is not None and version < 3:
+        raise AssertionError('unable to use extended flags in version < 3')
+    f.write(
+        struct.pack(
+            b">LLLLLL20sH",
+            entry.dev & 0xFFFFFFFF,
+            entry.ino & 0xFFFFFFFF,
+            entry.mode,
+            entry.uid,
+            entry.gid,
+            entry.size,
+            hex_to_sha(entry.sha),
+            flags,
+        )
+    )
+    if flags & FLAG_EXTENDED:
+        f.write(struct.pack(b">H", entry.extended_flags))
     f.write(name)
     f.write(name)
-    real_size = ((f.tell() - beginoffset + 8) & ~7)
-    f.write(b'\0' * ((beginoffset + real_size) - f.tell()))
+    if version < 4:
+        real_size = (f.tell() - beginoffset + 8) & ~7
+        f.write(b"\0" * ((beginoffset + real_size) - f.tell()))
 
 
 
 
 def read_index(f: BinaryIO):
 def read_index(f: BinaryIO):
     """Read an index file, yielding the individual entries."""
     """Read an index file, yielding the individual entries."""
     header = f.read(4)
     header = f.read(4)
-    if header != b'DIRC':
+    if header != b"DIRC":
         raise AssertionError("Invalid index file header: %r" % header)
         raise AssertionError("Invalid index file header: %r" % header)
-    (version, num_entries) = struct.unpack(b'>LL', f.read(4 * 2))
-    assert version in (1, 2)
+    (version, num_entries) = struct.unpack(b">LL", f.read(4 * 2))
+    assert version in (1, 2, 3), "index version is %r" % version
     for i in range(num_entries):
     for i in range(num_entries):
-        yield read_cache_entry(f)
+        yield read_cache_entry(f, version)
 
 
 
 
 def read_index_dict(f):
 def read_index_dict(f):
@@ -181,14 +248,12 @@ def read_index_dict(f):
       f: File object to read from
       f: File object to read from
     """
     """
     ret = {}
     ret = {}
-    for x in read_index(f):
-        ret[x[0]] = IndexEntry(*x[1:])
+    for name, entry in read_index(f):
+        ret[name] = entry
     return ret
     return ret
 
 
 
 
-def write_index(
-        f: BinaryIO,
-        entries: List[Any], version: Optional[int] = None):
+def write_index(f: BinaryIO, entries: List[Tuple[bytes, IndexEntry]], version: Optional[int] = None):
     """Write an index file.
     """Write an index file.
 
 
     Args:
     Args:
@@ -198,21 +263,21 @@ def write_index(
     """
     """
     if version is None:
     if version is None:
         version = DEFAULT_VERSION
         version = DEFAULT_VERSION
-    f.write(b'DIRC')
-    f.write(struct.pack(b'>LL', version, len(entries)))
-    for x in entries:
-        write_cache_entry(f, x)
+    f.write(b"DIRC")
+    f.write(struct.pack(b">LL", version, len(entries)))
+    for name, entry in entries:
+        write_cache_entry(f, name, entry, version)
 
 
 
 
 def write_index_dict(
 def write_index_dict(
-        f: BinaryIO, entries: Dict[bytes, IndexEntry],
-        version: Optional[int] = None) -> None:
-    """Write an index file based on the contents of a dictionary.
-
-    """
+    f: BinaryIO,
+    entries: Dict[bytes, IndexEntry],
+    version: Optional[int] = None,
+) -> None:
+    """Write an index file based on the contents of a dictionary."""
     entries_list = []
     entries_list = []
     for name in sorted(entries):
     for name in sorted(entries):
-        entries_list.append((name,) + tuple(entries[name]))
+        entries_list.append((name, entries[name]))
     write_index(f, entries_list, version=version)
     write_index(f, entries_list, version=version)
 
 
 
 
@@ -262,7 +327,7 @@ class Index(object):
 
 
     def write(self) -> None:
     def write(self) -> None:
         """Write current contents of index to disk."""
         """Write current contents of index to disk."""
-        f = GitFile(self._filename, 'wb')
+        f = GitFile(self._filename, "wb")
         try:
         try:
             f = SHA1Writer(f)
             f = SHA1Writer(f)
             write_index_dict(f, self._byname, version=self._version)
             write_index_dict(f, self._byname, version=self._version)
@@ -273,13 +338,13 @@ class Index(object):
         """Read current contents of index from disk."""
         """Read current contents of index from disk."""
         if not os.path.exists(self._filename):
         if not os.path.exists(self._filename):
             return
             return
-        f = GitFile(self._filename, 'rb')
+        f = GitFile(self._filename, "rb")
         try:
         try:
             f = SHA1Reader(f)
             f = SHA1Reader(f)
-            for x in read_index(f):
-                self[x[0]] = IndexEntry(*x[1:])
+            for name, entry in read_index(f):
+                self[name] = entry
             # FIXME: Additional data?
             # FIXME: Additional data?
-            f.read(os.path.getsize(self._filename)-f.tell()-20)
+            f.read(os.path.getsize(self._filename) - f.tell() - 20)
             f.check_sha()
             f.check_sha()
         finally:
         finally:
             f.close()
             f.close()
@@ -316,7 +381,8 @@ class Index(object):
 
 
     def iterblobs(self):
     def iterblobs(self):
         import warnings
         import warnings
-        warnings.warn('Use iterobjects() instead.', PendingDeprecationWarning)
+
+        warnings.warn("Use iterobjects() instead.", PendingDeprecationWarning)
         return self.iterobjects()
         return self.iterobjects()
 
 
     def clear(self):
     def clear(self):
@@ -325,7 +391,7 @@ class Index(object):
 
 
     def __setitem__(self, name, x):
     def __setitem__(self, name, x):
         assert isinstance(name, bytes)
         assert isinstance(name, bytes)
-        assert len(x) == 10
+        assert len(x) == len(IndexEntry._fields)
         # Remove the old entry if any
         # Remove the old entry if any
         self._byname[name] = IndexEntry(*x)
         self._byname[name] = IndexEntry(*x)
 
 
@@ -353,12 +419,18 @@ class Index(object):
         Returns: Iterator over tuples with (oldpath, newpath), (oldmode,
         Returns: Iterator over tuples with (oldpath, newpath), (oldmode,
             newmode), (oldsha, newsha)
             newmode), (oldsha, newsha)
         """
         """
+
         def lookup_entry(path):
         def lookup_entry(path):
             entry = self[path]
             entry = self[path]
             return entry.sha, cleanup_mode(entry.mode)
             return entry.sha, cleanup_mode(entry.mode)
+
         for (name, mode, sha) in changes_from_tree(
         for (name, mode, sha) in changes_from_tree(
-                self._byname.keys(), lookup_entry, object_store, tree,
-                want_unchanged=want_unchanged):
+            self._byname.keys(),
+            lookup_entry,
+            object_store,
+            tree,
+            want_unchanged=want_unchanged,
+        ):
             yield (name, mode, sha)
             yield (name, mode, sha)
 
 
     def commit(self, object_store):
     def commit(self, object_store):
@@ -373,8 +445,8 @@ class Index(object):
 
 
 
 
 def commit_tree(
 def commit_tree(
-        object_store: 'BaseObjectStore',
-        blobs: Iterable[Tuple[bytes, bytes, int]]) -> bytes:
+    object_store: "BaseObjectStore", blobs: Iterable[Tuple[bytes, bytes, int]]
+) -> bytes:
     """Commit a new tree.
     """Commit a new tree.
 
 
     Args:
     Args:
@@ -383,7 +455,7 @@ def commit_tree(
     Returns:
     Returns:
       SHA1 of the created tree.
       SHA1 of the created tree.
     """
     """
-    trees = {b'': {}}  # type: Dict[bytes, Any]
+    trees = {b"": {}}  # type: Dict[bytes, Any]
 
 
     def add_tree(path):
     def add_tree(path):
         if path in trees:
         if path in trees:
@@ -412,10 +484,11 @@ def commit_tree(
             tree.add(basename, mode, sha)
             tree.add(basename, mode, sha)
         object_store.add_object(tree)
         object_store.add_object(tree)
         return tree.id
         return tree.id
-    return build_tree(b'')
+
+    return build_tree(b"")
 
 
 
 
-def commit_index(object_store: 'BaseObjectStore', index: Index) -> bytes:
+def commit_index(object_store: "BaseObjectStore", index: Index) -> bytes:
     """Create a new tree from an index.
     """Create a new tree from an index.
 
 
     Args:
     Args:
@@ -428,14 +501,18 @@ def commit_index(object_store: 'BaseObjectStore', index: Index) -> bytes:
 
 
 
 
 def changes_from_tree(
 def changes_from_tree(
-        names: Iterable[bytes],
-        lookup_entry: Callable[[bytes], Tuple[bytes, int]],
-        object_store: 'BaseObjectStore', tree: Optional[bytes],
-        want_unchanged=False) -> Iterable[
-            Tuple[
-                Tuple[Optional[bytes], Optional[bytes]],
-                Tuple[Optional[int], Optional[int]],
-                Tuple[Optional[bytes], Optional[bytes]]]]:
+    names: Iterable[bytes],
+    lookup_entry: Callable[[bytes], Tuple[bytes, int]],
+    object_store: "BaseObjectStore",
+    tree: Optional[bytes],
+    want_unchanged=False,
+) -> Iterable[
+    Tuple[
+        Tuple[Optional[bytes], Optional[bytes]],
+        Tuple[Optional[int], Optional[int]],
+        Tuple[Optional[bytes], Optional[bytes]],
+    ]
+]:
     """Find the differences between the contents of a tree and
     """Find the differences between the contents of a tree and
     a working copy.
     a working copy.
 
 
@@ -460,7 +537,7 @@ def changes_from_tree(
                 yield ((name, None), (mode, None), (sha, None))
                 yield ((name, None), (mode, None), (sha, None))
             else:
             else:
                 other_names.remove(name)
                 other_names.remove(name)
-                if (want_unchanged or other_sha != sha or other_mode != mode):
+                if want_unchanged or other_sha != sha or other_mode != mode:
                     yield ((name, name), (mode, other_mode), (sha, other_sha))
                     yield ((name, name), (mode, other_mode), (sha, other_sha))
 
 
     # Mention added files
     # Mention added files
@@ -474,8 +551,9 @@ def changes_from_tree(
 
 
 
 
 def index_entry_from_stat(
 def index_entry_from_stat(
-        stat_val, hex_sha: bytes, flags: int,
-        mode: Optional[int] = None):
+    stat_val, hex_sha: bytes, flags: int, mode: Optional[int] = None,
+    extended_flags: Optional[int] = None
+):
     """Create a new index entry from a stat value.
     """Create a new index entry from a stat value.
 
 
     Args:
     Args:
@@ -487,13 +565,23 @@ def index_entry_from_stat(
         mode = cleanup_mode(stat_val.st_mode)
         mode = cleanup_mode(stat_val.st_mode)
 
 
     return IndexEntry(
     return IndexEntry(
-            stat_val.st_ctime, stat_val.st_mtime, stat_val.st_dev,
-            stat_val.st_ino, mode, stat_val.st_uid,
-            stat_val.st_gid, stat_val.st_size, hex_sha, flags)
+        stat_val.st_ctime,
+        stat_val.st_mtime,
+        stat_val.st_dev,
+        stat_val.st_ino,
+        mode,
+        stat_val.st_uid,
+        stat_val.st_gid,
+        stat_val.st_size,
+        hex_sha,
+        flags,
+        extended_flags
+    )
 
 
 
 
-def build_file_from_blob(blob, mode, target_path, honor_filemode=True,
-                         tree_encoding='utf-8'):
+def build_file_from_blob(
+    blob, mode, target_path, honor_filemode=True, tree_encoding="utf-8"
+):
     """Build a file or symlink on disk based on a Git object.
     """Build a file or symlink on disk based on a Git object.
 
 
     Args:
     Args:
@@ -513,18 +601,18 @@ def build_file_from_blob(blob, mode, target_path, honor_filemode=True,
         # FIXME: This will fail on Windows. What should we do instead?
         # FIXME: This will fail on Windows. What should we do instead?
         if oldstat:
         if oldstat:
             os.unlink(target_path)
             os.unlink(target_path)
-        if sys.platform == 'win32':
+        if sys.platform == "win32":
             # os.readlink on Python3 on Windows requires a unicode string.
             # os.readlink on Python3 on Windows requires a unicode string.
             contents = contents.decode(tree_encoding)
             contents = contents.decode(tree_encoding)
             target_path = target_path.decode(tree_encoding)
             target_path = target_path.decode(tree_encoding)
         os.symlink(contents, target_path)
         os.symlink(contents, target_path)
     else:
     else:
         if oldstat is not None and oldstat.st_size == len(contents):
         if oldstat is not None and oldstat.st_size == len(contents):
-            with open(target_path, 'rb') as f:
+            with open(target_path, "rb") as f:
                 if f.read() == contents:
                 if f.read() == contents:
                     return oldstat
                     return oldstat
 
 
-        with open(target_path, 'wb') as f:
+        with open(target_path, "wb") as f:
             # Write out file
             # Write out file
             f.write(contents)
             f.write(contents)
 
 
@@ -560,9 +648,14 @@ def validate_path(path, element_validator=validate_path_element_default):
         return True
         return True
 
 
 
 
-def build_index_from_tree(root_path, index_path, object_store, tree_id,
-                          honor_filemode=True,
-                          validate_path_element=validate_path_element_default):
+def build_index_from_tree(
+    root_path,
+    index_path,
+    object_store,
+    tree_id,
+    honor_filemode=True,
+    validate_path_element=validate_path_element_default,
+):
     """Generate and materialize index from a tree
     """Generate and materialize index from a tree
 
 
     Args:
     Args:
@@ -600,23 +693,33 @@ def build_index_from_tree(root_path, index_path, object_store, tree_id,
         else:
         else:
             obj = object_store[entry.sha]
             obj = object_store[entry.sha]
             st = build_file_from_blob(
             st = build_file_from_blob(
-                obj, entry.mode, full_path, honor_filemode=honor_filemode)
+                obj, entry.mode, full_path, honor_filemode=honor_filemode
+            )
 
 
         # Add file to index
         # Add file to index
         if not honor_filemode or S_ISGITLINK(entry.mode):
         if not honor_filemode or S_ISGITLINK(entry.mode):
             # we can not use tuple slicing to build a new tuple,
             # we can not use tuple slicing to build a new tuple,
             # because on windows that will convert the times to
             # because on windows that will convert the times to
             # longs, which causes errors further along
             # longs, which causes errors further along
-            st_tuple = (entry.mode, st.st_ino, st.st_dev, st.st_nlink,
-                        st.st_uid, st.st_gid, st.st_size, st.st_atime,
-                        st.st_mtime, st.st_ctime)
+            st_tuple = (
+                entry.mode,
+                st.st_ino,
+                st.st_dev,
+                st.st_nlink,
+                st.st_uid,
+                st.st_gid,
+                st.st_size,
+                st.st_atime,
+                st.st_mtime,
+                st.st_ctime,
+            )
             st = st.__class__(st_tuple)
             st = st.__class__(st_tuple)
         index[entry.path] = index_entry_from_stat(st, entry.sha, 0)
         index[entry.path] = index_entry_from_stat(st, entry.sha, 0)
 
 
     index.write()
     index.write()
 
 
 
 
-def blob_from_path_and_mode(fs_path, mode, tree_encoding='utf-8'):
+def blob_from_path_and_mode(fs_path, mode, tree_encoding="utf-8"):
     """Create a blob from a path and a stat object.
     """Create a blob from a path and a stat object.
 
 
     Args:
     Args:
@@ -627,19 +730,19 @@ def blob_from_path_and_mode(fs_path, mode, tree_encoding='utf-8'):
     assert isinstance(fs_path, bytes)
     assert isinstance(fs_path, bytes)
     blob = Blob()
     blob = Blob()
     if stat.S_ISLNK(mode):
     if stat.S_ISLNK(mode):
-        if sys.platform == 'win32':
+        if sys.platform == "win32":
             # os.readlink on Python3 on Windows requires a unicode string.
             # os.readlink on Python3 on Windows requires a unicode string.
             fs_path = os.fsdecode(fs_path)
             fs_path = os.fsdecode(fs_path)
             blob.data = os.readlink(fs_path).encode(tree_encoding)
             blob.data = os.readlink(fs_path).encode(tree_encoding)
         else:
         else:
             blob.data = os.readlink(fs_path)
             blob.data = os.readlink(fs_path)
     else:
     else:
-        with open(fs_path, 'rb') as f:
+        with open(fs_path, "rb") as f:
             blob.data = f.read()
             blob.data = f.read()
     return blob
     return blob
 
 
 
 
-def blob_from_path_and_stat(fs_path, st, tree_encoding='utf-8'):
+def blob_from_path_and_stat(fs_path, st, tree_encoding="utf-8"):
     """Create a blob from a path and a stat object.
     """Create a blob from a path and a stat object.
 
 
     Args:
     Args:
@@ -659,6 +762,7 @@ def read_submodule_head(path):
     """
     """
     from dulwich.errors import NotGitRepository
     from dulwich.errors import NotGitRepository
     from dulwich.repo import Repo
     from dulwich.repo import Repo
+
     # Repo currently expects a "str", so decode if necessary.
     # Repo currently expects a "str", so decode if necessary.
     # TODO(jelmer): Perhaps move this into Repo() ?
     # TODO(jelmer): Perhaps move this into Repo() ?
     if not isinstance(path, str):
     if not isinstance(path, str):
@@ -686,7 +790,7 @@ def _has_directory_changed(tree_path, entry):
     otherwise or if the path is not a directory.
     otherwise or if the path is not a directory.
     """
     """
     # This is actually a directory
     # This is actually a directory
-    if os.path.exists(os.path.join(tree_path, b'.git')):
+    if os.path.exists(os.path.join(tree_path, b".git")):
         # Submodule
         # Submodule
         head = read_submodule_head(tree_path)
         head = read_submodule_head(tree_path)
         if entry.sha != head:
         if entry.sha != head:
@@ -735,7 +839,7 @@ def get_unstaged_changes(index: Index, root_path, filter_blob_callback=None):
                 yield tree_path
                 yield tree_path
 
 
 
 
-os_sep_bytes = os.sep.encode('ascii')
+os_sep_bytes = os.sep.encode("ascii")
 
 
 
 
 def _tree_to_fs_path(root_path, tree_path: bytes):
 def _tree_to_fs_path(root_path, tree_path: bytes):
@@ -748,8 +852,8 @@ def _tree_to_fs_path(root_path, tree_path: bytes):
     Returns: File system path.
     Returns: File system path.
     """
     """
     assert isinstance(tree_path, bytes)
     assert isinstance(tree_path, bytes)
-    if os_sep_bytes != b'/':
-        sep_corrected_path = tree_path.replace(b'/', os_sep_bytes)
+    if os_sep_bytes != b"/":
+        sep_corrected_path = tree_path.replace(b"/", os_sep_bytes)
     else:
     else:
         sep_corrected_path = tree_path
         sep_corrected_path = tree_path
     return os.path.join(root_path, sep_corrected_path)
     return os.path.join(root_path, sep_corrected_path)
@@ -767,8 +871,8 @@ def _fs_to_tree_path(fs_path):
         fs_path_bytes = os.fsencode(fs_path)
         fs_path_bytes = os.fsencode(fs_path)
     else:
     else:
         fs_path_bytes = fs_path
         fs_path_bytes = fs_path
-    if os_sep_bytes != b'/':
-        tree_path = fs_path_bytes.replace(os_sep_bytes, b'/')
+    if os_sep_bytes != b"/":
+        tree_path = fs_path_bytes.replace(os_sep_bytes, b"/")
     else:
     else:
         tree_path = fs_path_bytes
         tree_path = fs_path_bytes
     return tree_path
     return tree_path
@@ -790,12 +894,11 @@ def index_entry_from_path(path, object_store=None):
     assert isinstance(path, bytes)
     assert isinstance(path, bytes)
     st = os.lstat(path)
     st = os.lstat(path)
     if stat.S_ISDIR(st.st_mode):
     if stat.S_ISDIR(st.st_mode):
-        if os.path.exists(os.path.join(path, b'.git')):
+        if os.path.exists(os.path.join(path, b".git")):
             head = read_submodule_head(path)
             head = read_submodule_head(path)
             if head is None:
             if head is None:
                 return None
                 return None
-            return index_entry_from_stat(
-                st, head, 0, mode=S_IFGITLINK)
+            return index_entry_from_stat(st, head, 0, mode=S_IFGITLINK)
         return None
         return None
 
 
     if stat.S_ISREG(st.st_mode) or stat.S_ISLNK(st.st_mode):
     if stat.S_ISREG(st.st_mode) or stat.S_ISLNK(st.st_mode):
@@ -808,7 +911,8 @@ def index_entry_from_path(path, object_store=None):
 
 
 
 
 def iter_fresh_entries(
 def iter_fresh_entries(
-        paths, root_path, object_store: Optional['BaseObjectStore'] = None):
+    paths, root_path, object_store: Optional["BaseObjectStore"] = None
+):
     """Iterate over current versions of index entries on disk.
     """Iterate over current versions of index entries on disk.
 
 
     Args:
     Args:
@@ -839,18 +943,16 @@ def iter_fresh_blobs(index, root_path):
     Returns: Iterator over path, sha, mode
     Returns: Iterator over path, sha, mode
     """
     """
     import warnings
     import warnings
-    warnings.warn(PendingDeprecationWarning,
-                  "Use iter_fresh_objects instead.")
-    for entry in iter_fresh_objects(
-            index, root_path, include_deleted=True):
+
+    warnings.warn(PendingDeprecationWarning, "Use iter_fresh_objects instead.")
+    for entry in iter_fresh_objects(index, root_path, include_deleted=True):
         if entry[1] is None:
         if entry[1] is None:
             del index[entry[0]]
             del index[entry[0]]
         else:
         else:
             yield entry
             yield entry
 
 
 
 
-def iter_fresh_objects(paths, root_path, include_deleted=False,
-                       object_store=None):
+def iter_fresh_objects(paths, root_path, include_deleted=False, object_store=None):
     """Iterate over versions of objecs on disk referenced by index.
     """Iterate over versions of objecs on disk referenced by index.
 
 
     Args:
     Args:
@@ -860,8 +962,7 @@ def iter_fresh_objects(paths, root_path, include_deleted=False,
       object_store: Optional object store to report new items to
       object_store: Optional object store to report new items to
     Returns: Iterator over path, sha, mode
     Returns: Iterator over path, sha, mode
     """
     """
-    for path, entry in iter_fresh_entries(paths, root_path,
-                                          object_store=object_store):
+    for path, entry in iter_fresh_entries(paths, root_path, object_store=object_store):
         if entry is None:
         if entry is None:
             if include_deleted:
             if include_deleted:
                 yield path, None, None
                 yield path, None, None

+ 7 - 8
dulwich/lfs.py

@@ -33,24 +33,24 @@ class LFSStore(object):
     def create(cls, lfs_dir):
     def create(cls, lfs_dir):
         if not os.path.isdir(lfs_dir):
         if not os.path.isdir(lfs_dir):
             os.mkdir(lfs_dir)
             os.mkdir(lfs_dir)
-        os.mkdir(os.path.join(lfs_dir, 'tmp'))
-        os.mkdir(os.path.join(lfs_dir, 'objects'))
+        os.mkdir(os.path.join(lfs_dir, "tmp"))
+        os.mkdir(os.path.join(lfs_dir, "objects"))
         return cls(lfs_dir)
         return cls(lfs_dir)
 
 
     @classmethod
     @classmethod
     def from_repo(cls, repo, create=False):
     def from_repo(cls, repo, create=False):
-        lfs_dir = os.path.join(repo.controldir, 'lfs')
+        lfs_dir = os.path.join(repo.controldir, "lfs")
         if create:
         if create:
             return cls.create(lfs_dir)
             return cls.create(lfs_dir)
         return cls(lfs_dir)
         return cls(lfs_dir)
 
 
     def _sha_path(self, sha):
     def _sha_path(self, sha):
-        return os.path.join(self.path, 'objects', sha[0:2], sha[2:4], sha)
+        return os.path.join(self.path, "objects", sha[0:2], sha[2:4], sha)
 
 
     def open_object(self, sha):
     def open_object(self, sha):
         """Open an object by sha."""
         """Open an object by sha."""
         try:
         try:
-            return open(self._sha_path(sha), 'rb')
+            return open(self._sha_path(sha), "rb")
         except FileNotFoundError:
         except FileNotFoundError:
             raise KeyError(sha)
             raise KeyError(sha)
 
 
@@ -60,9 +60,8 @@ class LFSStore(object):
         Returns: object SHA
         Returns: object SHA
         """
         """
         sha = hashlib.sha256()
         sha = hashlib.sha256()
-        tmpdir = os.path.join(self.path, 'tmp')
-        with tempfile.NamedTemporaryFile(
-                dir=tmpdir, mode='wb', delete=False) as f:
+        tmpdir = os.path.join(self.path, "tmp")
+        with tempfile.NamedTemporaryFile(dir=tmpdir, mode="wb", delete=False) as f:
             for chunk in chunks:
             for chunk in chunks:
                 sha.update(chunk)
                 sha.update(chunk)
                 f.write(chunk)
                 f.write(chunk)

+ 40 - 12
dulwich/line_ending.py

@@ -31,6 +31,16 @@ The normalization is a two-fold process that happens at two moments:
   when doing a `git add` call. We call this process the write filter in this
   when doing a `git add` call. We call this process the write filter in this
   module.
   module.
 
 
+Note that when checking status (getting unstaged changes), whether or not
+normalization is done on write depends on whether or not the file in the
+working dir has also been normalized on read:
+
+- For autocrlf=true all files are always normalized on both read and write.
+- For autocrlf=input files are only normalized on write if they are newly
+  "added". Since files which are already committed are not normalized on
+  checkout into the working tree, they are also left alone when staging
+  modifications into the index.
+
 One thing to know is that Git does line-ending normalization only on text
 One thing to know is that Git does line-ending normalization only on text
 files. How does Git know that a file is text? We can either mark a file as a
 files. How does Git know that a file is text? We can either mark a file as a
 text file, a binary file or ask Git to automatically decides. Git has an
 text file, a binary file or ask Git to automatically decides. Git has an
@@ -156,8 +166,7 @@ def convert_lf_to_crlf(text_hunk):
 
 
 
 
 def get_checkout_filter(core_eol, core_autocrlf, git_attributes):
 def get_checkout_filter(core_eol, core_autocrlf, git_attributes):
-    """ Returns the correct checkout filter based on the passed arguments
-    """
+    """Returns the correct checkout filter based on the passed arguments"""
     # TODO this function should process the git_attributes for the path and if
     # TODO this function should process the git_attributes for the path and if
     # the text attribute is not defined, fallback on the
     # the text attribute is not defined, fallback on the
     # get_checkout_filter_autocrlf function with the autocrlf value
     # get_checkout_filter_autocrlf function with the autocrlf value
@@ -165,8 +174,7 @@ def get_checkout_filter(core_eol, core_autocrlf, git_attributes):
 
 
 
 
 def get_checkin_filter(core_eol, core_autocrlf, git_attributes):
 def get_checkin_filter(core_eol, core_autocrlf, git_attributes):
-    """ Returns the correct checkin filter based on the passed arguments
-    """
+    """Returns the correct checkin filter based on the passed arguments"""
     # TODO this function should process the git_attributes for the path and if
     # TODO this function should process the git_attributes for the path and if
     # the text attribute is not defined, fallback on the
     # the text attribute is not defined, fallback on the
     # get_checkin_filter_autocrlf function with the autocrlf value
     # get_checkin_filter_autocrlf function with the autocrlf value
@@ -174,7 +182,7 @@ def get_checkin_filter(core_eol, core_autocrlf, git_attributes):
 
 
 
 
 def get_checkout_filter_autocrlf(core_autocrlf):
 def get_checkout_filter_autocrlf(core_autocrlf):
-    """ Returns the correct checkout filter base on autocrlf value
+    """Returns the correct checkout filter base on autocrlf value
 
 
     Args:
     Args:
       core_autocrlf: The bytes configuration value of core.autocrlf.
       core_autocrlf: The bytes configuration value of core.autocrlf.
@@ -190,7 +198,7 @@ def get_checkout_filter_autocrlf(core_autocrlf):
 
 
 
 
 def get_checkin_filter_autocrlf(core_autocrlf):
 def get_checkin_filter_autocrlf(core_autocrlf):
-    """ Returns the correct checkin filter base on autocrlf value
+    """Returns the correct checkin filter base on autocrlf value
 
 
     Args:
     Args:
       core_autocrlf: The bytes configuration value of core.autocrlf.
       core_autocrlf: The bytes configuration value of core.autocrlf.
@@ -207,7 +215,7 @@ def get_checkin_filter_autocrlf(core_autocrlf):
 
 
 
 
 class BlobNormalizer(object):
 class BlobNormalizer(object):
-    """ An object to store computation result of which filter to apply based
+    """An object to store computation result of which filter to apply based
     on configuration, gitattributes, path and operation (checkin or checkout)
     on configuration, gitattributes, path and operation (checkin or checkout)
     """
     """
 
 
@@ -234,8 +242,7 @@ class BlobNormalizer(object):
         )
         )
 
 
     def checkin_normalize(self, blob, tree_path):
     def checkin_normalize(self, blob, tree_path):
-        """ Normalize a blob during a checkin operation
-        """
+        """Normalize a blob during a checkin operation"""
         if self.fallback_write_filter is not None:
         if self.fallback_write_filter is not None:
             return normalize_blob(
             return normalize_blob(
                 blob, self.fallback_write_filter, binary_detection=True
                 blob, self.fallback_write_filter, binary_detection=True
@@ -244,8 +251,7 @@ class BlobNormalizer(object):
         return blob
         return blob
 
 
     def checkout_normalize(self, blob, tree_path):
     def checkout_normalize(self, blob, tree_path):
-        """ Normalize a blob during a checkout operation
-        """
+        """Normalize a blob during a checkout operation"""
         if self.fallback_read_filter is not None:
         if self.fallback_read_filter is not None:
             return normalize_blob(
             return normalize_blob(
                 blob, self.fallback_read_filter, binary_detection=True
                 blob, self.fallback_read_filter, binary_detection=True
@@ -255,7 +261,7 @@ class BlobNormalizer(object):
 
 
 
 
 def normalize_blob(blob, conversion, binary_detection):
 def normalize_blob(blob, conversion, binary_detection):
-    """ Takes a blob as input returns either the original blob if
+    """Takes a blob as input returns either the original blob if
     binary_detection is True and the blob content looks like binary, else
     binary_detection is True and the blob content looks like binary, else
     return a new blob with converted data
     return a new blob with converted data
     """
     """
@@ -276,3 +282,25 @@ def normalize_blob(blob, conversion, binary_detection):
     new_blob.data = converted_data
     new_blob.data = converted_data
 
 
     return new_blob
     return new_blob
+
+
+class TreeBlobNormalizer(BlobNormalizer):
+    def __init__(self, config_stack, git_attributes, object_store, tree=None):
+        super().__init__(config_stack, git_attributes)
+        if tree:
+            self.existing_paths = {
+                name
+                for name, _, _ in object_store.iter_tree_contents(tree)
+            }
+        else:
+            self.existing_paths = set()
+
+    def checkin_normalize(self, blob, tree_path):
+        # Existing files should only be normalized on checkin if it was
+        # previously normalized on checkout
+        if (
+            self.fallback_read_filter is not None
+            or tree_path not in self.existing_paths
+        ):
+            return super().checkin_normalize(blob, tree_path)
+        return blob

+ 6 - 3
dulwich/log_utils.py

@@ -49,15 +49,18 @@ class _NullHandler(logging.Handler):
 
 
 
 
 _NULL_HANDLER = _NullHandler()
 _NULL_HANDLER = _NullHandler()
-_DULWICH_LOGGER = getLogger('dulwich')
+_DULWICH_LOGGER = getLogger("dulwich")
 _DULWICH_LOGGER.addHandler(_NULL_HANDLER)
 _DULWICH_LOGGER.addHandler(_NULL_HANDLER)
 
 
 
 
 def default_logging_config():
 def default_logging_config():
     """Set up the default Dulwich loggers."""
     """Set up the default Dulwich loggers."""
     remove_null_handler()
     remove_null_handler()
-    logging.basicConfig(level=logging.INFO, stream=sys.stderr,
-                        format='%(asctime)s %(levelname)s: %(message)s')
+    logging.basicConfig(
+        level=logging.INFO,
+        stream=sys.stderr,
+        format="%(asctime)s %(levelname)s: %(message)s",
+    )
 
 
 
 
 def remove_null_handler():
 def remove_null_handler():

+ 35 - 26
dulwich/lru_cache.py

@@ -26,7 +26,7 @@ _null_key = object()
 class _LRUNode(object):
 class _LRUNode(object):
     """This maintains the linked-list which is the lru internals."""
     """This maintains the linked-list which is the lru internals."""
 
 
-    __slots__ = ('prev', 'next_key', 'key', 'value', 'cleanup', 'size')
+    __slots__ = ("prev", "next_key", "key", "value", "cleanup", "size")
 
 
     def __init__(self, key, value, cleanup=None):
     def __init__(self, key, value, cleanup=None):
         self.prev = None
         self.prev = None
@@ -44,8 +44,12 @@ class _LRUNode(object):
             prev_key = None
             prev_key = None
         else:
         else:
             prev_key = self.prev.key
             prev_key = self.prev.key
-        return '%s(%r n:%r p:%r)' % (self.__class__.__name__, self.key,
-                                     self.next_key, prev_key)
+        return "%s(%r n:%r p:%r)" % (
+            self.__class__.__name__,
+            self.key,
+            self.next_key,
+            prev_key,
+        )
 
 
     def run_cleanup(self):
     def run_cleanup(self):
         if self.cleanup is not None:
         if self.cleanup is not None:
@@ -108,29 +112,35 @@ class LRUCache(object):
         node = self._most_recently_used
         node = self._most_recently_used
         if node is not None:
         if node is not None:
             if node.prev is not None:
             if node.prev is not None:
-                raise AssertionError('the _most_recently_used entry is not'
-                                     ' supposed to have a previous entry'
-                                     ' %s' % (node,))
+                raise AssertionError(
+                    "the _most_recently_used entry is not"
+                    " supposed to have a previous entry"
+                    " %s" % (node,)
+                )
         while node is not None:
         while node is not None:
             if node.next_key is _null_key:
             if node.next_key is _null_key:
                 if node is not self._least_recently_used:
                 if node is not self._least_recently_used:
-                    raise AssertionError('only the last node should have'
-                                         ' no next value: %s' % (node,))
+                    raise AssertionError(
+                        "only the last node should have" " no next value: %s" % (node,)
+                    )
                 node_next = None
                 node_next = None
             else:
             else:
                 node_next = self._cache[node.next_key]
                 node_next = self._cache[node.next_key]
                 if node_next.prev is not node:
                 if node_next.prev is not node:
-                    raise AssertionError('inconsistency found, node.next.prev'
-                                         ' != node: %s' % (node,))
+                    raise AssertionError(
+                        "inconsistency found, node.next.prev" " != node: %s" % (node,)
+                    )
             if node.prev is None:
             if node.prev is None:
                 if node is not self._most_recently_used:
                 if node is not self._most_recently_used:
-                    raise AssertionError('only the _most_recently_used should'
-                                         ' not have a previous node: %s'
-                                         % (node,))
+                    raise AssertionError(
+                        "only the _most_recently_used should"
+                        " not have a previous node: %s" % (node,)
+                    )
             else:
             else:
                 if node.prev.next_key != node.key:
                 if node.prev.next_key != node.key:
-                    raise AssertionError('inconsistency found, node.prev.next'
-                                         ' != node: %s' % (node,))
+                    raise AssertionError(
+                        "inconsistency found, node.prev.next" " != node: %s" % (node,)
+                    )
             yield node
             yield node
             node = node_next
             node = node_next
 
 
@@ -147,7 +157,7 @@ class LRUCache(object):
                         'value' should be cleaned up.
                         'value' should be cleaned up.
         """
         """
         if key is _null_key:
         if key is _null_key:
-            raise ValueError('cannot use _null_key as a key')
+            raise ValueError("cannot use _null_key as a key")
         if key in self._cache:
         if key in self._cache:
             node = self._cache[key]
             node = self._cache[key]
             node.run_cleanup()
             node.run_cleanup()
@@ -186,7 +196,7 @@ class LRUCache(object):
 
 
     def items(self):
     def items(self):
         """Get the key:value pairs as a dict."""
         """Get the key:value pairs as a dict."""
-        return dict((k, n.value) for k, n in self._cache.items())
+        return {k: n.value for k, n in self._cache.items()}
 
 
     def cleanup(self):
     def cleanup(self):
         """Clear the cache until it shrinks to the requested size.
         """Clear the cache until it shrinks to the requested size.
@@ -262,16 +272,14 @@ class LRUCache(object):
 
 
     def resize(self, max_cache, after_cleanup_count=None):
     def resize(self, max_cache, after_cleanup_count=None):
         """Change the number of entries that will be cached."""
         """Change the number of entries that will be cached."""
-        self._update_max_cache(max_cache,
-                               after_cleanup_count=after_cleanup_count)
+        self._update_max_cache(max_cache, after_cleanup_count=after_cleanup_count)
 
 
     def _update_max_cache(self, max_cache, after_cleanup_count=None):
     def _update_max_cache(self, max_cache, after_cleanup_count=None):
         self._max_cache = max_cache
         self._max_cache = max_cache
         if after_cleanup_count is None:
         if after_cleanup_count is None:
             self._after_cleanup_count = self._max_cache * 8 / 10
             self._after_cleanup_count = self._max_cache * 8 / 10
         else:
         else:
-            self._after_cleanup_count = min(after_cleanup_count,
-                                            self._max_cache)
+            self._after_cleanup_count = min(after_cleanup_count, self._max_cache)
         self.cleanup()
         self.cleanup()
 
 
 
 
@@ -285,8 +293,9 @@ class LRUSizeCache(LRUCache):
     defaults to len() if not supplied.
     defaults to len() if not supplied.
     """
     """
 
 
-    def __init__(self, max_size=1024*1024, after_cleanup_size=None,
-                 compute_size=None):
+    def __init__(
+        self, max_size=1024 * 1024, after_cleanup_size=None, compute_size=None
+    ):
         """Create a new LRUSizeCache.
         """Create a new LRUSizeCache.
 
 
         Args:
         Args:
@@ -306,7 +315,7 @@ class LRUSizeCache(LRUCache):
         if compute_size is None:
         if compute_size is None:
             self._compute_size = len
             self._compute_size = len
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
-        LRUCache.__init__(self, max_cache=max(int(max_size/512), 1))
+        LRUCache.__init__(self, max_cache=max(int(max_size / 512), 1))
 
 
     def add(self, key, value, cleanup=None):
     def add(self, key, value, cleanup=None):
         """Add a new value to the cache.
         """Add a new value to the cache.
@@ -321,7 +330,7 @@ class LRUSizeCache(LRUCache):
                         'value' should be cleaned up.
                         'value' should be cleaned up.
         """
         """
         if key is _null_key:
         if key is _null_key:
-            raise ValueError('cannot use _null_key as a key')
+            raise ValueError("cannot use _null_key as a key")
         node = self._cache.get(key, None)
         node = self._cache.get(key, None)
         value_len = self._compute_size(value)
         value_len = self._compute_size(value)
         if value_len >= self._after_cleanup_size:
         if value_len >= self._after_cleanup_size:
@@ -363,7 +372,7 @@ class LRUSizeCache(LRUCache):
     def resize(self, max_size, after_cleanup_size=None):
     def resize(self, max_size, after_cleanup_size=None):
         """Change the number of bytes that will be cached."""
         """Change the number of bytes that will be cached."""
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
-        max_cache = max(int(max_size/512), 1)
+        max_cache = max(int(max_size / 512), 1)
         self._update_max_cache(max_cache)
         self._update_max_cache(max_cache)
 
 
     def _update_max_size(self, max_size, after_cleanup_size=None):
     def _update_max_size(self, max_size, after_cleanup_size=None):

+ 6 - 5
dulwich/mailmap.py

@@ -44,11 +44,11 @@ def read_mailmap(f):
     """
     """
     for line in f:
     for line in f:
         # Remove comments
         # Remove comments
-        line = line.split(b'#')[0]
+        line = line.split(b"#")[0]
         line = line.strip()
         line = line.strip()
         if not line:
         if not line:
             continue
             continue
-        (canonical_identity, from_identity) = line.split(b'>', 1)
+        (canonical_identity, from_identity) = line.split(b">", 1)
         canonical_identity += b">"
         canonical_identity += b">"
         if from_identity.strip():
         if from_identity.strip():
             parsed_from_identity = parse_identity(from_identity)
             parsed_from_identity = parse_identity(from_identity)
@@ -99,8 +99,9 @@ class Mailmap(object):
             canonical_identity = self._table.get(query)
             canonical_identity = self._table.get(query)
             if canonical_identity is not None:
             if canonical_identity is not None:
                 identity = (
                 identity = (
-                        canonical_identity[0] or identity[0],
-                        canonical_identity[1] or identity[1])
+                    canonical_identity[0] or identity[0],
+                    canonical_identity[1] or identity[1],
+                )
                 break
                 break
         if was_tuple:
         if was_tuple:
             return identity
             return identity
@@ -109,5 +110,5 @@ class Mailmap(object):
 
 
     @classmethod
     @classmethod
     def from_path(cls, path):
     def from_path(cls, path):
-        with open(path, 'rb') as f:
+        with open(path, "rb") as f:
             return cls(read_mailmap(f))
             return cls(read_mailmap(f))

+ 292 - 115
dulwich/object_store.py

@@ -30,10 +30,10 @@ import sys
 from dulwich.diff_tree import (
 from dulwich.diff_tree import (
     tree_changes,
     tree_changes,
     walk_trees,
     walk_trees,
-    )
+)
 from dulwich.errors import (
 from dulwich.errors import (
     NotTreeError,
     NotTreeError,
-    )
+)
 from dulwich.file import GitFile
 from dulwich.file import GitFile
 from dulwich.objects import (
 from dulwich.objects import (
     Commit,
     Commit,
@@ -47,12 +47,13 @@ from dulwich.objects import (
     S_ISGITLINK,
     S_ISGITLINK,
     object_class,
     object_class,
     valid_hexsha,
     valid_hexsha,
-    )
+)
 from dulwich.pack import (
 from dulwich.pack import (
     Pack,
     Pack,
     PackData,
     PackData,
     PackInflater,
     PackInflater,
     PackFileDisappeared,
     PackFileDisappeared,
+    load_pack_index_file,
     iter_sha1,
     iter_sha1,
     pack_objects_to_data,
     pack_objects_to_data,
     write_pack_header,
     write_pack_header,
@@ -62,21 +63,32 @@ from dulwich.pack import (
     compute_file_sha,
     compute_file_sha,
     PackIndexer,
     PackIndexer,
     PackStreamCopier,
     PackStreamCopier,
-    )
+)
+from dulwich.protocol import DEPTH_INFINITE
 from dulwich.refs import ANNOTATED_TAG_SUFFIX
 from dulwich.refs import ANNOTATED_TAG_SUFFIX
 
 
-INFODIR = 'info'
-PACKDIR = 'pack'
+INFODIR = "info"
+PACKDIR = "pack"
 
 
 
 
 class BaseObjectStore(object):
 class BaseObjectStore(object):
     """Object store interface."""
     """Object store interface."""
 
 
-    def determine_wants_all(self, refs):
-        return [sha for (ref, sha) in refs.items()
-                if sha not in self and
-                not ref.endswith(ANNOTATED_TAG_SUFFIX) and
-                not sha == ZERO_SHA]
+    def determine_wants_all(self, refs, depth=None):
+        def _want_deepen(sha):
+            if not depth:
+                return False
+            if depth == DEPTH_INFINITE:
+                return True
+            return depth > self._get_depth(sha)
+
+        return [
+            sha
+            for (ref, sha) in refs.items()
+            if (sha not in self or _want_deepen(sha))
+            and not ref.endswith(ANNOTATED_TAG_SUFFIX)
+            and not sha == ZERO_SHA
+        ]
 
 
     def iter_shas(self, shas):
     def iter_shas(self, shas):
         """Iterate over the objects for the specified shas.
         """Iterate over the objects for the specified shas.
@@ -126,9 +138,7 @@ class BaseObjectStore(object):
         raise NotImplementedError(self.__iter__)
         raise NotImplementedError(self.__iter__)
 
 
     def add_object(self, obj):
     def add_object(self, obj):
-        """Add a single object to this object store.
-
-        """
+        """Add a single object to this object store."""
         raise NotImplementedError(self.add_object)
         raise NotImplementedError(self.add_object)
 
 
     def add_objects(self, objects, progress=None):
     def add_objects(self, objects, progress=None):
@@ -152,17 +162,27 @@ class BaseObjectStore(object):
         f, commit, abort = self.add_pack()
         f, commit, abort = self.add_pack()
         try:
         try:
             write_pack_data(
             write_pack_data(
-                f, count, pack_data, progress,
-                compression_level=self.pack_compression_level)
+                f,
+                count,
+                pack_data,
+                progress,
+                compression_level=self.pack_compression_level,
+            )
         except BaseException:
         except BaseException:
             abort()
             abort()
             raise
             raise
         else:
         else:
             return commit()
             return commit()
 
 
-    def tree_changes(self, source, target, want_unchanged=False,
-                     include_trees=False, change_type_same=False,
-                     rename_detector=None):
+    def tree_changes(
+        self,
+        source,
+        target,
+        want_unchanged=False,
+        include_trees=False,
+        change_type_same=False,
+        rename_detector=None,
+    ):
         """Find the differences between the contents of two trees
         """Find the differences between the contents of two trees
 
 
         Args:
         Args:
@@ -175,14 +195,20 @@ class BaseObjectStore(object):
         Returns: Iterator over tuples with
         Returns: Iterator over tuples with
             (oldpath, newpath), (oldmode, newmode), (oldsha, newsha)
             (oldpath, newpath), (oldmode, newmode), (oldsha, newsha)
         """
         """
-        for change in tree_changes(self, source, target,
-                                   want_unchanged=want_unchanged,
-                                   include_trees=include_trees,
-                                   change_type_same=change_type_same,
-                                   rename_detector=rename_detector):
-            yield ((change.old.path, change.new.path),
-                   (change.old.mode, change.new.mode),
-                   (change.old.sha, change.new.sha))
+        for change in tree_changes(
+            self,
+            source,
+            target,
+            want_unchanged=want_unchanged,
+            include_trees=include_trees,
+            change_type_same=change_type_same,
+            rename_detector=rename_detector,
+        ):
+            yield (
+                (change.old.path, change.new.path),
+                (change.old.mode, change.new.mode),
+                (change.old.sha, change.new.sha),
+            )
 
 
     def iter_tree_contents(self, tree_id, include_trees=False):
     def iter_tree_contents(self, tree_id, include_trees=False):
         """Iterate the contents of a tree and all subtrees.
         """Iterate the contents of a tree and all subtrees.
@@ -196,14 +222,21 @@ class BaseObjectStore(object):
             tree.
             tree.
         """
         """
         for entry, _ in walk_trees(self, tree_id, None):
         for entry, _ in walk_trees(self, tree_id, None):
-            if ((entry.mode is not None and
-                 not stat.S_ISDIR(entry.mode)) or include_trees):
+            if (
+                entry.mode is not None and not stat.S_ISDIR(entry.mode)
+            ) or include_trees:
                 yield entry
                 yield entry
 
 
-    def find_missing_objects(self, haves, wants, shallow=None, progress=None,
-                             get_tagged=None,
-                             get_parents=lambda commit: commit.parents,
-                             depth=None):
+    def find_missing_objects(
+        self,
+        haves,
+        wants,
+        shallow=None,
+        progress=None,
+        get_tagged=None,
+        get_parents=lambda commit: commit.parents,
+        depth=None,
+    ):
         """Find the missing objects required for a set of revisions.
         """Find the missing objects required for a set of revisions.
 
 
         Args:
         Args:
@@ -218,8 +251,15 @@ class BaseObjectStore(object):
             commit.
             commit.
         Returns: Iterator over (sha, path) pairs.
         Returns: Iterator over (sha, path) pairs.
         """
         """
-        finder = MissingObjectFinder(self, haves, wants, shallow, progress,
-                                     get_tagged, get_parents=get_parents)
+        finder = MissingObjectFinder(
+            self,
+            haves,
+            wants,
+            shallow,
+            progress,
+            get_tagged,
+            get_parents=get_parents,
+        )
         return iter(finder.next, None)
         return iter(finder.next, None)
 
 
     def find_common_revisions(self, graphwalker):
     def find_common_revisions(self, graphwalker):
@@ -250,8 +290,9 @@ class BaseObjectStore(object):
         missing = self.find_missing_objects(have, want, shallow, progress)
         missing = self.find_missing_objects(have, want, shallow, progress)
         return self.iter_shas(missing)
         return self.iter_shas(missing)
 
 
-    def generate_pack_data(self, have, want, shallow=None, progress=None,
-                           ofs_delta=True):
+    def generate_pack_data(
+        self, have, want, shallow=None, progress=None, ofs_delta=True
+    ):
         """Generate pack data objects for a set of wants/haves.
         """Generate pack data objects for a set of wants/haves.
 
 
         Args:
         Args:
@@ -263,7 +304,8 @@ class BaseObjectStore(object):
         """
         """
         # TODO(jelmer): More efficient implementation
         # TODO(jelmer): More efficient implementation
         return pack_objects_to_data(
         return pack_objects_to_data(
-            self.generate_pack_contents(have, want, shallow, progress))
+            self.generate_pack_contents(have, want, shallow, progress)
+        )
 
 
     def peel_sha(self, sha):
     def peel_sha(self, sha):
         """Peel all tags from a SHA.
         """Peel all tags from a SHA.
@@ -281,8 +323,13 @@ class BaseObjectStore(object):
             obj = self[sha]
             obj = self[sha]
         return obj
         return obj
 
 
-    def _collect_ancestors(self, heads, common=set(), shallow=set(),
-                           get_parents=lambda commit: commit.parents):
+    def _collect_ancestors(
+        self,
+        heads,
+        common=set(),
+        shallow=set(),
+        get_parents=lambda commit: commit.parents,
+    ):
         """Collect all ancestors of heads up to (excluding) those in common.
         """Collect all ancestors of heads up to (excluding) those in common.
 
 
         Args:
         Args:
@@ -311,13 +358,42 @@ class BaseObjectStore(object):
                 queue.extend(get_parents(cmt))
                 queue.extend(get_parents(cmt))
         return (commits, bases)
         return (commits, bases)
 
 
+    def _get_depth(
+        self, head, get_parents=lambda commit: commit.parents, max_depth=None,
+    ):
+        """Return the current available depth for the given head.
+        For commits with multiple parents, the largest possible depth will be
+        returned.
+
+        Args:
+            head: commit to start from
+            get_parents: optional function for getting the parents of a commit
+            max_depth: maximum depth to search
+        """
+        if head not in self:
+            return 0
+        current_depth = 1
+        queue = [(head, current_depth)]
+        while queue and (max_depth is None or current_depth < max_depth):
+            e, depth = queue.pop(0)
+            current_depth = max(current_depth, depth)
+            cmt = self[e]
+            if isinstance(cmt, Tag):
+                _cls, sha = cmt.object
+                cmt = self[sha]
+            queue.extend(
+                (parent, depth + 1)
+                for parent in get_parents(cmt)
+                if parent in self
+            )
+        return current_depth
+
     def close(self):
     def close(self):
         """Close any files opened by this object store."""
         """Close any files opened by this object store."""
         # Default implementation is a NO-OP
         # Default implementation is a NO-OP
 
 
 
 
 class PackBasedObjectStore(BaseObjectStore):
 class PackBasedObjectStore(BaseObjectStore):
-
     def __init__(self, pack_compression_level=-1):
     def __init__(self, pack_compression_level=-1):
         self._pack_cache = {}
         self._pack_cache = {}
         self.pack_compression_level = pack_compression_level
         self.pack_compression_level = pack_compression_level
@@ -352,9 +428,7 @@ class PackBasedObjectStore(BaseObjectStore):
         return False
         return False
 
 
     def _add_cached_pack(self, base_name, pack):
     def _add_cached_pack(self, base_name, pack):
-        """Add a newly appeared pack to the cache by path.
-
-        """
+        """Add a newly appeared pack to the cache by path."""
         prev_pack = self._pack_cache.get(base_name)
         prev_pack = self._pack_cache.get(base_name)
         if prev_pack is not pack:
         if prev_pack is not pack:
             self._pack_cache[base_name] = pack
             self._pack_cache[base_name] = pack
@@ -380,8 +454,7 @@ class PackBasedObjectStore(BaseObjectStore):
     @property
     @property
     def packs(self):
     def packs(self):
         """List with pack objects."""
         """List with pack objects."""
-        return (
-            list(self._iter_cached_packs()) + list(self._update_pack_cache()))
+        return list(self._iter_cached_packs()) + list(self._update_pack_cache())
 
 
     def _iter_alternate_objects(self):
     def _iter_alternate_objects(self):
         """Iterate over the SHAs of all the objects in alternate stores."""
         """Iterate over the SHAs of all the objects in alternate stores."""
@@ -480,7 +553,7 @@ class PackBasedObjectStore(BaseObjectStore):
             sha = name
             sha = name
             hexsha = None
             hexsha = None
         else:
         else:
-            raise AssertionError("Invalid object name %r" % (name, ))
+            raise AssertionError("Invalid object name %r" % (name,))
         for pack in self._iter_cached_packs():
         for pack in self._iter_cached_packs():
             try:
             try:
                 return pack.get_raw(sha)
                 return pack.get_raw(sha)
@@ -513,16 +586,13 @@ class PackBasedObjectStore(BaseObjectStore):
             __len__.
             __len__.
         Returns: Pack object of the objects written.
         Returns: Pack object of the objects written.
         """
         """
-        return self.add_pack_data(
-                *pack_objects_to_data(objects),
-                progress=progress)
+        return self.add_pack_data(*pack_objects_to_data(objects), progress=progress)
 
 
 
 
 class DiskObjectStore(PackBasedObjectStore):
 class DiskObjectStore(PackBasedObjectStore):
     """Git-style object store that exists on disk."""
     """Git-style object store that exists on disk."""
 
 
-    def __init__(self, path, loose_compression_level=-1,
-                 pack_compression_level=-1):
+    def __init__(self, path, loose_compression_level=-1, pack_compression_level=-1):
         """Open an object store.
         """Open an object store.
 
 
         Args:
         Args:
@@ -531,7 +601,8 @@ class DiskObjectStore(PackBasedObjectStore):
           pack_compression_level: zlib compression level for pack objects
           pack_compression_level: zlib compression level for pack objects
         """
         """
         super(DiskObjectStore, self).__init__(
         super(DiskObjectStore, self).__init__(
-            pack_compression_level=pack_compression_level)
+            pack_compression_level=pack_compression_level
+        )
         self.path = path
         self.path = path
         self.pack_dir = os.path.join(self.path, PACKDIR)
         self.pack_dir = os.path.join(self.path, PACKDIR)
         self._alternates = None
         self._alternates = None
@@ -544,18 +615,21 @@ class DiskObjectStore(PackBasedObjectStore):
     @classmethod
     @classmethod
     def from_config(cls, path, config):
     def from_config(cls, path, config):
         try:
         try:
-            default_compression_level = int(config.get(
-                (b'core', ), b'compression').decode())
+            default_compression_level = int(
+                config.get((b"core",), b"compression").decode()
+            )
         except KeyError:
         except KeyError:
             default_compression_level = -1
             default_compression_level = -1
         try:
         try:
-            loose_compression_level = int(config.get(
-                (b'core', ), b'looseCompression').decode())
+            loose_compression_level = int(
+                config.get((b"core",), b"looseCompression").decode()
+            )
         except KeyError:
         except KeyError:
             loose_compression_level = default_compression_level
             loose_compression_level = default_compression_level
         try:
         try:
-            pack_compression_level = int(config.get(
-                (b'core', ), 'packCompression').decode())
+            pack_compression_level = int(
+                config.get((b"core",), "packCompression").decode()
+            )
         except KeyError:
         except KeyError:
             pack_compression_level = default_compression_level
             pack_compression_level = default_compression_level
         return cls(path, loose_compression_level, pack_compression_level)
         return cls(path, loose_compression_level, pack_compression_level)
@@ -571,7 +645,7 @@ class DiskObjectStore(PackBasedObjectStore):
 
 
     def _read_alternate_paths(self):
     def _read_alternate_paths(self):
         try:
         try:
-            f = GitFile(os.path.join(self.path, INFODIR, "alternates"), 'rb')
+            f = GitFile(os.path.join(self.path, INFODIR, "alternates"), "rb")
         except FileNotFoundError:
         except FileNotFoundError:
             return
             return
         with f:
         with f:
@@ -582,20 +656,18 @@ class DiskObjectStore(PackBasedObjectStore):
                 if os.path.isabs(line):
                 if os.path.isabs(line):
                     yield os.fsdecode(line)
                     yield os.fsdecode(line)
                 else:
                 else:
-                    yield os.fsdecode(os.path.join(os.fsencode(self.path),
-                                                   line))
+                    yield os.fsdecode(os.path.join(os.fsencode(self.path), line))
 
 
     def add_alternate_path(self, path):
     def add_alternate_path(self, path):
-        """Add an alternate path to this object store.
-        """
+        """Add an alternate path to this object store."""
         try:
         try:
             os.mkdir(os.path.join(self.path, INFODIR))
             os.mkdir(os.path.join(self.path, INFODIR))
         except FileExistsError:
         except FileExistsError:
             pass
             pass
         alternates_path = os.path.join(self.path, INFODIR, "alternates")
         alternates_path = os.path.join(self.path, INFODIR, "alternates")
-        with GitFile(alternates_path, 'wb') as f:
+        with GitFile(alternates_path, "wb") as f:
             try:
             try:
-                orig_f = open(alternates_path, 'rb')
+                orig_f = open(alternates_path, "rb")
             except FileNotFoundError:
             except FileNotFoundError:
                 pass
                 pass
             else:
             else:
@@ -621,7 +693,7 @@ class DiskObjectStore(PackBasedObjectStore):
                 # fully written)
                 # fully written)
                 idx_name = os.path.splitext(name)[0] + ".idx"
                 idx_name = os.path.splitext(name)[0] + ".idx"
                 if idx_name in pack_dir_contents:
                 if idx_name in pack_dir_contents:
-                    pack_name = name[:-len(".pack")]
+                    pack_name = name[: -len(".pack")]
                     pack_files.add(pack_name)
                     pack_files.add(pack_name)
 
 
         # Open newly appeared pack files
         # Open newly appeared pack files
@@ -645,7 +717,7 @@ class DiskObjectStore(PackBasedObjectStore):
             if len(base) != 2:
             if len(base) != 2:
                 continue
                 continue
             for rest in os.listdir(os.path.join(self.path, base)):
             for rest in os.listdir(os.path.join(self.path, base)):
-                sha = os.fsencode(base+rest)
+                sha = os.fsencode(base + rest)
                 if not valid_hexsha(sha):
                 if not valid_hexsha(sha):
                     continue
                     continue
                 yield sha
                 yield sha
@@ -672,7 +744,7 @@ class DiskObjectStore(PackBasedObjectStore):
     def _get_pack_basepath(self, entries):
     def _get_pack_basepath(self, entries):
         suffix = iter_sha1(entry[0] for entry in entries)
         suffix = iter_sha1(entry[0] for entry in entries)
         # TODO: Handle self.pack_dir being bytes
         # TODO: Handle self.pack_dir being bytes
-        suffix = suffix.decode('ascii')
+        suffix = suffix.decode("ascii")
         return os.path.join(self.pack_dir, "pack-" + suffix)
         return os.path.join(self.pack_dir, "pack-" + suffix)
 
 
     def _complete_thin_pack(self, f, path, copier, indexer):
     def _complete_thin_pack(self, f, path, copier, indexer):
@@ -708,8 +780,12 @@ class DiskObjectStore(PackBasedObjectStore):
             type_num, data = self.get_raw(ext_sha)
             type_num, data = self.get_raw(ext_sha)
             offset = f.tell()
             offset = f.tell()
             crc32 = write_pack_object(
             crc32 = write_pack_object(
-                f, type_num, data, sha=new_sha,
-                compression_level=self.pack_compression_level)
+                f,
+                type_num,
+                data,
+                sha=new_sha,
+                compression_level=self.pack_compression_level,
+            )
             entries.append((ext_sha, offset, crc32))
             entries.append((ext_sha, offset, crc32))
         pack_sha = new_sha.digest()
         pack_sha = new_sha.digest()
         f.write(pack_sha)
         f.write(pack_sha)
@@ -718,8 +794,8 @@ class DiskObjectStore(PackBasedObjectStore):
         # Move the pack in.
         # Move the pack in.
         entries.sort()
         entries.sort()
         pack_base_name = self._get_pack_basepath(entries)
         pack_base_name = self._get_pack_basepath(entries)
-        target_pack = pack_base_name + '.pack'
-        if sys.platform == 'win32':
+        target_pack = pack_base_name + ".pack"
+        if sys.platform == "win32":
             # Windows might have the target pack file lingering. Attempt
             # Windows might have the target pack file lingering. Attempt
             # removal, silently passing if the target does not exist.
             # removal, silently passing if the target does not exist.
             try:
             try:
@@ -729,7 +805,7 @@ class DiskObjectStore(PackBasedObjectStore):
         os.rename(path, target_pack)
         os.rename(path, target_pack)
 
 
         # Write the index.
         # Write the index.
-        index_file = GitFile(pack_base_name + '.idx', 'wb')
+        index_file = GitFile(pack_base_name + ".idx", "wb")
         try:
         try:
             write_pack_index_v2(index_file, entries, pack_sha)
             write_pack_index_v2(index_file, entries, pack_sha)
             index_file.close()
             index_file.close()
@@ -758,11 +834,11 @@ class DiskObjectStore(PackBasedObjectStore):
             objects/pack directory.
             objects/pack directory.
         """
         """
         import tempfile
         import tempfile
-        fd, path = tempfile.mkstemp(dir=self.path, prefix='tmp_pack_')
-        with os.fdopen(fd, 'w+b') as f:
+
+        fd, path = tempfile.mkstemp(dir=self.path, prefix="tmp_pack_")
+        with os.fdopen(fd, "w+b") as f:
             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
-            copier = PackStreamCopier(read_all, read_some, f,
-                                      delta_iter=indexer)
+            copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             copier.verify()
             return self._complete_thin_pack(f, path, copier, indexer)
             return self._complete_thin_pack(f, path, copier, indexer)
 
 
@@ -785,8 +861,8 @@ class DiskObjectStore(PackBasedObjectStore):
         for pack in self.packs:
         for pack in self.packs:
             if pack._basename == basename:
             if pack._basename == basename:
                 return pack
                 return pack
-        target_pack = basename + '.pack'
-        if sys.platform == 'win32':
+        target_pack = basename + ".pack"
+        if sys.platform == "win32":
             # Windows might have the target pack file lingering. Attempt
             # Windows might have the target pack file lingering. Attempt
             # removal, silently passing if the target does not exist.
             # removal, silently passing if the target does not exist.
             try:
             try:
@@ -806,8 +882,9 @@ class DiskObjectStore(PackBasedObjectStore):
             function.
             function.
         """
         """
         import tempfile
         import tempfile
+
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
-        f = os.fdopen(fd, 'wb')
+        f = os.fdopen(fd, "wb")
 
 
         def commit():
         def commit():
             f.flush()
             f.flush()
@@ -822,6 +899,7 @@ class DiskObjectStore(PackBasedObjectStore):
         def abort():
         def abort():
             f.close()
             f.close()
             os.remove(path)
             os.remove(path)
+
         return f, commit, abort
         return f, commit, abort
 
 
     def add_object(self, obj):
     def add_object(self, obj):
@@ -838,9 +916,10 @@ class DiskObjectStore(PackBasedObjectStore):
             pass
             pass
         if os.path.exists(path):
         if os.path.exists(path):
             return  # Already there, no need to write again
             return  # Already there, no need to write again
-        with GitFile(path, 'wb') as f:
-            f.write(obj.as_legacy_object(
-                compression_level=self.loose_compression_level))
+        with GitFile(path, "wb") as f:
+            f.write(
+                obj.as_legacy_object(compression_level=self.loose_compression_level)
+            )
 
 
     @classmethod
     @classmethod
     def init(cls, path):
     def init(cls, path):
@@ -904,9 +983,7 @@ class MemoryObjectStore(BaseObjectStore):
         del self._data[self._to_hexsha(name)]
         del self._data[self._to_hexsha(name)]
 
 
     def add_object(self, obj):
     def add_object(self, obj):
-        """Add a single object to this object store.
-
-        """
+        """Add a single object to this object store."""
         self._data[obj.id] = obj.copy()
         self._data[obj.id] = obj.copy()
 
 
     def add_objects(self, objects, progress=None):
     def add_objects(self, objects, progress=None):
@@ -937,6 +1014,7 @@ class MemoryObjectStore(BaseObjectStore):
 
 
         def abort():
         def abort():
             pass
             pass
+
         return f, commit, abort
         return f, commit, abort
 
 
     def _complete_thin_pack(self, f, indexer):
     def _complete_thin_pack(self, f, indexer):
@@ -959,8 +1037,7 @@ class MemoryObjectStore(BaseObjectStore):
         for ext_sha in indexer.ext_refs():
         for ext_sha in indexer.ext_refs():
             assert len(ext_sha) == 20
             assert len(ext_sha) == 20
             type_num, data = self.get_raw(ext_sha)
             type_num, data = self.get_raw(ext_sha)
-            write_pack_object(
-                f, type_num, data, sha=new_sha)
+            write_pack_object(f, type_num, data, sha=new_sha)
         pack_sha = new_sha.digest()
         pack_sha = new_sha.digest()
         f.write(pack_sha)
         f.write(pack_sha)
 
 
@@ -980,8 +1057,7 @@ class MemoryObjectStore(BaseObjectStore):
         f, commit, abort = self.add_pack()
         f, commit, abort = self.add_pack()
         try:
         try:
             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
             indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
-            copier = PackStreamCopier(read_all, read_some, f,
-                                      delta_iter=indexer)
+            copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             copier.verify()
             self._complete_thin_pack(f, indexer)
             self._complete_thin_pack(f, indexer)
         except BaseException:
         except BaseException:
@@ -1059,7 +1135,8 @@ class ObjectStoreIterator(ObjectIterator):
 
 
     def empty(self):
     def empty(self):
         import warnings
         import warnings
-        warnings.warn('Use bool() instead.', DeprecationWarning)
+
+        warnings.warn("Use bool() instead.", DeprecationWarning)
         return self._empty()
         return self._empty()
 
 
     def _empty(self):
     def _empty(self):
@@ -1138,7 +1215,8 @@ def _split_commits_and_tags(obj_store, lst, ignore_unknown=False):
                 tags.add(e)
                 tags.add(e)
                 tagged = o.object[1]
                 tagged = o.object[1]
                 c, t, o = _split_commits_and_tags(
                 c, t, o = _split_commits_and_tags(
-                    obj_store, [tagged], ignore_unknown=ignore_unknown)
+                    obj_store, [tagged], ignore_unknown=ignore_unknown
+                )
                 commits |= c
                 commits |= c
                 tags |= t
                 tags |= t
                 others |= o
                 others |= o
@@ -1162,8 +1240,16 @@ class MissingObjectFinder(object):
       tagged: dict of pointed-to sha -> tag sha for including tags
       tagged: dict of pointed-to sha -> tag sha for including tags
     """
     """
 
 
-    def __init__(self, object_store, haves, wants, shallow=None, progress=None,
-                 get_tagged=None, get_parents=lambda commit: commit.parents):
+    def __init__(
+        self,
+        object_store,
+        haves,
+        wants,
+        shallow=None,
+        progress=None,
+        get_tagged=None,
+        get_parents=lambda commit: commit.parents,
+    ):
         self.object_store = object_store
         self.object_store = object_store
         if shallow is None:
         if shallow is None:
             shallow = set()
             shallow = set()
@@ -1173,20 +1259,26 @@ class MissingObjectFinder(object):
         # and such SHAs would get filtered out by _split_commits_and_tags,
         # and such SHAs would get filtered out by _split_commits_and_tags,
         # wants shall list only known SHAs, and otherwise
         # wants shall list only known SHAs, and otherwise
         # _split_commits_and_tags fails with KeyError
         # _split_commits_and_tags fails with KeyError
-        have_commits, have_tags, have_others = (
-            _split_commits_and_tags(object_store, haves, True))
-        want_commits, want_tags, want_others = (
-            _split_commits_and_tags(object_store, wants, False))
+        have_commits, have_tags, have_others = _split_commits_and_tags(
+            object_store, haves, True
+        )
+        want_commits, want_tags, want_others = _split_commits_and_tags(
+            object_store, wants, False
+        )
         # all_ancestors is a set of commits that shall not be sent
         # all_ancestors is a set of commits that shall not be sent
         # (complete repository up to 'haves')
         # (complete repository up to 'haves')
         all_ancestors = object_store._collect_ancestors(
         all_ancestors = object_store._collect_ancestors(
-            have_commits, shallow=shallow, get_parents=self._get_parents)[0]
+            have_commits, shallow=shallow, get_parents=self._get_parents
+        )[0]
         # all_missing - complete set of commits between haves and wants
         # all_missing - complete set of commits between haves and wants
         # common - commits from all_ancestors we hit into while
         # common - commits from all_ancestors we hit into while
         # traversing parent hierarchy of wants
         # traversing parent hierarchy of wants
         missing_commits, common_commits = object_store._collect_ancestors(
         missing_commits, common_commits = object_store._collect_ancestors(
-            want_commits, all_ancestors, shallow=shallow,
-            get_parents=self._get_parents)
+            want_commits,
+            all_ancestors,
+            shallow=shallow,
+            get_parents=self._get_parents,
+        )
         self.sha_done = set()
         self.sha_done = set()
         # Now, fill sha_done with commits and revisions of
         # Now, fill sha_done with commits and revisions of
         # files and directories known to be both locally
         # files and directories known to be both locally
@@ -1216,8 +1308,7 @@ class MissingObjectFinder(object):
         self._tagged = get_tagged and get_tagged() or {}
         self._tagged = get_tagged and get_tagged() or {}
 
 
     def add_todo(self, entries):
     def add_todo(self, entries):
-        self.objects_to_send.update([e for e in entries
-                                     if not e[0] in self.sha_done])
+        self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
 
 
     def next(self):
     def next(self):
         while True:
         while True:
@@ -1231,16 +1322,19 @@ class MissingObjectFinder(object):
             if isinstance(o, Commit):
             if isinstance(o, Commit):
                 self.add_todo([(o.tree, "", False)])
                 self.add_todo([(o.tree, "", False)])
             elif isinstance(o, Tree):
             elif isinstance(o, Tree):
-                self.add_todo([(s, n, not stat.S_ISDIR(m))
-                               for n, m, s in o.iteritems()
-                               if not S_ISGITLINK(m)])
+                self.add_todo(
+                    [
+                        (s, n, not stat.S_ISDIR(m))
+                        for n, m, s in o.iteritems()
+                        if not S_ISGITLINK(m)
+                    ]
+                )
             elif isinstance(o, Tag):
             elif isinstance(o, Tag):
                 self.add_todo([(o.object[1], None, False)])
                 self.add_todo([(o.object[1], None, False)])
         if sha in self._tagged:
         if sha in self._tagged:
             self.add_todo([(self._tagged[sha], None, True)])
             self.add_todo([(self._tagged[sha], None, True)])
         self.sha_done.add(sha)
         self.sha_done.add(sha)
-        self.progress(("counting objects: %d\r" %
-                       len(self.sha_done)).encode('ascii'))
+        self.progress(("counting objects: %d\r" % len(self.sha_done)).encode("ascii"))
         return (sha, name)
         return (sha, name)
 
 
     __next__ = next
     __next__ = next
@@ -1297,10 +1391,12 @@ class ObjectStoreGraphWalker(object):
         """Iterate over ancestors of heads in the target."""
         """Iterate over ancestors of heads in the target."""
         if self.heads:
         if self.heads:
             ret = self.heads.pop()
             ret = self.heads.pop()
-            ps = self.get_parents(ret)
+            try:
+                ps = self.get_parents(ret)
+            except KeyError:
+                return None
             self.parents[ret] = ps
             self.parents[ret] = ps
-            self.heads.update(
-                [p for p in ps if p not in self.parents])
+            self.heads.update([p for p in ps if p not in self.parents])
             return ret
             return ret
         return None
         return None
 
 
@@ -1333,15 +1429,14 @@ def commit_tree_changes(object_store, tree, changes):
     nested_changes = {}
     nested_changes = {}
     for (path, new_mode, new_sha) in changes:
     for (path, new_mode, new_sha) in changes:
         try:
         try:
-            (dirname, subpath) = path.split(b'/', 1)
+            (dirname, subpath) = path.split(b"/", 1)
         except ValueError:
         except ValueError:
             if new_sha is None:
             if new_sha is None:
                 del tree[path]
                 del tree[path]
             else:
             else:
                 tree[path] = (new_mode, new_sha)
                 tree[path] = (new_mode, new_sha)
         else:
         else:
-            nested_changes.setdefault(dirname, []).append(
-                (subpath, new_mode, new_sha))
+            nested_changes.setdefault(dirname, []).append((subpath, new_mode, new_sha))
     for name, subchanges in nested_changes.items():
     for name, subchanges in nested_changes.items():
         try:
         try:
             orig_subtree = object_store[tree[name][1]]
             orig_subtree = object_store[tree[name][1]]
@@ -1418,3 +1513,85 @@ def read_packs_file(f):
         if kind != b"P":
         if kind != b"P":
             continue
             continue
         yield os.fsdecode(name)
         yield os.fsdecode(name)
+
+
+class BucketBasedObjectStore(PackBasedObjectStore):
+    """Object store implementation that uses a bucket store like S3 as backend.
+    """
+
+    def _iter_loose_objects(self):
+        """Iterate over the SHAs of all loose objects."""
+        return iter([])
+
+    def _get_loose_object(self, sha):
+        return None
+
+    def _remove_loose_object(self, sha):
+        # Doesn't exist..
+        pass
+
+    def _remove_pack(self, name):
+        raise NotImplementedError(self._remove_pack)
+
+    def _iter_pack_names(self):
+        raise NotImplementedError(self._iter_pack_names)
+
+    def _get_pack(self, name):
+        raise NotImplementedError(self._get_pack)
+
+    def _update_pack_cache(self):
+        pack_files = set(self._iter_pack_names())
+
+        # Open newly appeared pack files
+        new_packs = []
+        for f in pack_files:
+            if f not in self._pack_cache:
+                pack = self._get_pack(f)
+                new_packs.append(pack)
+                self._pack_cache[f] = pack
+        # Remove disappeared pack files
+        for f in set(self._pack_cache) - pack_files:
+            self._pack_cache.pop(f).close()
+        return new_packs
+
+    def _upload_pack(self, basename, pack_file, index_file):
+        raise NotImplementedError
+
+    def add_pack(self):
+        """Add a new pack to this object store.
+
+        Returns: Fileobject to write to, a commit function to
+            call when the pack is finished and an abort
+            function.
+        """
+        import tempfile
+
+        pf = tempfile.SpooledTemporaryFile()
+
+        def commit():
+            if pf.tell() == 0:
+                pf.close()
+                return None
+
+            pf.seek(0)
+            p = PackData(pf.name, pf)
+            entries = p.sorted_entries()
+            basename = iter_sha1(entry[0] for entry in entries).decode('ascii')
+            idxf = tempfile.SpooledTemporaryFile()
+            checksum = p.get_stored_checksum()
+            write_pack_index_v2(idxf, entries, checksum)
+            idxf.seek(0)
+            idx = load_pack_index_file(basename + '.idx', idxf)
+            for pack in self.packs:
+                if pack.get_stored_checksum() == p.get_stored_checksum():
+                    p.close()
+                    idx.close()
+                    return pack
+            pf.seek(0)
+            idxf.seek(0)
+            self._upload_pack(basename, pf, idxf)
+            final_pack = Pack.from_objects(p, idx)
+            self._add_cached_pack(basename, final_pack)
+            return final_pack
+
+        return pf, commit, pf.close

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 330 - 213
dulwich/objects.py


+ 3 - 4
dulwich/objectspec.py

@@ -25,7 +25,7 @@ from typing import Union, List, Tuple
 
 
 def to_bytes(text):
 def to_bytes(text):
     if getattr(text, "encode", None) is not None:
     if getattr(text, "encode", None) is not None:
-        text = text.encode('ascii')
+        text = text.encode("ascii")
     return text
     return text
 
 
 
 
@@ -77,7 +77,7 @@ def parse_ref(container, refspec):
         b"refs/tags/" + refspec,
         b"refs/tags/" + refspec,
         b"refs/heads/" + refspec,
         b"refs/heads/" + refspec,
         b"refs/remotes/" + refspec,
         b"refs/remotes/" + refspec,
-        b"refs/remotes/" + refspec + b"/HEAD"
+        b"refs/remotes/" + refspec + b"/HEAD",
     ]
     ]
     for ref in possible_refs:
     for ref in possible_refs:
         if ref in container:
         if ref in container:
@@ -140,8 +140,7 @@ def parse_reftuples(
     ret = []
     ret = []
     # TODO: Support * in refspecs
     # TODO: Support * in refspecs
     for refspec in refspecs:
     for refspec in refspecs:
-        ret.append(parse_reftuple(
-            lh_container, rh_container, refspec, force=force))
+        ret.append(parse_reftuple(lh_container, rh_container, refspec, force=force))
     return ret
     return ret
 
 
 
 

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 230 - 189
dulwich/pack.py


+ 100 - 69
dulwich/patch.py

@@ -32,13 +32,12 @@ from dulwich.objects import (
     Blob,
     Blob,
     Commit,
     Commit,
     S_ISGITLINK,
     S_ISGITLINK,
-    )
+)
 
 
 FIRST_FEW_BYTES = 8000
 FIRST_FEW_BYTES = 8000
 
 
 
 
-def write_commit_patch(f, commit, contents, progress, version=None,
-                       encoding=None):
+def write_commit_patch(f, commit, contents, progress, version=None, encoding=None):
     """Write a individual file patch.
     """Write a individual file patch.
 
 
     Args:
     Args:
@@ -51,19 +50,30 @@ def write_commit_patch(f, commit, contents, progress, version=None,
     if isinstance(contents, str):
     if isinstance(contents, str):
         contents = contents.encode(encoding)
         contents = contents.encode(encoding)
     (num, total) = progress
     (num, total) = progress
-    f.write(b"From " + commit.id + b" " +
-            time.ctime(commit.commit_time).encode(encoding) + b"\n")
+    f.write(
+        b"From "
+        + commit.id
+        + b" "
+        + time.ctime(commit.commit_time).encode(encoding)
+        + b"\n"
+    )
     f.write(b"From: " + commit.author + b"\n")
     f.write(b"From: " + commit.author + b"\n")
-    f.write(b"Date: " +
-            time.strftime("%a, %d %b %Y %H:%M:%S %Z").encode(encoding) + b"\n")
-    f.write(("Subject: [PATCH %d/%d] " % (num, total)).encode(encoding) +
-            commit.message + b"\n")
+    f.write(
+        b"Date: " + time.strftime("%a, %d %b %Y %H:%M:%S %Z").encode(encoding) + b"\n"
+    )
+    f.write(
+        ("Subject: [PATCH %d/%d] " % (num, total)).encode(encoding)
+        + commit.message
+        + b"\n"
+    )
     f.write(b"\n")
     f.write(b"\n")
     f.write(b"---\n")
     f.write(b"---\n")
     try:
     try:
         import subprocess
         import subprocess
-        p = subprocess.Popen(["diffstat"], stdout=subprocess.PIPE,
-                             stdin=subprocess.PIPE)
+
+        p = subprocess.Popen(
+            ["diffstat"], stdout=subprocess.PIPE, stdin=subprocess.PIPE
+        )
     except (ImportError, OSError):
     except (ImportError, OSError):
         pass  # diffstat not available?
         pass  # diffstat not available?
     else:
     else:
@@ -74,6 +84,7 @@ def write_commit_patch(f, commit, contents, progress, version=None,
     f.write(b"-- \n")
     f.write(b"-- \n")
     if version is None:
     if version is None:
         from dulwich import __version__ as dulwich_version
         from dulwich import __version__ as dulwich_version
+
         f.write(b"Dulwich %d.%d.%d\n" % dulwich_version)
         f.write(b"Dulwich %d.%d.%d\n" % dulwich_version)
     else:
     else:
         f.write(version.encode(encoding) + b"\n")
         f.write(version.encode(encoding) + b"\n")
@@ -86,7 +97,7 @@ def get_summary(commit):
       commit: Commit
       commit: Commit
     Returns: Summary string
     Returns: Summary string
     """
     """
-    decoded = commit.message.decode(errors='replace')
+    decoded = commit.message.decode(errors="replace")
     return decoded.splitlines()[0].replace(" ", "-")
     return decoded.splitlines()[0].replace(" ", "-")
 
 
 
 
@@ -97,15 +108,24 @@ def _format_range_unified(start, stop):
     beginning = start + 1  # lines start numbering with one
     beginning = start + 1  # lines start numbering with one
     length = stop - start
     length = stop - start
     if length == 1:
     if length == 1:
-        return '{}'.format(beginning)
+        return "{}".format(beginning)
     if not length:
     if not length:
         beginning -= 1  # empty ranges begin at line just before the range
         beginning -= 1  # empty ranges begin at line just before the range
-    return '{},{}'.format(beginning, length)
-
-
-def unified_diff(a, b, fromfile='', tofile='', fromfiledate='',
-                 tofiledate='', n=3, lineterm='\n', tree_encoding='utf-8',
-                 output_encoding='utf-8'):
+    return "{},{}".format(beginning, length)
+
+
+def unified_diff(
+    a,
+    b,
+    fromfile="",
+    tofile="",
+    fromfiledate="",
+    tofiledate="",
+    n=3,
+    lineterm="\n",
+    tree_encoding="utf-8",
+    output_encoding="utf-8",
+):
     """difflib.unified_diff that can detect "No newline at end of file" as
     """difflib.unified_diff that can detect "No newline at end of file" as
     original "git diff" does.
     original "git diff" does.
 
 
@@ -115,43 +135,37 @@ def unified_diff(a, b, fromfile='', tofile='', fromfiledate='',
     for group in SequenceMatcher(None, a, b).get_grouped_opcodes(n):
     for group in SequenceMatcher(None, a, b).get_grouped_opcodes(n):
         if not started:
         if not started:
             started = True
             started = True
-            fromdate = '\t{}'.format(fromfiledate) if fromfiledate else ''
-            todate = '\t{}'.format(tofiledate) if tofiledate else ''
-            yield '--- {}{}{}'.format(
-                fromfile.decode(tree_encoding),
-                fromdate,
-                lineterm
-                ).encode(output_encoding)
-            yield '+++ {}{}{}'.format(
-                tofile.decode(tree_encoding),
-                todate,
-                lineterm
-                ).encode(output_encoding)
+            fromdate = "\t{}".format(fromfiledate) if fromfiledate else ""
+            todate = "\t{}".format(tofiledate) if tofiledate else ""
+            yield "--- {}{}{}".format(
+                fromfile.decode(tree_encoding), fromdate, lineterm
+            ).encode(output_encoding)
+            yield "+++ {}{}{}".format(
+                tofile.decode(tree_encoding), todate, lineterm
+            ).encode(output_encoding)
 
 
         first, last = group[0], group[-1]
         first, last = group[0], group[-1]
         file1_range = _format_range_unified(first[1], last[2])
         file1_range = _format_range_unified(first[1], last[2])
         file2_range = _format_range_unified(first[3], last[4])
         file2_range = _format_range_unified(first[3], last[4])
-        yield '@@ -{} +{} @@{}'.format(
-            file1_range,
-            file2_range,
-            lineterm
-             ).encode(output_encoding)
+        yield "@@ -{} +{} @@{}".format(file1_range, file2_range, lineterm).encode(
+            output_encoding
+        )
 
 
         for tag, i1, i2, j1, j2 in group:
         for tag, i1, i2, j1, j2 in group:
-            if tag == 'equal':
+            if tag == "equal":
                 for line in a[i1:i2]:
                 for line in a[i1:i2]:
-                    yield b' ' + line
+                    yield b" " + line
                 continue
                 continue
-            if tag in ('replace', 'delete'):
+            if tag in ("replace", "delete"):
                 for line in a[i1:i2]:
                 for line in a[i1:i2]:
-                    if not line[-1:] == b'\n':
-                        line += b'\n\\ No newline at end of file\n'
-                    yield b'-' + line
-            if tag in ('replace', 'insert'):
+                    if not line[-1:] == b"\n":
+                        line += b"\n\\ No newline at end of file\n"
+                    yield b"-" + line
+            if tag in ("replace", "insert"):
                 for line in b[j1:j2]:
                 for line in b[j1:j2]:
-                    if not line[-1:] == b'\n':
-                        line += b'\n\\ No newline at end of file\n'
-                    yield b'+' + line
+                    if not line[-1:] == b"\n":
+                        line += b"\n\\ No newline at end of file\n"
+                    yield b"+" + line
 
 
 
 
 def is_binary(content):
 def is_binary(content):
@@ -160,7 +174,7 @@ def is_binary(content):
     Args:
     Args:
       content: Bytestring to check for binary content
       content: Bytestring to check for binary content
     """
     """
-    return b'\0' in content[:FIRST_FEW_BYTES]
+    return b"\0" in content[:FIRST_FEW_BYTES]
 
 
 
 
 def shortid(hexsha):
 def shortid(hexsha):
@@ -197,7 +211,7 @@ def write_object_diff(f, store, old_file, new_file, diff_binary=False):
 
 
     def content(mode, hexsha):
     def content(mode, hexsha):
         if hexsha is None:
         if hexsha is None:
-            return Blob.from_string(b'')
+            return Blob.from_string(b"")
         elif S_ISGITLINK(mode):
         elif S_ISGITLINK(mode):
             return Blob.from_string(b"Subproject commit " + hexsha + b"\n")
             return Blob.from_string(b"Subproject commit " + hexsha + b"\n")
         else:
         else:
@@ -208,12 +222,13 @@ def write_object_diff(f, store, old_file, new_file, diff_binary=False):
             return []
             return []
         else:
         else:
             return content.splitlines()
             return content.splitlines()
-    f.writelines(gen_diff_header(
-        (old_path, new_path), (old_mode, new_mode), (old_id, new_id)))
+
+    f.writelines(
+        gen_diff_header((old_path, new_path), (old_mode, new_mode), (old_id, new_id))
+    )
     old_content = content(old_mode, old_id)
     old_content = content(old_mode, old_id)
     new_content = content(new_mode, new_id)
     new_content = content(new_mode, new_id)
-    if not diff_binary and (
-            is_binary(old_content.data) or is_binary(new_content.data)):
+    if not diff_binary and (is_binary(old_content.data) or is_binary(new_content.data)):
         binary_diff = (
         binary_diff = (
             b"Binary files "
             b"Binary files "
             + patched_old_path
             + patched_old_path
@@ -223,8 +238,14 @@ def write_object_diff(f, store, old_file, new_file, diff_binary=False):
         )
         )
         f.write(binary_diff)
         f.write(binary_diff)
     else:
     else:
-        f.writelines(unified_diff(lines(old_content), lines(new_content),
-                     patched_old_path, patched_new_path))
+        f.writelines(
+            unified_diff(
+                lines(old_content),
+                lines(new_content),
+                patched_old_path,
+                patched_new_path,
+            )
+        )
 
 
 
 
 # TODO(jelmer): Support writing unicode, rather than bytes.
 # TODO(jelmer): Support writing unicode, rather than bytes.
@@ -250,13 +271,13 @@ def gen_diff_header(paths, modes, shas):
     if old_mode != new_mode:
     if old_mode != new_mode:
         if new_mode is not None:
         if new_mode is not None:
             if old_mode is not None:
             if old_mode is not None:
-                yield ("old file mode %o\n" % old_mode).encode('ascii')
-            yield ("new file mode %o\n" % new_mode).encode('ascii')
+                yield ("old file mode %o\n" % old_mode).encode("ascii")
+            yield ("new file mode %o\n" % new_mode).encode("ascii")
         else:
         else:
-            yield ("deleted file mode %o\n" % old_mode).encode('ascii')
+            yield ("deleted file mode %o\n" % old_mode).encode("ascii")
     yield b"index " + shortid(old_sha) + b".." + shortid(new_sha)
     yield b"index " + shortid(old_sha) + b".." + shortid(new_sha)
     if new_mode is not None and old_mode is not None:
     if new_mode is not None and old_mode is not None:
-        yield (" %o" % new_mode).encode('ascii')
+        yield (" %o" % new_mode).encode("ascii")
     yield b"\n"
     yield b"\n"
 
 
 
 
@@ -281,13 +302,19 @@ def write_blob_diff(f, old_file, new_file):
             return blob.splitlines()
             return blob.splitlines()
         else:
         else:
             return []
             return []
-    f.writelines(gen_diff_header(
-        (old_path, new_path), (old_mode, new_mode),
-        (getattr(old_blob, "id", None), getattr(new_blob, "id", None))))
+
+    f.writelines(
+        gen_diff_header(
+            (old_path, new_path),
+            (old_mode, new_mode),
+            (getattr(old_blob, "id", None), getattr(new_blob, "id", None)),
+        )
+    )
     old_contents = lines(old_blob)
     old_contents = lines(old_blob)
     new_contents = lines(new_blob)
     new_contents = lines(new_blob)
-    f.writelines(unified_diff(old_contents, new_contents,
-                 patched_old_path, patched_new_path))
+    f.writelines(
+        unified_diff(old_contents, new_contents, patched_old_path, patched_new_path)
+    )
 
 
 
 
 def write_tree_diff(f, store, old_tree, new_tree, diff_binary=False):
 def write_tree_diff(f, store, old_tree, new_tree, diff_binary=False):
@@ -302,8 +329,13 @@ def write_tree_diff(f, store, old_tree, new_tree, diff_binary=False):
     """
     """
     changes = store.tree_changes(old_tree, new_tree)
     changes = store.tree_changes(old_tree, new_tree)
     for (oldpath, newpath), (oldmode, newmode), (oldsha, newsha) in changes:
     for (oldpath, newpath), (oldmode, newmode), (oldsha, newsha) in changes:
-        write_object_diff(f, store, (oldpath, oldmode, oldsha),
-                          (newpath, newmode, newsha), diff_binary=diff_binary)
+        write_object_diff(
+            f,
+            store,
+            (oldpath, oldmode, oldsha),
+            (newpath, newmode, newsha),
+            diff_binary=diff_binary,
+        )
 
 
 
 
 def git_am_patch_split(f, encoding=None):
 def git_am_patch_split(f, encoding=None):
@@ -317,8 +349,7 @@ def git_am_patch_split(f, encoding=None):
     encoding = encoding or getattr(f, "encoding", "ascii")
     encoding = encoding or getattr(f, "encoding", "ascii")
     encoding = encoding or "ascii"
     encoding = encoding or "ascii"
     contents = f.read()
     contents = f.read()
-    if (isinstance(contents, bytes) and
-            getattr(email.parser, "BytesParser", None)):
+    if isinstance(contents, bytes) and getattr(email.parser, "BytesParser", None):
         parser = email.parser.BytesParser()
         parser = email.parser.BytesParser()
         msg = parser.parsebytes(contents)
         msg = parser.parsebytes(contents)
     else:
     else:
@@ -344,7 +375,7 @@ def parse_patch_message(msg, encoding=None):
         subject = msg["subject"]
         subject = msg["subject"]
     else:
     else:
         close = msg["subject"].index("] ", patch_tag_start)
         close = msg["subject"].index("] ", patch_tag_start)
-        subject = msg["subject"][close+2:]
+        subject = msg["subject"][close + 2 :]
     c.message = (subject.replace("\n", "") + "\n").encode(encoding)
     c.message = (subject.replace("\n", "") + "\n").encode(encoding)
     first = True
     first = True
 
 
@@ -357,7 +388,7 @@ def parse_patch_message(msg, encoding=None):
             break
             break
         if first:
         if first:
             if line.startswith(b"From: "):
             if line.startswith(b"From: "):
-                c.author = line[len(b"From: "):].rstrip()
+                c.author = line[len(b"From: ") :].rstrip()
             else:
             else:
                 c.message += b"\n" + line
                 c.message += b"\n" + line
             first = False
             first = False

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 320 - 186
dulwich/porcelain.py


+ 84 - 75
dulwich/protocol.py

@@ -24,14 +24,14 @@
 from io import BytesIO
 from io import BytesIO
 from os import (
 from os import (
     SEEK_END,
     SEEK_END,
-    )
+)
 import socket
 import socket
 
 
 import dulwich
 import dulwich
 from dulwich.errors import (
 from dulwich.errors import (
     HangupException,
     HangupException,
     GitProtocolError,
     GitProtocolError,
-    )
+)
 
 
 TCP_GIT_PORT = 9418
 TCP_GIT_PORT = 9418
 
 
@@ -48,77 +48,86 @@ SIDE_BAND_CHANNEL_PROGRESS = 2
 # fatal error message just before stream aborts
 # fatal error message just before stream aborts
 SIDE_BAND_CHANNEL_FATAL = 3
 SIDE_BAND_CHANNEL_FATAL = 3
 
 
-CAPABILITY_ATOMIC = b'atomic'
-CAPABILITY_DEEPEN_SINCE = b'deepen-since'
-CAPABILITY_DEEPEN_NOT = b'deepen-not'
-CAPABILITY_DEEPEN_RELATIVE = b'deepen-relative'
-CAPABILITY_DELETE_REFS = b'delete-refs'
-CAPABILITY_INCLUDE_TAG = b'include-tag'
-CAPABILITY_MULTI_ACK = b'multi_ack'
-CAPABILITY_MULTI_ACK_DETAILED = b'multi_ack_detailed'
-CAPABILITY_NO_DONE = b'no-done'
-CAPABILITY_NO_PROGRESS = b'no-progress'
-CAPABILITY_OFS_DELTA = b'ofs-delta'
-CAPABILITY_QUIET = b'quiet'
-CAPABILITY_REPORT_STATUS = b'report-status'
-CAPABILITY_SHALLOW = b'shallow'
-CAPABILITY_SIDE_BAND = b'side-band'
-CAPABILITY_SIDE_BAND_64K = b'side-band-64k'
-CAPABILITY_THIN_PACK = b'thin-pack'
-CAPABILITY_AGENT = b'agent'
-CAPABILITY_SYMREF = b'symref'
-CAPABILITY_ALLOW_TIP_SHA1_IN_WANT = b'allow-tip-sha1-in-want'
-CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT = b'allow-reachable-sha1-in-want'
+CAPABILITY_ATOMIC = b"atomic"
+CAPABILITY_DEEPEN_SINCE = b"deepen-since"
+CAPABILITY_DEEPEN_NOT = b"deepen-not"
+CAPABILITY_DEEPEN_RELATIVE = b"deepen-relative"
+CAPABILITY_DELETE_REFS = b"delete-refs"
+CAPABILITY_INCLUDE_TAG = b"include-tag"
+CAPABILITY_MULTI_ACK = b"multi_ack"
+CAPABILITY_MULTI_ACK_DETAILED = b"multi_ack_detailed"
+CAPABILITY_NO_DONE = b"no-done"
+CAPABILITY_NO_PROGRESS = b"no-progress"
+CAPABILITY_OFS_DELTA = b"ofs-delta"
+CAPABILITY_QUIET = b"quiet"
+CAPABILITY_REPORT_STATUS = b"report-status"
+CAPABILITY_SHALLOW = b"shallow"
+CAPABILITY_SIDE_BAND = b"side-band"
+CAPABILITY_SIDE_BAND_64K = b"side-band-64k"
+CAPABILITY_THIN_PACK = b"thin-pack"
+CAPABILITY_AGENT = b"agent"
+CAPABILITY_SYMREF = b"symref"
+CAPABILITY_ALLOW_TIP_SHA1_IN_WANT = b"allow-tip-sha1-in-want"
+CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT = b"allow-reachable-sha1-in-want"
 
 
 # Magic ref that is used to attach capabilities to when
 # Magic ref that is used to attach capabilities to when
 # there are no refs. Should always be ste to ZERO_SHA.
 # there are no refs. Should always be ste to ZERO_SHA.
-CAPABILITIES_REF = b'capabilities^{}'
+CAPABILITIES_REF = b"capabilities^{}"
 
 
 COMMON_CAPABILITIES = [
 COMMON_CAPABILITIES = [
     CAPABILITY_OFS_DELTA,
     CAPABILITY_OFS_DELTA,
     CAPABILITY_SIDE_BAND,
     CAPABILITY_SIDE_BAND,
     CAPABILITY_SIDE_BAND_64K,
     CAPABILITY_SIDE_BAND_64K,
     CAPABILITY_AGENT,
     CAPABILITY_AGENT,
-    CAPABILITY_NO_PROGRESS]
-KNOWN_UPLOAD_CAPABILITIES = set(COMMON_CAPABILITIES + [
-    CAPABILITY_THIN_PACK,
-    CAPABILITY_MULTI_ACK,
-    CAPABILITY_MULTI_ACK_DETAILED,
-    CAPABILITY_INCLUDE_TAG,
-    CAPABILITY_DEEPEN_SINCE,
-    CAPABILITY_SYMREF,
-    CAPABILITY_SHALLOW,
-    CAPABILITY_DEEPEN_NOT,
-    CAPABILITY_DEEPEN_RELATIVE,
-    CAPABILITY_ALLOW_TIP_SHA1_IN_WANT,
-    CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT,
-    ])
-KNOWN_RECEIVE_CAPABILITIES = set(COMMON_CAPABILITIES + [
-    CAPABILITY_REPORT_STATUS,
-    CAPABILITY_DELETE_REFS,
-    CAPABILITY_QUIET,
-    CAPABILITY_ATOMIC,
-    ])
+    CAPABILITY_NO_PROGRESS,
+]
+KNOWN_UPLOAD_CAPABILITIES = set(
+    COMMON_CAPABILITIES
+    + [
+        CAPABILITY_THIN_PACK,
+        CAPABILITY_MULTI_ACK,
+        CAPABILITY_MULTI_ACK_DETAILED,
+        CAPABILITY_INCLUDE_TAG,
+        CAPABILITY_DEEPEN_SINCE,
+        CAPABILITY_SYMREF,
+        CAPABILITY_SHALLOW,
+        CAPABILITY_DEEPEN_NOT,
+        CAPABILITY_DEEPEN_RELATIVE,
+        CAPABILITY_ALLOW_TIP_SHA1_IN_WANT,
+        CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT,
+    ]
+)
+KNOWN_RECEIVE_CAPABILITIES = set(
+    COMMON_CAPABILITIES
+    + [
+        CAPABILITY_REPORT_STATUS,
+        CAPABILITY_DELETE_REFS,
+        CAPABILITY_QUIET,
+        CAPABILITY_ATOMIC,
+    ]
+)
+
+DEPTH_INFINITE = 0x7FFFFFFF
 
 
 
 
 def agent_string():
 def agent_string():
-    return ('dulwich/%d.%d.%d' % dulwich.__version__).encode('ascii')
+    return ("dulwich/%d.%d.%d" % dulwich.__version__).encode("ascii")
 
 
 
 
 def capability_agent():
 def capability_agent():
-    return CAPABILITY_AGENT + b'=' + agent_string()
+    return CAPABILITY_AGENT + b"=" + agent_string()
 
 
 
 
 def capability_symref(from_ref, to_ref):
 def capability_symref(from_ref, to_ref):
-    return CAPABILITY_SYMREF + b'=' + from_ref + b':' + to_ref
+    return CAPABILITY_SYMREF + b"=" + from_ref + b":" + to_ref
 
 
 
 
 def extract_capability_names(capabilities):
 def extract_capability_names(capabilities):
-    return set(parse_capability(c)[0] for c in capabilities)
+    return {parse_capability(c)[0] for c in capabilities}
 
 
 
 
 def parse_capability(capability):
 def parse_capability(capability):
-    parts = capability.split(b'=', 1)
+    parts = capability.split(b"=", 1)
     if len(parts) == 1:
     if len(parts) == 1:
         return (parts[0], None)
         return (parts[0], None)
     return tuple(parts)
     return tuple(parts)
@@ -128,12 +137,12 @@ def symref_capabilities(symrefs):
     return [capability_symref(*k) for k in symrefs]
     return [capability_symref(*k) for k in symrefs]
 
 
 
 
-COMMAND_DEEPEN = b'deepen'
-COMMAND_SHALLOW = b'shallow'
-COMMAND_UNSHALLOW = b'unshallow'
-COMMAND_DONE = b'done'
-COMMAND_WANT = b'want'
-COMMAND_HAVE = b'have'
+COMMAND_DEEPEN = b"deepen"
+COMMAND_SHALLOW = b"shallow"
+COMMAND_UNSHALLOW = b"unshallow"
+COMMAND_DONE = b"done"
+COMMAND_WANT = b"want"
+COMMAND_HAVE = b"have"
 
 
 
 
 class ProtocolFile(object):
 class ProtocolFile(object):
@@ -156,7 +165,7 @@ def format_cmd_pkt(cmd, *args):
 
 
 def parse_cmd_pkt(line):
 def parse_cmd_pkt(line):
     splice_at = line.find(b" ")
     splice_at = line.find(b" ")
-    cmd, args = line[:splice_at], line[splice_at+1:]
+    cmd, args = line[:splice_at], line[splice_at + 1 :]
     assert args[-1:] == b"\x00"
     assert args[-1:] == b"\x00"
     return cmd, args[:-1].split(b"\0")
     return cmd, args[:-1].split(b"\0")
 
 
@@ -170,8 +179,8 @@ def pkt_line(data):
         None, returns the flush-pkt ('0000').
         None, returns the flush-pkt ('0000').
     """
     """
     if data is None:
     if data is None:
-        return b'0000'
-    return ('%04x' % (len(data) + 4)).encode('ascii') + data
+        return b"0000"
+    return ("%04x" % (len(data) + 4)).encode("ascii") + data
 
 
 
 
 class Protocol(object):
 class Protocol(object):
@@ -224,18 +233,19 @@ class Protocol(object):
             size = int(sizestr, 16)
             size = int(sizestr, 16)
             if size == 0:
             if size == 0:
                 if self.report_activity:
                 if self.report_activity:
-                    self.report_activity(4, 'read')
+                    self.report_activity(4, "read")
                 return None
                 return None
             if self.report_activity:
             if self.report_activity:
-                self.report_activity(size, 'read')
-            pkt_contents = read(size-4)
+                self.report_activity(size, "read")
+            pkt_contents = read(size - 4)
         except socket.error as e:
         except socket.error as e:
             raise GitProtocolError(e)
             raise GitProtocolError(e)
         else:
         else:
             if len(pkt_contents) + 4 != size:
             if len(pkt_contents) + 4 != size:
                 raise GitProtocolError(
                 raise GitProtocolError(
-                    'Length of pkt read %04x does not match length prefix %04x'
-                    % (len(pkt_contents) + 4, size))
+                    "Length of pkt read %04x does not match length prefix %04x"
+                    % (len(pkt_contents) + 4, size)
+                )
             return pkt_contents
             return pkt_contents
 
 
     def eof(self):
     def eof(self):
@@ -265,7 +275,7 @@ class Protocol(object):
           ValueError: If more than one pkt-line is unread.
           ValueError: If more than one pkt-line is unread.
         """
         """
         if self._readahead is not None:
         if self._readahead is not None:
-            raise ValueError('Attempted to unread multiple pkt-lines.')
+            raise ValueError("Attempted to unread multiple pkt-lines.")
         self._readahead = BytesIO(pkt_line(data))
         self._readahead = BytesIO(pkt_line(data))
 
 
     def read_pkt_seq(self):
     def read_pkt_seq(self):
@@ -290,7 +300,7 @@ class Protocol(object):
             line = pkt_line(line)
             line = pkt_line(line)
             self.write(line)
             self.write(line)
             if self.report_activity:
             if self.report_activity:
-                self.report_activity(len(line), 'write')
+                self.report_activity(len(line), "write")
         except socket.error as e:
         except socket.error as e:
             raise GitProtocolError(e)
             raise GitProtocolError(e)
 
 
@@ -298,7 +308,6 @@ class Protocol(object):
         """Return a writable file-like object for this protocol."""
         """Return a writable file-like object for this protocol."""
 
 
         class ProtocolFile(object):
         class ProtocolFile(object):
-
             def __init__(self, proto):
             def __init__(self, proto):
                 self._proto = proto
                 self._proto = proto
                 self._offset = 0
                 self._offset = 0
@@ -366,10 +375,12 @@ class ReceivableProtocol(Protocol):
     will still block until at least one byte is read.
     will still block until at least one byte is read.
     """
     """
 
 
-    def __init__(self, recv, write, close=None, report_activity=None,
-                 rbufsize=_RBUFSIZE):
+    def __init__(
+        self, recv, write, close=None, report_activity=None, rbufsize=_RBUFSIZE
+    ):
         super(ReceivableProtocol, self).__init__(
         super(ReceivableProtocol, self).__init__(
-                self.read, write, close=close, report_activity=report_activity)
+            self.read, write, close=close, report_activity=report_activity
+        )
         self._recv = recv
         self._recv = recv
         self._rbuf = BytesIO()
         self._rbuf = BytesIO()
         self._rbufsize = rbufsize
         self._rbufsize = rbufsize
@@ -492,9 +503,9 @@ def extract_want_line_capabilities(text):
 
 
 def ack_type(capabilities):
 def ack_type(capabilities):
     """Extract the ack type from a capabilities list."""
     """Extract the ack type from a capabilities list."""
-    if b'multi_ack_detailed' in capabilities:
+    if b"multi_ack_detailed" in capabilities:
         return MULTI_ACK_DETAILED
         return MULTI_ACK_DETAILED
-    elif b'multi_ack' in capabilities:
+    elif b"multi_ack" in capabilities:
         return MULTI_ACK
         return MULTI_ACK
     return SINGLE_ACK
     return SINGLE_ACK
 
 
@@ -544,16 +555,14 @@ class BufferedPktLineWriter(object):
 
 
 
 
 class PktLineParser(object):
 class PktLineParser(object):
-    """Packet line parser that hands completed packets off to a callback.
-    """
+    """Packet line parser that hands completed packets off to a callback."""
 
 
     def __init__(self, handle_pkt):
     def __init__(self, handle_pkt):
         self.handle_pkt = handle_pkt
         self.handle_pkt = handle_pkt
         self._readahead = BytesIO()
         self._readahead = BytesIO()
 
 
     def parse(self, data):
     def parse(self, data):
-        """Parse a fragment of data and call back for any completed packets.
-        """
+        """Parse a fragment of data and call back for any completed packets."""
         self._readahead.write(data)
         self._readahead.write(data)
         buf = self._readahead.getvalue()
         buf = self._readahead.getvalue()
         if len(buf) < 4:
         if len(buf) < 4:

+ 88 - 13
dulwich/reflog.py

@@ -27,15 +27,15 @@ from dulwich.objects import (
     format_timezone,
     format_timezone,
     parse_timezone,
     parse_timezone,
     ZERO_SHA,
     ZERO_SHA,
-    )
+)
 
 
 Entry = collections.namedtuple(
 Entry = collections.namedtuple(
-    'Entry', ['old_sha', 'new_sha', 'committer', 'timestamp', 'timezone',
-              'message'])
+    "Entry",
+    ["old_sha", "new_sha", "committer", "timestamp", "timezone", "message"],
+)
 
 
 
 
-def format_reflog_line(old_sha, new_sha, committer, timestamp, timezone,
-                       message):
+def format_reflog_line(old_sha, new_sha, committer, timestamp, timezone, message):
     """Generate a single reflog line.
     """Generate a single reflog line.
 
 
     Args:
     Args:
@@ -48,9 +48,19 @@ def format_reflog_line(old_sha, new_sha, committer, timestamp, timezone,
     """
     """
     if old_sha is None:
     if old_sha is None:
         old_sha = ZERO_SHA
         old_sha = ZERO_SHA
-    return (old_sha + b' ' + new_sha + b' ' + committer + b' ' +
-            str(int(timestamp)).encode('ascii') + b' ' +
-            format_timezone(timezone) + b'\t' + message)
+    return (
+        old_sha
+        + b" "
+        + new_sha
+        + b" "
+        + committer
+        + b" "
+        + str(int(timestamp)).encode("ascii")
+        + b" "
+        + format_timezone(timezone)
+        + b"\t"
+        + message
+    )
 
 
 
 
 def parse_reflog_line(line):
 def parse_reflog_line(line):
@@ -61,11 +71,17 @@ def parse_reflog_line(line):
     Returns: Tuple of (old_sha, new_sha, committer, timestamp, timezone,
     Returns: Tuple of (old_sha, new_sha, committer, timestamp, timezone,
         message)
         message)
     """
     """
-    (begin, message) = line.split(b'\t', 1)
-    (old_sha, new_sha, rest) = begin.split(b' ', 2)
-    (committer, timestamp_str, timezone_str) = rest.rsplit(b' ', 2)
-    return Entry(old_sha, new_sha, committer, int(timestamp_str),
-                 parse_timezone(timezone_str)[0], message)
+    (begin, message) = line.split(b"\t", 1)
+    (old_sha, new_sha, rest) = begin.split(b" ", 2)
+    (committer, timestamp_str, timezone_str) = rest.rsplit(b" ", 2)
+    return Entry(
+        old_sha,
+        new_sha,
+        committer,
+        int(timestamp_str),
+        parse_timezone(timezone_str)[0],
+        message,
+    )
 
 
 
 
 def read_reflog(f):
 def read_reflog(f):
@@ -77,3 +93,62 @@ def read_reflog(f):
     """
     """
     for line in f:
     for line in f:
         yield parse_reflog_line(line)
         yield parse_reflog_line(line)
+
+
+def drop_reflog_entry(f, index, rewrite=False):
+    """Drop the specified reflog entry.
+
+    Args:
+        f: File-like object
+        index: Reflog entry index (in Git reflog reverse 0-indexed order)
+        rewrite: If a reflog entry's predecessor is removed, set its
+            old SHA to the new SHA of the entry that now precedes it
+    """
+    if index < 0:
+        raise ValueError("Invalid reflog index %d" % index)
+
+    log = []
+    offset = f.tell()
+    for line in f:
+        log.append((offset, parse_reflog_line(line)))
+        offset = f.tell()
+
+    inverse_index = len(log) - index - 1
+    write_offset = log[inverse_index][0]
+    f.seek(write_offset)
+
+    if index == 0:
+        f.truncate()
+        return
+
+    del log[inverse_index]
+    if rewrite and index > 0 and log:
+        if inverse_index == 0:
+            previous_new = ZERO_SHA
+        else:
+            previous_new = log[inverse_index - 1][1].new_sha
+        offset, entry = log[inverse_index]
+        log[inverse_index] = (
+            offset,
+            Entry(
+                previous_new,
+                entry.new_sha,
+                entry.committer,
+                entry.timestamp,
+                entry.timezone,
+                entry.message,
+            ),
+        )
+
+    for _, entry in log[inverse_index:]:
+        f.write(
+            format_reflog_line(
+                entry.old_sha,
+                entry.new_sha,
+                entry.committer,
+                entry.timestamp,
+                entry.timezone,
+                entry.message,
+            )
+        )
+    f.truncate()

+ 295 - 142
dulwich/refs.py

@@ -27,23 +27,23 @@ import os
 from dulwich.errors import (
 from dulwich.errors import (
     PackedRefsException,
     PackedRefsException,
     RefFormatError,
     RefFormatError,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     git_line,
     git_line,
     valid_hexsha,
     valid_hexsha,
     ZERO_SHA,
     ZERO_SHA,
-    )
+)
 from dulwich.file import (
 from dulwich.file import (
     GitFile,
     GitFile,
     ensure_dir_exists,
     ensure_dir_exists,
-    )
+)
 
 
 
 
-SYMREF = b'ref: '
-LOCAL_BRANCH_PREFIX = b'refs/heads/'
-LOCAL_TAG_PREFIX = b'refs/tags/'
-BAD_REF_CHARS = set(b'\177 ~^:?*[')
-ANNOTATED_TAG_SUFFIX = b'^{}'
+SYMREF = b"ref: "
+LOCAL_BRANCH_PREFIX = b"refs/heads/"
+LOCAL_TAG_PREFIX = b"refs/tags/"
+BAD_REF_CHARS = set(b"\177 ~^:?*[")
+ANNOTATED_TAG_SUFFIX = b"^{}"
 
 
 
 
 def parse_symref_value(contents):
 def parse_symref_value(contents):
@@ -54,7 +54,7 @@ def parse_symref_value(contents):
     Returns: Destination
     Returns: Destination
     """
     """
     if contents.startswith(SYMREF):
     if contents.startswith(SYMREF):
-        return contents[len(SYMREF):].rstrip(b'\r\n')
+        return contents[len(SYMREF) :].rstrip(b"\r\n")
     raise ValueError(contents)
     raise ValueError(contents)
 
 
 
 
@@ -72,22 +72,22 @@ def check_ref_format(refname):
     """
     """
     # These could be combined into one big expression, but are listed
     # These could be combined into one big expression, but are listed
     # separately to parallel [1].
     # separately to parallel [1].
-    if b'/.' in refname or refname.startswith(b'.'):
+    if b"/." in refname or refname.startswith(b"."):
         return False
         return False
-    if b'/' not in refname:
+    if b"/" not in refname:
         return False
         return False
-    if b'..' in refname:
+    if b".." in refname:
         return False
         return False
     for i, c in enumerate(refname):
     for i, c in enumerate(refname):
-        if ord(refname[i:i+1]) < 0o40 or c in BAD_REF_CHARS:
+        if ord(refname[i : i + 1]) < 0o40 or c in BAD_REF_CHARS:
             return False
             return False
-    if refname[-1] in b'/.':
+    if refname[-1] in b"/.":
         return False
         return False
-    if refname.endswith(b'.lock'):
+    if refname.endswith(b".lock"):
         return False
         return False
-    if b'@{' in refname:
+    if b"@{" in refname:
         return False
         return False
-    if b'\\' in refname:
+    if b"\\" in refname:
         return False
         return False
     return True
     return True
 
 
@@ -98,17 +98,31 @@ class RefsContainer(object):
     def __init__(self, logger=None):
     def __init__(self, logger=None):
         self._logger = logger
         self._logger = logger
 
 
-    def _log(self, ref, old_sha, new_sha, committer=None, timestamp=None,
-             timezone=None, message=None):
+    def _log(
+        self,
+        ref,
+        old_sha,
+        new_sha,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         if self._logger is None:
         if self._logger is None:
             return
             return
         if message is None:
         if message is None:
             return
             return
-        self._logger(ref, old_sha, new_sha, committer, timestamp,
-                     timezone, message)
-
-    def set_symbolic_ref(self, name, other, committer=None, timestamp=None,
-                         timezone=None, message=None):
+        self._logger(ref, old_sha, new_sha, committer, timestamp, timezone, message)
+
+    def set_symbolic_ref(
+        self,
+        name,
+        other,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
         Args:
         Args:
@@ -139,8 +153,16 @@ class RefsContainer(object):
         """
         """
         return None
         return None
 
 
-    def import_refs(self, base, other, committer=None, timestamp=None,
-                    timezone=None, message=None, prune=False):
+    def import_refs(
+        self,
+        base,
+        other,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+        prune=False,
+    ):
         if prune:
         if prune:
             to_delete = set(self.subkeys(base))
             to_delete = set(self.subkeys(base))
         else:
         else:
@@ -149,16 +171,16 @@ class RefsContainer(object):
             if value is None:
             if value is None:
                 to_delete.add(name)
                 to_delete.add(name)
             else:
             else:
-                self.set_if_equals(b'/'.join((base, name)), None, value,
-                                   message=message)
+                self.set_if_equals(
+                    b"/".join((base, name)), None, value, message=message
+                )
             if to_delete:
             if to_delete:
                 try:
                 try:
                     to_delete.remove(name)
                     to_delete.remove(name)
                 except KeyError:
                 except KeyError:
                     pass
                     pass
         for ref in to_delete:
         for ref in to_delete:
-            self.remove_if_equals(
-                b'/'.join((base, ref)), None, message=message)
+            self.remove_if_equals(b"/".join((base, ref)), None, message=message)
 
 
     def allkeys(self):
     def allkeys(self):
         """All refs present in this container."""
         """All refs present in this container."""
@@ -196,18 +218,16 @@ class RefsContainer(object):
         return keys
         return keys
 
 
     def as_dict(self, base=None):
     def as_dict(self, base=None):
-        """Return the contents of this container as a dictionary.
-
-        """
+        """Return the contents of this container as a dictionary."""
         ret = {}
         ret = {}
         keys = self.keys(base)
         keys = self.keys(base)
         if base is None:
         if base is None:
-            base = b''
+            base = b""
         else:
         else:
-            base = base.rstrip(b'/')
+            base = base.rstrip(b"/")
         for key in keys:
         for key in keys:
             try:
             try:
-                ret[key] = self[(base + b'/' + key).strip(b'/')]
+                ret[key] = self[(base + b"/" + key).strip(b"/")]
             except KeyError:
             except KeyError:
                 continue  # Unable to resolve
                 continue  # Unable to resolve
 
 
@@ -226,9 +246,9 @@ class RefsContainer(object):
         Raises:
         Raises:
           KeyError: if a refname is not HEAD or is otherwise not valid.
           KeyError: if a refname is not HEAD or is otherwise not valid.
         """
         """
-        if name in (b'HEAD', b'refs/stash'):
+        if name in (b"HEAD", b"refs/stash"):
             return
             return
-        if not name.startswith(b'refs/') or not check_ref_format(name[5:]):
+        if not name.startswith(b"refs/") or not check_ref_format(name[5:]):
             raise RefFormatError(name)
             raise RefFormatError(name)
 
 
     def read_ref(self, refname):
     def read_ref(self, refname):
@@ -264,7 +284,7 @@ class RefsContainer(object):
         depth = 0
         depth = 0
         refnames = []
         refnames = []
         while contents.startswith(SYMREF):
         while contents.startswith(SYMREF):
-            refname = contents[len(SYMREF):]
+            refname = contents[len(SYMREF) :]
             refnames.append(refname)
             refnames.append(refname)
             contents = self.read_ref(refname)
             contents = self.read_ref(refname)
             if not contents:
             if not contents:
@@ -276,9 +296,11 @@ class RefsContainer(object):
 
 
     def _follow(self, name):
     def _follow(self, name):
         import warnings
         import warnings
+
         warnings.warn(
         warnings.warn(
-            "RefsContainer._follow is deprecated. Use RefsContainer.follow "
-            "instead.", DeprecationWarning)
+            "RefsContainer._follow is deprecated. Use RefsContainer.follow " "instead.",
+            DeprecationWarning,
+        )
         refnames, contents = self.follow(name)
         refnames, contents = self.follow(name)
         if not refnames:
         if not refnames:
             return (None, contents)
             return (None, contents)
@@ -299,8 +321,16 @@ class RefsContainer(object):
             raise KeyError(name)
             raise KeyError(name)
         return sha
         return sha
 
 
-    def set_if_equals(self, name, old_ref, new_ref, committer=None,
-                      timestamp=None, timezone=None, message=None):
+    def set_if_equals(
+        self,
+        name,
+        old_ref,
+        new_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
         This method follows all symbolic references if applicable for the
         This method follows all symbolic references if applicable for the
@@ -343,8 +373,15 @@ class RefsContainer(object):
         """
         """
         self.set_if_equals(name, None, ref)
         self.set_if_equals(name, None, ref)
 
 
-    def remove_if_equals(self, name, old_ref, committer=None,
-                         timestamp=None, timezone=None, message=None):
+    def remove_if_equals(
+        self,
+        name,
+        old_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
         This method does not follow symbolic references, even if applicable for
         This method does not follow symbolic references, even if applicable for
@@ -399,12 +436,12 @@ class RefsContainer(object):
 
 
 
 
 class _DictRefsWatcher(object):
 class _DictRefsWatcher(object):
-
     def __init__(self, refs):
     def __init__(self, refs):
         self._refs = refs
         self._refs = refs
 
 
     def __enter__(self):
     def __enter__(self):
         from queue import Queue
         from queue import Queue
+
         self.queue = Queue()
         self.queue = Queue()
         self._refs._watchers.add(self)
         self._refs._watchers.add(self)
         return self
         return self
@@ -449,17 +486,39 @@ class DictRefsContainer(RefsContainer):
     def watch(self):
     def watch(self):
         return _DictRefsWatcher(self)
         return _DictRefsWatcher(self)
 
 
-    def set_symbolic_ref(self, name, other, committer=None,
-                         timestamp=None, timezone=None, message=None):
+    def set_symbolic_ref(
+        self,
+        name,
+        other,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         old = self.follow(name)[-1]
         old = self.follow(name)[-1]
         new = SYMREF + other
         new = SYMREF + other
         self._refs[name] = new
         self._refs[name] = new
         self._notify(name, new)
         self._notify(name, new)
-        self._log(name, old, new, committer=committer, timestamp=timestamp,
-                  timezone=timezone, message=message)
-
-    def set_if_equals(self, name, old_ref, new_ref, committer=None,
-                      timestamp=None, timezone=None, message=None):
+        self._log(
+            name,
+            old,
+            new,
+            committer=committer,
+            timestamp=timestamp,
+            timezone=timezone,
+            message=message,
+        )
+
+    def set_if_equals(
+        self,
+        name,
+        old_ref,
+        new_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
             return False
             return False
         realnames, _ = self.follow(name)
         realnames, _ = self.follow(name)
@@ -468,22 +527,50 @@ class DictRefsContainer(RefsContainer):
             old = self._refs.get(realname)
             old = self._refs.get(realname)
             self._refs[realname] = new_ref
             self._refs[realname] = new_ref
             self._notify(realname, new_ref)
             self._notify(realname, new_ref)
-            self._log(realname, old, new_ref, committer=committer,
-                      timestamp=timestamp, timezone=timezone, message=message)
+            self._log(
+                realname,
+                old,
+                new_ref,
+                committer=committer,
+                timestamp=timestamp,
+                timezone=timezone,
+                message=message,
+            )
         return True
         return True
 
 
-    def add_if_new(self, name, ref, committer=None, timestamp=None,
-                   timezone=None, message=None):
+    def add_if_new(
+        self,
+        name,
+        ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         if name in self._refs:
         if name in self._refs:
             return False
             return False
         self._refs[name] = ref
         self._refs[name] = ref
         self._notify(name, ref)
         self._notify(name, ref)
-        self._log(name, None, ref, committer=committer, timestamp=timestamp,
-                  timezone=timezone, message=message)
+        self._log(
+            name,
+            None,
+            ref,
+            committer=committer,
+            timestamp=timestamp,
+            timezone=timezone,
+            message=message,
+        )
         return True
         return True
 
 
-    def remove_if_equals(self, name, old_ref, committer=None, timestamp=None,
-                         timezone=None, message=None):
+    def remove_if_equals(
+        self,
+        name,
+        old_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
             return False
             return False
         try:
         try:
@@ -492,8 +579,15 @@ class DictRefsContainer(RefsContainer):
             pass
             pass
         else:
         else:
             self._notify(name, None)
             self._notify(name, None)
-            self._log(name, old, None, committer=committer,
-                      timestamp=timestamp, timezone=timezone, message=message)
+            self._log(
+                name,
+                old,
+                None,
+                committer=committer,
+                timestamp=timestamp,
+                timezone=timezone,
+                message=message,
+            )
         return True
         return True
 
 
     def get_peeled(self, name):
     def get_peeled(self, name):
@@ -518,7 +612,7 @@ class InfoRefsContainer(RefsContainer):
         self._refs = {}
         self._refs = {}
         self._peeled = {}
         self._peeled = {}
         for line in f.readlines():
         for line in f.readlines():
-            sha, name = line.rstrip(b'\n').split(b'\t')
+            sha, name = line.rstrip(b"\n").split(b"\t")
             if name.endswith(ANNOTATED_TAG_SUFFIX):
             if name.endswith(ANNOTATED_TAG_SUFFIX):
                 name = name[:-3]
                 name = name[:-3]
                 if not check_ref_format(name):
                 if not check_ref_format(name):
@@ -546,32 +640,35 @@ class InfoRefsContainer(RefsContainer):
 
 
 
 
 class _InotifyRefsWatcher(object):
 class _InotifyRefsWatcher(object):
-
     def __init__(self, path):
     def __init__(self, path):
         import pyinotify
         import pyinotify
         from queue import Queue
         from queue import Queue
+
         self.path = os.fsdecode(path)
         self.path = os.fsdecode(path)
         self.manager = pyinotify.WatchManager()
         self.manager = pyinotify.WatchManager()
         self.manager.add_watch(
         self.manager.add_watch(
-            self.path, pyinotify.IN_DELETE |
-            pyinotify.IN_CLOSE_WRITE | pyinotify.IN_MOVED_TO, rec=True,
-            auto_add=True)
+            self.path,
+            pyinotify.IN_DELETE | pyinotify.IN_CLOSE_WRITE | pyinotify.IN_MOVED_TO,
+            rec=True,
+            auto_add=True,
+        )
 
 
         self.notifier = pyinotify.ThreadedNotifier(
         self.notifier = pyinotify.ThreadedNotifier(
-            self.manager, default_proc_fun=self._notify)
+            self.manager, default_proc_fun=self._notify
+        )
         self.queue = Queue()
         self.queue = Queue()
 
 
     def _notify(self, event):
     def _notify(self, event):
         if event.dir:
         if event.dir:
             return
             return
-        if event.pathname.endswith('.lock'):
+        if event.pathname.endswith(".lock"):
             return
             return
         ref = os.fsencode(os.path.relpath(event.pathname, self.path))
         ref = os.fsencode(os.path.relpath(event.pathname, self.path))
-        if event.maskname == 'IN_DELETE':
+        if event.maskname == "IN_DELETE":
             self.queue.put_nowait((ref, None))
             self.queue.put_nowait((ref, None))
-        elif event.maskname in ('IN_CLOSE_WRITE', 'IN_MOVED_TO'):
-            with open(event.pathname, 'rb') as f:
-                sha = f.readline().rstrip(b'\n\r')
+        elif event.maskname in ("IN_CLOSE_WRITE", "IN_MOVED_TO"):
+            with open(event.pathname, "rb") as f:
+                sha = f.readline().rstrip(b"\n\r")
                 self.queue.put_nowait((ref, sha))
                 self.queue.put_nowait((ref, sha))
 
 
     def __next__(self):
     def __next__(self):
@@ -591,12 +688,12 @@ class DiskRefsContainer(RefsContainer):
 
 
     def __init__(self, path, worktree_path=None, logger=None):
     def __init__(self, path, worktree_path=None, logger=None):
         super(DiskRefsContainer, self).__init__(logger=logger)
         super(DiskRefsContainer, self).__init__(logger=logger)
-        if getattr(path, 'encode', None) is not None:
+        if getattr(path, "encode", None) is not None:
             path = os.fsencode(path)
             path = os.fsencode(path)
         self.path = path
         self.path = path
         if worktree_path is None:
         if worktree_path is None:
             worktree_path = path
             worktree_path = path
-        if getattr(worktree_path, 'encode', None) is not None:
+        if getattr(worktree_path, "encode", None) is not None:
             worktree_path = os.fsencode(worktree_path)
             worktree_path = os.fsencode(worktree_path)
         self.worktree_path = worktree_path
         self.worktree_path = worktree_path
         self._packed_refs = None
         self._packed_refs = None
@@ -609,30 +706,30 @@ class DiskRefsContainer(RefsContainer):
         subkeys = set()
         subkeys = set()
         path = self.refpath(base)
         path = self.refpath(base)
         for root, unused_dirs, files in os.walk(path):
         for root, unused_dirs, files in os.walk(path):
-            dir = root[len(path):]
-            if os.path.sep != '/':
+            dir = root[len(path) :]
+            if os.path.sep != "/":
                 dir = dir.replace(os.fsencode(os.path.sep), b"/")
                 dir = dir.replace(os.fsencode(os.path.sep), b"/")
-            dir = dir.strip(b'/')
+            dir = dir.strip(b"/")
             for filename in files:
             for filename in files:
                 refname = b"/".join(([dir] if dir else []) + [filename])
                 refname = b"/".join(([dir] if dir else []) + [filename])
                 # check_ref_format requires at least one /, so we prepend the
                 # check_ref_format requires at least one /, so we prepend the
                 # base before calling it.
                 # base before calling it.
-                if check_ref_format(base + b'/' + refname):
+                if check_ref_format(base + b"/" + refname):
                     subkeys.add(refname)
                     subkeys.add(refname)
         for key in self.get_packed_refs():
         for key in self.get_packed_refs():
             if key.startswith(base):
             if key.startswith(base):
-                subkeys.add(key[len(base):].strip(b'/'))
+                subkeys.add(key[len(base) :].strip(b"/"))
         return subkeys
         return subkeys
 
 
     def allkeys(self):
     def allkeys(self):
         allkeys = set()
         allkeys = set()
-        if os.path.exists(self.refpath(b'HEAD')):
-            allkeys.add(b'HEAD')
-        path = self.refpath(b'')
-        refspath = self.refpath(b'refs')
+        if os.path.exists(self.refpath(b"HEAD")):
+            allkeys.add(b"HEAD")
+        path = self.refpath(b"")
+        refspath = self.refpath(b"refs")
         for root, unused_dirs, files in os.walk(refspath):
         for root, unused_dirs, files in os.walk(refspath):
-            dir = root[len(path):]
-            if os.path.sep != '/':
+            dir = root[len(path) :]
+            if os.path.sep != "/":
                 dir = dir.replace(os.fsencode(os.path.sep), b"/")
                 dir = dir.replace(os.fsencode(os.path.sep), b"/")
             for filename in files:
             for filename in files:
                 refname = b"/".join([dir, filename])
                 refname = b"/".join([dir, filename])
@@ -642,14 +739,12 @@ class DiskRefsContainer(RefsContainer):
         return allkeys
         return allkeys
 
 
     def refpath(self, name):
     def refpath(self, name):
-        """Return the disk path of a ref.
-
-        """
+        """Return the disk path of a ref."""
         if os.path.sep != "/":
         if os.path.sep != "/":
             name = name.replace(b"/", os.fsencode(os.path.sep))
             name = name.replace(b"/", os.fsencode(os.path.sep))
         # TODO: as the 'HEAD' reference is working tree specific, it
         # TODO: as the 'HEAD' reference is working tree specific, it
         # should actually not be a part of RefsContainer
         # should actually not be a part of RefsContainer
-        if name == b'HEAD':
+        if name == b"HEAD":
             return os.path.join(self.worktree_path, name)
             return os.path.join(self.worktree_path, name)
         else:
         else:
             return os.path.join(self.path, name)
             return os.path.join(self.path, name)
@@ -668,15 +763,14 @@ class DiskRefsContainer(RefsContainer):
             # None if and only if _packed_refs is also None.
             # None if and only if _packed_refs is also None.
             self._packed_refs = {}
             self._packed_refs = {}
             self._peeled_refs = {}
             self._peeled_refs = {}
-            path = os.path.join(self.path, b'packed-refs')
+            path = os.path.join(self.path, b"packed-refs")
             try:
             try:
-                f = GitFile(path, 'rb')
+                f = GitFile(path, "rb")
             except FileNotFoundError:
             except FileNotFoundError:
                 return {}
                 return {}
             with f:
             with f:
                 first_line = next(iter(f)).rstrip()
                 first_line = next(iter(f)).rstrip()
-                if (first_line.startswith(b'# pack-refs') and b' peeled' in
-                        first_line):
+                if first_line.startswith(b"# pack-refs") and b" peeled" in first_line:
                     for sha, name, peeled in read_packed_refs_with_peeled(f):
                     for sha, name, peeled in read_packed_refs_with_peeled(f):
                         self._packed_refs[name] = sha
                         self._packed_refs[name] = sha
                         if peeled:
                         if peeled:
@@ -721,11 +815,11 @@ class DiskRefsContainer(RefsContainer):
         """
         """
         filename = self.refpath(name)
         filename = self.refpath(name)
         try:
         try:
-            with GitFile(filename, 'rb') as f:
+            with GitFile(filename, "rb") as f:
                 header = f.read(len(SYMREF))
                 header = f.read(len(SYMREF))
                 if header == SYMREF:
                 if header == SYMREF:
                     # Read only the first line
                     # Read only the first line
-                    return header + next(iter(f)).rstrip(b'\r\n')
+                    return header + next(iter(f)).rstrip(b"\r\n")
                 else:
                 else:
                     # Read only the first 40 bytes
                     # Read only the first 40 bytes
                     return header + f.read(40 - len(SYMREF))
                     return header + f.read(40 - len(SYMREF))
@@ -735,9 +829,9 @@ class DiskRefsContainer(RefsContainer):
     def _remove_packed_ref(self, name):
     def _remove_packed_ref(self, name):
         if self._packed_refs is None:
         if self._packed_refs is None:
             return
             return
-        filename = os.path.join(self.path, b'packed-refs')
+        filename = os.path.join(self.path, b"packed-refs")
         # reread cached refs from disk, while holding the lock
         # reread cached refs from disk, while holding the lock
-        f = GitFile(filename, 'wb')
+        f = GitFile(filename, "wb")
         try:
         try:
             self._packed_refs = None
             self._packed_refs = None
             self.get_packed_refs()
             self.get_packed_refs()
@@ -753,8 +847,15 @@ class DiskRefsContainer(RefsContainer):
         finally:
         finally:
             f.abort()
             f.abort()
 
 
-    def set_symbolic_ref(self, name, other, committer=None, timestamp=None,
-                         timezone=None, message=None):
+    def set_symbolic_ref(
+        self,
+        name,
+        other,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
         Args:
         Args:
@@ -765,21 +866,35 @@ class DiskRefsContainer(RefsContainer):
         self._check_refname(name)
         self._check_refname(name)
         self._check_refname(other)
         self._check_refname(other)
         filename = self.refpath(name)
         filename = self.refpath(name)
-        f = GitFile(filename, 'wb')
+        f = GitFile(filename, "wb")
         try:
         try:
-            f.write(SYMREF + other + b'\n')
+            f.write(SYMREF + other + b"\n")
             sha = self.follow(name)[-1]
             sha = self.follow(name)[-1]
-            self._log(name, sha, sha, committer=committer,
-                      timestamp=timestamp, timezone=timezone,
-                      message=message)
+            self._log(
+                name,
+                sha,
+                sha,
+                committer=committer,
+                timestamp=timestamp,
+                timezone=timezone,
+                message=message,
+            )
         except BaseException:
         except BaseException:
             f.abort()
             f.abort()
             raise
             raise
         else:
         else:
             f.close()
             f.close()
 
 
-    def set_if_equals(self, name, old_ref, new_ref, committer=None,
-                      timestamp=None, timezone=None, message=None):
+    def set_if_equals(
+        self,
+        name,
+        old_ref,
+        new_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
         This method follows all symbolic references, and can be used to perform
         This method follows all symbolic references, and can be used to perform
@@ -810,14 +925,13 @@ class DiskRefsContainer(RefsContainer):
             probe_ref = os.path.dirname(probe_ref)
             probe_ref = os.path.dirname(probe_ref)
 
 
         ensure_dir_exists(os.path.dirname(filename))
         ensure_dir_exists(os.path.dirname(filename))
-        with GitFile(filename, 'wb') as f:
+        with GitFile(filename, "wb") as f:
             if old_ref is not None:
             if old_ref is not None:
                 try:
                 try:
                     # read again while holding the lock
                     # read again while holding the lock
                     orig_ref = self.read_loose_ref(realname)
                     orig_ref = self.read_loose_ref(realname)
                     if orig_ref is None:
                     if orig_ref is None:
-                        orig_ref = self.get_packed_refs().get(
-                                realname, ZERO_SHA)
+                        orig_ref = self.get_packed_refs().get(realname, ZERO_SHA)
                     if orig_ref != old_ref:
                     if orig_ref != old_ref:
                         f.abort()
                         f.abort()
                         return False
                         return False
@@ -825,16 +939,30 @@ class DiskRefsContainer(RefsContainer):
                     f.abort()
                     f.abort()
                     raise
                     raise
             try:
             try:
-                f.write(new_ref + b'\n')
+                f.write(new_ref + b"\n")
             except (OSError, IOError):
             except (OSError, IOError):
                 f.abort()
                 f.abort()
                 raise
                 raise
-            self._log(realname, old_ref, new_ref, committer=committer,
-                      timestamp=timestamp, timezone=timezone, message=message)
+            self._log(
+                realname,
+                old_ref,
+                new_ref,
+                committer=committer,
+                timestamp=timestamp,
+                timezone=timezone,
+                message=message,
+            )
         return True
         return True
 
 
-    def add_if_new(self, name, ref, committer=None, timestamp=None,
-                   timezone=None, message=None):
+    def add_if_new(
+        self,
+        name,
+        ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Add a new reference only if it does not already exist.
         """Add a new reference only if it does not already exist.
 
 
         This method follows symrefs, and only ensures that the last ref in the
         This method follows symrefs, and only ensures that the last ref in the
@@ -856,23 +984,36 @@ class DiskRefsContainer(RefsContainer):
         self._check_refname(realname)
         self._check_refname(realname)
         filename = self.refpath(realname)
         filename = self.refpath(realname)
         ensure_dir_exists(os.path.dirname(filename))
         ensure_dir_exists(os.path.dirname(filename))
-        with GitFile(filename, 'wb') as f:
+        with GitFile(filename, "wb") as f:
             if os.path.exists(filename) or name in self.get_packed_refs():
             if os.path.exists(filename) or name in self.get_packed_refs():
                 f.abort()
                 f.abort()
                 return False
                 return False
             try:
             try:
-                f.write(ref + b'\n')
+                f.write(ref + b"\n")
             except (OSError, IOError):
             except (OSError, IOError):
                 f.abort()
                 f.abort()
                 raise
                 raise
             else:
             else:
-                self._log(name, None, ref, committer=committer,
-                          timestamp=timestamp, timezone=timezone,
-                          message=message)
+                self._log(
+                    name,
+                    None,
+                    ref,
+                    committer=committer,
+                    timestamp=timestamp,
+                    timezone=timezone,
+                    message=message,
+                )
         return True
         return True
 
 
-    def remove_if_equals(self, name, old_ref, committer=None, timestamp=None,
-                         timezone=None, message=None):
+    def remove_if_equals(
+        self,
+        name,
+        old_ref,
+        committer=None,
+        timestamp=None,
+        timezone=None,
+        message=None,
+    ):
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
         This method does not follow symbolic references. It can be used to
         This method does not follow symbolic references. It can be used to
@@ -888,7 +1029,7 @@ class DiskRefsContainer(RefsContainer):
         self._check_refname(name)
         self._check_refname(name)
         filename = self.refpath(name)
         filename = self.refpath(name)
         ensure_dir_exists(os.path.dirname(filename))
         ensure_dir_exists(os.path.dirname(filename))
-        f = GitFile(filename, 'wb')
+        f = GitFile(filename, "wb")
         try:
         try:
             if old_ref is not None:
             if old_ref is not None:
                 orig_ref = self.read_loose_ref(name)
                 orig_ref = self.read_loose_ref(name)
@@ -904,8 +1045,15 @@ class DiskRefsContainer(RefsContainer):
                 pass  # may only be packed
                 pass  # may only be packed
 
 
             self._remove_packed_ref(name)
             self._remove_packed_ref(name)
-            self._log(name, old_ref, None, committer=committer,
-                      timestamp=timestamp, timezone=timezone, message=message)
+            self._log(
+                name,
+                old_ref,
+                None,
+                committer=committer,
+                timestamp=timestamp,
+                timezone=timezone,
+                message=message,
+            )
         finally:
         finally:
             # never write, we just wanted the lock
             # never write, we just wanted the lock
             f.abort()
             f.abort()
@@ -916,10 +1064,12 @@ class DiskRefsContainer(RefsContainer):
         parent = name
         parent = name
         while True:
         while True:
             try:
             try:
-                parent, _ = parent.rsplit(b'/', 1)
+                parent, _ = parent.rsplit(b"/", 1)
             except ValueError:
             except ValueError:
                 break
                 break
 
 
+            if parent == b'refs':
+                break
             parent_filename = self.refpath(parent)
             parent_filename = self.refpath(parent)
             try:
             try:
                 os.rmdir(parent_filename)
                 os.rmdir(parent_filename)
@@ -934,12 +1084,13 @@ class DiskRefsContainer(RefsContainer):
 
 
     def watch(self):
     def watch(self):
         import pyinotify  # noqa: F401
         import pyinotify  # noqa: F401
+
         return _InotifyRefsWatcher(self.path)
         return _InotifyRefsWatcher(self.path)
 
 
 
 
 def _split_ref_line(line):
 def _split_ref_line(line):
     """Split a single ref line into a tuple of SHA1 and name."""
     """Split a single ref line into a tuple of SHA1 and name."""
-    fields = line.rstrip(b'\n\r').split(b' ')
+    fields = line.rstrip(b"\n\r").split(b" ")
     if len(fields) != 2:
     if len(fields) != 2:
         raise PackedRefsException("invalid ref line %r" % line)
         raise PackedRefsException("invalid ref line %r" % line)
     sha, name = fields
     sha, name = fields
@@ -958,12 +1109,11 @@ def read_packed_refs(f):
     Returns: Iterator over tuples with SHA1s and ref names.
     Returns: Iterator over tuples with SHA1s and ref names.
     """
     """
     for line in f:
     for line in f:
-        if line.startswith(b'#'):
+        if line.startswith(b"#"):
             # Comment
             # Comment
             continue
             continue
-        if line.startswith(b'^'):
-            raise PackedRefsException(
-              "found peeled ref in packed-refs without peeled")
+        if line.startswith(b"^"):
+            raise PackedRefsException("found peeled ref in packed-refs without peeled")
         yield _split_ref_line(line)
         yield _split_ref_line(line)
 
 
 
 
@@ -978,10 +1128,10 @@ def read_packed_refs_with_peeled(f):
     """
     """
     last = None
     last = None
     for line in f:
     for line in f:
-        if line[0] == b'#':
+        if line[0] == b"#":
             continue
             continue
-        line = line.rstrip(b'\r\n')
-        if line.startswith(b'^'):
+        line = line.rstrip(b"\r\n")
+        if line.startswith(b"^"):
             if not last:
             if not last:
                 raise PackedRefsException("unexpected peeled ref line")
                 raise PackedRefsException("unexpected peeled ref line")
             if not valid_hexsha(line[1:]):
             if not valid_hexsha(line[1:]):
@@ -1010,11 +1160,11 @@ def write_packed_refs(f, packed_refs, peeled_refs=None):
     if peeled_refs is None:
     if peeled_refs is None:
         peeled_refs = {}
         peeled_refs = {}
     else:
     else:
-        f.write(b'# pack-refs with: peeled\n')
+        f.write(b"# pack-refs with: peeled\n")
     for refname in sorted(packed_refs.keys()):
     for refname in sorted(packed_refs.keys()):
         f.write(git_line(packed_refs[refname], refname))
         f.write(git_line(packed_refs[refname], refname))
         if refname in peeled_refs:
         if refname in peeled_refs:
-            f.write(b'^' + peeled_refs[refname] + b'\n')
+            f.write(b"^" + peeled_refs[refname] + b"\n")
 
 
 
 
 def read_info_refs(f):
 def read_info_refs(f):
@@ -1030,16 +1180,16 @@ def write_info_refs(refs, store):
     for name, sha in sorted(refs.items()):
     for name, sha in sorted(refs.items()):
         # get_refs() includes HEAD as a special case, but we don't want to
         # get_refs() includes HEAD as a special case, but we don't want to
         # advertise it
         # advertise it
-        if name == b'HEAD':
+        if name == b"HEAD":
             continue
             continue
         try:
         try:
             o = store[sha]
             o = store[sha]
         except KeyError:
         except KeyError:
             continue
             continue
         peeled = store.peel_sha(sha)
         peeled = store.peel_sha(sha)
-        yield o.id + b'\t' + name + b'\n'
+        yield o.id + b"\t" + name + b"\n"
         if o.id != peeled.id:
         if o.id != peeled.id:
-            yield peeled.id + b'\t' + name + ANNOTATED_TAG_SUFFIX + b'\n'
+            yield peeled.id + b"\t" + name + ANNOTATED_TAG_SUFFIX + b"\n"
 
 
 
 
 def is_local_branch(x):
 def is_local_branch(x):
@@ -1048,5 +1198,8 @@ def is_local_branch(x):
 
 
 def strip_peeled_refs(refs):
 def strip_peeled_refs(refs):
     """Remove all peeled refs"""
     """Remove all peeled refs"""
-    return {ref: sha for (ref, sha) in refs.items()
-            if not ref.endswith(ANNOTATED_TAG_SUFFIX)}
+    return {
+        ref: sha
+        for (ref, sha) in refs.items()
+        if not ref.endswith(ANNOTATED_TAG_SUFFIX)
+    }

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 278 - 191
dulwich/repo.py


+ 209 - 157
dulwich/server.py

@@ -61,15 +61,15 @@ from dulwich.errors import (
     NotGitRepository,
     NotGitRepository,
     UnexpectedCommandError,
     UnexpectedCommandError,
     ObjectFormatException,
     ObjectFormatException,
-    )
+)
 from dulwich import log_utils
 from dulwich import log_utils
 from dulwich.objects import (
 from dulwich.objects import (
     Commit,
     Commit,
     valid_hexsha,
     valid_hexsha,
-    )
+)
 from dulwich.pack import (
 from dulwich.pack import (
     write_pack_objects,
     write_pack_objects,
-    )
+)
 from dulwich.protocol import (  # noqa: F401
 from dulwich.protocol import (  # noqa: F401
     BufferedPktLineWriter,
     BufferedPktLineWriter,
     capability_agent,
     capability_agent,
@@ -108,15 +108,15 @@ from dulwich.protocol import (  # noqa: F401
     extract_capabilities,
     extract_capabilities,
     extract_want_line_capabilities,
     extract_want_line_capabilities,
     symref_capabilities,
     symref_capabilities,
-    )
+)
 from dulwich.refs import (
 from dulwich.refs import (
     ANNOTATED_TAG_SUFFIX,
     ANNOTATED_TAG_SUFFIX,
     write_info_refs,
     write_info_refs,
-    )
+)
 from dulwich.repo import (
 from dulwich.repo import (
     BaseRepo,
     BaseRepo,
     Repo,
     Repo,
-    )
+)
 
 
 
 
 logger = log_utils.getLogger(__name__)
 logger = log_utils.getLogger(__name__)
@@ -167,8 +167,7 @@ class BackendRepo(object):
         """
         """
         return None
         return None
 
 
-    def fetch_objects(self, determine_wants, graph_walker, progress,
-                      get_tagged=None):
+    def fetch_objects(self, determine_wants, graph_walker, progress, get_tagged=None):
         """
         """
         Yield the objects required for a list of commits.
         Yield the objects required for a list of commits.
 
 
@@ -187,7 +186,7 @@ class DictBackend(Backend):
         self.repos = repos
         self.repos = repos
 
 
     def open_repository(self, path: str) -> BaseRepo:
     def open_repository(self, path: str) -> BaseRepo:
-        logger.debug('Opening repository at %s', path)
+        logger.debug("Opening repository at %s", path)
         try:
         try:
             return self.repos[path]
             return self.repos[path]
         except KeyError:
         except KeyError:
@@ -201,18 +200,15 @@ class FileSystemBackend(Backend):
 
 
     def __init__(self, root=os.sep):
     def __init__(self, root=os.sep):
         super(FileSystemBackend, self).__init__()
         super(FileSystemBackend, self).__init__()
-        self.root = (os.path.abspath(root) + os.sep).replace(
-                os.sep * 2, os.sep)
+        self.root = (os.path.abspath(root) + os.sep).replace(os.sep * 2, os.sep)
 
 
     def open_repository(self, path):
     def open_repository(self, path):
-        logger.debug('opening repository at %s', path)
+        logger.debug("opening repository at %s", path)
         abspath = os.path.abspath(os.path.join(self.root, path)) + os.sep
         abspath = os.path.abspath(os.path.join(self.root, path)) + os.sep
         normcase_abspath = os.path.normcase(abspath)
         normcase_abspath = os.path.normcase(abspath)
         normcase_root = os.path.normcase(self.root)
         normcase_root = os.path.normcase(self.root)
         if not normcase_abspath.startswith(normcase_root):
         if not normcase_abspath.startswith(normcase_root):
-            raise NotGitRepository(
-                    "Path %r not inside root %r" %
-                    (path, self.root))
+            raise NotGitRepository("Path %r not inside root %r" % (path, self.root))
         return Repo(abspath)
         return Repo(abspath)
 
 
 
 
@@ -239,7 +235,7 @@ class PackHandler(Handler):
 
 
     @classmethod
     @classmethod
     def capability_line(cls, capabilities):
     def capability_line(cls, capabilities):
-        logger.info('Sending capabilities: %s', capabilities)
+        logger.info("Sending capabilities: %s", capabilities)
         return b"".join([b" " + c for c in capabilities])
         return b"".join([b" " + c for c in capabilities])
 
 
     @classmethod
     @classmethod
@@ -248,9 +244,13 @@ class PackHandler(Handler):
 
 
     @classmethod
     @classmethod
     def innocuous_capabilities(cls) -> Iterable[bytes]:
     def innocuous_capabilities(cls) -> Iterable[bytes]:
-        return [CAPABILITY_INCLUDE_TAG, CAPABILITY_THIN_PACK,
-                CAPABILITY_NO_PROGRESS, CAPABILITY_OFS_DELTA,
-                capability_agent()]
+        return [
+            CAPABILITY_INCLUDE_TAG,
+            CAPABILITY_THIN_PACK,
+            CAPABILITY_NO_PROGRESS,
+            CAPABILITY_OFS_DELTA,
+            capability_agent(),
+        ]
 
 
     @classmethod
     @classmethod
     def required_capabilities(cls) -> Iterable[bytes]:
     def required_capabilities(cls) -> Iterable[bytes]:
@@ -261,22 +261,25 @@ class PackHandler(Handler):
         allowable_caps = set(self.innocuous_capabilities())
         allowable_caps = set(self.innocuous_capabilities())
         allowable_caps.update(self.capabilities())
         allowable_caps.update(self.capabilities())
         for cap in caps:
         for cap in caps:
-            if cap.startswith(CAPABILITY_AGENT + b'='):
+            if cap.startswith(CAPABILITY_AGENT + b"="):
                 continue
                 continue
             if cap not in allowable_caps:
             if cap not in allowable_caps:
-                raise GitProtocolError('Client asked for capability %r that '
-                                       'was not advertised.' % cap)
+                raise GitProtocolError(
+                    "Client asked for capability %r that " "was not advertised." % cap
+                )
         for cap in self.required_capabilities():
         for cap in self.required_capabilities():
             if cap not in caps:
             if cap not in caps:
-                raise GitProtocolError('Client does not support required '
-                                       'capability %r.' % cap)
+                raise GitProtocolError(
+                    "Client does not support required " "capability %r." % cap
+                )
         self._client_capabilities = set(caps)
         self._client_capabilities = set(caps)
-        logger.info('Client capabilities: %s', caps)
+        logger.info("Client capabilities: %s", caps)
 
 
     def has_capability(self, cap: bytes) -> bool:
     def has_capability(self, cap: bytes) -> bool:
         if self._client_capabilities is None:
         if self._client_capabilities is None:
-            raise GitProtocolError('Server attempted to access capability %r '
-                                   'before asking client' % cap)
+            raise GitProtocolError(
+                "Server attempted to access capability %r " "before asking client" % cap
+            )
         return cap in self._client_capabilities
         return cap in self._client_capabilities
 
 
     def notify_done(self) -> None:
     def notify_done(self) -> None:
@@ -286,10 +289,10 @@ class PackHandler(Handler):
 class UploadPackHandler(PackHandler):
 class UploadPackHandler(PackHandler):
     """Protocol handler for uploading a pack to the client."""
     """Protocol handler for uploading a pack to the client."""
 
 
-    def __init__(self, backend, args, proto, stateless_rpc=None,
-                 advertise_refs=False):
+    def __init__(self, backend, args, proto, stateless_rpc=None, advertise_refs=False):
         super(UploadPackHandler, self).__init__(
         super(UploadPackHandler, self).__init__(
-                backend, proto, stateless_rpc=stateless_rpc)
+            backend, proto, stateless_rpc=stateless_rpc
+        )
         self.repo = backend.open_repository(args[0])
         self.repo = backend.open_repository(args[0])
         self._graph_walker = None
         self._graph_walker = None
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
@@ -300,19 +303,28 @@ class UploadPackHandler(PackHandler):
 
 
     @classmethod
     @classmethod
     def capabilities(cls):
     def capabilities(cls):
-        return [CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_MULTI_ACK,
-                CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
-                CAPABILITY_OFS_DELTA, CAPABILITY_NO_PROGRESS,
-                CAPABILITY_INCLUDE_TAG, CAPABILITY_SHALLOW, CAPABILITY_NO_DONE]
+        return [
+            CAPABILITY_MULTI_ACK_DETAILED,
+            CAPABILITY_MULTI_ACK,
+            CAPABILITY_SIDE_BAND_64K,
+            CAPABILITY_THIN_PACK,
+            CAPABILITY_OFS_DELTA,
+            CAPABILITY_NO_PROGRESS,
+            CAPABILITY_INCLUDE_TAG,
+            CAPABILITY_SHALLOW,
+            CAPABILITY_NO_DONE,
+        ]
 
 
     @classmethod
     @classmethod
     def required_capabilities(cls):
     def required_capabilities(cls):
-        return (CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
-                CAPABILITY_OFS_DELTA)
+        return (
+            CAPABILITY_SIDE_BAND_64K,
+            CAPABILITY_THIN_PACK,
+            CAPABILITY_OFS_DELTA,
+        )
 
 
     def progress(self, message):
     def progress(self, message):
-        if (self.has_capability(CAPABILITY_NO_PROGRESS) or
-                self._processing_have_lines):
+        if self.has_capability(CAPABILITY_NO_PROGRESS) or self._processing_have_lines:
             return
             return
         self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
         self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
 
 
@@ -353,17 +365,23 @@ class UploadPackHandler(PackHandler):
             return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
             return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
 
 
         graph_walker = _ProtocolGraphWalker(
         graph_walker = _ProtocolGraphWalker(
-                self, self.repo.object_store, self.repo.get_peeled,
-                self.repo.refs.get_symrefs)
+            self,
+            self.repo.object_store,
+            self.repo.get_peeled,
+            self.repo.refs.get_symrefs,
+        )
         wants = []
         wants = []
 
 
-        def wants_wrapper(refs):
-            wants.extend(graph_walker.determine_wants(refs))
+        def wants_wrapper(refs, **kwargs):
+            wants.extend(graph_walker.determine_wants(refs, **kwargs))
             return wants
             return wants
 
 
         objects_iter = self.repo.fetch_objects(
         objects_iter = self.repo.fetch_objects(
-            wants_wrapper, graph_walker, self.progress,
-            get_tagged=self.get_tagged)
+            wants_wrapper,
+            graph_walker,
+            self.progress,
+            get_tagged=self.get_tagged,
+        )
 
 
         # Note the fact that client is only processing responses related
         # Note the fact that client is only processing responses related
         # to the have lines it sent, and any other data (including side-
         # to the have lines it sent, and any other data (including side-
@@ -384,13 +402,13 @@ class UploadPackHandler(PackHandler):
         self._processing_have_lines = False
         self._processing_have_lines = False
 
 
         if not graph_walker.handle_done(
         if not graph_walker.handle_done(
-                not self.has_capability(CAPABILITY_NO_DONE),
-                self._done_received):
+            not self.has_capability(CAPABILITY_NO_DONE), self._done_received
+        ):
             return
             return
 
 
         self.progress(
         self.progress(
-                ("counting objects: %d, done.\n" % len(objects_iter)).encode(
-                    'ascii'))
+            ("counting objects: %d, done.\n" % len(objects_iter)).encode("ascii")
+        )
         write_pack_objects(ProtocolFile(None, write), objects_iter)
         write_pack_objects(ProtocolFile(None, write), objects_iter)
         # we are done
         # we are done
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
@@ -418,21 +436,25 @@ def _split_proto_line(line, allowed):
     if not line:
     if not line:
         fields = [None]
         fields = [None]
     else:
     else:
-        fields = line.rstrip(b'\n').split(b' ', 1)
+        fields = line.rstrip(b"\n").split(b" ", 1)
     command = fields[0]
     command = fields[0]
     if allowed is not None and command not in allowed:
     if allowed is not None and command not in allowed:
         raise UnexpectedCommandError(command)
         raise UnexpectedCommandError(command)
     if len(fields) == 1 and command in (COMMAND_DONE, None):
     if len(fields) == 1 and command in (COMMAND_DONE, None):
         return (command, None)
         return (command, None)
     elif len(fields) == 2:
     elif len(fields) == 2:
-        if command in (COMMAND_WANT, COMMAND_HAVE, COMMAND_SHALLOW,
-                       COMMAND_UNSHALLOW):
+        if command in (
+            COMMAND_WANT,
+            COMMAND_HAVE,
+            COMMAND_SHALLOW,
+            COMMAND_UNSHALLOW,
+        ):
             if not valid_hexsha(fields[1]):
             if not valid_hexsha(fields[1]):
                 raise GitProtocolError("Invalid sha")
                 raise GitProtocolError("Invalid sha")
             return tuple(fields)
             return tuple(fields)
         elif command == COMMAND_DEEPEN:
         elif command == COMMAND_DEEPEN:
             return command, int(fields[1])
             return command, int(fields[1])
-    raise GitProtocolError('Received invalid line from client: %r' % line)
+    raise GitProtocolError("Received invalid line from client: %r" % line)
 
 
 
 
 def _find_shallow(store, heads, depth):
 def _find_shallow(store, heads, depth):
@@ -533,6 +555,7 @@ class _ProtocolGraphWalker(object):
     call to set_ack_type() is required to set up the implementation, before
     call to set_ack_type() is required to set up the implementation, before
     any calls to next() or ack() are made.
     any calls to next() or ack() are made.
     """
     """
+
     def __init__(self, handler, object_store, get_peeled, get_symrefs):
     def __init__(self, handler, object_store, get_peeled, get_symrefs):
         self.handler = handler
         self.handler = handler
         self.store = object_store
         self.store = object_store
@@ -550,7 +573,7 @@ class _ProtocolGraphWalker(object):
         self._cache_index = 0
         self._cache_index = 0
         self._impl = None
         self._impl = None
 
 
-    def determine_wants(self, heads):
+    def determine_wants(self, heads, depth=None):
         """Determine the wants for a set of heads.
         """Determine the wants for a set of heads.
 
 
         The given heads are advertised to the client, who then specifies which
         The given heads are advertised to the client, who then specifies which
@@ -579,16 +602,17 @@ class _ProtocolGraphWalker(object):
                     # TODO(jelmer): Integrate with Repo.fetch_objects refs
                     # TODO(jelmer): Integrate with Repo.fetch_objects refs
                     # logic.
                     # logic.
                     continue
                     continue
-                line = sha + b' ' + ref
+                line = sha + b" " + ref
                 if not i:
                 if not i:
-                    line += (b'\x00' +
-                             self.handler.capability_line(
-                                 self.handler.capabilities() +
-                                 symref_capabilities(symrefs.items())))
-                self.proto.write_pkt_line(line + b'\n')
+                    line += b"\x00" + self.handler.capability_line(
+                        self.handler.capabilities()
+                        + symref_capabilities(symrefs.items())
+                    )
+                self.proto.write_pkt_line(line + b"\n")
                 if peeled_sha != sha:
                 if peeled_sha != sha:
                     self.proto.write_pkt_line(
                     self.proto.write_pkt_line(
-                        peeled_sha + b' ' + ref + ANNOTATED_TAG_SUFFIX + b'\n')
+                        peeled_sha + b" " + ref + ANNOTATED_TAG_SUFFIX + b"\n"
+                    )
 
 
             # i'm done..
             # i'm done..
             self.proto.write_pkt_line(None)
             self.proto.write_pkt_line(None)
@@ -609,8 +633,7 @@ class _ProtocolGraphWalker(object):
         want_revs = []
         want_revs = []
         while command == COMMAND_WANT:
         while command == COMMAND_WANT:
             if sha not in values:
             if sha not in values:
-                raise GitProtocolError(
-                  'Client wants invalid object %s' % sha)
+                raise GitProtocolError("Client wants invalid object %s" % sha)
             want_revs.append(sha)
             want_revs.append(sha)
             command, sha = self.read_proto_line(allowed)
             command, sha = self.read_proto_line(allowed)
 
 
@@ -630,8 +653,8 @@ class _ProtocolGraphWalker(object):
 
 
     def unread_proto_line(self, command, value):
     def unread_proto_line(self, command, value):
         if isinstance(value, int):
         if isinstance(value, int):
-            value = str(value).encode('ascii')
-        self.proto.unread_pkt_line(command + b' ' + value)
+            value = str(value).encode("ascii")
+        self.proto.unread_pkt_line(command + b" " + value)
 
 
     def ack(self, have_ref):
     def ack(self, have_ref):
         if len(have_ref) != 40:
         if len(have_ref) != 40:
@@ -667,8 +690,7 @@ class _ProtocolGraphWalker(object):
 
 
     def _handle_shallow_request(self, wants):
     def _handle_shallow_request(self, wants):
         while True:
         while True:
-            command, val = self.read_proto_line(
-                    (COMMAND_DEEPEN, COMMAND_SHALLOW))
+            command, val = self.read_proto_line((COMMAND_DEEPEN, COMMAND_SHALLOW))
             if command == COMMAND_DEEPEN:
             if command == COMMAND_DEEPEN:
                 depth = val
                 depth = val
                 break
                 break
@@ -684,9 +706,9 @@ class _ProtocolGraphWalker(object):
         unshallow = self.unshallow = not_shallow & self.client_shallow
         unshallow = self.unshallow = not_shallow & self.client_shallow
 
 
         for sha in sorted(new_shallow):
         for sha in sorted(new_shallow):
-            self.proto.write_pkt_line(COMMAND_SHALLOW + b' ' + sha)
+            self.proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha)
         for sha in sorted(unshallow):
         for sha in sorted(unshallow):
-            self.proto.write_pkt_line(COMMAND_UNSHALLOW + b' ' + sha)
+            self.proto.write_pkt_line(COMMAND_UNSHALLOW + b" " + sha)
 
 
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
 
 
@@ -694,13 +716,13 @@ class _ProtocolGraphWalker(object):
         # relay the message down to the handler.
         # relay the message down to the handler.
         self.handler.notify_done()
         self.handler.notify_done()
 
 
-    def send_ack(self, sha, ack_type=b''):
+    def send_ack(self, sha, ack_type=b""):
         if ack_type:
         if ack_type:
-            ack_type = b' ' + ack_type
-        self.proto.write_pkt_line(b'ACK ' + sha + ack_type + b'\n')
+            ack_type = b" " + ack_type
+        self.proto.write_pkt_line(b"ACK " + sha + ack_type + b"\n")
 
 
     def send_nak(self):
     def send_nak(self):
-        self.proto.write_pkt_line(b'NAK\n')
+        self.proto.write_pkt_line(b"NAK\n")
 
 
     def handle_done(self, done_required, done_received):
     def handle_done(self, done_required, done_received):
         # Delegate this to the implementation.
         # Delegate this to the implementation.
@@ -721,10 +743,10 @@ class _ProtocolGraphWalker(object):
 
 
     def set_ack_type(self, ack_type):
     def set_ack_type(self, ack_type):
         impl_classes = {
         impl_classes = {
-          MULTI_ACK: MultiAckGraphWalkerImpl,
-          MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
-          SINGLE_ACK: SingleAckGraphWalkerImpl,
-          }
+            MULTI_ACK: MultiAckGraphWalkerImpl,
+            MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
+            SINGLE_ACK: SingleAckGraphWalkerImpl,
+        }
         self._impl = impl_classes[ack_type](self)
         self._impl = impl_classes[ack_type](self)
 
 
 
 
@@ -786,7 +808,7 @@ class MultiAckGraphWalkerImpl(object):
     def ack(self, have_ref):
     def ack(self, have_ref):
         self._common.append(have_ref)
         self._common.append(have_ref)
         if not self._found_base:
         if not self._found_base:
-            self.walker.send_ack(have_ref, b'continue')
+            self.walker.send_ack(have_ref, b"continue")
             if self.walker.all_wants_satisfied(self._common):
             if self.walker.all_wants_satisfied(self._common):
                 self._found_base = True
                 self._found_base = True
         # else we blind ack within next
         # else we blind ack within next
@@ -805,7 +827,7 @@ class MultiAckGraphWalkerImpl(object):
             elif command == COMMAND_HAVE:
             elif command == COMMAND_HAVE:
                 if self._found_base:
                 if self._found_base:
                     # blind ack
                     # blind ack
-                    self.walker.send_ack(sha, b'continue')
+                    self.walker.send_ack(sha, b"continue")
                 return sha
                 return sha
 
 
     __next__ = next
     __next__ = next
@@ -844,14 +866,14 @@ class MultiAckDetailedGraphWalkerImpl(object):
     def ack(self, have_ref):
     def ack(self, have_ref):
         # Should only be called iff have_ref is common
         # Should only be called iff have_ref is common
         self._common.append(have_ref)
         self._common.append(have_ref)
-        self.walker.send_ack(have_ref, b'common')
+        self.walker.send_ack(have_ref, b"common")
 
 
     def next(self):
     def next(self):
         while True:
         while True:
             command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
             command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
             if command is None:
             if command is None:
                 if self.walker.all_wants_satisfied(self._common):
                 if self.walker.all_wants_satisfied(self._common):
-                    self.walker.send_ack(self._common[-1], b'ready')
+                    self.walker.send_ack(self._common[-1], b"ready")
                 self.walker.send_nak()
                 self.walker.send_nak()
                 if self.walker.stateless_rpc:
                 if self.walker.stateless_rpc:
                     # The HTTP version of this request a flush-pkt always
                     # The HTTP version of this request a flush-pkt always
@@ -902,25 +924,37 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(PackHandler):
 class ReceivePackHandler(PackHandler):
     """Protocol handler for downloading a pack from the client."""
     """Protocol handler for downloading a pack from the client."""
 
 
-    def __init__(self, backend, args, proto, stateless_rpc=None,
-                 advertise_refs=False):
+    def __init__(self, backend, args, proto, stateless_rpc=None, advertise_refs=False):
         super(ReceivePackHandler, self).__init__(
         super(ReceivePackHandler, self).__init__(
-                backend, proto, stateless_rpc=stateless_rpc)
+            backend, proto, stateless_rpc=stateless_rpc
+        )
         self.repo = backend.open_repository(args[0])
         self.repo = backend.open_repository(args[0])
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
 
 
     @classmethod
     @classmethod
     def capabilities(cls) -> Iterable[bytes]:
     def capabilities(cls) -> Iterable[bytes]:
-        return [CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS,
-                CAPABILITY_QUIET, CAPABILITY_OFS_DELTA,
-                CAPABILITY_SIDE_BAND_64K, CAPABILITY_NO_DONE]
+        return [
+            CAPABILITY_REPORT_STATUS,
+            CAPABILITY_DELETE_REFS,
+            CAPABILITY_QUIET,
+            CAPABILITY_OFS_DELTA,
+            CAPABILITY_SIDE_BAND_64K,
+            CAPABILITY_NO_DONE,
+        ]
 
 
     def _apply_pack(
     def _apply_pack(
-            self, refs: List[Tuple[bytes, bytes, bytes]]
-            ) -> List[Tuple[bytes, bytes]]:
-        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
-                          AssertionError, socket.error, zlib.error,
-                          ObjectFormatException)
+        self, refs: List[Tuple[bytes, bytes, bytes]]
+    ) -> List[Tuple[bytes, bytes]]:
+        all_exceptions = (
+            IOError,
+            OSError,
+            ChecksumMismatch,
+            ApplyDeltaError,
+            AssertionError,
+            socket.error,
+            zlib.error,
+            ObjectFormatException,
+        )
         status = []
         status = []
         will_send_pack = False
         will_send_pack = False
 
 
@@ -934,36 +968,36 @@ class ReceivePackHandler(PackHandler):
             try:
             try:
                 recv = getattr(self.proto, "recv", None)
                 recv = getattr(self.proto, "recv", None)
                 self.repo.object_store.add_thin_pack(self.proto.read, recv)
                 self.repo.object_store.add_thin_pack(self.proto.read, recv)
-                status.append((b'unpack', b'ok'))
+                status.append((b"unpack", b"ok"))
             except all_exceptions as e:
             except all_exceptions as e:
-                status.append(
-                    (b'unpack', str(e).replace('\n', '').encode('utf-8')))
+                status.append((b"unpack", str(e).replace("\n", "").encode("utf-8")))
                 # The pack may still have been moved in, but it may contain
                 # The pack may still have been moved in, but it may contain
                 # broken objects. We trust a later GC to clean it up.
                 # broken objects. We trust a later GC to clean it up.
         else:
         else:
             # The git protocol want to find a status entry related to unpack
             # The git protocol want to find a status entry related to unpack
             # process even if no pack data has been sent.
             # process even if no pack data has been sent.
-            status.append((b'unpack', b'ok'))
+            status.append((b"unpack", b"ok"))
 
 
         for oldsha, sha, ref in refs:
         for oldsha, sha, ref in refs:
-            ref_status = b'ok'
+            ref_status = b"ok"
             try:
             try:
                 if sha == ZERO_SHA:
                 if sha == ZERO_SHA:
                     if CAPABILITY_DELETE_REFS not in self.capabilities():
                     if CAPABILITY_DELETE_REFS not in self.capabilities():
                         raise GitProtocolError(
                         raise GitProtocolError(
-                          'Attempted to delete refs without delete-refs '
-                          'capability.')
+                            "Attempted to delete refs without delete-refs "
+                            "capability."
+                        )
                     try:
                     try:
                         self.repo.refs.remove_if_equals(ref, oldsha)
                         self.repo.refs.remove_if_equals(ref, oldsha)
                     except all_exceptions:
                     except all_exceptions:
-                        ref_status = b'failed to delete'
+                        ref_status = b"failed to delete"
                 else:
                 else:
                     try:
                     try:
                         self.repo.refs.set_if_equals(ref, oldsha, sha)
                         self.repo.refs.set_if_equals(ref, oldsha, sha)
                     except all_exceptions:
                     except all_exceptions:
-                        ref_status = b'failed to write'
+                        ref_status = b"failed to write"
             except KeyError:
             except KeyError:
-                ref_status = b'bad ref'
+                ref_status = b"bad ref"
             status.append((ref, ref_status))
             status.append((ref, ref_status))
 
 
         return status
         return status
@@ -971,12 +1005,14 @@ class ReceivePackHandler(PackHandler):
     def _report_status(self, status: List[Tuple[bytes, bytes]]) -> None:
     def _report_status(self, status: List[Tuple[bytes, bytes]]) -> None:
         if self.has_capability(CAPABILITY_SIDE_BAND_64K):
         if self.has_capability(CAPABILITY_SIDE_BAND_64K):
             writer = BufferedPktLineWriter(
             writer = BufferedPktLineWriter(
-              lambda d: self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, d))
+                lambda d: self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, d)
+            )
             write = writer.write
             write = writer.write
 
 
             def flush():
             def flush():
                 writer.flush()
                 writer.flush()
                 self.proto.write_pkt_line(None)
                 self.proto.write_pkt_line(None)
+
         else:
         else:
             write = self.proto.write_pkt_line
             write = self.proto.write_pkt_line
 
 
@@ -984,17 +1020,17 @@ class ReceivePackHandler(PackHandler):
                 pass
                 pass
 
 
         for name, msg in status:
         for name, msg in status:
-            if name == b'unpack':
-                write(b'unpack ' + msg + b'\n')
-            elif msg == b'ok':
-                write(b'ok ' + name + b'\n')
+            if name == b"unpack":
+                write(b"unpack " + msg + b"\n")
+            elif msg == b"ok":
+                write(b"ok " + name + b"\n")
             else:
             else:
-                write(b'ng ' + name + b' ' + msg + b'\n')
+                write(b"ng " + name + b" " + msg + b"\n")
         write(None)
         write(None)
         flush()
         flush()
 
 
     def _on_post_receive(self, client_refs):
     def _on_post_receive(self, client_refs):
-        hook = self.repo.hooks.get('post-receive', None)
+        hook = self.repo.hooks.get("post-receive", None)
         if not hook:
         if not hook:
             return
             return
         try:
         try:
@@ -1002,7 +1038,7 @@ class ReceivePackHandler(PackHandler):
             if output:
             if output:
                 self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, output)
                 self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, output)
         except HookError as err:
         except HookError as err:
-            self.proto.write_sideband(SIDE_BAND_CHANNEL_FATAL, repr(err))
+            self.proto.write_sideband(SIDE_BAND_CHANNEL_FATAL, str(err).encode('utf-8'))
 
 
     def handle(self) -> None:
     def handle(self) -> None:
         if self.advertise_refs or not self.stateless_rpc:
         if self.advertise_refs or not self.stateless_rpc:
@@ -1012,12 +1048,18 @@ class ReceivePackHandler(PackHandler):
             if not refs:
             if not refs:
                 refs = [(CAPABILITIES_REF, ZERO_SHA)]
                 refs = [(CAPABILITIES_REF, ZERO_SHA)]
             self.proto.write_pkt_line(
             self.proto.write_pkt_line(
-              refs[0][1] + b' ' + refs[0][0] + b'\0' +
-              self.capability_line(
-                  self.capabilities() + symref_capabilities(symrefs)) + b'\n')
+                refs[0][1]
+                + b" "
+                + refs[0][0]
+                + b"\0"
+                + self.capability_line(
+                    self.capabilities() + symref_capabilities(symrefs)
+                )
+                + b"\n"
+            )
             for i in range(1, len(refs)):
             for i in range(1, len(refs)):
                 ref = refs[i]
                 ref = refs[i]
-                self.proto.write_pkt_line(ref[1] + b' ' + ref[0] + b'\n')
+                self.proto.write_pkt_line(ref[1] + b" " + ref[0] + b"\n")
 
 
             self.proto.write_pkt_line(None)
             self.proto.write_pkt_line(None)
             if self.advertise_refs:
             if self.advertise_refs:
@@ -1050,55 +1092,54 @@ class ReceivePackHandler(PackHandler):
 
 
 
 
 class UploadArchiveHandler(Handler):
 class UploadArchiveHandler(Handler):
-
     def __init__(self, backend, args, proto, stateless_rpc=None):
     def __init__(self, backend, args, proto, stateless_rpc=None):
-        super(UploadArchiveHandler, self).__init__(
-            backend, proto, stateless_rpc)
+        super(UploadArchiveHandler, self).__init__(backend, proto, stateless_rpc)
         self.repo = backend.open_repository(args[0])
         self.repo = backend.open_repository(args[0])
 
 
     def handle(self):
     def handle(self):
         def write(x):
         def write(x):
             return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
             return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
+
         arguments = []
         arguments = []
         for pkt in self.proto.read_pkt_seq():
         for pkt in self.proto.read_pkt_seq():
-            (key, value) = pkt.split(b' ', 1)
-            if key != b'argument':
-                raise GitProtocolError('unknown command %s' % key)
-            arguments.append(value.rstrip(b'\n'))
-        prefix = b''
-        format = 'tar'
+            (key, value) = pkt.split(b" ", 1)
+            if key != b"argument":
+                raise GitProtocolError("unknown command %s" % key)
+            arguments.append(value.rstrip(b"\n"))
+        prefix = b""
+        format = "tar"
         i = 0
         i = 0
         store = self.repo.object_store
         store = self.repo.object_store
         while i < len(arguments):
         while i < len(arguments):
             argument = arguments[i]
             argument = arguments[i]
-            if argument == b'--prefix':
+            if argument == b"--prefix":
                 i += 1
                 i += 1
                 prefix = arguments[i]
                 prefix = arguments[i]
-            elif argument == b'--format':
+            elif argument == b"--format":
                 i += 1
                 i += 1
-                format = arguments[i].decode('ascii')
+                format = arguments[i].decode("ascii")
             else:
             else:
                 commit_sha = self.repo.refs[argument]
                 commit_sha = self.repo.refs[argument]
                 tree = store[store[commit_sha].tree]
                 tree = store[store[commit_sha].tree]
             i += 1
             i += 1
-        self.proto.write_pkt_line(b'ACK')
+        self.proto.write_pkt_line(b"ACK")
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
         for chunk in tar_stream(
         for chunk in tar_stream(
-                store, tree, mtime=time.time(), prefix=prefix, format=format):
+            store, tree, mtime=time.time(), prefix=prefix, format=format
+        ):
             write(chunk)
             write(chunk)
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
 
 
 
 
 # Default handler classes for git services.
 # Default handler classes for git services.
 DEFAULT_HANDLERS = {
 DEFAULT_HANDLERS = {
-  b'git-upload-pack': UploadPackHandler,
-  b'git-receive-pack': ReceivePackHandler,
-  b'git-upload-archive': UploadArchiveHandler,
+    b"git-upload-pack": UploadPackHandler,
+    b"git-receive-pack": ReceivePackHandler,
+    b"git-upload-archive": UploadArchiveHandler,
 }
 }
 
 
 
 
 class TCPGitRequestHandler(socketserver.StreamRequestHandler):
 class TCPGitRequestHandler(socketserver.StreamRequestHandler):
-
     def __init__(self, handlers, *args, **kwargs):
     def __init__(self, handlers, *args, **kwargs):
         self.handlers = handlers
         self.handlers = handlers
         socketserver.StreamRequestHandler.__init__(self, *args, **kwargs)
         socketserver.StreamRequestHandler.__init__(self, *args, **kwargs)
@@ -1106,11 +1147,11 @@ class TCPGitRequestHandler(socketserver.StreamRequestHandler):
     def handle(self):
     def handle(self):
         proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
         proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
         command, args = proto.read_cmd()
         command, args = proto.read_cmd()
-        logger.info('Handling %s request, args=%s', command, args)
+        logger.info("Handling %s request, args=%s", command, args)
 
 
         cls = self.handlers.get(command, None)
         cls = self.handlers.get(command, None)
         if not callable(cls):
         if not callable(cls):
-            raise GitProtocolError('Invalid service %s' % command)
+            raise GitProtocolError("Invalid service %s" % command)
         h = cls(self.server.backend, args, proto)
         h = cls(self.server.backend, args, proto)
         h.handle()
         h.handle()
 
 
@@ -1128,45 +1169,56 @@ class TCPGitServer(socketserver.TCPServer):
         if handlers is not None:
         if handlers is not None:
             self.handlers.update(handlers)
             self.handlers.update(handlers)
         self.backend = backend
         self.backend = backend
-        logger.info('Listening for TCP connections on %s:%d',
-                    listen_addr, port)
-        socketserver.TCPServer.__init__(self, (listen_addr, port),
-                                        self._make_handler)
+        logger.info("Listening for TCP connections on %s:%d", listen_addr, port)
+        socketserver.TCPServer.__init__(self, (listen_addr, port), self._make_handler)
 
 
     def verify_request(self, request, client_address):
     def verify_request(self, request, client_address):
-        logger.info('Handling request from %s', client_address)
+        logger.info("Handling request from %s", client_address)
         return True
         return True
 
 
     def handle_error(self, request, client_address):
     def handle_error(self, request, client_address):
-        logger.exception('Exception happened during processing of request '
-                         'from %s', client_address)
+        logger.exception(
+            "Exception happened during processing of request " "from %s",
+            client_address,
+        )
 
 
 
 
 def main(argv=sys.argv):
 def main(argv=sys.argv):
     """Entry point for starting a TCP git server."""
     """Entry point for starting a TCP git server."""
     import optparse
     import optparse
+
     parser = optparse.OptionParser()
     parser = optparse.OptionParser()
-    parser.add_option("-l", "--listen_address", dest="listen_address",
-                      default="localhost",
-                      help="Binding IP address.")
-    parser.add_option("-p", "--port", dest="port", type=int,
-                      default=TCP_GIT_PORT,
-                      help="Binding TCP port.")
+    parser.add_option(
+        "-l",
+        "--listen_address",
+        dest="listen_address",
+        default="localhost",
+        help="Binding IP address.",
+    )
+    parser.add_option(
+        "-p",
+        "--port",
+        dest="port",
+        type=int,
+        default=TCP_GIT_PORT,
+        help="Binding TCP port.",
+    )
     options, args = parser.parse_args(argv)
     options, args = parser.parse_args(argv)
 
 
     log_utils.default_logging_config()
     log_utils.default_logging_config()
     if len(args) > 1:
     if len(args) > 1:
         gitdir = args[1]
         gitdir = args[1]
     else:
     else:
-        gitdir = '.'
+        gitdir = "."
     # TODO(jelmer): Support git-daemon-export-ok and --export-all.
     # TODO(jelmer): Support git-daemon-export-ok and --export-all.
     backend = FileSystemBackend(gitdir)
     backend = FileSystemBackend(gitdir)
     server = TCPGitServer(backend, options.listen_address, options.port)
     server = TCPGitServer(backend, options.listen_address, options.port)
     server.serve_forever()
     server.serve_forever()
 
 
 
 
-def serve_command(handler_cls, argv=sys.argv, backend=None, inf=sys.stdin,
-                  outf=sys.stdout):
+def serve_command(
+    handler_cls, argv=sys.argv, backend=None, inf=sys.stdin, outf=sys.stdout
+):
     """Serve a single command.
     """Serve a single command.
 
 
     This is mostly useful for the implementation of commands used by e.g.
     This is mostly useful for the implementation of commands used by e.g.
@@ -1186,6 +1238,7 @@ def serve_command(handler_cls, argv=sys.argv, backend=None, inf=sys.stdin,
     def send_fn(data):
     def send_fn(data):
         outf.write(data)
         outf.write(data)
         outf.flush()
         outf.flush()
+
     proto = Protocol(inf.read, send_fn)
     proto = Protocol(inf.read, send_fn)
     handler = handler_cls(backend, argv[1:], proto)
     handler = handler_cls(backend, argv[1:], proto)
     # FIXME: Catch exceptions and write a single-line summary to outf.
     # FIXME: Catch exceptions and write a single-line summary to outf.
@@ -1202,9 +1255,7 @@ def generate_info_refs(repo):
 def generate_objects_info_packs(repo):
 def generate_objects_info_packs(repo):
     """Generate an index for for packs."""
     """Generate an index for for packs."""
     for pack in repo.object_store.packs:
     for pack in repo.object_store.packs:
-        yield (
-            b'P ' + os.fsencode(pack.data.filename) +
-            b'\n')
+        yield (b"P " + os.fsencode(pack.data.filename) + b"\n")
 
 
 
 
 def update_server_info(repo):
 def update_server_info(repo):
@@ -1214,13 +1265,14 @@ def update_server_info(repo):
     similar to "git update-server-info".
     similar to "git update-server-info".
     """
     """
     repo._put_named_file(
     repo._put_named_file(
-        os.path.join('info', 'refs'),
-        b"".join(generate_info_refs(repo)))
+        os.path.join("info", "refs"), b"".join(generate_info_refs(repo))
+    )
 
 
     repo._put_named_file(
     repo._put_named_file(
-        os.path.join('objects', 'info', 'packs'),
-        b"".join(generate_objects_info_packs(repo)))
+        os.path.join("objects", "info", "packs"),
+        b"".join(generate_objects_info_packs(repo)),
+    )
 
 
 
 
-if __name__ == '__main__':
+if __name__ == "__main__":
     main()
     main()

+ 37 - 17
dulwich/stash.py

@@ -28,8 +28,8 @@ from dulwich.file import GitFile
 from dulwich.index import (
 from dulwich.index import (
     commit_tree,
     commit_tree,
     iter_fresh_objects,
     iter_fresh_objects,
-    )
-from dulwich.reflog import read_reflog
+)
+from dulwich.reflog import drop_reflog_entry, read_reflog
 
 
 
 
 DEFAULT_STASH_REF = b"refs/stash"
 DEFAULT_STASH_REF = b"refs/stash"
@@ -45,11 +45,15 @@ class Stash(object):
         self._ref = ref
         self._ref = ref
         self._repo = repo
         self._repo = repo
 
 
+    @property
+    def _reflog_path(self):
+        return os.path.join(
+            self._repo.commondir(), "logs", os.fsdecode(self._ref)
+        )
+
     def stashes(self):
     def stashes(self):
-        reflog_path = os.path.join(
-            self._repo.commondir(), 'logs', os.fsdecode(self._ref))
         try:
         try:
-            with GitFile(reflog_path, 'rb') as f:
+            with GitFile(self._reflog_path, "rb") as f:
                 return reversed(list(read_reflog(f)))
                 return reversed(list(read_reflog(f)))
         except FileNotFoundError:
         except FileNotFoundError:
             return []
             return []
@@ -61,10 +65,17 @@ class Stash(object):
 
 
     def drop(self, index):
     def drop(self, index):
         """Drop entry with specified index."""
         """Drop entry with specified index."""
-        raise NotImplementedError(self.drop)
+        with open(self._reflog_path, "rb+") as f:
+            drop_reflog_entry(f, index, rewrite=True)
+        if len(self) == 0:
+            os.remove(self._reflog_path)
+            del self._repo.refs[self._ref]
+            return
+        if index == 0:
+            self._repo.refs[self._ref] = self[0].new_sha
 
 
     def pop(self, index):
     def pop(self, index):
-        raise NotImplementedError(self.drop)
+        raise NotImplementedError(self.pop)
 
 
     def push(self, committer=None, author=None, message=None):
     def push(self, committer=None, author=None, message=None):
         """Create a new stash.
         """Create a new stash.
@@ -77,24 +88,30 @@ class Stash(object):
         # First, create the index commit.
         # First, create the index commit.
         commit_kwargs = {}
         commit_kwargs = {}
         if committer is not None:
         if committer is not None:
-            commit_kwargs['committer'] = committer
+            commit_kwargs["committer"] = committer
         if author is not None:
         if author is not None:
-            commit_kwargs['author'] = author
+            commit_kwargs["author"] = author
 
 
         index = self._repo.open_index()
         index = self._repo.open_index()
         index_tree_id = index.commit(self._repo.object_store)
         index_tree_id = index.commit(self._repo.object_store)
         index_commit_id = self._repo.do_commit(
         index_commit_id = self._repo.do_commit(
-            ref=None, tree=index_tree_id,
+            ref=None,
+            tree=index_tree_id,
             message=b"Index stash",
             message=b"Index stash",
             merge_heads=[self._repo.head()],
             merge_heads=[self._repo.head()],
-            **commit_kwargs)
+            no_verify=True,
+            **commit_kwargs
+        )
 
 
         # Then, the working tree one.
         # Then, the working tree one.
         stash_tree_id = commit_tree(
         stash_tree_id = commit_tree(
-                self._repo.object_store,
-                iter_fresh_objects(
-                    index, os.fsencode(self._repo.path),
-                    object_store=self._repo.object_store))
+            self._repo.object_store,
+            iter_fresh_objects(
+                index,
+                os.fsencode(self._repo.path),
+                object_store=self._repo.object_store,
+            ),
+        )
 
 
         if message is None:
         if message is None:
             message = b"A stash on " + self._repo.head()
             message = b"A stash on " + self._repo.head()
@@ -103,10 +120,13 @@ class Stash(object):
         self._repo.refs[self._ref] = self._repo.head()
         self._repo.refs[self._ref] = self._repo.head()
 
 
         cid = self._repo.do_commit(
         cid = self._repo.do_commit(
-            ref=self._ref, tree=stash_tree_id,
+            ref=self._ref,
+            tree=stash_tree_id,
             message=message,
             message=message,
             merge_heads=[index_commit_id],
             merge_heads=[index_commit_id],
-            **commit_kwargs)
+            no_verify=True,
+            **commit_kwargs
+        )
 
 
         return cid
         return cid
 
 

+ 70 - 56
dulwich/tests/__init__.py

@@ -35,15 +35,15 @@ from unittest import (  # noqa: F401
     TestCase as _TestCase,
     TestCase as _TestCase,
     skipIf,
     skipIf,
     expectedFailure,
     expectedFailure,
-    )
+)
 
 
 
 
 class TestCase(_TestCase):
 class TestCase(_TestCase):
-
     def setUp(self):
     def setUp(self):
         super(TestCase, self).setUp()
         super(TestCase, self).setUp()
         self._old_home = os.environ.get("HOME")
         self._old_home = os.environ.get("HOME")
         os.environ["HOME"] = "/nonexistant"
         os.environ["HOME"] = "/nonexistant"
+        os.environ["GIT_CONFIG_NOSYSTEM"] = "1"
 
 
     def tearDown(self):
     def tearDown(self):
         super(TestCase, self).tearDown()
         super(TestCase, self).tearDown()
@@ -57,9 +57,11 @@ class BlackboxTestCase(TestCase):
     """Blackbox testing."""
     """Blackbox testing."""
 
 
     # TODO(jelmer): Include more possible binary paths.
     # TODO(jelmer): Include more possible binary paths.
-    bin_directories = [os.path.abspath(os.path.join(
-            os.path.dirname(__file__), "..", "..", "bin")), '/usr/bin',
-            '/usr/local/bin']
+    bin_directories = [
+        os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "bin")),
+        "/usr/bin",
+        "/usr/local/bin",
+    ]
 
 
     def bin_path(self, name):
     def bin_path(self, name):
         """Determine the full path of a binary.
         """Determine the full path of a binary.
@@ -92,50 +94,52 @@ class BlackboxTestCase(TestCase):
         # Save us from all that headache and call python with the bin script.
         # Save us from all that headache and call python with the bin script.
         argv = [sys.executable, self.bin_path(name)] + args
         argv = [sys.executable, self.bin_path(name)] + args
         return subprocess.Popen(
         return subprocess.Popen(
-                argv,
-                stdout=subprocess.PIPE,
-                stdin=subprocess.PIPE, stderr=subprocess.PIPE,
-                env=env)
+            argv,
+            stdout=subprocess.PIPE,
+            stdin=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+            env=env,
+        )
 
 
 
 
 def self_test_suite():
 def self_test_suite():
     names = [
     names = [
-        'archive',
-        'blackbox',
-        'bundle',
-        'client',
-        'config',
-        'diff_tree',
-        'fastexport',
-        'file',
-        'grafts',
-        'graph',
-        'greenthreads',
-        'hooks',
-        'ignore',
-        'index',
-        'lfs',
-        'line_ending',
-        'lru_cache',
-        'mailmap',
-        'objects',
-        'objectspec',
-        'object_store',
-        'missing_obj_finder',
-        'pack',
-        'patch',
-        'porcelain',
-        'protocol',
-        'reflog',
-        'refs',
-        'repository',
-        'server',
-        'stash',
-        'utils',
-        'walk',
-        'web',
-        ]
-    module_names = ['dulwich.tests.test_' + name for name in names]
+        "archive",
+        "blackbox",
+        "bundle",
+        "client",
+        "config",
+        "diff_tree",
+        "fastexport",
+        "file",
+        "grafts",
+        "graph",
+        "greenthreads",
+        "hooks",
+        "ignore",
+        "index",
+        "lfs",
+        "line_ending",
+        "lru_cache",
+        "mailmap",
+        "objects",
+        "objectspec",
+        "object_store",
+        "missing_obj_finder",
+        "pack",
+        "patch",
+        "porcelain",
+        "protocol",
+        "reflog",
+        "refs",
+        "repository",
+        "server",
+        "stash",
+        "utils",
+        "walk",
+        "web",
+    ]
+    module_names = ["dulwich.tests.test_" + name for name in names]
     loader = unittest.TestLoader()
     loader = unittest.TestLoader()
     return loader.loadTestsFromNames(module_names)
     return loader.loadTestsFromNames(module_names)
 
 
@@ -148,28 +152,34 @@ def tutorial_test_suite():
     import dulwich.repo  # noqa: F401
     import dulwich.repo  # noqa: F401
     import dulwich.server  # noqa: F401
     import dulwich.server  # noqa: F401
     import dulwich.patch  # noqa: F401
     import dulwich.patch  # noqa: F401
+
     tutorial = [
     tutorial = [
-        'introduction',
-        'file-format',
-        'repo',
-        'object-store',
-        'remote',
-        'conclusion',
-        ]
+        "introduction",
+        "file-format",
+        "repo",
+        "object-store",
+        "remote",
+        "conclusion",
+    ]
     tutorial_files = ["../../docs/tutorial/%s.txt" % name for name in tutorial]
     tutorial_files = ["../../docs/tutorial/%s.txt" % name for name in tutorial]
 
 
     def setup(test):
     def setup(test):
         test.__old_cwd = os.getcwd()
         test.__old_cwd = os.getcwd()
         test.tempdir = tempfile.mkdtemp()
         test.tempdir = tempfile.mkdtemp()
-        test.globs.update({'tempdir': test.tempdir})
+        test.globs.update({"tempdir": test.tempdir})
         os.chdir(test.tempdir)
         os.chdir(test.tempdir)
 
 
     def teardown(test):
     def teardown(test):
         os.chdir(test.__old_cwd)
         os.chdir(test.__old_cwd)
         shutil.rmtree(test.tempdir)
         shutil.rmtree(test.tempdir)
+
     return doctest.DocFileSuite(
     return doctest.DocFileSuite(
-            module_relative=True, package='dulwich.tests',
-            setUp=setup, tearDown=teardown, *tutorial_files)
+        module_relative=True,
+        package="dulwich.tests",
+        setUp=setup,
+        tearDown=teardown,
+        *tutorial_files
+    )
 
 
 
 
 def nocompat_test_suite():
 def nocompat_test_suite():
@@ -177,6 +187,7 @@ def nocompat_test_suite():
     result.addTests(self_test_suite())
     result.addTests(self_test_suite())
     result.addTests(tutorial_test_suite())
     result.addTests(tutorial_test_suite())
     from dulwich.contrib import test_suite as contrib_test_suite
     from dulwich.contrib import test_suite as contrib_test_suite
+
     result.addTests(contrib_test_suite())
     result.addTests(contrib_test_suite())
     return result
     return result
 
 
@@ -184,6 +195,7 @@ def nocompat_test_suite():
 def compat_test_suite():
 def compat_test_suite():
     result = unittest.TestSuite()
     result = unittest.TestSuite()
     from dulwich.tests.compat import test_suite as compat_test_suite
     from dulwich.tests.compat import test_suite as compat_test_suite
+
     result.addTests(compat_test_suite())
     result.addTests(compat_test_suite())
     return result
     return result
 
 
@@ -191,10 +203,12 @@ def compat_test_suite():
 def test_suite():
 def test_suite():
     result = unittest.TestSuite()
     result = unittest.TestSuite()
     result.addTests(self_test_suite())
     result.addTests(self_test_suite())
-    if sys.platform != 'win32':
+    if sys.platform != "win32":
         result.addTests(tutorial_test_suite())
         result.addTests(tutorial_test_suite())
     from dulwich.tests.compat import test_suite as compat_test_suite
     from dulwich.tests.compat import test_suite as compat_test_suite
+
     result.addTests(compat_test_suite())
     result.addTests(compat_test_suite())
     from dulwich.contrib import test_suite as contrib_test_suite
     from dulwich.contrib import test_suite as contrib_test_suite
+
     result.addTests(contrib_test_suite())
     result.addTests(contrib_test_suite())
     return result
     return result

+ 10 - 9
dulwich/tests/compat/__init__.py

@@ -25,15 +25,16 @@ import unittest
 
 
 def test_suite():
 def test_suite():
     names = [
     names = [
-        'client',
-        'pack',
-        'patch',
-        'repository',
-        'server',
-        'utils',
-        'web',
-        ]
-    module_names = ['dulwich.tests.compat.test_' + name for name in names]
+        "client",
+        "pack",
+        "patch",
+        "porcelain",
+        "repository",
+        "server",
+        "utils",
+        "web",
+    ]
+    module_names = ["dulwich.tests.compat.test_" + name for name in names]
     result = unittest.TestSuite()
     result = unittest.TestSuite()
     loader = unittest.TestLoader()
     loader = unittest.TestLoader()
     suite = loader.loadTestsFromNames(module_names)
     suite = loader.loadTestsFromNames(module_names)

+ 148 - 94
dulwich/tests/compat/server_utils.py

@@ -30,16 +30,16 @@ from dulwich.repo import Repo
 from dulwich.objects import hex_to_sha
 from dulwich.objects import hex_to_sha
 from dulwich.protocol import (
 from dulwich.protocol import (
     CAPABILITY_SIDE_BAND_64K,
     CAPABILITY_SIDE_BAND_64K,
-    )
+)
 from dulwich.server import (
 from dulwich.server import (
     ReceivePackHandler,
     ReceivePackHandler,
-    )
+)
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     tear_down_repo,
     tear_down_repo,
-    )
+)
 from dulwich.tests.compat.utils import (
 from dulwich.tests.compat.utils import (
     run_git_or_fail,
     run_git_or_fail,
-    )
+)
 from dulwich.tests.compat.utils import require_git_version
 from dulwich.tests.compat.utils import require_git_version
 
 
 
 
@@ -56,7 +56,7 @@ class _StubRepo(object):
 
 
 
 
 def _get_shallow(repo):
 def _get_shallow(repo):
-    shallow_file = repo.get_named_file('shallow')
+    shallow_file = repo.get_named_file("shallow")
     if not shallow_file:
     if not shallow_file:
         return []
         return []
     shallows = []
     shallows = []
@@ -76,70 +76,80 @@ class ServerTests(object):
     Does not inherit from TestCase so tests are not automatically run.
     Does not inherit from TestCase so tests are not automatically run.
     """
     """
 
 
-    min_single_branch_version = (1, 7, 10,)
+    min_single_branch_version = (
+        1,
+        7,
+        10,
+    )
 
 
     def import_repos(self):
     def import_repos(self):
-        self._old_repo = self.import_repo('server_old.export')
-        self._new_repo = self.import_repo('server_new.export')
+        self._old_repo = self.import_repo("server_old.export")
+        self._new_repo = self.import_repo("server_new.export")
 
 
     def url(self, port):
     def url(self, port):
-        return '%s://localhost:%s/' % (self.protocol, port)
+        return "%s://localhost:%s/" % (self.protocol, port)
 
 
     def branch_args(self, branches=None):
     def branch_args(self, branches=None):
         if branches is None:
         if branches is None:
-            branches = ['master', 'branch']
-        return ['%s:%s' % (b, b) for b in branches]
+            branches = ["master", "branch"]
+        return ["%s:%s" % (b, b) for b in branches]
 
 
     def test_push_to_dulwich(self):
     def test_push_to_dulwich(self):
         self.import_repos()
         self.import_repos()
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
         port = self._start_server(self._old_repo)
 
 
-        run_git_or_fail(['push', self.url(port)] + self.branch_args(),
-                        cwd=self._new_repo.path)
+        run_git_or_fail(
+            ["push", self.url(port)] + self.branch_args(),
+            cwd=self._new_repo.path,
+        )
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
 
 
     def test_push_to_dulwich_no_op(self):
     def test_push_to_dulwich_no_op(self):
-        self._old_repo = self.import_repo('server_old.export')
-        self._new_repo = self.import_repo('server_old.export')
+        self._old_repo = self.import_repo("server_old.export")
+        self._new_repo = self.import_repo("server_old.export")
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
         port = self._start_server(self._old_repo)
 
 
-        run_git_or_fail(['push', self.url(port)] + self.branch_args(),
-                        cwd=self._new_repo.path)
+        run_git_or_fail(
+            ["push", self.url(port)] + self.branch_args(),
+            cwd=self._new_repo.path,
+        )
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
 
 
     def test_push_to_dulwich_remove_branch(self):
     def test_push_to_dulwich_remove_branch(self):
-        self._old_repo = self.import_repo('server_old.export')
-        self._new_repo = self.import_repo('server_old.export')
+        self._old_repo = self.import_repo("server_old.export")
+        self._new_repo = self.import_repo("server_old.export")
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
         port = self._start_server(self._old_repo)
 
 
-        run_git_or_fail(['push', self.url(port), ":master"],
-                        cwd=self._new_repo.path)
+        run_git_or_fail(["push", self.url(port), ":master"], cwd=self._new_repo.path)
 
 
-        self.assertEqual(
-            list(self._old_repo.get_refs().keys()), [b"refs/heads/branch"])
+        self.assertEqual(list(self._old_repo.get_refs().keys()), [b"refs/heads/branch"])
 
 
     def test_fetch_from_dulwich(self):
     def test_fetch_from_dulwich(self):
         self.import_repos()
         self.import_repos()
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._new_repo)
         port = self._start_server(self._new_repo)
 
 
-        run_git_or_fail(['fetch', self.url(port)] + self.branch_args(),
-                        cwd=self._old_repo.path)
+        run_git_or_fail(
+            ["fetch", self.url(port)] + self.branch_args(),
+            cwd=self._old_repo.path,
+        )
         # flush the pack cache so any new packs are picked up
         # flush the pack cache so any new packs are picked up
         self._old_repo.object_store._pack_cache_time = 0
         self._old_repo.object_store._pack_cache_time = 0
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
 
 
     def test_fetch_from_dulwich_no_op(self):
     def test_fetch_from_dulwich_no_op(self):
-        self._old_repo = self.import_repo('server_old.export')
-        self._new_repo = self.import_repo('server_old.export')
+        self._old_repo = self.import_repo("server_old.export")
+        self._new_repo = self.import_repo("server_old.export")
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._new_repo)
         port = self._start_server(self._new_repo)
 
 
-        run_git_or_fail(['fetch', self.url(port)] + self.branch_args(),
-                        cwd=self._old_repo.path)
+        run_git_or_fail(
+            ["fetch", self.url(port)] + self.branch_args(),
+            cwd=self._old_repo.path,
+        )
         # flush the pack cache so any new packs are picked up
         # flush the pack cache so any new packs are picked up
         self._old_repo.object_store._pack_cache_time = 0
         self._old_repo.object_store._pack_cache_time = 0
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
@@ -152,146 +162,185 @@ class ServerTests(object):
 
 
         new_repo_base_dir = tempfile.mkdtemp()
         new_repo_base_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, new_repo_base_dir)
         self.addCleanup(shutil.rmtree, new_repo_base_dir)
-        new_repo_dir = os.path.join(new_repo_base_dir, 'empty_new')
-        run_git_or_fail(['clone', self.url(port), new_repo_dir],
-                        cwd=new_repo_base_dir)
+        new_repo_dir = os.path.join(new_repo_base_dir, "empty_new")
+        run_git_or_fail(["clone", self.url(port), new_repo_dir], cwd=new_repo_base_dir)
         new_repo = Repo(new_repo_dir)
         new_repo = Repo(new_repo_dir)
         self.assertReposEqual(self._old_repo, new_repo)
         self.assertReposEqual(self._old_repo, new_repo)
 
 
     def test_lsremote_from_dulwich(self):
     def test_lsremote_from_dulwich(self):
-        self._repo = self.import_repo('server_old.export')
+        self._repo = self.import_repo("server_old.export")
         port = self._start_server(self._repo)
         port = self._start_server(self._repo)
-        o = run_git_or_fail(['ls-remote', self.url(port)])
-        self.assertEqual(len(o.split(b'\n')), 4)
+        o = run_git_or_fail(["ls-remote", self.url(port)])
+        self.assertEqual(len(o.split(b"\n")), 4)
 
 
     def test_new_shallow_clone_from_dulwich(self):
     def test_new_shallow_clone_from_dulwich(self):
         require_git_version(self.min_single_branch_version)
         require_git_version(self.min_single_branch_version)
-        self._source_repo = self.import_repo('server_new.export')
-        self._stub_repo = _StubRepo('shallow')
+        self._source_repo = self.import_repo("server_new.export")
+        self._stub_repo = _StubRepo("shallow")
         self.addCleanup(tear_down_repo, self._stub_repo)
         self.addCleanup(tear_down_repo, self._stub_repo)
         port = self._start_server(self._source_repo)
         port = self._start_server(self._source_repo)
 
 
         # Fetch at depth 1
         # Fetch at depth 1
         run_git_or_fail(
         run_git_or_fail(
-            ['clone', '--mirror', '--depth=1', '--no-single-branch',
-             self.url(port), self._stub_repo.path])
+            [
+                "clone",
+                "--mirror",
+                "--depth=1",
+                "--no-single-branch",
+                self.url(port),
+                self._stub_repo.path,
+            ]
+        )
         clone = self._stub_repo = Repo(self._stub_repo.path)
         clone = self._stub_repo = Repo(self._stub_repo.path)
-        expected_shallow = [b'35e0b59e187dd72a0af294aedffc213eaa4d03ff',
-                            b'514dc6d3fbfe77361bcaef320c4d21b72bc10be9']
+        expected_shallow = [
+            b"35e0b59e187dd72a0af294aedffc213eaa4d03ff",
+            b"514dc6d3fbfe77361bcaef320c4d21b72bc10be9",
+        ]
         self.assertEqual(expected_shallow, _get_shallow(clone))
         self.assertEqual(expected_shallow, _get_shallow(clone))
         self.assertReposNotEqual(clone, self._source_repo)
         self.assertReposNotEqual(clone, self._source_repo)
 
 
     def test_shallow_clone_from_git_is_identical(self):
     def test_shallow_clone_from_git_is_identical(self):
         require_git_version(self.min_single_branch_version)
         require_git_version(self.min_single_branch_version)
-        self._source_repo = self.import_repo('server_new.export')
-        self._stub_repo_git = _StubRepo('shallow-git')
+        self._source_repo = self.import_repo("server_new.export")
+        self._stub_repo_git = _StubRepo("shallow-git")
         self.addCleanup(tear_down_repo, self._stub_repo_git)
         self.addCleanup(tear_down_repo, self._stub_repo_git)
-        self._stub_repo_dw = _StubRepo('shallow-dw')
+        self._stub_repo_dw = _StubRepo("shallow-dw")
         self.addCleanup(tear_down_repo, self._stub_repo_dw)
         self.addCleanup(tear_down_repo, self._stub_repo_dw)
 
 
         # shallow clone using stock git, then using dulwich
         # shallow clone using stock git, then using dulwich
         run_git_or_fail(
         run_git_or_fail(
-            ['clone', '--mirror', '--depth=1', '--no-single-branch',
-             'file://' + self._source_repo.path, self._stub_repo_git.path])
+            [
+                "clone",
+                "--mirror",
+                "--depth=1",
+                "--no-single-branch",
+                "file://" + self._source_repo.path,
+                self._stub_repo_git.path,
+            ]
+        )
 
 
         port = self._start_server(self._source_repo)
         port = self._start_server(self._source_repo)
         run_git_or_fail(
         run_git_or_fail(
-            ['clone', '--mirror', '--depth=1', '--no-single-branch',
-             self.url(port), self._stub_repo_dw.path])
+            [
+                "clone",
+                "--mirror",
+                "--depth=1",
+                "--no-single-branch",
+                self.url(port),
+                self._stub_repo_dw.path,
+            ]
+        )
 
 
         # compare the two clones; they should be equal
         # compare the two clones; they should be equal
-        self.assertReposEqual(Repo(self._stub_repo_git.path),
-                              Repo(self._stub_repo_dw.path))
+        self.assertReposEqual(
+            Repo(self._stub_repo_git.path), Repo(self._stub_repo_dw.path)
+        )
 
 
     def test_fetch_same_depth_into_shallow_clone_from_dulwich(self):
     def test_fetch_same_depth_into_shallow_clone_from_dulwich(self):
         require_git_version(self.min_single_branch_version)
         require_git_version(self.min_single_branch_version)
-        self._source_repo = self.import_repo('server_new.export')
-        self._stub_repo = _StubRepo('shallow')
+        self._source_repo = self.import_repo("server_new.export")
+        self._stub_repo = _StubRepo("shallow")
         self.addCleanup(tear_down_repo, self._stub_repo)
         self.addCleanup(tear_down_repo, self._stub_repo)
         port = self._start_server(self._source_repo)
         port = self._start_server(self._source_repo)
 
 
         # Fetch at depth 2
         # Fetch at depth 2
         run_git_or_fail(
         run_git_or_fail(
-            ['clone', '--mirror', '--depth=2', '--no-single-branch',
-             self.url(port), self._stub_repo.path])
+            [
+                "clone",
+                "--mirror",
+                "--depth=2",
+                "--no-single-branch",
+                self.url(port),
+                self._stub_repo.path,
+            ]
+        )
         clone = self._stub_repo = Repo(self._stub_repo.path)
         clone = self._stub_repo = Repo(self._stub_repo.path)
 
 
         # Fetching at the same depth is a no-op.
         # Fetching at the same depth is a no-op.
         run_git_or_fail(
         run_git_or_fail(
-          ['fetch', '--depth=2', self.url(port)] + self.branch_args(),
-          cwd=self._stub_repo.path)
-        expected_shallow = [b'94de09a530df27ac3bb613aaecdd539e0a0655e1',
-                            b'da5cd81e1883c62a25bb37c4d1f8ad965b29bf8d']
+            ["fetch", "--depth=2", self.url(port)] + self.branch_args(),
+            cwd=self._stub_repo.path,
+        )
+        expected_shallow = [
+            b"94de09a530df27ac3bb613aaecdd539e0a0655e1",
+            b"da5cd81e1883c62a25bb37c4d1f8ad965b29bf8d",
+        ]
         self.assertEqual(expected_shallow, _get_shallow(clone))
         self.assertEqual(expected_shallow, _get_shallow(clone))
         self.assertReposNotEqual(clone, self._source_repo)
         self.assertReposNotEqual(clone, self._source_repo)
 
 
     def test_fetch_full_depth_into_shallow_clone_from_dulwich(self):
     def test_fetch_full_depth_into_shallow_clone_from_dulwich(self):
         require_git_version(self.min_single_branch_version)
         require_git_version(self.min_single_branch_version)
-        self._source_repo = self.import_repo('server_new.export')
-        self._stub_repo = _StubRepo('shallow')
+        self._source_repo = self.import_repo("server_new.export")
+        self._stub_repo = _StubRepo("shallow")
         self.addCleanup(tear_down_repo, self._stub_repo)
         self.addCleanup(tear_down_repo, self._stub_repo)
         port = self._start_server(self._source_repo)
         port = self._start_server(self._source_repo)
 
 
         # Fetch at depth 2
         # Fetch at depth 2
         run_git_or_fail(
         run_git_or_fail(
-            ['clone', '--mirror', '--depth=2', '--no-single-branch',
-             self.url(port), self._stub_repo.path])
+            [
+                "clone",
+                "--mirror",
+                "--depth=2",
+                "--no-single-branch",
+                self.url(port),
+                self._stub_repo.path,
+            ]
+        )
         clone = self._stub_repo = Repo(self._stub_repo.path)
         clone = self._stub_repo = Repo(self._stub_repo.path)
 
 
         # Fetching at the same depth is a no-op.
         # Fetching at the same depth is a no-op.
         run_git_or_fail(
         run_git_or_fail(
-          ['fetch', '--depth=2', self.url(port)] + self.branch_args(),
-          cwd=self._stub_repo.path)
+            ["fetch", "--depth=2", self.url(port)] + self.branch_args(),
+            cwd=self._stub_repo.path,
+        )
 
 
         # The whole repo only has depth 4, so it should equal server_new.
         # The whole repo only has depth 4, so it should equal server_new.
         run_git_or_fail(
         run_git_or_fail(
-          ['fetch', '--depth=4', self.url(port)] + self.branch_args(),
-          cwd=self._stub_repo.path)
+            ["fetch", "--depth=4", self.url(port)] + self.branch_args(),
+            cwd=self._stub_repo.path,
+        )
         self.assertEqual([], _get_shallow(clone))
         self.assertEqual([], _get_shallow(clone))
         self.assertReposEqual(clone, self._source_repo)
         self.assertReposEqual(clone, self._source_repo)
 
 
     def test_fetch_from_dulwich_issue_88_standard(self):
     def test_fetch_from_dulwich_issue_88_standard(self):
         # Basically an integration test to see that the ACK/NAK
         # Basically an integration test to see that the ACK/NAK
         # generation works on repos with common head.
         # generation works on repos with common head.
-        self._source_repo = self.import_repo(
-            'issue88_expect_ack_nak_server.export')
-        self._client_repo = self.import_repo(
-            'issue88_expect_ack_nak_client.export')
+        self._source_repo = self.import_repo("issue88_expect_ack_nak_server.export")
+        self._client_repo = self.import_repo("issue88_expect_ack_nak_client.export")
         port = self._start_server(self._source_repo)
         port = self._start_server(self._source_repo)
 
 
-        run_git_or_fail(['fetch', self.url(port), 'master'],
-                        cwd=self._client_repo.path)
+        run_git_or_fail(["fetch", self.url(port), "master"], cwd=self._client_repo.path)
         self.assertObjectStoreEqual(
         self.assertObjectStoreEqual(
-            self._source_repo.object_store,
-            self._client_repo.object_store)
+            self._source_repo.object_store, self._client_repo.object_store
+        )
 
 
     def test_fetch_from_dulwich_issue_88_alternative(self):
     def test_fetch_from_dulwich_issue_88_alternative(self):
         # likewise, but the case where the two repos have no common parent
         # likewise, but the case where the two repos have no common parent
-        self._source_repo = self.import_repo(
-            'issue88_expect_ack_nak_other.export')
-        self._client_repo = self.import_repo(
-            'issue88_expect_ack_nak_client.export')
+        self._source_repo = self.import_repo("issue88_expect_ack_nak_other.export")
+        self._client_repo = self.import_repo("issue88_expect_ack_nak_client.export")
         port = self._start_server(self._source_repo)
         port = self._start_server(self._source_repo)
 
 
         self.assertRaises(
         self.assertRaises(
-            KeyError, self._client_repo.get_object,
-            b'02a14da1fc1fc13389bbf32f0af7d8899f2b2323')
-        run_git_or_fail(['fetch', self.url(port), 'master'],
-                        cwd=self._client_repo.path)
-        self.assertEqual(b'commit', self._client_repo.get_object(
-            b'02a14da1fc1fc13389bbf32f0af7d8899f2b2323').type_name)
+            KeyError,
+            self._client_repo.get_object,
+            b"02a14da1fc1fc13389bbf32f0af7d8899f2b2323",
+        )
+        run_git_or_fail(["fetch", self.url(port), "master"], cwd=self._client_repo.path)
+        self.assertEqual(
+            b"commit",
+            self._client_repo.get_object(
+                b"02a14da1fc1fc13389bbf32f0af7d8899f2b2323"
+            ).type_name,
+        )
 
 
     def test_push_to_dulwich_issue_88_standard(self):
     def test_push_to_dulwich_issue_88_standard(self):
         # Same thing, but we reverse the role of the server/client
         # Same thing, but we reverse the role of the server/client
         # and do a push instead.
         # and do a push instead.
-        self._source_repo = self.import_repo(
-            'issue88_expect_ack_nak_client.export')
-        self._client_repo = self.import_repo(
-            'issue88_expect_ack_nak_server.export')
+        self._source_repo = self.import_repo("issue88_expect_ack_nak_client.export")
+        self._client_repo = self.import_repo("issue88_expect_ack_nak_server.export")
         port = self._start_server(self._source_repo)
         port = self._start_server(self._source_repo)
 
 
-        run_git_or_fail(['push', self.url(port), 'master'],
-                        cwd=self._client_repo.path)
+        run_git_or_fail(["push", self.url(port), "master"], cwd=self._client_repo.path)
         self.assertReposEqual(self._source_repo, self._client_repo)
         self.assertReposEqual(self._source_repo, self._client_repo)
 
 
 
 
@@ -303,12 +352,17 @@ class NoSideBand64kReceivePackHandler(ReceivePackHandler):
 
 
     @classmethod
     @classmethod
     def capabilities(cls):
     def capabilities(cls):
-        return [c for c in ReceivePackHandler.capabilities()
-                if c != CAPABILITY_SIDE_BAND_64K]
+        return [
+            c
+            for c in ReceivePackHandler.capabilities()
+            if c != CAPABILITY_SIDE_BAND_64K
+        ]
 
 
 
 
 def ignore_error(error):
 def ignore_error(error):
     """Check whether this error is safe to ignore."""
     """Check whether this error is safe to ignore."""
     (e_type, e_value, e_tb) = error
     (e_type, e_value, e_tb) = error
-    return (issubclass(e_type, socket.error) and
-            e_value[0] in (errno.ECONNRESET, errno.EPIPE))
+    return issubclass(e_type, socket.error) and e_value[0] in (
+        errno.ECONNRESET,
+        errno.EPIPE,
+    )

+ 220 - 176
dulwich/tests/compat/test_client.py

@@ -43,11 +43,11 @@ from dulwich import (
     protocol,
     protocol,
     objects,
     objects,
     repo,
     repo,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     SkipTest,
     SkipTest,
     expectedFailure,
     expectedFailure,
-    )
+)
 from dulwich.tests.compat.utils import (
 from dulwich.tests.compat.utils import (
     CompatTestCase,
     CompatTestCase,
     check_for_daemon,
     check_for_daemon,
@@ -55,10 +55,10 @@ from dulwich.tests.compat.utils import (
     rmtree_ro,
     rmtree_ro,
     run_git_or_fail,
     run_git_or_fail,
     _DEFAULT_GIT,
     _DEFAULT_GIT,
-    )
+)
 
 
 
 
-if sys.platform == 'win32':
+if sys.platform == "win32":
     import ctypes
     import ctypes
 
 
 
 
@@ -67,17 +67,18 @@ class DulwichClientTestBase(object):
 
 
     def setUp(self):
     def setUp(self):
         self.gitroot = os.path.dirname(
         self.gitroot = os.path.dirname(
-                import_repo_to_dir('server_new.export').rstrip(os.sep))
-        self.dest = os.path.join(self.gitroot, 'dest')
+            import_repo_to_dir("server_new.export").rstrip(os.sep)
+        )
+        self.dest = os.path.join(self.gitroot, "dest")
         file.ensure_dir_exists(self.dest)
         file.ensure_dir_exists(self.dest)
-        run_git_or_fail(['init', '--quiet', '--bare'], cwd=self.dest)
+        run_git_or_fail(["init", "--quiet", "--bare"], cwd=self.dest)
 
 
     def tearDown(self):
     def tearDown(self):
         rmtree_ro(self.gitroot)
         rmtree_ro(self.gitroot)
 
 
     def assertDestEqualsSrc(self):
     def assertDestEqualsSrc(self):
-        repo_dir = os.path.join(self.gitroot, 'server_new.export')
-        dest_repo_dir = os.path.join(self.gitroot, 'dest')
+        repo_dir = os.path.join(self.gitroot, "server_new.export")
+        dest_repo_dir = os.path.join(self.gitroot, "dest")
         with repo.Repo(repo_dir) as src:
         with repo.Repo(repo_dir) as src:
             with repo.Repo(dest_repo_dir) as dest:
             with repo.Repo(dest_repo_dir) as dest:
                 self.assertReposEqual(src, dest)
                 self.assertReposEqual(src, dest)
@@ -90,12 +91,15 @@ class DulwichClientTestBase(object):
 
 
     def _do_send_pack(self):
     def _do_send_pack(self):
         c = self._client()
         c = self._client()
-        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        srcpath = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(srcpath) as src:
         with repo.Repo(srcpath) as src:
             sendrefs = dict(src.get_refs())
             sendrefs = dict(src.get_refs())
-            del sendrefs[b'HEAD']
-            c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
-                        src.generate_pack_data)
+            del sendrefs[b"HEAD"]
+            c.send_pack(
+                self._build_path("/dest"),
+                lambda _: sendrefs,
+                src.generate_pack_data,
+            )
 
 
     def test_send_pack(self):
     def test_send_pack(self):
         self._do_send_pack()
         self._do_send_pack()
@@ -111,157 +115,175 @@ class DulwichClientTestBase(object):
     def _add_file(repo, tree_id, filename, contents):
     def _add_file(repo, tree_id, filename, contents):
         tree = repo[tree_id]
         tree = repo[tree_id]
         blob = objects.Blob()
         blob = objects.Blob()
-        blob.data = contents.encode('utf-8')
+        blob.data = contents.encode("utf-8")
         repo.object_store.add_object(blob)
         repo.object_store.add_object(blob)
-        tree.add(filename.encode('utf-8'), stat.S_IFREG | 0o644, blob.id)
+        tree.add(filename.encode("utf-8"), stat.S_IFREG | 0o644, blob.id)
         repo.object_store.add_object(tree)
         repo.object_store.add_object(tree)
         return tree.id
         return tree.id
 
 
     def test_send_pack_from_shallow_clone(self):
     def test_send_pack_from_shallow_clone(self):
         c = self._client()
         c = self._client()
-        server_new_path = os.path.join(self.gitroot, 'server_new.export')
-        run_git_or_fail(['config', 'http.uploadpack', 'true'],
-                        cwd=server_new_path)
-        run_git_or_fail(['config', 'http.receivepack', 'true'],
-                        cwd=server_new_path)
-        remote_path = self._build_path('/server_new.export')
+        server_new_path = os.path.join(self.gitroot, "server_new.export")
+        run_git_or_fail(["config", "http.uploadpack", "true"], cwd=server_new_path)
+        run_git_or_fail(["config", "http.receivepack", "true"], cwd=server_new_path)
+        remote_path = self._build_path("/server_new.export")
         with repo.Repo(self.dest) as local:
         with repo.Repo(self.dest) as local:
             result = c.fetch(remote_path, local, depth=1)
             result = c.fetch(remote_path, local, depth=1)
             for r in result.refs.items():
             for r in result.refs.items():
                 local.refs.set_if_equals(r[0], None, r[1])
                 local.refs.set_if_equals(r[0], None, r[1])
             tree_id = local[local.head()].tree
             tree_id = local[local.head()].tree
-            for filename, contents in [('bar', 'bar contents'),
-                                       ('zop', 'zop contents')]:
+            for filename, contents in [
+                ("bar", "bar contents"),
+                ("zop", "zop contents"),
+            ]:
                 tree_id = self._add_file(local, tree_id, filename, contents)
                 tree_id = self._add_file(local, tree_id, filename, contents)
                 commit_id = local.do_commit(
                 commit_id = local.do_commit(
-                    message=b"add " + filename.encode('utf-8'),
+                    message=b"add " + filename.encode("utf-8"),
                     committer=b"Joe Example <joe@example.com>",
                     committer=b"Joe Example <joe@example.com>",
-                    tree=tree_id)
+                    tree=tree_id,
+                )
             sendrefs = dict(local.get_refs())
             sendrefs = dict(local.get_refs())
-            del sendrefs[b'HEAD']
-            c.send_pack(remote_path, lambda _: sendrefs,
-                        local.generate_pack_data)
+            del sendrefs[b"HEAD"]
+            c.send_pack(remote_path, lambda _: sendrefs, local.generate_pack_data)
         with repo.Repo(server_new_path) as remote:
         with repo.Repo(server_new_path) as remote:
             self.assertEqual(remote.head(), commit_id)
             self.assertEqual(remote.head(), commit_id)
 
 
     def test_send_without_report_status(self):
     def test_send_without_report_status(self):
         c = self._client()
         c = self._client()
-        c._send_capabilities.remove(b'report-status')
-        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        c._send_capabilities.remove(b"report-status")
+        srcpath = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(srcpath) as src:
         with repo.Repo(srcpath) as src:
             sendrefs = dict(src.get_refs())
             sendrefs = dict(src.get_refs())
-            del sendrefs[b'HEAD']
-            c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
-                        src.generate_pack_data)
+            del sendrefs[b"HEAD"]
+            c.send_pack(
+                self._build_path("/dest"),
+                lambda _: sendrefs,
+                src.generate_pack_data,
+            )
             self.assertDestEqualsSrc()
             self.assertDestEqualsSrc()
 
 
     def make_dummy_commit(self, dest):
     def make_dummy_commit(self, dest):
-        b = objects.Blob.from_string(b'hi')
+        b = objects.Blob.from_string(b"hi")
         dest.object_store.add_object(b)
         dest.object_store.add_object(b)
-        t = index.commit_tree(dest.object_store, [(b'hi', b.id, 0o100644)])
+        t = index.commit_tree(dest.object_store, [(b"hi", b.id, 0o100644)])
         c = objects.Commit()
         c = objects.Commit()
-        c.author = c.committer = b'Foo Bar <foo@example.com>'
+        c.author = c.committer = b"Foo Bar <foo@example.com>"
         c.author_time = c.commit_time = 0
         c.author_time = c.commit_time = 0
         c.author_timezone = c.commit_timezone = 0
         c.author_timezone = c.commit_timezone = 0
-        c.message = b'hi'
+        c.message = b"hi"
         c.tree = t
         c.tree = t
         dest.object_store.add_object(c)
         dest.object_store.add_object(c)
         return c.id
         return c.id
 
 
     def disable_ff_and_make_dummy_commit(self):
     def disable_ff_and_make_dummy_commit(self):
         # disable non-fast-forward pushes to the server
         # disable non-fast-forward pushes to the server
-        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
-        run_git_or_fail(['config', 'receive.denyNonFastForwards', 'true'],
-                        cwd=dest.path)
+        dest = repo.Repo(os.path.join(self.gitroot, "dest"))
+        run_git_or_fail(
+            ["config", "receive.denyNonFastForwards", "true"], cwd=dest.path
+        )
         commit_id = self.make_dummy_commit(dest)
         commit_id = self.make_dummy_commit(dest)
         return dest, commit_id
         return dest, commit_id
 
 
     def compute_send(self, src):
     def compute_send(self, src):
         sendrefs = dict(src.get_refs())
         sendrefs = dict(src.get_refs())
-        del sendrefs[b'HEAD']
+        del sendrefs[b"HEAD"]
         return sendrefs, src.generate_pack_data
         return sendrefs, src.generate_pack_data
 
 
     def test_send_pack_one_error(self):
     def test_send_pack_one_error(self):
         dest, dummy_commit = self.disable_ff_and_make_dummy_commit()
         dest, dummy_commit = self.disable_ff_and_make_dummy_commit()
-        dest.refs[b'refs/heads/master'] = dummy_commit
-        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        dest.refs[b"refs/heads/master"] = dummy_commit
+        repo_dir = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(repo_dir) as src:
         with repo.Repo(repo_dir) as src:
             sendrefs, gen_pack = self.compute_send(src)
             sendrefs, gen_pack = self.compute_send(src)
             c = self._client()
             c = self._client()
             result = c.send_pack(
             result = c.send_pack(
-                self._build_path('/dest'), lambda _: sendrefs, gen_pack)
-            self.assertEqual({b'refs/heads/branch': None,
-                              b'refs/heads/master': 'non-fast-forward'},
-                             result.ref_status)
+                self._build_path("/dest"), lambda _: sendrefs, gen_pack
+            )
+            self.assertEqual(
+                {
+                    b"refs/heads/branch": None,
+                    b"refs/heads/master": "non-fast-forward",
+                },
+                result.ref_status,
+            )
 
 
     def test_send_pack_multiple_errors(self):
     def test_send_pack_multiple_errors(self):
         dest, dummy = self.disable_ff_and_make_dummy_commit()
         dest, dummy = self.disable_ff_and_make_dummy_commit()
         # set up for two non-ff errors
         # set up for two non-ff errors
-        branch, master = b'refs/heads/branch', b'refs/heads/master'
+        branch, master = b"refs/heads/branch", b"refs/heads/master"
         dest.refs[branch] = dest.refs[master] = dummy
         dest.refs[branch] = dest.refs[master] = dummy
-        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        repo_dir = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(repo_dir) as src:
         with repo.Repo(repo_dir) as src:
             sendrefs, gen_pack = self.compute_send(src)
             sendrefs, gen_pack = self.compute_send(src)
             c = self._client()
             c = self._client()
             result = c.send_pack(
             result = c.send_pack(
-                self._build_path('/dest'), lambda _: sendrefs, gen_pack)
-            self.assertEqual({branch: 'non-fast-forward',
-                              master: 'non-fast-forward'},
-                             result.ref_status)
+                self._build_path("/dest"), lambda _: sendrefs, gen_pack
+            )
+            self.assertEqual(
+                {branch: "non-fast-forward", master: "non-fast-forward"},
+                result.ref_status,
+            )
 
 
     def test_archive(self):
     def test_archive(self):
         c = self._client()
         c = self._client()
         f = BytesIO()
         f = BytesIO()
-        c.archive(self._build_path('/server_new.export'), b'HEAD', f.write)
+        c.archive(self._build_path("/server_new.export"), b"HEAD", f.write)
         f.seek(0)
         f.seek(0)
         tf = tarfile.open(fileobj=f)
         tf = tarfile.open(fileobj=f)
-        self.assertEqual(['baz', 'foo'], tf.getnames())
+        self.assertEqual(["baz", "foo"], tf.getnames())
 
 
     def test_fetch_pack(self):
     def test_fetch_pack(self):
         c = self._client()
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            result = c.fetch(self._build_path('/server_new.export'), dest)
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
+            result = c.fetch(self._build_path("/server_new.export"), dest)
             for r in result.refs.items():
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
             self.assertDestEqualsSrc()
 
 
     def test_fetch_pack_depth(self):
     def test_fetch_pack_depth(self):
         c = self._client()
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            result = c.fetch(self._build_path('/server_new.export'), dest,
-                             depth=1)
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
+            result = c.fetch(self._build_path("/server_new.export"), dest, depth=1)
             for r in result.refs.items():
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertEqual(
             self.assertEqual(
-                    dest.get_shallow(),
-                    set([b'35e0b59e187dd72a0af294aedffc213eaa4d03ff',
-                         b'514dc6d3fbfe77361bcaef320c4d21b72bc10be9']))
+                dest.get_shallow(),
+                set(
+                    [
+                        b"35e0b59e187dd72a0af294aedffc213eaa4d03ff",
+                        b"514dc6d3fbfe77361bcaef320c4d21b72bc10be9",
+                    ]
+                ),
+            )
 
 
     def test_repeat(self):
     def test_repeat(self):
         c = self._client()
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            result = c.fetch(self._build_path('/server_new.export'), dest)
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
+            result = c.fetch(self._build_path("/server_new.export"), dest)
             for r in result.refs.items():
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
             self.assertDestEqualsSrc()
-            result = c.fetch(self._build_path('/server_new.export'), dest)
+            result = c.fetch(self._build_path("/server_new.export"), dest)
             for r in result.refs.items():
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
             self.assertDestEqualsSrc()
 
 
     def test_fetch_empty_pack(self):
     def test_fetch_empty_pack(self):
         c = self._client()
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            result = c.fetch(self._build_path('/server_new.export'), dest)
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
+            result = c.fetch(self._build_path("/server_new.export"), dest)
             for r in result.refs.items():
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
             self.assertDestEqualsSrc()
 
 
-            def dw(refs):
+            def dw(refs, **kwargs):
                 return list(refs.values())
                 return list(refs.values())
+
             result = c.fetch(
             result = c.fetch(
-                self._build_path('/server_new.export'), dest,
-                determine_wants=dw)
+                self._build_path("/server_new.export"),
+                dest,
+                determine_wants=dw,
+            )
             for r in result.refs.items():
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
             self.assertDestEqualsSrc()
@@ -269,20 +291,20 @@ class DulwichClientTestBase(object):
     def test_incremental_fetch_pack(self):
     def test_incremental_fetch_pack(self):
         self.test_fetch_pack()
         self.test_fetch_pack()
         dest, dummy = self.disable_ff_and_make_dummy_commit()
         dest, dummy = self.disable_ff_and_make_dummy_commit()
-        dest.refs[b'refs/heads/master'] = dummy
+        dest.refs[b"refs/heads/master"] = dummy
         c = self._client()
         c = self._client()
-        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        repo_dir = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(repo_dir) as dest:
         with repo.Repo(repo_dir) as dest:
-            result = c.fetch(self._build_path('/dest'), dest)
+            result = c.fetch(self._build_path("/dest"), dest)
             for r in result.refs.items():
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
             self.assertDestEqualsSrc()
 
 
     def test_fetch_pack_no_side_band_64k(self):
     def test_fetch_pack_no_side_band_64k(self):
         c = self._client()
         c = self._client()
-        c._fetch_capabilities.remove(b'side-band-64k')
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
-            result = c.fetch(self._build_path('/server_new.export'), dest)
+        c._fetch_capabilities.remove(b"side-band-64k")
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
+            result = c.fetch(self._build_path("/server_new.export"), dest)
             for r in result.refs.items():
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
             self.assertDestEqualsSrc()
@@ -291,84 +313,96 @@ class DulwichClientTestBase(object):
         # zero sha1s are already present on the client, and should
         # zero sha1s are already present on the client, and should
         # be ignored
         # be ignored
         c = self._client()
         c = self._client()
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
             result = c.fetch(
             result = c.fetch(
-                self._build_path('/server_new.export'), dest,
-                lambda refs: [protocol.ZERO_SHA])
+                self._build_path("/server_new.export"),
+                dest,
+                lambda refs, **kwargs: [protocol.ZERO_SHA],
+            )
             for r in result.refs.items():
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])
 
 
     def test_send_remove_branch(self):
     def test_send_remove_branch(self):
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
             dummy_commit = self.make_dummy_commit(dest)
             dummy_commit = self.make_dummy_commit(dest)
-            dest.refs[b'refs/heads/master'] = dummy_commit
-            dest.refs[b'refs/heads/abranch'] = dummy_commit
+            dest.refs[b"refs/heads/master"] = dummy_commit
+            dest.refs[b"refs/heads/abranch"] = dummy_commit
             sendrefs = dict(dest.refs)
             sendrefs = dict(dest.refs)
-            sendrefs[b'refs/heads/abranch'] = b"00" * 20
-            del sendrefs[b'HEAD']
+            sendrefs[b"refs/heads/abranch"] = b"00" * 20
+            del sendrefs[b"HEAD"]
 
 
             def gen_pack(have, want, ofs_delta=False):
             def gen_pack(have, want, ofs_delta=False):
                 return 0, []
                 return 0, []
+
             c = self._client()
             c = self._client()
             self.assertEqual(dest.refs[b"refs/heads/abranch"], dummy_commit)
             self.assertEqual(dest.refs[b"refs/heads/abranch"], dummy_commit)
-            c.send_pack(
-                self._build_path('/dest'), lambda _: sendrefs, gen_pack)
+            c.send_pack(self._build_path("/dest"), lambda _: sendrefs, gen_pack)
             self.assertFalse(b"refs/heads/abranch" in dest.refs)
             self.assertFalse(b"refs/heads/abranch" in dest.refs)
 
 
     def test_send_new_branch_empty_pack(self):
     def test_send_new_branch_empty_pack(self):
-        with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
+        with repo.Repo(os.path.join(self.gitroot, "dest")) as dest:
             dummy_commit = self.make_dummy_commit(dest)
             dummy_commit = self.make_dummy_commit(dest)
-            dest.refs[b'refs/heads/master'] = dummy_commit
-            dest.refs[b'refs/heads/abranch'] = dummy_commit
-            sendrefs = {b'refs/heads/bbranch': dummy_commit}
+            dest.refs[b"refs/heads/master"] = dummy_commit
+            dest.refs[b"refs/heads/abranch"] = dummy_commit
+            sendrefs = {b"refs/heads/bbranch": dummy_commit}
 
 
             def gen_pack(have, want, ofs_delta=False):
             def gen_pack(have, want, ofs_delta=False):
                 return 0, []
                 return 0, []
+
             c = self._client()
             c = self._client()
             self.assertEqual(dest.refs[b"refs/heads/abranch"], dummy_commit)
             self.assertEqual(dest.refs[b"refs/heads/abranch"], dummy_commit)
-            c.send_pack(
-                self._build_path('/dest'), lambda _: sendrefs, gen_pack)
+            c.send_pack(self._build_path("/dest"), lambda _: sendrefs, gen_pack)
             self.assertEqual(dummy_commit, dest.refs[b"refs/heads/abranch"])
             self.assertEqual(dummy_commit, dest.refs[b"refs/heads/abranch"])
 
 
     def test_get_refs(self):
     def test_get_refs(self):
         c = self._client()
         c = self._client()
-        refs = c.get_refs(self._build_path('/server_new.export'))
+        refs = c.get_refs(self._build_path("/server_new.export"))
 
 
-        repo_dir = os.path.join(self.gitroot, 'server_new.export')
+        repo_dir = os.path.join(self.gitroot, "server_new.export")
         with repo.Repo(repo_dir) as dest:
         with repo.Repo(repo_dir) as dest:
             self.assertDictEqual(dest.refs.as_dict(), refs)
             self.assertDictEqual(dest.refs.as_dict(), refs)
 
 
 
 
 class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
 class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
-
     def setUp(self):
     def setUp(self):
         CompatTestCase.setUp(self)
         CompatTestCase.setUp(self)
         DulwichClientTestBase.setUp(self)
         DulwichClientTestBase.setUp(self)
         if check_for_daemon(limit=1):
         if check_for_daemon(limit=1):
-            raise SkipTest('git-daemon was already running on port %s' %
-                           protocol.TCP_GIT_PORT)
-        fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
-                                            suffix=".pid")
+            raise SkipTest(
+                "git-daemon was already running on port %s" % protocol.TCP_GIT_PORT
+            )
+        fd, self.pidfile = tempfile.mkstemp(
+            prefix="dulwich-test-git-client", suffix=".pid"
+        )
         os.fdopen(fd).close()
         os.fdopen(fd).close()
-        args = [_DEFAULT_GIT, 'daemon', '--verbose', '--export-all',
-                '--pid-file=%s' % self.pidfile,
-                '--base-path=%s' % self.gitroot,
-                '--enable=receive-pack', '--enable=upload-archive',
-                '--listen=localhost', '--reuseaddr',
-                self.gitroot]
+        args = [
+            _DEFAULT_GIT,
+            "daemon",
+            "--verbose",
+            "--export-all",
+            "--pid-file=%s" % self.pidfile,
+            "--base-path=%s" % self.gitroot,
+            "--enable=receive-pack",
+            "--enable=upload-archive",
+            "--listen=localhost",
+            "--reuseaddr",
+            self.gitroot,
+        ]
         self.process = subprocess.Popen(
         self.process = subprocess.Popen(
-            args, cwd=self.gitroot,
-            stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+            args,
+            cwd=self.gitroot,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+        )
         if not check_for_daemon():
         if not check_for_daemon():
-            raise SkipTest('git-daemon failed to start')
+            raise SkipTest("git-daemon failed to start")
 
 
     def tearDown(self):
     def tearDown(self):
         with open(self.pidfile) as f:
         with open(self.pidfile) as f:
             pid = int(f.read().strip())
             pid = int(f.read().strip())
-        if sys.platform == 'win32':
+        if sys.platform == "win32":
             PROCESS_TERMINATE = 1
             PROCESS_TERMINATE = 1
-            handle = ctypes.windll.kernel32.OpenProcess(
-                PROCESS_TERMINATE, False, pid)
+            handle = ctypes.windll.kernel32.OpenProcess(PROCESS_TERMINATE, False, pid)
             ctypes.windll.kernel32.TerminateProcess(handle, -1)
             ctypes.windll.kernel32.TerminateProcess(handle, -1)
             ctypes.windll.kernel32.CloseHandle(handle)
             ctypes.windll.kernel32.CloseHandle(handle)
         else:
         else:
@@ -384,32 +418,42 @@ class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
         CompatTestCase.tearDown(self)
         CompatTestCase.tearDown(self)
 
 
     def _client(self):
     def _client(self):
-        return client.TCPGitClient('localhost')
+        return client.TCPGitClient("localhost")
 
 
     def _build_path(self, path):
     def _build_path(self, path):
         return path
         return path
 
 
-    if sys.platform == 'win32':
+    if sys.platform == "win32":
+
         @expectedFailure
         @expectedFailure
         def test_fetch_pack_no_side_band_64k(self):
         def test_fetch_pack_no_side_band_64k(self):
             DulwichClientTestBase.test_fetch_pack_no_side_band_64k(self)
             DulwichClientTestBase.test_fetch_pack_no_side_band_64k(self)
 
 
 
 
 class TestSSHVendor(object):
 class TestSSHVendor(object):
-
     @staticmethod
     @staticmethod
-    def run_command(host, command, username=None, port=None,
-                    password=None, key_filename=None):
-        cmd, path = command.split(' ')
-        cmd = cmd.split('-', 1)
+    def run_command(
+        host,
+        command,
+        username=None,
+        port=None,
+        password=None,
+        key_filename=None,
+    ):
+        cmd, path = command.split(" ")
+        cmd = cmd.split("-", 1)
         path = path.replace("'", "")
         path = path.replace("'", "")
-        p = subprocess.Popen(cmd + [path], bufsize=0, stdin=subprocess.PIPE,
-                             stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+        p = subprocess.Popen(
+            cmd + [path],
+            bufsize=0,
+            stdin=subprocess.PIPE,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE,
+        )
         return client.SubprocessWrapper(p)
         return client.SubprocessWrapper(p)
 
 
 
 
 class DulwichMockSSHClientTest(CompatTestCase, DulwichClientTestBase):
 class DulwichMockSSHClientTest(CompatTestCase, DulwichClientTestBase):
-
     def setUp(self):
     def setUp(self):
         CompatTestCase.setUp(self)
         CompatTestCase.setUp(self)
         DulwichClientTestBase.setUp(self)
         DulwichClientTestBase.setUp(self)
@@ -422,14 +466,13 @@ class DulwichMockSSHClientTest(CompatTestCase, DulwichClientTestBase):
         client.get_ssh_vendor = self.real_vendor
         client.get_ssh_vendor = self.real_vendor
 
 
     def _client(self):
     def _client(self):
-        return client.SSHGitClient('localhost')
+        return client.SSHGitClient("localhost")
 
 
     def _build_path(self, path):
     def _build_path(self, path):
         return self.gitroot + path
         return self.gitroot + path
 
 
 
 
 class DulwichSubprocessClientTest(CompatTestCase, DulwichClientTestBase):
 class DulwichSubprocessClientTest(CompatTestCase, DulwichClientTestBase):
-
     def setUp(self):
     def setUp(self):
         CompatTestCase.setUp(self)
         CompatTestCase.setUp(self)
         DulwichClientTestBase.setUp(self)
         DulwichClientTestBase.setUp(self)
@@ -461,11 +504,11 @@ class GitHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
     def send_head(self):
     def send_head(self):
         return self.run_backend()
         return self.run_backend()
 
 
-    def log_request(self, code='-', size='-'):
+    def log_request(self, code="-", size="-"):
         # Let's be quiet, the test suite is noisy enough already
         # Let's be quiet, the test suite is noisy enough already
         pass
         pass
 
 
-    def run_backend(self):
+    def run_backend(self):  # noqa: C901
         """Call out to git http-backend."""
         """Call out to git http-backend."""
         # Based on CGIHTTPServer.CGIHTTPRequestHandler.run_cgi:
         # Based on CGIHTTPServer.CGIHTTPRequestHandler.run_cgi:
         # Copyright (c) 2001-2010 Python Software Foundation;
         # Copyright (c) 2001-2010 Python Software Foundation;
@@ -473,83 +516,88 @@ class GitHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
         # Licensed under the Python Software Foundation License.
         # Licensed under the Python Software Foundation License.
         rest = self.path
         rest = self.path
         # find an explicit query string, if present.
         # find an explicit query string, if present.
-        i = rest.rfind('?')
+        i = rest.rfind("?")
         if i >= 0:
         if i >= 0:
-            rest, query = rest[:i], rest[i+1:]
+            rest, query = rest[:i], rest[i + 1 :]
         else:
         else:
-            query = ''
+            query = ""
 
 
         env = copy.deepcopy(os.environ)
         env = copy.deepcopy(os.environ)
-        env['SERVER_SOFTWARE'] = self.version_string()
-        env['SERVER_NAME'] = self.server.server_name
-        env['GATEWAY_INTERFACE'] = 'CGI/1.1'
-        env['SERVER_PROTOCOL'] = self.protocol_version
-        env['SERVER_PORT'] = str(self.server.server_port)
-        env['GIT_PROJECT_ROOT'] = self.server.root_path
+        env["SERVER_SOFTWARE"] = self.version_string()
+        env["SERVER_NAME"] = self.server.server_name
+        env["GATEWAY_INTERFACE"] = "CGI/1.1"
+        env["SERVER_PROTOCOL"] = self.protocol_version
+        env["SERVER_PORT"] = str(self.server.server_port)
+        env["GIT_PROJECT_ROOT"] = self.server.root_path
         env["GIT_HTTP_EXPORT_ALL"] = "1"
         env["GIT_HTTP_EXPORT_ALL"] = "1"
-        env['REQUEST_METHOD'] = self.command
+        env["REQUEST_METHOD"] = self.command
         uqrest = unquote(rest)
         uqrest = unquote(rest)
-        env['PATH_INFO'] = uqrest
-        env['SCRIPT_NAME'] = "/"
+        env["PATH_INFO"] = uqrest
+        env["SCRIPT_NAME"] = "/"
         if query:
         if query:
-            env['QUERY_STRING'] = query
+            env["QUERY_STRING"] = query
         host = self.address_string()
         host = self.address_string()
         if host != self.client_address[0]:
         if host != self.client_address[0]:
-            env['REMOTE_HOST'] = host
-        env['REMOTE_ADDR'] = self.client_address[0]
+            env["REMOTE_HOST"] = host
+        env["REMOTE_ADDR"] = self.client_address[0]
         authorization = self.headers.get("authorization")
         authorization = self.headers.get("authorization")
         if authorization:
         if authorization:
             authorization = authorization.split()
             authorization = authorization.split()
             if len(authorization) == 2:
             if len(authorization) == 2:
                 import base64
                 import base64
                 import binascii
                 import binascii
-                env['AUTH_TYPE'] = authorization[0]
+
+                env["AUTH_TYPE"] = authorization[0]
                 if authorization[0].lower() == "basic":
                 if authorization[0].lower() == "basic":
                     try:
                     try:
                         authorization = base64.decodestring(authorization[1])
                         authorization = base64.decodestring(authorization[1])
                     except binascii.Error:
                     except binascii.Error:
                         pass
                         pass
                     else:
                     else:
-                        authorization = authorization.split(':')
+                        authorization = authorization.split(":")
                         if len(authorization) == 2:
                         if len(authorization) == 2:
-                            env['REMOTE_USER'] = authorization[0]
+                            env["REMOTE_USER"] = authorization[0]
         # XXX REMOTE_IDENT
         # XXX REMOTE_IDENT
-        content_type = self.headers.get('content-type')
+        content_type = self.headers.get("content-type")
         if content_type:
         if content_type:
-            env['CONTENT_TYPE'] = content_type
-        length = self.headers.get('content-length')
+            env["CONTENT_TYPE"] = content_type
+        length = self.headers.get("content-length")
         if length:
         if length:
-            env['CONTENT_LENGTH'] = length
-        referer = self.headers.get('referer')
+            env["CONTENT_LENGTH"] = length
+        referer = self.headers.get("referer")
         if referer:
         if referer:
-            env['HTTP_REFERER'] = referer
+            env["HTTP_REFERER"] = referer
         accept = []
         accept = []
-        for line in self.headers.getallmatchingheaders('accept'):
+        for line in self.headers.getallmatchingheaders("accept"):
             if line[:1] in "\t\n\r ":
             if line[:1] in "\t\n\r ":
                 accept.append(line.strip())
                 accept.append(line.strip())
             else:
             else:
-                accept = accept + line[7:].split(',')
-        env['HTTP_ACCEPT'] = ','.join(accept)
-        ua = self.headers.get('user-agent')
+                accept = accept + line[7:].split(",")
+        env["HTTP_ACCEPT"] = ",".join(accept)
+        ua = self.headers.get("user-agent")
         if ua:
         if ua:
-            env['HTTP_USER_AGENT'] = ua
-        co = self.headers.get('cookie')
+            env["HTTP_USER_AGENT"] = ua
+        co = self.headers.get("cookie")
         if co:
         if co:
-            env['HTTP_COOKIE'] = co
+            env["HTTP_COOKIE"] = co
         # XXX Other HTTP_* headers
         # XXX Other HTTP_* headers
         # Since we're setting the env in the parent, provide empty
         # Since we're setting the env in the parent, provide empty
         # values to override previously set values
         # values to override previously set values
-        for k in ('QUERY_STRING', 'REMOTE_HOST', 'CONTENT_LENGTH',
-                  'HTTP_USER_AGENT', 'HTTP_COOKIE', 'HTTP_REFERER'):
+        for k in (
+            "QUERY_STRING",
+            "REMOTE_HOST",
+            "CONTENT_LENGTH",
+            "HTTP_USER_AGENT",
+            "HTTP_COOKIE",
+            "HTTP_REFERER",
+        ):
             env.setdefault(k, "")
             env.setdefault(k, "")
 
 
         self.wfile.write(b"HTTP/1.1 200 Script output follows\r\n")
         self.wfile.write(b"HTTP/1.1 200 Script output follows\r\n")
-        self.wfile.write(
-            ("Server: %s\r\n" % self.server.server_name).encode('ascii'))
-        self.wfile.write(
-            ("Date: %s\r\n" % self.date_time_string()).encode('ascii'))
+        self.wfile.write(("Server: %s\r\n" % self.server.server_name).encode("ascii"))
+        self.wfile.write(("Date: %s\r\n" % self.date_time_string()).encode("ascii"))
 
 
-        decoded_query = query.replace('+', ' ')
+        decoded_query = query.replace("+", " ")
 
 
         try:
         try:
             nbytes = int(length)
             nbytes = int(length)
@@ -559,16 +607,15 @@ class GitHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
             data = self.rfile.read(nbytes)
             data = self.rfile.read(nbytes)
         else:
         else:
             data = None
             data = None
-            env['CONTENT_LENGTH'] = '0'
+            env["CONTENT_LENGTH"] = "0"
         # throw away additional data [see bug #427345]
         # throw away additional data [see bug #427345]
         while select.select([self.rfile._sock], [], [], 0)[0]:
         while select.select([self.rfile._sock], [], [], 0)[0]:
             if not self.rfile._sock.recv(1):
             if not self.rfile._sock.recv(1):
                 break
                 break
-        args = ['http-backend']
-        if '=' not in decoded_query:
+        args = ["http-backend"]
+        if "=" not in decoded_query:
             args.append(decoded_query)
             args.append(decoded_query)
-        stdout = run_git_or_fail(
-            args, input=data, env=env, stderr=subprocess.PIPE)
+        stdout = run_git_or_fail(args, input=data, env=env, stderr=subprocess.PIPE)
         self.wfile.write(stdout)
         self.wfile.write(stdout)
 
 
 
 
@@ -577,13 +624,12 @@ class HTTPGitServer(http.server.HTTPServer):
     allow_reuse_address = True
     allow_reuse_address = True
 
 
     def __init__(self, server_address, root_path):
     def __init__(self, server_address, root_path):
-        http.server.HTTPServer.__init__(
-            self, server_address, GitHTTPRequestHandler)
+        http.server.HTTPServer.__init__(self, server_address, GitHTTPRequestHandler)
         self.root_path = root_path
         self.root_path = root_path
         self.server_name = "localhost"
         self.server_name = "localhost"
 
 
     def get_url(self):
     def get_url(self):
-        return 'http://%s:%s/' % (self.server_name, self.server_port)
+        return "http://%s:%s/" % (self.server_name, self.server_port)
 
 
 
 
 class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):
 class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):
@@ -596,10 +642,8 @@ class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):
         self._httpd = HTTPGitServer(("localhost", 0), self.gitroot)
         self._httpd = HTTPGitServer(("localhost", 0), self.gitroot)
         self.addCleanup(self._httpd.shutdown)
         self.addCleanup(self._httpd.shutdown)
         threading.Thread(target=self._httpd.serve_forever).start()
         threading.Thread(target=self._httpd.serve_forever).start()
-        run_git_or_fail(['config', 'http.uploadpack', 'true'],
-                        cwd=self.dest)
-        run_git_or_fail(['config', 'http.receivepack', 'true'],
-                        cwd=self.dest)
+        run_git_or_fail(["config", "http.uploadpack", "true"], cwd=self.dest)
+        run_git_or_fail(["config", "http.receivepack", "true"], cwd=self.dest)
 
 
     def tearDown(self):
     def tearDown(self):
         DulwichClientTestBase.tearDown(self)
         DulwichClientTestBase.tearDown(self)

+ 55 - 39
dulwich/tests/compat/test_pack.py

@@ -29,24 +29,24 @@ import tempfile
 
 
 from dulwich.pack import (
 from dulwich.pack import (
     write_pack,
     write_pack,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     SkipTest,
     SkipTest,
-    )
+)
 from dulwich.tests.test_pack import (
 from dulwich.tests.test_pack import (
     a_sha,
     a_sha,
     pack1_sha,
     pack1_sha,
     PackTests,
     PackTests,
-    )
+)
 from dulwich.tests.compat.utils import (
 from dulwich.tests.compat.utils import (
     require_git_version,
     require_git_version,
     run_git_or_fail,
     run_git_or_fail,
-    )
+)
 
 
-_NON_DELTA_RE = re.compile(b'non delta: (?P<non_delta>\\d+) objects')
+_NON_DELTA_RE = re.compile(b"non delta: (?P<non_delta>\\d+) objects")
 
 
 
 
 def _git_verify_pack_object_list(output):
 def _git_verify_pack_object_list(output):
@@ -75,28 +75,32 @@ class TestPack(PackTests):
             self.assertSucceeds(origpack.index.check)
             self.assertSucceeds(origpack.index.check)
             pack_path = os.path.join(self._tempdir, "Elch")
             pack_path = os.path.join(self._tempdir, "Elch")
             write_pack(pack_path, origpack.pack_tuples())
             write_pack(pack_path, origpack.pack_tuples())
-            output = run_git_or_fail(['verify-pack', '-v', pack_path])
-            orig_shas = set(o.id for o in origpack.iterobjects())
+            output = run_git_or_fail(["verify-pack", "-v", pack_path])
+            orig_shas = {o.id for o in origpack.iterobjects()}
             self.assertEqual(orig_shas, _git_verify_pack_object_list(output))
             self.assertEqual(orig_shas, _git_verify_pack_object_list(output))
 
 
     def test_deltas_work(self):
     def test_deltas_work(self):
         with self.get_pack(pack1_sha) as orig_pack:
         with self.get_pack(pack1_sha) as orig_pack:
             orig_blob = orig_pack[a_sha]
             orig_blob = orig_pack[a_sha]
             new_blob = Blob()
             new_blob = Blob()
-            new_blob.data = orig_blob.data + b'x'
+            new_blob.data = orig_blob.data + b"x"
             all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None)]
             all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None)]
-        pack_path = os.path.join(self._tempdir, 'pack_with_deltas')
+        pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         write_pack(pack_path, all_to_pack, deltify=True)
         write_pack(pack_path, all_to_pack, deltify=True)
-        output = run_git_or_fail(['verify-pack', '-v', pack_path])
-        self.assertEqual(set(x[0].id for x in all_to_pack),
-                         _git_verify_pack_object_list(output))
+        output = run_git_or_fail(["verify-pack", "-v", pack_path])
+        self.assertEqual(
+            {x[0].id for x in all_to_pack},
+            _git_verify_pack_object_list(output),
+        )
         # We specifically made a new blob that should be a delta
         # We specifically made a new blob that should be a delta
         # against the blob a_sha, so make sure we really got only 3
         # against the blob a_sha, so make sure we really got only 3
         # non-delta objects:
         # non-delta objects:
-        got_non_delta = int(_NON_DELTA_RE.search(output).group('non_delta'))
+        got_non_delta = int(_NON_DELTA_RE.search(output).group("non_delta"))
         self.assertEqual(
         self.assertEqual(
-            3, got_non_delta,
-            'Expected 3 non-delta objects, got %d' % got_non_delta)
+            3,
+            got_non_delta,
+            "Expected 3 non-delta objects, got %d" % got_non_delta,
+        )
 
 
     def test_delta_medium_object(self):
     def test_delta_medium_object(self):
         # This tests an object set that will have a copy operation
         # This tests an object set that will have a copy operation
@@ -104,26 +108,32 @@ class TestPack(PackTests):
         with self.get_pack(pack1_sha) as orig_pack:
         with self.get_pack(pack1_sha) as orig_pack:
             orig_blob = orig_pack[a_sha]
             orig_blob = orig_pack[a_sha]
             new_blob = Blob()
             new_blob = Blob()
-            new_blob.data = orig_blob.data + (b'x' * 2 ** 20)
+            new_blob.data = orig_blob.data + (b"x" * 2 ** 20)
             new_blob_2 = Blob()
             new_blob_2 = Blob()
-            new_blob_2.data = new_blob.data + b'y'
-            all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None),
-                                                           (new_blob_2, None)]
-        pack_path = os.path.join(self._tempdir, 'pack_with_deltas')
+            new_blob_2.data = new_blob.data + b"y"
+            all_to_pack = list(orig_pack.pack_tuples()) + [
+                (new_blob, None),
+                (new_blob_2, None),
+            ]
+        pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         write_pack(pack_path, all_to_pack, deltify=True)
         write_pack(pack_path, all_to_pack, deltify=True)
-        output = run_git_or_fail(['verify-pack', '-v', pack_path])
-        self.assertEqual(set(x[0].id for x in all_to_pack),
-                         _git_verify_pack_object_list(output))
+        output = run_git_or_fail(["verify-pack", "-v", pack_path])
+        self.assertEqual(
+            {x[0].id for x in all_to_pack},
+            _git_verify_pack_object_list(output),
+        )
         # We specifically made a new blob that should be a delta
         # We specifically made a new blob that should be a delta
         # against the blob a_sha, so make sure we really got only 3
         # against the blob a_sha, so make sure we really got only 3
         # non-delta objects:
         # non-delta objects:
-        got_non_delta = int(_NON_DELTA_RE.search(output).group('non_delta'))
+        got_non_delta = int(_NON_DELTA_RE.search(output).group("non_delta"))
         self.assertEqual(
         self.assertEqual(
-            3, got_non_delta,
-            'Expected 3 non-delta objects, got %d' % got_non_delta)
+            3,
+            got_non_delta,
+            "Expected 3 non-delta objects, got %d" % got_non_delta,
+        )
         # We expect one object to have a delta chain length of two
         # We expect one object to have a delta chain length of two
         # (new_blob_2), so let's verify that actually happens:
         # (new_blob_2), so let's verify that actually happens:
-        self.assertIn(b'chain length = 2', output)
+        self.assertIn(b"chain length = 2", output)
 
 
     # This test is SUPER slow: over 80 seconds on a 2012-era
     # This test is SUPER slow: over 80 seconds on a 2012-era
     # laptop. This is because SequenceMatcher is worst-case quadratic
     # laptop. This is because SequenceMatcher is worst-case quadratic
@@ -134,23 +144,29 @@ class TestPack(PackTests):
         # This tests an object set that will have a copy operation
         # This tests an object set that will have a copy operation
         # 2**25 in size. This is a copy large enough that it requires
         # 2**25 in size. This is a copy large enough that it requires
         # two copy operations in git's binary delta format.
         # two copy operations in git's binary delta format.
-        raise SkipTest('skipping slow, large test')
+        raise SkipTest("skipping slow, large test")
         with self.get_pack(pack1_sha) as orig_pack:
         with self.get_pack(pack1_sha) as orig_pack:
             new_blob = Blob()
             new_blob = Blob()
-            new_blob.data = 'big blob' + ('x' * 2 ** 25)
+            new_blob.data = "big blob" + ("x" * 2 ** 25)
             new_blob_2 = Blob()
             new_blob_2 = Blob()
-            new_blob_2.data = new_blob.data + 'y'
-            all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None),
-                                                           (new_blob_2, None)]
+            new_blob_2.data = new_blob.data + "y"
+            all_to_pack = list(orig_pack.pack_tuples()) + [
+                (new_blob, None),
+                (new_blob_2, None),
+            ]
         pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         write_pack(pack_path, all_to_pack, deltify=True)
         write_pack(pack_path, all_to_pack, deltify=True)
-        output = run_git_or_fail(['verify-pack', '-v', pack_path])
-        self.assertEqual(set(x[0].id for x in all_to_pack),
-                         _git_verify_pack_object_list(output))
+        output = run_git_or_fail(["verify-pack", "-v", pack_path])
+        self.assertEqual(
+            {x[0].id for x in all_to_pack},
+            _git_verify_pack_object_list(output),
+        )
         # We specifically made a new blob that should be a delta
         # We specifically made a new blob that should be a delta
         # against the blob a_sha, so make sure we really got only 4
         # against the blob a_sha, so make sure we really got only 4
         # non-delta objects:
         # non-delta objects:
-        got_non_delta = int(_NON_DELTA_RE.search(output).group('non_delta'))
+        got_non_delta = int(_NON_DELTA_RE.search(output).group("non_delta"))
         self.assertEqual(
         self.assertEqual(
-            4, got_non_delta,
-            'Expected 4 non-delta objects, got %d' % got_non_delta)
+            4,
+            got_non_delta,
+            "Expected 4 non-delta objects, got %d" % got_non_delta,
+        )

+ 6 - 6
dulwich/tests/compat/test_patch.py

@@ -27,20 +27,19 @@ import tempfile
 from dulwich import porcelain
 from dulwich import porcelain
 from dulwich.repo import (
 from dulwich.repo import (
     Repo,
     Repo,
-    )
+)
 from dulwich.tests.compat.utils import (
 from dulwich.tests.compat.utils import (
     CompatTestCase,
     CompatTestCase,
     run_git_or_fail,
     run_git_or_fail,
-    )
+)
 
 
 
 
 class CompatPatchTestCase(CompatTestCase):
 class CompatPatchTestCase(CompatTestCase):
-
     def setUp(self):
     def setUp(self):
         super(CompatPatchTestCase, self).setUp()
         super(CompatPatchTestCase, self).setUp()
         self.test_dir = tempfile.mkdtemp()
         self.test_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, self.test_dir)
         self.addCleanup(shutil.rmtree, self.test_dir)
-        self.repo_path = os.path.join(self.test_dir, 'repo')
+        self.repo_path = os.path.join(self.test_dir, "repo")
         self.repo = Repo.init(self.repo_path, mkdir=True)
         self.repo = Repo.init(self.repo_path, mkdir=True)
         self.addCleanup(self.repo.close)
         self.addCleanup(self.repo.close)
 
 
@@ -82,8 +81,9 @@ class CompatPatchTestCase(CompatTestCase):
         second_tree = self.repo[second_commit].tree
         second_tree = self.repo[second_commit].tree
 
 
         outstream = BytesIO()
         outstream = BytesIO()
-        porcelain.diff_tree(self.repo.path, first_tree, second_tree,
-                            outstream=outstream)
+        porcelain.diff_tree(
+            self.repo.path, first_tree, second_tree, outstream=outstream
+        )
 
 
         # Save it on disk
         # Save it on disk
         patch_path = os.path.join(self.test_dir, "patch.patch")
         patch_path = os.path.join(self.test_dir, "patch.patch")

+ 101 - 0
dulwich/tests/compat/test_porcelain.py

@@ -0,0 +1,101 @@
+# test_porcelain .py -- Tests for dulwich.porcelain/CGit compatibility
+# Copyright (C) 2010 Google, Inc.
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Compatibility tests for dulwich.porcelain."""
+
+import os
+import platform
+import sys
+from unittest import skipIf
+
+from dulwich import porcelain
+from dulwich.tests.utils import (
+    build_commit_graph,
+)
+from dulwich.tests.compat.utils import (
+    run_git_or_fail,
+    CompatTestCase,
+)
+from dulwich.tests.test_porcelain import (
+    PorcelainGpgTestCase,
+)
+
+
+@skipIf(platform.python_implementation() == "PyPy" or sys.platform == "win32", "gpgme not easily available or supported on Windows and PyPy")
+class TagCreateSignTestCase(PorcelainGpgTestCase, CompatTestCase):
+    def setUp(self):
+        super(TagCreateSignTestCase, self).setUp()
+
+    def test_sign(self):
+        # Test that dulwich signatures can be verified by CGit
+        c1, c2, c3 = build_commit_graph(
+            self.repo.object_store, [[1], [2, 1], [3, 1, 2]]
+        )
+        self.repo.refs[b"HEAD"] = c3.id
+        cfg = self.repo.get_config()
+        cfg.set(("user",), "signingKey", PorcelainGpgTestCase.DEFAULT_KEY_ID)
+        self.import_default_key()
+
+        porcelain.tag_create(
+            self.repo.path,
+            b"tryme",
+            b"foo <foo@bar.com>",
+            b"bar",
+            annotated=True,
+            sign=True,
+        )
+
+        run_git_or_fail(
+            [
+                "--git-dir={}".format(self.repo.controldir()),
+                "tag",
+                "-v",
+                "tryme"
+            ],
+            env={'GNUPGHOME': os.environ['GNUPGHOME']},
+        )
+
+    def test_verify(self):
+        # Test that CGit signatures can be verified by dulwich
+        c1, c2, c3 = build_commit_graph(
+            self.repo.object_store, [[1], [2, 1], [3, 1, 2]]
+        )
+        self.repo.refs[b"HEAD"] = c3.id
+        self.import_default_key()
+
+        run_git_or_fail(
+            [
+                "--git-dir={}".format(self.repo.controldir()),
+                "tag",
+                "-u",
+                PorcelainGpgTestCase.DEFAULT_KEY_ID,
+                "-m",
+                "foo",
+                "verifyme",
+            ],
+            env={
+                'GNUPGHOME': os.environ['GNUPGHOME'],
+                'GIT_COMMITTER_NAME': 'Joe Example',
+                'GIT_COMMITTER_EMAIL': 'joe@example.com',
+                },
+        )
+        tag = self.repo[b"refs/tags/verifyme"]
+        self.assertNotEqual(tag.signature, None)
+        tag.verify()

+ 41 - 43
dulwich/tests/compat/test_repository.py

@@ -28,17 +28,17 @@ import tempfile
 
 
 from dulwich.objects import (
 from dulwich.objects import (
     hex_to_sha,
     hex_to_sha,
-    )
+)
 from dulwich.repo import (
 from dulwich.repo import (
     check_ref_format,
     check_ref_format,
     Repo,
     Repo,
-    )
+)
 from dulwich.tests.compat.utils import (
 from dulwich.tests.compat.utils import (
     require_git_version,
     require_git_version,
     rmtree_ro,
     rmtree_ro,
     run_git_or_fail,
     run_git_or_fail,
     CompatTestCase,
     CompatTestCase,
-    )
+)
 
 
 
 
 class ObjectStoreTestCase(CompatTestCase):
 class ObjectStoreTestCase(CompatTestCase):
@@ -46,7 +46,7 @@ class ObjectStoreTestCase(CompatTestCase):
 
 
     def setUp(self):
     def setUp(self):
         super(ObjectStoreTestCase, self).setUp()
         super(ObjectStoreTestCase, self).setUp()
-        self._repo = self.import_repo('server_new.export')
+        self._repo = self.import_repo("server_new.export")
 
 
     def _run_git(self, args):
     def _run_git(self, args):
         return run_git_or_fail(args, cwd=self._repo.path)
         return run_git_or_fail(args, cwd=self._repo.path)
@@ -54,7 +54,7 @@ class ObjectStoreTestCase(CompatTestCase):
     def _parse_refs(self, output):
     def _parse_refs(self, output):
         refs = {}
         refs = {}
         for line in BytesIO(output):
         for line in BytesIO(output):
-            fields = line.rstrip(b'\n').split(b' ')
+            fields = line.rstrip(b"\n").split(b" ")
             self.assertEqual(3, len(fields))
             self.assertEqual(3, len(fields))
             refname, type_name, sha = fields
             refname, type_name, sha = fields
             check_ref_format(refname[5:])
             check_ref_format(refname[5:])
@@ -63,26 +63,27 @@ class ObjectStoreTestCase(CompatTestCase):
         return refs
         return refs
 
 
     def _parse_objects(self, output):
     def _parse_objects(self, output):
-        return set(s.rstrip(b'\n').split(b' ')[0] for s in BytesIO(output))
+        return {s.rstrip(b"\n").split(b" ")[0] for s in BytesIO(output)}
 
 
     def test_bare(self):
     def test_bare(self):
         self.assertTrue(self._repo.bare)
         self.assertTrue(self._repo.bare)
-        self.assertFalse(os.path.exists(os.path.join(self._repo.path, '.git')))
+        self.assertFalse(os.path.exists(os.path.join(self._repo.path, ".git")))
 
 
     def test_head(self):
     def test_head(self):
-        output = self._run_git(['rev-parse', 'HEAD'])
-        head_sha = output.rstrip(b'\n')
+        output = self._run_git(["rev-parse", "HEAD"])
+        head_sha = output.rstrip(b"\n")
         hex_to_sha(head_sha)
         hex_to_sha(head_sha)
-        self.assertEqual(head_sha, self._repo.refs[b'HEAD'])
+        self.assertEqual(head_sha, self._repo.refs[b"HEAD"])
 
 
     def test_refs(self):
     def test_refs(self):
         output = self._run_git(
         output = self._run_git(
-          ['for-each-ref', '--format=%(refname) %(objecttype) %(objectname)'])
+            ["for-each-ref", "--format=%(refname) %(objecttype) %(objectname)"]
+        )
         expected_refs = self._parse_refs(output)
         expected_refs = self._parse_refs(output)
 
 
         actual_refs = {}
         actual_refs = {}
         for refname, sha in self._repo.refs.as_dict().items():
         for refname, sha in self._repo.refs.as_dict().items():
-            if refname == b'HEAD':
+            if refname == b"HEAD":
                 continue  # handled in test_head
                 continue  # handled in test_head
             obj = self._repo[sha]
             obj = self._repo[sha]
             self.assertEqual(sha, obj.id)
             self.assertEqual(sha, obj.id)
@@ -92,12 +93,11 @@ class ObjectStoreTestCase(CompatTestCase):
     # TODO(dborowitz): peeled ref tests
     # TODO(dborowitz): peeled ref tests
 
 
     def _get_loose_shas(self):
     def _get_loose_shas(self):
-        output = self._run_git(
-            ['rev-list', '--all', '--objects', '--unpacked'])
+        output = self._run_git(["rev-list", "--all", "--objects", "--unpacked"])
         return self._parse_objects(output)
         return self._parse_objects(output)
 
 
     def _get_all_shas(self):
     def _get_all_shas(self):
-        output = self._run_git(['rev-list', '--all', '--objects'])
+        output = self._run_git(["rev-list", "--all", "--objects"])
         return self._parse_objects(output)
         return self._parse_objects(output)
 
 
     def assertShasMatch(self, expected_shas, actual_shas_iter):
     def assertShasMatch(self, expected_shas, actual_shas_iter):
@@ -112,14 +112,14 @@ class ObjectStoreTestCase(CompatTestCase):
         # TODO(dborowitz): This is currently not very useful since
         # TODO(dborowitz): This is currently not very useful since
         # fast-imported repos only contained packed objects.
         # fast-imported repos only contained packed objects.
         expected_shas = self._get_loose_shas()
         expected_shas = self._get_loose_shas()
-        self.assertShasMatch(expected_shas,
-                             self._repo.object_store._iter_loose_objects())
+        self.assertShasMatch(
+            expected_shas, self._repo.object_store._iter_loose_objects()
+        )
 
 
     def test_packed_objects(self):
     def test_packed_objects(self):
         expected_shas = self._get_all_shas() - self._get_loose_shas()
         expected_shas = self._get_all_shas() - self._get_loose_shas()
         self.assertShasMatch(
         self.assertShasMatch(
-            expected_shas,
-            chain.from_iterable(self._repo.object_store.packs)
+            expected_shas, chain.from_iterable(self._repo.object_store.packs)
         )
         )
 
 
     def test_all_objects(self):
     def test_all_objects(self):
@@ -142,15 +142,13 @@ class WorkingTreeTestCase(ObjectStoreTestCase):
         Returns: The path to the new working tree.
         Returns: The path to the new working tree.
         """
         """
         temp_dir = tempfile.mkdtemp()
         temp_dir = tempfile.mkdtemp()
-        run_git_or_fail(['worktree', 'add', temp_dir, branch],
-                        cwd=repo_dir)
+        run_git_or_fail(["worktree", "add", temp_dir, branch], cwd=repo_dir)
         self.addCleanup(rmtree_ro, temp_dir)
         self.addCleanup(rmtree_ro, temp_dir)
         return temp_dir
         return temp_dir
 
 
     def setUp(self):
     def setUp(self):
         super(WorkingTreeTestCase, self).setUp()
         super(WorkingTreeTestCase, self).setUp()
-        self._worktree_path = self.create_new_worktree(
-            self._repo.path, 'branch')
+        self._worktree_path = self.create_new_worktree(self._repo.path, "branch")
         self._worktree_repo = Repo(self._worktree_path)
         self._worktree_repo = Repo(self._worktree_path)
         self.addCleanup(self._worktree_repo.close)
         self.addCleanup(self._worktree_repo.close)
         self._mainworktree_repo = self._repo
         self._mainworktree_repo = self._repo
@@ -159,42 +157,40 @@ class WorkingTreeTestCase(ObjectStoreTestCase):
 
 
     def test_refs(self):
     def test_refs(self):
         super(WorkingTreeTestCase, self).test_refs()
         super(WorkingTreeTestCase, self).test_refs()
-        self.assertEqual(self._mainworktree_repo.refs.allkeys(),
-                         self._repo.refs.allkeys())
+        self.assertEqual(
+            self._mainworktree_repo.refs.allkeys(), self._repo.refs.allkeys()
+        )
 
 
     def test_head_equality(self):
     def test_head_equality(self):
-        self.assertNotEqual(self._repo.refs[b'HEAD'],
-                            self._mainworktree_repo.refs[b'HEAD'])
+        self.assertNotEqual(
+            self._repo.refs[b"HEAD"], self._mainworktree_repo.refs[b"HEAD"]
+        )
 
 
     def test_bare(self):
     def test_bare(self):
         self.assertFalse(self._repo.bare)
         self.assertFalse(self._repo.bare)
-        self.assertTrue(os.path.isfile(os.path.join(self._repo.path, '.git')))
+        self.assertTrue(os.path.isfile(os.path.join(self._repo.path, ".git")))
 
 
     def _parse_worktree_list(self, output):
     def _parse_worktree_list(self, output):
         worktrees = []
         worktrees = []
         for line in BytesIO(output):
         for line in BytesIO(output):
-            fields = line.rstrip(b'\n').split()
+            fields = line.rstrip(b"\n").split()
             worktrees.append(tuple(f.decode() for f in fields))
             worktrees.append(tuple(f.decode() for f in fields))
         return worktrees
         return worktrees
 
 
     def test_git_worktree_list(self):
     def test_git_worktree_list(self):
         # 'git worktree list' was introduced in 2.7.0
         # 'git worktree list' was introduced in 2.7.0
         require_git_version((2, 7, 0))
         require_git_version((2, 7, 0))
-        output = run_git_or_fail(['worktree', 'list'], cwd=self._repo.path)
+        output = run_git_or_fail(["worktree", "list"], cwd=self._repo.path)
         worktrees = self._parse_worktree_list(output)
         worktrees = self._parse_worktree_list(output)
         self.assertEqual(len(worktrees), self._number_of_working_tree)
         self.assertEqual(len(worktrees), self._number_of_working_tree)
-        self.assertEqual(worktrees[0][1], '(bare)')
-        self.assertTrue(
-            os.path.samefile(worktrees[0][0], self._mainworktree_repo.path))
+        self.assertEqual(worktrees[0][1], "(bare)")
+        self.assertTrue(os.path.samefile(worktrees[0][0], self._mainworktree_repo.path))
 
 
-        output = run_git_or_fail(
-            ['worktree', 'list'], cwd=self._mainworktree_repo.path)
+        output = run_git_or_fail(["worktree", "list"], cwd=self._mainworktree_repo.path)
         worktrees = self._parse_worktree_list(output)
         worktrees = self._parse_worktree_list(output)
         self.assertEqual(len(worktrees), self._number_of_working_tree)
         self.assertEqual(len(worktrees), self._number_of_working_tree)
-        self.assertEqual(worktrees[0][1], '(bare)')
-        self.assertTrue(os.path.samefile(
-            worktrees[0][0],
-            self._mainworktree_repo.path))
+        self.assertEqual(worktrees[0][1], "(bare)")
+        self.assertTrue(os.path.samefile(worktrees[0][0], self._mainworktree_repo.path))
 
 
 
 
 class InitNewWorkingDirectoryTestCase(WorkingTreeTestCase):
 class InitNewWorkingDirectoryTestCase(WorkingTreeTestCase):
@@ -208,14 +204,16 @@ class InitNewWorkingDirectoryTestCase(WorkingTreeTestCase):
         worktree_repo_path = tempfile.mkdtemp()
         worktree_repo_path = tempfile.mkdtemp()
         self.addCleanup(rmtree_ro, worktree_repo_path)
         self.addCleanup(rmtree_ro, worktree_repo_path)
         self._repo = Repo._init_new_working_directory(
         self._repo = Repo._init_new_working_directory(
-            worktree_repo_path, self._mainworktree_repo)
+            worktree_repo_path, self._mainworktree_repo
+        )
         self.addCleanup(self._repo.close)
         self.addCleanup(self._repo.close)
         self._number_of_working_tree = 3
         self._number_of_working_tree = 3
 
 
     def test_head_equality(self):
     def test_head_equality(self):
-        self.assertEqual(self._repo.refs[b'HEAD'],
-                         self._mainworktree_repo.refs[b'HEAD'])
+        self.assertEqual(
+            self._repo.refs[b"HEAD"], self._mainworktree_repo.refs[b"HEAD"]
+        )
 
 
     def test_bare(self):
     def test_bare(self):
         self.assertFalse(self._repo.bare)
         self.assertFalse(self._repo.bare)
-        self.assertTrue(os.path.isfile(os.path.join(self._repo.path, '.git')))
+        self.assertTrue(os.path.isfile(os.path.join(self._repo.path, ".git")))

+ 14 - 17
dulwich/tests/compat/test_server.py

@@ -32,40 +32,38 @@ import sys
 from dulwich.server import (
 from dulwich.server import (
     DictBackend,
     DictBackend,
     TCPGitServer,
     TCPGitServer,
-    )
+)
 from dulwich.tests import skipIf
 from dulwich.tests import skipIf
 from dulwich.tests.compat.server_utils import (
 from dulwich.tests.compat.server_utils import (
     ServerTests,
     ServerTests,
     NoSideBand64kReceivePackHandler,
     NoSideBand64kReceivePackHandler,
-    )
+)
 from dulwich.tests.compat.utils import (
 from dulwich.tests.compat.utils import (
     CompatTestCase,
     CompatTestCase,
     require_git_version,
     require_git_version,
-    )
+)
 
 
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class GitServerTestCase(ServerTests, CompatTestCase):
 class GitServerTestCase(ServerTests, CompatTestCase):
     """Tests for client/server compatibility.
     """Tests for client/server compatibility.
 
 
     This server test case does not use side-band-64k in git-receive-pack.
     This server test case does not use side-band-64k in git-receive-pack.
     """
     """
 
 
-    protocol = 'git'
+    protocol = "git"
 
 
     def _handlers(self):
     def _handlers(self):
-        return {b'git-receive-pack': NoSideBand64kReceivePackHandler}
+        return {b"git-receive-pack": NoSideBand64kReceivePackHandler}
 
 
     def _check_server(self, dul_server):
     def _check_server(self, dul_server):
-        receive_pack_handler_cls = dul_server.handlers[b'git-receive-pack']
+        receive_pack_handler_cls = dul_server.handlers[b"git-receive-pack"]
         caps = receive_pack_handler_cls.capabilities()
         caps = receive_pack_handler_cls.capabilities()
-        self.assertFalse(b'side-band-64k' in caps)
+        self.assertFalse(b"side-band-64k" in caps)
 
 
     def _start_server(self, repo):
     def _start_server(self, repo):
-        backend = DictBackend({b'/': repo})
-        dul_server = TCPGitServer(backend, b'localhost', 0,
-                                  handlers=self._handlers())
+        backend = DictBackend({b"/": repo})
+        dul_server = TCPGitServer(backend, b"localhost", 0, handlers=self._handlers())
         self._check_server(dul_server)
         self._check_server(dul_server)
         self.addCleanup(dul_server.shutdown)
         self.addCleanup(dul_server.shutdown)
         self.addCleanup(dul_server.server_close)
         self.addCleanup(dul_server.server_close)
@@ -75,8 +73,7 @@ class GitServerTestCase(ServerTests, CompatTestCase):
         return port
         return port
 
 
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class GitServerSideBand64kTestCase(GitServerTestCase):
 class GitServerSideBand64kTestCase(GitServerTestCase):
     """Tests for client/server compatibility with side-band-64k support."""
     """Tests for client/server compatibility with side-band-64k support."""
 
 
@@ -88,13 +85,13 @@ class GitServerSideBand64kTestCase(GitServerTestCase):
         # side-band-64k is broken in the windows client.
         # side-band-64k is broken in the windows client.
         # https://github.com/msysgit/git/issues/101
         # https://github.com/msysgit/git/issues/101
         # Fix has landed for the 1.9.3 release.
         # Fix has landed for the 1.9.3 release.
-        if os.name == 'nt':
+        if os.name == "nt":
             require_git_version((1, 9, 3))
             require_git_version((1, 9, 3))
 
 
     def _handlers(self):
     def _handlers(self):
         return None  # default handlers include side-band-64k
         return None  # default handlers include side-band-64k
 
 
     def _check_server(self, server):
     def _check_server(self, server):
-        receive_pack_handler_cls = server.handlers[b'git-receive-pack']
+        receive_pack_handler_cls = server.handlers[b"git-receive-pack"]
         caps = receive_pack_handler_cls.capabilities()
         caps = receive_pack_handler_cls.capabilities()
-        self.assertTrue(b'side-band-64k' in caps)
+        self.assertTrue(b"side-band-64k" in caps)

+ 12 - 14
dulwich/tests/compat/test_utils.py

@@ -23,20 +23,20 @@
 from dulwich.tests import (
 from dulwich.tests import (
     SkipTest,
     SkipTest,
     TestCase,
     TestCase,
-    )
+)
 from dulwich.tests.compat import utils
 from dulwich.tests.compat import utils
 
 
 
 
 class GitVersionTests(TestCase):
 class GitVersionTests(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(GitVersionTests, self).setUp()
         super(GitVersionTests, self).setUp()
         self._orig_run_git = utils.run_git
         self._orig_run_git = utils.run_git
         self._version_str = None  # tests can override to set stub version
         self._version_str = None  # tests can override to set stub version
 
 
         def run_git(args, **unused_kwargs):
         def run_git(args, **unused_kwargs):
-            self.assertEqual(['--version'], args)
+            self.assertEqual(["--version"], args)
             return 0, self._version_str
             return 0, self._version_str
+
         utils.run_git = run_git
         utils.run_git = run_git
 
 
     def tearDown(self):
     def tearDown(self):
@@ -44,19 +44,19 @@ class GitVersionTests(TestCase):
         utils.run_git = self._orig_run_git
         utils.run_git = self._orig_run_git
 
 
     def test_git_version_none(self):
     def test_git_version_none(self):
-        self._version_str = b'not a git version'
+        self._version_str = b"not a git version"
         self.assertEqual(None, utils.git_version())
         self.assertEqual(None, utils.git_version())
 
 
     def test_git_version_3(self):
     def test_git_version_3(self):
-        self._version_str = b'git version 1.6.6'
+        self._version_str = b"git version 1.6.6"
         self.assertEqual((1, 6, 6, 0), utils.git_version())
         self.assertEqual((1, 6, 6, 0), utils.git_version())
 
 
     def test_git_version_4(self):
     def test_git_version_4(self):
-        self._version_str = b'git version 1.7.0.2'
+        self._version_str = b"git version 1.7.0.2"
         self.assertEqual((1, 7, 0, 2), utils.git_version())
         self.assertEqual((1, 7, 0, 2), utils.git_version())
 
 
     def test_git_version_extra(self):
     def test_git_version_extra(self):
-        self._version_str = b'git version 1.7.0.3.295.gd8fa2'
+        self._version_str = b"git version 1.7.0.3.295.gd8fa2"
         self.assertEqual((1, 7, 0, 3), utils.git_version())
         self.assertEqual((1, 7, 0, 3), utils.git_version())
 
 
     def assertRequireSucceeds(self, required_version):
     def assertRequireSucceeds(self, required_version):
@@ -66,22 +66,20 @@ class GitVersionTests(TestCase):
             self.fail()
             self.fail()
 
 
     def assertRequireFails(self, required_version):
     def assertRequireFails(self, required_version):
-        self.assertRaises(SkipTest, utils.require_git_version,
-                          required_version)
+        self.assertRaises(SkipTest, utils.require_git_version, required_version)
 
 
     def test_require_git_version(self):
     def test_require_git_version(self):
         try:
         try:
-            self._version_str = b'git version 1.6.6'
+            self._version_str = b"git version 1.6.6"
             self.assertRequireSucceeds((1, 6, 6))
             self.assertRequireSucceeds((1, 6, 6))
             self.assertRequireSucceeds((1, 6, 6, 0))
             self.assertRequireSucceeds((1, 6, 6, 0))
             self.assertRequireSucceeds((1, 6, 5))
             self.assertRequireSucceeds((1, 6, 5))
             self.assertRequireSucceeds((1, 6, 5, 99))
             self.assertRequireSucceeds((1, 6, 5, 99))
             self.assertRequireFails((1, 7, 0))
             self.assertRequireFails((1, 7, 0))
             self.assertRequireFails((1, 7, 0, 2))
             self.assertRequireFails((1, 7, 0, 2))
-            self.assertRaises(ValueError, utils.require_git_version,
-                              (1, 6, 6, 0, 0))
+            self.assertRaises(ValueError, utils.require_git_version, (1, 6, 6, 0, 0))
 
 
-            self._version_str = b'git version 1.7.0.2'
+            self._version_str = b"git version 1.7.0.2"
             self.assertRequireSucceeds((1, 6, 6))
             self.assertRequireSucceeds((1, 6, 6))
             self.assertRequireSucceeds((1, 6, 6, 0))
             self.assertRequireSucceeds((1, 6, 6, 0))
             self.assertRequireSucceeds((1, 7, 0))
             self.assertRequireSucceeds((1, 7, 0))
@@ -90,4 +88,4 @@ class GitVersionTests(TestCase):
             self.assertRequireFails((1, 7, 1))
             self.assertRequireFails((1, 7, 1))
         except SkipTest as e:
         except SkipTest as e:
             # This test is designed to catch all SkipTest exceptions.
             # This test is designed to catch all SkipTest exceptions.
-            self.fail('Test unexpectedly skipped: %s' % e)
+            self.fail("Test unexpectedly skipped: %s" % e)

+ 36 - 34
dulwich/tests/compat/test_web.py

@@ -34,29 +34,28 @@ from dulwich.server import (
     DictBackend,
     DictBackend,
     UploadPackHandler,
     UploadPackHandler,
     ReceivePackHandler,
     ReceivePackHandler,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     SkipTest,
     SkipTest,
     skipIf,
     skipIf,
-    )
+)
 from dulwich.web import (
 from dulwich.web import (
     make_wsgi_chain,
     make_wsgi_chain,
     HTTPGitApplication,
     HTTPGitApplication,
     WSGIRequestHandlerLogger,
     WSGIRequestHandlerLogger,
     WSGIServerLogger,
     WSGIServerLogger,
-    )
+)
 
 
 from dulwich.tests.compat.server_utils import (
 from dulwich.tests.compat.server_utils import (
     ServerTests,
     ServerTests,
     NoSideBand64kReceivePackHandler,
     NoSideBand64kReceivePackHandler,
-    )
+)
 from dulwich.tests.compat.utils import (
 from dulwich.tests.compat.utils import (
     CompatTestCase,
     CompatTestCase,
-    )
+)
 
 
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class WebTests(ServerTests):
 class WebTests(ServerTests):
     """Base tests for web server tests.
     """Base tests for web server tests.
 
 
@@ -64,14 +63,18 @@ class WebTests(ServerTests):
     TestCase so tests are not automatically run.
     TestCase so tests are not automatically run.
     """
     """
 
 
-    protocol = 'http'
+    protocol = "http"
 
 
     def _start_server(self, repo):
     def _start_server(self, repo):
-        backend = DictBackend({'/': repo})
+        backend = DictBackend({"/": repo})
         app = self._make_app(backend)
         app = self._make_app(backend)
         dul_server = simple_server.make_server(
         dul_server = simple_server.make_server(
-          'localhost', 0, app, server_class=WSGIServerLogger,
-          handler_class=WSGIRequestHandlerLogger)
+            "localhost",
+            0,
+            app,
+            server_class=WSGIServerLogger,
+            handler_class=WSGIRequestHandlerLogger,
+        )
         self.addCleanup(dul_server.shutdown)
         self.addCleanup(dul_server.shutdown)
         self.addCleanup(dul_server.server_close)
         self.addCleanup(dul_server.server_close)
         threading.Thread(target=dul_server.serve_forever).start()
         threading.Thread(target=dul_server.serve_forever).start()
@@ -80,8 +83,7 @@ class WebTests(ServerTests):
         return port
         return port
 
 
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class SmartWebTestCase(WebTests, CompatTestCase):
 class SmartWebTestCase(WebTests, CompatTestCase):
     """Test cases for smart HTTP server.
     """Test cases for smart HTTP server.
 
 
@@ -91,12 +93,12 @@ class SmartWebTestCase(WebTests, CompatTestCase):
     min_git_version = (1, 6, 6)  # type: Tuple[int, ...]
     min_git_version = (1, 6, 6)  # type: Tuple[int, ...]
 
 
     def _handlers(self):
     def _handlers(self):
-        return {b'git-receive-pack': NoSideBand64kReceivePackHandler}
+        return {b"git-receive-pack": NoSideBand64kReceivePackHandler}
 
 
     def _check_app(self, app):
     def _check_app(self, app):
-        receive_pack_handler_cls = app.handlers[b'git-receive-pack']
+        receive_pack_handler_cls = app.handlers[b"git-receive-pack"]
         caps = receive_pack_handler_cls.capabilities()
         caps = receive_pack_handler_cls.capabilities()
-        self.assertNotIn(b'side-band-64k', caps)
+        self.assertNotIn(b"side-band-64k", caps)
 
 
     def _make_app(self, backend):
     def _make_app(self, backend):
         app = make_wsgi_chain(backend, handlers=self._handlers())
         app = make_wsgi_chain(backend, handlers=self._handlers())
@@ -113,16 +115,17 @@ def patch_capabilities(handler, caps_removed):
     # removed, and return the original classmethod for restoration.
     # removed, and return the original classmethod for restoration.
     original_capabilities = handler.capabilities
     original_capabilities = handler.capabilities
     filtered_capabilities = [
     filtered_capabilities = [
-        i for i in original_capabilities() if i not in caps_removed]
+        i for i in original_capabilities() if i not in caps_removed
+    ]
 
 
     def capabilities(cls):
     def capabilities(cls):
         return filtered_capabilities
         return filtered_capabilities
+
     handler.capabilities = classmethod(capabilities)
     handler.capabilities = classmethod(capabilities)
     return original_capabilities
     return original_capabilities
 
 
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class SmartWebSideBand64kTestCase(SmartWebTestCase):
 class SmartWebSideBand64kTestCase(SmartWebTestCase):
     """Test cases for smart HTTP server with side-band-64k support."""
     """Test cases for smart HTTP server with side-band-64k support."""
 
 
@@ -143,10 +146,10 @@ class SmartWebSideBand64kTestCase(SmartWebTestCase):
         return None  # default handlers include side-band-64k
         return None  # default handlers include side-band-64k
 
 
     def _check_app(self, app):
     def _check_app(self, app):
-        receive_pack_handler_cls = app.handlers[b'git-receive-pack']
+        receive_pack_handler_cls = app.handlers[b"git-receive-pack"]
         caps = receive_pack_handler_cls.capabilities()
         caps = receive_pack_handler_cls.capabilities()
-        self.assertIn(b'side-band-64k', caps)
-        self.assertNotIn(b'no-done', caps)
+        self.assertIn(b"side-band-64k", caps)
+        self.assertNotIn(b"no-done", caps)
 
 
 
 
 class SmartWebSideBand64kNoDoneTestCase(SmartWebTestCase):
 class SmartWebSideBand64kNoDoneTestCase(SmartWebTestCase):
@@ -161,14 +164,13 @@ class SmartWebSideBand64kNoDoneTestCase(SmartWebTestCase):
         return None  # default handlers include side-band-64k
         return None  # default handlers include side-band-64k
 
 
     def _check_app(self, app):
     def _check_app(self, app):
-        receive_pack_handler_cls = app.handlers[b'git-receive-pack']
+        receive_pack_handler_cls = app.handlers[b"git-receive-pack"]
         caps = receive_pack_handler_cls.capabilities()
         caps = receive_pack_handler_cls.capabilities()
-        self.assertIn(b'side-band-64k', caps)
-        self.assertIn(b'no-done', caps)
+        self.assertIn(b"side-band-64k", caps)
+        self.assertIn(b"no-done", caps)
 
 
 
 
-@skipIf(sys.platform == 'win32',
-        'Broken on windows, with very long fail time.')
+@skipIf(sys.platform == "win32", "Broken on windows, with very long fail time.")
 class DumbWebTestCase(WebTests, CompatTestCase):
 class DumbWebTestCase(WebTests, CompatTestCase):
     """Test cases for dumb HTTP server."""
     """Test cases for dumb HTTP server."""
 
 
@@ -177,31 +179,31 @@ class DumbWebTestCase(WebTests, CompatTestCase):
 
 
     def test_push_to_dulwich(self):
     def test_push_to_dulwich(self):
         # Note: remove this if dulwich implements dumb web pushing.
         # Note: remove this if dulwich implements dumb web pushing.
-        raise SkipTest('Dumb web pushing not supported.')
+        raise SkipTest("Dumb web pushing not supported.")
 
 
     def test_push_to_dulwich_remove_branch(self):
     def test_push_to_dulwich_remove_branch(self):
         # Note: remove this if dumb pushing is supported
         # Note: remove this if dumb pushing is supported
-        raise SkipTest('Dumb web pushing not supported.')
+        raise SkipTest("Dumb web pushing not supported.")
 
 
     def test_new_shallow_clone_from_dulwich(self):
     def test_new_shallow_clone_from_dulwich(self):
         # Note: remove this if C git and dulwich implement dumb web shallow
         # Note: remove this if C git and dulwich implement dumb web shallow
         # clones.
         # clones.
-        raise SkipTest('Dumb web shallow cloning not supported.')
+        raise SkipTest("Dumb web shallow cloning not supported.")
 
 
     def test_shallow_clone_from_git_is_identical(self):
     def test_shallow_clone_from_git_is_identical(self):
         # Note: remove this if C git and dulwich implement dumb web shallow
         # Note: remove this if C git and dulwich implement dumb web shallow
         # clones.
         # clones.
-        raise SkipTest('Dumb web shallow cloning not supported.')
+        raise SkipTest("Dumb web shallow cloning not supported.")
 
 
     def test_fetch_same_depth_into_shallow_clone_from_dulwich(self):
     def test_fetch_same_depth_into_shallow_clone_from_dulwich(self):
         # Note: remove this if C git and dulwich implement dumb web shallow
         # Note: remove this if C git and dulwich implement dumb web shallow
         # clones.
         # clones.
-        raise SkipTest('Dumb web shallow cloning not supported.')
+        raise SkipTest("Dumb web shallow cloning not supported.")
 
 
     def test_fetch_full_depth_into_shallow_clone_from_dulwich(self):
     def test_fetch_full_depth_into_shallow_clone_from_dulwich(self):
         # Note: remove this if C git and dulwich implement dumb web shallow
         # Note: remove this if C git and dulwich implement dumb web shallow
         # clones.
         # clones.
-        raise SkipTest('Dumb web shallow cloning not supported.')
+        raise SkipTest("Dumb web shallow cloning not supported.")
 
 
     def test_push_to_dulwich_issue_88_standard(self):
     def test_push_to_dulwich_issue_88_standard(self):
-        raise SkipTest('Dumb web pushing not supported.')
+        raise SkipTest("Dumb web pushing not supported.")

+ 44 - 35
dulwich/tests/compat/utils.py

@@ -38,12 +38,13 @@ from dulwich.protocol import TCP_GIT_PORT
 from dulwich.tests import (
 from dulwich.tests import (
     SkipTest,
     SkipTest,
     TestCase,
     TestCase,
-    )
+)
 
 
-_DEFAULT_GIT = 'git'
+_DEFAULT_GIT = "git"
 _VERSION_LEN = 4
 _VERSION_LEN = 4
-_REPOS_DATA_DIR = os.path.abspath(os.path.join(
-    os.path.dirname(__file__), os.pardir, 'data', 'repos'))
+_REPOS_DATA_DIR = os.path.abspath(
+    os.path.join(os.path.dirname(__file__), os.pardir, "data", "repos")
+)
 
 
 
 
 def git_version(git_path=_DEFAULT_GIT):
 def git_version(git_path=_DEFAULT_GIT):
@@ -56,14 +57,14 @@ def git_version(git_path=_DEFAULT_GIT):
         None if no git installation was found.
         None if no git installation was found.
     """
     """
     try:
     try:
-        output = run_git_or_fail(['--version'], git_path=git_path)
+        output = run_git_or_fail(["--version"], git_path=git_path)
     except OSError:
     except OSError:
         return None
         return None
-    version_prefix = b'git version '
+    version_prefix = b"git version "
     if not output.startswith(version_prefix):
     if not output.startswith(version_prefix):
         return None
         return None
 
 
-    parts = output[len(version_prefix):].split(b'.')
+    parts = output[len(version_prefix) :].split(b".")
     nums = []
     nums = []
     for part in parts:
     for part in parts:
         try:
         try:
@@ -90,12 +91,15 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
     """
     """
     found_version = git_version(git_path=git_path)
     found_version = git_version(git_path=git_path)
     if found_version is None:
     if found_version is None:
-        raise SkipTest('Test requires git >= %s, but c git not found' %
-                       (required_version, ))
+        raise SkipTest(
+            "Test requires git >= %s, but c git not found" % (required_version,)
+        )
 
 
     if len(required_version) > _VERSION_LEN:
     if len(required_version) > _VERSION_LEN:
-        raise ValueError('Invalid version tuple %s, expected %i parts' %
-                         (required_version, _VERSION_LEN))
+        raise ValueError(
+            "Invalid version tuple %s, expected %i parts"
+            % (required_version, _VERSION_LEN)
+        )
 
 
     required_version = list(required_version)
     required_version = list(required_version)
     while len(found_version) < len(required_version):
     while len(found_version) < len(required_version):
@@ -103,14 +107,16 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
     required_version = tuple(required_version)
     required_version = tuple(required_version)
 
 
     if found_version < required_version:
     if found_version < required_version:
-        required_version = '.'.join(map(str, required_version))
-        found_version = '.'.join(map(str, found_version))
-        raise SkipTest('Test requires git >= %s, found %s' %
-                       (required_version, found_version))
+        required_version = ".".join(map(str, required_version))
+        found_version = ".".join(map(str, found_version))
+        raise SkipTest(
+            "Test requires git >= %s, found %s" % (required_version, found_version)
+        )
 
 
 
 
-def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
-            **popen_kwargs):
+def run_git(
+    args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False, **popen_kwargs
+):
     """Run a git command.
     """Run a git command.
 
 
     Input is piped from the input parameter and output is sent to the standard
     Input is piped from the input parameter and output is sent to the standard
@@ -129,15 +135,15 @@ def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
       OSError: if the git executable was not found.
       OSError: if the git executable was not found.
     """
     """
 
 
-    env = popen_kwargs.pop('env', {})
-    env['LC_ALL'] = env['LANG'] = 'C'
+    env = popen_kwargs.pop("env", {})
+    env["LC_ALL"] = env["LANG"] = "C"
 
 
     args = [git_path] + args
     args = [git_path] + args
-    popen_kwargs['stdin'] = subprocess.PIPE
+    popen_kwargs["stdin"] = subprocess.PIPE
     if capture_stdout:
     if capture_stdout:
-        popen_kwargs['stdout'] = subprocess.PIPE
+        popen_kwargs["stdout"] = subprocess.PIPE
     else:
     else:
-        popen_kwargs.pop('stdout', None)
+        popen_kwargs.pop("stdout", None)
     p = subprocess.Popen(args, env=env, **popen_kwargs)
     p = subprocess.Popen(args, env=env, **popen_kwargs)
     stdout, stderr = p.communicate(input=input)
     stdout, stderr = p.communicate(input=input)
     return (p.returncode, stdout)
     return (p.returncode, stdout)
@@ -145,13 +151,15 @@ def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
 
 
 def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
 def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
     """Run a git command, capture stdout/stderr, and fail if git fails."""
     """Run a git command, capture stdout/stderr, and fail if git fails."""
-    if 'stderr' not in popen_kwargs:
-        popen_kwargs['stderr'] = subprocess.STDOUT
-    returncode, stdout = run_git(args, git_path=git_path, input=input,
-                                 capture_stdout=True, **popen_kwargs)
+    if "stderr" not in popen_kwargs:
+        popen_kwargs["stderr"] = subprocess.STDOUT
+    returncode, stdout = run_git(
+        args, git_path=git_path, input=input, capture_stdout=True, **popen_kwargs
+    )
     if returncode != 0:
     if returncode != 0:
-        raise AssertionError("git with args %r failed with %d: %r" % (
-            args, returncode, stdout))
+        raise AssertionError(
+            "git with args %r failed with %d: %r" % (args, returncode, stdout)
+        )
     return stdout
     return stdout
 
 
 
 
@@ -169,10 +177,9 @@ def import_repo_to_dir(name):
     temp_dir = tempfile.mkdtemp()
     temp_dir = tempfile.mkdtemp()
     export_path = os.path.join(_REPOS_DATA_DIR, name)
     export_path = os.path.join(_REPOS_DATA_DIR, name)
     temp_repo_dir = os.path.join(temp_dir, name)
     temp_repo_dir = os.path.join(temp_dir, name)
-    export_file = open(export_path, 'rb')
-    run_git_or_fail(['init', '--quiet', '--bare', temp_repo_dir])
-    run_git_or_fail(['fast-import'], input=export_file.read(),
-                    cwd=temp_repo_dir)
+    export_file = open(export_path, "rb")
+    run_git_or_fail(["init", "--quiet", "--bare", temp_repo_dir])
+    run_git_or_fail(["fast-import"], input=export_file.read(), cwd=temp_repo_dir)
     export_file.close()
     export_file.close()
     return temp_repo_dir
     return temp_repo_dir
 
 
@@ -195,12 +202,12 @@ def check_for_daemon(limit=10, delay=0.1, timeout=0.1, port=TCP_GIT_PORT):
         s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         s.settimeout(delay)
         s.settimeout(delay)
         try:
         try:
-            s.connect(('localhost', port))
+            s.connect(("localhost", port))
             return True
             return True
         except socket.timeout:
         except socket.timeout:
             pass
             pass
         except socket.error as e:
         except socket.error as e:
-            if getattr(e, 'errno', False) and e.errno != errno.ECONNREFUSED:
+            if getattr(e, "errno", False) and e.errno != errno.ECONNREFUSED:
                 raise
                 raise
             elif e.args[0] != errno.ECONNREFUSED:
             elif e.args[0] != errno.ECONNREFUSED:
                 raise
                 raise
@@ -251,11 +258,13 @@ class CompatTestCase(TestCase):
         def cleanup():
         def cleanup():
             repo.close()
             repo.close()
             rmtree_ro(os.path.dirname(path.rstrip(os.sep)))
             rmtree_ro(os.path.dirname(path.rstrip(os.sep)))
+
         self.addCleanup(cleanup)
         self.addCleanup(cleanup)
         return repo
         return repo
 
 
 
 
-if sys.platform == 'win32':
+if sys.platform == "win32":
+
     def remove_ro(action, name, exc):
     def remove_ro(action, name, exc):
         os.chmod(name, stat.S_IWRITE)
         os.chmod(name, stat.S_IWRITE)
         os.remove(name)
         os.remove(name)

+ 15 - 18
dulwich/tests/test_archive.py

@@ -28,31 +28,30 @@ from unittest import skipUnless
 from dulwich.archive import tar_stream
 from dulwich.archive import tar_stream
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
     Tree,
     Tree,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     build_commit_graph,
     build_commit_graph,
-    )
+)
 
 
 try:
 try:
     from unittest.mock import patch
     from unittest.mock import patch
 except ImportError:
 except ImportError:
-    patch = None   # type: ignore
+    patch = None  # type: ignore
 
 
 
 
 class ArchiveTests(TestCase):
 class ArchiveTests(TestCase):
-
     def test_empty(self):
     def test_empty(self):
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         c1, c2, c3 = build_commit_graph(store, [[1], [2, 1], [3, 1, 2]])
         c1, c2, c3 = build_commit_graph(store, [[1], [2, 1], [3, 1, 2]])
         tree = store[c3.tree]
         tree = store[c3.tree]
-        stream = b''.join(tar_stream(store, tree, 10))
+        stream = b"".join(tar_stream(store, tree, 10))
         out = BytesIO(stream)
         out = BytesIO(stream)
         tf = tarfile.TarFile(fileobj=out)
         tf = tarfile.TarFile(fileobj=out)
         self.addCleanup(tf.close)
         self.addCleanup(tf.close)
@@ -65,8 +64,7 @@ class ArchiveTests(TestCase):
         t1 = Tree()
         t1 = Tree()
         t1.add(b"somename", 0o100644, b1.id)
         t1.add(b"somename", 0o100644, b1.id)
         store.add_object(t1)
         store.add_object(t1)
-        stream = b''.join(
-            tar_stream(store, t1, *tar_stream_args, **tar_stream_kwargs))
+        stream = b"".join(tar_stream(store, t1, *tar_stream_args, **tar_stream_kwargs))
         return BytesIO(stream)
         return BytesIO(stream)
 
 
     def test_simple(self):
     def test_simple(self):
@@ -76,27 +74,26 @@ class ArchiveTests(TestCase):
         self.assertEqual(["somename"], tf.getnames())
         self.assertEqual(["somename"], tf.getnames())
 
 
     def test_prefix(self):
     def test_prefix(self):
-        stream = self._get_example_tar_stream(mtime=0, prefix=b'blah')
+        stream = self._get_example_tar_stream(mtime=0, prefix=b"blah")
         tf = tarfile.TarFile(fileobj=stream)
         tf = tarfile.TarFile(fileobj=stream)
         self.addCleanup(tf.close)
         self.addCleanup(tf.close)
         self.assertEqual(["blah/somename"], tf.getnames())
         self.assertEqual(["blah/somename"], tf.getnames())
 
 
     def test_gzip_mtime(self):
     def test_gzip_mtime(self):
-        stream = self._get_example_tar_stream(mtime=1234, format='gz')
-        expected_mtime = struct.pack('<L', 1234)
+        stream = self._get_example_tar_stream(mtime=1234, format="gz")
+        expected_mtime = struct.pack("<L", 1234)
         self.assertEqual(stream.getvalue()[4:8], expected_mtime)
         self.assertEqual(stream.getvalue()[4:8], expected_mtime)
 
 
     @skipUnless(patch, "Required mock.patch")
     @skipUnless(patch, "Required mock.patch")
     def test_same_file(self):
     def test_same_file(self):
         contents = [None, None]
         contents = [None, None]
-        for format in ['', 'gz', 'bz2']:
+        for format in ["", "gz", "bz2"]:
             for i in [0, 1]:
             for i in [0, 1]:
-                with patch('time.time', return_value=i):
-                    stream = self._get_example_tar_stream(
-                        mtime=0, format=format)
+                with patch("time.time", return_value=i):
+                    stream = self._get_example_tar_stream(mtime=0, format=format)
                     contents[i] = stream.getvalue()
                     contents[i] = stream.getvalue()
             self.assertEqual(
             self.assertEqual(
                 contents[0],
                 contents[0],
                 contents[1],
                 contents[1],
-                "Different file contents for format %r" % format
-                )
+                "Different file contents for format %r" % format,
+            )

+ 9 - 9
dulwich/tests/test_blackbox.py

@@ -25,10 +25,10 @@ import shutil
 
 
 from dulwich.repo import (
 from dulwich.repo import (
     Repo,
     Repo,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     BlackboxTestCase,
     BlackboxTestCase,
-    )
+)
 
 
 
 
 class GitReceivePackTests(BlackboxTestCase):
 class GitReceivePackTests(BlackboxTestCase):
@@ -43,16 +43,16 @@ class GitReceivePackTests(BlackboxTestCase):
     def test_basic(self):
     def test_basic(self):
         process = self.run_command("dul-receive-pack", [self.path])
         process = self.run_command("dul-receive-pack", [self.path])
         (stdout, stderr) = process.communicate(b"0000")
         (stdout, stderr) = process.communicate(b"0000")
-        self.assertEqual(b'0000', stdout[-4:])
+        self.assertEqual(b"0000", stdout[-4:])
         self.assertEqual(0, process.returncode)
         self.assertEqual(0, process.returncode)
 
 
     def test_missing_arg(self):
     def test_missing_arg(self):
         process = self.run_command("dul-receive-pack", [])
         process = self.run_command("dul-receive-pack", [])
         (stdout, stderr) = process.communicate()
         (stdout, stderr) = process.communicate()
         self.assertEqual(
         self.assertEqual(
-            [b'usage: dul-receive-pack <git-dir>'],
-            stderr.splitlines()[-1:])
-        self.assertEqual(b'', stdout)
+            [b"usage: dul-receive-pack <git-dir>"], stderr.splitlines()[-1:]
+        )
+        self.assertEqual(b"", stdout)
         self.assertEqual(1, process.returncode)
         self.assertEqual(1, process.returncode)
 
 
 
 
@@ -69,7 +69,7 @@ class GitUploadPackTests(BlackboxTestCase):
         process = self.run_command("dul-upload-pack", [])
         process = self.run_command("dul-upload-pack", [])
         (stdout, stderr) = process.communicate()
         (stdout, stderr) = process.communicate()
         self.assertEqual(
         self.assertEqual(
-            [b'usage: dul-upload-pack <git-dir>'],
-            stderr.splitlines()[-1:])
-        self.assertEqual(b'', stdout)
+            [b"usage: dul-upload-pack <git-dir>"], stderr.splitlines()[-1:]
+        )
+        self.assertEqual(b"", stdout)
         self.assertEqual(1, process.returncode)
         self.assertEqual(1, process.returncode)

+ 7 - 8
dulwich/tests/test_bundle.py

@@ -25,28 +25,27 @@ import tempfile
 
 
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
-    )
+)
 
 
 from dulwich.bundle import (
 from dulwich.bundle import (
     Bundle,
     Bundle,
     read_bundle,
     read_bundle,
     write_bundle,
     write_bundle,
-    )
+)
 
 
 
 
 class BundleTests(TestCase):
 class BundleTests(TestCase):
-
     def test_roundtrip_bundle(self):
     def test_roundtrip_bundle(self):
         origbundle = Bundle()
         origbundle = Bundle()
         origbundle.version = 3
         origbundle.version = 3
-        origbundle.capabilities = {'foo': None}
-        origbundle.references = {b'refs/heads/master': b'ab' * 20}
-        origbundle.prerequisites = [(b'cc' * 20, 'comment')]
+        origbundle.capabilities = {"foo": None}
+        origbundle.references = {b"refs/heads/master": b"ab" * 20}
+        origbundle.prerequisites = [(b"cc" * 20, "comment")]
         with tempfile.TemporaryDirectory() as td:
         with tempfile.TemporaryDirectory() as td:
-            with open(os.path.join(td, 'foo'), 'wb') as f:
+            with open(os.path.join(td, "foo"), "wb") as f:
                 write_bundle(f, origbundle)
                 write_bundle(f, origbundle)
 
 
-            with open(os.path.join(td, 'foo'), 'rb') as f:
+            with open(os.path.join(td, "foo"), "rb") as f:
                 newbundle = read_bundle(f)
                 newbundle = read_bundle(f)
 
 
                 self.assertEqual(origbundle, newbundle)
                 self.assertEqual(origbundle, newbundle)

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 312 - 286
dulwich/tests/test_client.py


+ 172 - 101
dulwich/tests/test_config.py

@@ -20,7 +20,12 @@
 
 
 """Tests for reading and writing configuration files."""
 """Tests for reading and writing configuration files."""
 
 
+import os
+import sys
 from io import BytesIO
 from io import BytesIO
+from unittest import skipIf
+from unittest.mock import patch
+
 from dulwich.config import (
 from dulwich.config import (
     ConfigDict,
     ConfigDict,
     ConfigFile,
     ConfigFile,
@@ -31,14 +36,13 @@ from dulwich.config import (
     _escape_value,
     _escape_value,
     _parse_string,
     _parse_string,
     parse_submodules,
     parse_submodules,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
-    )
+)
 
 
 
 
 class ConfigFileTests(TestCase):
 class ConfigFileTests(TestCase):
-
     def from_file(self, text):
     def from_file(self, text):
         return ConfigFile.from_file(BytesIO(text))
         return ConfigFile.from_file(BytesIO(text))
 
 
@@ -49,17 +53,27 @@ class ConfigFileTests(TestCase):
         self.assertEqual(ConfigFile(), ConfigFile())
         self.assertEqual(ConfigFile(), ConfigFile())
 
 
     def test_default_config(self):
     def test_default_config(self):
-        cf = self.from_file(b"""[core]
+        cf = self.from_file(
+            b"""[core]
 \trepositoryformatversion = 0
 \trepositoryformatversion = 0
 \tfilemode = true
 \tfilemode = true
 \tbare = false
 \tbare = false
 \tlogallrefupdates = true
 \tlogallrefupdates = true
-""")
-        self.assertEqual(ConfigFile({(b"core", ): {
-            b"repositoryformatversion": b"0",
-            b"filemode": b"true",
-            b"bare": b"false",
-            b"logallrefupdates": b"true"}}), cf)
+"""
+        )
+        self.assertEqual(
+            ConfigFile(
+                {
+                    (b"core",): {
+                        b"repositoryformatversion": b"0",
+                        b"filemode": b"true",
+                        b"bare": b"false",
+                        b"logallrefupdates": b"true",
+                    }
+                }
+            ),
+            cf,
+        )
 
 
     def test_from_file_empty(self):
     def test_from_file_empty(self):
         cf = self.from_file(b"")
         cf = self.from_file(b"")
@@ -67,81 +81,71 @@ class ConfigFileTests(TestCase):
 
 
     def test_empty_line_before_section(self):
     def test_empty_line_before_section(self):
         cf = self.from_file(b"\n[section]\n")
         cf = self.from_file(b"\n[section]\n")
-        self.assertEqual(ConfigFile({(b"section", ): {}}), cf)
+        self.assertEqual(ConfigFile({(b"section",): {}}), cf)
 
 
     def test_comment_before_section(self):
     def test_comment_before_section(self):
         cf = self.from_file(b"# foo\n[section]\n")
         cf = self.from_file(b"# foo\n[section]\n")
-        self.assertEqual(ConfigFile({(b"section", ): {}}), cf)
+        self.assertEqual(ConfigFile({(b"section",): {}}), cf)
 
 
     def test_comment_after_section(self):
     def test_comment_after_section(self):
         cf = self.from_file(b"[section] # foo\n")
         cf = self.from_file(b"[section] # foo\n")
-        self.assertEqual(ConfigFile({(b"section", ): {}}), cf)
+        self.assertEqual(ConfigFile({(b"section",): {}}), cf)
 
 
     def test_comment_after_variable(self):
     def test_comment_after_variable(self):
         cf = self.from_file(b"[section]\nbar= foo # a comment\n")
         cf = self.from_file(b"[section]\nbar= foo # a comment\n")
-        self.assertEqual(ConfigFile({(b"section", ): {b"bar": b"foo"}}), cf)
+        self.assertEqual(ConfigFile({(b"section",): {b"bar": b"foo"}}), cf)
 
 
     def test_comment_character_within_value_string(self):
     def test_comment_character_within_value_string(self):
-        cf = self.from_file(b"[section]\nbar= \"foo#bar\"\n")
-        self.assertEqual(
-            ConfigFile({(b"section", ): {b"bar": b"foo#bar"}}), cf)
+        cf = self.from_file(b'[section]\nbar= "foo#bar"\n')
+        self.assertEqual(ConfigFile({(b"section",): {b"bar": b"foo#bar"}}), cf)
 
 
     def test_comment_character_within_section_string(self):
     def test_comment_character_within_section_string(self):
-        cf = self.from_file(b"[branch \"foo#bar\"] # a comment\nbar= foo\n")
-        self.assertEqual(
-            ConfigFile({(b"branch", b"foo#bar"): {b"bar": b"foo"}}), cf)
+        cf = self.from_file(b'[branch "foo#bar"] # a comment\nbar= foo\n')
+        self.assertEqual(ConfigFile({(b"branch", b"foo#bar"): {b"bar": b"foo"}}), cf)
 
 
     def test_from_file_section(self):
     def test_from_file_section(self):
         cf = self.from_file(b"[core]\nfoo = bar\n")
         cf = self.from_file(b"[core]\nfoo = bar\n")
-        self.assertEqual(b"bar", cf.get((b"core", ), b"foo"))
+        self.assertEqual(b"bar", cf.get((b"core",), b"foo"))
         self.assertEqual(b"bar", cf.get((b"core", b"foo"), b"foo"))
         self.assertEqual(b"bar", cf.get((b"core", b"foo"), b"foo"))
 
 
     def test_from_file_section_case_insensitive_lower(self):
     def test_from_file_section_case_insensitive_lower(self):
         cf = self.from_file(b"[cOre]\nfOo = bar\n")
         cf = self.from_file(b"[cOre]\nfOo = bar\n")
-        self.assertEqual(b"bar", cf.get((b"core", ), b"foo"))
+        self.assertEqual(b"bar", cf.get((b"core",), b"foo"))
         self.assertEqual(b"bar", cf.get((b"core", b"foo"), b"foo"))
         self.assertEqual(b"bar", cf.get((b"core", b"foo"), b"foo"))
 
 
     def test_from_file_section_case_insensitive_mixed(self):
     def test_from_file_section_case_insensitive_mixed(self):
         cf = self.from_file(b"[cOre]\nfOo = bar\n")
         cf = self.from_file(b"[cOre]\nfOo = bar\n")
-        self.assertEqual(b"bar", cf.get((b"core", ), b"fOo"))
+        self.assertEqual(b"bar", cf.get((b"core",), b"fOo"))
         self.assertEqual(b"bar", cf.get((b"cOre", b"fOo"), b"fOo"))
         self.assertEqual(b"bar", cf.get((b"cOre", b"fOo"), b"fOo"))
 
 
     def test_from_file_with_mixed_quoted(self):
     def test_from_file_with_mixed_quoted(self):
-        cf = self.from_file(b"[core]\nfoo = \"bar\"la\n")
-        self.assertEqual(b"barla", cf.get((b"core", ), b"foo"))
+        cf = self.from_file(b'[core]\nfoo = "bar"la\n')
+        self.assertEqual(b"barla", cf.get((b"core",), b"foo"))
 
 
     def test_from_file_section_with_open_brackets(self):
     def test_from_file_section_with_open_brackets(self):
         self.assertRaises(ValueError, self.from_file, b"[core\nfoo = bar\n")
         self.assertRaises(ValueError, self.from_file, b"[core\nfoo = bar\n")
 
 
     def test_from_file_value_with_open_quoted(self):
     def test_from_file_value_with_open_quoted(self):
-        self.assertRaises(ValueError, self.from_file, b"[core]\nfoo = \"bar\n")
+        self.assertRaises(ValueError, self.from_file, b'[core]\nfoo = "bar\n')
 
 
     def test_from_file_with_quotes(self):
     def test_from_file_with_quotes(self):
-        cf = self.from_file(
-            b"[core]\n"
-            b'foo = " bar"\n')
-        self.assertEqual(b" bar", cf.get((b"core", ), b"foo"))
+        cf = self.from_file(b"[core]\n" b'foo = " bar"\n')
+        self.assertEqual(b" bar", cf.get((b"core",), b"foo"))
 
 
     def test_from_file_with_interrupted_line(self):
     def test_from_file_with_interrupted_line(self):
-        cf = self.from_file(
-            b"[core]\n"
-            b'foo = bar\\\n'
-            b' la\n')
-        self.assertEqual(b"barla", cf.get((b"core", ), b"foo"))
+        cf = self.from_file(b"[core]\n" b"foo = bar\\\n" b" la\n")
+        self.assertEqual(b"barla", cf.get((b"core",), b"foo"))
 
 
     def test_from_file_with_boolean_setting(self):
     def test_from_file_with_boolean_setting(self):
-        cf = self.from_file(
-            b"[core]\n"
-            b'foo\n')
-        self.assertEqual(b"true", cf.get((b"core", ), b"foo"))
+        cf = self.from_file(b"[core]\n" b"foo\n")
+        self.assertEqual(b"true", cf.get((b"core",), b"foo"))
 
 
     def test_from_file_subsection(self):
     def test_from_file_subsection(self):
-        cf = self.from_file(b"[branch \"foo\"]\nfoo = bar\n")
+        cf = self.from_file(b'[branch "foo"]\nfoo = bar\n')
         self.assertEqual(b"bar", cf.get((b"branch", b"foo"), b"foo"))
         self.assertEqual(b"bar", cf.get((b"branch", b"foo"), b"foo"))
 
 
     def test_from_file_subsection_invalid(self):
     def test_from_file_subsection_invalid(self):
-        self.assertRaises(
-                ValueError, self.from_file, b"[branch \"foo]\nfoo = bar\n")
+        self.assertRaises(ValueError, self.from_file, b'[branch "foo]\nfoo = bar\n')
 
 
     def test_from_file_subsection_not_quoted(self):
     def test_from_file_subsection_not_quoted(self):
         cf = self.from_file(b"[branch.foo]\nfoo = bar\n")
         cf = self.from_file(b"[branch.foo]\nfoo = bar\n")
@@ -155,7 +159,7 @@ class ConfigFileTests(TestCase):
 
 
     def test_write_to_file_section(self):
     def test_write_to_file_section(self):
         c = ConfigFile()
         c = ConfigFile()
-        c.set((b"core", ), b"foo", b"bar")
+        c.set((b"core",), b"foo", b"bar")
         f = BytesIO()
         f = BytesIO()
         c.write_to_file(f)
         c.write_to_file(f)
         self.assertEqual(b"[core]\n\tfoo = bar\n", f.getvalue())
         self.assertEqual(b"[core]\n\tfoo = bar\n", f.getvalue())
@@ -165,100 +169,161 @@ class ConfigFileTests(TestCase):
         c.set((b"branch", b"blie"), b"foo", b"bar")
         c.set((b"branch", b"blie"), b"foo", b"bar")
         f = BytesIO()
         f = BytesIO()
         c.write_to_file(f)
         c.write_to_file(f)
-        self.assertEqual(b"[branch \"blie\"]\n\tfoo = bar\n", f.getvalue())
+        self.assertEqual(b'[branch "blie"]\n\tfoo = bar\n', f.getvalue())
 
 
     def test_same_line(self):
     def test_same_line(self):
         cf = self.from_file(b"[branch.foo] foo = bar\n")
         cf = self.from_file(b"[branch.foo] foo = bar\n")
         self.assertEqual(b"bar", cf.get((b"branch", b"foo"), b"foo"))
         self.assertEqual(b"bar", cf.get((b"branch", b"foo"), b"foo"))
 
 
     def test_quoted(self):
     def test_quoted(self):
-        cf = self.from_file(b"""[gui]
+        cf = self.from_file(
+            b"""[gui]
 \tfontdiff = -family \\\"Ubuntu Mono\\\" -size 11 -overstrike 0
 \tfontdiff = -family \\\"Ubuntu Mono\\\" -size 11 -overstrike 0
-""")
-        self.assertEqual(ConfigFile({(b'gui', ): {
-            b'fontdiff': b'-family "Ubuntu Mono" -size 11 -overstrike 0',
-        }}), cf)
+"""
+        )
+        self.assertEqual(
+            ConfigFile(
+                {
+                    (b"gui",): {
+                        b"fontdiff": b'-family "Ubuntu Mono" -size 11 -overstrike 0',
+                    }
+                }
+            ),
+            cf,
+        )
 
 
     def test_quoted_multiline(self):
     def test_quoted_multiline(self):
-        cf = self.from_file(b"""[alias]
+        cf = self.from_file(
+            b"""[alias]
 who = \"!who() {\\
 who = \"!who() {\\
   git log --no-merges --pretty=format:'%an - %ae' $@ | uniq -c | sort -rn;\\
   git log --no-merges --pretty=format:'%an - %ae' $@ | uniq -c | sort -rn;\\
 };\\
 };\\
 who\"
 who\"
-""")
-        self.assertEqual(ConfigFile({(b'alias', ): {
-            b'who': (b"!who() {git log --no-merges --pretty=format:'%an - "
-                     b"%ae' $@ | uniq -c | sort -rn;};who")
-            }}), cf)
+"""
+        )
+        self.assertEqual(
+            ConfigFile(
+                {
+                    (b"alias",): {
+                        b"who": (
+                            b"!who() {git log --no-merges --pretty=format:'%an - "
+                            b"%ae' $@ | uniq -c | sort -rn;};who"
+                        )
+                    }
+                }
+            ),
+            cf,
+        )
 
 
     def test_set_hash_gets_quoted(self):
     def test_set_hash_gets_quoted(self):
         c = ConfigFile()
         c = ConfigFile()
         c.set(b"xandikos", b"color", b"#665544")
         c.set(b"xandikos", b"color", b"#665544")
         f = BytesIO()
         f = BytesIO()
         c.write_to_file(f)
         c.write_to_file(f)
-        self.assertEqual(b"[xandikos]\n\tcolor = \"#665544\"\n", f.getvalue())
+        self.assertEqual(b'[xandikos]\n\tcolor = "#665544"\n', f.getvalue())
 
 
 
 
 class ConfigDictTests(TestCase):
 class ConfigDictTests(TestCase):
-
     def test_get_set(self):
     def test_get_set(self):
         cd = ConfigDict()
         cd = ConfigDict()
         self.assertRaises(KeyError, cd.get, b"foo", b"core")
         self.assertRaises(KeyError, cd.get, b"foo", b"core")
-        cd.set((b"core", ), b"foo", b"bla")
-        self.assertEqual(b"bla", cd.get((b"core", ), b"foo"))
-        cd.set((b"core", ), b"foo", b"bloe")
-        self.assertEqual(b"bloe", cd.get((b"core", ), b"foo"))
+        cd.set((b"core",), b"foo", b"bla")
+        self.assertEqual(b"bla", cd.get((b"core",), b"foo"))
+        cd.set((b"core",), b"foo", b"bloe")
+        self.assertEqual(b"bloe", cd.get((b"core",), b"foo"))
 
 
     def test_get_boolean(self):
     def test_get_boolean(self):
         cd = ConfigDict()
         cd = ConfigDict()
-        cd.set((b"core", ), b"foo", b"true")
-        self.assertTrue(cd.get_boolean((b"core", ), b"foo"))
-        cd.set((b"core", ), b"foo", b"false")
-        self.assertFalse(cd.get_boolean((b"core", ), b"foo"))
-        cd.set((b"core", ), b"foo", b"invalid")
-        self.assertRaises(ValueError, cd.get_boolean, (b"core", ), b"foo")
+        cd.set((b"core",), b"foo", b"true")
+        self.assertTrue(cd.get_boolean((b"core",), b"foo"))
+        cd.set((b"core",), b"foo", b"false")
+        self.assertFalse(cd.get_boolean((b"core",), b"foo"))
+        cd.set((b"core",), b"foo", b"invalid")
+        self.assertRaises(ValueError, cd.get_boolean, (b"core",), b"foo")
 
 
     def test_dict(self):
     def test_dict(self):
         cd = ConfigDict()
         cd = ConfigDict()
-        cd.set((b"core", ), b"foo", b"bla")
-        cd.set((b"core2", ), b"foo", b"bloe")
+        cd.set((b"core",), b"foo", b"bla")
+        cd.set((b"core2",), b"foo", b"bloe")
 
 
-        self.assertEqual([(b"core", ), (b"core2", )], list(cd.keys()))
-        self.assertEqual(cd[(b"core", )], {b'foo': b'bla'})
+        self.assertEqual([(b"core",), (b"core2",)], list(cd.keys()))
+        self.assertEqual(cd[(b"core",)], {b"foo": b"bla"})
 
 
-        cd[b'a'] = b'b'
-        self.assertEqual(cd[b'a'], b'b')
+        cd[b"a"] = b"b"
+        self.assertEqual(cd[b"a"], b"b")
 
 
     def test_iteritems(self):
     def test_iteritems(self):
         cd = ConfigDict()
         cd = ConfigDict()
-        cd.set((b"core", ), b"foo", b"bla")
-        cd.set((b"core2", ), b"foo", b"bloe")
+        cd.set((b"core",), b"foo", b"bla")
+        cd.set((b"core2",), b"foo", b"bloe")
 
 
-        self.assertEqual(
-            [(b'foo', b'bla')],
-            list(cd.iteritems((b"core", ))))
+        self.assertEqual([(b"foo", b"bla")], list(cd.iteritems((b"core",))))
 
 
     def test_iteritems_nonexistant(self):
     def test_iteritems_nonexistant(self):
         cd = ConfigDict()
         cd = ConfigDict()
-        cd.set((b"core2", ), b"foo", b"bloe")
+        cd.set((b"core2",), b"foo", b"bloe")
 
 
-        self.assertEqual([], list(cd.iteritems((b"core", ))))
+        self.assertEqual([], list(cd.iteritems((b"core",))))
 
 
     def test_itersections(self):
     def test_itersections(self):
         cd = ConfigDict()
         cd = ConfigDict()
-        cd.set((b"core2", ), b"foo", b"bloe")
+        cd.set((b"core2",), b"foo", b"bloe")
 
 
-        self.assertEqual([(b"core2", )], list(cd.itersections()))
+        self.assertEqual([(b"core2",)], list(cd.itersections()))
 
 
 
 
 class StackedConfigTests(TestCase):
 class StackedConfigTests(TestCase):
+    def setUp(self):
+        super(StackedConfigTests, self).setUp()
+        self._old_path = os.environ.get("PATH")
+
+    def tearDown(self):
+        super(StackedConfigTests, self).tearDown()
+        os.environ["PATH"] = self._old_path
 
 
     def test_default_backends(self):
     def test_default_backends(self):
         StackedConfig.default_backends()
         StackedConfig.default_backends()
 
 
+    @skipIf(sys.platform != "win32", "Windows specfic config location.")
+    def test_windows_config_from_path(self):
+        from dulwich.config import get_win_system_paths
+
+        install_dir = os.path.join("C:", "foo", "Git")
+        os.environ["PATH"] = os.path.join(install_dir, "cmd")
+        with patch("os.path.exists", return_value=True):
+            paths = set(get_win_system_paths())
+        self.assertEqual(
+            {
+                os.path.join(os.environ.get("PROGRAMDATA"), "Git", "config"),
+                os.path.join(install_dir, "etc", "gitconfig"),
+            },
+            paths,
+        )
+
+    @skipIf(sys.platform != "win32", "Windows specfic config location.")
+    def test_windows_config_from_reg(self):
+        import winreg
+
+        from dulwich.config import get_win_system_paths
+
+        del os.environ["PATH"]
+        install_dir = os.path.join("C:", "foo", "Git")
+        with patch("winreg.OpenKey"):
+            with patch(
+                "winreg.QueryValueEx",
+                return_value=(install_dir, winreg.REG_SZ),
+            ):
+                paths = set(get_win_system_paths())
+        self.assertEqual(
+            {
+                os.path.join(os.environ.get("PROGRAMDATA"), "Git", "config"),
+                os.path.join(install_dir, "etc", "gitconfig"),
+            },
+            paths,
+        )
 
 
-class EscapeValueTests(TestCase):
 
 
+class EscapeValueTests(TestCase):
     def test_nothing(self):
     def test_nothing(self):
         self.assertEqual(b"foo", _escape_value(b"foo"))
         self.assertEqual(b"foo", _escape_value(b"foo"))
 
 
@@ -270,28 +335,26 @@ class EscapeValueTests(TestCase):
 
 
 
 
 class FormatStringTests(TestCase):
 class FormatStringTests(TestCase):
-
     def test_quoted(self):
     def test_quoted(self):
         self.assertEqual(b'" foo"', _format_string(b" foo"))
         self.assertEqual(b'" foo"', _format_string(b" foo"))
         self.assertEqual(b'"\\tfoo"', _format_string(b"\tfoo"))
         self.assertEqual(b'"\\tfoo"', _format_string(b"\tfoo"))
 
 
     def test_not_quoted(self):
     def test_not_quoted(self):
-        self.assertEqual(b'foo', _format_string(b"foo"))
-        self.assertEqual(b'foo bar', _format_string(b"foo bar"))
+        self.assertEqual(b"foo", _format_string(b"foo"))
+        self.assertEqual(b"foo bar", _format_string(b"foo bar"))
 
 
 
 
 class ParseStringTests(TestCase):
 class ParseStringTests(TestCase):
-
     def test_quoted(self):
     def test_quoted(self):
-        self.assertEqual(b' foo', _parse_string(b'" foo"'))
-        self.assertEqual(b'\tfoo', _parse_string(b'"\\tfoo"'))
+        self.assertEqual(b" foo", _parse_string(b'" foo"'))
+        self.assertEqual(b"\tfoo", _parse_string(b'"\\tfoo"'))
 
 
     def test_not_quoted(self):
     def test_not_quoted(self):
-        self.assertEqual(b'foo', _parse_string(b"foo"))
-        self.assertEqual(b'foo bar', _parse_string(b"foo bar"))
+        self.assertEqual(b"foo", _parse_string(b"foo"))
+        self.assertEqual(b"foo bar", _parse_string(b"foo bar"))
 
 
     def test_nothing(self):
     def test_nothing(self):
-        self.assertEqual(b"", _parse_string(b''))
+        self.assertEqual(b"", _parse_string(b""))
 
 
     def test_tab(self):
     def test_tab(self):
         self.assertEqual(b"\tbar\t", _parse_string(b"\\tbar\\t"))
         self.assertEqual(b"\tbar\t", _parse_string(b"\\tbar\\t"))
@@ -300,11 +363,10 @@ class ParseStringTests(TestCase):
         self.assertEqual(b"\nbar\t", _parse_string(b"\\nbar\\t\t"))
         self.assertEqual(b"\nbar\t", _parse_string(b"\\nbar\\t\t"))
 
 
     def test_quote(self):
     def test_quote(self):
-        self.assertEqual(b"\"foo\"", _parse_string(b"\\\"foo\\\""))
+        self.assertEqual(b'"foo"', _parse_string(b'\\"foo\\"'))
 
 
 
 
 class CheckVariableNameTests(TestCase):
 class CheckVariableNameTests(TestCase):
-
     def test_invalid(self):
     def test_invalid(self):
         self.assertFalse(_check_variable_name(b"foo "))
         self.assertFalse(_check_variable_name(b"foo "))
         self.assertFalse(_check_variable_name(b"bar,bar"))
         self.assertFalse(_check_variable_name(b"bar,bar"))
@@ -317,7 +379,6 @@ class CheckVariableNameTests(TestCase):
 
 
 
 
 class CheckSectionNameTests(TestCase):
 class CheckSectionNameTests(TestCase):
-
     def test_invalid(self):
     def test_invalid(self):
         self.assertFalse(_check_section_name(b"foo "))
         self.assertFalse(_check_section_name(b"foo "))
         self.assertFalse(_check_section_name(b"bar,bar"))
         self.assertFalse(_check_section_name(b"bar,bar"))
@@ -330,14 +391,24 @@ class CheckSectionNameTests(TestCase):
 
 
 
 
 class SubmodulesTests(TestCase):
 class SubmodulesTests(TestCase):
-
     def testSubmodules(self):
     def testSubmodules(self):
-        cf = ConfigFile.from_file(BytesIO(b"""\
+        cf = ConfigFile.from_file(
+            BytesIO(
+                b"""\
 [submodule "core/lib"]
 [submodule "core/lib"]
 \tpath = core/lib
 \tpath = core/lib
 \turl = https://github.com/phhusson/QuasselC.git
 \turl = https://github.com/phhusson/QuasselC.git
-"""))
+"""
+            )
+        )
         got = list(parse_submodules(cf))
         got = list(parse_submodules(cf))
-        self.assertEqual([
-            (b'core/lib', b'https://github.com/phhusson/QuasselC.git',
-             b'core/lib')], got)
+        self.assertEqual(
+            [
+                (
+                    b"core/lib",
+                    b"https://github.com/phhusson/QuasselC.git",
+                    b"core/lib",
+                )
+            ],
+            got,
+        )

+ 787 - 584
dulwich/tests/test_diff_tree.py

@@ -37,33 +37,32 @@ from dulwich.diff_tree import (
     _tree_change_key,
     _tree_change_key,
     RenameDetector,
     RenameDetector,
     _is_tree,
     _is_tree,
-    _is_tree_py
-    )
+    _is_tree_py,
+)
 from dulwich.index import (
 from dulwich.index import (
     commit_tree,
     commit_tree,
-    )
+)
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     ShaFile,
     ShaFile,
     Blob,
     Blob,
     TreeEntry,
     TreeEntry,
     Tree,
     Tree,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     F,
     F,
     make_object,
     make_object,
     functest_builder,
     functest_builder,
     ext_functest_builder,
     ext_functest_builder,
-    )
+)
 
 
 
 
 class DiffTestCase(TestCase):
 class DiffTestCase(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(DiffTestCase, self).setUp()
         super(DiffTestCase, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
@@ -87,7 +86,6 @@ class DiffTestCase(TestCase):
 
 
 
 
 class TreeChangesTest(DiffTestCase):
 class TreeChangesTest(DiffTestCase):
-
     def setUp(self):
     def setUp(self):
         super(TreeChangesTest, self).setUp()
         super(TreeChangesTest, self).setUp()
         self.detector = RenameDetector(self.store)
         self.detector = RenameDetector(self.store)
@@ -95,62 +93,74 @@ class TreeChangesTest(DiffTestCase):
     def assertMergeFails(self, merge_entries, name, mode, sha):
     def assertMergeFails(self, merge_entries, name, mode, sha):
         t = Tree()
         t = Tree()
         t[name] = (mode, sha)
         t[name] = (mode, sha)
-        self.assertRaises((TypeError, ValueError), merge_entries, '', t, t)
+        self.assertRaises((TypeError, ValueError), merge_entries, "", t, t)
 
 
     def _do_test_merge_entries(self, merge_entries):
     def _do_test_merge_entries(self, merge_entries):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_b1 = make_object(Blob, data=b'b1')
-        blob_c2 = make_object(Blob, data=b'c2')
-        tree1 = self.commit_tree([(b'a', blob_a1, 0o100644),
-                                  (b'b', blob_b1, 0o100755)])
-        tree2 = self.commit_tree([(b'a', blob_a2, 0o100644),
-                                  (b'c', blob_c2, 0o100755)])
-
-        self.assertEqual([], merge_entries(b'', self.empty_tree,
-                                           self.empty_tree))
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_b1 = make_object(Blob, data=b"b1")
+        blob_c2 = make_object(Blob, data=b"c2")
+        tree1 = self.commit_tree([(b"a", blob_a1, 0o100644), (b"b", blob_b1, 0o100755)])
+        tree2 = self.commit_tree([(b"a", blob_a2, 0o100644), (b"c", blob_c2, 0o100755)])
+
+        self.assertEqual([], merge_entries(b"", self.empty_tree, self.empty_tree))
         self.assertEqual(
         self.assertEqual(
-            [((None, None, None), (b'a', 0o100644, blob_a1.id)),
-             ((None, None, None), (b'b', 0o100755, blob_b1.id)), ],
-            merge_entries(b'', self.empty_tree, tree1))
+            [
+                ((None, None, None), (b"a", 0o100644, blob_a1.id)),
+                ((None, None, None), (b"b", 0o100755, blob_b1.id)),
+            ],
+            merge_entries(b"", self.empty_tree, tree1),
+        )
         self.assertEqual(
         self.assertEqual(
-            [((None, None, None), (b'x/a', 0o100644, blob_a1.id)),
-             ((None, None, None), (b'x/b', 0o100755, blob_b1.id)), ],
-            merge_entries(b'x', self.empty_tree, tree1))
+            [
+                ((None, None, None), (b"x/a", 0o100644, blob_a1.id)),
+                ((None, None, None), (b"x/b", 0o100755, blob_b1.id)),
+            ],
+            merge_entries(b"x", self.empty_tree, tree1),
+        )
 
 
         self.assertEqual(
         self.assertEqual(
-            [((b'a', 0o100644, blob_a2.id), (None, None, None)),
-             ((b'c', 0o100755, blob_c2.id), (None, None, None)), ],
-            merge_entries(b'', tree2, self.empty_tree))
+            [
+                ((b"a", 0o100644, blob_a2.id), (None, None, None)),
+                ((b"c", 0o100755, blob_c2.id), (None, None, None)),
+            ],
+            merge_entries(b"", tree2, self.empty_tree),
+        )
 
 
         self.assertEqual(
         self.assertEqual(
-            [((b'a', 0o100644, blob_a1.id), (b'a', 0o100644, blob_a2.id)),
-             ((b'b', 0o100755, blob_b1.id), (None, None, None)),
-             ((None, None, None), (b'c', 0o100755, blob_c2.id)), ],
-            merge_entries(b'', tree1, tree2))
+            [
+                ((b"a", 0o100644, blob_a1.id), (b"a", 0o100644, blob_a2.id)),
+                ((b"b", 0o100755, blob_b1.id), (None, None, None)),
+                ((None, None, None), (b"c", 0o100755, blob_c2.id)),
+            ],
+            merge_entries(b"", tree1, tree2),
+        )
 
 
         self.assertEqual(
         self.assertEqual(
-            [((b'a', 0o100644, blob_a2.id), (b'a', 0o100644, blob_a1.id)),
-             ((None, None, None), (b'b', 0o100755, blob_b1.id)),
-             ((b'c', 0o100755, blob_c2.id), (None, None, None)), ],
-            merge_entries(b'', tree2, tree1))
-
-        self.assertMergeFails(merge_entries, 0xdeadbeef, 0o100644, '1' * 40)
-        self.assertMergeFails(merge_entries, b'a', b'deadbeef', '1' * 40)
-        self.assertMergeFails(merge_entries, b'a', 0o100644, 0xdeadbeef)
-
-    test_merge_entries = functest_builder(_do_test_merge_entries,
-                                          _merge_entries_py)
-    test_merge_entries_extension = ext_functest_builder(_do_test_merge_entries,
-                                                        _merge_entries)
+            [
+                ((b"a", 0o100644, blob_a2.id), (b"a", 0o100644, blob_a1.id)),
+                ((None, None, None), (b"b", 0o100755, blob_b1.id)),
+                ((b"c", 0o100755, blob_c2.id), (None, None, None)),
+            ],
+            merge_entries(b"", tree2, tree1),
+        )
+
+        self.assertMergeFails(merge_entries, 0xDEADBEEF, 0o100644, "1" * 40)
+        self.assertMergeFails(merge_entries, b"a", b"deadbeef", "1" * 40)
+        self.assertMergeFails(merge_entries, b"a", 0o100644, 0xDEADBEEF)
+
+    test_merge_entries = functest_builder(_do_test_merge_entries, _merge_entries_py)
+    test_merge_entries_extension = ext_functest_builder(
+        _do_test_merge_entries, _merge_entries
+    )
 
 
     def _do_test_is_tree(self, is_tree):
     def _do_test_is_tree(self, is_tree):
         self.assertFalse(is_tree(TreeEntry(None, None, None)))
         self.assertFalse(is_tree(TreeEntry(None, None, None)))
-        self.assertFalse(is_tree(TreeEntry(b'a', 0o100644, b'a' * 40)))
-        self.assertFalse(is_tree(TreeEntry(b'a', 0o100755, b'a' * 40)))
-        self.assertFalse(is_tree(TreeEntry(b'a', 0o120000, b'a' * 40)))
-        self.assertTrue(is_tree(TreeEntry(b'a', 0o040000, b'a' * 40)))
-        self.assertRaises(TypeError, is_tree, TreeEntry(b'a', b'x', b'a' * 40))
+        self.assertFalse(is_tree(TreeEntry(b"a", 0o100644, b"a" * 40)))
+        self.assertFalse(is_tree(TreeEntry(b"a", 0o100755, b"a" * 40)))
+        self.assertFalse(is_tree(TreeEntry(b"a", 0o120000, b"a" * 40)))
+        self.assertTrue(is_tree(TreeEntry(b"a", 0o040000, b"a" * 40)))
+        self.assertRaises(TypeError, is_tree, TreeEntry(b"a", b"x", b"a" * 40))
         self.assertRaises(AttributeError, is_tree, 1234)
         self.assertRaises(AttributeError, is_tree, 1234)
 
 
     test_is_tree = functest_builder(_do_test_is_tree, _is_tree_py)
     test_is_tree = functest_builder(_do_test_is_tree, _is_tree_py)
@@ -166,243 +176,334 @@ class TreeChangesTest(DiffTestCase):
         self.assertChangesEqual([], self.empty_tree, self.empty_tree)
         self.assertChangesEqual([], self.empty_tree, self.empty_tree)
 
 
     def test_tree_changes_no_changes(self):
     def test_tree_changes_no_changes(self):
-        blob = make_object(Blob, data=b'blob')
-        tree = self.commit_tree([(b'a', blob), (b'b/c', blob)])
+        blob = make_object(Blob, data=b"blob")
+        tree = self.commit_tree([(b"a", blob), (b"b/c", blob)])
         self.assertChangesEqual([], self.empty_tree, self.empty_tree)
         self.assertChangesEqual([], self.empty_tree, self.empty_tree)
         self.assertChangesEqual([], tree, tree)
         self.assertChangesEqual([], tree, tree)
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange(CHANGE_UNCHANGED, (b'a', F, blob.id),
-                        (b'a', F, blob.id)),
-             TreeChange(CHANGE_UNCHANGED, (b'b/c', F, blob.id),
-                        (b'b/c', F, blob.id))],
-            tree, tree, want_unchanged=True)
+            [
+                TreeChange(CHANGE_UNCHANGED, (b"a", F, blob.id), (b"a", F, blob.id)),
+                TreeChange(
+                    CHANGE_UNCHANGED,
+                    (b"b/c", F, blob.id),
+                    (b"b/c", F, blob.id),
+                ),
+            ],
+            tree,
+            tree,
+            want_unchanged=True,
+        )
 
 
     def test_tree_changes_add_delete(self):
     def test_tree_changes_add_delete(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
-        tree = self.commit_tree([(b'a', blob_a, 0o100644),
-                                 (b'x/b', blob_b, 0o100755)])
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
+        tree = self.commit_tree([(b"a", blob_a, 0o100644), (b"x/b", blob_b, 0o100755)])
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange.add((b'a', 0o100644, blob_a.id)),
-             TreeChange.add((b'x/b', 0o100755, blob_b.id))],
-            self.empty_tree, tree)
+            [
+                TreeChange.add((b"a", 0o100644, blob_a.id)),
+                TreeChange.add((b"x/b", 0o100755, blob_b.id)),
+            ],
+            self.empty_tree,
+            tree,
+        )
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', 0o100644, blob_a.id)),
-             TreeChange.delete((b'x/b', 0o100755, blob_b.id))],
-            tree, self.empty_tree)
+            [
+                TreeChange.delete((b"a", 0o100644, blob_a.id)),
+                TreeChange.delete((b"x/b", 0o100755, blob_b.id)),
+            ],
+            tree,
+            self.empty_tree,
+        )
 
 
     def test_tree_changes_modify_contents(self):
     def test_tree_changes_modify_contents(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        tree1 = self.commit_tree([(b'a', blob_a1)])
-        tree2 = self.commit_tree([(b'a', blob_a2)])
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        tree1 = self.commit_tree([(b"a", blob_a1)])
+        tree2 = self.commit_tree([(b"a", blob_a2)])
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob_a1.id),
-                        (b'a', F, blob_a2.id))],
-            tree1, tree2)
+            [TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a2.id))],
+            tree1,
+            tree2,
+        )
 
 
     def test_tree_changes_modify_mode(self):
     def test_tree_changes_modify_mode(self):
-        blob_a = make_object(Blob, data=b'a')
-        tree1 = self.commit_tree([(b'a', blob_a, 0o100644)])
-        tree2 = self.commit_tree([(b'a', blob_a, 0o100755)])
+        blob_a = make_object(Blob, data=b"a")
+        tree1 = self.commit_tree([(b"a", blob_a, 0o100644)])
+        tree2 = self.commit_tree([(b"a", blob_a, 0o100755)])
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', 0o100644, blob_a.id),
-                        (b'a', 0o100755, blob_a.id))],
-            tree1, tree2)
+            [
+                TreeChange(
+                    CHANGE_MODIFY,
+                    (b"a", 0o100644, blob_a.id),
+                    (b"a", 0o100755, blob_a.id),
+                )
+            ],
+            tree1,
+            tree2,
+        )
 
 
     def test_tree_changes_change_type(self):
     def test_tree_changes_change_type(self):
-        blob_a1 = make_object(Blob, data=b'a')
-        blob_a2 = make_object(Blob, data=b'/foo/bar')
-        tree1 = self.commit_tree([(b'a', blob_a1, 0o100644)])
-        tree2 = self.commit_tree([(b'a', blob_a2, 0o120000)])
+        blob_a1 = make_object(Blob, data=b"a")
+        blob_a2 = make_object(Blob, data=b"/foo/bar")
+        tree1 = self.commit_tree([(b"a", blob_a1, 0o100644)])
+        tree2 = self.commit_tree([(b"a", blob_a2, 0o120000)])
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', 0o100644, blob_a1.id)),
-             TreeChange.add((b'a', 0o120000, blob_a2.id))],
-            tree1, tree2)
+            [
+                TreeChange.delete((b"a", 0o100644, blob_a1.id)),
+                TreeChange.add((b"a", 0o120000, blob_a2.id)),
+            ],
+            tree1,
+            tree2,
+        )
 
 
     def test_tree_changes_change_type_same(self):
     def test_tree_changes_change_type_same(self):
-        blob_a1 = make_object(Blob, data=b'a')
-        blob_a2 = make_object(Blob, data=b'/foo/bar')
-        tree1 = self.commit_tree([(b'a', blob_a1, 0o100644)])
-        tree2 = self.commit_tree([(b'a', blob_a2, 0o120000)])
+        blob_a1 = make_object(Blob, data=b"a")
+        blob_a2 = make_object(Blob, data=b"/foo/bar")
+        tree1 = self.commit_tree([(b"a", blob_a1, 0o100644)])
+        tree2 = self.commit_tree([(b"a", blob_a2, 0o120000)])
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', 0o100644, blob_a1.id),
-                        (b'a', 0o120000, blob_a2.id))],
-            tree1, tree2, change_type_same=True)
+            [
+                TreeChange(
+                    CHANGE_MODIFY,
+                    (b"a", 0o100644, blob_a1.id),
+                    (b"a", 0o120000, blob_a2.id),
+                )
+            ],
+            tree1,
+            tree2,
+            change_type_same=True,
+        )
 
 
     def test_tree_changes_to_tree(self):
     def test_tree_changes_to_tree(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_x = make_object(Blob, data=b'x')
-        tree1 = self.commit_tree([(b'a', blob_a)])
-        tree2 = self.commit_tree([(b'a/x', blob_x)])
+        blob_a = make_object(Blob, data=b"a")
+        blob_x = make_object(Blob, data=b"x")
+        tree1 = self.commit_tree([(b"a", blob_a)])
+        tree2 = self.commit_tree([(b"a/x", blob_x)])
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', F, blob_a.id)),
-             TreeChange.add((b'a/x', F, blob_x.id))],
-            tree1, tree2)
+            [
+                TreeChange.delete((b"a", F, blob_a.id)),
+                TreeChange.add((b"a/x", F, blob_x.id)),
+            ],
+            tree1,
+            tree2,
+        )
 
 
     def test_tree_changes_complex(self):
     def test_tree_changes_complex(self):
-        blob_a_1 = make_object(Blob, data=b'a1_1')
-        blob_bx1_1 = make_object(Blob, data=b'bx1_1')
-        blob_bx2_1 = make_object(Blob, data=b'bx2_1')
-        blob_by1_1 = make_object(Blob, data=b'by1_1')
-        blob_by2_1 = make_object(Blob, data=b'by2_1')
-        tree1 = self.commit_tree([
-            (b'a', blob_a_1),
-            (b'b/x/1', blob_bx1_1),
-            (b'b/x/2', blob_bx2_1),
-            (b'b/y/1', blob_by1_1),
-            (b'b/y/2', blob_by2_1),
-        ])
-
-        blob_a_2 = make_object(Blob, data=b'a1_2')
+        blob_a_1 = make_object(Blob, data=b"a1_1")
+        blob_bx1_1 = make_object(Blob, data=b"bx1_1")
+        blob_bx2_1 = make_object(Blob, data=b"bx2_1")
+        blob_by1_1 = make_object(Blob, data=b"by1_1")
+        blob_by2_1 = make_object(Blob, data=b"by2_1")
+        tree1 = self.commit_tree(
+            [
+                (b"a", blob_a_1),
+                (b"b/x/1", blob_bx1_1),
+                (b"b/x/2", blob_bx2_1),
+                (b"b/y/1", blob_by1_1),
+                (b"b/y/2", blob_by2_1),
+            ]
+        )
+
+        blob_a_2 = make_object(Blob, data=b"a1_2")
         blob_bx1_2 = blob_bx1_1
         blob_bx1_2 = blob_bx1_1
-        blob_by_2 = make_object(Blob, data=b'by_2')
-        blob_c_2 = make_object(Blob, data=b'c_2')
-        tree2 = self.commit_tree([
-            (b'a', blob_a_2),
-            (b'b/x/1', blob_bx1_2),
-            (b'b/y', blob_by_2),
-            (b'c', blob_c_2),
-        ])
+        blob_by_2 = make_object(Blob, data=b"by_2")
+        blob_c_2 = make_object(Blob, data=b"c_2")
+        tree2 = self.commit_tree(
+            [
+                (b"a", blob_a_2),
+                (b"b/x/1", blob_bx1_2),
+                (b"b/y", blob_by_2),
+                (b"c", blob_c_2),
+            ]
+        )
 
 
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob_a_1.id),
-                        (b'a', F, blob_a_2.id)),
-             TreeChange.delete((b'b/x/2', F, blob_bx2_1.id)),
-             TreeChange.add((b'b/y', F, blob_by_2.id)),
-             TreeChange.delete((b'b/y/1', F, blob_by1_1.id)),
-             TreeChange.delete((b'b/y/2', F, blob_by2_1.id)),
-             TreeChange.add((b'c', F, blob_c_2.id))],
-            tree1, tree2)
+            [
+                TreeChange(
+                    CHANGE_MODIFY,
+                    (b"a", F, blob_a_1.id),
+                    (b"a", F, blob_a_2.id),
+                ),
+                TreeChange.delete((b"b/x/2", F, blob_bx2_1.id)),
+                TreeChange.add((b"b/y", F, blob_by_2.id)),
+                TreeChange.delete((b"b/y/1", F, blob_by1_1.id)),
+                TreeChange.delete((b"b/y/2", F, blob_by2_1.id)),
+                TreeChange.add((b"c", F, blob_c_2.id)),
+            ],
+            tree1,
+            tree2,
+        )
 
 
     def test_tree_changes_name_order(self):
     def test_tree_changes_name_order(self):
-        blob = make_object(Blob, data=b'a')
-        tree1 = self.commit_tree([(b'a', blob), (b'a.', blob), (b'a..', blob)])
+        blob = make_object(Blob, data=b"a")
+        tree1 = self.commit_tree([(b"a", blob), (b"a.", blob), (b"a..", blob)])
         # Tree order is the reverse of this, so if we used tree order, 'a..'
         # Tree order is the reverse of this, so if we used tree order, 'a..'
         # would not be merged.
         # would not be merged.
-        tree2 = self.commit_tree(
-                [(b'a/x', blob), (b'a./x', blob), (b'a..', blob)])
+        tree2 = self.commit_tree([(b"a/x", blob), (b"a./x", blob), (b"a..", blob)])
 
 
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', F, blob.id)),
-             TreeChange.add((b'a/x', F, blob.id)),
-             TreeChange.delete((b'a.', F, blob.id)),
-             TreeChange.add((b'a./x', F, blob.id))],
-            tree1, tree2)
+            [
+                TreeChange.delete((b"a", F, blob.id)),
+                TreeChange.add((b"a/x", F, blob.id)),
+                TreeChange.delete((b"a.", F, blob.id)),
+                TreeChange.add((b"a./x", F, blob.id)),
+            ],
+            tree1,
+            tree2,
+        )
 
 
     def test_tree_changes_prune(self):
     def test_tree_changes_prune(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_x = make_object(Blob, data=b'x')
-        tree1 = self.commit_tree([(b'a', blob_a1), (b'b/x', blob_x)])
-        tree2 = self.commit_tree([(b'a', blob_a2), (b'b/x', blob_x)])
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_x = make_object(Blob, data=b"x")
+        tree1 = self.commit_tree([(b"a", blob_a1), (b"b/x", blob_x)])
+        tree2 = self.commit_tree([(b"a", blob_a2), (b"b/x", blob_x)])
         # Remove identical items so lookups will fail unless we prune.
         # Remove identical items so lookups will fail unless we prune.
-        subtree = self.store[tree1[b'b'][1]]
+        subtree = self.store[tree1[b"b"][1]]
         for entry in subtree.items():
         for entry in subtree.items():
             del self.store[entry.sha]
             del self.store[entry.sha]
         del self.store[subtree.id]
         del self.store[subtree.id]
 
 
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob_a1.id),
-                        (b'a', F, blob_a2.id))],
-            tree1, tree2)
+            [TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a2.id))],
+            tree1,
+            tree2,
+        )
 
 
     def test_tree_changes_rename_detector(self):
     def test_tree_changes_rename_detector(self):
-        blob_a1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob_a2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob_b = make_object(Blob, data=b'b')
-        tree1 = self.commit_tree([(b'a', blob_a1), (b'b', blob_b)])
-        tree2 = self.commit_tree([(b'c', blob_a2), (b'b', blob_b)])
+        blob_a1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob_a2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob_b = make_object(Blob, data=b"b")
+        tree1 = self.commit_tree([(b"a", blob_a1), (b"b", blob_b)])
+        tree2 = self.commit_tree([(b"c", blob_a2), (b"b", blob_b)])
         detector = RenameDetector(self.store)
         detector = RenameDetector(self.store)
 
 
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', F, blob_a1.id)),
-             TreeChange.add((b'c', F, blob_a2.id))],
-            tree1, tree2)
+            [
+                TreeChange.delete((b"a", F, blob_a1.id)),
+                TreeChange.add((b"c", F, blob_a2.id)),
+            ],
+            tree1,
+            tree2,
+        )
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange.delete((b'a', F, blob_a1.id)),
-             TreeChange(CHANGE_UNCHANGED, (b'b', F, blob_b.id),
-                        (b'b', F, blob_b.id)),
-             TreeChange.add((b'c', F, blob_a2.id))],
-            tree1, tree2, want_unchanged=True)
+            [
+                TreeChange.delete((b"a", F, blob_a1.id)),
+                TreeChange(
+                    CHANGE_UNCHANGED,
+                    (b"b", F, blob_b.id),
+                    (b"b", F, blob_b.id),
+                ),
+                TreeChange.add((b"c", F, blob_a2.id)),
+            ],
+            tree1,
+            tree2,
+            want_unchanged=True,
+        )
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob_a1.id),
-                        (b'c', F, blob_a2.id))],
-            tree1, tree2, rename_detector=detector)
+            [TreeChange(CHANGE_RENAME, (b"a", F, blob_a1.id), (b"c", F, blob_a2.id))],
+            tree1,
+            tree2,
+            rename_detector=detector,
+        )
         self.assertChangesEqual(
         self.assertChangesEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob_a1.id),
-                        (b'c', F, blob_a2.id)),
-             TreeChange(CHANGE_UNCHANGED, (b'b', F, blob_b.id),
-                        (b'b', F, blob_b.id))],
-            tree1, tree2, rename_detector=detector, want_unchanged=True)
-
-    def assertChangesForMergeEqual(self, expected, parent_trees, merge_tree,
-                                   **kwargs):
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob_a1.id), (b"c", F, blob_a2.id)),
+                TreeChange(
+                    CHANGE_UNCHANGED,
+                    (b"b", F, blob_b.id),
+                    (b"b", F, blob_b.id),
+                ),
+            ],
+            tree1,
+            tree2,
+            rename_detector=detector,
+            want_unchanged=True,
+        )
+
+    def assertChangesForMergeEqual(self, expected, parent_trees, merge_tree, **kwargs):
         parent_tree_ids = [t.id for t in parent_trees]
         parent_tree_ids = [t.id for t in parent_trees]
-        actual = list(tree_changes_for_merge(
-          self.store, parent_tree_ids, merge_tree.id, **kwargs))
+        actual = list(
+            tree_changes_for_merge(self.store, parent_tree_ids, merge_tree.id, **kwargs)
+        )
         self.assertEqual(expected, actual)
         self.assertEqual(expected, actual)
 
 
         parent_tree_ids.reverse()
         parent_tree_ids.reverse()
         expected = [list(reversed(cs)) for cs in expected]
         expected = [list(reversed(cs)) for cs in expected]
-        actual = list(tree_changes_for_merge(
-          self.store, parent_tree_ids, merge_tree.id, **kwargs))
+        actual = list(
+            tree_changes_for_merge(self.store, parent_tree_ids, merge_tree.id, **kwargs)
+        )
         self.assertEqual(expected, actual)
         self.assertEqual(expected, actual)
 
 
     def test_tree_changes_for_merge_add_no_conflict(self):
     def test_tree_changes_for_merge_add_no_conflict(self):
-        blob = make_object(Blob, data=b'blob')
+        blob = make_object(Blob, data=b"blob")
         parent1 = self.commit_tree([])
         parent1 = self.commit_tree([])
-        parent2 = merge = self.commit_tree([(b'a', blob)])
+        parent2 = merge = self.commit_tree([(b"a", blob)])
         self.assertChangesForMergeEqual([], [parent1, parent2], merge)
         self.assertChangesForMergeEqual([], [parent1, parent2], merge)
         self.assertChangesForMergeEqual([], [parent2, parent2], merge)
         self.assertChangesForMergeEqual([], [parent2, parent2], merge)
 
 
     def test_tree_changes_for_merge_add_modify_conflict(self):
     def test_tree_changes_for_merge_add_modify_conflict(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
         parent1 = self.commit_tree([])
         parent1 = self.commit_tree([])
-        parent2 = self.commit_tree([(b'a', blob1)])
-        merge = self.commit_tree([(b'a', blob2)])
+        parent2 = self.commit_tree([(b"a", blob1)])
+        merge = self.commit_tree([(b"a", blob2)])
         self.assertChangesForMergeEqual(
         self.assertChangesForMergeEqual(
-            [[TreeChange.add((b'a', F, blob2.id)),
-              TreeChange(CHANGE_MODIFY, (b'a', F, blob1.id),
-                         (b'a', F, blob2.id))]],
-            [parent1, parent2], merge)
+            [
+                [
+                    TreeChange.add((b"a", F, blob2.id)),
+                    TreeChange(CHANGE_MODIFY, (b"a", F, blob1.id), (b"a", F, blob2.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+        )
 
 
     def test_tree_changes_for_merge_modify_modify_conflict(self):
     def test_tree_changes_for_merge_modify_modify_conflict(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        blob3 = make_object(Blob, data=b'3')
-        parent1 = self.commit_tree([(b'a', blob1)])
-        parent2 = self.commit_tree([(b'a', blob2)])
-        merge = self.commit_tree([(b'a', blob3)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        blob3 = make_object(Blob, data=b"3")
+        parent1 = self.commit_tree([(b"a", blob1)])
+        parent2 = self.commit_tree([(b"a", blob2)])
+        merge = self.commit_tree([(b"a", blob3)])
         self.assertChangesForMergeEqual(
         self.assertChangesForMergeEqual(
-            [[TreeChange(CHANGE_MODIFY, (b'a', F, blob1.id),
-                         (b'a', F, blob3.id)),
-              TreeChange(CHANGE_MODIFY, (b'a', F, blob2.id),
-                         (b'a', F, blob3.id))]],
-            [parent1, parent2], merge)
+            [
+                [
+                    TreeChange(CHANGE_MODIFY, (b"a", F, blob1.id), (b"a", F, blob3.id)),
+                    TreeChange(CHANGE_MODIFY, (b"a", F, blob2.id), (b"a", F, blob3.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+        )
 
 
     def test_tree_changes_for_merge_modify_no_conflict(self):
     def test_tree_changes_for_merge_modify_no_conflict(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        parent1 = self.commit_tree([(b'a', blob1)])
-        parent2 = merge = self.commit_tree([(b'a', blob2)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        parent1 = self.commit_tree([(b"a", blob1)])
+        parent2 = merge = self.commit_tree([(b"a", blob2)])
         self.assertChangesForMergeEqual([], [parent1, parent2], merge)
         self.assertChangesForMergeEqual([], [parent1, parent2], merge)
 
 
     def test_tree_changes_for_merge_delete_delete_conflict(self):
     def test_tree_changes_for_merge_delete_delete_conflict(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        parent1 = self.commit_tree([(b'a', blob1)])
-        parent2 = self.commit_tree([(b'a', blob2)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        parent1 = self.commit_tree([(b"a", blob1)])
+        parent2 = self.commit_tree([(b"a", blob2)])
         merge = self.commit_tree([])
         merge = self.commit_tree([])
         self.assertChangesForMergeEqual(
         self.assertChangesForMergeEqual(
-            [[TreeChange.delete((b'a', F, blob1.id)),
-              TreeChange.delete((b'a', F, blob2.id))]],
-            [parent1, parent2], merge)
+            [
+                [
+                    TreeChange.delete((b"a", F, blob1.id)),
+                    TreeChange.delete((b"a", F, blob2.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+        )
 
 
     def test_tree_changes_for_merge_delete_no_conflict(self):
     def test_tree_changes_for_merge_delete_no_conflict(self):
-        blob = make_object(Blob, data=b'blob')
-        has = self.commit_tree([(b'a', blob)])
+        blob = make_object(Blob, data=b"blob")
+        has = self.commit_tree([(b"a", blob)])
         doesnt_have = self.commit_tree([])
         doesnt_have = self.commit_tree([])
         self.assertChangesForMergeEqual([], [has, has], doesnt_have)
         self.assertChangesForMergeEqual([], [has, has], doesnt_have)
         self.assertChangesForMergeEqual([], [has, doesnt_have], doesnt_have)
         self.assertChangesForMergeEqual([], [has, doesnt_have], doesnt_have)
@@ -410,7 +511,7 @@ class TreeChangesTest(DiffTestCase):
     def test_tree_changes_for_merge_octopus_no_conflict(self):
     def test_tree_changes_for_merge_octopus_no_conflict(self):
         r = list(range(5))
         r = list(range(5))
         blobs = [make_object(Blob, data=bytes(i)) for i in r]
         blobs = [make_object(Blob, data=bytes(i)) for i in r]
-        parents = [self.commit_tree([(b'a', blobs[i])]) for i in r]
+        parents = [self.commit_tree([(b"a", blobs[i])]) for i in r]
         for i in r:
         for i in r:
             # Take the SHA from each of the parents.
             # Take the SHA from each of the parents.
             self.assertChangesForMergeEqual([], parents, parents[i])
             self.assertChangesForMergeEqual([], parents, parents[i])
@@ -421,134 +522,168 @@ class TreeChangesTest(DiffTestCase):
         # defined, so test it anyway.
         # defined, so test it anyway.
         r = list(range(5))
         r = list(range(5))
         parent_blobs = [make_object(Blob, data=bytes(i)) for i in r]
         parent_blobs = [make_object(Blob, data=bytes(i)) for i in r]
-        merge_blob = make_object(Blob, data=b'merge')
-        parents = [self.commit_tree([(b'a', parent_blobs[i])]) for i in r]
-        merge = self.commit_tree([(b'a', merge_blob)])
-        expected = [[TreeChange(CHANGE_MODIFY, (b'a', F, parent_blobs[i].id),
-                                (b'a', F, merge_blob.id)) for i in r]]
+        merge_blob = make_object(Blob, data=b"merge")
+        parents = [self.commit_tree([(b"a", parent_blobs[i])]) for i in r]
+        merge = self.commit_tree([(b"a", merge_blob)])
+        expected = [
+            [
+                TreeChange(
+                    CHANGE_MODIFY,
+                    (b"a", F, parent_blobs[i].id),
+                    (b"a", F, merge_blob.id),
+                )
+                for i in r
+            ]
+        ]
         self.assertChangesForMergeEqual(expected, parents, merge)
         self.assertChangesForMergeEqual(expected, parents, merge)
 
 
     def test_tree_changes_for_merge_octopus_delete(self):
     def test_tree_changes_for_merge_octopus_delete(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'3')
-        parent1 = self.commit_tree([(b'a', blob1)])
-        parent2 = self.commit_tree([(b'a', blob2)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"3")
+        parent1 = self.commit_tree([(b"a", blob1)])
+        parent2 = self.commit_tree([(b"a", blob2)])
         parent3 = merge = self.commit_tree([])
         parent3 = merge = self.commit_tree([])
         self.assertChangesForMergeEqual([], [parent1, parent1, parent1], merge)
         self.assertChangesForMergeEqual([], [parent1, parent1, parent1], merge)
         self.assertChangesForMergeEqual([], [parent1, parent1, parent3], merge)
         self.assertChangesForMergeEqual([], [parent1, parent1, parent3], merge)
         self.assertChangesForMergeEqual([], [parent1, parent3, parent3], merge)
         self.assertChangesForMergeEqual([], [parent1, parent3, parent3], merge)
         self.assertChangesForMergeEqual(
         self.assertChangesForMergeEqual(
-            [[TreeChange.delete((b'a', F, blob1.id)),
-              TreeChange.delete((b'a', F, blob2.id)),
-              None]],
-            [parent1, parent2, parent3], merge)
+            [
+                [
+                    TreeChange.delete((b"a", F, blob1.id)),
+                    TreeChange.delete((b"a", F, blob2.id)),
+                    None,
+                ]
+            ],
+            [parent1, parent2, parent3],
+            merge,
+        )
 
 
     def test_tree_changes_for_merge_add_add_same_conflict(self):
     def test_tree_changes_for_merge_add_add_same_conflict(self):
-        blob = make_object(Blob, data=b'a\nb\nc\nd\n')
-        parent1 = self.commit_tree([(b'a', blob)])
+        blob = make_object(Blob, data=b"a\nb\nc\nd\n")
+        parent1 = self.commit_tree([(b"a", blob)])
         parent2 = self.commit_tree([])
         parent2 = self.commit_tree([])
-        merge = self.commit_tree([(b'b', blob)])
-        add = TreeChange.add((b'b', F, blob.id))
-        self.assertChangesForMergeEqual(
-                [[add, add]], [parent1, parent2], merge)
+        merge = self.commit_tree([(b"b", blob)])
+        add = TreeChange.add((b"b", F, blob.id))
+        self.assertChangesForMergeEqual([[add, add]], [parent1, parent2], merge)
 
 
     def test_tree_changes_for_merge_add_exact_rename_conflict(self):
     def test_tree_changes_for_merge_add_exact_rename_conflict(self):
-        blob = make_object(Blob, data=b'a\nb\nc\nd\n')
-        parent1 = self.commit_tree([(b'a', blob)])
+        blob = make_object(Blob, data=b"a\nb\nc\nd\n")
+        parent1 = self.commit_tree([(b"a", blob)])
         parent2 = self.commit_tree([])
         parent2 = self.commit_tree([])
-        merge = self.commit_tree([(b'b', blob)])
+        merge = self.commit_tree([(b"b", blob)])
         self.assertChangesForMergeEqual(
         self.assertChangesForMergeEqual(
-            [[TreeChange(CHANGE_RENAME, (b'a', F, blob.id),
-                         (b'b', F, blob.id)),
-              TreeChange.add((b'b', F, blob.id))]],
-            [parent1, parent2], merge, rename_detector=self.detector)
+            [
+                [
+                    TreeChange(CHANGE_RENAME, (b"a", F, blob.id), (b"b", F, blob.id)),
+                    TreeChange.add((b"b", F, blob.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+            rename_detector=self.detector,
+        )
 
 
     def test_tree_changes_for_merge_add_content_rename_conflict(self):
     def test_tree_changes_for_merge_add_content_rename_conflict(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        parent1 = self.commit_tree([(b'a', blob1)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        parent1 = self.commit_tree([(b"a", blob1)])
         parent2 = self.commit_tree([])
         parent2 = self.commit_tree([])
-        merge = self.commit_tree([(b'b', blob2)])
+        merge = self.commit_tree([(b"b", blob2)])
         self.assertChangesForMergeEqual(
         self.assertChangesForMergeEqual(
-            [[TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                         (b'b', F, blob2.id)),
-              TreeChange.add((b'b', F, blob2.id))]],
-            [parent1, parent2], merge, rename_detector=self.detector)
+            [
+                [
+                    TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob2.id)),
+                    TreeChange.add((b"b", F, blob2.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+            rename_detector=self.detector,
+        )
 
 
     def test_tree_changes_for_merge_modify_rename_conflict(self):
     def test_tree_changes_for_merge_modify_rename_conflict(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        parent1 = self.commit_tree([(b'a', blob1)])
-        parent2 = self.commit_tree([(b'b', blob1)])
-        merge = self.commit_tree([(b'b', blob2)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        parent1 = self.commit_tree([(b"a", blob1)])
+        parent2 = self.commit_tree([(b"b", blob1)])
+        merge = self.commit_tree([(b"b", blob2)])
         self.assertChangesForMergeEqual(
         self.assertChangesForMergeEqual(
-            [[TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                         (b'b', F, blob2.id)),
-              TreeChange(CHANGE_MODIFY, (b'b', F, blob1.id),
-                         (b'b', F, blob2.id))]],
-            [parent1, parent2], merge, rename_detector=self.detector)
+            [
+                [
+                    TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob2.id)),
+                    TreeChange(CHANGE_MODIFY, (b"b", F, blob1.id), (b"b", F, blob2.id)),
+                ]
+            ],
+            [parent1, parent2],
+            merge,
+            rename_detector=self.detector,
+        )
 
 
 
 
 class RenameDetectionTest(DiffTestCase):
 class RenameDetectionTest(DiffTestCase):
-
     def _do_test_count_blocks(self, count_blocks):
     def _do_test_count_blocks(self, count_blocks):
-        blob = make_object(Blob, data=b'a\nb\na\n')
-        self.assertBlockCountEqual({b'a\n': 4, b'b\n': 2}, count_blocks(blob))
+        blob = make_object(Blob, data=b"a\nb\na\n")
+        self.assertBlockCountEqual({b"a\n": 4, b"b\n": 2}, count_blocks(blob))
 
 
-    test_count_blocks = functest_builder(_do_test_count_blocks,
-                                         _count_blocks_py)
-    test_count_blocks_extension = ext_functest_builder(_do_test_count_blocks,
-                                                       _count_blocks)
+    test_count_blocks = functest_builder(_do_test_count_blocks, _count_blocks_py)
+    test_count_blocks_extension = ext_functest_builder(
+        _do_test_count_blocks, _count_blocks
+    )
 
 
     def _do_test_count_blocks_no_newline(self, count_blocks):
     def _do_test_count_blocks_no_newline(self, count_blocks):
-        blob = make_object(Blob, data=b'a\na')
-        self.assertBlockCountEqual({b'a\n': 2, b'a': 1}, _count_blocks(blob))
+        blob = make_object(Blob, data=b"a\na")
+        self.assertBlockCountEqual({b"a\n": 2, b"a": 1}, _count_blocks(blob))
 
 
     test_count_blocks_no_newline = functest_builder(
     test_count_blocks_no_newline = functest_builder(
-        _do_test_count_blocks_no_newline, _count_blocks_py)
+        _do_test_count_blocks_no_newline, _count_blocks_py
+    )
     test_count_blocks_no_newline_extension = ext_functest_builder(
     test_count_blocks_no_newline_extension = ext_functest_builder(
-        _do_test_count_blocks_no_newline, _count_blocks)
+        _do_test_count_blocks_no_newline, _count_blocks
+    )
 
 
     def assertBlockCountEqual(self, expected, got):
     def assertBlockCountEqual(self, expected, got):
         self.assertEqual(
         self.assertEqual(
-            {(hash(l) & 0xffffffff): c for (l, c) in expected.items()},
-            {(h & 0xffffffff): c for (h, c) in got.items()})
+            {(hash(l) & 0xFFFFFFFF): c for (l, c) in expected.items()},
+            {(h & 0xFFFFFFFF): c for (h, c) in got.items()},
+        )
 
 
     def _do_test_count_blocks_chunks(self, count_blocks):
     def _do_test_count_blocks_chunks(self, count_blocks):
-        blob = ShaFile.from_raw_chunks(Blob.type_num, [b'a\nb', b'\na\n'])
-        self.assertBlockCountEqual({b'a\n': 4, b'b\n': 2}, _count_blocks(blob))
+        blob = ShaFile.from_raw_chunks(Blob.type_num, [b"a\nb", b"\na\n"])
+        self.assertBlockCountEqual({b"a\n": 4, b"b\n": 2}, _count_blocks(blob))
 
 
-    test_count_blocks_chunks = functest_builder(_do_test_count_blocks_chunks,
-                                                _count_blocks_py)
+    test_count_blocks_chunks = functest_builder(
+        _do_test_count_blocks_chunks, _count_blocks_py
+    )
     test_count_blocks_chunks_extension = ext_functest_builder(
     test_count_blocks_chunks_extension = ext_functest_builder(
-        _do_test_count_blocks_chunks, _count_blocks)
+        _do_test_count_blocks_chunks, _count_blocks
+    )
 
 
     def _do_test_count_blocks_long_lines(self, count_blocks):
     def _do_test_count_blocks_long_lines(self, count_blocks):
-        a = b'a' * 64
-        data = a + b'xxx\ny\n' + a + b'zzz\n'
+        a = b"a" * 64
+        data = a + b"xxx\ny\n" + a + b"zzz\n"
         blob = make_object(Blob, data=data)
         blob = make_object(Blob, data=data)
         self.assertBlockCountEqual(
         self.assertBlockCountEqual(
-            {b'a' * 64: 128,
-             b'xxx\n': 4,
-             b'y\n': 2,
-             b'zzz\n': 4},
-            _count_blocks(blob))
+            {b"a" * 64: 128, b"xxx\n": 4, b"y\n": 2, b"zzz\n": 4},
+            _count_blocks(blob),
+        )
 
 
     test_count_blocks_long_lines = functest_builder(
     test_count_blocks_long_lines = functest_builder(
-        _do_test_count_blocks_long_lines, _count_blocks_py)
+        _do_test_count_blocks_long_lines, _count_blocks_py
+    )
     test_count_blocks_long_lines_extension = ext_functest_builder(
     test_count_blocks_long_lines_extension = ext_functest_builder(
-        _do_test_count_blocks_long_lines, _count_blocks)
+        _do_test_count_blocks_long_lines, _count_blocks
+    )
 
 
     def assertSimilar(self, expected_score, blob1, blob2):
     def assertSimilar(self, expected_score, blob1, blob2):
         self.assertEqual(expected_score, _similarity_score(blob1, blob2))
         self.assertEqual(expected_score, _similarity_score(blob1, blob2))
         self.assertEqual(expected_score, _similarity_score(blob2, blob1))
         self.assertEqual(expected_score, _similarity_score(blob2, blob1))
 
 
     def test_similarity_score(self):
     def test_similarity_score(self):
-        blob0 = make_object(Blob, data=b'')
-        blob1 = make_object(Blob, data=b'ab\ncd\ncd\n')
-        blob2 = make_object(Blob, data=b'ab\n')
-        blob3 = make_object(Blob, data=b'cd\n')
-        blob4 = make_object(Blob, data=b'cd\ncd\n')
+        blob0 = make_object(Blob, data=b"")
+        blob1 = make_object(Blob, data=b"ab\ncd\ncd\n")
+        blob2 = make_object(Blob, data=b"ab\n")
+        blob3 = make_object(Blob, data=b"cd\n")
+        blob4 = make_object(Blob, data=b"cd\ncd\n")
 
 
         self.assertSimilar(100, blob0, blob0)
         self.assertSimilar(100, blob0, blob0)
         self.assertSimilar(0, blob0, blob1)
         self.assertSimilar(0, blob0, blob1)
@@ -559,396 +694,464 @@ class RenameDetectionTest(DiffTestCase):
         self.assertSimilar(50, blob3, blob4)
         self.assertSimilar(50, blob3, blob4)
 
 
     def test_similarity_score_cache(self):
     def test_similarity_score_cache(self):
-        blob1 = make_object(Blob, data=b'ab\ncd\n')
-        blob2 = make_object(Blob, data=b'ab\n')
+        blob1 = make_object(Blob, data=b"ab\ncd\n")
+        blob2 = make_object(Blob, data=b"ab\n")
 
 
         block_cache = {}
         block_cache = {}
-        self.assertEqual(
-            50, _similarity_score(blob1, blob2, block_cache=block_cache))
+        self.assertEqual(50, _similarity_score(blob1, blob2, block_cache=block_cache))
         self.assertEqual(set([blob1.id, blob2.id]), set(block_cache))
         self.assertEqual(set([blob1.id, blob2.id]), set(block_cache))
 
 
         def fail_chunks():
         def fail_chunks():
-            self.fail('Unexpected call to as_raw_chunks()')
+            self.fail("Unexpected call to as_raw_chunks()")
 
 
         blob1.as_raw_chunks = blob2.as_raw_chunks = fail_chunks
         blob1.as_raw_chunks = blob2.as_raw_chunks = fail_chunks
         blob1.raw_length = lambda: 6
         blob1.raw_length = lambda: 6
         blob2.raw_length = lambda: 3
         blob2.raw_length = lambda: 3
-        self.assertEqual(
-            50, _similarity_score(blob1, blob2, block_cache=block_cache))
+        self.assertEqual(50, _similarity_score(blob1, blob2, block_cache=block_cache))
 
 
     def test_tree_entry_sort(self):
     def test_tree_entry_sort(self):
-        sha = 'abcd' * 10
+        sha = "abcd" * 10
         expected_entries = [
         expected_entries = [
-            TreeChange.add(TreeEntry(b'aaa', F, sha)),
-            TreeChange(CHANGE_COPY, TreeEntry(b'bbb', F, sha),
-                       TreeEntry(b'aab', F, sha)),
-            TreeChange(CHANGE_MODIFY, TreeEntry(b'bbb', F, sha),
-                       TreeEntry(b'bbb', F, b'dabc' * 10)),
-            TreeChange(CHANGE_RENAME, TreeEntry(b'bbc', F, sha),
-                       TreeEntry(b'ddd', F, sha)),
-            TreeChange.delete(TreeEntry(b'ccc', F, sha)),
+            TreeChange.add(TreeEntry(b"aaa", F, sha)),
+            TreeChange(
+                CHANGE_COPY,
+                TreeEntry(b"bbb", F, sha),
+                TreeEntry(b"aab", F, sha),
+            ),
+            TreeChange(
+                CHANGE_MODIFY,
+                TreeEntry(b"bbb", F, sha),
+                TreeEntry(b"bbb", F, b"dabc" * 10),
+            ),
+            TreeChange(
+                CHANGE_RENAME,
+                TreeEntry(b"bbc", F, sha),
+                TreeEntry(b"ddd", F, sha),
+            ),
+            TreeChange.delete(TreeEntry(b"ccc", F, sha)),
         ]
         ]
 
 
         for perm in permutations(expected_entries):
         for perm in permutations(expected_entries):
-            self.assertEqual(expected_entries,
-                             sorted(perm, key=_tree_change_key))
+            self.assertEqual(expected_entries, sorted(perm, key=_tree_change_key))
 
 
     def detect_renames(self, tree1, tree2, want_unchanged=False, **kwargs):
     def detect_renames(self, tree1, tree2, want_unchanged=False, **kwargs):
         detector = RenameDetector(self.store, **kwargs)
         detector = RenameDetector(self.store, **kwargs)
-        return detector.changes_with_renames(tree1.id, tree2.id,
-                                             want_unchanged=want_unchanged)
+        return detector.changes_with_renames(
+            tree1.id, tree2.id, want_unchanged=want_unchanged
+        )
 
 
     def test_no_renames(self):
     def test_no_renames(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\ne\nf\n')
-        blob3 = make_object(Blob, data=b'a\nb\ng\nh\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'a', blob1), (b'b', blob3)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\ne\nf\n")
+        blob3 = make_object(Blob, data=b"a\nb\ng\nh\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"a", blob1), (b"b", blob3)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_MODIFY, (b'b', F, blob2.id),
-                        (b'b', F, blob3.id))],
-            self.detect_renames(tree1, tree2))
+            [TreeChange(CHANGE_MODIFY, (b"b", F, blob2.id), (b"b", F, blob3.id))],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_exact_rename_one_to_one(self):
     def test_exact_rename_one_to_one(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'c', blob1), (b'd', blob2)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"c", blob1), (b"d", blob2)])
         self.assertEqual(
         self.assertEqual(
-                [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                            (b'c', F, blob1.id)),
-                 TreeChange(CHANGE_RENAME, (b'b', F, blob2.id),
-                            (b'd', F, blob2.id))],
-                self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"c", F, blob1.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob2.id), (b"d", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_exact_rename_split_different_type(self):
     def test_exact_rename_split_different_type(self):
-        blob = make_object(Blob, data=b'/foo')
-        tree1 = self.commit_tree([(b'a', blob, 0o100644)])
-        tree2 = self.commit_tree([(b'a', blob, 0o120000)])
+        blob = make_object(Blob, data=b"/foo")
+        tree1 = self.commit_tree([(b"a", blob, 0o100644)])
+        tree2 = self.commit_tree([(b"a", blob, 0o120000)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange.add((b'a', 0o120000, blob.id)),
-             TreeChange.delete((b'a', 0o100644, blob.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange.add((b"a", 0o120000, blob.id)),
+                TreeChange.delete((b"a", 0o100644, blob.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_exact_rename_and_different_type(self):
     def test_exact_rename_and_different_type(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'a', blob2, 0o120000), (b'b', blob1)])
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"a", blob2, 0o120000), (b"b", blob1)])
         self.assertEqual(
         self.assertEqual(
-                [TreeChange.add((b'a', 0o120000, blob2.id)),
-                 TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                            (b'b', F, blob1.id))],
-                self.detect_renames(tree1, tree2))
+            [
+                TreeChange.add((b"a", 0o120000, blob2.id)),
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob1.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_exact_rename_one_to_many(self):
     def test_exact_rename_one_to_many(self):
-        blob = make_object(Blob, data=b'1')
-        tree1 = self.commit_tree([(b'a', blob)])
-        tree2 = self.commit_tree([(b'b', blob), (b'c', blob)])
+        blob = make_object(Blob, data=b"1")
+        tree1 = self.commit_tree([(b"a", blob)])
+        tree2 = self.commit_tree([(b"b", blob), (b"c", blob)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob.id), (b'b', F, blob.id)),
-             TreeChange(CHANGE_COPY, (b'a', F, blob.id), (b'c', F, blob.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob.id), (b"b", F, blob.id)),
+                TreeChange(CHANGE_COPY, (b"a", F, blob.id), (b"c", F, blob.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_exact_rename_many_to_one(self):
     def test_exact_rename_many_to_one(self):
-        blob = make_object(Blob, data=b'1')
-        tree1 = self.commit_tree([(b'a', blob), (b'b', blob)])
-        tree2 = self.commit_tree([(b'c', blob)])
+        blob = make_object(Blob, data=b"1")
+        tree1 = self.commit_tree([(b"a", blob), (b"b", blob)])
+        tree2 = self.commit_tree([(b"c", blob)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob.id), (b'c', F, blob.id)),
-             TreeChange.delete((b'b', F, blob.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob.id), (b"c", F, blob.id)),
+                TreeChange.delete((b"b", F, blob.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_exact_rename_many_to_many(self):
     def test_exact_rename_many_to_many(self):
-        blob = make_object(Blob, data=b'1')
-        tree1 = self.commit_tree([(b'a', blob), (b'b', blob)])
-        tree2 = self.commit_tree([(b'c', blob), (b'd', blob), (b'e', blob)])
-        self.assertEqual(
-                [TreeChange(CHANGE_RENAME, (b'a', F, blob.id),
-                            (b'c', F, blob.id)),
-                 TreeChange(CHANGE_COPY, (b'a', F, blob.id),
-                            (b'e', F, blob.id)),
-                 TreeChange(CHANGE_RENAME, (b'b', F, blob.id),
-                            (b'd', F, blob.id))],
-                self.detect_renames(tree1, tree2))
+        blob = make_object(Blob, data=b"1")
+        tree1 = self.commit_tree([(b"a", blob), (b"b", blob)])
+        tree2 = self.commit_tree([(b"c", blob), (b"d", blob), (b"e", blob)])
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob.id), (b"c", F, blob.id)),
+                TreeChange(CHANGE_COPY, (b"a", F, blob.id), (b"e", F, blob.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob.id), (b"d", F, blob.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_exact_copy_modify(self):
     def test_exact_copy_modify(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'a', blob2), (b'b', blob1)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"a", blob2), (b"b", blob1)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob1.id),
-                        (b'a', F, blob2.id)),
-             TreeChange(CHANGE_COPY, (b'a', F, blob1.id),
-                        (b'b', F, blob1.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_MODIFY, (b"a", F, blob1.id), (b"a", F, blob2.id)),
+                TreeChange(CHANGE_COPY, (b"a", F, blob1.id), (b"b", F, blob1.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_exact_copy_change_mode(self):
     def test_exact_copy_change_mode(self):
-        blob = make_object(Blob, data=b'a\nb\nc\nd\n')
-        tree1 = self.commit_tree([(b'a', blob)])
-        tree2 = self.commit_tree([(b'a', blob, 0o100755), (b'b', blob)])
+        blob = make_object(Blob, data=b"a\nb\nc\nd\n")
+        tree1 = self.commit_tree([(b"a", blob)])
+        tree2 = self.commit_tree([(b"a", blob, 0o100755), (b"b", blob)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob.id),
-                        (b'a', 0o100755, blob.id)),
-             TreeChange(CHANGE_COPY, (b'a', F, blob.id), (b'b', F, blob.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(
+                    CHANGE_MODIFY,
+                    (b"a", F, blob.id),
+                    (b"a", 0o100755, blob.id),
+                ),
+                TreeChange(CHANGE_COPY, (b"a", F, blob.id), (b"b", F, blob.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_rename_threshold(self):
     def test_rename_threshold(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\n')
-        blob2 = make_object(Blob, data=b'a\nb\nd\n')
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'b', blob2)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\n")
+        blob2 = make_object(Blob, data=b"a\nb\nd\n")
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"b", blob2)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2, rename_threshold=50))
+            [TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob2.id))],
+            self.detect_renames(tree1, tree2, rename_threshold=50),
+        )
         self.assertEqual(
         self.assertEqual(
-            [TreeChange.delete((b'a', F, blob1.id)),
-             TreeChange.add((b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2, rename_threshold=75))
+            [
+                TreeChange.delete((b"a", F, blob1.id)),
+                TreeChange.add((b"b", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2, rename_threshold=75),
+        )
 
 
     def test_content_rename_max_files(self):
     def test_content_rename_max_files(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd')
-        blob4 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob2 = make_object(Blob, data=b'e\nf\ng\nh\n')
-        blob3 = make_object(Blob, data=b'e\nf\ng\ni\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'c', blob3), (b'd', blob4)])
-        self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'd', F, blob4.id)),
-             TreeChange(CHANGE_RENAME, (b'b', F, blob2.id),
-                        (b'c', F, blob3.id))],
-            self.detect_renames(tree1, tree2))
-        self.assertEqual(
-            [TreeChange.delete((b'a', F, blob1.id)),
-             TreeChange.delete((b'b', F, blob2.id)),
-             TreeChange.add((b'c', F, blob3.id)),
-             TreeChange.add((b'd', F, blob4.id))],
-            self.detect_renames(tree1, tree2, max_files=1))
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd")
+        blob4 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob2 = make_object(Blob, data=b"e\nf\ng\nh\n")
+        blob3 = make_object(Blob, data=b"e\nf\ng\ni\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"c", blob3), (b"d", blob4)])
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"d", F, blob4.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob2.id), (b"c", F, blob3.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
+        self.assertEqual(
+            [
+                TreeChange.delete((b"a", F, blob1.id)),
+                TreeChange.delete((b"b", F, blob2.id)),
+                TreeChange.add((b"c", F, blob3.id)),
+                TreeChange.add((b"d", F, blob4.id)),
+            ],
+            self.detect_renames(tree1, tree2, max_files=1),
+        )
 
 
     def test_content_rename_one_to_one(self):
     def test_content_rename_one_to_one(self):
-        b11 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        b12 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        b21 = make_object(Blob, data=b'e\nf\ng\n\nh')
-        b22 = make_object(Blob, data=b'e\nf\ng\n\ni')
-        tree1 = self.commit_tree([(b'a', b11), (b'b', b21)])
-        tree2 = self.commit_tree([(b'c', b12), (b'd', b22)])
+        b11 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        b12 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        b21 = make_object(Blob, data=b"e\nf\ng\n\nh")
+        b22 = make_object(Blob, data=b"e\nf\ng\n\ni")
+        tree1 = self.commit_tree([(b"a", b11), (b"b", b21)])
+        tree2 = self.commit_tree([(b"c", b12), (b"d", b22)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, b11.id), (b'c', F, b12.id)),
-             TreeChange(CHANGE_RENAME, (b'b', F, b21.id), (b'd', F, b22.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, b11.id), (b"c", F, b12.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, b21.id), (b"d", F, b22.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_content_rename_one_to_one_ordering(self):
     def test_content_rename_one_to_one_ordering(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\ne\nf\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\nd\ng\nh\n')
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\ne\nf\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\nd\ng\nh\n")
         # 6/10 match to blob1, 8/10 match to blob2
         # 6/10 match to blob1, 8/10 match to blob2
-        blob3 = make_object(Blob, data=b'a\nb\nc\nd\ng\ni\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'c', blob3)])
+        blob3 = make_object(Blob, data=b"a\nb\nc\nd\ng\ni\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"c", blob3)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange.delete((b'a', F, blob1.id)),
-             TreeChange(CHANGE_RENAME, (b'b', F, blob2.id),
-                        (b'c', F, blob3.id))],
-            self.detect_renames(tree1, tree2))
-
-        tree3 = self.commit_tree([(b'a', blob2), (b'b', blob1)])
-        tree4 = self.commit_tree([(b'c', blob3)])
+            [
+                TreeChange.delete((b"a", F, blob1.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob2.id), (b"c", F, blob3.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
+
+        tree3 = self.commit_tree([(b"a", blob2), (b"b", blob1)])
+        tree4 = self.commit_tree([(b"c", blob3)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob2.id),
-                        (b'c', F, blob3.id)),
-             TreeChange.delete((b'b', F, blob1.id))],
-            self.detect_renames(tree3, tree4))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob2.id), (b"c", F, blob3.id)),
+                TreeChange.delete((b"b", F, blob1.id)),
+            ],
+            self.detect_renames(tree3, tree4),
+        )
 
 
     def test_content_rename_one_to_many(self):
     def test_content_rename_one_to_many(self):
-        blob1 = make_object(Blob, data=b'aa\nb\nc\nd\ne\n')
-        blob2 = make_object(Blob, data=b'ab\nb\nc\nd\ne\n')  # 8/11 match
-        blob3 = make_object(Blob, data=b'aa\nb\nc\nd\nf\n')  # 9/11 match
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'b', blob2), (b'c', blob3)])
+        blob1 = make_object(Blob, data=b"aa\nb\nc\nd\ne\n")
+        blob2 = make_object(Blob, data=b"ab\nb\nc\nd\ne\n")  # 8/11 match
+        blob3 = make_object(Blob, data=b"aa\nb\nc\nd\nf\n")  # 9/11 match
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"b", blob2), (b"c", blob3)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_COPY, (b'a', F, blob1.id), (b'b', F, blob2.id)),
-             TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'c', F, blob3.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_COPY, (b"a", F, blob1.id), (b"b", F, blob2.id)),
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"c", F, blob3.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_content_rename_many_to_one(self):
     def test_content_rename_many_to_one(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob3 = make_object(Blob, data=b'a\nb\nc\nf\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'c', blob3)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob3 = make_object(Blob, data=b"a\nb\nc\nf\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"c", blob3)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'c', F, blob3.id)),
-             TreeChange.delete((b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"c", F, blob3.id)),
+                TreeChange.delete((b"b", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_content_rename_many_to_many(self):
     def test_content_rename_many_to_many(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob3 = make_object(Blob, data=b'a\nb\nc\nf\n')
-        blob4 = make_object(Blob, data=b'a\nb\nc\ng\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'c', blob3), (b'd', blob4)])
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob3 = make_object(Blob, data=b"a\nb\nc\nf\n")
+        blob4 = make_object(Blob, data=b"a\nb\nc\ng\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"c", blob3), (b"d", blob4)])
         # TODO(dborowitz): Distribute renames rather than greedily choosing
         # TODO(dborowitz): Distribute renames rather than greedily choosing
         # copies.
         # copies.
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'c', F, blob3.id)),
-             TreeChange(CHANGE_COPY, (b'a', F, blob1.id), (b'd', F, blob4.id)),
-             TreeChange.delete((b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"c", F, blob3.id)),
+                TreeChange(CHANGE_COPY, (b"a", F, blob1.id), (b"d", F, blob4.id)),
+                TreeChange.delete((b"b", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_content_rename_with_more_deletions(self):
     def test_content_rename_with_more_deletions(self):
-        blob1 = make_object(Blob, data=b'')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob1), (b'c', blob1),
-                                  (b'd', blob1)])
-        tree2 = self.commit_tree([(b'e', blob1), (b'f', blob1), (b'g', blob1)])
+        blob1 = make_object(Blob, data=b"")
+        tree1 = self.commit_tree(
+            [(b"a", blob1), (b"b", blob1), (b"c", blob1), (b"d", blob1)]
+        )
+        tree2 = self.commit_tree([(b"e", blob1), (b"f", blob1), (b"g", blob1)])
         self.maxDiff = None
         self.maxDiff = None
         self.assertEqual(
         self.assertEqual(
-          [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id), (b'e', F, blob1.id)),
-           TreeChange(CHANGE_RENAME, (b'b', F, blob1.id), (b'f', F, blob1.id)),
-           TreeChange(CHANGE_RENAME, (b'c', F, blob1.id), (b'g', F, blob1.id)),
-           TreeChange.delete((b'd', F, blob1.id))],
-          self.detect_renames(tree1, tree2))
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"e", F, blob1.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob1.id), (b"f", F, blob1.id)),
+                TreeChange(CHANGE_RENAME, (b"c", F, blob1.id), (b"g", F, blob1.id)),
+                TreeChange.delete((b"d", F, blob1.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_content_rename_gitlink(self):
     def test_content_rename_gitlink(self):
-        blob1 = make_object(Blob, data=b'blob1')
-        blob2 = make_object(Blob, data=b'blob2')
-        link1 = b'1' * 40
-        link2 = b'2' * 40
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', link1, 0o160000)])
-        tree2 = self.commit_tree([(b'c', blob2), (b'd', link2, 0o160000)])
-        self.assertEqual(
-            [TreeChange.delete((b'a', 0o100644, blob1.id)),
-             TreeChange.delete((b'b', 0o160000, link1)),
-             TreeChange.add((b'c', 0o100644, blob2.id)),
-             TreeChange.add((b'd', 0o160000, link2))],
-            self.detect_renames(tree1, tree2))
+        blob1 = make_object(Blob, data=b"blob1")
+        blob2 = make_object(Blob, data=b"blob2")
+        link1 = b"1" * 40
+        link2 = b"2" * 40
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", link1, 0o160000)])
+        tree2 = self.commit_tree([(b"c", blob2), (b"d", link2, 0o160000)])
+        self.assertEqual(
+            [
+                TreeChange.delete((b"a", 0o100644, blob1.id)),
+                TreeChange.delete((b"b", 0o160000, link1)),
+                TreeChange.add((b"c", 0o100644, blob2.id)),
+                TreeChange.add((b"d", 0o160000, link2)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
 
 
     def test_exact_rename_swap(self):
     def test_exact_rename_swap(self):
-        blob1 = make_object(Blob, data=b'1')
-        blob2 = make_object(Blob, data=b'2')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'a', blob2), (b'b', blob1)])
-        self.assertEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob1.id),
-                        (b'a', F, blob2.id)),
-             TreeChange(CHANGE_MODIFY, (b'b', F, blob2.id),
-                        (b'b', F, blob1.id))],
-            self.detect_renames(tree1, tree2))
-        self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'b', F, blob1.id)),
-             TreeChange(CHANGE_RENAME, (b'b', F, blob2.id),
-                        (b'a', F, blob2.id))],
-            self.detect_renames(tree1, tree2, rewrite_threshold=50))
+        blob1 = make_object(Blob, data=b"1")
+        blob2 = make_object(Blob, data=b"2")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"a", blob2), (b"b", blob1)])
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_MODIFY, (b"a", F, blob1.id), (b"a", F, blob2.id)),
+                TreeChange(CHANGE_MODIFY, (b"b", F, blob2.id), (b"b", F, blob1.id)),
+            ],
+            self.detect_renames(tree1, tree2),
+        )
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob1.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob2.id), (b"a", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2, rewrite_threshold=50),
+        )
 
 
     def test_content_rename_swap(self):
     def test_content_rename_swap(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'e\nf\ng\nh\n')
-        blob3 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob4 = make_object(Blob, data=b'e\nf\ng\ni\n')
-        tree1 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        tree2 = self.commit_tree([(b'a', blob4), (b'b', blob3)])
-        self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'b', F, blob3.id)),
-             TreeChange(CHANGE_RENAME, (b'b', F, blob2.id),
-                        (b'a', F, blob4.id))],
-            self.detect_renames(tree1, tree2, rewrite_threshold=60))
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"e\nf\ng\nh\n")
+        blob3 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob4 = make_object(Blob, data=b"e\nf\ng\ni\n")
+        tree1 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
+        tree2 = self.commit_tree([(b"a", blob4), (b"b", blob3)])
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob3.id)),
+                TreeChange(CHANGE_RENAME, (b"b", F, blob2.id), (b"a", F, blob4.id)),
+            ],
+            self.detect_renames(tree1, tree2, rewrite_threshold=60),
+        )
 
 
     def test_rewrite_threshold(self):
     def test_rewrite_threshold(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        blob3 = make_object(Blob, data=b'a\nb\nf\ng\n')
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        blob3 = make_object(Blob, data=b"a\nb\nf\ng\n")
 
 
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'a', blob3), (b'b', blob2)])
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"a", blob3), (b"b", blob2)])
 
 
         no_renames = [
         no_renames = [
-            TreeChange(CHANGE_MODIFY, (b'a', F, blob1.id),
-                       (b'a', F, blob3.id)),
-            TreeChange(CHANGE_COPY, (b'a', F, blob1.id), (b'b', F, blob2.id))]
-        self.assertEqual(
-            no_renames, self.detect_renames(tree1, tree2))
+            TreeChange(CHANGE_MODIFY, (b"a", F, blob1.id), (b"a", F, blob3.id)),
+            TreeChange(CHANGE_COPY, (b"a", F, blob1.id), (b"b", F, blob2.id)),
+        ]
+        self.assertEqual(no_renames, self.detect_renames(tree1, tree2))
         self.assertEqual(
         self.assertEqual(
-            no_renames, self.detect_renames(
-                tree1, tree2, rewrite_threshold=40))
+            no_renames, self.detect_renames(tree1, tree2, rewrite_threshold=40)
+        )
         self.assertEqual(
         self.assertEqual(
-            [TreeChange.add((b'a', F, blob3.id)),
-             TreeChange(CHANGE_RENAME, (b'a', F, blob1.id),
-                        (b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2, rewrite_threshold=80))
+            [
+                TreeChange.add((b"a", F, blob3.id)),
+                TreeChange(CHANGE_RENAME, (b"a", F, blob1.id), (b"b", F, blob2.id)),
+            ],
+            self.detect_renames(tree1, tree2, rewrite_threshold=80),
+        )
 
 
     def test_find_copies_harder_exact(self):
     def test_find_copies_harder_exact(self):
-        blob = make_object(Blob, data=b'blob')
-        tree1 = self.commit_tree([(b'a', blob)])
-        tree2 = self.commit_tree([(b'a', blob), (b'b', blob)])
-        self.assertEqual([TreeChange.add((b'b', F, blob.id))],
-                         self.detect_renames(tree1, tree2))
+        blob = make_object(Blob, data=b"blob")
+        tree1 = self.commit_tree([(b"a", blob)])
+        tree2 = self.commit_tree([(b"a", blob), (b"b", blob)])
+        self.assertEqual(
+            [TreeChange.add((b"b", F, blob.id))],
+            self.detect_renames(tree1, tree2),
+        )
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_COPY, (b'a', F, blob.id), (b'b', F, blob.id))],
-            self.detect_renames(tree1, tree2, find_copies_harder=True))
+            [TreeChange(CHANGE_COPY, (b"a", F, blob.id), (b"b", F, blob.id))],
+            self.detect_renames(tree1, tree2, find_copies_harder=True),
+        )
 
 
     def test_find_copies_harder_content(self):
     def test_find_copies_harder_content(self):
-        blob1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        tree1 = self.commit_tree([(b'a', blob1)])
-        tree2 = self.commit_tree([(b'a', blob1), (b'b', blob2)])
-        self.assertEqual([TreeChange.add((b'b', F, blob2.id))],
-                         self.detect_renames(tree1, tree2))
+        blob1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        tree1 = self.commit_tree([(b"a", blob1)])
+        tree2 = self.commit_tree([(b"a", blob1), (b"b", blob2)])
         self.assertEqual(
         self.assertEqual(
-            [TreeChange(CHANGE_COPY, (b'a', F, blob1.id),
-                        (b'b', F, blob2.id))],
-            self.detect_renames(tree1, tree2, find_copies_harder=True))
+            [TreeChange.add((b"b", F, blob2.id))],
+            self.detect_renames(tree1, tree2),
+        )
+        self.assertEqual(
+            [TreeChange(CHANGE_COPY, (b"a", F, blob1.id), (b"b", F, blob2.id))],
+            self.detect_renames(tree1, tree2, find_copies_harder=True),
+        )
 
 
     def test_find_copies_harder_with_rewrites(self):
     def test_find_copies_harder_with_rewrites(self):
-        blob_a1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob_a2 = make_object(Blob, data=b'f\ng\nh\ni\n')
-        blob_b2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        tree1 = self.commit_tree([(b'a', blob_a1)])
-        tree2 = self.commit_tree([(b'a', blob_a2), (b'b', blob_b2)])
-        self.assertEqual(
-            [TreeChange(CHANGE_MODIFY, (b'a', F, blob_a1.id),
-                        (b'a', F, blob_a2.id)),
-             TreeChange(CHANGE_COPY, (b'a', F, blob_a1.id),
-                        (b'b', F, blob_b2.id))],
-            self.detect_renames(tree1, tree2, find_copies_harder=True))
-        self.assertEqual(
-            [TreeChange.add((b'a', F, blob_a2.id)),
-             TreeChange(CHANGE_RENAME, (b'a', F, blob_a1.id),
-                        (b'b', F, blob_b2.id))],
-            self.detect_renames(tree1, tree2, rewrite_threshold=50,
-                                find_copies_harder=True))
+        blob_a1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob_a2 = make_object(Blob, data=b"f\ng\nh\ni\n")
+        blob_b2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        tree1 = self.commit_tree([(b"a", blob_a1)])
+        tree2 = self.commit_tree([(b"a", blob_a2), (b"b", blob_b2)])
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a2.id)),
+                TreeChange(CHANGE_COPY, (b"a", F, blob_a1.id), (b"b", F, blob_b2.id)),
+            ],
+            self.detect_renames(tree1, tree2, find_copies_harder=True),
+        )
+        self.assertEqual(
+            [
+                TreeChange.add((b"a", F, blob_a2.id)),
+                TreeChange(CHANGE_RENAME, (b"a", F, blob_a1.id), (b"b", F, blob_b2.id)),
+            ],
+            self.detect_renames(
+                tree1, tree2, rewrite_threshold=50, find_copies_harder=True
+            ),
+        )
 
 
     def test_reuse_detector(self):
     def test_reuse_detector(self):
-        blob = make_object(Blob, data=b'blob')
-        tree1 = self.commit_tree([(b'a', blob)])
-        tree2 = self.commit_tree([(b'b', blob)])
+        blob = make_object(Blob, data=b"blob")
+        tree1 = self.commit_tree([(b"a", blob)])
+        tree2 = self.commit_tree([(b"b", blob)])
         detector = RenameDetector(self.store)
         detector = RenameDetector(self.store)
-        changes = [TreeChange(CHANGE_RENAME, (b'a', F, blob.id),
-                              (b'b', F, blob.id))]
-        self.assertEqual(changes,
-                         detector.changes_with_renames(tree1.id, tree2.id))
-        self.assertEqual(changes,
-                         detector.changes_with_renames(tree1.id, tree2.id))
+        changes = [TreeChange(CHANGE_RENAME, (b"a", F, blob.id), (b"b", F, blob.id))]
+        self.assertEqual(changes, detector.changes_with_renames(tree1.id, tree2.id))
+        self.assertEqual(changes, detector.changes_with_renames(tree1.id, tree2.id))
 
 
     def test_want_unchanged(self):
     def test_want_unchanged(self):
-        blob_a1 = make_object(Blob, data=b'a\nb\nc\nd\n')
-        blob_b = make_object(Blob, data=b'b')
-        blob_c2 = make_object(Blob, data=b'a\nb\nc\ne\n')
-        tree1 = self.commit_tree([(b'a', blob_a1), (b'b', blob_b)])
-        tree2 = self.commit_tree([(b'c', blob_c2), (b'b', blob_b)])
-        self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob_a1.id),
-                        (b'c', F, blob_c2.id))],
-            self.detect_renames(tree1, tree2))
-        self.assertEqual(
-            [TreeChange(CHANGE_RENAME, (b'a', F, blob_a1.id),
-                        (b'c', F, blob_c2.id)),
-             TreeChange(CHANGE_UNCHANGED, (b'b', F, blob_b.id),
-                        (b'b', F, blob_b.id))],
-            self.detect_renames(tree1, tree2, want_unchanged=True))
+        blob_a1 = make_object(Blob, data=b"a\nb\nc\nd\n")
+        blob_b = make_object(Blob, data=b"b")
+        blob_c2 = make_object(Blob, data=b"a\nb\nc\ne\n")
+        tree1 = self.commit_tree([(b"a", blob_a1), (b"b", blob_b)])
+        tree2 = self.commit_tree([(b"c", blob_c2), (b"b", blob_b)])
+        self.assertEqual(
+            [TreeChange(CHANGE_RENAME, (b"a", F, blob_a1.id), (b"c", F, blob_c2.id))],
+            self.detect_renames(tree1, tree2),
+        )
+        self.assertEqual(
+            [
+                TreeChange(CHANGE_RENAME, (b"a", F, blob_a1.id), (b"c", F, blob_c2.id)),
+                TreeChange(
+                    CHANGE_UNCHANGED,
+                    (b"b", F, blob_b.id),
+                    (b"b", F, blob_b.id),
+                ),
+            ],
+            self.detect_renames(tree1, tree2, want_unchanged=True),
+        )

+ 112 - 56
dulwich/tests/test_fastexport.py

@@ -24,23 +24,23 @@ import stat
 
 
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
     Commit,
     Commit,
     Tree,
     Tree,
     ZERO_SHA,
     ZERO_SHA,
-    )
+)
 from dulwich.repo import (
 from dulwich.repo import (
     MemoryRepo,
     MemoryRepo,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     SkipTest,
     SkipTest,
     TestCase,
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     build_commit_graph,
     build_commit_graph,
-    )
+)
 
 
 
 
 class GitFastExporterTests(TestCase):
 class GitFastExporterTests(TestCase):
@@ -60,8 +60,7 @@ class GitFastExporterTests(TestCase):
         b = Blob()
         b = Blob()
         b.data = b"fooBAR"
         b.data = b"fooBAR"
         self.fastexporter.emit_blob(b)
         self.fastexporter.emit_blob(b)
-        self.assertEqual(b'blob\nmark :1\ndata 6\nfooBAR\n',
-                         self.stream.getvalue())
+        self.assertEqual(b"blob\nmark :1\ndata 6\nfooBAR\n", self.stream.getvalue())
 
 
     def test_emit_commit(self):
     def test_emit_commit(self):
         b = Blob()
         b = Blob()
@@ -76,7 +75,8 @@ class GitFastExporterTests(TestCase):
         c.tree = t.id
         c.tree = t.id
         self.store.add_objects([(b, None), (t, None), (c, None)])
         self.store.add_objects([(b, None), (t, None), (c, None)])
         self.fastexporter.emit_commit(c, b"refs/heads/master")
         self.fastexporter.emit_commit(c, b"refs/heads/master")
-        self.assertEqual(b"""blob
+        self.assertEqual(
+            b"""blob
 mark :1
 mark :1
 data 3
 data 3
 FOO
 FOO
@@ -87,7 +87,9 @@ committer Jelmer <jelmer@host> 1271345553 +0000
 data 3
 data 3
 msg
 msg
 M 644 :1 foo
 M 644 :1 foo
-""", self.stream.getvalue())
+""",
+            self.stream.getvalue(),
+        )
 
 
 
 
 class GitImportProcessorTests(TestCase):
 class GitImportProcessorTests(TestCase):
@@ -104,6 +106,7 @@ class GitImportProcessorTests(TestCase):
 
 
     def test_reset_handler(self):
     def test_reset_handler(self):
         from fastimport import commands
         from fastimport import commands
+
         [c1] = build_commit_graph(self.repo.object_store, [[1]])
         [c1] = build_commit_graph(self.repo.object_store, [[1]])
         cmd = commands.ResetCommand(b"refs/heads/foo", c1.id)
         cmd = commands.ResetCommand(b"refs/heads/foo", c1.id)
         self.processor.reset_handler(cmd)
         self.processor.reset_handler(cmd)
@@ -112,14 +115,16 @@ class GitImportProcessorTests(TestCase):
 
 
     def test_reset_handler_marker(self):
     def test_reset_handler_marker(self):
         from fastimport import commands
         from fastimport import commands
+
         [c1, c2] = build_commit_graph(self.repo.object_store, [[1], [2]])
         [c1, c2] = build_commit_graph(self.repo.object_store, [[1], [2]])
-        self.processor.markers[b'10'] = c1.id
-        cmd = commands.ResetCommand(b"refs/heads/foo", b':10')
+        self.processor.markers[b"10"] = c1.id
+        cmd = commands.ResetCommand(b"refs/heads/foo", b":10")
         self.processor.reset_handler(cmd)
         self.processor.reset_handler(cmd)
         self.assertEqual(c1.id, self.repo.get_refs()[b"refs/heads/foo"])
         self.assertEqual(c1.id, self.repo.get_refs()[b"refs/heads/foo"])
 
 
     def test_reset_handler_default(self):
     def test_reset_handler_default(self):
         from fastimport import commands
         from fastimport import commands
+
         [c1, c2] = build_commit_graph(self.repo.object_store, [[1], [2]])
         [c1, c2] = build_commit_graph(self.repo.object_store, [[1], [2]])
         cmd = commands.ResetCommand(b"refs/heads/foo", None)
         cmd = commands.ResetCommand(b"refs/heads/foo", None)
         self.processor.reset_handler(cmd)
         self.processor.reset_handler(cmd)
@@ -127,11 +132,17 @@ class GitImportProcessorTests(TestCase):
 
 
     def test_commit_handler(self):
     def test_commit_handler(self):
         from fastimport import commands
         from fastimport import commands
+
         cmd = commands.CommitCommand(
         cmd = commands.CommitCommand(
-                b"refs/heads/foo",  b"mrkr",
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                b"FOO", None, [], [])
+            b"refs/heads/foo",
+            b"mrkr",
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            b"FOO",
+            None,
+            [],
+            [],
+        )
         self.processor.commit_handler(cmd)
         self.processor.commit_handler(cmd)
         commit = self.repo[self.processor.last_commit]
         commit = self.repo[self.processor.last_commit]
         self.assertEqual(b"Jelmer <jelmer@samba.org>", commit.author)
         self.assertEqual(b"Jelmer <jelmer@samba.org>", commit.author)
@@ -146,16 +157,21 @@ class GitImportProcessorTests(TestCase):
 
 
     def test_commit_handler_markers(self):
     def test_commit_handler_markers(self):
         from fastimport import commands
         from fastimport import commands
-        [c1, c2, c3] = build_commit_graph(self.repo.object_store,
-                                          [[1], [2], [3]])
-        self.processor.markers[b'10'] = c1.id
-        self.processor.markers[b'42'] = c2.id
-        self.processor.markers[b'98'] = c3.id
+
+        [c1, c2, c3] = build_commit_graph(self.repo.object_store, [[1], [2], [3]])
+        self.processor.markers[b"10"] = c1.id
+        self.processor.markers[b"42"] = c2.id
+        self.processor.markers[b"98"] = c3.id
         cmd = commands.CommitCommand(
         cmd = commands.CommitCommand(
-                b"refs/heads/foo",  b"mrkr",
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                b"FOO", b':10', [b':42', b':98'], [])
+            b"refs/heads/foo",
+            b"mrkr",
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            b"FOO",
+            b":10",
+            [b":42", b":98"],
+            [],
+        )
         self.processor.commit_handler(cmd)
         self.processor.commit_handler(cmd)
         commit = self.repo[self.processor.last_commit]
         commit = self.repo[self.processor.last_commit]
         self.assertEqual(c1.id, commit.parents[0])
         self.assertEqual(c1.id, commit.parents[0])
@@ -163,7 +179,9 @@ class GitImportProcessorTests(TestCase):
         self.assertEqual(c3.id, commit.parents[2])
         self.assertEqual(c3.id, commit.parents[2])
 
 
     def test_import_stream(self):
     def test_import_stream(self):
-        markers = self.processor.import_stream(BytesIO(b"""blob
+        markers = self.processor.import_stream(
+            BytesIO(
+                b"""blob
 mark :1
 mark :1
 data 11
 data 11
 text for a
 text for a
@@ -175,37 +193,50 @@ data 20
 <The commit message>
 <The commit message>
 M 100644 :1 a
 M 100644 :1 a
 
 
-"""))
+"""
+            )
+        )
         self.assertEqual(2, len(markers))
         self.assertEqual(2, len(markers))
         self.assertTrue(isinstance(self.repo[markers[b"1"]], Blob))
         self.assertTrue(isinstance(self.repo[markers[b"1"]], Blob))
         self.assertTrue(isinstance(self.repo[markers[b"2"]], Commit))
         self.assertTrue(isinstance(self.repo[markers[b"2"]], Commit))
 
 
     def test_file_add(self):
     def test_file_add(self):
         from fastimport import commands
         from fastimport import commands
+
         cmd = commands.BlobCommand(b"23", b"data")
         cmd = commands.BlobCommand(b"23", b"data")
         self.processor.blob_handler(cmd)
         self.processor.blob_handler(cmd)
         cmd = commands.CommitCommand(
         cmd = commands.CommitCommand(
-                b"refs/heads/foo", b"mrkr",
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                b"FOO", None, [],
-                [commands.FileModifyCommand(b"path", 0o100644, b":23", None)])
+            b"refs/heads/foo",
+            b"mrkr",
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            b"FOO",
+            None,
+            [],
+            [commands.FileModifyCommand(b"path", 0o100644, b":23", None)],
+        )
         self.processor.commit_handler(cmd)
         self.processor.commit_handler(cmd)
         commit = self.repo[self.processor.last_commit]
         commit = self.repo[self.processor.last_commit]
-        self.assertEqual([
-            (b'path', 0o100644, b'6320cd248dd8aeaab759d5871f8781b5c0505172')],
-            self.repo[commit.tree].items())
+        self.assertEqual(
+            [(b"path", 0o100644, b"6320cd248dd8aeaab759d5871f8781b5c0505172")],
+            self.repo[commit.tree].items(),
+        )
 
 
     def simple_commit(self):
     def simple_commit(self):
         from fastimport import commands
         from fastimport import commands
+
         cmd = commands.BlobCommand(b"23", b"data")
         cmd = commands.BlobCommand(b"23", b"data")
         self.processor.blob_handler(cmd)
         self.processor.blob_handler(cmd)
         cmd = commands.CommitCommand(
         cmd = commands.CommitCommand(
-                b"refs/heads/foo", b"mrkr",
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                b"FOO", None, [],
-                [commands.FileModifyCommand(b"path", 0o100644, b":23", None)])
+            b"refs/heads/foo",
+            b"mrkr",
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            b"FOO",
+            None,
+            [],
+            [commands.FileModifyCommand(b"path", 0o100644, b":23", None)],
+        )
         self.processor.commit_handler(cmd)
         self.processor.commit_handler(cmd)
         commit = self.repo[self.processor.last_commit]
         commit = self.repo[self.processor.last_commit]
         return commit
         return commit
@@ -218,44 +249,69 @@ M 100644 :1 a
         Returns: The created commit object
         Returns: The created commit object
         """
         """
         from fastimport import commands
         from fastimport import commands
+
         cmd = commands.CommitCommand(
         cmd = commands.CommitCommand(
-                b"refs/heads/foo", b"mrkr",
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
-                b"FOO", None, [], file_cmds)
+            b"refs/heads/foo",
+            b"mrkr",
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            (b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
+            b"FOO",
+            None,
+            [],
+            file_cmds,
+        )
         self.processor.commit_handler(cmd)
         self.processor.commit_handler(cmd)
         return self.repo[self.processor.last_commit]
         return self.repo[self.processor.last_commit]
 
 
     def test_file_copy(self):
     def test_file_copy(self):
         from fastimport import commands
         from fastimport import commands
+
         self.simple_commit()
         self.simple_commit()
-        commit = self.make_file_commit(
-                [commands.FileCopyCommand(b"path", b"new_path")])
-        self.assertEqual([
-                (b'new_path', 0o100644,
-                 b'6320cd248dd8aeaab759d5871f8781b5c0505172'),
-                (b'path', 0o100644,
-                 b'6320cd248dd8aeaab759d5871f8781b5c0505172'),
-                ], self.repo[commit.tree].items())
+        commit = self.make_file_commit([commands.FileCopyCommand(b"path", b"new_path")])
+        self.assertEqual(
+            [
+                (
+                    b"new_path",
+                    0o100644,
+                    b"6320cd248dd8aeaab759d5871f8781b5c0505172",
+                ),
+                (
+                    b"path",
+                    0o100644,
+                    b"6320cd248dd8aeaab759d5871f8781b5c0505172",
+                ),
+            ],
+            self.repo[commit.tree].items(),
+        )
 
 
     def test_file_move(self):
     def test_file_move(self):
         from fastimport import commands
         from fastimport import commands
+
         self.simple_commit()
         self.simple_commit()
         commit = self.make_file_commit(
         commit = self.make_file_commit(
-                [commands.FileRenameCommand(b"path", b"new_path")])
-        self.assertEqual([
-                (b'new_path', 0o100644,
-                 b'6320cd248dd8aeaab759d5871f8781b5c0505172'),
-                ], self.repo[commit.tree].items())
+            [commands.FileRenameCommand(b"path", b"new_path")]
+        )
+        self.assertEqual(
+            [
+                (
+                    b"new_path",
+                    0o100644,
+                    b"6320cd248dd8aeaab759d5871f8781b5c0505172",
+                ),
+            ],
+            self.repo[commit.tree].items(),
+        )
 
 
     def test_file_delete(self):
     def test_file_delete(self):
         from fastimport import commands
         from fastimport import commands
+
         self.simple_commit()
         self.simple_commit()
         commit = self.make_file_commit([commands.FileDeleteCommand(b"path")])
         commit = self.make_file_commit([commands.FileDeleteCommand(b"path")])
         self.assertEqual([], self.repo[commit.tree].items())
         self.assertEqual([], self.repo[commit.tree].items())
 
 
     def test_file_deleteall(self):
     def test_file_deleteall(self):
         from fastimport import commands
         from fastimport import commands
+
         self.simple_commit()
         self.simple_commit()
         commit = self.make_file_commit([commands.FileDeleteAllCommand()])
         commit = self.make_file_commit([commands.FileDeleteAllCommand()])
         self.assertEqual([], self.repo[commit.tree].items())
         self.assertEqual([], self.repo[commit.tree].items())

+ 62 - 64
dulwich/tests/test_file.py

@@ -28,17 +28,16 @@ from dulwich.file import FileLocked, GitFile, _fancy_rename
 from dulwich.tests import (
 from dulwich.tests import (
     SkipTest,
     SkipTest,
     TestCase,
     TestCase,
-    )
+)
 
 
 
 
 class FancyRenameTests(TestCase):
 class FancyRenameTests(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(FancyRenameTests, self).setUp()
         super(FancyRenameTests, self).setUp()
         self._tempdir = tempfile.mkdtemp()
         self._tempdir = tempfile.mkdtemp()
-        self.foo = self.path('foo')
-        self.bar = self.path('bar')
-        self.create(self.foo, b'foo contents')
+        self.foo = self.path("foo")
+        self.bar = self.path("bar")
+        self.create(self.foo, b"foo contents")
 
 
     def tearDown(self):
     def tearDown(self):
         shutil.rmtree(self._tempdir)
         shutil.rmtree(self._tempdir)
@@ -48,7 +47,7 @@ class FancyRenameTests(TestCase):
         return os.path.join(self._tempdir, filename)
         return os.path.join(self._tempdir, filename)
 
 
     def create(self, path, contents):
     def create(self, path, contents):
-        f = open(path, 'wb')
+        f = open(path, "wb")
         f.write(contents)
         f.write(contents)
         f.close()
         f.close()
 
 
@@ -57,44 +56,43 @@ class FancyRenameTests(TestCase):
         _fancy_rename(self.foo, self.bar)
         _fancy_rename(self.foo, self.bar)
         self.assertFalse(os.path.exists(self.foo))
         self.assertFalse(os.path.exists(self.foo))
 
 
-        new_f = open(self.bar, 'rb')
-        self.assertEqual(b'foo contents', new_f.read())
+        new_f = open(self.bar, "rb")
+        self.assertEqual(b"foo contents", new_f.read())
         new_f.close()
         new_f.close()
 
 
     def test_dest_exists(self):
     def test_dest_exists(self):
-        self.create(self.bar, b'bar contents')
+        self.create(self.bar, b"bar contents")
         _fancy_rename(self.foo, self.bar)
         _fancy_rename(self.foo, self.bar)
         self.assertFalse(os.path.exists(self.foo))
         self.assertFalse(os.path.exists(self.foo))
 
 
-        new_f = open(self.bar, 'rb')
-        self.assertEqual(b'foo contents', new_f.read())
+        new_f = open(self.bar, "rb")
+        self.assertEqual(b"foo contents", new_f.read())
         new_f.close()
         new_f.close()
 
 
     def test_dest_opened(self):
     def test_dest_opened(self):
         if sys.platform != "win32":
         if sys.platform != "win32":
             raise SkipTest("platform allows overwriting open files")
             raise SkipTest("platform allows overwriting open files")
-        self.create(self.bar, b'bar contents')
-        dest_f = open(self.bar, 'rb')
+        self.create(self.bar, b"bar contents")
+        dest_f = open(self.bar, "rb")
         self.assertRaises(OSError, _fancy_rename, self.foo, self.bar)
         self.assertRaises(OSError, _fancy_rename, self.foo, self.bar)
         dest_f.close()
         dest_f.close()
-        self.assertTrue(os.path.exists(self.path('foo')))
+        self.assertTrue(os.path.exists(self.path("foo")))
 
 
-        new_f = open(self.foo, 'rb')
-        self.assertEqual(b'foo contents', new_f.read())
+        new_f = open(self.foo, "rb")
+        self.assertEqual(b"foo contents", new_f.read())
         new_f.close()
         new_f.close()
 
 
-        new_f = open(self.bar, 'rb')
-        self.assertEqual(b'bar contents', new_f.read())
+        new_f = open(self.bar, "rb")
+        self.assertEqual(b"bar contents", new_f.read())
         new_f.close()
         new_f.close()
 
 
 
 
 class GitFileTests(TestCase):
 class GitFileTests(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(GitFileTests, self).setUp()
         super(GitFileTests, self).setUp()
         self._tempdir = tempfile.mkdtemp()
         self._tempdir = tempfile.mkdtemp()
-        f = open(self.path('foo'), 'wb')
-        f.write(b'foo contents')
+        f = open(self.path("foo"), "wb")
+        f.write(b"foo contents")
         f.close()
         f.close()
 
 
     def tearDown(self):
     def tearDown(self):
@@ -105,98 +103,98 @@ class GitFileTests(TestCase):
         return os.path.join(self._tempdir, filename)
         return os.path.join(self._tempdir, filename)
 
 
     def test_invalid(self):
     def test_invalid(self):
-        foo = self.path('foo')
-        self.assertRaises(IOError, GitFile, foo, mode='r')
-        self.assertRaises(IOError, GitFile, foo, mode='ab')
-        self.assertRaises(IOError, GitFile, foo, mode='r+b')
-        self.assertRaises(IOError, GitFile, foo, mode='w+b')
-        self.assertRaises(IOError, GitFile, foo, mode='a+bU')
+        foo = self.path("foo")
+        self.assertRaises(IOError, GitFile, foo, mode="r")
+        self.assertRaises(IOError, GitFile, foo, mode="ab")
+        self.assertRaises(IOError, GitFile, foo, mode="r+b")
+        self.assertRaises(IOError, GitFile, foo, mode="w+b")
+        self.assertRaises(IOError, GitFile, foo, mode="a+bU")
 
 
     def test_readonly(self):
     def test_readonly(self):
-        f = GitFile(self.path('foo'), 'rb')
+        f = GitFile(self.path("foo"), "rb")
         self.assertTrue(isinstance(f, io.IOBase))
         self.assertTrue(isinstance(f, io.IOBase))
-        self.assertEqual(b'foo contents', f.read())
-        self.assertEqual(b'', f.read())
+        self.assertEqual(b"foo contents", f.read())
+        self.assertEqual(b"", f.read())
         f.seek(4)
         f.seek(4)
-        self.assertEqual(b'contents', f.read())
+        self.assertEqual(b"contents", f.read())
         f.close()
         f.close()
 
 
     def test_default_mode(self):
     def test_default_mode(self):
-        f = GitFile(self.path('foo'))
-        self.assertEqual(b'foo contents', f.read())
+        f = GitFile(self.path("foo"))
+        self.assertEqual(b"foo contents", f.read())
         f.close()
         f.close()
 
 
     def test_write(self):
     def test_write(self):
-        foo = self.path('foo')
-        foo_lock = '%s.lock' % foo
+        foo = self.path("foo")
+        foo_lock = "%s.lock" % foo
 
 
-        orig_f = open(foo, 'rb')
-        self.assertEqual(orig_f.read(), b'foo contents')
+        orig_f = open(foo, "rb")
+        self.assertEqual(orig_f.read(), b"foo contents")
         orig_f.close()
         orig_f.close()
 
 
         self.assertFalse(os.path.exists(foo_lock))
         self.assertFalse(os.path.exists(foo_lock))
-        f = GitFile(foo, 'wb')
+        f = GitFile(foo, "wb")
         self.assertFalse(f.closed)
         self.assertFalse(f.closed)
-        self.assertRaises(AttributeError, getattr, f, 'not_a_file_property')
+        self.assertRaises(AttributeError, getattr, f, "not_a_file_property")
 
 
         self.assertTrue(os.path.exists(foo_lock))
         self.assertTrue(os.path.exists(foo_lock))
-        f.write(b'new stuff')
+        f.write(b"new stuff")
         f.seek(4)
         f.seek(4)
-        f.write(b'contents')
+        f.write(b"contents")
         f.close()
         f.close()
         self.assertFalse(os.path.exists(foo_lock))
         self.assertFalse(os.path.exists(foo_lock))
 
 
-        new_f = open(foo, 'rb')
-        self.assertEqual(b'new contents', new_f.read())
+        new_f = open(foo, "rb")
+        self.assertEqual(b"new contents", new_f.read())
         new_f.close()
         new_f.close()
 
 
     def test_open_twice(self):
     def test_open_twice(self):
-        foo = self.path('foo')
-        f1 = GitFile(foo, 'wb')
-        f1.write(b'new')
+        foo = self.path("foo")
+        f1 = GitFile(foo, "wb")
+        f1.write(b"new")
         try:
         try:
-            f2 = GitFile(foo, 'wb')
+            f2 = GitFile(foo, "wb")
             self.fail()
             self.fail()
         except FileLocked:
         except FileLocked:
             pass
             pass
         else:
         else:
             f2.close()
             f2.close()
-        f1.write(b' contents')
+        f1.write(b" contents")
         f1.close()
         f1.close()
 
 
         # Ensure trying to open twice doesn't affect original.
         # Ensure trying to open twice doesn't affect original.
-        f = open(foo, 'rb')
-        self.assertEqual(b'new contents', f.read())
+        f = open(foo, "rb")
+        self.assertEqual(b"new contents", f.read())
         f.close()
         f.close()
 
 
     def test_abort(self):
     def test_abort(self):
-        foo = self.path('foo')
-        foo_lock = '%s.lock' % foo
+        foo = self.path("foo")
+        foo_lock = "%s.lock" % foo
 
 
-        orig_f = open(foo, 'rb')
-        self.assertEqual(orig_f.read(), b'foo contents')
+        orig_f = open(foo, "rb")
+        self.assertEqual(orig_f.read(), b"foo contents")
         orig_f.close()
         orig_f.close()
 
 
-        f = GitFile(foo, 'wb')
-        f.write(b'new contents')
+        f = GitFile(foo, "wb")
+        f.write(b"new contents")
         f.abort()
         f.abort()
         self.assertTrue(f.closed)
         self.assertTrue(f.closed)
         self.assertFalse(os.path.exists(foo_lock))
         self.assertFalse(os.path.exists(foo_lock))
 
 
-        new_orig_f = open(foo, 'rb')
-        self.assertEqual(new_orig_f.read(), b'foo contents')
+        new_orig_f = open(foo, "rb")
+        self.assertEqual(new_orig_f.read(), b"foo contents")
         new_orig_f.close()
         new_orig_f.close()
 
 
     def test_abort_close(self):
     def test_abort_close(self):
-        foo = self.path('foo')
-        f = GitFile(foo, 'wb')
+        foo = self.path("foo")
+        f = GitFile(foo, "wb")
         f.abort()
         f.abort()
         try:
         try:
             f.close()
             f.close()
         except (IOError, OSError):
         except (IOError, OSError):
             self.fail()
             self.fail()
 
 
-        f = GitFile(foo, 'wb')
+        f = GitFile(foo, "wb")
         f.close()
         f.close()
         try:
         try:
             f.abort()
             f.abort()
@@ -204,11 +202,11 @@ class GitFileTests(TestCase):
             self.fail()
             self.fail()
 
 
     def test_abort_close_removed(self):
     def test_abort_close_removed(self):
-        foo = self.path('foo')
-        f = GitFile(foo, 'wb')
+        foo = self.path("foo")
+        f = GitFile(foo, "wb")
 
 
         f._file.close()
         f._file.close()
-        os.remove(foo+".lock")
+        os.remove(foo + ".lock")
 
 
         f.abort()
         f.abort()
         self.assertTrue(f._closed)
         self.assertTrue(f._closed)

+ 67 - 65
dulwich/tests/test_grafts.py

@@ -27,7 +27,7 @@ from dulwich.errors import ObjectFormatException
 from dulwich.tests import TestCase
 from dulwich.tests import TestCase
 from dulwich.objects import (
 from dulwich.objects import (
     Tree,
     Tree,
-    )
+)
 from dulwich.repo import (
 from dulwich.repo import (
     parse_graftpoints,
     parse_graftpoints,
     serialize_graftpoints,
     serialize_graftpoints,
@@ -37,11 +37,10 @@ from dulwich.repo import (
 
 
 
 
 def makesha(digit):
 def makesha(digit):
-    return (str(digit).encode('ascii') * 40)[:40]
+    return (str(digit).encode("ascii") * 40)[:40]
 
 
 
 
 class GraftParserTests(TestCase):
 class GraftParserTests(TestCase):
-
     def assertParse(self, expected, graftpoints):
     def assertParse(self, expected, graftpoints):
         self.assertEqual(expected, parse_graftpoints(iter(graftpoints)))
         self.assertEqual(expected, parse_graftpoints(iter(graftpoints)))
 
 
@@ -52,49 +51,60 @@ class GraftParserTests(TestCase):
         self.assertParse({makesha(0): []}, [makesha(0)])
         self.assertParse({makesha(0): []}, [makesha(0)])
 
 
     def test_parents(self):
     def test_parents(self):
-        self.assertParse({makesha(0): [makesha(1), makesha(2)]},
-                         [b' '.join([makesha(0), makesha(1), makesha(2)])])
+        self.assertParse(
+            {makesha(0): [makesha(1), makesha(2)]},
+            [b" ".join([makesha(0), makesha(1), makesha(2)])],
+        )
 
 
     def test_multiple_hybrid(self):
     def test_multiple_hybrid(self):
         self.assertParse(
         self.assertParse(
-            {makesha(0): [],
-             makesha(1): [makesha(2)],
-             makesha(3): [makesha(4), makesha(5)]},
-            [makesha(0),
-             b' '.join([makesha(1), makesha(2)]),
-             b' '.join([makesha(3), makesha(4), makesha(5)])])
+            {
+                makesha(0): [],
+                makesha(1): [makesha(2)],
+                makesha(3): [makesha(4), makesha(5)],
+            },
+            [
+                makesha(0),
+                b" ".join([makesha(1), makesha(2)]),
+                b" ".join([makesha(3), makesha(4), makesha(5)]),
+            ],
+        )
 
 
 
 
 class GraftSerializerTests(TestCase):
 class GraftSerializerTests(TestCase):
-
     def assertSerialize(self, expected, graftpoints):
     def assertSerialize(self, expected, graftpoints):
-        self.assertEqual(
-            sorted(expected),
-            sorted(serialize_graftpoints(graftpoints)))
+        self.assertEqual(sorted(expected), sorted(serialize_graftpoints(graftpoints)))
 
 
     def test_no_grafts(self):
     def test_no_grafts(self):
-        self.assertSerialize(b'', {})
+        self.assertSerialize(b"", {})
 
 
     def test_no_parents(self):
     def test_no_parents(self):
         self.assertSerialize(makesha(0), {makesha(0): []})
         self.assertSerialize(makesha(0), {makesha(0): []})
 
 
     def test_parents(self):
     def test_parents(self):
-        self.assertSerialize(b' '.join([makesha(0), makesha(1), makesha(2)]),
-                             {makesha(0): [makesha(1), makesha(2)]})
+        self.assertSerialize(
+            b" ".join([makesha(0), makesha(1), makesha(2)]),
+            {makesha(0): [makesha(1), makesha(2)]},
+        )
 
 
     def test_multiple_hybrid(self):
     def test_multiple_hybrid(self):
         self.assertSerialize(
         self.assertSerialize(
-            b'\n'.join([
-                makesha(0),
-                b' '.join([makesha(1), makesha(2)]),
-                b' '.join([makesha(3), makesha(4), makesha(5)])]),
-            {makesha(0): [],
-             makesha(1): [makesha(2)],
-             makesha(3): [makesha(4), makesha(5)]})
+            b"\n".join(
+                [
+                    makesha(0),
+                    b" ".join([makesha(1), makesha(2)]),
+                    b" ".join([makesha(3), makesha(4), makesha(5)]),
+                ]
+            ),
+            {
+                makesha(0): [],
+                makesha(1): [makesha(2)],
+                makesha(3): [makesha(4), makesha(5)],
+            },
+        )
 
 
 
 
 class GraftsInRepositoryBase(object):
 class GraftsInRepositoryBase(object):
-
     def tearDown(self):
     def tearDown(self):
         super(GraftsInRepositoryBase, self).tearDown()
         super(GraftsInRepositoryBase, self).tearDown()
 
 
@@ -112,33 +122,31 @@ class GraftsInRepositoryBase(object):
     def test_no_parents_graft(self):
     def test_no_parents_graft(self):
         r = self.get_repo_with_grafts({self._repo.head(): []})
         r = self.get_repo_with_grafts({self._repo.head(): []})
 
 
-        self.assertEqual([e.commit.id for e in r.get_walker()],
-                         [r.head()])
+        self.assertEqual([e.commit.id for e in r.get_walker()], [r.head()])
 
 
     def test_existing_parent_graft(self):
     def test_existing_parent_graft(self):
         r = self.get_repo_with_grafts({self._shas[-1]: [self._shas[0]]})
         r = self.get_repo_with_grafts({self._shas[-1]: [self._shas[0]]})
 
 
-        self.assertEqual([e.commit.id for e in r.get_walker()],
-                         [self._shas[-1], self._shas[0]])
+        self.assertEqual(
+            [e.commit.id for e in r.get_walker()],
+            [self._shas[-1], self._shas[0]],
+        )
 
 
     def test_remove_graft(self):
     def test_remove_graft(self):
         r = self.get_repo_with_grafts({self._repo.head(): []})
         r = self.get_repo_with_grafts({self._repo.head(): []})
         r._remove_graftpoints([self._repo.head()])
         r._remove_graftpoints([self._repo.head()])
 
 
-        self.assertEqual([e.commit.id for e in r.get_walker()],
-                         self._shas[::-1])
+        self.assertEqual([e.commit.id for e in r.get_walker()], self._shas[::-1])
 
 
     def test_object_store_fail_invalid_parents(self):
     def test_object_store_fail_invalid_parents(self):
         r = self._repo
         r = self._repo
 
 
         self.assertRaises(
         self.assertRaises(
-            ObjectFormatException,
-            r._add_graftpoints,
-            {self._shas[-1]: ['1']})
+            ObjectFormatException, r._add_graftpoints, {self._shas[-1]: ["1"]}
+        )
 
 
 
 
 class GraftsInRepoTests(GraftsInRepositoryBase, TestCase):
 class GraftsInRepoTests(GraftsInRepositoryBase, TestCase):
-
     def setUp(self):
     def setUp(self):
         super(GraftsInRepoTests, self).setUp()
         super(GraftsInRepoTests, self).setUp()
         self._repo_dir = os.path.join(tempfile.mkdtemp())
         self._repo_dir = os.path.join(tempfile.mkdtemp())
@@ -148,24 +156,21 @@ class GraftsInRepoTests(GraftsInRepositoryBase, TestCase):
         self._shas = []
         self._shas = []
 
 
         commit_kwargs = {
         commit_kwargs = {
-            'committer': b'Test Committer <test@nodomain.com>',
-            'author': b'Test Author <test@nodomain.com>',
-            'commit_timestamp': 12395,
-            'commit_timezone': 0,
-            'author_timestamp': 12395,
-            'author_timezone': 0,
+            "committer": b"Test Committer <test@nodomain.com>",
+            "author": b"Test Author <test@nodomain.com>",
+            "commit_timestamp": 12395,
+            "commit_timezone": 0,
+            "author_timestamp": 12395,
+            "author_timezone": 0,
         }
         }
 
 
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))
 
 
     def test_init_with_empty_info_grafts(self):
     def test_init_with_empty_info_grafts(self):
         r = self._repo
         r = self._repo
-        r._put_named_file(os.path.join('info', 'grafts'), b'')
+        r._put_named_file(os.path.join("info", "grafts"), b"")
 
 
         r = Repo(self._repo_dir)
         r = Repo(self._repo_dir)
         self.assertEqual({}, r._graftpoints)
         self.assertEqual({}, r._graftpoints)
@@ -173,15 +178,15 @@ class GraftsInRepoTests(GraftsInRepositoryBase, TestCase):
     def test_init_with_info_grafts(self):
     def test_init_with_info_grafts(self):
         r = self._repo
         r = self._repo
         r._put_named_file(
         r._put_named_file(
-            os.path.join('info', 'grafts'),
-            self._shas[-1] + b' ' + self._shas[0])
+            os.path.join("info", "grafts"),
+            self._shas[-1] + b" " + self._shas[0],
+        )
 
 
         r = Repo(self._repo_dir)
         r = Repo(self._repo_dir)
         self.assertEqual({self._shas[-1]: [self._shas[0]]}, r._graftpoints)
         self.assertEqual({self._shas[-1]: [self._shas[0]]}, r._graftpoints)
 
 
 
 
 class GraftsInMemoryRepoTests(GraftsInRepositoryBase, TestCase):
 class GraftsInMemoryRepoTests(GraftsInRepositoryBase, TestCase):
-
     def setUp(self):
     def setUp(self):
         super(GraftsInMemoryRepoTests, self).setUp()
         super(GraftsInMemoryRepoTests, self).setUp()
         r = self._repo = MemoryRepo()
         r = self._repo = MemoryRepo()
@@ -191,18 +196,15 @@ class GraftsInMemoryRepoTests(GraftsInRepositoryBase, TestCase):
         tree = Tree()
         tree = Tree()
 
 
         commit_kwargs = {
         commit_kwargs = {
-            'committer': b'Test Committer <test@nodomain.com>',
-            'author': b'Test Author <test@nodomain.com>',
-            'commit_timestamp': 12395,
-            'commit_timezone': 0,
-            'author_timestamp': 12395,
-            'author_timezone': 0,
-            'tree': tree.id
+            "committer": b"Test Committer <test@nodomain.com>",
+            "author": b"Test Author <test@nodomain.com>",
+            "commit_timestamp": 12395,
+            "commit_timezone": 0,
+            "author_timestamp": 12395,
+            "author_timezone": 0,
+            "tree": tree.id,
         }
         }
 
 
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
-        self._shas.append(r.do_commit(
-            b'empty commit', **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))
+        self._shas.append(r.do_commit(b"empty commit", **commit_kwargs))

+ 76 - 77
dulwich/tests/test_graph.py

@@ -29,11 +29,11 @@ from dulwich.graph import _find_lcas, can_fast_forward
 
 
 
 
 class FindMergeBaseTests(TestCase):
 class FindMergeBaseTests(TestCase):
-
     @staticmethod
     @staticmethod
     def run_test(dag, inputs):
     def run_test(dag, inputs):
         def lookup_parents(commit_id):
         def lookup_parents(commit_id):
             return dag[commit_id]
             return dag[commit_id]
+
         c1 = inputs[0]
         c1 = inputs[0]
         c2s = inputs[1:]
         c2s = inputs[1:]
         return set(_find_lcas(lookup_parents, c1, c2s))
         return set(_find_lcas(lookup_parents, c1, c2s))
@@ -41,125 +41,125 @@ class FindMergeBaseTests(TestCase):
     def test_multiple_lca(self):
     def test_multiple_lca(self):
         # two lowest common ancestors
         # two lowest common ancestors
         graph = {
         graph = {
-            '5': ['1', '2'],
-            '4': ['3', '1'],
-            '3': ['2'],
-            '2': ['0'],
-            '1': [],
-            '0': []
+            "5": ["1", "2"],
+            "4": ["3", "1"],
+            "3": ["2"],
+            "2": ["0"],
+            "1": [],
+            "0": [],
         }
         }
-        self.assertEqual(self.run_test(graph, ['4', '5']), set(['1', '2']))
+        self.assertEqual(self.run_test(graph, ["4", "5"]), set(["1", "2"]))
 
 
     def test_no_common_ancestor(self):
     def test_no_common_ancestor(self):
         # no common ancestor
         # no common ancestor
         graph = {
         graph = {
-            '4': ['2'],
-            '3': ['1'],
-            '2': [],
-            '1': ['0'],
-            '0': [],
+            "4": ["2"],
+            "3": ["1"],
+            "2": [],
+            "1": ["0"],
+            "0": [],
         }
         }
-        self.assertEqual(self.run_test(graph, ['4', '3']), set([]))
+        self.assertEqual(self.run_test(graph, ["4", "3"]), set([]))
 
 
     def test_ancestor(self):
     def test_ancestor(self):
         # ancestor
         # ancestor
         graph = {
         graph = {
-            'G': ['D', 'F'],
-            'F': ['E'],
-            'D': ['C'],
-            'C': ['B'],
-            'E': ['B'],
-            'B': ['A'],
-            'A': []
+            "G": ["D", "F"],
+            "F": ["E"],
+            "D": ["C"],
+            "C": ["B"],
+            "E": ["B"],
+            "B": ["A"],
+            "A": [],
         }
         }
-        self.assertEqual(self.run_test(graph, ['D', 'C']), set(['C']))
+        self.assertEqual(self.run_test(graph, ["D", "C"]), set(["C"]))
 
 
     def test_direct_parent(self):
     def test_direct_parent(self):
         # parent
         # parent
         graph = {
         graph = {
-            'G': ['D', 'F'],
-            'F': ['E'],
-            'D': ['C'],
-            'C': ['B'],
-            'E': ['B'],
-            'B': ['A'],
-            'A': []
+            "G": ["D", "F"],
+            "F": ["E"],
+            "D": ["C"],
+            "C": ["B"],
+            "E": ["B"],
+            "B": ["A"],
+            "A": [],
         }
         }
-        self.assertEqual(self.run_test(graph, ['G', 'D']), set(['D']))
+        self.assertEqual(self.run_test(graph, ["G", "D"]), set(["D"]))
 
 
     def test_another_crossover(self):
     def test_another_crossover(self):
         # Another cross over
         # Another cross over
         graph = {
         graph = {
-            'G': ['D', 'F'],
-            'F': ['E', 'C'],
-            'D': ['C', 'E'],
-            'C': ['B'],
-            'E': ['B'],
-            'B': ['A'],
-            'A': []
+            "G": ["D", "F"],
+            "F": ["E", "C"],
+            "D": ["C", "E"],
+            "C": ["B"],
+            "E": ["B"],
+            "B": ["A"],
+            "A": [],
         }
         }
-        self.assertEqual(self.run_test(graph, ['D', 'F']), set(['E', 'C']))
+        self.assertEqual(self.run_test(graph, ["D", "F"]), set(["E", "C"]))
 
 
     def test_three_way_merge_lca(self):
     def test_three_way_merge_lca(self):
         # three way merge commit straight from git docs
         # three way merge commit straight from git docs
         graph = {
         graph = {
-            'C': ['C1'],
-            'C1': ['C2'],
-            'C2': ['C3'],
-            'C3': ['C4'],
-            'C4': ['2'],
-            'B': ['B1'],
-            'B1': ['B2'],
-            'B2': ['B3'],
-            'B3': ['1'],
-            'A': ['A1'],
-            'A1': ['A2'],
-            'A2': ['A3'],
-            'A3': ['1'],
-            '1': ['2'],
-            '2': [],
+            "C": ["C1"],
+            "C1": ["C2"],
+            "C2": ["C3"],
+            "C3": ["C4"],
+            "C4": ["2"],
+            "B": ["B1"],
+            "B1": ["B2"],
+            "B2": ["B3"],
+            "B3": ["1"],
+            "A": ["A1"],
+            "A1": ["A2"],
+            "A2": ["A3"],
+            "A3": ["1"],
+            "1": ["2"],
+            "2": [],
         }
         }
         # assumes a theoretical merge M exists that merges B and C first
         # assumes a theoretical merge M exists that merges B and C first
         # which actually means find the first LCA from either of B OR C with A
         # which actually means find the first LCA from either of B OR C with A
-        self.assertEqual(self.run_test(graph, ['A', 'B', 'C']), set(['1']))
+        self.assertEqual(self.run_test(graph, ["A", "B", "C"]), set(["1"]))
 
 
     def test_octopus(self):
     def test_octopus(self):
         # octopus algorithm test
         # octopus algorithm test
         # test straight from git docs of A, B, and C
         # test straight from git docs of A, B, and C
         # but this time use octopus to find lcas of A, B, and C simultaneously
         # but this time use octopus to find lcas of A, B, and C simultaneously
         graph = {
         graph = {
-            'C': ['C1'],
-            'C1': ['C2'],
-            'C2': ['C3'],
-            'C3': ['C4'],
-            'C4': ['2'],
-            'B': ['B1'],
-            'B1': ['B2'],
-            'B2': ['B3'],
-            'B3': ['1'],
-            'A': ['A1'],
-            'A1': ['A2'],
-            'A2': ['A3'],
-            'A3': ['1'],
-            '1': ['2'],
-            '2': [],
+            "C": ["C1"],
+            "C1": ["C2"],
+            "C2": ["C3"],
+            "C3": ["C4"],
+            "C4": ["2"],
+            "B": ["B1"],
+            "B1": ["B2"],
+            "B2": ["B3"],
+            "B3": ["1"],
+            "A": ["A1"],
+            "A1": ["A2"],
+            "A2": ["A3"],
+            "A3": ["1"],
+            "1": ["2"],
+            "2": [],
         }
         }
 
 
         def lookup_parents(cid):
         def lookup_parents(cid):
             return graph[cid]
             return graph[cid]
-        lcas = ['A']
-        others = ['B', 'C']
+
+        lcas = ["A"]
+        others = ["B", "C"]
         for cmt in others:
         for cmt in others:
             next_lcas = []
             next_lcas = []
             for ca in lcas:
             for ca in lcas:
                 res = _find_lcas(lookup_parents, cmt, [ca])
                 res = _find_lcas(lookup_parents, cmt, [ca])
                 next_lcas.extend(res)
                 next_lcas.extend(res)
             lcas = next_lcas[:]
             lcas = next_lcas[:]
-        self.assertEqual(set(lcas), set(['2']))
+        self.assertEqual(set(lcas), set(["2"]))
 
 
 
 
 class CanFastForwardTests(TestCase):
 class CanFastForwardTests(TestCase):
-
     def test_ff(self):
     def test_ff(self):
         r = MemoryRepo()
         r = MemoryRepo()
         base = make_commit()
         base = make_commit()
@@ -175,10 +175,9 @@ class CanFastForwardTests(TestCase):
         r = MemoryRepo()
         r = MemoryRepo()
         base = make_commit()
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c1 = make_commit(parents=[base.id])
-        c2a = make_commit(parents=[c1.id], message=b'2a')
-        c2b = make_commit(parents=[c1.id], message=b'2b')
-        r.object_store.add_objects(
-            [(base, None), (c1, None), (c2a, None), (c2b, None)])
+        c2a = make_commit(parents=[c1.id], message=b"2a")
+        c2b = make_commit(parents=[c1.id], message=b"2b")
+        r.object_store.add_objects([(base, None), (c1, None), (c2a, None), (c2b, None)])
         self.assertTrue(can_fast_forward(r, c1.id, c2a.id))
         self.assertTrue(can_fast_forward(r, c1.id, c2a.id))
         self.assertTrue(can_fast_forward(r, c1.id, c2b.id))
         self.assertTrue(can_fast_forward(r, c1.id, c2b.id))
         self.assertFalse(can_fast_forward(r, c2a.id, c2b.id))
         self.assertFalse(can_fast_forward(r, c2a.id, c2b.id))

+ 18 - 15
dulwich/tests/test_greenthreads.py

@@ -25,20 +25,21 @@ import time
 from dulwich.tests import (
 from dulwich.tests import (
     skipIf,
     skipIf,
     TestCase,
     TestCase,
-    )
+)
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
     MissingObjectFinder,
     MissingObjectFinder,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     Commit,
     Commit,
     Blob,
     Blob,
     Tree,
     Tree,
     parse_timezone,
     parse_timezone,
-    )
+)
 
 
 try:
 try:
     import gevent  # noqa: F401
     import gevent  # noqa: F401
+
     gevent_support = True
     gevent_support = True
 except ImportError:
 except ImportError:
     gevent_support = False
     gevent_support = False
@@ -53,14 +54,14 @@ skipmsg = "Gevent library is not installed"
 
 
 
 
 def create_commit(marker=None):
 def create_commit(marker=None):
-    blob = Blob.from_string(b'The blob content ' + marker)
+    blob = Blob.from_string(b"The blob content " + marker)
     tree = Tree()
     tree = Tree()
     tree.add(b"thefile " + marker, 0o100644, blob.id)
     tree.add(b"thefile " + marker, 0o100644, blob.id)
     cmt = Commit()
     cmt = Commit()
     cmt.tree = tree.id
     cmt.tree = tree.id
     cmt.author = cmt.committer = b"John Doe <john@doe.net>"
     cmt.author = cmt.committer = b"John Doe <john@doe.net>"
     cmt.message = marker
     cmt.message = marker
-    tz = parse_timezone(b'-0200')[0]
+    tz = parse_timezone(b"-0200")[0]
     cmt.commit_time = cmt.author_time = int(time.time())
     cmt.commit_time = cmt.author_time = int(time.time())
     cmt.commit_timezone = cmt.author_timezone = tz
     cmt.commit_timezone = cmt.author_timezone = tz
     return cmt, tree, blob
     return cmt, tree, blob
@@ -69,7 +70,7 @@ def create_commit(marker=None):
 def init_store(store, count=1):
 def init_store(store, count=1):
     ret = []
     ret = []
     for i in range(0, count):
     for i in range(0, count):
-        objs = create_commit(marker=("%d" % i).encode('ascii'))
+        objs = create_commit(marker=("%d" % i).encode("ascii"))
         for obj in objs:
         for obj in objs:
             ret.append(obj)
             ret.append(obj)
             store.add_object(obj)
             store.add_object(obj)
@@ -78,7 +79,6 @@ def init_store(store, count=1):
 
 
 @skipIf(not gevent_support, skipmsg)
 @skipIf(not gevent_support, skipmsg)
 class TestGreenThreadsObjectStoreIterator(TestCase):
 class TestGreenThreadsObjectStoreIterator(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(TestGreenThreadsObjectStoreIterator, self).setUp()
         super(TestGreenThreadsObjectStoreIterator, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
@@ -89,20 +89,23 @@ class TestGreenThreadsObjectStoreIterator(TestCase):
         wants = [sha.id for sha in self.objs if isinstance(sha, Commit)]
         wants = [sha.id for sha in self.objs if isinstance(sha, Commit)]
         finder = MissingObjectFinder(self.store, (), wants)
         finder = MissingObjectFinder(self.store, (), wants)
         iterator = GreenThreadsObjectStoreIterator(
         iterator = GreenThreadsObjectStoreIterator(
-                self.store, iter(finder.next, None), finder)
+            self.store, iter(finder.next, None), finder
+        )
         # One commit refers one tree and one blob
         # One commit refers one tree and one blob
         self.assertEqual(len(iterator), self.cmt_amount * 3)
         self.assertEqual(len(iterator), self.cmt_amount * 3)
-        haves = wants[0:self.cmt_amount-1]
+        haves = wants[0 : self.cmt_amount - 1]
         finder = MissingObjectFinder(self.store, haves, wants)
         finder = MissingObjectFinder(self.store, haves, wants)
         iterator = GreenThreadsObjectStoreIterator(
         iterator = GreenThreadsObjectStoreIterator(
-            self.store, iter(finder.next, None), finder)
+            self.store, iter(finder.next, None), finder
+        )
         self.assertEqual(len(iterator), 3)
         self.assertEqual(len(iterator), 3)
 
 
     def test_iter(self):
     def test_iter(self):
         wants = [sha.id for sha in self.objs if isinstance(sha, Commit)]
         wants = [sha.id for sha in self.objs if isinstance(sha, Commit)]
         finder = MissingObjectFinder(self.store, (), wants)
         finder = MissingObjectFinder(self.store, (), wants)
         iterator = GreenThreadsObjectStoreIterator(
         iterator = GreenThreadsObjectStoreIterator(
-            self.store, iter(finder.next, None), finder)
+            self.store, iter(finder.next, None), finder
+        )
         objs = []
         objs = []
         for sha, path in iterator:
         for sha, path in iterator:
             self.assertIn(sha, self.objs)
             self.assertIn(sha, self.objs)
@@ -112,7 +115,6 @@ class TestGreenThreadsObjectStoreIterator(TestCase):
 
 
 @skipIf(not gevent_support, skipmsg)
 @skipIf(not gevent_support, skipmsg)
 class TestGreenThreadsMissingObjectFinder(TestCase):
 class TestGreenThreadsMissingObjectFinder(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(TestGreenThreadsMissingObjectFinder, self).setUp()
         super(TestGreenThreadsMissingObjectFinder, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
@@ -126,7 +128,8 @@ class TestGreenThreadsMissingObjectFinder(TestCase):
         self.assertEqual(len(finder.objects_to_send), self.cmt_amount)
         self.assertEqual(len(finder.objects_to_send), self.cmt_amount)
 
 
         finder = GreenThreadsMissingObjectFinder(
         finder = GreenThreadsMissingObjectFinder(
-            self.store, wants[0:int(self.cmt_amount/2)], wants)
+            self.store, wants[0 : int(self.cmt_amount / 2)], wants
+        )
         # sha_done will contains commit id and sha of blob refered in tree
         # sha_done will contains commit id and sha of blob refered in tree
-        self.assertEqual(len(finder.sha_done), (self.cmt_amount/2)*2)
-        self.assertEqual(len(finder.objects_to_send), self.cmt_amount/2)
+        self.assertEqual(len(finder.sha_done), (self.cmt_amount / 2) * 2)
+        self.assertEqual(len(finder.objects_to_send), self.cmt_amount / 2)

+ 51 - 34
dulwich/tests/test_hooks.py

@@ -37,16 +37,15 @@ from dulwich.tests import TestCase
 
 
 
 
 class ShellHookTests(TestCase):
 class ShellHookTests(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(ShellHookTests, self).setUp()
         super(ShellHookTests, self).setUp()
-        if os.name != 'posix':
-            self.skipTest('shell hook tests requires POSIX shell')
-        self.assertTrue(os.path.exists('/bin/sh'))
+        if os.name != "posix":
+            self.skipTest("shell hook tests requires POSIX shell")
+        self.assertTrue(os.path.exists("/bin/sh"))
 
 
     def test_hook_pre_commit(self):
     def test_hook_pre_commit(self):
         repo_dir = os.path.join(tempfile.mkdtemp())
         repo_dir = os.path.join(tempfile.mkdtemp())
-        os.mkdir(os.path.join(repo_dir, 'hooks'))
+        os.mkdir(os.path.join(repo_dir, "hooks"))
         self.addCleanup(shutil.rmtree, repo_dir)
         self.addCleanup(shutil.rmtree, repo_dir)
 
 
         pre_commit_fail = """#!/bin/sh
         pre_commit_fail = """#!/bin/sh
@@ -56,34 +55,40 @@ exit 1
         pre_commit_success = """#!/bin/sh
         pre_commit_success = """#!/bin/sh
 exit 0
 exit 0
 """
 """
-        pre_commit_cwd = """#!/bin/sh
-if [ "$(pwd)" != '""" + repo_dir + """' ]; then
-    echo "Expected path '""" + repo_dir + """', got '$(pwd)'"
+        pre_commit_cwd = (
+            """#!/bin/sh
+if [ "$(pwd)" != '"""
+            + repo_dir
+            + """' ]; then
+    echo "Expected path '"""
+            + repo_dir
+            + """', got '$(pwd)'"
     exit 1
     exit 1
 fi
 fi
 
 
 exit 0
 exit 0
 """
 """
+        )
 
 
-        pre_commit = os.path.join(repo_dir, 'hooks', 'pre-commit')
+        pre_commit = os.path.join(repo_dir, "hooks", "pre-commit")
         hook = PreCommitShellHook(repo_dir)
         hook = PreCommitShellHook(repo_dir)
 
 
-        with open(pre_commit, 'w') as f:
+        with open(pre_commit, "w") as f:
             f.write(pre_commit_fail)
             f.write(pre_commit_fail)
         os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
         os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
 
         self.assertRaises(errors.HookError, hook.execute)
         self.assertRaises(errors.HookError, hook.execute)
 
 
-        if sys.platform != 'darwin':
+        if sys.platform != "darwin":
             # Don't bother running this test on darwin since path
             # Don't bother running this test on darwin since path
             # canonicalization messages with our simple string comparison.
             # canonicalization messages with our simple string comparison.
-            with open(pre_commit, 'w') as f:
+            with open(pre_commit, "w") as f:
                 f.write(pre_commit_cwd)
                 f.write(pre_commit_cwd)
             os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
             os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
 
             hook.execute()
             hook.execute()
 
 
-        with open(pre_commit, 'w') as f:
+        with open(pre_commit, "w") as f:
             f.write(pre_commit_success)
             f.write(pre_commit_success)
         os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
         os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
 
@@ -92,7 +97,7 @@ exit 0
     def test_hook_commit_msg(self):
     def test_hook_commit_msg(self):
 
 
         repo_dir = os.path.join(tempfile.mkdtemp())
         repo_dir = os.path.join(tempfile.mkdtemp())
-        os.mkdir(os.path.join(repo_dir, 'hooks'))
+        os.mkdir(os.path.join(repo_dir, "hooks"))
         self.addCleanup(shutil.rmtree, repo_dir)
         self.addCleanup(shutil.rmtree, repo_dir)
 
 
         commit_msg_fail = """#!/bin/sh
         commit_msg_fail = """#!/bin/sh
@@ -103,32 +108,36 @@ exit 1
 exit 0
 exit 0
 """
 """
 
 
-        commit_msg_cwd = """#!/bin/sh
-if [ "$(pwd)" = '""" + repo_dir + "' ]; then exit 0; else exit 1; fi\n"
+        commit_msg_cwd = (
+            """#!/bin/sh
+if [ "$(pwd)" = '"""
+            + repo_dir
+            + "' ]; then exit 0; else exit 1; fi\n"
+        )
 
 
-        commit_msg = os.path.join(repo_dir, 'hooks', 'commit-msg')
+        commit_msg = os.path.join(repo_dir, "hooks", "commit-msg")
         hook = CommitMsgShellHook(repo_dir)
         hook = CommitMsgShellHook(repo_dir)
 
 
-        with open(commit_msg, 'w') as f:
+        with open(commit_msg, "w") as f:
             f.write(commit_msg_fail)
             f.write(commit_msg_fail)
         os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
         os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
 
-        self.assertRaises(errors.HookError, hook.execute, b'failed commit')
+        self.assertRaises(errors.HookError, hook.execute, b"failed commit")
 
 
-        if sys.platform != 'darwin':
+        if sys.platform != "darwin":
             # Don't bother running this test on darwin since path
             # Don't bother running this test on darwin since path
             # canonicalization messages with our simple string comparison.
             # canonicalization messages with our simple string comparison.
-            with open(commit_msg, 'w') as f:
+            with open(commit_msg, "w") as f:
                 f.write(commit_msg_cwd)
                 f.write(commit_msg_cwd)
             os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
             os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
 
-            hook.execute(b'cwd test commit')
+            hook.execute(b"cwd test commit")
 
 
-        with open(commit_msg, 'w') as f:
+        with open(commit_msg, "w") as f:
             f.write(commit_msg_success)
             f.write(commit_msg_success)
         os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
         os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
 
-        hook.execute(b'empty commit')
+        hook.execute(b"empty commit")
 
 
     def test_hook_post_commit(self):
     def test_hook_post_commit(self):
 
 
@@ -136,38 +145,46 @@ if [ "$(pwd)" = '""" + repo_dir + "' ]; then exit 0; else exit 1; fi\n"
         os.close(fd)
         os.close(fd)
 
 
         repo_dir = os.path.join(tempfile.mkdtemp())
         repo_dir = os.path.join(tempfile.mkdtemp())
-        os.mkdir(os.path.join(repo_dir, 'hooks'))
+        os.mkdir(os.path.join(repo_dir, "hooks"))
         self.addCleanup(shutil.rmtree, repo_dir)
         self.addCleanup(shutil.rmtree, repo_dir)
 
 
-        post_commit_success = """#!/bin/sh
-rm """ + path + "\n"
+        post_commit_success = (
+            """#!/bin/sh
+rm """
+            + path
+            + "\n"
+        )
 
 
         post_commit_fail = """#!/bin/sh
         post_commit_fail = """#!/bin/sh
 exit 1
 exit 1
 """
 """
 
 
-        post_commit_cwd = """#!/bin/sh
-if [ "$(pwd)" = '""" + repo_dir + "' ]; then exit 0; else exit 1; fi\n"
+        post_commit_cwd = (
+            """#!/bin/sh
+if [ "$(pwd)" = '"""
+            + repo_dir
+            + "' ]; then exit 0; else exit 1; fi\n"
+        )
 
 
-        post_commit = os.path.join(repo_dir, 'hooks', 'post-commit')
+        post_commit = os.path.join(repo_dir, "hooks", "post-commit")
         hook = PostCommitShellHook(repo_dir)
         hook = PostCommitShellHook(repo_dir)
 
 
-        with open(post_commit, 'w') as f:
+        with open(post_commit, "w") as f:
             f.write(post_commit_fail)
             f.write(post_commit_fail)
         os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
         os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
 
         self.assertRaises(errors.HookError, hook.execute)
         self.assertRaises(errors.HookError, hook.execute)
 
 
-        if sys.platform != 'darwin':
+        if sys.platform != "darwin":
             # Don't bother running this test on darwin since path
             # Don't bother running this test on darwin since path
             # canonicalization messages with our simple string comparison.
             # canonicalization messages with our simple string comparison.
-            with open(post_commit, 'w') as f:
+            with open(post_commit, "w") as f:
                 f.write(post_commit_cwd)
                 f.write(post_commit_cwd)
             os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
             os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
 
             hook.execute()
             hook.execute()
 
 
-        with open(post_commit, 'w') as f:
+        with open(post_commit, "w") as f:
             f.write(post_commit_success)
             f.write(post_commit_success)
         os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
         os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
 

+ 126 - 115
dulwich/tests/test_ignore.py

@@ -35,7 +35,7 @@ from dulwich.ignore import (
     match_pattern,
     match_pattern,
     read_ignore_patterns,
     read_ignore_patterns,
     translate,
     translate,
-    )
+)
 from dulwich.repo import Repo
 from dulwich.repo import Repo
 
 
 
 
@@ -65,44 +65,45 @@ NEGATIVE_MATCH_TESTS = [
     (b"foo/foo.c", b"/*.c"),
     (b"foo/foo.c", b"/*.c"),
     (b"foo/bar/", b"/bar/"),
     (b"foo/bar/", b"/bar/"),
     (b"foo/bar/", b"foo/bar/*"),
     (b"foo/bar/", b"foo/bar/*"),
-    (b"foo/bar", b"foo?bar")
+    (b"foo/bar", b"foo?bar"),
 ]
 ]
 
 
 
 
 TRANSLATE_TESTS = [
 TRANSLATE_TESTS = [
-    (b"*.c", b'(?ms)(.*/)?[^/]*\\.c/?\\Z'),
-    (b"foo.c", b'(?ms)(.*/)?foo\\.c/?\\Z'),
-    (b"/*.c", b'(?ms)[^/]*\\.c/?\\Z'),
-    (b"/foo.c", b'(?ms)foo\\.c/?\\Z'),
-    (b"foo.c", b'(?ms)(.*/)?foo\\.c/?\\Z'),
-    (b"foo.[ch]", b'(?ms)(.*/)?foo\\.[ch]/?\\Z'),
-    (b"bar/", b'(?ms)(.*/)?bar\\/\\Z'),
-    (b"foo/**", b'(?ms)foo(/.*)?/?\\Z'),
-    (b"foo/**/blie.c", b'(?ms)foo(/.*)?\\/blie\\.c/?\\Z'),
-    (b"**/bla.c", b'(?ms)(.*/)?bla\\.c/?\\Z'),
-    (b"foo/**/bar", b'(?ms)foo(/.*)?\\/bar/?\\Z'),
-    (b"foo/bar/*", b'(?ms)foo\\/bar\\/[^/]+/?\\Z'),
+    (b"*.c", b"(?ms)(.*/)?[^/]*\\.c/?\\Z"),
+    (b"foo.c", b"(?ms)(.*/)?foo\\.c/?\\Z"),
+    (b"/*.c", b"(?ms)[^/]*\\.c/?\\Z"),
+    (b"/foo.c", b"(?ms)foo\\.c/?\\Z"),
+    (b"foo.c", b"(?ms)(.*/)?foo\\.c/?\\Z"),
+    (b"foo.[ch]", b"(?ms)(.*/)?foo\\.[ch]/?\\Z"),
+    (b"bar/", b"(?ms)(.*/)?bar\\/\\Z"),
+    (b"foo/**", b"(?ms)foo(/.*)?/?\\Z"),
+    (b"foo/**/blie.c", b"(?ms)foo(/.*)?\\/blie\\.c/?\\Z"),
+    (b"**/bla.c", b"(?ms)(.*/)?bla\\.c/?\\Z"),
+    (b"foo/**/bar", b"(?ms)foo(/.*)?\\/bar/?\\Z"),
+    (b"foo/bar/*", b"(?ms)foo\\/bar\\/[^/]+/?\\Z"),
 ]
 ]
 
 
 
 
 class TranslateTests(TestCase):
 class TranslateTests(TestCase):
-
     def test_translate(self):
     def test_translate(self):
         for (pattern, regex) in TRANSLATE_TESTS:
         for (pattern, regex) in TRANSLATE_TESTS:
-            if re.escape(b'/') == b'/':
+            if re.escape(b"/") == b"/":
                 # Slash is no longer escaped in Python3.7, so undo the escaping
                 # Slash is no longer escaped in Python3.7, so undo the escaping
                 # in the expected return value..
                 # in the expected return value..
-                regex = regex.replace(b'\\/', b'/')
+                regex = regex.replace(b"\\/", b"/")
             self.assertEqual(
             self.assertEqual(
-                regex, translate(pattern),
-                "orig pattern: %r, regex: %r, expected: %r" %
-                (pattern, translate(pattern), regex))
+                regex,
+                translate(pattern),
+                "orig pattern: %r, regex: %r, expected: %r"
+                % (pattern, translate(pattern), regex),
+            )
 
 
 
 
 class ReadIgnorePatterns(TestCase):
 class ReadIgnorePatterns(TestCase):
-
     def test_read_file(self):
     def test_read_file(self):
-        f = BytesIO(b"""
+        f = BytesIO(
+            b"""
 # a comment
 # a comment
 
 
 # and an empty line:
 # and an empty line:
@@ -111,151 +112,161 @@ class ReadIgnorePatterns(TestCase):
 !negative
 !negative
 with trailing whitespace 
 with trailing whitespace 
 with escaped trailing whitespace\\ 
 with escaped trailing whitespace\\ 
-""")  # noqa: W291
-        self.assertEqual(list(read_ignore_patterns(f)), [
-            b'\\#not a comment',
-            b'!negative',
-            b'with trailing whitespace',
-            b'with escaped trailing whitespace '
-        ])
+"""
+        )  # noqa: W291
+        self.assertEqual(
+            list(read_ignore_patterns(f)),
+            [
+                b"\\#not a comment",
+                b"!negative",
+                b"with trailing whitespace",
+                b"with escaped trailing whitespace ",
+            ],
+        )
 
 
 
 
 class MatchPatternTests(TestCase):
 class MatchPatternTests(TestCase):
-
     def test_matches(self):
     def test_matches(self):
         for (path, pattern) in POSITIVE_MATCH_TESTS:
         for (path, pattern) in POSITIVE_MATCH_TESTS:
             self.assertTrue(
             self.assertTrue(
                 match_pattern(path, pattern),
                 match_pattern(path, pattern),
-                "path: %r, pattern: %r" % (path, pattern))
+                "path: %r, pattern: %r" % (path, pattern),
+            )
 
 
     def test_no_matches(self):
     def test_no_matches(self):
         for (path, pattern) in NEGATIVE_MATCH_TESTS:
         for (path, pattern) in NEGATIVE_MATCH_TESTS:
             self.assertFalse(
             self.assertFalse(
                 match_pattern(path, pattern),
                 match_pattern(path, pattern),
-                "path: %r, pattern: %r" % (path, pattern))
+                "path: %r, pattern: %r" % (path, pattern),
+            )
 
 
 
 
 class IgnoreFilterTests(TestCase):
 class IgnoreFilterTests(TestCase):
-
     def test_included(self):
     def test_included(self):
-        filter = IgnoreFilter([b'a.c', b'b.c'])
-        self.assertTrue(filter.is_ignored(b'a.c'))
-        self.assertIs(None, filter.is_ignored(b'c.c'))
-        self.assertEqual(
-            [Pattern(b'a.c')],
-            list(filter.find_matching(b'a.c')))
-        self.assertEqual(
-            [],
-            list(filter.find_matching(b'c.c')))
+        filter = IgnoreFilter([b"a.c", b"b.c"])
+        self.assertTrue(filter.is_ignored(b"a.c"))
+        self.assertIs(None, filter.is_ignored(b"c.c"))
+        self.assertEqual([Pattern(b"a.c")], list(filter.find_matching(b"a.c")))
+        self.assertEqual([], list(filter.find_matching(b"c.c")))
 
 
     def test_included_ignorecase(self):
     def test_included_ignorecase(self):
-        filter = IgnoreFilter([b'a.c', b'b.c'], ignorecase=False)
-        self.assertTrue(filter.is_ignored(b'a.c'))
-        self.assertFalse(filter.is_ignored(b'A.c'))
-        filter = IgnoreFilter([b'a.c', b'b.c'], ignorecase=True)
-        self.assertTrue(filter.is_ignored(b'a.c'))
-        self.assertTrue(filter.is_ignored(b'A.c'))
-        self.assertTrue(filter.is_ignored(b'A.C'))
+        filter = IgnoreFilter([b"a.c", b"b.c"], ignorecase=False)
+        self.assertTrue(filter.is_ignored(b"a.c"))
+        self.assertFalse(filter.is_ignored(b"A.c"))
+        filter = IgnoreFilter([b"a.c", b"b.c"], ignorecase=True)
+        self.assertTrue(filter.is_ignored(b"a.c"))
+        self.assertTrue(filter.is_ignored(b"A.c"))
+        self.assertTrue(filter.is_ignored(b"A.C"))
 
 
     def test_excluded(self):
     def test_excluded(self):
-        filter = IgnoreFilter([b'a.c', b'b.c', b'!c.c'])
-        self.assertFalse(filter.is_ignored(b'c.c'))
-        self.assertIs(None, filter.is_ignored(b'd.c'))
-        self.assertEqual(
-            [Pattern(b'!c.c')],
-            list(filter.find_matching(b'c.c')))
-        self.assertEqual([], list(filter.find_matching(b'd.c')))
+        filter = IgnoreFilter([b"a.c", b"b.c", b"!c.c"])
+        self.assertFalse(filter.is_ignored(b"c.c"))
+        self.assertIs(None, filter.is_ignored(b"d.c"))
+        self.assertEqual([Pattern(b"!c.c")], list(filter.find_matching(b"c.c")))
+        self.assertEqual([], list(filter.find_matching(b"d.c")))
 
 
     def test_include_exclude_include(self):
     def test_include_exclude_include(self):
-        filter = IgnoreFilter([b'a.c', b'!a.c', b'a.c'])
-        self.assertTrue(filter.is_ignored(b'a.c'))
+        filter = IgnoreFilter([b"a.c", b"!a.c", b"a.c"])
+        self.assertTrue(filter.is_ignored(b"a.c"))
         self.assertEqual(
         self.assertEqual(
-            [Pattern(b'a.c'), Pattern(b'!a.c'), Pattern(b'a.c')],
-            list(filter.find_matching(b'a.c')))
+            [Pattern(b"a.c"), Pattern(b"!a.c"), Pattern(b"a.c")],
+            list(filter.find_matching(b"a.c")),
+        )
 
 
     def test_manpage(self):
     def test_manpage(self):
         # A specific example from the gitignore manpage
         # A specific example from the gitignore manpage
-        filter = IgnoreFilter([
-            b'/*',
-            b'!/foo',
-            b'/foo/*',
-            b'!/foo/bar'])
-        self.assertTrue(filter.is_ignored(b'a.c'))
-        self.assertTrue(filter.is_ignored(b'foo/blie'))
-        self.assertFalse(filter.is_ignored(b'foo'))
-        self.assertFalse(filter.is_ignored(b'foo/bar'))
-        self.assertFalse(filter.is_ignored(b'foo/bar/'))
-        self.assertFalse(filter.is_ignored(b'foo/bar/bloe'))
+        filter = IgnoreFilter([b"/*", b"!/foo", b"/foo/*", b"!/foo/bar"])
+        self.assertTrue(filter.is_ignored(b"a.c"))
+        self.assertTrue(filter.is_ignored(b"foo/blie"))
+        self.assertFalse(filter.is_ignored(b"foo"))
+        self.assertFalse(filter.is_ignored(b"foo/bar"))
+        self.assertFalse(filter.is_ignored(b"foo/bar/"))
+        self.assertFalse(filter.is_ignored(b"foo/bar/bloe"))
 
 
 
 
 class IgnoreFilterStackTests(TestCase):
 class IgnoreFilterStackTests(TestCase):
-
     def test_stack_first(self):
     def test_stack_first(self):
-        filter1 = IgnoreFilter([b'[a].c', b'[b].c', b'![d].c'])
-        filter2 = IgnoreFilter([b'[a].c', b'![b],c', b'[c].c', b'[d].c'])
+        filter1 = IgnoreFilter([b"[a].c", b"[b].c", b"![d].c"])
+        filter2 = IgnoreFilter([b"[a].c", b"![b],c", b"[c].c", b"[d].c"])
         stack = IgnoreFilterStack([filter1, filter2])
         stack = IgnoreFilterStack([filter1, filter2])
-        self.assertIs(True, stack.is_ignored(b'a.c'))
-        self.assertIs(True, stack.is_ignored(b'b.c'))
-        self.assertIs(True, stack.is_ignored(b'c.c'))
-        self.assertIs(False, stack.is_ignored(b'd.c'))
-        self.assertIs(None, stack.is_ignored(b'e.c'))
+        self.assertIs(True, stack.is_ignored(b"a.c"))
+        self.assertIs(True, stack.is_ignored(b"b.c"))
+        self.assertIs(True, stack.is_ignored(b"c.c"))
+        self.assertIs(False, stack.is_ignored(b"d.c"))
+        self.assertIs(None, stack.is_ignored(b"e.c"))
 
 
 
 
 class IgnoreFilterManagerTests(TestCase):
 class IgnoreFilterManagerTests(TestCase):
-
     def test_load_ignore(self):
     def test_load_ignore(self):
         tmp_dir = tempfile.mkdtemp()
         tmp_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, tmp_dir)
         self.addCleanup(shutil.rmtree, tmp_dir)
         repo = Repo.init(tmp_dir)
         repo = Repo.init(tmp_dir)
+        with open(os.path.join(repo.path, ".gitignore"), "wb") as f:
+            f.write(b"/foo/bar\n")
+            f.write(b"/dir2\n")
+            f.write(b"/dir3/\n")
+        os.mkdir(os.path.join(repo.path, "dir"))
+        with open(os.path.join(repo.path, "dir", ".gitignore"), "wb") as f:
+            f.write(b"/blie\n")
+        with open(os.path.join(repo.path, "dir", "blie"), "wb") as f:
+            f.write(b"IGNORED")
+        p = os.path.join(repo.controldir(), "info", "exclude")
+        with open(p, "wb") as f:
+            f.write(b"/excluded\n")
+        m = IgnoreFilterManager.from_repo(repo)
+        self.assertTrue(m.is_ignored("dir/blie"))
+        self.assertIs(None, m.is_ignored(os.path.join("dir", "bloe")))
+        self.assertIs(None, m.is_ignored("dir"))
+        self.assertTrue(m.is_ignored(os.path.join("foo", "bar")))
+        self.assertTrue(m.is_ignored(os.path.join("excluded")))
+        self.assertTrue(m.is_ignored(os.path.join("dir2", "fileinignoreddir")))
+        self.assertFalse(m.is_ignored("dir3"))
+        self.assertTrue(m.is_ignored("dir3/"))
+        self.assertTrue(m.is_ignored("dir3/bla"))
+
+    def test_nested_gitignores(self):
+        tmp_dir = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, tmp_dir)
+        repo = Repo.init(tmp_dir)
+
         with open(os.path.join(repo.path, '.gitignore'), 'wb') as f:
         with open(os.path.join(repo.path, '.gitignore'), 'wb') as f:
-            f.write(b'/foo/bar\n')
-            f.write(b'/dir2\n')
-            f.write(b'/dir3/\n')
-        os.mkdir(os.path.join(repo.path, 'dir'))
-        with open(os.path.join(repo.path, 'dir', '.gitignore'), 'wb') as f:
-            f.write(b'/blie\n')
-        with open(os.path.join(repo.path, 'dir', 'blie'), 'wb') as f:
+            f.write(b'/*\n')
+            f.write(b'!/foo\n')
+
+        os.mkdir(os.path.join(repo.path, 'foo'))
+        with open(os.path.join(repo.path, 'foo', '.gitignore'), 'wb') as f:
+            f.write(b'/bar\n')
+
+        with open(os.path.join(repo.path, 'foo', 'bar'), 'wb') as f:
             f.write(b'IGNORED')
             f.write(b'IGNORED')
-        p = os.path.join(repo.controldir(), 'info', 'exclude')
-        with open(p, 'wb') as f:
-            f.write(b'/excluded\n')
+        
         m = IgnoreFilterManager.from_repo(repo)
         m = IgnoreFilterManager.from_repo(repo)
-        self.assertTrue(m.is_ignored('dir/blie'))
-        self.assertIs(None,
-                      m.is_ignored(os.path.join('dir', 'bloe')))
-        self.assertIs(None, m.is_ignored('dir'))
-        self.assertTrue(m.is_ignored(os.path.join('foo', 'bar')))
-        self.assertTrue(m.is_ignored(os.path.join('excluded')))
-        self.assertTrue(m.is_ignored(os.path.join(
-            'dir2', 'fileinignoreddir')))
-        self.assertFalse(m.is_ignored('dir3'))
-        self.assertTrue(m.is_ignored('dir3/'))
-        self.assertTrue(m.is_ignored('dir3/bla'))
+        self.assertTrue(m.is_ignored('foo/bar'))
 
 
     def test_load_ignore_ignorecase(self):
     def test_load_ignore_ignorecase(self):
         tmp_dir = tempfile.mkdtemp()
         tmp_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, tmp_dir)
         self.addCleanup(shutil.rmtree, tmp_dir)
         repo = Repo.init(tmp_dir)
         repo = Repo.init(tmp_dir)
         config = repo.get_config()
         config = repo.get_config()
-        config.set(b'core', b'ignorecase', True)
+        config.set(b"core", b"ignorecase", True)
         config.write_to_path()
         config.write_to_path()
-        with open(os.path.join(repo.path, '.gitignore'), 'wb') as f:
-            f.write(b'/foo/bar\n')
-            f.write(b'/dir\n')
+        with open(os.path.join(repo.path, ".gitignore"), "wb") as f:
+            f.write(b"/foo/bar\n")
+            f.write(b"/dir\n")
         m = IgnoreFilterManager.from_repo(repo)
         m = IgnoreFilterManager.from_repo(repo)
-        self.assertTrue(m.is_ignored(os.path.join('dir', 'blie')))
-        self.assertTrue(m.is_ignored(os.path.join('DIR', 'blie')))
+        self.assertTrue(m.is_ignored(os.path.join("dir", "blie")))
+        self.assertTrue(m.is_ignored(os.path.join("DIR", "blie")))
 
 
     def test_ignored_contents(self):
     def test_ignored_contents(self):
         tmp_dir = tempfile.mkdtemp()
         tmp_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, tmp_dir)
         self.addCleanup(shutil.rmtree, tmp_dir)
         repo = Repo.init(tmp_dir)
         repo = Repo.init(tmp_dir)
-        with open(os.path.join(repo.path, '.gitignore'), 'wb') as f:
-            f.write(b'a/*\n')
-            f.write(b'!a/*.txt\n')
+        with open(os.path.join(repo.path, ".gitignore"), "wb") as f:
+            f.write(b"a/*\n")
+            f.write(b"!a/*.txt\n")
         m = IgnoreFilterManager.from_repo(repo)
         m = IgnoreFilterManager.from_repo(repo)
-        os.mkdir(os.path.join(repo.path, 'a'))
-        self.assertIs(None, m.is_ignored('a'))
-        self.assertIs(None, m.is_ignored('a/'))
-        self.assertFalse(m.is_ignored('a/b.txt'))
-        self.assertTrue(m.is_ignored('a/c.dat'))
+        os.mkdir(os.path.join(repo.path, "a"))
+        self.assertIs(None, m.is_ignored("a"))
+        self.assertIs(None, m.is_ignored("a/"))
+        self.assertFalse(m.is_ignored("a/b.txt"))
+        self.assertTrue(m.is_ignored("a/c.dat"))

+ 296 - 221
dulwich/tests/test_index.py

@@ -48,38 +48,39 @@ from dulwich.index import (
     write_index_dict,
     write_index_dict,
     _tree_to_fs_path,
     _tree_to_fs_path,
     _fs_to_tree_path,
     _fs_to_tree_path,
-    )
+    IndexEntry,
+)
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
     Commit,
     Commit,
     Tree,
     Tree,
     S_IFGITLINK,
     S_IFGITLINK,
-    )
+)
 from dulwich.repo import Repo
 from dulwich.repo import Repo
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
     skipIf,
     skipIf,
-    )
+)
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     setup_warning_catcher,
     setup_warning_catcher,
-    )
+)
 
 
 
 
 def can_symlink():
 def can_symlink():
     """Return whether running process can create symlinks."""
     """Return whether running process can create symlinks."""
-    if sys.platform != 'win32':
+    if sys.platform != "win32":
         # Platforms other than Windows should allow symlinks without issues.
         # Platforms other than Windows should allow symlinks without issues.
         return True
         return True
 
 
-    if not hasattr(os, 'symlink'):
+    if not hasattr(os, "symlink"):
         # Older Python versions do not have `os.symlink` on Windows.
         # Older Python versions do not have `os.symlink` on Windows.
         return False
         return False
 
 
     test_source = tempfile.mkdtemp()
     test_source = tempfile.mkdtemp()
-    test_target = test_source + 'can_symlink'
+    test_target = test_source + "can_symlink"
     try:
     try:
         os.symlink(test_source, test_target)
         os.symlink(test_source, test_target)
     except (NotImplementedError, OSError):
     except (NotImplementedError, OSError):
@@ -89,24 +90,24 @@ def can_symlink():
 
 
 class IndexTestCase(TestCase):
 class IndexTestCase(TestCase):
 
 
-    datadir = os.path.join(os.path.dirname(__file__), 'data/indexes')
+    datadir = os.path.join(os.path.dirname(__file__), "data/indexes")
 
 
     def get_simple_index(self, name):
     def get_simple_index(self, name):
         return Index(os.path.join(self.datadir, name))
         return Index(os.path.join(self.datadir, name))
 
 
 
 
 class SimpleIndexTestCase(IndexTestCase):
 class SimpleIndexTestCase(IndexTestCase):
-
     def test_len(self):
     def test_len(self):
         self.assertEqual(1, len(self.get_simple_index("index")))
         self.assertEqual(1, len(self.get_simple_index("index")))
 
 
     def test_iter(self):
     def test_iter(self):
-        self.assertEqual([b'bla'], list(self.get_simple_index("index")))
+        self.assertEqual([b"bla"], list(self.get_simple_index("index")))
 
 
     def test_iterobjects(self):
     def test_iterobjects(self):
         self.assertEqual(
         self.assertEqual(
-                [(b'bla', b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 33188)],
-                list(self.get_simple_index("index").iterobjects()))
+            [(b"bla", b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391", 33188)],
+            list(self.get_simple_index("index").iterobjects()),
+        )
 
 
     def test_iterblobs(self):
     def test_iterblobs(self):
         warnings.simplefilter("always", UserWarning)
         warnings.simplefilter("always", UserWarning)
@@ -115,26 +116,36 @@ class SimpleIndexTestCase(IndexTestCase):
         self.addCleanup(restore_warnings)
         self.addCleanup(restore_warnings)
 
 
         self.assertEqual(
         self.assertEqual(
-                [(b'bla', b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 33188)],
-                list(self.get_simple_index("index").iterblobs()))
+            [(b"bla", b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391", 33188)],
+            list(self.get_simple_index("index").iterblobs()),
+        )
 
 
-        expected_warning = PendingDeprecationWarning(
-            'Use iterobjects() instead.')
+        expected_warning = PendingDeprecationWarning("Use iterobjects() instead.")
         for w in warnings_list:
         for w in warnings_list:
-            if (type(w) == type(expected_warning) and
-                    w.args == expected_warning.args):
+            if type(w) == type(expected_warning) and w.args == expected_warning.args:
                 break
                 break
         else:
         else:
             raise AssertionError(
             raise AssertionError(
-                'Expected warning %r not in %r' %
-                (expected_warning, warnings_list))
+                "Expected warning %r not in %r" % (expected_warning, warnings_list)
+            )
 
 
     def test_getitem(self):
     def test_getitem(self):
         self.assertEqual(
         self.assertEqual(
-                ((1230680220, 0), (1230680220, 0), 2050, 3761020,
-                 33188, 1000, 1000, 0,
-                 b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0),
-                self.get_simple_index("index")[b"bla"])
+            (
+                (1230680220, 0),
+                (1230680220, 0),
+                2050,
+                3761020,
+                33188,
+                1000,
+                1000,
+                0,
+                b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
+                0,
+                0,
+            ),
+            self.get_simple_index("index")[b"bla"],
+        )
 
 
     def test_empty(self):
     def test_empty(self):
         i = self.get_simple_index("notanindex")
         i = self.get_simple_index("notanindex")
@@ -146,12 +157,11 @@ class SimpleIndexTestCase(IndexTestCase):
         changes = list(i.changes_from_tree(MemoryObjectStore(), None))
         changes = list(i.changes_from_tree(MemoryObjectStore(), None))
         self.assertEqual(1, len(changes))
         self.assertEqual(1, len(changes))
         (oldname, newname), (oldmode, newmode), (oldsha, newsha) = changes[0]
         (oldname, newname), (oldmode, newmode), (oldsha, newsha) = changes[0]
-        self.assertEqual(b'bla', newname)
-        self.assertEqual(b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', newsha)
+        self.assertEqual(b"bla", newname)
+        self.assertEqual(b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391", newsha)
 
 
 
 
 class SimpleIndexWriterTestCase(IndexTestCase):
 class SimpleIndexWriterTestCase(IndexTestCase):
-
     def setUp(self):
     def setUp(self):
         IndexTestCase.setUp(self)
         IndexTestCase.setUp(self)
         self.tempdir = tempfile.mkdtemp()
         self.tempdir = tempfile.mkdtemp()
@@ -161,19 +171,32 @@ class SimpleIndexWriterTestCase(IndexTestCase):
         shutil.rmtree(self.tempdir)
         shutil.rmtree(self.tempdir)
 
 
     def test_simple_write(self):
     def test_simple_write(self):
-        entries = [(b'barbla', (1230680220, 0), (1230680220, 0), 2050, 3761020,
-                    33188, 1000, 1000, 0,
-                    b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0)]
-        filename = os.path.join(self.tempdir, 'test-simple-write-index')
-        with open(filename, 'wb+') as x:
+        entries = [
+            (
+                b"barbla",
+                IndexEntry(
+                    (1230680220, 0),
+                    (1230680220, 0),
+                    2050,
+                    3761020,
+                    33188,
+                    1000,
+                    1000,
+                    0,
+                    b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
+                    0,
+                    0)
+            )
+        ]
+        filename = os.path.join(self.tempdir, "test-simple-write-index")
+        with open(filename, "wb+") as x:
             write_index(x, entries)
             write_index(x, entries)
 
 
-        with open(filename, 'rb') as x:
+        with open(filename, "rb") as x:
             self.assertEqual(entries, list(read_index(x)))
             self.assertEqual(entries, list(read_index(x)))
 
 
 
 
 class ReadIndexDictTests(IndexTestCase):
 class ReadIndexDictTests(IndexTestCase):
-
     def setUp(self):
     def setUp(self):
         IndexTestCase.setUp(self)
         IndexTestCase.setUp(self)
         self.tempdir = tempfile.mkdtemp()
         self.tempdir = tempfile.mkdtemp()
@@ -184,20 +207,29 @@ class ReadIndexDictTests(IndexTestCase):
 
 
     def test_simple_write(self):
     def test_simple_write(self):
         entries = {
         entries = {
-                b'barbla':
-                ((1230680220, 0), (1230680220, 0), 2050, 3761020, 33188,
-                 1000, 1000, 0,
-                 b'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0)}
-        filename = os.path.join(self.tempdir, 'test-simple-write-index')
-        with open(filename, 'wb+') as x:
+            b"barbla": IndexEntry(
+                (1230680220, 0),
+                (1230680220, 0),
+                2050,
+                3761020,
+                33188,
+                1000,
+                1000,
+                0,
+                b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
+                0,
+                0,
+            )
+        }
+        filename = os.path.join(self.tempdir, "test-simple-write-index")
+        with open(filename, "wb+") as x:
             write_index_dict(x, entries)
             write_index_dict(x, entries)
 
 
-        with open(filename, 'rb') as x:
+        with open(filename, "rb") as x:
             self.assertEqual(entries, read_index_dict(x))
             self.assertEqual(entries, read_index_dict(x))
 
 
 
 
 class CommitTreeTests(TestCase):
 class CommitTreeTests(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(CommitTreeTests, self).setUp()
         super(CommitTreeTests, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
@@ -223,14 +255,12 @@ class CommitTreeTests(TestCase):
         self.assertEqual(dirid, b"c1a1deb9788150829579a8b4efa6311e7b638650")
         self.assertEqual(dirid, b"c1a1deb9788150829579a8b4efa6311e7b638650")
         self.assertEqual((stat.S_IFDIR, dirid), self.store[rootid][b"bla"])
         self.assertEqual((stat.S_IFDIR, dirid), self.store[rootid][b"bla"])
         self.assertEqual((stat.S_IFREG, blob.id), self.store[dirid][b"bar"])
         self.assertEqual((stat.S_IFREG, blob.id), self.store[dirid][b"bar"])
-        self.assertEqual(set([rootid, dirid, blob.id]),
-                         set(self.store._data.keys()))
+        self.assertEqual(set([rootid, dirid, blob.id]), set(self.store._data.keys()))
 
 
 
 
 class CleanupModeTests(TestCase):
 class CleanupModeTests(TestCase):
-
     def assertModeEqual(self, expected, got):
     def assertModeEqual(self, expected, got):
-        self.assertEqual(expected, got, '%o != %o' % (expected, got))
+        self.assertEqual(expected, got, "%o != %o" % (expected, got))
 
 
     def test_file(self):
     def test_file(self):
         self.assertModeEqual(0o100644, cleanup_mode(0o100000))
         self.assertModeEqual(0o100644, cleanup_mode(0o100000))
@@ -250,7 +280,6 @@ class CleanupModeTests(TestCase):
 
 
 
 
 class WriteCacheTimeTests(TestCase):
 class WriteCacheTimeTests(TestCase):
-
     def test_write_string(self):
     def test_write_string(self):
         f = BytesIO()
         f = BytesIO()
         self.assertRaises(TypeError, write_cache_time, f, "foo")
         self.assertRaises(TypeError, write_cache_time, f, "foo")
@@ -272,46 +301,74 @@ class WriteCacheTimeTests(TestCase):
 
 
 
 
 class IndexEntryFromStatTests(TestCase):
 class IndexEntryFromStatTests(TestCase):
-
     def test_simple(self):
     def test_simple(self):
         st = os.stat_result(
         st = os.stat_result(
-                (16877, 131078, 64769, 154, 1000, 1000, 12288,
-                 1323629595, 1324180496, 1324180496))
+            (
+                16877,
+                131078,
+                64769,
+                154,
+                1000,
+                1000,
+                12288,
+                1323629595,
+                1324180496,
+                1324180496,
+            )
+        )
         entry = index_entry_from_stat(st, "22" * 20, 0)
         entry = index_entry_from_stat(st, "22" * 20, 0)
-        self.assertEqual(entry, (
-            1324180496,
-            1324180496,
-            64769,
-            131078,
-            16384,
-            1000,
-            1000,
-            12288,
-            '2222222222222222222222222222222222222222',
-            0))
+        self.assertEqual(
+            entry,
+            IndexEntry(
+                1324180496,
+                1324180496,
+                64769,
+                131078,
+                16384,
+                1000,
+                1000,
+                12288,
+                "2222222222222222222222222222222222222222",
+                0,
+                None,
+            ),
+        )
 
 
     def test_override_mode(self):
     def test_override_mode(self):
         st = os.stat_result(
         st = os.stat_result(
-                (stat.S_IFREG + 0o644, 131078, 64769,
-                 154, 1000, 1000, 12288,
-                 1323629595, 1324180496, 1324180496))
-        entry = index_entry_from_stat(
-            st, "22" * 20, 0, mode=stat.S_IFREG + 0o755)
-        self.assertEqual(entry, (
-            1324180496,
-            1324180496,
-            64769,
-            131078,
-            33261,
-            1000,
-            1000,
-            12288,
-            '2222222222222222222222222222222222222222',
-            0))
+            (
+                stat.S_IFREG + 0o644,
+                131078,
+                64769,
+                154,
+                1000,
+                1000,
+                12288,
+                1323629595,
+                1324180496,
+                1324180496,
+            )
+        )
+        entry = index_entry_from_stat(st, "22" * 20, 0, mode=stat.S_IFREG + 0o755)
+        self.assertEqual(
+            entry,
+            IndexEntry(
+                1324180496,
+                1324180496,
+                64769,
+                131078,
+                33261,
+                1000,
+                1000,
+                12288,
+                "2222222222222222222222222222222222222222",
+                0,
+                None,
+            ),
+        )
 
 
 
 
 class BuildIndexTests(TestCase):
 class BuildIndexTests(TestCase):
-
     def assertReasonableIndexEntry(self, index_entry, mode, filesize, sha):
     def assertReasonableIndexEntry(self, index_entry, mode, filesize, sha):
         self.assertEqual(index_entry[4], mode)  # mode
         self.assertEqual(index_entry[4], mode)  # mode
         self.assertEqual(index_entry[7], filesize)  # filesize
         self.assertEqual(index_entry[7], filesize)  # filesize
@@ -321,7 +378,7 @@ class BuildIndexTests(TestCase):
         if symlink:
         if symlink:
             self.assertEqual(os.readlink(path), contents)
             self.assertEqual(os.readlink(path), contents)
         else:
         else:
-            with open(path, 'rb') as f:
+            with open(path, "rb") as f:
                 self.assertEqual(f.read(), contents)
                 self.assertEqual(f.read(), contents)
 
 
     def test_empty(self):
     def test_empty(self):
@@ -332,15 +389,15 @@ class BuildIndexTests(TestCase):
             repo.object_store.add_object(tree)
             repo.object_store.add_object(tree)
 
 
             build_index_from_tree(
             build_index_from_tree(
-                    repo.path, repo.index_path(),
-                    repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
 
             # Verify index entries
             # Verify index entries
             index = repo.open_index()
             index = repo.open_index()
             self.assertEqual(len(index), 0)
             self.assertEqual(len(index), 0)
 
 
             # Verify no files
             # Verify no files
-            self.assertEqual(['.git'], os.listdir(repo.path))
+            self.assertEqual([".git"], os.listdir(repo.path))
 
 
     def test_git_dir(self):
     def test_git_dir(self):
         repo_dir = tempfile.mkdtemp()
         repo_dir = tempfile.mkdtemp()
@@ -348,33 +405,34 @@ class BuildIndexTests(TestCase):
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
 
 
             # Populate repo
             # Populate repo
-            filea = Blob.from_string(b'file a')
-            filee = Blob.from_string(b'd')
+            filea = Blob.from_string(b"file a")
+            filee = Blob.from_string(b"d")
 
 
             tree = Tree()
             tree = Tree()
-            tree[b'.git/a'] = (stat.S_IFREG | 0o644, filea.id)
-            tree[b'c/e'] = (stat.S_IFREG | 0o644, filee.id)
+            tree[b".git/a"] = (stat.S_IFREG | 0o644, filea.id)
+            tree[b"c/e"] = (stat.S_IFREG | 0o644, filee.id)
 
 
-            repo.object_store.add_objects(
-                    [(o, None) for o in [filea, filee, tree]])
+            repo.object_store.add_objects([(o, None) for o in [filea, filee, tree]])
 
 
             build_index_from_tree(
             build_index_from_tree(
-                repo.path, repo.index_path(), repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
 
             # Verify index entries
             # Verify index entries
             index = repo.open_index()
             index = repo.open_index()
             self.assertEqual(len(index), 1)
             self.assertEqual(len(index), 1)
 
 
             # filea
             # filea
-            apath = os.path.join(repo.path, '.git', 'a')
+            apath = os.path.join(repo.path, ".git", "a")
             self.assertFalse(os.path.exists(apath))
             self.assertFalse(os.path.exists(apath))
 
 
             # filee
             # filee
-            epath = os.path.join(repo.path, 'c', 'e')
+            epath = os.path.join(repo.path, "c", "e")
             self.assertTrue(os.path.exists(epath))
             self.assertTrue(os.path.exists(epath))
             self.assertReasonableIndexEntry(
             self.assertReasonableIndexEntry(
-                index[b'c/e'], stat.S_IFREG | 0o644, 1, filee.id)
-            self.assertFileContents(epath, b'd')
+                index[b"c/e"], stat.S_IFREG | 0o644, 1, filee.id
+            )
+            self.assertFileContents(epath, b"d")
 
 
     def test_nonempty(self):
     def test_nonempty(self):
         repo_dir = tempfile.mkdtemp()
         repo_dir = tempfile.mkdtemp()
@@ -382,122 +440,130 @@ class BuildIndexTests(TestCase):
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
 
 
             # Populate repo
             # Populate repo
-            filea = Blob.from_string(b'file a')
-            fileb = Blob.from_string(b'file b')
-            filed = Blob.from_string(b'file d')
+            filea = Blob.from_string(b"file a")
+            fileb = Blob.from_string(b"file b")
+            filed = Blob.from_string(b"file d")
 
 
             tree = Tree()
             tree = Tree()
-            tree[b'a'] = (stat.S_IFREG | 0o644, filea.id)
-            tree[b'b'] = (stat.S_IFREG | 0o644, fileb.id)
-            tree[b'c/d'] = (stat.S_IFREG | 0o644, filed.id)
+            tree[b"a"] = (stat.S_IFREG | 0o644, filea.id)
+            tree[b"b"] = (stat.S_IFREG | 0o644, fileb.id)
+            tree[b"c/d"] = (stat.S_IFREG | 0o644, filed.id)
 
 
             repo.object_store.add_objects(
             repo.object_store.add_objects(
-                [(o, None) for o in [filea, fileb, filed, tree]])
+                [(o, None) for o in [filea, fileb, filed, tree]]
+            )
 
 
             build_index_from_tree(
             build_index_from_tree(
-                repo.path, repo.index_path(), repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
 
             # Verify index entries
             # Verify index entries
             index = repo.open_index()
             index = repo.open_index()
             self.assertEqual(len(index), 3)
             self.assertEqual(len(index), 3)
 
 
             # filea
             # filea
-            apath = os.path.join(repo.path, 'a')
+            apath = os.path.join(repo.path, "a")
             self.assertTrue(os.path.exists(apath))
             self.assertTrue(os.path.exists(apath))
             self.assertReasonableIndexEntry(
             self.assertReasonableIndexEntry(
-                    index[b'a'], stat.S_IFREG | 0o644, 6, filea.id)
-            self.assertFileContents(apath, b'file a')
+                index[b"a"], stat.S_IFREG | 0o644, 6, filea.id
+            )
+            self.assertFileContents(apath, b"file a")
 
 
             # fileb
             # fileb
-            bpath = os.path.join(repo.path, 'b')
+            bpath = os.path.join(repo.path, "b")
             self.assertTrue(os.path.exists(bpath))
             self.assertTrue(os.path.exists(bpath))
             self.assertReasonableIndexEntry(
             self.assertReasonableIndexEntry(
-                    index[b'b'], stat.S_IFREG | 0o644, 6, fileb.id)
-            self.assertFileContents(bpath, b'file b')
+                index[b"b"], stat.S_IFREG | 0o644, 6, fileb.id
+            )
+            self.assertFileContents(bpath, b"file b")
 
 
             # filed
             # filed
-            dpath = os.path.join(repo.path, 'c', 'd')
+            dpath = os.path.join(repo.path, "c", "d")
             self.assertTrue(os.path.exists(dpath))
             self.assertTrue(os.path.exists(dpath))
             self.assertReasonableIndexEntry(
             self.assertReasonableIndexEntry(
-                    index[b'c/d'], stat.S_IFREG | 0o644, 6, filed.id)
-            self.assertFileContents(dpath, b'file d')
+                index[b"c/d"], stat.S_IFREG | 0o644, 6, filed.id
+            )
+            self.assertFileContents(dpath, b"file d")
 
 
             # Verify no extra files
             # Verify no extra files
-            self.assertEqual(
-                    ['.git', 'a', 'b', 'c'], sorted(os.listdir(repo.path)))
-            self.assertEqual(
-                    ['d'], sorted(os.listdir(os.path.join(repo.path, 'c'))))
+            self.assertEqual([".git", "a", "b", "c"], sorted(os.listdir(repo.path)))
+            self.assertEqual(["d"], sorted(os.listdir(os.path.join(repo.path, "c"))))
 
 
-    @skipIf(not getattr(os, 'sync', None), 'Requires sync support')
+    @skipIf(not getattr(os, "sync", None), "Requires sync support")
     def test_norewrite(self):
     def test_norewrite(self):
         repo_dir = tempfile.mkdtemp()
         repo_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, repo_dir)
         self.addCleanup(shutil.rmtree, repo_dir)
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
             # Populate repo
             # Populate repo
-            filea = Blob.from_string(b'file a')
-            filea_path = os.path.join(repo_dir, 'a')
+            filea = Blob.from_string(b"file a")
+            filea_path = os.path.join(repo_dir, "a")
             tree = Tree()
             tree = Tree()
-            tree[b'a'] = (stat.S_IFREG | 0o644, filea.id)
+            tree[b"a"] = (stat.S_IFREG | 0o644, filea.id)
 
 
             repo.object_store.add_objects([(o, None) for o in [filea, tree]])
             repo.object_store.add_objects([(o, None) for o in [filea, tree]])
 
 
             # First Write
             # First Write
-            build_index_from_tree(repo.path, repo.index_path(),
-                                  repo.object_store, tree.id)
+            build_index_from_tree(
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
             # Use sync as metadata can be cached on some FS
             # Use sync as metadata can be cached on some FS
             os.sync()
             os.sync()
             mtime = os.stat(filea_path).st_mtime
             mtime = os.stat(filea_path).st_mtime
 
 
             # Test Rewrite
             # Test Rewrite
-            build_index_from_tree(repo.path, repo.index_path(),
-                                  repo.object_store, tree.id)
+            build_index_from_tree(
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
             os.sync()
             os.sync()
             self.assertEqual(mtime, os.stat(filea_path).st_mtime)
             self.assertEqual(mtime, os.stat(filea_path).st_mtime)
 
 
             # Modify content
             # Modify content
-            with open(filea_path, 'wb') as fh:
-                fh.write(b'test a')
+            with open(filea_path, "wb") as fh:
+                fh.write(b"test a")
             os.sync()
             os.sync()
             mtime = os.stat(filea_path).st_mtime
             mtime = os.stat(filea_path).st_mtime
 
 
             # Test rewrite
             # Test rewrite
-            build_index_from_tree(repo.path, repo.index_path(),
-                                  repo.object_store, tree.id)
+            build_index_from_tree(
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
             os.sync()
             os.sync()
-            with open(filea_path, 'rb') as fh:
-                self.assertEqual(b'file a', fh.read())
+            with open(filea_path, "rb") as fh:
+                self.assertEqual(b"file a", fh.read())
 
 
-    @skipIf(not can_symlink(), 'Requires symlink support')
+    @skipIf(not can_symlink(), "Requires symlink support")
     def test_symlink(self):
     def test_symlink(self):
         repo_dir = tempfile.mkdtemp()
         repo_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, repo_dir)
         self.addCleanup(shutil.rmtree, repo_dir)
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
 
 
             # Populate repo
             # Populate repo
-            filed = Blob.from_string(b'file d')
-            filee = Blob.from_string(b'd')
+            filed = Blob.from_string(b"file d")
+            filee = Blob.from_string(b"d")
 
 
             tree = Tree()
             tree = Tree()
-            tree[b'c/d'] = (stat.S_IFREG | 0o644, filed.id)
-            tree[b'c/e'] = (stat.S_IFLNK, filee.id)  # symlink
+            tree[b"c/d"] = (stat.S_IFREG | 0o644, filed.id)
+            tree[b"c/e"] = (stat.S_IFLNK, filee.id)  # symlink
 
 
-            repo.object_store.add_objects(
-                    [(o, None) for o in [filed, filee, tree]])
+            repo.object_store.add_objects([(o, None) for o in [filed, filee, tree]])
 
 
             build_index_from_tree(
             build_index_from_tree(
-                    repo.path, repo.index_path(), repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
 
             # Verify index entries
             # Verify index entries
             index = repo.open_index()
             index = repo.open_index()
 
 
             # symlink to d
             # symlink to d
-            epath = os.path.join(repo.path, 'c', 'e')
+            epath = os.path.join(repo.path, "c", "e")
             self.assertTrue(os.path.exists(epath))
             self.assertTrue(os.path.exists(epath))
             self.assertReasonableIndexEntry(
             self.assertReasonableIndexEntry(
-                index[b'c/e'], stat.S_IFLNK,
-                0 if sys.platform == 'win32' else 1,
-                filee.id)
-            self.assertFileContents(epath, 'd', symlink=True)
+                index[b"c/e"],
+                stat.S_IFLNK,
+                0 if sys.platform == "win32" else 1,
+                filee.id,
+            )
+            self.assertFileContents(epath, "d", symlink=True)
 
 
     def test_no_decode_encode(self):
     def test_no_decode_encode(self):
         repo_dir = tempfile.mkdtemp()
         repo_dir = tempfile.mkdtemp()
@@ -506,33 +572,32 @@ class BuildIndexTests(TestCase):
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
 
 
             # Populate repo
             # Populate repo
-            file = Blob.from_string(b'foo')
+            file = Blob.from_string(b"foo")
 
 
             tree = Tree()
             tree = Tree()
-            latin1_name = u'À'.encode('latin1')
+            latin1_name = u"À".encode("latin1")
             latin1_path = os.path.join(repo_dir_bytes, latin1_name)
             latin1_path = os.path.join(repo_dir_bytes, latin1_name)
-            utf8_name = u'À'.encode('utf8')
+            utf8_name = u"À".encode("utf8")
             utf8_path = os.path.join(repo_dir_bytes, utf8_name)
             utf8_path = os.path.join(repo_dir_bytes, utf8_name)
             tree[latin1_name] = (stat.S_IFREG | 0o644, file.id)
             tree[latin1_name] = (stat.S_IFREG | 0o644, file.id)
             tree[utf8_name] = (stat.S_IFREG | 0o644, file.id)
             tree[utf8_name] = (stat.S_IFREG | 0o644, file.id)
 
 
-            repo.object_store.add_objects(
-                [(o, None) for o in [file, tree]])
+            repo.object_store.add_objects([(o, None) for o in [file, tree]])
 
 
             try:
             try:
                 build_index_from_tree(
                 build_index_from_tree(
-                    repo.path, repo.index_path(),
-                    repo.object_store, tree.id)
+                    repo.path, repo.index_path(), repo.object_store, tree.id
+                )
             except OSError as e:
             except OSError as e:
-                if e.errno == 92 and sys.platform == 'darwin':
+                if e.errno == 92 and sys.platform == "darwin":
                     # Our filename isn't supported by the platform :(
                     # Our filename isn't supported by the platform :(
-                    self.skipTest('can not write filename %r' % e.filename)
+                    self.skipTest("can not write filename %r" % e.filename)
                 else:
                 else:
                     raise
                     raise
             except UnicodeDecodeError:
             except UnicodeDecodeError:
                 # This happens e.g. with python3.6 on Windows.
                 # This happens e.g. with python3.6 on Windows.
                 # It implicitly decodes using utf8, which doesn't work.
                 # It implicitly decodes using utf8, which doesn't work.
-                self.skipTest('can not implicitly convert as utf8')
+                self.skipTest("can not implicitly convert as utf8")
 
 
             # Verify index entries
             # Verify index entries
             index = repo.open_index()
             index = repo.open_index()
@@ -547,86 +612,85 @@ class BuildIndexTests(TestCase):
         repo_dir = tempfile.mkdtemp()
         repo_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, repo_dir)
         self.addCleanup(shutil.rmtree, repo_dir)
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
-            filea = Blob.from_string(b'file alalala')
+            filea = Blob.from_string(b"file alalala")
 
 
             subtree = Tree()
             subtree = Tree()
-            subtree[b'a'] = (stat.S_IFREG | 0o644, filea.id)
+            subtree[b"a"] = (stat.S_IFREG | 0o644, filea.id)
 
 
             c = Commit()
             c = Commit()
             c.tree = subtree.id
             c.tree = subtree.id
-            c.committer = c.author = b'Somebody <somebody@example.com>'
+            c.committer = c.author = b"Somebody <somebody@example.com>"
             c.commit_time = c.author_time = 42342
             c.commit_time = c.author_time = 42342
             c.commit_timezone = c.author_timezone = 0
             c.commit_timezone = c.author_timezone = 0
             c.parents = []
             c.parents = []
-            c.message = b'Subcommit'
+            c.message = b"Subcommit"
 
 
             tree = Tree()
             tree = Tree()
-            tree[b'c'] = (S_IFGITLINK, c.id)
+            tree[b"c"] = (S_IFGITLINK, c.id)
 
 
-            repo.object_store.add_objects(
-                [(o, None) for o in [tree]])
+            repo.object_store.add_objects([(o, None) for o in [tree]])
 
 
             build_index_from_tree(
             build_index_from_tree(
-                    repo.path, repo.index_path(), repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
 
             # Verify index entries
             # Verify index entries
             index = repo.open_index()
             index = repo.open_index()
             self.assertEqual(len(index), 1)
             self.assertEqual(len(index), 1)
 
 
             # filea
             # filea
-            apath = os.path.join(repo.path, 'c/a')
+            apath = os.path.join(repo.path, "c/a")
             self.assertFalse(os.path.exists(apath))
             self.assertFalse(os.path.exists(apath))
 
 
             # dir c
             # dir c
-            cpath = os.path.join(repo.path, 'c')
+            cpath = os.path.join(repo.path, "c")
             self.assertTrue(os.path.isdir(cpath))
             self.assertTrue(os.path.isdir(cpath))
-            self.assertEqual(index[b'c'][4], S_IFGITLINK)  # mode
-            self.assertEqual(index[b'c'][8], c.id)  # sha
+            self.assertEqual(index[b"c"][4], S_IFGITLINK)  # mode
+            self.assertEqual(index[b"c"][8], c.id)  # sha
 
 
     def test_git_submodule_exists(self):
     def test_git_submodule_exists(self):
         repo_dir = tempfile.mkdtemp()
         repo_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, repo_dir)
         self.addCleanup(shutil.rmtree, repo_dir)
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
-            filea = Blob.from_string(b'file alalala')
+            filea = Blob.from_string(b"file alalala")
 
 
             subtree = Tree()
             subtree = Tree()
-            subtree[b'a'] = (stat.S_IFREG | 0o644, filea.id)
+            subtree[b"a"] = (stat.S_IFREG | 0o644, filea.id)
 
 
             c = Commit()
             c = Commit()
             c.tree = subtree.id
             c.tree = subtree.id
-            c.committer = c.author = b'Somebody <somebody@example.com>'
+            c.committer = c.author = b"Somebody <somebody@example.com>"
             c.commit_time = c.author_time = 42342
             c.commit_time = c.author_time = 42342
             c.commit_timezone = c.author_timezone = 0
             c.commit_timezone = c.author_timezone = 0
             c.parents = []
             c.parents = []
-            c.message = b'Subcommit'
+            c.message = b"Subcommit"
 
 
             tree = Tree()
             tree = Tree()
-            tree[b'c'] = (S_IFGITLINK, c.id)
+            tree[b"c"] = (S_IFGITLINK, c.id)
 
 
-            os.mkdir(os.path.join(repo_dir, 'c'))
-            repo.object_store.add_objects(
-                [(o, None) for o in [tree]])
+            os.mkdir(os.path.join(repo_dir, "c"))
+            repo.object_store.add_objects([(o, None) for o in [tree]])
 
 
             build_index_from_tree(
             build_index_from_tree(
-                    repo.path, repo.index_path(), repo.object_store, tree.id)
+                repo.path, repo.index_path(), repo.object_store, tree.id
+            )
 
 
             # Verify index entries
             # Verify index entries
             index = repo.open_index()
             index = repo.open_index()
             self.assertEqual(len(index), 1)
             self.assertEqual(len(index), 1)
 
 
             # filea
             # filea
-            apath = os.path.join(repo.path, 'c/a')
+            apath = os.path.join(repo.path, "c/a")
             self.assertFalse(os.path.exists(apath))
             self.assertFalse(os.path.exists(apath))
 
 
             # dir c
             # dir c
-            cpath = os.path.join(repo.path, 'c')
+            cpath = os.path.join(repo.path, "c")
             self.assertTrue(os.path.isdir(cpath))
             self.assertTrue(os.path.isdir(cpath))
-            self.assertEqual(index[b'c'][4], S_IFGITLINK)  # mode
-            self.assertEqual(index[b'c'][8], c.id)  # sha
+            self.assertEqual(index[b"c"][4], S_IFGITLINK)  # mode
+            self.assertEqual(index[b"c"][8], c.id)  # sha
 
 
 
 
 class GetUnstagedChangesTests(TestCase):
 class GetUnstagedChangesTests(TestCase):
-
     def test_get_unstaged_changes(self):
     def test_get_unstaged_changes(self):
         """Unit test for get_unstaged_changes."""
         """Unit test for get_unstaged_changes."""
 
 
@@ -635,27 +699,30 @@ class GetUnstagedChangesTests(TestCase):
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
 
 
             # Commit a dummy file then modify it
             # Commit a dummy file then modify it
-            foo1_fullpath = os.path.join(repo_dir, 'foo1')
-            with open(foo1_fullpath, 'wb') as f:
-                f.write(b'origstuff')
+            foo1_fullpath = os.path.join(repo_dir, "foo1")
+            with open(foo1_fullpath, "wb") as f:
+                f.write(b"origstuff")
 
 
-            foo2_fullpath = os.path.join(repo_dir, 'foo2')
-            with open(foo2_fullpath, 'wb') as f:
-                f.write(b'origstuff')
+            foo2_fullpath = os.path.join(repo_dir, "foo2")
+            with open(foo2_fullpath, "wb") as f:
+                f.write(b"origstuff")
 
 
-            repo.stage(['foo1', 'foo2'])
-            repo.do_commit(b'test status', author=b'author <email>',
-                           committer=b'committer <email>')
+            repo.stage(["foo1", "foo2"])
+            repo.do_commit(
+                b"test status",
+                author=b"author <email>",
+                committer=b"committer <email>",
+            )
 
 
-            with open(foo1_fullpath, 'wb') as f:
-                f.write(b'newstuff')
+            with open(foo1_fullpath, "wb") as f:
+                f.write(b"newstuff")
 
 
             # modify access and modify time of path
             # modify access and modify time of path
             os.utime(foo1_fullpath, (0, 0))
             os.utime(foo1_fullpath, (0, 0))
 
 
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
 
 
-            self.assertEqual(list(changes), [b'foo1'])
+            self.assertEqual(list(changes), [b"foo1"])
 
 
     def test_get_unstaged_deleted_changes(self):
     def test_get_unstaged_deleted_changes(self):
         """Unit test for get_unstaged_changes."""
         """Unit test for get_unstaged_changes."""
@@ -665,19 +732,22 @@ class GetUnstagedChangesTests(TestCase):
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
 
 
             # Commit a dummy file then remove it
             # Commit a dummy file then remove it
-            foo1_fullpath = os.path.join(repo_dir, 'foo1')
-            with open(foo1_fullpath, 'wb') as f:
-                f.write(b'origstuff')
+            foo1_fullpath = os.path.join(repo_dir, "foo1")
+            with open(foo1_fullpath, "wb") as f:
+                f.write(b"origstuff")
 
 
-            repo.stage(['foo1'])
-            repo.do_commit(b'test status', author=b'author <email>',
-                           committer=b'committer <email>')
+            repo.stage(["foo1"])
+            repo.do_commit(
+                b"test status",
+                author=b"author <email>",
+                committer=b"committer <email>",
+            )
 
 
             os.unlink(foo1_fullpath)
             os.unlink(foo1_fullpath)
 
 
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
 
 
-            self.assertEqual(list(changes), [b'foo1'])
+            self.assertEqual(list(changes), [b"foo1"])
 
 
     def test_get_unstaged_changes_removed_replaced_by_directory(self):
     def test_get_unstaged_changes_removed_replaced_by_directory(self):
         """Unit test for get_unstaged_changes."""
         """Unit test for get_unstaged_changes."""
@@ -687,22 +757,25 @@ class GetUnstagedChangesTests(TestCase):
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
 
 
             # Commit a dummy file then modify it
             # Commit a dummy file then modify it
-            foo1_fullpath = os.path.join(repo_dir, 'foo1')
-            with open(foo1_fullpath, 'wb') as f:
-                f.write(b'origstuff')
+            foo1_fullpath = os.path.join(repo_dir, "foo1")
+            with open(foo1_fullpath, "wb") as f:
+                f.write(b"origstuff")
 
 
-            repo.stage(['foo1'])
-            repo.do_commit(b'test status', author=b'author <email>',
-                           committer=b'committer <email>')
+            repo.stage(["foo1"])
+            repo.do_commit(
+                b"test status",
+                author=b"author <email>",
+                committer=b"committer <email>",
+            )
 
 
             os.remove(foo1_fullpath)
             os.remove(foo1_fullpath)
             os.mkdir(foo1_fullpath)
             os.mkdir(foo1_fullpath)
 
 
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
 
 
-            self.assertEqual(list(changes), [b'foo1'])
+            self.assertEqual(list(changes), [b"foo1"])
 
 
-    @skipIf(not can_symlink(), 'Requires symlink support')
+    @skipIf(not can_symlink(), "Requires symlink support")
     def test_get_unstaged_changes_removed_replaced_by_link(self):
     def test_get_unstaged_changes_removed_replaced_by_link(self):
         """Unit test for get_unstaged_changes."""
         """Unit test for get_unstaged_changes."""
 
 
@@ -711,24 +784,26 @@ class GetUnstagedChangesTests(TestCase):
         with Repo.init(repo_dir) as repo:
         with Repo.init(repo_dir) as repo:
 
 
             # Commit a dummy file then modify it
             # Commit a dummy file then modify it
-            foo1_fullpath = os.path.join(repo_dir, 'foo1')
-            with open(foo1_fullpath, 'wb') as f:
-                f.write(b'origstuff')
+            foo1_fullpath = os.path.join(repo_dir, "foo1")
+            with open(foo1_fullpath, "wb") as f:
+                f.write(b"origstuff")
 
 
-            repo.stage(['foo1'])
-            repo.do_commit(b'test status', author=b'author <email>',
-                           committer=b'committer <email>')
+            repo.stage(["foo1"])
+            repo.do_commit(
+                b"test status",
+                author=b"author <email>",
+                committer=b"committer <email>",
+            )
 
 
             os.remove(foo1_fullpath)
             os.remove(foo1_fullpath)
             os.symlink(os.path.dirname(foo1_fullpath), foo1_fullpath)
             os.symlink(os.path.dirname(foo1_fullpath), foo1_fullpath)
 
 
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
             changes = get_unstaged_changes(repo.open_index(), repo_dir)
 
 
-            self.assertEqual(list(changes), [b'foo1'])
+            self.assertEqual(list(changes), [b"foo1"])
 
 
 
 
 class TestValidatePathElement(TestCase):
 class TestValidatePathElement(TestCase):
-
     def test_default(self):
     def test_default(self):
         self.assertTrue(validate_path_element_default(b"bla"))
         self.assertTrue(validate_path_element_default(b"bla"))
         self.assertTrue(validate_path_element_default(b".bla"))
         self.assertTrue(validate_path_element_default(b".bla"))
@@ -747,20 +822,20 @@ class TestValidatePathElement(TestCase):
 
 
 
 
 class TestTreeFSPathConversion(TestCase):
 class TestTreeFSPathConversion(TestCase):
-
     def test_tree_to_fs_path(self):
     def test_tree_to_fs_path(self):
-        tree_path = u'délwíçh/foo'.encode('utf8')
-        fs_path = _tree_to_fs_path(b'/prefix/path', tree_path)
+        tree_path = u"délwíçh/foo".encode("utf8")
+        fs_path = _tree_to_fs_path(b"/prefix/path", tree_path)
         self.assertEqual(
         self.assertEqual(
             fs_path,
             fs_path,
-            os.fsencode(os.path.join(u'/prefix/path', u'délwíçh', u'foo')))
+            os.fsencode(os.path.join(u"/prefix/path", u"délwíçh", u"foo")),
+        )
 
 
     def test_fs_to_tree_path_str(self):
     def test_fs_to_tree_path_str(self):
-        fs_path = os.path.join(os.path.join(u'délwíçh', u'foo'))
+        fs_path = os.path.join(os.path.join(u"délwíçh", u"foo"))
         tree_path = _fs_to_tree_path(fs_path)
         tree_path = _fs_to_tree_path(fs_path)
-        self.assertEqual(tree_path, u'délwíçh/foo'.encode('utf-8'))
+        self.assertEqual(tree_path, u"délwíçh/foo".encode("utf-8"))
 
 
     def test_fs_to_tree_path_bytes(self):
     def test_fs_to_tree_path_bytes(self):
-        fs_path = os.path.join(os.fsencode(os.path.join(u'délwíçh', u'foo')))
+        fs_path = os.path.join(os.fsencode(os.path.join(u"délwíçh", u"foo")))
         tree_path = _fs_to_tree_path(fs_path)
         tree_path = _fs_to_tree_path(fs_path)
-        self.assertEqual(tree_path, u'délwíçh/foo'.encode('utf-8'))
+        self.assertEqual(tree_path, u"délwíçh/foo".encode("utf-8"))

+ 3 - 5
dulwich/tests/test_lfs.py

@@ -27,7 +27,6 @@ import tempfile
 
 
 
 
 class LFSTests(TestCase):
 class LFSTests(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(LFSTests, self).setUp()
         super(LFSTests, self).setUp()
         self.test_dir = tempfile.mkdtemp()
         self.test_dir = tempfile.mkdtemp()
@@ -35,10 +34,9 @@ class LFSTests(TestCase):
         self.lfs = LFSStore.create(self.test_dir)
         self.lfs = LFSStore.create(self.test_dir)
 
 
     def test_create(self):
     def test_create(self):
-        sha = self.lfs.write_object([b'a', b'b'])
+        sha = self.lfs.write_object([b"a", b"b"])
         with self.lfs.open_object(sha) as f:
         with self.lfs.open_object(sha) as f:
-            self.assertEqual(b'ab', f.read())
+            self.assertEqual(b"ab", f.read())
 
 
     def test_missing(self):
     def test_missing(self):
-        self.assertRaises(
-            KeyError, self.lfs.open_object, 'abcdeabcdeabcdeabcde')
+        self.assertRaises(KeyError, self.lfs.open_object, "abcdeabcdeabcdeabcde")

+ 4 - 13
dulwich/tests/test_line_ending.py

@@ -40,31 +40,22 @@ class LineEndingConversion(TestCase):
         self.assertEqual(convert_crlf_to_lf(b"foobar"), b"foobar")
         self.assertEqual(convert_crlf_to_lf(b"foobar"), b"foobar")
 
 
     def test_convert_crlf_to_lf(self):
     def test_convert_crlf_to_lf(self):
-        self.assertEqual(
-            convert_crlf_to_lf(b"line1\r\nline2"), b"line1\nline2"
-        )
+        self.assertEqual(convert_crlf_to_lf(b"line1\r\nline2"), b"line1\nline2")
 
 
     def test_convert_crlf_to_lf_mixed(self):
     def test_convert_crlf_to_lf_mixed(self):
-        self.assertEqual(
-            convert_crlf_to_lf(b"line1\r\n\nline2"), b"line1\n\nline2"
-        )
+        self.assertEqual(convert_crlf_to_lf(b"line1\r\n\nline2"), b"line1\n\nline2")
 
 
     def test_convert_lf_to_crlf_no_op(self):
     def test_convert_lf_to_crlf_no_op(self):
         self.assertEqual(convert_lf_to_crlf(b"foobar"), b"foobar")
         self.assertEqual(convert_lf_to_crlf(b"foobar"), b"foobar")
 
 
     def test_convert_lf_to_crlf(self):
     def test_convert_lf_to_crlf(self):
-        self.assertEqual(
-            convert_lf_to_crlf(b"line1\nline2"), b"line1\r\nline2"
-        )
+        self.assertEqual(convert_lf_to_crlf(b"line1\nline2"), b"line1\r\nline2")
 
 
     def test_convert_lf_to_crlf_mixed(self):
     def test_convert_lf_to_crlf_mixed(self):
-        self.assertEqual(
-            convert_lf_to_crlf(b"line1\r\n\nline2"), b"line1\r\n\r\nline2"
-        )
+        self.assertEqual(convert_lf_to_crlf(b"line1\r\n\nline2"), b"line1\r\n\r\nline2")
 
 
 
 
 class GetLineEndingAutocrlfFilters(TestCase):
 class GetLineEndingAutocrlfFilters(TestCase):
-
     def test_get_checkin_filter_autocrlf_default(self):
     def test_get_checkin_filter_autocrlf_default(self):
         checkin_filter = get_checkin_filter_autocrlf(b"false")
         checkin_filter = get_checkin_filter_autocrlf(b"false")
 
 

+ 89 - 88
dulwich/tests/test_lru_cache.py

@@ -21,10 +21,10 @@
 
 
 from dulwich import (
 from dulwich import (
     lru_cache,
     lru_cache,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
-    )
+)
 
 
 
 
 class TestLRUCache(TestCase):
 class TestLRUCache(TestCase):
@@ -43,13 +43,13 @@ class TestLRUCache(TestCase):
     def test_missing(self):
     def test_missing(self):
         cache = lru_cache.LRUCache(max_cache=10)
         cache = lru_cache.LRUCache(max_cache=10)
 
 
-        self.assertFalse('foo' in cache)
-        self.assertRaises(KeyError, cache.__getitem__, 'foo')
+        self.assertFalse("foo" in cache)
+        self.assertRaises(KeyError, cache.__getitem__, "foo")
 
 
-        cache['foo'] = 'bar'
-        self.assertEqual('bar', cache['foo'])
-        self.assertTrue('foo' in cache)
-        self.assertFalse('bar' in cache)
+        cache["foo"] = "bar"
+        self.assertEqual("bar", cache["foo"])
+        self.assertTrue("foo" in cache)
+        self.assertFalse("bar" in cache)
 
 
     def test_map_None(self):
     def test_map_None(self):
         # Make sure that we can properly map None as a key.
         # Make sure that we can properly map None as a key.
@@ -76,28 +76,28 @@ class TestLRUCache(TestCase):
         """Adding extra entries will pop out old ones."""
         """Adding extra entries will pop out old ones."""
         cache = lru_cache.LRUCache(max_cache=1, after_cleanup_count=1)
         cache = lru_cache.LRUCache(max_cache=1, after_cleanup_count=1)
 
 
-        cache['foo'] = 'bar'
+        cache["foo"] = "bar"
         # With a max cache of 1, adding 'baz' should pop out 'foo'
         # With a max cache of 1, adding 'baz' should pop out 'foo'
-        cache['baz'] = 'biz'
+        cache["baz"] = "biz"
 
 
-        self.assertFalse('foo' in cache)
-        self.assertTrue('baz' in cache)
+        self.assertFalse("foo" in cache)
+        self.assertTrue("baz" in cache)
 
 
-        self.assertEqual('biz', cache['baz'])
+        self.assertEqual("biz", cache["baz"])
 
 
     def test_by_usage(self):
     def test_by_usage(self):
         """Accessing entries bumps them up in priority."""
         """Accessing entries bumps them up in priority."""
         cache = lru_cache.LRUCache(max_cache=2)
         cache = lru_cache.LRUCache(max_cache=2)
 
 
-        cache['baz'] = 'biz'
-        cache['foo'] = 'bar'
+        cache["baz"] = "biz"
+        cache["foo"] = "bar"
 
 
-        self.assertEqual('biz', cache['baz'])
+        self.assertEqual("biz", cache["baz"])
 
 
         # This must kick out 'foo' because it was the last accessed
         # This must kick out 'foo' because it was the last accessed
-        cache['nub'] = 'in'
+        cache["nub"] = "in"
 
 
-        self.assertFalse('foo' in cache)
+        self.assertFalse("foo" in cache)
 
 
     def test_cleanup(self):
     def test_cleanup(self):
         """Test that we can use a cleanup function."""
         """Test that we can use a cleanup function."""
@@ -108,17 +108,16 @@ class TestLRUCache(TestCase):
 
 
         cache = lru_cache.LRUCache(max_cache=2, after_cleanup_count=2)
         cache = lru_cache.LRUCache(max_cache=2, after_cleanup_count=2)
 
 
-        cache.add('baz', '1', cleanup=cleanup_func)
-        cache.add('foo', '2', cleanup=cleanup_func)
-        cache.add('biz', '3', cleanup=cleanup_func)
+        cache.add("baz", "1", cleanup=cleanup_func)
+        cache.add("foo", "2", cleanup=cleanup_func)
+        cache.add("biz", "3", cleanup=cleanup_func)
 
 
-        self.assertEqual([('baz', '1')], cleanup_called)
+        self.assertEqual([("baz", "1")], cleanup_called)
 
 
         # 'foo' is now most recent, so final cleanup will call it last
         # 'foo' is now most recent, so final cleanup will call it last
-        cache['foo']
+        cache["foo"]
         cache.clear()
         cache.clear()
-        self.assertEqual([('baz', '1'), ('biz', '3'), ('foo', '2')],
-                         cleanup_called)
+        self.assertEqual([("baz", "1"), ("biz", "3"), ("foo", "2")], cleanup_called)
 
 
     def test_cleanup_on_replace(self):
     def test_cleanup_on_replace(self):
         """Replacing an object should cleanup the old value."""
         """Replacing an object should cleanup the old value."""
@@ -166,8 +165,10 @@ class TestLRUCache(TestCase):
 
 
         # We hit the max
         # We hit the max
         self.assertEqual(10, len(cache))
         self.assertEqual(10, len(cache))
-        self.assertEqual([11, 10, 9, 1, 8, 7, 6, 5, 4, 3],
-                         [n.key for n in cache._walk_lru()])
+        self.assertEqual(
+            [11, 10, 9, 1, 8, 7, 6, 5, 4, 3],
+            [n.key for n in cache._walk_lru()],
+        )
 
 
     def test_cleanup_shrinks_to_after_clean_count(self):
     def test_cleanup_shrinks_to_after_clean_count(self):
         cache = lru_cache.LRUCache(max_cache=5, after_cleanup_count=3)
         cache = lru_cache.LRUCache(max_cache=5, after_cleanup_count=3)
@@ -293,11 +294,10 @@ class TestLRUCache(TestCase):
 
 
 
 
 class TestLRUSizeCache(TestCase):
 class TestLRUSizeCache(TestCase):
-
     def test_basic_init(self):
     def test_basic_init(self):
         cache = lru_cache.LRUSizeCache()
         cache = lru_cache.LRUSizeCache()
         self.assertEqual(2048, cache._max_cache)
         self.assertEqual(2048, cache._max_cache)
-        self.assertEqual(int(cache._max_size*0.8), cache._after_cleanup_size)
+        self.assertEqual(int(cache._max_size * 0.8), cache._after_cleanup_size)
         self.assertEqual(0, cache._value_size)
         self.assertEqual(0, cache._value_size)
 
 
     def test_add__null_key(self):
     def test_add__null_key(self):
@@ -307,15 +307,15 @@ class TestLRUSizeCache(TestCase):
     def test_add_tracks_size(self):
     def test_add_tracks_size(self):
         cache = lru_cache.LRUSizeCache()
         cache = lru_cache.LRUSizeCache()
         self.assertEqual(0, cache._value_size)
         self.assertEqual(0, cache._value_size)
-        cache.add('my key', 'my value text')
+        cache.add("my key", "my value text")
         self.assertEqual(13, cache._value_size)
         self.assertEqual(13, cache._value_size)
 
 
     def test_remove_tracks_size(self):
     def test_remove_tracks_size(self):
         cache = lru_cache.LRUSizeCache()
         cache = lru_cache.LRUSizeCache()
         self.assertEqual(0, cache._value_size)
         self.assertEqual(0, cache._value_size)
-        cache.add('my key', 'my value text')
+        cache.add("my key", "my value text")
         self.assertEqual(13, cache._value_size)
         self.assertEqual(13, cache._value_size)
-        node = cache._cache['my key']
+        node = cache._cache["my key"]
         cache._remove_node(node)
         cache._remove_node(node)
         self.assertEqual(0, cache._value_size)
         self.assertEqual(0, cache._value_size)
 
 
@@ -324,21 +324,21 @@ class TestLRUSizeCache(TestCase):
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=5)
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=5)
         self.assertEqual(0, cache._value_size)
         self.assertEqual(0, cache._value_size)
         self.assertEqual({}, cache.items())
         self.assertEqual({}, cache.items())
-        cache.add('test', 'key')
+        cache.add("test", "key")
         self.assertEqual(3, cache._value_size)
         self.assertEqual(3, cache._value_size)
-        self.assertEqual({'test': 'key'}, cache.items())
-        cache.add('test2', 'key that is too big')
+        self.assertEqual({"test": "key"}, cache.items())
+        cache.add("test2", "key that is too big")
         self.assertEqual(3, cache._value_size)
         self.assertEqual(3, cache._value_size)
-        self.assertEqual({'test': 'key'}, cache.items())
+        self.assertEqual({"test": "key"}, cache.items())
         # If we would add a key, only to cleanup and remove all cached entries,
         # If we would add a key, only to cleanup and remove all cached entries,
         # then obviously that value should not be stored
         # then obviously that value should not be stored
-        cache.add('test3', 'bigkey')
+        cache.add("test3", "bigkey")
         self.assertEqual(3, cache._value_size)
         self.assertEqual(3, cache._value_size)
-        self.assertEqual({'test': 'key'}, cache.items())
+        self.assertEqual({"test": "key"}, cache.items())
 
 
-        cache.add('test4', 'bikey')
+        cache.add("test4", "bikey")
         self.assertEqual(3, cache._value_size)
         self.assertEqual(3, cache._value_size)
-        self.assertEqual({'test': 'key'}, cache.items())
+        self.assertEqual({"test": "key"}, cache.items())
 
 
     def test_no_add_over_size_cleanup(self):
     def test_no_add_over_size_cleanup(self):
         """If a large value is not cached, we will call cleanup right away."""
         """If a large value is not cached, we will call cleanup right away."""
@@ -350,63 +350,64 @@ class TestLRUSizeCache(TestCase):
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=5)
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=5)
         self.assertEqual(0, cache._value_size)
         self.assertEqual(0, cache._value_size)
         self.assertEqual({}, cache.items())
         self.assertEqual({}, cache.items())
-        cache.add('test', 'key that is too big', cleanup=cleanup)
+        cache.add("test", "key that is too big", cleanup=cleanup)
         # key was not added
         # key was not added
         self.assertEqual(0, cache._value_size)
         self.assertEqual(0, cache._value_size)
         self.assertEqual({}, cache.items())
         self.assertEqual({}, cache.items())
         # and cleanup was called
         # and cleanup was called
-        self.assertEqual([('test', 'key that is too big')], cleanup_calls)
+        self.assertEqual([("test", "key that is too big")], cleanup_calls)
 
 
     def test_adding_clears_cache_based_on_size(self):
     def test_adding_clears_cache_based_on_size(self):
         """The cache is cleared in LRU order until small enough"""
         """The cache is cleared in LRU order until small enough"""
         cache = lru_cache.LRUSizeCache(max_size=20)
         cache = lru_cache.LRUSizeCache(max_size=20)
-        cache.add('key1', 'value')  # 5 chars
-        cache.add('key2', 'value2')  # 6 chars
-        cache.add('key3', 'value23')  # 7 chars
-        self.assertEqual(5+6+7, cache._value_size)
-        cache['key2']  # reference key2 so it gets a newer reference time
-        cache.add('key4', 'value234')  # 8 chars, over limit
+        cache.add("key1", "value")  # 5 chars
+        cache.add("key2", "value2")  # 6 chars
+        cache.add("key3", "value23")  # 7 chars
+        self.assertEqual(5 + 6 + 7, cache._value_size)
+        cache["key2"]  # reference key2 so it gets a newer reference time
+        cache.add("key4", "value234")  # 8 chars, over limit
         # We have to remove 2 keys to get back under limit
         # We have to remove 2 keys to get back under limit
-        self.assertEqual(6+8, cache._value_size)
-        self.assertEqual({'key2': 'value2', 'key4': 'value234'},
-                         cache.items())
+        self.assertEqual(6 + 8, cache._value_size)
+        self.assertEqual({"key2": "value2", "key4": "value234"}, cache.items())
 
 
     def test_adding_clears_to_after_cleanup_size(self):
     def test_adding_clears_to_after_cleanup_size(self):
         cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10)
         cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10)
-        cache.add('key1', 'value')  # 5 chars
-        cache.add('key2', 'value2')  # 6 chars
-        cache.add('key3', 'value23')  # 7 chars
-        self.assertEqual(5+6+7, cache._value_size)
-        cache['key2']  # reference key2 so it gets a newer reference time
-        cache.add('key4', 'value234')  # 8 chars, over limit
+        cache.add("key1", "value")  # 5 chars
+        cache.add("key2", "value2")  # 6 chars
+        cache.add("key3", "value23")  # 7 chars
+        self.assertEqual(5 + 6 + 7, cache._value_size)
+        cache["key2"]  # reference key2 so it gets a newer reference time
+        cache.add("key4", "value234")  # 8 chars, over limit
         # We have to remove 3 keys to get back under limit
         # We have to remove 3 keys to get back under limit
         self.assertEqual(8, cache._value_size)
         self.assertEqual(8, cache._value_size)
-        self.assertEqual({'key4': 'value234'}, cache.items())
+        self.assertEqual({"key4": "value234"}, cache.items())
 
 
     def test_custom_sizes(self):
     def test_custom_sizes(self):
         def size_of_list(lst):
         def size_of_list(lst):
             return sum(len(x) for x in lst)
             return sum(len(x) for x in lst)
-        cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10,
-                                       compute_size=size_of_list)
-
-        cache.add('key1', ['val', 'ue'])  # 5 chars
-        cache.add('key2', ['val', 'ue2'])  # 6 chars
-        cache.add('key3', ['val', 'ue23'])  # 7 chars
-        self.assertEqual(5+6+7, cache._value_size)
-        cache['key2']  # reference key2 so it gets a newer reference time
-        cache.add('key4', ['value', '234'])  # 8 chars, over limit
+
+        cache = lru_cache.LRUSizeCache(
+            max_size=20, after_cleanup_size=10, compute_size=size_of_list
+        )
+
+        cache.add("key1", ["val", "ue"])  # 5 chars
+        cache.add("key2", ["val", "ue2"])  # 6 chars
+        cache.add("key3", ["val", "ue23"])  # 7 chars
+        self.assertEqual(5 + 6 + 7, cache._value_size)
+        cache["key2"]  # reference key2 so it gets a newer reference time
+        cache.add("key4", ["value", "234"])  # 8 chars, over limit
         # We have to remove 3 keys to get back under limit
         # We have to remove 3 keys to get back under limit
         self.assertEqual(8, cache._value_size)
         self.assertEqual(8, cache._value_size)
-        self.assertEqual({'key4': ['value', '234']}, cache.items())
+        self.assertEqual({"key4": ["value", "234"]}, cache.items())
 
 
     def test_cleanup(self):
     def test_cleanup(self):
         cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10)
         cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10)
 
 
         # Add these in order
         # Add these in order
-        cache.add('key1', 'value')  # 5 chars
-        cache.add('key2', 'value2')  # 6 chars
-        cache.add('key3', 'value23')  # 7 chars
-        self.assertEqual(5+6+7, cache._value_size)
+        cache.add("key1", "value")  # 5 chars
+        cache.add("key2", "value2")  # 6 chars
+        cache.add("key3", "value23")  # 7 chars
+        self.assertEqual(5 + 6 + 7, cache._value_size)
 
 
         cache.cleanup()
         cache.cleanup()
         # Only the most recent fits after cleaning up
         # Only the most recent fits after cleaning up
@@ -415,40 +416,40 @@ class TestLRUSizeCache(TestCase):
     def test_keys(self):
     def test_keys(self):
         cache = lru_cache.LRUSizeCache(max_size=10)
         cache = lru_cache.LRUSizeCache(max_size=10)
 
 
-        cache[1] = 'a'
-        cache[2] = 'b'
-        cache[3] = 'cdef'
+        cache[1] = "a"
+        cache[2] = "b"
+        cache[3] = "cdef"
         self.assertEqual([1, 2, 3], sorted(cache.keys()))
         self.assertEqual([1, 2, 3], sorted(cache.keys()))
 
 
     def test_resize_smaller(self):
     def test_resize_smaller(self):
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=9)
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=9)
-        cache[1] = 'abc'
-        cache[2] = 'def'
-        cache[3] = 'ghi'
-        cache[4] = 'jkl'
+        cache[1] = "abc"
+        cache[2] = "def"
+        cache[3] = "ghi"
+        cache[4] = "jkl"
         # Triggers a cleanup
         # Triggers a cleanup
         self.assertEqual([2, 3, 4], sorted(cache.keys()))
         self.assertEqual([2, 3, 4], sorted(cache.keys()))
         # Resize should also cleanup again
         # Resize should also cleanup again
         cache.resize(max_size=6, after_cleanup_size=4)
         cache.resize(max_size=6, after_cleanup_size=4)
         self.assertEqual([4], sorted(cache.keys()))
         self.assertEqual([4], sorted(cache.keys()))
         # Adding should use the new max size
         # Adding should use the new max size
-        cache[5] = 'mno'
+        cache[5] = "mno"
         self.assertEqual([4, 5], sorted(cache.keys()))
         self.assertEqual([4, 5], sorted(cache.keys()))
-        cache[6] = 'pqr'
+        cache[6] = "pqr"
         self.assertEqual([6], sorted(cache.keys()))
         self.assertEqual([6], sorted(cache.keys()))
 
 
     def test_resize_larger(self):
     def test_resize_larger(self):
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=9)
         cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=9)
-        cache[1] = 'abc'
-        cache[2] = 'def'
-        cache[3] = 'ghi'
-        cache[4] = 'jkl'
+        cache[1] = "abc"
+        cache[2] = "def"
+        cache[3] = "ghi"
+        cache[4] = "jkl"
         # Triggers a cleanup
         # Triggers a cleanup
         self.assertEqual([2, 3, 4], sorted(cache.keys()))
         self.assertEqual([2, 3, 4], sorted(cache.keys()))
         cache.resize(max_size=15, after_cleanup_size=12)
         cache.resize(max_size=15, after_cleanup_size=12)
         self.assertEqual([2, 3, 4], sorted(cache.keys()))
         self.assertEqual([2, 3, 4], sorted(cache.keys()))
-        cache[5] = 'mno'
-        cache[6] = 'pqr'
+        cache[5] = "mno"
+        cache[6] = "pqr"
         self.assertEqual([2, 3, 4, 5, 6], sorted(cache.keys()))
         self.assertEqual([2, 3, 4, 5, 6], sorted(cache.keys()))
-        cache[7] = 'stu'
+        cache[7] = "stu"
         self.assertEqual([4, 5, 6, 7], sorted(cache.keys()))
         self.assertEqual([4, 5, 6, 7], sorted(cache.keys()))

+ 51 - 40
dulwich/tests/test_mailmap.py

@@ -28,9 +28,9 @@ from dulwich.mailmap import Mailmap, read_mailmap
 
 
 
 
 class ReadMailmapTests(TestCase):
 class ReadMailmapTests(TestCase):
-
     def test_read(self):
     def test_read(self):
-        b = BytesIO(b"""\
+        b = BytesIO(
+            b"""\
 Jane Doe         <jane@desktop.(none)>
 Jane Doe         <jane@desktop.(none)>
 Joe R. Developer <joe@example.com>
 Joe R. Developer <joe@example.com>
 # A comment
 # A comment
@@ -39,52 +39,63 @@ Some Dude <some@dude.xx>         nick1 <bugs@company.xx>
 Other Author <other@author.xx>   nick2 <bugs@company.xx>
 Other Author <other@author.xx>   nick2 <bugs@company.xx>
 Other Author <other@author.xx>         <nick2@company.xx>
 Other Author <other@author.xx>         <nick2@company.xx>
 Santa Claus <santa.claus@northpole.xx> <me@company.xx>
 Santa Claus <santa.claus@northpole.xx> <me@company.xx>
-""")
-        self.assertEqual([
-            ((b'Jane Doe', b'jane@desktop.(none)'), None),
-            ((b'Joe R. Developer', b'joe@example.com'), None),
-            ((None, b'cto@company.xx'), (None, b'cto@coompany.xx')),
-            ((b'Some Dude', b'some@dude.xx'), (b'nick1', b'bugs@company.xx')),
-            ((b'Other Author', b'other@author.xx'),
-                (b'nick2', b'bugs@company.xx')),
-            ((b'Other Author', b'other@author.xx'),
-                (None, b'nick2@company.xx')),
-            ((b'Santa Claus', b'santa.claus@northpole.xx'),
-                (None, b'me@company.xx'))],
-            list(read_mailmap(b)))
+"""
+        )
+        self.assertEqual(
+            [
+                ((b"Jane Doe", b"jane@desktop.(none)"), None),
+                ((b"Joe R. Developer", b"joe@example.com"), None),
+                ((None, b"cto@company.xx"), (None, b"cto@coompany.xx")),
+                (
+                    (b"Some Dude", b"some@dude.xx"),
+                    (b"nick1", b"bugs@company.xx"),
+                ),
+                (
+                    (b"Other Author", b"other@author.xx"),
+                    (b"nick2", b"bugs@company.xx"),
+                ),
+                (
+                    (b"Other Author", b"other@author.xx"),
+                    (None, b"nick2@company.xx"),
+                ),
+                (
+                    (b"Santa Claus", b"santa.claus@northpole.xx"),
+                    (None, b"me@company.xx"),
+                ),
+            ],
+            list(read_mailmap(b)),
+        )
 
 
 
 
 class MailmapTests(TestCase):
 class MailmapTests(TestCase):
-
     def test_lookup(self):
     def test_lookup(self):
         m = Mailmap()
         m = Mailmap()
-        m.add_entry((b'Jane Doe', b'jane@desktop.(none)'), (None, None))
-        m.add_entry((b'Joe R. Developer', b'joe@example.com'), None)
-        m.add_entry((None, b'cto@company.xx'), (None, b'cto@coompany.xx'))
+        m.add_entry((b"Jane Doe", b"jane@desktop.(none)"), (None, None))
+        m.add_entry((b"Joe R. Developer", b"joe@example.com"), None)
+        m.add_entry((None, b"cto@company.xx"), (None, b"cto@coompany.xx"))
+        m.add_entry((b"Some Dude", b"some@dude.xx"), (b"nick1", b"bugs@company.xx"))
         m.add_entry(
         m.add_entry(
-                (b'Some Dude', b'some@dude.xx'),
-                (b'nick1', b'bugs@company.xx'))
+            (b"Other Author", b"other@author.xx"),
+            (b"nick2", b"bugs@company.xx"),
+        )
+        m.add_entry((b"Other Author", b"other@author.xx"), (None, b"nick2@company.xx"))
         m.add_entry(
         m.add_entry(
-                (b'Other Author', b'other@author.xx'),
-                (b'nick2', b'bugs@company.xx'))
-        m.add_entry(
-                (b'Other Author', b'other@author.xx'),
-                (None, b'nick2@company.xx'))
-        m.add_entry(
-                (b'Santa Claus', b'santa.claus@northpole.xx'),
-                (None, b'me@company.xx'))
-        self.assertEqual(
-            b'Jane Doe <jane@desktop.(none)>',
-            m.lookup(b'Jane Doe <jane@desktop.(none)>'))
+            (b"Santa Claus", b"santa.claus@northpole.xx"),
+            (None, b"me@company.xx"),
+        )
         self.assertEqual(
         self.assertEqual(
-            b'Jane Doe <jane@desktop.(none)>',
-            m.lookup(b'Jane Doe <jane@example.com>'))
+            b"Jane Doe <jane@desktop.(none)>",
+            m.lookup(b"Jane Doe <jane@desktop.(none)>"),
+        )
         self.assertEqual(
         self.assertEqual(
-            b'Jane Doe <jane@desktop.(none)>',
-            m.lookup(b'Jane D. <jane@desktop.(none)>'))
+            b"Jane Doe <jane@desktop.(none)>",
+            m.lookup(b"Jane Doe <jane@example.com>"),
+        )
         self.assertEqual(
         self.assertEqual(
-            b'Some Dude <some@dude.xx>',
-            m.lookup(b'nick1 <bugs@company.xx>'))
+            b"Jane Doe <jane@desktop.(none)>",
+            m.lookup(b"Jane D. <jane@desktop.(none)>"),
+        )
         self.assertEqual(
         self.assertEqual(
-            b'CTO <cto@company.xx>',
-            m.lookup(b'CTO <cto@coompany.xx>'))
+            b"Some Dude <some@dude.xx>", m.lookup(b"nick1 <bugs@company.xx>")
+        )
+        self.assertEqual(b"CTO <cto@company.xx>", m.lookup(b"CTO <cto@coompany.xx>"))

+ 142 - 85
dulwich/tests/test_missing_obj_finder.py

@@ -20,57 +20,60 @@
 
 
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
-    )
+)
 from dulwich.tests import TestCase
 from dulwich.tests import TestCase
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     make_object,
     make_object,
     make_tag,
     make_tag,
     build_commit_graph,
     build_commit_graph,
-    )
+)
 
 
 
 
 class MissingObjectFinderTest(TestCase):
 class MissingObjectFinderTest(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(MissingObjectFinderTest, self).setUp()
         super(MissingObjectFinderTest, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
         self.commits = []
         self.commits = []
 
 
     def cmt(self, n):
     def cmt(self, n):
-        return self.commits[n-1]
+        return self.commits[n - 1]
 
 
     def assertMissingMatch(self, haves, wants, expected):
     def assertMissingMatch(self, haves, wants, expected):
         for sha, path in self.store.find_missing_objects(haves, wants, set()):
         for sha, path in self.store.find_missing_objects(haves, wants, set()):
             self.assertTrue(
             self.assertTrue(
-                    sha in expected,
-                    "(%s,%s) erroneously reported as missing" % (sha, path))
+                sha in expected,
+                "(%s,%s) erroneously reported as missing" % (sha, path),
+            )
             expected.remove(sha)
             expected.remove(sha)
 
 
         self.assertEqual(
         self.assertEqual(
-                len(expected), 0,
-                "some objects are not reported as missing: %s" % (expected, ))
+            len(expected),
+            0,
+            "some objects are not reported as missing: %s" % (expected,),
+        )
 
 
 
 
 class MOFLinearRepoTest(MissingObjectFinderTest):
 class MOFLinearRepoTest(MissingObjectFinderTest):
-
     def setUp(self):
     def setUp(self):
         super(MOFLinearRepoTest, self).setUp()
         super(MOFLinearRepoTest, self).setUp()
         # present in 1, removed in 3
         # present in 1, removed in 3
-        f1_1 = make_object(Blob, data=b'f1')
+        f1_1 = make_object(Blob, data=b"f1")
         # present in all revisions, changed in 2 and 3
         # present in all revisions, changed in 2 and 3
-        f2_1 = make_object(Blob, data=b'f2')
-        f2_2 = make_object(Blob, data=b'f2-changed')
-        f2_3 = make_object(Blob, data=b'f2-changed-again')
+        f2_1 = make_object(Blob, data=b"f2")
+        f2_2 = make_object(Blob, data=b"f2-changed")
+        f2_3 = make_object(Blob, data=b"f2-changed-again")
         # added in 2, left unmodified in 3
         # added in 2, left unmodified in 3
-        f3_2 = make_object(Blob, data=b'f3')
+        f3_2 = make_object(Blob, data=b"f3")
 
 
         commit_spec = [[1], [2, 1], [3, 2]]
         commit_spec = [[1], [2, 1], [3, 2]]
-        trees = {1: [(b'f1', f1_1), (b'f2', f2_1)],
-                 2: [(b'f1', f1_1), (b'f2', f2_2), (b'f3', f3_2)],
-                 3: [(b'f2', f2_3), (b'f3', f3_2)]}
+        trees = {
+            1: [(b"f1", f1_1), (b"f2", f2_1)],
+            2: [(b"f1", f1_1), (b"f2", f2_2), (b"f3", f3_2)],
+            3: [(b"f2", f2_3), (b"f3", f3_2)],
+        }
         # commit 1: f1 and f2
         # commit 1: f1 and f2
         # commit 2: f3 added, f2 changed. Missing shall report commit id and a
         # commit 2: f3 added, f2 changed. Missing shall report commit id and a
         # tree referenced by commit
         # tree referenced by commit
@@ -80,24 +83,23 @@ class MOFLinearRepoTest(MissingObjectFinderTest):
         self.missing_1_2 = [self.cmt(2).id, self.cmt(2).tree, f2_2.id, f3_2.id]
         self.missing_1_2 = [self.cmt(2).id, self.cmt(2).tree, f2_2.id, f3_2.id]
         self.missing_2_3 = [self.cmt(3).id, self.cmt(3).tree, f2_3.id]
         self.missing_2_3 = [self.cmt(3).id, self.cmt(3).tree, f2_3.id]
         self.missing_1_3 = [
         self.missing_1_3 = [
-            self.cmt(2).id, self.cmt(3).id,
-            self.cmt(2).tree, self.cmt(3).tree,
-            f2_2.id, f3_2.id, f2_3.id]
+            self.cmt(2).id,
+            self.cmt(3).id,
+            self.cmt(2).tree,
+            self.cmt(3).tree,
+            f2_2.id,
+            f3_2.id,
+            f2_3.id,
+        ]
 
 
     def test_1_to_2(self):
     def test_1_to_2(self):
-        self.assertMissingMatch(
-                [self.cmt(1).id], [self.cmt(2).id],
-                self.missing_1_2)
+        self.assertMissingMatch([self.cmt(1).id], [self.cmt(2).id], self.missing_1_2)
 
 
     def test_2_to_3(self):
     def test_2_to_3(self):
-        self.assertMissingMatch(
-                [self.cmt(2).id], [self.cmt(3).id],
-                self.missing_2_3)
+        self.assertMissingMatch([self.cmt(2).id], [self.cmt(3).id], self.missing_2_3)
 
 
     def test_1_to_3(self):
     def test_1_to_3(self):
-        self.assertMissingMatch(
-                [self.cmt(1).id], [self.cmt(3).id],
-                self.missing_1_3)
+        self.assertMissingMatch([self.cmt(1).id], [self.cmt(3).id], self.missing_1_3)
 
 
     def test_bogus_haves(self):
     def test_bogus_haves(self):
         """Ensure non-existent SHA in haves are tolerated"""
         """Ensure non-existent SHA in haves are tolerated"""
@@ -112,7 +114,8 @@ class MOFLinearRepoTest(MissingObjectFinderTest):
         haves = [self.cmt(1).id]
         haves = [self.cmt(1).id]
         wants = [self.cmt(3).id, bogus_sha]
         wants = [self.cmt(3).id, bogus_sha]
         self.assertRaises(
         self.assertRaises(
-                KeyError, self.store.find_missing_objects, haves, wants, set())
+            KeyError, self.store.find_missing_objects, haves, wants, set()
+        )
 
 
     def test_no_changes(self):
     def test_no_changes(self):
         self.assertMissingMatch([self.cmt(3).id], [self.cmt(3).id], [])
         self.assertMissingMatch([self.cmt(3).id], [self.cmt(3).id], [])
@@ -127,25 +130,27 @@ class MOFMergeForkRepoTest(MissingObjectFinderTest):
 
 
     def setUp(self):
     def setUp(self):
         super(MOFMergeForkRepoTest, self).setUp()
         super(MOFMergeForkRepoTest, self).setUp()
-        f1_1 = make_object(Blob, data=b'f1')
-        f1_2 = make_object(Blob, data=b'f1-2')
-        f1_4 = make_object(Blob, data=b'f1-4')
-        f1_7 = make_object(Blob, data=b'f1-2')  # same data as in rev 2
-        f2_1 = make_object(Blob, data=b'f2')
-        f2_3 = make_object(Blob, data=b'f2-3')
-        f3_3 = make_object(Blob, data=b'f3')
-        f3_5 = make_object(Blob, data=b'f3-5')
+        f1_1 = make_object(Blob, data=b"f1")
+        f1_2 = make_object(Blob, data=b"f1-2")
+        f1_4 = make_object(Blob, data=b"f1-4")
+        f1_7 = make_object(Blob, data=b"f1-2")  # same data as in rev 2
+        f2_1 = make_object(Blob, data=b"f2")
+        f2_3 = make_object(Blob, data=b"f2-3")
+        f3_3 = make_object(Blob, data=b"f3")
+        f3_5 = make_object(Blob, data=b"f3-5")
         commit_spec = [[1], [2, 1], [3, 2], [4, 2], [5, 3], [6, 3, 4], [7, 6]]
         commit_spec = [[1], [2, 1], [3, 2], [4, 2], [5, 3], [6, 3, 4], [7, 6]]
-        trees = {1: [(b'f1', f1_1), (b'f2', f2_1)],
-                 2: [(b'f1', f1_2), (b'f2', f2_1)],  # f1 changed
-                 # f3 added, f2 changed
-                 3: [(b'f1', f1_2), (b'f2', f2_3), (b'f3', f3_3)],
-                 4: [(b'f1', f1_4), (b'f2', f2_1)],  # f1 changed
-                 5: [(b'f1', f1_2), (b'f3', f3_5)],  # f2 removed, f3 changed
-                 # merged 3 and 4
-                 6: [(b'f1', f1_4), (b'f2', f2_3), (b'f3', f3_3)],
-                 # f1 changed to match rev2. f3 removed
-                 7: [(b'f1', f1_7), (b'f2', f2_3)]}
+        trees = {
+            1: [(b"f1", f1_1), (b"f2", f2_1)],
+            2: [(b"f1", f1_2), (b"f2", f2_1)],  # f1 changed
+            # f3 added, f2 changed
+            3: [(b"f1", f1_2), (b"f2", f2_3), (b"f3", f3_3)],
+            4: [(b"f1", f1_4), (b"f2", f2_1)],  # f1 changed
+            5: [(b"f1", f1_2), (b"f3", f3_5)],  # f2 removed, f3 changed
+            # merged 3 and 4
+            6: [(b"f1", f1_4), (b"f2", f2_3), (b"f3", f3_3)],
+            # f1 changed to match rev2. f3 removed
+            7: [(b"f1", f1_7), (b"f2", f2_3)],
+        }
         self.commits = build_commit_graph(self.store, commit_spec, trees)
         self.commits = build_commit_graph(self.store, commit_spec, trees)
 
 
         self.f1_2_id = f1_2.id
         self.f1_2_id = f1_2.id
@@ -164,54 +169,96 @@ class MOFMergeForkRepoTest(MissingObjectFinderTest):
         # doesn't record f1_2 was known prior to that, hence can't detect f1_7
         # doesn't record f1_2 was known prior to that, hence can't detect f1_7
         # is in fact f1_2 and shall not be reported)
         # is in fact f1_2 and shall not be reported)
         self.assertMissingMatch(
         self.assertMissingMatch(
-                [self.cmt(6).id], [self.cmt(7).id],
-                [self.cmt(7).id, self.cmt(7).tree, self.f1_7_id])
+            [self.cmt(6).id],
+            [self.cmt(7).id],
+            [self.cmt(7).id, self.cmt(7).tree, self.f1_7_id],
+        )
 
 
     def test_have4_want7(self):
     def test_have4_want7(self):
         # have 4, want 7. Shall not include rev5 as it is not in the tree
         # have 4, want 7. Shall not include rev5 as it is not in the tree
         # between 4 and 7 (well, it is, but its SHA's are irrelevant for 4..7
         # between 4 and 7 (well, it is, but its SHA's are irrelevant for 4..7
         # commit hierarchy)
         # commit hierarchy)
-        self.assertMissingMatch([self.cmt(4).id], [self.cmt(7).id], [
-            self.cmt(7).id, self.cmt(6).id, self.cmt(3).id,
-            self.cmt(7).tree, self.cmt(6).tree, self.cmt(3).tree,
-            self.f2_3_id, self.f3_3_id])
+        self.assertMissingMatch(
+            [self.cmt(4).id],
+            [self.cmt(7).id],
+            [
+                self.cmt(7).id,
+                self.cmt(6).id,
+                self.cmt(3).id,
+                self.cmt(7).tree,
+                self.cmt(6).tree,
+                self.cmt(3).tree,
+                self.f2_3_id,
+                self.f3_3_id,
+            ],
+        )
 
 
     def test_have1_want6(self):
     def test_have1_want6(self):
         # have 1, want 6. Shall not include rev5
         # have 1, want 6. Shall not include rev5
-        self.assertMissingMatch([self.cmt(1).id], [self.cmt(6).id], [
-            self.cmt(6).id, self.cmt(4).id, self.cmt(3).id, self.cmt(2).id,
-            self.cmt(6).tree, self.cmt(4).tree, self.cmt(3).tree,
-            self.cmt(2).tree, self.f1_2_id, self.f1_4_id, self.f2_3_id,
-            self.f3_3_id])
+        self.assertMissingMatch(
+            [self.cmt(1).id],
+            [self.cmt(6).id],
+            [
+                self.cmt(6).id,
+                self.cmt(4).id,
+                self.cmt(3).id,
+                self.cmt(2).id,
+                self.cmt(6).tree,
+                self.cmt(4).tree,
+                self.cmt(3).tree,
+                self.cmt(2).tree,
+                self.f1_2_id,
+                self.f1_4_id,
+                self.f2_3_id,
+                self.f3_3_id,
+            ],
+        )
 
 
     def test_have3_want6(self):
     def test_have3_want6(self):
         # have 3, want 7. Shall not report rev2 and its tree, because
         # have 3, want 7. Shall not report rev2 and its tree, because
         # haves(3) means has parents, i.e. rev2, too
         # haves(3) means has parents, i.e. rev2, too
         # BUT shall report any changes descending rev2 (excluding rev3)
         # BUT shall report any changes descending rev2 (excluding rev3)
-        # Shall NOT report f1_7 as it's techically == f1_2
-        self.assertMissingMatch([self.cmt(3).id], [self.cmt(7).id], [
-              self.cmt(7).id, self.cmt(6).id, self.cmt(4).id,
-              self.cmt(7).tree, self.cmt(6).tree, self.cmt(4).tree,
-              self.f1_4_id])
+        # Shall NOT report f1_7 as it's technically == f1_2
+        self.assertMissingMatch(
+            [self.cmt(3).id],
+            [self.cmt(7).id],
+            [
+                self.cmt(7).id,
+                self.cmt(6).id,
+                self.cmt(4).id,
+                self.cmt(7).tree,
+                self.cmt(6).tree,
+                self.cmt(4).tree,
+                self.f1_4_id,
+            ],
+        )
 
 
     def test_have5_want7(self):
     def test_have5_want7(self):
         # have 5, want 7. Common parent is rev2, hence children of rev2 from
         # have 5, want 7. Common parent is rev2, hence children of rev2 from
         # a descent line other than rev5 shall be reported
         # a descent line other than rev5 shall be reported
         # expects f1_4 from rev6. f3_5 is known in rev5;
         # expects f1_4 from rev6. f3_5 is known in rev5;
         # f1_7 shall be the same as f1_2 (known, too)
         # f1_7 shall be the same as f1_2 (known, too)
-        self.assertMissingMatch([self.cmt(5).id], [self.cmt(7).id], [
-              self.cmt(7).id, self.cmt(6).id, self.cmt(4).id,
-              self.cmt(7).tree, self.cmt(6).tree, self.cmt(4).tree,
-              self.f1_4_id])
+        self.assertMissingMatch(
+            [self.cmt(5).id],
+            [self.cmt(7).id],
+            [
+                self.cmt(7).id,
+                self.cmt(6).id,
+                self.cmt(4).id,
+                self.cmt(7).tree,
+                self.cmt(6).tree,
+                self.cmt(4).tree,
+                self.f1_4_id,
+            ],
+        )
 
 
 
 
 class MOFTagsTest(MissingObjectFinderTest):
 class MOFTagsTest(MissingObjectFinderTest):
-
     def setUp(self):
     def setUp(self):
         super(MOFTagsTest, self).setUp()
         super(MOFTagsTest, self).setUp()
-        f1_1 = make_object(Blob, data=b'f1')
+        f1_1 = make_object(Blob, data=b"f1")
         commit_spec = [[1]]
         commit_spec = [[1]]
-        trees = {1: [(b'f1', f1_1)]}
+        trees = {1: [(b"f1", f1_1)]}
         self.commits = build_commit_graph(self.store, commit_spec, trees)
         self.commits = build_commit_graph(self.store, commit_spec, trees)
 
 
         self._normal_tag = make_tag(self.cmt(1))
         self._normal_tag = make_tag(self.cmt(1))
@@ -234,28 +281,38 @@ class MOFTagsTest(MissingObjectFinderTest):
     def test_tagged_commit(self):
     def test_tagged_commit(self):
         # The user already has the tagged commit, all they want is the tag,
         # The user already has the tagged commit, all they want is the tag,
         # so send them only the tag object.
         # so send them only the tag object.
-        self.assertMissingMatch([self.cmt(1).id], [self._normal_tag.id],
-                                [self._normal_tag.id])
+        self.assertMissingMatch(
+            [self.cmt(1).id], [self._normal_tag.id], [self._normal_tag.id]
+        )
 
 
     # The remaining cases are unusual, but do happen in the wild.
     # The remaining cases are unusual, but do happen in the wild.
     def test_tagged_tag(self):
     def test_tagged_tag(self):
         # User already has tagged tag, send only tag of tag
         # User already has tagged tag, send only tag of tag
-        self.assertMissingMatch([self._normal_tag.id], [self._tag_of_tag.id],
-                                [self._tag_of_tag.id])
+        self.assertMissingMatch(
+            [self._normal_tag.id], [self._tag_of_tag.id], [self._tag_of_tag.id]
+        )
         # User needs both tags, but already has commit
         # User needs both tags, but already has commit
-        self.assertMissingMatch([self.cmt(1).id], [self._tag_of_tag.id],
-                                [self._normal_tag.id, self._tag_of_tag.id])
+        self.assertMissingMatch(
+            [self.cmt(1).id],
+            [self._tag_of_tag.id],
+            [self._normal_tag.id, self._tag_of_tag.id],
+        )
 
 
     def test_tagged_tree(self):
     def test_tagged_tree(self):
         self.assertMissingMatch(
         self.assertMissingMatch(
-            [], [self._tag_of_tree.id],
-            [self._tag_of_tree.id, self.cmt(1).tree, self.f1_1_id])
+            [],
+            [self._tag_of_tree.id],
+            [self._tag_of_tree.id, self.cmt(1).tree, self.f1_1_id],
+        )
 
 
     def test_tagged_blob(self):
     def test_tagged_blob(self):
-        self.assertMissingMatch([], [self._tag_of_blob.id],
-                                [self._tag_of_blob.id, self.f1_1_id])
+        self.assertMissingMatch(
+            [], [self._tag_of_blob.id], [self._tag_of_blob.id, self.f1_1_id]
+        )
 
 
     def test_tagged_tagged_blob(self):
     def test_tagged_tagged_blob(self):
-        self.assertMissingMatch([], [self._tag_of_tag_of_blob.id],
-                                [self._tag_of_tag_of_blob.id,
-                                 self._tag_of_blob.id, self.f1_1_id])
+        self.assertMissingMatch(
+            [],
+            [self._tag_of_tag_of_blob.id],
+            [self._tag_of_tag_of_blob.id, self._tag_of_blob.id, self.f1_1_id],
+        )

+ 257 - 170
dulwich/tests/test_object_store.py

@@ -23,6 +23,7 @@
 
 
 from contextlib import closing
 from contextlib import closing
 from io import BytesIO
 from io import BytesIO
+from unittest import skipUnless
 import os
 import os
 import shutil
 import shutil
 import stat
 import stat
@@ -30,17 +31,17 @@ import tempfile
 
 
 from dulwich.index import (
 from dulwich.index import (
     commit_tree,
     commit_tree,
-    )
+)
 from dulwich.errors import (
 from dulwich.errors import (
     NotTreeError,
     NotTreeError,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     sha_to_hex,
     sha_to_hex,
     Blob,
     Blob,
     Tree,
     Tree,
     TreeEntry,
     TreeEntry,
     EmptyFileException,
     EmptyFileException,
-    )
+)
 from dulwich.object_store import (
 from dulwich.object_store import (
     DiskObjectStore,
     DiskObjectStore,
     MemoryObjectStore,
     MemoryObjectStore,
@@ -49,34 +50,83 @@ from dulwich.object_store import (
     commit_tree_changes,
     commit_tree_changes,
     read_packs_file,
     read_packs_file,
     tree_lookup_path,
     tree_lookup_path,
-    )
+)
 from dulwich.pack import (
 from dulwich.pack import (
     REF_DELTA,
     REF_DELTA,
     write_pack_objects,
     write_pack_objects,
-    )
+)
+from dulwich.protocol import DEPTH_INFINITE
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     make_object,
     make_object,
     make_tag,
     make_tag,
     build_pack,
     build_pack,
-    )
+)
+
+try:
+    from unittest.mock import patch
+except ImportError:
+    patch = None  # type: ignore
 
 
 
 
 testobject = make_object(Blob, data=b"yummy data")
 testobject = make_object(Blob, data=b"yummy data")
 
 
 
 
 class ObjectStoreTests(object):
 class ObjectStoreTests(object):
-
     def test_determine_wants_all(self):
     def test_determine_wants_all(self):
         self.assertEqual(
         self.assertEqual(
             [b"1" * 40],
             [b"1" * 40],
-            self.store.determine_wants_all({b"refs/heads/foo": b"1" * 40}))
+            self.store.determine_wants_all({b"refs/heads/foo": b"1" * 40}),
+        )
 
 
     def test_determine_wants_all_zero(self):
     def test_determine_wants_all_zero(self):
         self.assertEqual(
         self.assertEqual(
-            [], self.store.determine_wants_all({b"refs/heads/foo": b"0" * 40}))
+            [], self.store.determine_wants_all({b"refs/heads/foo": b"0" * 40})
+        )
+
+    @skipUnless(patch, "Required mock.patch")
+    def test_determine_wants_all_depth(self):
+        self.store.add_object(testobject)
+        refs = {b"refs/heads/foo": testobject.id}
+        with patch.object(self.store, "_get_depth", return_value=1) as m:
+            self.assertEqual(
+                [], self.store.determine_wants_all(refs, depth=0)
+            )
+            self.assertEqual(
+                [testobject.id],
+                self.store.determine_wants_all(refs, depth=DEPTH_INFINITE),
+            )
+            m.assert_not_called()
+
+            self.assertEqual(
+                [], self.store.determine_wants_all(refs, depth=1)
+            )
+            m.assert_called_with(testobject.id)
+            self.assertEqual(
+                [testobject.id], self.store.determine_wants_all(refs, depth=2)
+            )
+
+    def test_get_depth(self):
+        self.assertEqual(
+            0, self.store._get_depth(testobject.id)
+        )
+
+        self.store.add_object(testobject)
+        self.assertEqual(
+            1, self.store._get_depth(testobject.id, get_parents=lambda x: [])
+        )
+
+        parent = make_object(Blob, data=b"parent data")
+        self.store.add_object(parent)
+        self.assertEqual(
+            2,
+            self.store._get_depth(
+                testobject.id,
+                get_parents=lambda x: [parent.id] if x == testobject else [],
+            ),
+        )
 
 
     def test_iter(self):
     def test_iter(self):
         self.assertEqual([], list(self.store))
         self.assertEqual([], list(self.store))
@@ -99,11 +149,11 @@ class ObjectStoreTests(object):
         """Test if updating an existing stored object doesn't erase the
         """Test if updating an existing stored object doesn't erase the
         object from the store.
         object from the store.
         """
         """
-        test_object = make_object(Blob, data=b'data')
+        test_object = make_object(Blob, data=b"data")
 
 
         self.store.add_object(test_object)
         self.store.add_object(test_object)
         test_object_id = test_object.id
         test_object_id = test_object.id
-        test_object.data = test_object.data + b'update'
+        test_object.data = test_object.data + b"update"
         stored_test_object = self.store[test_object_id]
         stored_test_object = self.store[test_object_id]
 
 
         self.assertNotEqual(test_object.id, stored_test_object.id)
         self.assertNotEqual(test_object.id, stored_test_object.id)
@@ -125,69 +175,75 @@ class ObjectStoreTests(object):
         self.assertEqual(r, testobject)
         self.assertEqual(r, testobject)
 
 
     def test_tree_changes(self):
     def test_tree_changes(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_b = make_object(Blob, data=b'b')
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_b = make_object(Blob, data=b"b")
         for blob in [blob_a1, blob_a2, blob_b]:
         for blob in [blob_a1, blob_a2, blob_b]:
             self.store.add_object(blob)
             self.store.add_object(blob)
 
 
-        blobs_1 = [(b'a', blob_a1.id, 0o100644), (b'b', blob_b.id, 0o100644)]
+        blobs_1 = [(b"a", blob_a1.id, 0o100644), (b"b", blob_b.id, 0o100644)]
         tree1_id = commit_tree(self.store, blobs_1)
         tree1_id = commit_tree(self.store, blobs_1)
-        blobs_2 = [(b'a', blob_a2.id, 0o100644), (b'b', blob_b.id, 0o100644)]
+        blobs_2 = [(b"a", blob_a2.id, 0o100644), (b"b", blob_b.id, 0o100644)]
         tree2_id = commit_tree(self.store, blobs_2)
         tree2_id = commit_tree(self.store, blobs_2)
-        change_a = ((b'a', b'a'), (0o100644, 0o100644),
-                    (blob_a1.id, blob_a2.id))
-        self.assertEqual([change_a],
-                         list(self.store.tree_changes(tree1_id, tree2_id)))
+        change_a = (
+            (b"a", b"a"),
+            (0o100644, 0o100644),
+            (blob_a1.id, blob_a2.id),
+        )
+        self.assertEqual([change_a], list(self.store.tree_changes(tree1_id, tree2_id)))
         self.assertEqual(
         self.assertEqual(
-            [change_a, ((b'b', b'b'), (0o100644, 0o100644),
-             (blob_b.id, blob_b.id))],
-            list(self.store.tree_changes(tree1_id, tree2_id,
-                 want_unchanged=True)))
+            [
+                change_a,
+                ((b"b", b"b"), (0o100644, 0o100644), (blob_b.id, blob_b.id)),
+            ],
+            list(self.store.tree_changes(tree1_id, tree2_id, want_unchanged=True)),
+        )
 
 
     def test_iter_tree_contents(self):
     def test_iter_tree_contents(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
-        blob_c = make_object(Blob, data=b'c')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
+        blob_c = make_object(Blob, data=b"c")
         for blob in [blob_a, blob_b, blob_c]:
         for blob in [blob_a, blob_b, blob_c]:
             self.store.add_object(blob)
             self.store.add_object(blob)
 
 
         blobs = [
         blobs = [
-            (b'a', blob_a.id, 0o100644),
-            (b'ad/b', blob_b.id, 0o100644),
-            (b'ad/bd/c', blob_c.id, 0o100755),
-            (b'ad/c', blob_c.id, 0o100644),
-            (b'c', blob_c.id, 0o100644),
+            (b"a", blob_a.id, 0o100644),
+            (b"ad/b", blob_b.id, 0o100644),
+            (b"ad/bd/c", blob_c.id, 0o100755),
+            (b"ad/c", blob_c.id, 0o100644),
+            (b"c", blob_c.id, 0o100644),
         ]
         ]
         tree_id = commit_tree(self.store, blobs)
         tree_id = commit_tree(self.store, blobs)
-        self.assertEqual([TreeEntry(p, m, h) for (p, h, m) in blobs],
-                         list(self.store.iter_tree_contents(tree_id)))
+        self.assertEqual(
+            [TreeEntry(p, m, h) for (p, h, m) in blobs],
+            list(self.store.iter_tree_contents(tree_id)),
+        )
 
 
     def test_iter_tree_contents_include_trees(self):
     def test_iter_tree_contents_include_trees(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
-        blob_c = make_object(Blob, data=b'c')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
+        blob_c = make_object(Blob, data=b"c")
         for blob in [blob_a, blob_b, blob_c]:
         for blob in [blob_a, blob_b, blob_c]:
             self.store.add_object(blob)
             self.store.add_object(blob)
 
 
         blobs = [
         blobs = [
-          (b'a', blob_a.id, 0o100644),
-          (b'ad/b', blob_b.id, 0o100644),
-          (b'ad/bd/c', blob_c.id, 0o100755),
-          ]
+            (b"a", blob_a.id, 0o100644),
+            (b"ad/b", blob_b.id, 0o100644),
+            (b"ad/bd/c", blob_c.id, 0o100755),
+        ]
         tree_id = commit_tree(self.store, blobs)
         tree_id = commit_tree(self.store, blobs)
         tree = self.store[tree_id]
         tree = self.store[tree_id]
-        tree_ad = self.store[tree[b'ad'][1]]
-        tree_bd = self.store[tree_ad[b'bd'][1]]
+        tree_ad = self.store[tree[b"ad"][1]]
+        tree_bd = self.store[tree_ad[b"bd"][1]]
 
 
         expected = [
         expected = [
-          TreeEntry(b'', 0o040000, tree_id),
-          TreeEntry(b'a', 0o100644, blob_a.id),
-          TreeEntry(b'ad', 0o040000, tree_ad.id),
-          TreeEntry(b'ad/b', 0o100644, blob_b.id),
-          TreeEntry(b'ad/bd', 0o040000, tree_bd.id),
-          TreeEntry(b'ad/bd/c', 0o100755, blob_c.id),
-          ]
+            TreeEntry(b"", 0o040000, tree_id),
+            TreeEntry(b"a", 0o100644, blob_a.id),
+            TreeEntry(b"ad", 0o040000, tree_ad.id),
+            TreeEntry(b"ad/b", 0o100644, blob_b.id),
+            TreeEntry(b"ad/bd", 0o040000, tree_bd.id),
+            TreeEntry(b"ad/bd/c", 0o100755, blob_c.id),
+        ]
         actual = self.store.iter_tree_contents(tree_id, include_trees=True)
         actual = self.store.iter_tree_contents(tree_id, include_trees=True)
         self.assertEqual(expected, list(actual))
         self.assertEqual(expected, list(actual))
 
 
@@ -198,16 +254,17 @@ class ObjectStoreTests(object):
 
 
     def test_peel_sha(self):
     def test_peel_sha(self):
         self.store.add_object(testobject)
         self.store.add_object(testobject)
-        tag1 = self.make_tag(b'1', testobject)
-        tag2 = self.make_tag(b'2', testobject)
-        tag3 = self.make_tag(b'3', testobject)
+        tag1 = self.make_tag(b"1", testobject)
+        tag2 = self.make_tag(b"2", testobject)
+        tag3 = self.make_tag(b"3", testobject)
         for obj in [testobject, tag1, tag2, tag3]:
         for obj in [testobject, tag1, tag2, tag3]:
             self.assertEqual(testobject, self.store.peel_sha(obj.id))
             self.assertEqual(testobject, self.store.peel_sha(obj.id))
 
 
     def test_get_raw(self):
     def test_get_raw(self):
         self.store.add_object(testobject)
         self.store.add_object(testobject)
-        self.assertEqual((Blob.type_num, b'yummy data'),
-                         self.store.get_raw(testobject.id))
+        self.assertEqual(
+            (Blob.type_num, b"yummy data"), self.store.get_raw(testobject.id)
+        )
 
 
     def test_close(self):
     def test_close(self):
         # For now, just check that close doesn't barf.
         # For now, just check that close doesn't barf.
@@ -216,7 +273,6 @@ class ObjectStoreTests(object):
 
 
 
 
 class OverlayObjectStoreTests(ObjectStoreTests, TestCase):
 class OverlayObjectStoreTests(ObjectStoreTests, TestCase):
-
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
         self.bases = [MemoryObjectStore(), MemoryObjectStore()]
         self.bases = [MemoryObjectStore(), MemoryObjectStore()]
@@ -224,7 +280,6 @@ class OverlayObjectStoreTests(ObjectStoreTests, TestCase):
 
 
 
 
 class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
 class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
-
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
@@ -248,17 +303,22 @@ class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
 
 
     def test_add_thin_pack(self):
     def test_add_thin_pack(self):
         o = MemoryObjectStore()
         o = MemoryObjectStore()
-        blob = make_object(Blob, data=b'yummy data')
+        blob = make_object(Blob, data=b"yummy data")
         o.add_object(blob)
         o.add_object(blob)
 
 
         f = BytesIO()
         f = BytesIO()
-        entries = build_pack(f, [
-            (REF_DELTA, (blob.id, b'more yummy data')),
-            ], store=o)
+        entries = build_pack(
+            f,
+            [
+                (REF_DELTA, (blob.id, b"more yummy data")),
+            ],
+            store=o,
+        )
         o.add_thin_pack(f.read, None)
         o.add_thin_pack(f.read, None)
         packed_blob_sha = sha_to_hex(entries[0][3])
         packed_blob_sha = sha_to_hex(entries[0][3])
-        self.assertEqual((Blob.type_num, b'more yummy data'),
-                         o.get_raw(packed_blob_sha))
+        self.assertEqual(
+            (Blob.type_num, b"more yummy data"), o.get_raw(packed_blob_sha)
+        )
 
 
     def test_add_thin_pack_empty(self):
     def test_add_thin_pack_empty(self):
         o = MemoryObjectStore()
         o = MemoryObjectStore()
@@ -270,7 +330,6 @@ class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
 
 
 
 
 class PackBasedObjectStoreTests(ObjectStoreTests):
 class PackBasedObjectStoreTests(ObjectStoreTests):
-
     def tearDown(self):
     def tearDown(self):
         for pack in self.store.packs:
         for pack in self.store.packs:
             pack.close()
             pack.close()
@@ -303,8 +362,7 @@ class PackBasedObjectStoreTests(ObjectStoreTests):
         b5 = make_object(Blob, data=b"and more data")
         b5 = make_object(Blob, data=b"and more data")
         b6 = make_object(Blob, data=b"and some more data")
         b6 = make_object(Blob, data=b"and some more data")
         self.store.add_objects([(b5, None), (b6, None)])
         self.store.add_objects([(b5, None), (b6, None)])
-        self.assertEqual({b1.id, b2.id, b3.id, b4.id, b5.id, b6.id},
-                         set(self.store))
+        self.assertEqual({b1.id, b2.id, b3.id, b4.id, b5.id, b6.id}, set(self.store))
         self.assertEqual(2, len(self.store.packs))
         self.assertEqual(2, len(self.store.packs))
         self.assertEqual(6, self.store.repack())
         self.assertEqual(6, self.store.repack())
         self.assertEqual(1, len(self.store.packs))
         self.assertEqual(1, len(self.store.packs))
@@ -331,7 +389,6 @@ class PackBasedObjectStoreTests(ObjectStoreTests):
 
 
 
 
 class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
 class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
-
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
         self.store_dir = tempfile.mkdtemp()
         self.store_dir = tempfile.mkdtemp()
@@ -345,8 +402,7 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
     def test_loose_compression_level(self):
     def test_loose_compression_level(self):
         alternate_dir = tempfile.mkdtemp()
         alternate_dir = tempfile.mkdtemp()
         self.addCleanup(shutil.rmtree, alternate_dir)
         self.addCleanup(shutil.rmtree, alternate_dir)
-        alternate_store = DiskObjectStore(
-            alternate_dir, loose_compression_level=6)
+        alternate_store = DiskObjectStore(alternate_dir, loose_compression_level=6)
         b2 = make_object(Blob, data=b"yummy data")
         b2 = make_object(Blob, data=b"yummy data")
         alternate_store.add_object(b2)
         alternate_store.add_object(b2)
 
 
@@ -365,14 +421,16 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
     def test_read_alternate_paths(self):
     def test_read_alternate_paths(self):
         store = DiskObjectStore(self.store_dir)
         store = DiskObjectStore(self.store_dir)
 
 
-        abs_path = os.path.abspath(os.path.normpath('/abspath'))
+        abs_path = os.path.abspath(os.path.normpath("/abspath"))
         # ensures in particular existence of the alternates file
         # ensures in particular existence of the alternates file
         store.add_alternate_path(abs_path)
         store.add_alternate_path(abs_path)
         self.assertEqual(set(store._read_alternate_paths()), {abs_path})
         self.assertEqual(set(store._read_alternate_paths()), {abs_path})
 
 
         store.add_alternate_path("relative-path")
         store.add_alternate_path("relative-path")
-        self.assertIn(os.path.join(store.path, "relative-path"),
-                      set(store._read_alternate_paths()))
+        self.assertIn(
+            os.path.join(store.path, "relative-path"),
+            set(store._read_alternate_paths()),
+        )
 
 
         # arguably, add_alternate_path() could strip comments.
         # arguably, add_alternate_path() could strip comments.
         # Meanwhile it's more convenient to use it than to import INFODIR
         # Meanwhile it's more convenient to use it than to import INFODIR
@@ -383,16 +441,17 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
     def test_corrupted_object_raise_exception(self):
     def test_corrupted_object_raise_exception(self):
         """Corrupted sha1 disk file should raise specific exception"""
         """Corrupted sha1 disk file should raise specific exception"""
         self.store.add_object(testobject)
         self.store.add_object(testobject)
-        self.assertEqual((Blob.type_num, b'yummy data'),
-                         self.store.get_raw(testobject.id))
+        self.assertEqual(
+            (Blob.type_num, b"yummy data"), self.store.get_raw(testobject.id)
+        )
         self.assertTrue(self.store.contains_loose(testobject.id))
         self.assertTrue(self.store.contains_loose(testobject.id))
         self.assertIsNotNone(self.store._get_loose_object(testobject.id))
         self.assertIsNotNone(self.store._get_loose_object(testobject.id))
 
 
         path = self.store._get_shafile_path(testobject.id)
         path = self.store._get_shafile_path(testobject.id)
-        with open(path, 'wb') as f:  # corrupt the file
-            f.write(b'')
+        with open(path, "wb") as f:  # corrupt the file
+            f.write(b"")
 
 
-        expected_error_msg = 'Corrupted empty file detected'
+        expected_error_msg = "Corrupted empty file detected"
         try:
         try:
             self.store.contains_loose(testobject.id)
             self.store.contains_loose(testobject.id)
         except EmptyFileException as e:
         except EmptyFileException as e:
@@ -404,13 +463,11 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
             self.assertEqual(str(e), expected_error_msg)
             self.assertEqual(str(e), expected_error_msg)
 
 
         # this does not change iteration on loose objects though
         # this does not change iteration on loose objects though
-        self.assertEqual([testobject.id],
-                         list(self.store._iter_loose_objects()))
+        self.assertEqual([testobject.id], list(self.store._iter_loose_objects()))
 
 
     def test_tempfile_in_loose_store(self):
     def test_tempfile_in_loose_store(self):
         self.store.add_object(testobject)
         self.store.add_object(testobject)
-        self.assertEqual([testobject.id],
-                         list(self.store._iter_loose_objects()))
+        self.assertEqual([testobject.id], list(self.store._iter_loose_objects()))
 
 
         # add temporary files to the loose store
         # add temporary files to the loose store
         for i in range(256):
         for i in range(256):
@@ -420,8 +477,7 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
             fd, n = tempfile.mkstemp(prefix="tmp_obj_", dir=dirname)
             fd, n = tempfile.mkstemp(prefix="tmp_obj_", dir=dirname)
             os.close(fd)
             os.close(fd)
 
 
-        self.assertEqual([testobject.id],
-                         list(self.store._iter_loose_objects()))
+        self.assertEqual([testobject.id], list(self.store._iter_loose_objects()))
 
 
     def test_add_alternate_path(self):
     def test_add_alternate_path(self):
         store = DiskObjectStore(self.store_dir)
         store = DiskObjectStore(self.store_dir)
@@ -430,8 +486,8 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
         self.assertEqual(["/foo/path"], list(store._read_alternate_paths()))
         self.assertEqual(["/foo/path"], list(store._read_alternate_paths()))
         store.add_alternate_path("/bar/path")
         store.add_alternate_path("/bar/path")
         self.assertEqual(
         self.assertEqual(
-            ["/foo/path", "/bar/path"],
-            list(store._read_alternate_paths()))
+            ["/foo/path", "/bar/path"], list(store._read_alternate_paths())
+        )
 
 
     def test_rel_alternative_path(self):
     def test_rel_alternative_path(self):
         alternate_dir = tempfile.mkdtemp()
         alternate_dir = tempfile.mkdtemp()
@@ -441,8 +497,7 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
         alternate_store.add_object(b2)
         alternate_store.add_object(b2)
         store = DiskObjectStore(self.store_dir)
         store = DiskObjectStore(self.store_dir)
         self.assertRaises(KeyError, store.__getitem__, b2.id)
         self.assertRaises(KeyError, store.__getitem__, b2.id)
-        store.add_alternate_path(
-            os.path.relpath(alternate_dir, self.store_dir))
+        store.add_alternate_path(os.path.relpath(alternate_dir, self.store_dir))
         self.assertEqual(list(alternate_store), list(store.alternates[0]))
         self.assertEqual(list(alternate_store), list(store.alternates[0]))
         self.assertIn(b2.id, store)
         self.assertIn(b2.id, store)
         self.assertEqual(b2, store[b2.id])
         self.assertEqual(b2, store[b2.id])
@@ -466,23 +521,28 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
     def test_add_thin_pack(self):
     def test_add_thin_pack(self):
         o = DiskObjectStore(self.store_dir)
         o = DiskObjectStore(self.store_dir)
         try:
         try:
-            blob = make_object(Blob, data=b'yummy data')
+            blob = make_object(Blob, data=b"yummy data")
             o.add_object(blob)
             o.add_object(blob)
 
 
             f = BytesIO()
             f = BytesIO()
-            entries = build_pack(f, [
-              (REF_DELTA, (blob.id, b'more yummy data')),
-              ], store=o)
+            entries = build_pack(
+                f,
+                [
+                    (REF_DELTA, (blob.id, b"more yummy data")),
+                ],
+                store=o,
+            )
 
 
             with o.add_thin_pack(f.read, None) as pack:
             with o.add_thin_pack(f.read, None) as pack:
                 packed_blob_sha = sha_to_hex(entries[0][3])
                 packed_blob_sha = sha_to_hex(entries[0][3])
                 pack.check_length_and_checksum()
                 pack.check_length_and_checksum()
-                self.assertEqual(
-                    sorted([blob.id, packed_blob_sha]), list(pack))
+                self.assertEqual(sorted([blob.id, packed_blob_sha]), list(pack))
                 self.assertTrue(o.contains_packed(packed_blob_sha))
                 self.assertTrue(o.contains_packed(packed_blob_sha))
                 self.assertTrue(o.contains_packed(blob.id))
                 self.assertTrue(o.contains_packed(blob.id))
-                self.assertEqual((Blob.type_num, b'more yummy data'),
-                                 o.get_raw(packed_blob_sha))
+                self.assertEqual(
+                    (Blob.type_num, b"more yummy data"),
+                    o.get_raw(packed_blob_sha),
+                )
         finally:
         finally:
             o.close()
             o.close()
 
 
@@ -495,58 +555,62 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
 
 
 
 
 class TreeLookupPathTests(TestCase):
 class TreeLookupPathTests(TestCase):
-
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
-        blob_c = make_object(Blob, data=b'c')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
+        blob_c = make_object(Blob, data=b"c")
         for blob in [blob_a, blob_b, blob_c]:
         for blob in [blob_a, blob_b, blob_c]:
             self.store.add_object(blob)
             self.store.add_object(blob)
 
 
         blobs = [
         blobs = [
-          (b'a', blob_a.id, 0o100644),
-          (b'ad/b', blob_b.id, 0o100644),
-          (b'ad/bd/c', blob_c.id, 0o100755),
-          (b'ad/c', blob_c.id, 0o100644),
-          (b'c', blob_c.id, 0o100644),
-          ]
+            (b"a", blob_a.id, 0o100644),
+            (b"ad/b", blob_b.id, 0o100644),
+            (b"ad/bd/c", blob_c.id, 0o100755),
+            (b"ad/c", blob_c.id, 0o100644),
+            (b"c", blob_c.id, 0o100644),
+        ]
         self.tree_id = commit_tree(self.store, blobs)
         self.tree_id = commit_tree(self.store, blobs)
 
 
     def get_object(self, sha):
     def get_object(self, sha):
         return self.store[sha]
         return self.store[sha]
 
 
     def test_lookup_blob(self):
     def test_lookup_blob(self):
-        o_id = tree_lookup_path(self.get_object, self.tree_id, b'a')[1]
+        o_id = tree_lookup_path(self.get_object, self.tree_id, b"a")[1]
         self.assertTrue(isinstance(self.store[o_id], Blob))
         self.assertTrue(isinstance(self.store[o_id], Blob))
 
 
     def test_lookup_tree(self):
     def test_lookup_tree(self):
-        o_id = tree_lookup_path(self.get_object, self.tree_id, b'ad')[1]
+        o_id = tree_lookup_path(self.get_object, self.tree_id, b"ad")[1]
         self.assertTrue(isinstance(self.store[o_id], Tree))
         self.assertTrue(isinstance(self.store[o_id], Tree))
-        o_id = tree_lookup_path(self.get_object, self.tree_id, b'ad/bd')[1]
+        o_id = tree_lookup_path(self.get_object, self.tree_id, b"ad/bd")[1]
         self.assertTrue(isinstance(self.store[o_id], Tree))
         self.assertTrue(isinstance(self.store[o_id], Tree))
-        o_id = tree_lookup_path(self.get_object, self.tree_id, b'ad/bd/')[1]
+        o_id = tree_lookup_path(self.get_object, self.tree_id, b"ad/bd/")[1]
         self.assertTrue(isinstance(self.store[o_id], Tree))
         self.assertTrue(isinstance(self.store[o_id], Tree))
 
 
     def test_lookup_nonexistent(self):
     def test_lookup_nonexistent(self):
         self.assertRaises(
         self.assertRaises(
-            KeyError, tree_lookup_path, self.get_object, self.tree_id, b'j')
+            KeyError, tree_lookup_path, self.get_object, self.tree_id, b"j"
+        )
 
 
     def test_lookup_not_tree(self):
     def test_lookup_not_tree(self):
         self.assertRaises(
         self.assertRaises(
-            NotTreeError, tree_lookup_path, self.get_object, self.tree_id,
-            b'ad/b/j')
+            NotTreeError,
+            tree_lookup_path,
+            self.get_object,
+            self.tree_id,
+            b"ad/b/j",
+        )
 
 
 
 
 class ObjectStoreGraphWalkerTests(TestCase):
 class ObjectStoreGraphWalkerTests(TestCase):
-
     def get_walker(self, heads, parent_map):
     def get_walker(self, heads, parent_map):
         new_parent_map = dict(
         new_parent_map = dict(
-                [(k * 40, [(p * 40) for p in ps])
-                 for (k, ps) in parent_map.items()])
-        return ObjectStoreGraphWalker([x * 40 for x in heads],
-                                      new_parent_map.__getitem__)
+            [(k * 40, [(p * 40) for p in ps]) for (k, ps) in parent_map.items()]
+        )
+        return ObjectStoreGraphWalker(
+            [x * 40 for x in heads], new_parent_map.__getitem__
+        )
 
 
     def test_ack_invalid_value(self):
     def test_ack_invalid_value(self):
         gw = self.get_walker([], {})
         gw = self.get_walker([], {})
@@ -587,13 +651,16 @@ class ObjectStoreGraphWalkerTests(TestCase):
         # c  d
         # c  d
         # \ /
         # \ /
         #  e
         #  e
-        gw = self.get_walker([b"a", b"b"], {
+        gw = self.get_walker(
+            [b"a", b"b"],
+            {
                 b"a": [b"c"],
                 b"a": [b"c"],
                 b"b": [b"d"],
                 b"b": [b"d"],
                 b"c": [b"e"],
                 b"c": [b"e"],
                 b"d": [b"e"],
                 b"d": [b"e"],
                 b"e": [],
                 b"e": [],
-                })
+            },
+        )
         walk = []
         walk = []
         acked = False
         acked = False
         walk.append(next(gw))
         walk.append(next(gw))
@@ -612,86 +679,106 @@ class ObjectStoreGraphWalkerTests(TestCase):
         walk.append(next(gw))
         walk.append(next(gw))
         self.assertIs(None, next(gw))
         self.assertIs(None, next(gw))
 
 
-        self.assertEqual([b"a" * 40, b"b" * 40, b"c" * 40, b"d" * 40],
-                         sorted(walk))
+        self.assertEqual([b"a" * 40, b"b" * 40, b"c" * 40, b"d" * 40], sorted(walk))
         self.assertLess(walk.index(b"a" * 40), walk.index(b"c" * 40))
         self.assertLess(walk.index(b"a" * 40), walk.index(b"c" * 40))
         self.assertLess(walk.index(b"b" * 40), walk.index(b"d" * 40))
         self.assertLess(walk.index(b"b" * 40), walk.index(b"d" * 40))
 
 
 
 
 class CommitTreeChangesTests(TestCase):
 class CommitTreeChangesTests(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(CommitTreeChangesTests, self).setUp()
         super(CommitTreeChangesTests, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
-        self.blob_a = make_object(Blob, data=b'a')
-        self.blob_b = make_object(Blob, data=b'b')
-        self.blob_c = make_object(Blob, data=b'c')
+        self.blob_a = make_object(Blob, data=b"a")
+        self.blob_b = make_object(Blob, data=b"b")
+        self.blob_c = make_object(Blob, data=b"c")
         for blob in [self.blob_a, self.blob_b, self.blob_c]:
         for blob in [self.blob_a, self.blob_b, self.blob_c]:
             self.store.add_object(blob)
             self.store.add_object(blob)
 
 
         blobs = [
         blobs = [
-          (b'a', self.blob_a.id, 0o100644),
-          (b'ad/b', self.blob_b.id, 0o100644),
-          (b'ad/bd/c', self.blob_c.id, 0o100755),
-          (b'ad/c', self.blob_c.id, 0o100644),
-          (b'c', self.blob_c.id, 0o100644),
-          ]
+            (b"a", self.blob_a.id, 0o100644),
+            (b"ad/b", self.blob_b.id, 0o100644),
+            (b"ad/bd/c", self.blob_c.id, 0o100755),
+            (b"ad/c", self.blob_c.id, 0o100644),
+            (b"c", self.blob_c.id, 0o100644),
+        ]
         self.tree_id = commit_tree(self.store, blobs)
         self.tree_id = commit_tree(self.store, blobs)
 
 
     def test_no_changes(self):
     def test_no_changes(self):
         self.assertEqual(
         self.assertEqual(
-                self.store[self.tree_id],
-                commit_tree_changes(self.store, self.store[self.tree_id], []))
+            self.store[self.tree_id],
+            commit_tree_changes(self.store, self.store[self.tree_id], []),
+        )
 
 
     def test_add_blob(self):
     def test_add_blob(self):
-        blob_d = make_object(Blob, data=b'd')
+        blob_d = make_object(Blob, data=b"d")
         new_tree = commit_tree_changes(
         new_tree = commit_tree_changes(
-                self.store, self.store[self.tree_id], [
-                    (b'd', 0o100644, blob_d.id)])
+            self.store, self.store[self.tree_id], [(b"d", 0o100644, blob_d.id)]
+        )
         self.assertEqual(
         self.assertEqual(
-            new_tree[b'd'],
-            (33188, b'c59d9b6344f1af00e504ba698129f07a34bbed8d'))
+            new_tree[b"d"],
+            (33188, b"c59d9b6344f1af00e504ba698129f07a34bbed8d"),
+        )
 
 
     def test_add_blob_in_dir(self):
     def test_add_blob_in_dir(self):
-        blob_d = make_object(Blob, data=b'd')
+        blob_d = make_object(Blob, data=b"d")
         new_tree = commit_tree_changes(
         new_tree = commit_tree_changes(
-                self.store, self.store[self.tree_id], [
-                    (b'e/f/d', 0o100644, blob_d.id)])
+            self.store,
+            self.store[self.tree_id],
+            [(b"e/f/d", 0o100644, blob_d.id)],
+        )
         self.assertEqual(
         self.assertEqual(
-            new_tree.items(), [
-                TreeEntry(path=b'a', mode=stat.S_IFREG | 0o100644,
-                          sha=self.blob_a.id),
-                TreeEntry(path=b'ad', mode=stat.S_IFDIR,
-                          sha=b'0e2ce2cd7725ff4817791be31ccd6e627e801f4a'),
-                TreeEntry(path=b'c', mode=stat.S_IFREG | 0o100644,
-                          sha=self.blob_c.id),
-                TreeEntry(path=b'e', mode=stat.S_IFDIR,
-                          sha=b'6ab344e288724ac2fb38704728b8896e367ed108')
-                ])
-        e_tree = self.store[new_tree[b'e'][1]]
+            new_tree.items(),
+            [
+                TreeEntry(path=b"a", mode=stat.S_IFREG | 0o100644, sha=self.blob_a.id),
+                TreeEntry(
+                    path=b"ad",
+                    mode=stat.S_IFDIR,
+                    sha=b"0e2ce2cd7725ff4817791be31ccd6e627e801f4a",
+                ),
+                TreeEntry(path=b"c", mode=stat.S_IFREG | 0o100644, sha=self.blob_c.id),
+                TreeEntry(
+                    path=b"e",
+                    mode=stat.S_IFDIR,
+                    sha=b"6ab344e288724ac2fb38704728b8896e367ed108",
+                ),
+            ],
+        )
+        e_tree = self.store[new_tree[b"e"][1]]
         self.assertEqual(
         self.assertEqual(
-            e_tree.items(), [
-                TreeEntry(path=b'f', mode=stat.S_IFDIR,
-                          sha=b'24d2c94d8af232b15a0978c006bf61ef4479a0a5')
-                ])
-        f_tree = self.store[e_tree[b'f'][1]]
+            e_tree.items(),
+            [
+                TreeEntry(
+                    path=b"f",
+                    mode=stat.S_IFDIR,
+                    sha=b"24d2c94d8af232b15a0978c006bf61ef4479a0a5",
+                )
+            ],
+        )
+        f_tree = self.store[e_tree[b"f"][1]]
         self.assertEqual(
         self.assertEqual(
-            f_tree.items(), [
-                TreeEntry(path=b'd', mode=stat.S_IFREG | 0o100644,
-                          sha=blob_d.id)
-                ])
+            f_tree.items(),
+            [TreeEntry(path=b"d", mode=stat.S_IFREG | 0o100644, sha=blob_d.id)],
+        )
 
 
     def test_delete_blob(self):
     def test_delete_blob(self):
         new_tree = commit_tree_changes(
         new_tree = commit_tree_changes(
-                self.store, self.store[self.tree_id], [
-                    (b'ad/bd/c', None, None)])
-        self.assertEqual(set(new_tree), {b'a', b'ad', b'c'})
-        ad_tree = self.store[new_tree[b'ad'][1]]
-        self.assertEqual(set(ad_tree), {b'b', b'c'})
+            self.store, self.store[self.tree_id], [(b"ad/bd/c", None, None)]
+        )
+        self.assertEqual(set(new_tree), {b"a", b"ad", b"c"})
+        ad_tree = self.store[new_tree[b"ad"][1]]
+        self.assertEqual(set(ad_tree), {b"b", b"c"})
 
 
 
 
 class TestReadPacksFile(TestCase):
 class TestReadPacksFile(TestCase):
-
     def test_read_packs(self):
     def test_read_packs(self):
-        self.assertEqual(["pack-1.pack"], list(read_packs_file(BytesIO(b"""P pack-1.pack
-"""))))
+        self.assertEqual(
+            ["pack-1.pack"],
+            list(
+                read_packs_file(
+                    BytesIO(
+                        b"""P pack-1.pack
+"""
+                    )
+                )
+            ),
+        )

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 316 - 284
dulwich/tests/test_objects.py


+ 78 - 56
dulwich/tests/test_objectspec.py

@@ -25,7 +25,7 @@
 
 
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
-    )
+)
 from dulwich.objectspec import (
 from dulwich.objectspec import (
     parse_object,
     parse_object,
     parse_commit,
     parse_commit,
@@ -35,14 +35,14 @@ from dulwich.objectspec import (
     parse_reftuple,
     parse_reftuple,
     parse_reftuples,
     parse_reftuples,
     parse_tree,
     parse_tree,
-    )
+)
 from dulwich.repo import MemoryRepo
 from dulwich.repo import MemoryRepo
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     build_commit_graph,
     build_commit_graph,
-    )
+)
 
 
 
 
 class ParseObjectTests(TestCase):
 class ParseObjectTests(TestCase):
@@ -68,8 +68,7 @@ class ParseCommitRangeTests(TestCase):
 
 
     def test_commit_by_sha(self):
     def test_commit_by_sha(self):
         r = MemoryRepo()
         r = MemoryRepo()
-        c1, c2, c3 = build_commit_graph(
-                r.object_store, [[1], [2, 1], [3, 1, 2]])
+        c1, c2, c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         self.assertEqual([c1], list(parse_commit_range(r, c1.id)))
         self.assertEqual([c1], list(parse_commit_range(r, c1.id)))
 
 
 
 
@@ -92,44 +91,50 @@ class ParseCommitTests(TestCase):
 
 
 
 
 class ParseRefTests(TestCase):
 class ParseRefTests(TestCase):
-
     def test_nonexistent(self):
     def test_nonexistent(self):
         r = {}
         r = {}
         self.assertRaises(KeyError, parse_ref, r, b"thisdoesnotexist")
         self.assertRaises(KeyError, parse_ref, r, b"thisdoesnotexist")
 
 
     def test_ambiguous_ref(self):
     def test_ambiguous_ref(self):
-        r = {b"ambig1": 'bla',
-             b"refs/ambig1": 'bla',
-             b"refs/tags/ambig1": 'bla',
-             b"refs/heads/ambig1": 'bla',
-             b"refs/remotes/ambig1": 'bla',
-             b"refs/remotes/ambig1/HEAD": "bla"}
+        r = {
+            b"ambig1": "bla",
+            b"refs/ambig1": "bla",
+            b"refs/tags/ambig1": "bla",
+            b"refs/heads/ambig1": "bla",
+            b"refs/remotes/ambig1": "bla",
+            b"refs/remotes/ambig1/HEAD": "bla",
+        }
         self.assertEqual(b"ambig1", parse_ref(r, b"ambig1"))
         self.assertEqual(b"ambig1", parse_ref(r, b"ambig1"))
 
 
     def test_ambiguous_ref2(self):
     def test_ambiguous_ref2(self):
-        r = {b"refs/ambig2": 'bla',
-             b"refs/tags/ambig2": 'bla',
-             b"refs/heads/ambig2": 'bla',
-             b"refs/remotes/ambig2": 'bla',
-             b"refs/remotes/ambig2/HEAD": "bla"}
+        r = {
+            b"refs/ambig2": "bla",
+            b"refs/tags/ambig2": "bla",
+            b"refs/heads/ambig2": "bla",
+            b"refs/remotes/ambig2": "bla",
+            b"refs/remotes/ambig2/HEAD": "bla",
+        }
         self.assertEqual(b"refs/ambig2", parse_ref(r, b"ambig2"))
         self.assertEqual(b"refs/ambig2", parse_ref(r, b"ambig2"))
 
 
     def test_ambiguous_tag(self):
     def test_ambiguous_tag(self):
-        r = {b"refs/tags/ambig3": 'bla',
-             b"refs/heads/ambig3": 'bla',
-             b"refs/remotes/ambig3": 'bla',
-             b"refs/remotes/ambig3/HEAD": "bla"}
+        r = {
+            b"refs/tags/ambig3": "bla",
+            b"refs/heads/ambig3": "bla",
+            b"refs/remotes/ambig3": "bla",
+            b"refs/remotes/ambig3/HEAD": "bla",
+        }
         self.assertEqual(b"refs/tags/ambig3", parse_ref(r, b"ambig3"))
         self.assertEqual(b"refs/tags/ambig3", parse_ref(r, b"ambig3"))
 
 
     def test_ambiguous_head(self):
     def test_ambiguous_head(self):
-        r = {b"refs/heads/ambig4": 'bla',
-             b"refs/remotes/ambig4": 'bla',
-             b"refs/remotes/ambig4/HEAD": "bla"}
+        r = {
+            b"refs/heads/ambig4": "bla",
+            b"refs/remotes/ambig4": "bla",
+            b"refs/remotes/ambig4/HEAD": "bla",
+        }
         self.assertEqual(b"refs/heads/ambig4", parse_ref(r, b"ambig4"))
         self.assertEqual(b"refs/heads/ambig4", parse_ref(r, b"ambig4"))
 
 
     def test_ambiguous_remote(self):
     def test_ambiguous_remote(self):
-        r = {b"refs/remotes/ambig5": 'bla',
-             b"refs/remotes/ambig5/HEAD": "bla"}
+        r = {b"refs/remotes/ambig5": "bla", b"refs/remotes/ambig5/HEAD": "bla"}
         self.assertEqual(b"refs/remotes/ambig5", parse_ref(r, b"ambig5"))
         self.assertEqual(b"refs/remotes/ambig5", parse_ref(r, b"ambig5"))
 
 
     def test_ambiguous_remote_head(self):
     def test_ambiguous_remote_head(self):
@@ -150,7 +155,6 @@ class ParseRefTests(TestCase):
 
 
 
 
 class ParseRefsTests(TestCase):
 class ParseRefsTests(TestCase):
-
     def test_nonexistent(self):
     def test_nonexistent(self):
         r = {}
         r = {}
         self.assertRaises(KeyError, parse_refs, r, [b"thisdoesnotexist"])
         self.assertRaises(KeyError, parse_refs, r, [b"thisdoesnotexist"])
@@ -165,62 +169,81 @@ class ParseRefsTests(TestCase):
 
 
 
 
 class ParseReftupleTests(TestCase):
 class ParseReftupleTests(TestCase):
-
     def test_nonexistent(self):
     def test_nonexistent(self):
         r = {}
         r = {}
         self.assertRaises(KeyError, parse_reftuple, r, r, b"thisdoesnotexist")
         self.assertRaises(KeyError, parse_reftuple, r, r, b"thisdoesnotexist")
 
 
     def test_head(self):
     def test_head(self):
         r = {b"refs/heads/foo": "bla"}
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", False),
-                         parse_reftuple(r, r, b"foo"))
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", True),
-                         parse_reftuple(r, r, b"+foo"))
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", True),
-                         parse_reftuple(r, {}, b"+foo"))
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", True),
-                         parse_reftuple(r, {}, b"foo", True))
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", False),
+            parse_reftuple(r, r, b"foo"),
+        )
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", True),
+            parse_reftuple(r, r, b"+foo"),
+        )
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", True),
+            parse_reftuple(r, {}, b"+foo"),
+        )
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", True),
+            parse_reftuple(r, {}, b"foo", True),
+        )
 
 
     def test_full(self):
     def test_full(self):
         r = {b"refs/heads/foo": "bla"}
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", False),
-                         parse_reftuple(r, r, b"refs/heads/foo"))
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", False),
+            parse_reftuple(r, r, b"refs/heads/foo"),
+        )
 
 
     def test_no_left_ref(self):
     def test_no_left_ref(self):
         r = {b"refs/heads/foo": "bla"}
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual((None, b"refs/heads/foo", False),
-                         parse_reftuple(r, r, b":refs/heads/foo"))
+        self.assertEqual(
+            (None, b"refs/heads/foo", False),
+            parse_reftuple(r, r, b":refs/heads/foo"),
+        )
 
 
     def test_no_right_ref(self):
     def test_no_right_ref(self):
         r = {b"refs/heads/foo": "bla"}
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual((b"refs/heads/foo", None, False),
-                         parse_reftuple(r, r, b"refs/heads/foo:"))
+        self.assertEqual(
+            (b"refs/heads/foo", None, False),
+            parse_reftuple(r, r, b"refs/heads/foo:"),
+        )
 
 
     def test_default_with_string(self):
     def test_default_with_string(self):
         r = {b"refs/heads/foo": "bla"}
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual((b"refs/heads/foo", b"refs/heads/foo", False),
-                         parse_reftuple(r, r, "foo"))
+        self.assertEqual(
+            (b"refs/heads/foo", b"refs/heads/foo", False),
+            parse_reftuple(r, r, "foo"),
+        )
 
 
 
 
 class ParseReftuplesTests(TestCase):
 class ParseReftuplesTests(TestCase):
-
     def test_nonexistent(self):
     def test_nonexistent(self):
         r = {}
         r = {}
-        self.assertRaises(KeyError, parse_reftuples, r, r,
-                          [b"thisdoesnotexist"])
+        self.assertRaises(KeyError, parse_reftuples, r, r, [b"thisdoesnotexist"])
 
 
     def test_head(self):
     def test_head(self):
         r = {b"refs/heads/foo": "bla"}
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual([(b"refs/heads/foo", b"refs/heads/foo", False)],
-                         parse_reftuples(r, r, [b"foo"]))
+        self.assertEqual(
+            [(b"refs/heads/foo", b"refs/heads/foo", False)],
+            parse_reftuples(r, r, [b"foo"]),
+        )
 
 
     def test_full(self):
     def test_full(self):
         r = {b"refs/heads/foo": "bla"}
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual([(b"refs/heads/foo", b"refs/heads/foo", False)],
-                         parse_reftuples(r, r, b"refs/heads/foo"))
+        self.assertEqual(
+            [(b"refs/heads/foo", b"refs/heads/foo", False)],
+            parse_reftuples(r, r, b"refs/heads/foo"),
+        )
         r = {b"refs/heads/foo": "bla"}
         r = {b"refs/heads/foo": "bla"}
-        self.assertEqual([(b"refs/heads/foo", b"refs/heads/foo", True)],
-                         parse_reftuples(r, r, b"refs/heads/foo", True))
+        self.assertEqual(
+            [(b"refs/heads/foo", b"refs/heads/foo", True)],
+            parse_reftuples(r, r, b"refs/heads/foo", True),
+        )
 
 
 
 
 class ParseTreeTests(TestCase):
 class ParseTreeTests(TestCase):
@@ -232,7 +255,6 @@ class ParseTreeTests(TestCase):
 
 
     def test_from_commit(self):
     def test_from_commit(self):
         r = MemoryRepo()
         r = MemoryRepo()
-        c1, c2, c3 = build_commit_graph(
-                r.object_store, [[1], [2, 1], [3, 1, 2]])
+        c1, c2, c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         self.assertEqual(r[c1.tree], parse_tree(r, c1.id))
         self.assertEqual(r[c1.tree], parse_tree(r, c1.id))
         self.assertEqual(r[c1.tree], parse_tree(r, c1.tree))
         self.assertEqual(r[c1.tree], parse_tree(r, c1.tree))

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 313 - 250
dulwich/tests/test_pack.py


+ 311 - 216
dulwich/tests/test_patch.py

@@ -27,10 +27,10 @@ from dulwich.objects import (
     Commit,
     Commit,
     S_IFGITLINK,
     S_IFGITLINK,
     Tree,
     Tree,
-    )
+)
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
-    )
+)
 from dulwich.patch import (
 from dulwich.patch import (
     get_summary,
     get_summary,
     git_am_patch_split,
     git_am_patch_split,
@@ -38,15 +38,14 @@ from dulwich.patch import (
     write_commit_patch,
     write_commit_patch,
     write_object_diff,
     write_object_diff,
     write_tree_diff,
     write_tree_diff,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     SkipTest,
     SkipTest,
     TestCase,
     TestCase,
-    )
+)
 
 
 
 
 class WriteCommitPatchTests(TestCase):
 class WriteCommitPatchTests(TestCase):
-
     def test_simple_bytesio(self):
     def test_simple_bytesio(self):
         f = BytesIO()
         f = BytesIO()
         c = Commit()
         c = Commit()
@@ -58,26 +57,28 @@ class WriteCommitPatchTests(TestCase):
         write_commit_patch(f, c, b"CONTENTS", (1, 1), version="custom")
         write_commit_patch(f, c, b"CONTENTS", (1, 1), version="custom")
         f.seek(0)
         f.seek(0)
         lines = f.readlines()
         lines = f.readlines()
-        self.assertTrue(lines[0].startswith(
-                    b"From 0b0d34d1b5b596c928adc9a727a4b9e03d025298"))
+        self.assertTrue(
+            lines[0].startswith(b"From 0b0d34d1b5b596c928adc9a727a4b9e03d025298")
+        )
         self.assertEqual(lines[1], b"From: Jelmer <jelmer@samba.org>\n")
         self.assertEqual(lines[1], b"From: Jelmer <jelmer@samba.org>\n")
         self.assertTrue(lines[2].startswith(b"Date: "))
         self.assertTrue(lines[2].startswith(b"Date: "))
-        self.assertEqual([
-            b"Subject: [PATCH 1/1] This is the first line\n",
-            b"And this is the second line.\n",
-            b"\n",
-            b"\n",
-            b"---\n"], lines[3:8])
-        self.assertEqual([
-            b"CONTENTS-- \n",
-            b"custom\n"], lines[-2:])
+        self.assertEqual(
+            [
+                b"Subject: [PATCH 1/1] This is the first line\n",
+                b"And this is the second line.\n",
+                b"\n",
+                b"\n",
+                b"---\n",
+            ],
+            lines[3:8],
+        )
+        self.assertEqual([b"CONTENTS-- \n", b"custom\n"], lines[-2:])
         if len(lines) >= 12:
         if len(lines) >= 12:
             # diffstat may not be present
             # diffstat may not be present
             self.assertEqual(lines[8], b" 0 files changed\n")
             self.assertEqual(lines[8], b" 0 files changed\n")
 
 
 
 
 class ReadGitAmPatch(TestCase):
 class ReadGitAmPatch(TestCase):
-
     def test_extract_string(self):
     def test_extract_string(self):
         text = b"""\
         text = b"""\
 From ff643aae102d8870cac88e8f007e70f58f3a7363 Mon Sep 17 00:00:00 2001
 From ff643aae102d8870cac88e8f007e70f58f3a7363 Mon Sep 17 00:00:00 2001
@@ -93,17 +94,21 @@ Subject: [PATCH 1/2] Remove executable bit from prey.ico (triggers a warning).
 -- 
 -- 
 1.7.0.4
 1.7.0.4
 """  # noqa: W291
 """  # noqa: W291
-        c, diff, version = git_am_patch_split(
-                StringIO(text.decode("utf-8")), "utf-8")
+        c, diff, version = git_am_patch_split(StringIO(text.decode("utf-8")), "utf-8")
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.committer)
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.committer)
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.author)
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.author)
-        self.assertEqual(b"Remove executable bit from prey.ico "
-                         b"(triggers a warning).\n", c.message)
-        self.assertEqual(b""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
+        self.assertEqual(
+            b"Remove executable bit from prey.ico " b"(triggers a warning).\n",
+            c.message,
+        )
+        self.assertEqual(
+            b""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
  1 files changed, 0 insertions(+), 0 deletions(-)
  1 files changed, 0 insertions(+), 0 deletions(-)
  mode change 100755 => 100644 pixmaps/prey.ico
  mode change 100755 => 100644 pixmaps/prey.ico
 
 
-""", diff)
+""",
+            diff,
+        )
         self.assertEqual(b"1.7.0.4", version)
         self.assertEqual(b"1.7.0.4", version)
 
 
     def test_extract_bytes(self):
     def test_extract_bytes(self):
@@ -124,13 +129,18 @@ Subject: [PATCH 1/2] Remove executable bit from prey.ico (triggers a warning).
         c, diff, version = git_am_patch_split(BytesIO(text))
         c, diff, version = git_am_patch_split(BytesIO(text))
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.committer)
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.committer)
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.author)
         self.assertEqual(b"Jelmer Vernooij <jelmer@samba.org>", c.author)
-        self.assertEqual(b"Remove executable bit from prey.ico "
-                         b"(triggers a warning).\n", c.message)
-        self.assertEqual(b""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
+        self.assertEqual(
+            b"Remove executable bit from prey.ico " b"(triggers a warning).\n",
+            c.message,
+        )
+        self.assertEqual(
+            b""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
  1 files changed, 0 insertions(+), 0 deletions(-)
  1 files changed, 0 insertions(+), 0 deletions(-)
  mode change 100755 => 100644 pixmaps/prey.ico
  mode change 100755 => 100644 pixmaps/prey.ico
 
 
-""", diff)
+""",
+            diff,
+        )
         self.assertEqual(b"1.7.0.4", version)
         self.assertEqual(b"1.7.0.4", version)
 
 
     def test_extract_spaces(self):
     def test_extract_spaces(self):
@@ -152,13 +162,16 @@ Subject:  [Dulwich-users] [PATCH] Added unit tests for
 1.7.0.4
 1.7.0.4
 """  # noqa: W291
 """  # noqa: W291
         c, diff, version = git_am_patch_split(BytesIO(text), "utf-8")
         c, diff, version = git_am_patch_split(BytesIO(text), "utf-8")
-        self.assertEqual(b'''\
+        self.assertEqual(
+            b"""\
 Added unit tests for dulwich.object_store.tree_lookup_path.
 Added unit tests for dulwich.object_store.tree_lookup_path.
 
 
 * dulwich/tests/test_object_store.py
 * dulwich/tests/test_object_store.py
   (TreeLookupPathTests): This test case contains a few tests that ensure the
   (TreeLookupPathTests): This test case contains a few tests that ensure the
    tree_lookup_path function works as expected.
    tree_lookup_path function works as expected.
-''', c.message)
+""",
+            c.message,
+        )
 
 
     def test_extract_pseudo_from_header(self):
     def test_extract_pseudo_from_header(self):
         text = b"""From ff643aae102d8870cac88e8f007e70f58f3a7363 Mon Sep 17 00:00:00 2001
         text = b"""From ff643aae102d8870cac88e8f007e70f58f3a7363 Mon Sep 17 00:00:00 2001
@@ -182,13 +195,16 @@ From: Jelmer Vernooij <jelmer@debian.org>
 """  # noqa: W291
 """  # noqa: W291
         c, diff, version = git_am_patch_split(BytesIO(text), "utf-8")
         c, diff, version = git_am_patch_split(BytesIO(text), "utf-8")
         self.assertEqual(b"Jelmer Vernooij <jelmer@debian.org>", c.author)
         self.assertEqual(b"Jelmer Vernooij <jelmer@debian.org>", c.author)
-        self.assertEqual(b'''\
+        self.assertEqual(
+            b"""\
 Added unit tests for dulwich.object_store.tree_lookup_path.
 Added unit tests for dulwich.object_store.tree_lookup_path.
 
 
 * dulwich/tests/test_object_store.py
 * dulwich/tests/test_object_store.py
   (TreeLookupPathTests): This test case contains a few tests that ensure the
   (TreeLookupPathTests): This test case contains a few tests that ensure the
    tree_lookup_path function works as expected.
    tree_lookup_path function works as expected.
-''', c.message)
+""",
+            c.message,
+        )
 
 
     def test_extract_no_version_tail(self):
     def test_extract_no_version_tail(self):
         text = b"""\
         text = b"""\
@@ -211,8 +227,8 @@ From: Jelmer Vernooij <jelmer@debian.org>
 
 
     def test_extract_mercurial(self):
     def test_extract_mercurial(self):
         raise SkipTest(
         raise SkipTest(
-                "git_am_patch_split doesn't handle Mercurial patches "
-                "properly yet")
+            "git_am_patch_split doesn't handle Mercurial patches " "properly yet"
+        )
         expected_diff = """\
         expected_diff = """\
 diff --git a/dulwich/tests/test_patch.py b/dulwich/tests/test_patch.py
 diff --git a/dulwich/tests/test_patch.py b/dulwich/tests/test_patch.py
 --- a/dulwich/tests/test_patch.py
 --- a/dulwich/tests/test_patch.py
@@ -227,7 +243,8 @@ diff --git a/dulwich/tests/test_patch.py b/dulwich/tests/test_patch.py
  
  
  class DiffTests(TestCase):
  class DiffTests(TestCase):
 """  # noqa: W291,W293
 """  # noqa: W291,W293
-        text = """\
+        text = (
+            """\
 From dulwich-users-bounces+jelmer=samba.org@lists.launchpad.net \
 From dulwich-users-bounces+jelmer=samba.org@lists.launchpad.net \
 Mon Nov 29 00:58:18 2010
 Mon Nov 29 00:58:18 2010
 Date: Sun, 28 Nov 2010 17:57:27 -0600
 Date: Sun, 28 Nov 2010 17:57:27 -0600
@@ -246,7 +263,9 @@ Post to     : dulwich-users@lists.launchpad.net
 Unsubscribe : https://launchpad.net/~dulwich-users
 Unsubscribe : https://launchpad.net/~dulwich-users
 More help   : https://help.launchpad.net/ListHelp
 More help   : https://help.launchpad.net/ListHelp
 
 
-""" % expected_diff  # noqa: W291
+"""
+            % expected_diff
+        )  # noqa: W291
         c, diff, version = git_am_patch_split(BytesIO(text))
         c, diff, version = git_am_patch_split(BytesIO(text))
         self.assertEqual(expected_diff, diff)
         self.assertEqual(expected_diff, diff)
         self.assertEqual(None, version)
         self.assertEqual(None, version)
@@ -258,50 +277,65 @@ class DiffTests(TestCase):
     def test_blob_diff(self):
     def test_blob_diff(self):
         f = BytesIO()
         f = BytesIO()
         write_blob_diff(
         write_blob_diff(
-            f, (b"foo.txt", 0o644, Blob.from_string(b"old\nsame\n")),
-            (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")))
-        self.assertEqual([
-            b"diff --git a/foo.txt b/bar.txt",
-            b"index 3b0f961..a116b51 644",
-            b"--- a/foo.txt",
-            b"+++ b/bar.txt",
-            b"@@ -1,2 +1,2 @@",
-            b"-old",
-            b"+new",
-            b" same"
-            ], f.getvalue().splitlines())
+            f,
+            (b"foo.txt", 0o644, Blob.from_string(b"old\nsame\n")),
+            (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")),
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/foo.txt b/bar.txt",
+                b"index 3b0f961..a116b51 644",
+                b"--- a/foo.txt",
+                b"+++ b/bar.txt",
+                b"@@ -1,2 +1,2 @@",
+                b"-old",
+                b"+new",
+                b" same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_blob_add(self):
     def test_blob_add(self):
         f = BytesIO()
         f = BytesIO()
         write_blob_diff(
         write_blob_diff(
-            f, (None, None, None),
-            (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")))
-        self.assertEqual([
-             b'diff --git a/bar.txt b/bar.txt',
-             b'new file mode 644',
-             b'index 0000000..a116b51',
-             b'--- /dev/null',
-             b'+++ b/bar.txt',
-             b'@@ -0,0 +1,2 @@',
-             b'+new',
-             b'+same'
-            ], f.getvalue().splitlines())
+            f,
+            (None, None, None),
+            (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")),
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/bar.txt b/bar.txt",
+                b"new file mode 644",
+                b"index 0000000..a116b51",
+                b"--- /dev/null",
+                b"+++ b/bar.txt",
+                b"@@ -0,0 +1,2 @@",
+                b"+new",
+                b"+same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_blob_remove(self):
     def test_blob_remove(self):
         f = BytesIO()
         f = BytesIO()
         write_blob_diff(
         write_blob_diff(
-            f, (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")),
-            (None, None, None))
-        self.assertEqual([
-            b'diff --git a/bar.txt b/bar.txt',
-            b'deleted file mode 644',
-            b'index a116b51..0000000',
-            b'--- a/bar.txt',
-            b'+++ /dev/null',
-            b'@@ -1,2 +0,0 @@',
-            b'-new',
-            b'-same'
-            ], f.getvalue().splitlines())
+            f,
+            (b"bar.txt", 0o644, Blob.from_string(b"new\nsame\n")),
+            (None, None, None),
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/bar.txt b/bar.txt",
+                b"deleted file mode 644",
+                b"index a116b51..0000000",
+                b"--- a/bar.txt",
+                b"+++ /dev/null",
+                b"@@ -1,2 +0,0 @@",
+                b"-new",
+                b"-same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_tree_diff(self):
     def test_tree_diff(self):
         f = BytesIO()
         f = BytesIO()
@@ -319,54 +353,78 @@ class DiffTests(TestCase):
         tree2.add(b"added.txt", 0o644, added.id)
         tree2.add(b"added.txt", 0o644, added.id)
         tree2.add(b"changed.txt", 0o644, changed2.id)
         tree2.add(b"changed.txt", 0o644, changed2.id)
         tree2.add(b"unchanged.txt", 0o644, changed1.id)
         tree2.add(b"unchanged.txt", 0o644, changed1.id)
-        store.add_objects([(o, None) for o in [
-            tree1, tree2, added, removed, changed1, changed2, unchanged]])
+        store.add_objects(
+            [
+                (o, None)
+                for o in [
+                    tree1,
+                    tree2,
+                    added,
+                    removed,
+                    changed1,
+                    changed2,
+                    unchanged,
+                ]
+            ]
+        )
         write_tree_diff(f, store, tree1.id, tree2.id)
         write_tree_diff(f, store, tree1.id, tree2.id)
-        self.assertEqual([
-            b'diff --git a/added.txt b/added.txt',
-            b'new file mode 644',
-            b'index 0000000..76d4bb8',
-            b'--- /dev/null',
-            b'+++ b/added.txt',
-            b'@@ -0,0 +1 @@',
-            b'+add',
-            b'diff --git a/changed.txt b/changed.txt',
-            b'index bf84e48..1be2436 644',
-            b'--- a/changed.txt',
-            b'+++ b/changed.txt',
-            b'@@ -1,2 +1,2 @@',
-            b' unchanged',
-            b'-removed',
-            b'+added',
-            b'diff --git a/removed.txt b/removed.txt',
-            b'deleted file mode 644',
-            b'index 2c3f0b3..0000000',
-            b'--- a/removed.txt',
-            b'+++ /dev/null',
-            b'@@ -1 +0,0 @@',
-            b'-removed',
-            ], f.getvalue().splitlines())
+        self.assertEqual(
+            [
+                b"diff --git a/added.txt b/added.txt",
+                b"new file mode 644",
+                b"index 0000000..76d4bb8",
+                b"--- /dev/null",
+                b"+++ b/added.txt",
+                b"@@ -0,0 +1 @@",
+                b"+add",
+                b"diff --git a/changed.txt b/changed.txt",
+                b"index bf84e48..1be2436 644",
+                b"--- a/changed.txt",
+                b"+++ b/changed.txt",
+                b"@@ -1,2 +1,2 @@",
+                b" unchanged",
+                b"-removed",
+                b"+added",
+                b"diff --git a/removed.txt b/removed.txt",
+                b"deleted file mode 644",
+                b"index 2c3f0b3..0000000",
+                b"--- a/removed.txt",
+                b"+++ /dev/null",
+                b"@@ -1 +0,0 @@",
+                b"-removed",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_tree_diff_submodule(self):
     def test_tree_diff_submodule(self):
         f = BytesIO()
         f = BytesIO()
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         tree1 = Tree()
         tree1 = Tree()
-        tree1.add(b"asubmodule", S_IFGITLINK,
-                  b"06d0bdd9e2e20377b3180e4986b14c8549b393e4")
+        tree1.add(
+            b"asubmodule",
+            S_IFGITLINK,
+            b"06d0bdd9e2e20377b3180e4986b14c8549b393e4",
+        )
         tree2 = Tree()
         tree2 = Tree()
-        tree2.add(b"asubmodule", S_IFGITLINK,
-                  b"cc975646af69f279396d4d5e1379ac6af80ee637")
+        tree2.add(
+            b"asubmodule",
+            S_IFGITLINK,
+            b"cc975646af69f279396d4d5e1379ac6af80ee637",
+        )
         store.add_objects([(o, None) for o in [tree1, tree2]])
         store.add_objects([(o, None) for o in [tree1, tree2]])
         write_tree_diff(f, store, tree1.id, tree2.id)
         write_tree_diff(f, store, tree1.id, tree2.id)
-        self.assertEqual([
-            b'diff --git a/asubmodule b/asubmodule',
-            b'index 06d0bdd..cc97564 160000',
-            b'--- a/asubmodule',
-            b'+++ b/asubmodule',
-            b'@@ -1 +1 @@',
-            b'-Subproject commit 06d0bdd9e2e20377b3180e4986b14c8549b393e4',
-            b'+Subproject commit cc975646af69f279396d4d5e1379ac6af80ee637',
-            ], f.getvalue().splitlines())
+        self.assertEqual(
+            [
+                b"diff --git a/asubmodule b/asubmodule",
+                b"index 06d0bdd..cc97564 160000",
+                b"--- a/asubmodule",
+                b"+++ b/asubmodule",
+                b"@@ -1 +1 @@",
+                b"-Subproject commit 06d0bdd9e2e20377b3180e4986b14c8549b393e4",
+                b"+Subproject commit cc975646af69f279396d4d5e1379ac6af80ee637",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_object_diff_blob(self):
     def test_object_diff_blob(self):
         f = BytesIO()
         f = BytesIO()
@@ -374,54 +432,62 @@ class DiffTests(TestCase):
         b2 = Blob.from_string(b"new\nsame\n")
         b2 = Blob.from_string(b"new\nsame\n")
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         store.add_objects([(b1, None), (b2, None)])
         store.add_objects([(b1, None), (b2, None)])
-        write_object_diff(f, store, (b"foo.txt", 0o644, b1.id),
-                                    (b"bar.txt", 0o644, b2.id))
-        self.assertEqual([
-            b"diff --git a/foo.txt b/bar.txt",
-            b"index 3b0f961..a116b51 644",
-            b"--- a/foo.txt",
-            b"+++ b/bar.txt",
-            b"@@ -1,2 +1,2 @@",
-            b"-old",
-            b"+new",
-            b" same"
-            ], f.getvalue().splitlines())
+        write_object_diff(
+            f, store, (b"foo.txt", 0o644, b1.id), (b"bar.txt", 0o644, b2.id)
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/foo.txt b/bar.txt",
+                b"index 3b0f961..a116b51 644",
+                b"--- a/foo.txt",
+                b"+++ b/bar.txt",
+                b"@@ -1,2 +1,2 @@",
+                b"-old",
+                b"+new",
+                b" same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_object_diff_add_blob(self):
     def test_object_diff_add_blob(self):
         f = BytesIO()
         f = BytesIO()
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         b2 = Blob.from_string(b"new\nsame\n")
         b2 = Blob.from_string(b"new\nsame\n")
         store.add_object(b2)
         store.add_object(b2)
-        write_object_diff(f, store, (None, None, None),
-                                    (b"bar.txt", 0o644, b2.id))
-        self.assertEqual([
-             b'diff --git a/bar.txt b/bar.txt',
-             b'new file mode 644',
-             b'index 0000000..a116b51',
-             b'--- /dev/null',
-             b'+++ b/bar.txt',
-             b'@@ -0,0 +1,2 @@',
-             b'+new',
-             b'+same'
-            ], f.getvalue().splitlines())
+        write_object_diff(f, store, (None, None, None), (b"bar.txt", 0o644, b2.id))
+        self.assertEqual(
+            [
+                b"diff --git a/bar.txt b/bar.txt",
+                b"new file mode 644",
+                b"index 0000000..a116b51",
+                b"--- /dev/null",
+                b"+++ b/bar.txt",
+                b"@@ -0,0 +1,2 @@",
+                b"+new",
+                b"+same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_object_diff_remove_blob(self):
     def test_object_diff_remove_blob(self):
         f = BytesIO()
         f = BytesIO()
         b1 = Blob.from_string(b"new\nsame\n")
         b1 = Blob.from_string(b"new\nsame\n")
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         store.add_object(b1)
         store.add_object(b1)
-        write_object_diff(f, store, (b"bar.txt", 0o644, b1.id),
-                                    (None, None, None))
-        self.assertEqual([
-            b'diff --git a/bar.txt b/bar.txt',
-            b'deleted file mode 644',
-            b'index a116b51..0000000',
-            b'--- a/bar.txt',
-            b'+++ /dev/null',
-            b'@@ -1,2 +0,0 @@',
-            b'-new',
-            b'-same'
-            ], f.getvalue().splitlines())
+        write_object_diff(f, store, (b"bar.txt", 0o644, b1.id), (None, None, None))
+        self.assertEqual(
+            [
+                b"diff --git a/bar.txt b/bar.txt",
+                b"deleted file mode 644",
+                b"index a116b51..0000000",
+                b"--- a/bar.txt",
+                b"+++ /dev/null",
+                b"@@ -1,2 +0,0 @@",
+                b"-new",
+                b"-same",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_object_diff_bin_blob_force(self):
     def test_object_diff_bin_blob_force(self):
         f = BytesIO()
         f = BytesIO()
@@ -430,33 +496,42 @@ class DiffTests(TestCase):
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
-            b"\x08\x04\x00\x00\x00\x05\x04\x8b")
+            b"\x08\x04\x00\x00\x00\x05\x04\x8b"
+        )
         b2 = Blob.from_string(
         b2 = Blob.from_string(
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
-            b"\x08\x03\x00\x00\x00\x98\xd3\xb3")
+            b"\x08\x03\x00\x00\x00\x98\xd3\xb3"
+        )
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         store.add_objects([(b1, None), (b2, None)])
         store.add_objects([(b1, None), (b2, None)])
         write_object_diff(
         write_object_diff(
-            f, store, (b'foo.png', 0o644, b1.id),
-            (b'bar.png', 0o644, b2.id), diff_binary=True)
-        self.assertEqual([
-            b'diff --git a/foo.png b/bar.png',
-            b'index f73e47d..06364b7 644',
-            b'--- a/foo.png',
-            b'+++ b/bar.png',
-            b'@@ -1,4 +1,4 @@',
-            b' \x89PNG',
-            b' \x1a',
-            b' \x00\x00\x00',
-            b'-IHDR\x00\x00\x01\xd5\x00\x00\x00'
-            b'\x9f\x08\x04\x00\x00\x00\x05\x04\x8b',
-            b'\\ No newline at end of file',
-            b'+IHDR\x00\x00\x01\xd5\x00\x00\x00\x9f'
-            b'\x08\x03\x00\x00\x00\x98\xd3\xb3',
-            b'\\ No newline at end of file'
-            ], f.getvalue().splitlines())
+            f,
+            store,
+            (b"foo.png", 0o644, b1.id),
+            (b"bar.png", 0o644, b2.id),
+            diff_binary=True,
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/foo.png b/bar.png",
+                b"index f73e47d..06364b7 644",
+                b"--- a/foo.png",
+                b"+++ b/bar.png",
+                b"@@ -1,4 +1,4 @@",
+                b" \x89PNG",
+                b" \x1a",
+                b" \x00\x00\x00",
+                b"-IHDR\x00\x00\x01\xd5\x00\x00\x00"
+                b"\x9f\x08\x04\x00\x00\x00\x05\x04\x8b",
+                b"\\ No newline at end of file",
+                b"+IHDR\x00\x00\x01\xd5\x00\x00\x00\x9f"
+                b"\x08\x03\x00\x00\x00\x98\xd3\xb3",
+                b"\\ No newline at end of file",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_object_diff_bin_blob(self):
     def test_object_diff_bin_blob(self):
         f = BytesIO()
         f = BytesIO()
@@ -465,57 +540,69 @@ class DiffTests(TestCase):
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
-            b"\x08\x04\x00\x00\x00\x05\x04\x8b")
+            b"\x08\x04\x00\x00\x00\x05\x04\x8b"
+        )
         b2 = Blob.from_string(
         b2 = Blob.from_string(
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x00\x0d\x49\x48\x44\x52"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
             b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
-            b"\x08\x03\x00\x00\x00\x98\xd3\xb3")
+            b"\x08\x03\x00\x00\x00\x98\xd3\xb3"
+        )
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         store.add_objects([(b1, None), (b2, None)])
         store.add_objects([(b1, None), (b2, None)])
-        write_object_diff(f, store, (b'foo.png', 0o644, b1.id),
-                                    (b'bar.png', 0o644, b2.id))
-        self.assertEqual([
-            b'diff --git a/foo.png b/bar.png',
-            b'index f73e47d..06364b7 644',
-            b'Binary files a/foo.png and b/bar.png differ'
-            ], f.getvalue().splitlines())
+        write_object_diff(
+            f, store, (b"foo.png", 0o644, b1.id), (b"bar.png", 0o644, b2.id)
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/foo.png b/bar.png",
+                b"index f73e47d..06364b7 644",
+                b"Binary files a/foo.png and b/bar.png differ",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_object_diff_add_bin_blob(self):
     def test_object_diff_add_bin_blob(self):
         f = BytesIO()
         f = BytesIO()
         b2 = Blob.from_string(
         b2 = Blob.from_string(
-            b'\x89\x50\x4e\x47\x0d\x0a\x1a\x0a'
-            b'\x00\x00\x00\x0d\x49\x48\x44\x52'
-            b'\x00\x00\x01\xd5\x00\x00\x00\x9f'
-            b'\x08\x03\x00\x00\x00\x98\xd3\xb3')
+            b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
+            b"\x00\x00\x00\x0d\x49\x48\x44\x52"
+            b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
+            b"\x08\x03\x00\x00\x00\x98\xd3\xb3"
+        )
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         store.add_object(b2)
         store.add_object(b2)
-        write_object_diff(f, store, (None, None, None),
-                                    (b'bar.png', 0o644, b2.id))
-        self.assertEqual([
-            b'diff --git a/bar.png b/bar.png',
-            b'new file mode 644',
-            b'index 0000000..06364b7',
-            b'Binary files /dev/null and b/bar.png differ'
-            ], f.getvalue().splitlines())
+        write_object_diff(f, store, (None, None, None), (b"bar.png", 0o644, b2.id))
+        self.assertEqual(
+            [
+                b"diff --git a/bar.png b/bar.png",
+                b"new file mode 644",
+                b"index 0000000..06364b7",
+                b"Binary files /dev/null and b/bar.png differ",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_object_diff_remove_bin_blob(self):
     def test_object_diff_remove_bin_blob(self):
         f = BytesIO()
         f = BytesIO()
         b1 = Blob.from_string(
         b1 = Blob.from_string(
-            b'\x89\x50\x4e\x47\x0d\x0a\x1a\x0a'
-            b'\x00\x00\x00\x0d\x49\x48\x44\x52'
-            b'\x00\x00\x01\xd5\x00\x00\x00\x9f'
-            b'\x08\x04\x00\x00\x00\x05\x04\x8b')
+            b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a"
+            b"\x00\x00\x00\x0d\x49\x48\x44\x52"
+            b"\x00\x00\x01\xd5\x00\x00\x00\x9f"
+            b"\x08\x04\x00\x00\x00\x05\x04\x8b"
+        )
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         store.add_object(b1)
         store.add_object(b1)
-        write_object_diff(f, store, (b'foo.png', 0o644, b1.id),
-                                    (None, None, None))
-        self.assertEqual([
-            b'diff --git a/foo.png b/foo.png',
-            b'deleted file mode 644',
-            b'index f73e47d..0000000',
-            b'Binary files a/foo.png and /dev/null differ'
-            ], f.getvalue().splitlines())
+        write_object_diff(f, store, (b"foo.png", 0o644, b1.id), (None, None, None))
+        self.assertEqual(
+            [
+                b"diff --git a/foo.png b/foo.png",
+                b"deleted file mode 644",
+                b"index f73e47d..0000000",
+                b"Binary files a/foo.png and /dev/null differ",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
     def test_object_diff_kind_change(self):
     def test_object_diff_kind_change(self):
         f = BytesIO()
         f = BytesIO()
@@ -523,25 +610,33 @@ class DiffTests(TestCase):
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         store.add_object(b1)
         store.add_object(b1)
         write_object_diff(
         write_object_diff(
-            f, store, (b"bar.txt", 0o644, b1.id),
-            (b"bar.txt", 0o160000,
-                b"06d0bdd9e2e20377b3180e4986b14c8549b393e4"))
-        self.assertEqual([
-            b'diff --git a/bar.txt b/bar.txt',
-            b'old file mode 644',
-            b'new file mode 160000',
-            b'index a116b51..06d0bdd 160000',
-            b'--- a/bar.txt',
-            b'+++ b/bar.txt',
-            b'@@ -1,2 +1 @@',
-            b'-new',
-            b'-same',
-            b'+Subproject commit 06d0bdd9e2e20377b3180e4986b14c8549b393e4',
-            ], f.getvalue().splitlines())
+            f,
+            store,
+            (b"bar.txt", 0o644, b1.id),
+            (
+                b"bar.txt",
+                0o160000,
+                b"06d0bdd9e2e20377b3180e4986b14c8549b393e4",
+            ),
+        )
+        self.assertEqual(
+            [
+                b"diff --git a/bar.txt b/bar.txt",
+                b"old file mode 644",
+                b"new file mode 160000",
+                b"index a116b51..06d0bdd 160000",
+                b"--- a/bar.txt",
+                b"+++ b/bar.txt",
+                b"@@ -1,2 +1 @@",
+                b"-new",
+                b"-same",
+                b"+Subproject commit 06d0bdd9e2e20377b3180e4986b14c8549b393e4",
+            ],
+            f.getvalue().splitlines(),
+        )
 
 
 
 
 class GetSummaryTests(TestCase):
 class GetSummaryTests(TestCase):
-
     def test_simple(self):
     def test_simple(self):
         c = Commit()
         c = Commit()
         c.committer = c.author = b"Jelmer <jelmer@samba.org>"
         c.committer = c.author = b"Jelmer <jelmer@samba.org>"
@@ -549,4 +644,4 @@ class GetSummaryTests(TestCase):
         c.commit_timezone = c.author_timezone = 0
         c.commit_timezone = c.author_timezone = 0
         c.message = b"This is the first line\nAnd this is the second line.\n"
         c.message = b"This is the first line\nAnd this is the second line.\n"
         c.tree = Tree().id
         c.tree = Tree().id
-        self.assertEqual('This-is-the-first-line', get_summary(c))
+        self.assertEqual("This-is-the-first-line", get_summary(c))

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 470 - 201
dulwich/tests/test_porcelain.py


+ 84 - 84
dulwich/tests/test_protocol.py

@@ -25,7 +25,7 @@ from io import BytesIO
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     HangupException,
     HangupException,
-    )
+)
 from dulwich.protocol import (
 from dulwich.protocol import (
     GitProtocolError,
     GitProtocolError,
     PktLineParser,
     PktLineParser,
@@ -38,27 +38,26 @@ from dulwich.protocol import (
     MULTI_ACK,
     MULTI_ACK,
     MULTI_ACK_DETAILED,
     MULTI_ACK_DETAILED,
     BufferedPktLineWriter,
     BufferedPktLineWriter,
-    )
+)
 from dulwich.tests import TestCase
 from dulwich.tests import TestCase
 
 
 
 
 class BaseProtocolTests(object):
 class BaseProtocolTests(object):
-
     def test_write_pkt_line_none(self):
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
-        self.assertEqual(self.rout.getvalue(), b'0000')
+        self.assertEqual(self.rout.getvalue(), b"0000")
 
 
     def test_write_pkt_line(self):
     def test_write_pkt_line(self):
-        self.proto.write_pkt_line(b'bla')
-        self.assertEqual(self.rout.getvalue(), b'0007bla')
+        self.proto.write_pkt_line(b"bla")
+        self.assertEqual(self.rout.getvalue(), b"0007bla")
 
 
     def test_read_pkt_line(self):
     def test_read_pkt_line(self):
-        self.rin.write(b'0008cmd ')
+        self.rin.write(b"0008cmd ")
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEqual(b'cmd ', self.proto.read_pkt_line())
+        self.assertEqual(b"cmd ", self.proto.read_pkt_line())
 
 
     def test_eof(self):
     def test_eof(self):
-        self.rin.write(b'0000')
+        self.rin.write(b"0000")
         self.rin.seek(0)
         self.rin.seek(0)
         self.assertFalse(self.proto.eof())
         self.assertFalse(self.proto.eof())
         self.assertEqual(None, self.proto.read_pkt_line())
         self.assertEqual(None, self.proto.read_pkt_line())
@@ -66,51 +65,50 @@ class BaseProtocolTests(object):
         self.assertRaises(HangupException, self.proto.read_pkt_line)
         self.assertRaises(HangupException, self.proto.read_pkt_line)
 
 
     def test_unread_pkt_line(self):
     def test_unread_pkt_line(self):
-        self.rin.write(b'0007foo0000')
+        self.rin.write(b"0007foo0000")
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEqual(b'foo', self.proto.read_pkt_line())
-        self.proto.unread_pkt_line(b'bar')
-        self.assertEqual(b'bar', self.proto.read_pkt_line())
+        self.assertEqual(b"foo", self.proto.read_pkt_line())
+        self.proto.unread_pkt_line(b"bar")
+        self.assertEqual(b"bar", self.proto.read_pkt_line())
         self.assertEqual(None, self.proto.read_pkt_line())
         self.assertEqual(None, self.proto.read_pkt_line())
-        self.proto.unread_pkt_line(b'baz1')
-        self.assertRaises(ValueError, self.proto.unread_pkt_line, b'baz2')
+        self.proto.unread_pkt_line(b"baz1")
+        self.assertRaises(ValueError, self.proto.unread_pkt_line, b"baz2")
 
 
     def test_read_pkt_seq(self):
     def test_read_pkt_seq(self):
-        self.rin.write(b'0008cmd 0005l0000')
+        self.rin.write(b"0008cmd 0005l0000")
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEqual([b'cmd ', b'l'], list(self.proto.read_pkt_seq()))
+        self.assertEqual([b"cmd ", b"l"], list(self.proto.read_pkt_seq()))
 
 
     def test_read_pkt_line_none(self):
     def test_read_pkt_line_none(self):
-        self.rin.write(b'0000')
+        self.rin.write(b"0000")
         self.rin.seek(0)
         self.rin.seek(0)
         self.assertEqual(None, self.proto.read_pkt_line())
         self.assertEqual(None, self.proto.read_pkt_line())
 
 
     def test_read_pkt_line_wrong_size(self):
     def test_read_pkt_line_wrong_size(self):
-        self.rin.write(b'0100too short')
+        self.rin.write(b"0100too short")
         self.rin.seek(0)
         self.rin.seek(0)
         self.assertRaises(GitProtocolError, self.proto.read_pkt_line)
         self.assertRaises(GitProtocolError, self.proto.read_pkt_line)
 
 
     def test_write_sideband(self):
     def test_write_sideband(self):
-        self.proto.write_sideband(3, b'bloe')
-        self.assertEqual(self.rout.getvalue(), b'0009\x03bloe')
+        self.proto.write_sideband(3, b"bloe")
+        self.assertEqual(self.rout.getvalue(), b"0009\x03bloe")
 
 
     def test_send_cmd(self):
     def test_send_cmd(self):
-        self.proto.send_cmd(b'fetch', b'a', b'b')
-        self.assertEqual(self.rout.getvalue(), b'000efetch a\x00b\x00')
+        self.proto.send_cmd(b"fetch", b"a", b"b")
+        self.assertEqual(self.rout.getvalue(), b"000efetch a\x00b\x00")
 
 
     def test_read_cmd(self):
     def test_read_cmd(self):
-        self.rin.write(b'0012cmd arg1\x00arg2\x00')
+        self.rin.write(b"0012cmd arg1\x00arg2\x00")
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEqual((b'cmd', [b'arg1', b'arg2']), self.proto.read_cmd())
+        self.assertEqual((b"cmd", [b"arg1", b"arg2"]), self.proto.read_cmd())
 
 
     def test_read_cmd_noend0(self):
     def test_read_cmd_noend0(self):
-        self.rin.write(b'0011cmd arg1\x00arg2')
+        self.rin.write(b"0011cmd arg1\x00arg2")
         self.rin.seek(0)
         self.rin.seek(0)
         self.assertRaises(AssertionError, self.proto.read_cmd)
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
 
 
 
 class ProtocolTests(BaseProtocolTests, TestCase):
 class ProtocolTests(BaseProtocolTests, TestCase):
-
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
         self.rout = BytesIO()
         self.rout = BytesIO()
@@ -128,9 +126,8 @@ class ReceivableBytesIO(BytesIO):
     def recv(self, size):
     def recv(self, size):
         # fail fast if no bytes are available; in a real socket, this would
         # fail fast if no bytes are available; in a real socket, this would
         # block forever
         # block forever
-        if (self.tell() == len(self.getvalue())
-                and not self.allow_read_past_eof):
-            raise GitProtocolError('Blocking read past end of socket')
+        if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof:
+            raise GitProtocolError("Blocking read past end of socket")
         if size == 1:
         if size == 1:
             return self.read(1)
             return self.read(1)
         # calls shouldn't return quite as much as asked for
         # calls shouldn't return quite as much as asked for
@@ -138,7 +135,6 @@ class ReceivableBytesIO(BytesIO):
 
 
 
 
 class ReceivableProtocolTests(BaseProtocolTests, TestCase):
 class ReceivableProtocolTests(BaseProtocolTests, TestCase):
-
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
         self.rout = BytesIO()
         self.rout = BytesIO()
@@ -154,10 +150,10 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         BaseProtocolTests.test_eof(self)
         BaseProtocolTests.test_eof(self)
 
 
     def test_recv(self):
     def test_recv(self):
-        all_data = b'1234567' * 10  # not a multiple of bufsize
+        all_data = b"1234567" * 10  # not a multiple of bufsize
         self.rin.write(all_data)
         self.rin.write(all_data)
         self.rin.seek(0)
         self.rin.seek(0)
-        data = b''
+        data = b""
         # We ask for 8 bytes each time and actually read 7, so it should take
         # We ask for 8 bytes each time and actually read 7, so it should take
         # exactly 10 iterations.
         # exactly 10 iterations.
         for _ in range(10):
         for _ in range(10):
@@ -167,28 +163,28 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.assertEqual(all_data, data)
         self.assertEqual(all_data, data)
 
 
     def test_recv_read(self):
     def test_recv_read(self):
-        all_data = b'1234567'  # recv exactly in one call
+        all_data = b"1234567"  # recv exactly in one call
         self.rin.write(all_data)
         self.rin.write(all_data)
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEqual(b'1234', self.proto.recv(4))
-        self.assertEqual(b'567', self.proto.read(3))
+        self.assertEqual(b"1234", self.proto.recv(4))
+        self.assertEqual(b"567", self.proto.read(3))
         self.assertRaises(GitProtocolError, self.proto.recv, 10)
         self.assertRaises(GitProtocolError, self.proto.recv, 10)
 
 
     def test_read_recv(self):
     def test_read_recv(self):
-        all_data = b'12345678abcdefg'
+        all_data = b"12345678abcdefg"
         self.rin.write(all_data)
         self.rin.write(all_data)
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEqual(b'1234', self.proto.read(4))
-        self.assertEqual(b'5678abc', self.proto.recv(8))
-        self.assertEqual(b'defg', self.proto.read(4))
+        self.assertEqual(b"1234", self.proto.read(4))
+        self.assertEqual(b"5678abc", self.proto.recv(8))
+        self.assertEqual(b"defg", self.proto.read(4))
         self.assertRaises(GitProtocolError, self.proto.recv, 10)
         self.assertRaises(GitProtocolError, self.proto.recv, 10)
 
 
     def test_mixed(self):
     def test_mixed(self):
         # arbitrary non-repeating string
         # arbitrary non-repeating string
-        all_data = b','.join(str(i).encode('ascii') for i in range(100))
+        all_data = b",".join(str(i).encode("ascii") for i in range(100))
         self.rin.write(all_data)
         self.rin.write(all_data)
         self.rin.seek(0)
         self.rin.seek(0)
-        data = b''
+        data = b""
 
 
         for i in range(1, 100):
         for i in range(1, 100):
             data += self.proto.recv(i)
             data += self.proto.recv(i)
@@ -209,41 +205,46 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
 
 
 
 
 class CapabilitiesTestCase(TestCase):
 class CapabilitiesTestCase(TestCase):
-
     def test_plain(self):
     def test_plain(self):
-        self.assertEqual((b'bla', []), extract_capabilities(b'bla'))
+        self.assertEqual((b"bla", []), extract_capabilities(b"bla"))
 
 
     def test_caps(self):
     def test_caps(self):
-        self.assertEqual((b'bla', [b'la']), extract_capabilities(b'bla\0la'))
-        self.assertEqual((b'bla', [b'la']), extract_capabilities(b'bla\0la\n'))
-        self.assertEqual((b'bla', [b'la', b'la']),
-                         extract_capabilities(b'bla\0la la'))
+        self.assertEqual((b"bla", [b"la"]), extract_capabilities(b"bla\0la"))
+        self.assertEqual((b"bla", [b"la"]), extract_capabilities(b"bla\0la\n"))
+        self.assertEqual((b"bla", [b"la", b"la"]), extract_capabilities(b"bla\0la la"))
 
 
     def test_plain_want_line(self):
     def test_plain_want_line(self):
-        self.assertEqual((b'want bla', []),
-                         extract_want_line_capabilities(b'want bla'))
+        self.assertEqual((b"want bla", []), extract_want_line_capabilities(b"want bla"))
 
 
     def test_caps_want_line(self):
     def test_caps_want_line(self):
-        self.assertEqual((b'want bla', [b'la']),
-                         extract_want_line_capabilities(b'want bla la'))
-        self.assertEqual((b'want bla', [b'la']),
-                         extract_want_line_capabilities(b'want bla la\n'))
-        self.assertEqual((b'want bla', [b'la', b'la']),
-                         extract_want_line_capabilities(b'want bla la la'))
+        self.assertEqual(
+            (b"want bla", [b"la"]),
+            extract_want_line_capabilities(b"want bla la"),
+        )
+        self.assertEqual(
+            (b"want bla", [b"la"]),
+            extract_want_line_capabilities(b"want bla la\n"),
+        )
+        self.assertEqual(
+            (b"want bla", [b"la", b"la"]),
+            extract_want_line_capabilities(b"want bla la la"),
+        )
 
 
     def test_ack_type(self):
     def test_ack_type(self):
-        self.assertEqual(SINGLE_ACK, ack_type([b'foo', b'bar']))
-        self.assertEqual(MULTI_ACK, ack_type([b'foo', b'bar', b'multi_ack']))
-        self.assertEqual(MULTI_ACK_DETAILED,
-                         ack_type([b'foo', b'bar', b'multi_ack_detailed']))
+        self.assertEqual(SINGLE_ACK, ack_type([b"foo", b"bar"]))
+        self.assertEqual(MULTI_ACK, ack_type([b"foo", b"bar", b"multi_ack"]))
+        self.assertEqual(
+            MULTI_ACK_DETAILED,
+            ack_type([b"foo", b"bar", b"multi_ack_detailed"]),
+        )
         # choose detailed when both present
         # choose detailed when both present
-        self.assertEqual(MULTI_ACK_DETAILED,
-                         ack_type([b'foo', b'bar', b'multi_ack',
-                                   b'multi_ack_detailed']))
+        self.assertEqual(
+            MULTI_ACK_DETAILED,
+            ack_type([b"foo", b"bar", b"multi_ack", b"multi_ack_detailed"]),
+        )
 
 
 
 
 class BufferedPktLineWriterTests(TestCase):
 class BufferedPktLineWriterTests(TestCase):
-
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
         self._output = BytesIO()
         self._output = BytesIO()
@@ -257,48 +258,47 @@ class BufferedPktLineWriterTests(TestCase):
         self._output.truncate()
         self._output.truncate()
 
 
     def test_write(self):
     def test_write(self):
-        self._writer.write(b'foo')
-        self.assertOutputEquals(b'')
+        self._writer.write(b"foo")
+        self.assertOutputEquals(b"")
         self._writer.flush()
         self._writer.flush()
-        self.assertOutputEquals(b'0007foo')
+        self.assertOutputEquals(b"0007foo")
 
 
     def test_write_none(self):
     def test_write_none(self):
         self._writer.write(None)
         self._writer.write(None)
-        self.assertOutputEquals(b'')
+        self.assertOutputEquals(b"")
         self._writer.flush()
         self._writer.flush()
-        self.assertOutputEquals(b'0000')
+        self.assertOutputEquals(b"0000")
 
 
     def test_flush_empty(self):
     def test_flush_empty(self):
         self._writer.flush()
         self._writer.flush()
-        self.assertOutputEquals(b'')
+        self.assertOutputEquals(b"")
 
 
     def test_write_multiple(self):
     def test_write_multiple(self):
-        self._writer.write(b'foo')
-        self._writer.write(b'bar')
-        self.assertOutputEquals(b'')
+        self._writer.write(b"foo")
+        self._writer.write(b"bar")
+        self.assertOutputEquals(b"")
         self._writer.flush()
         self._writer.flush()
-        self.assertOutputEquals(b'0007foo0007bar')
+        self.assertOutputEquals(b"0007foo0007bar")
 
 
     def test_write_across_boundary(self):
     def test_write_across_boundary(self):
-        self._writer.write(b'foo')
-        self._writer.write(b'barbaz')
-        self.assertOutputEquals(b'0007foo000abarba')
+        self._writer.write(b"foo")
+        self._writer.write(b"barbaz")
+        self.assertOutputEquals(b"0007foo000abarba")
         self._truncate()
         self._truncate()
         self._writer.flush()
         self._writer.flush()
-        self.assertOutputEquals(b'z')
+        self.assertOutputEquals(b"z")
 
 
     def test_write_to_boundary(self):
     def test_write_to_boundary(self):
-        self._writer.write(b'foo')
-        self._writer.write(b'barba')
-        self.assertOutputEquals(b'0007foo0009barba')
+        self._writer.write(b"foo")
+        self._writer.write(b"barba")
+        self.assertOutputEquals(b"0007foo0009barba")
         self._truncate()
         self._truncate()
-        self._writer.write(b'z')
+        self._writer.write(b"z")
         self._writer.flush()
         self._writer.flush()
-        self.assertOutputEquals(b'0005z')
+        self.assertOutputEquals(b"0005z")
 
 
 
 
 class PktLineParserTests(TestCase):
 class PktLineParserTests(TestCase):
-
     def test_none(self):
     def test_none(self):
         pktlines = []
         pktlines = []
         parser = PktLineParser(pktlines.append)
         parser = PktLineParser(pktlines.append)

+ 102 - 28
dulwich/tests/test_reflog.py

@@ -21,52 +21,126 @@
 
 
 """Tests for dulwich.reflog."""
 """Tests for dulwich.reflog."""
 
 
+from io import BytesIO
 
 
+from dulwich.objects import ZERO_SHA
 from dulwich.reflog import (
 from dulwich.reflog import (
+    drop_reflog_entry,
     format_reflog_line,
     format_reflog_line,
     parse_reflog_line,
     parse_reflog_line,
-    )
+    read_reflog,
+)
 
 
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
-    )
+)
 
 
 
 
 class ReflogLineTests(TestCase):
 class ReflogLineTests(TestCase):
-
     def test_format(self):
     def test_format(self):
         self.assertEqual(
         self.assertEqual(
-            b'0000000000000000000000000000000000000000 '
-            b'49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij '
-            b'<jelmer@jelmer.uk> 1446552482 +0000	'
-            b'clone: from git://jelmer.uk/samba',
+            b"0000000000000000000000000000000000000000 "
+            b"49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij "
+            b"<jelmer@jelmer.uk> 1446552482 +0000	"
+            b"clone: from git://jelmer.uk/samba",
             format_reflog_line(
             format_reflog_line(
-                b'0000000000000000000000000000000000000000',
-                b'49030649db3dfec5a9bc03e5dde4255a14499f16',
-                b'Jelmer Vernooij <jelmer@jelmer.uk>',
-                1446552482, 0, b'clone: from git://jelmer.uk/samba'))
+                b"0000000000000000000000000000000000000000",
+                b"49030649db3dfec5a9bc03e5dde4255a14499f16",
+                b"Jelmer Vernooij <jelmer@jelmer.uk>",
+                1446552482,
+                0,
+                b"clone: from git://jelmer.uk/samba",
+            ),
+        )
 
 
         self.assertEqual(
         self.assertEqual(
-            b'0000000000000000000000000000000000000000 '
-            b'49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij '
-            b'<jelmer@jelmer.uk> 1446552482 +0000	'
-            b'clone: from git://jelmer.uk/samba',
+            b"0000000000000000000000000000000000000000 "
+            b"49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij "
+            b"<jelmer@jelmer.uk> 1446552482 +0000	"
+            b"clone: from git://jelmer.uk/samba",
             format_reflog_line(
             format_reflog_line(
                 None,
                 None,
-                b'49030649db3dfec5a9bc03e5dde4255a14499f16',
-                b'Jelmer Vernooij <jelmer@jelmer.uk>',
-                1446552482, 0, b'clone: from git://jelmer.uk/samba'))
+                b"49030649db3dfec5a9bc03e5dde4255a14499f16",
+                b"Jelmer Vernooij <jelmer@jelmer.uk>",
+                1446552482,
+                0,
+                b"clone: from git://jelmer.uk/samba",
+            ),
+        )
 
 
     def test_parse(self):
     def test_parse(self):
         reflog_line = (
         reflog_line = (
-                 b'0000000000000000000000000000000000000000 '
-                 b'49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij '
-                 b'<jelmer@jelmer.uk> 1446552482 +0000	'
-                 b'clone: from git://jelmer.uk/samba'
-                 )
+            b"0000000000000000000000000000000000000000 "
+            b"49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij "
+            b"<jelmer@jelmer.uk> 1446552482 +0000	"
+            b"clone: from git://jelmer.uk/samba"
+        )
         self.assertEqual(
         self.assertEqual(
-                (b'0000000000000000000000000000000000000000',
-                 b'49030649db3dfec5a9bc03e5dde4255a14499f16',
-                 b'Jelmer Vernooij <jelmer@jelmer.uk>',
-                 1446552482, 0, b'clone: from git://jelmer.uk/samba'),
-                parse_reflog_line(reflog_line))
+            (
+                b"0000000000000000000000000000000000000000",
+                b"49030649db3dfec5a9bc03e5dde4255a14499f16",
+                b"Jelmer Vernooij <jelmer@jelmer.uk>",
+                1446552482,
+                0,
+                b"clone: from git://jelmer.uk/samba",
+            ),
+            parse_reflog_line(reflog_line),
+        )
+
+
+_TEST_REFLOG = (
+    b"0000000000000000000000000000000000000000 "
+    b"49030649db3dfec5a9bc03e5dde4255a14499f16 Jelmer Vernooij "
+    b"<jelmer@jelmer.uk> 1446552482 +0000	"
+    b"clone: from git://jelmer.uk/samba\n"
+    b"49030649db3dfec5a9bc03e5dde4255a14499f16 "
+    b"42d06bd4b77fed026b154d16493e5deab78f02ec Jelmer Vernooij "
+    b"<jelmer@jelmer.uk> 1446552483 +0000	"
+    b"clone: from git://jelmer.uk/samba\n"
+    b"42d06bd4b77fed026b154d16493e5deab78f02ec "
+    b"df6800012397fb85c56e7418dd4eb9405dee075c Jelmer Vernooij "
+    b"<jelmer@jelmer.uk> 1446552484 +0000	"
+    b"clone: from git://jelmer.uk/samba\n"
+)
+
+
+class ReflogDropTests(TestCase):
+    def setUp(self):
+        TestCase.setUp(self)
+        self.f = BytesIO(_TEST_REFLOG)
+        self.original_log = list(read_reflog(self.f))
+        self.f.seek(0)
+
+    def _read_log(self):
+        self.f.seek(0)
+        return list(read_reflog(self.f))
+
+    def test_invalid(self):
+        self.assertRaises(ValueError, drop_reflog_entry, self.f, -1)
+
+    def test_drop_entry(self):
+        drop_reflog_entry(self.f, 0)
+        log = self._read_log()
+        self.assertEqual(len(log), 2)
+        self.assertEqual(self.original_log[0:2], log)
+
+        self.f.seek(0)
+        drop_reflog_entry(self.f, 1)
+        log = self._read_log()
+        self.assertEqual(len(log), 1)
+        self.assertEqual(self.original_log[1], log[0])
+
+    def test_drop_entry_with_rewrite(self):
+        drop_reflog_entry(self.f, 1, True)
+        log = self._read_log()
+        self.assertEqual(len(log), 2)
+        self.assertEqual(self.original_log[0], log[0])
+        self.assertEqual(self.original_log[0].new_sha, log[1].old_sha)
+        self.assertEqual(self.original_log[2].new_sha, log[1].new_sha)
+
+        self.f.seek(0)
+        drop_reflog_entry(self.f, 1, True)
+        log = self._read_log()
+        self.assertEqual(len(log), 1)
+        self.assertEqual(ZERO_SHA, log[0].old_sha)
+        self.assertEqual(self.original_log[2].new_sha, log[0].new_sha)

Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 432 - 354
dulwich/tests/test_refs.py


Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 416 - 322
dulwich/tests/test_repository.py


Файлын зөрүү хэтэрхий том тул дарагдсан байна
+ 297 - 256
dulwich/tests/test_server.py


+ 29 - 23
dulwich/tests/test_utils.py

@@ -22,21 +22,20 @@
 
 
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
-    )
+)
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
-    )
+)
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     make_object,
     make_object,
     build_commit_graph,
     build_commit_graph,
-    )
+)
 
 
 
 
 class BuildCommitGraphTest(TestCase):
 class BuildCommitGraphTest(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(BuildCommitGraphTest, self).setUp()
         super(BuildCommitGraphTest, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
@@ -52,35 +51,42 @@ class BuildCommitGraphTest(TestCase):
         self.assertTrue(c2.commit_time > c1.commit_time)
         self.assertTrue(c2.commit_time > c1.commit_time)
 
 
     def test_merge(self):
     def test_merge(self):
-        c1, c2, c3, c4 = build_commit_graph(self.store,
-                                            [[1], [2, 1], [3, 1], [4, 2, 3]])
+        c1, c2, c3, c4 = build_commit_graph(
+            self.store, [[1], [2, 1], [3, 1], [4, 2, 3]]
+        )
         self.assertEqual([c2.id, c3.id], c4.parents)
         self.assertEqual([c2.id, c3.id], c4.parents)
         self.assertTrue(c4.commit_time > c2.commit_time)
         self.assertTrue(c4.commit_time > c2.commit_time)
         self.assertTrue(c4.commit_time > c3.commit_time)
         self.assertTrue(c4.commit_time > c3.commit_time)
 
 
     def test_missing_parent(self):
     def test_missing_parent(self):
-        self.assertRaises(ValueError, build_commit_graph, self.store,
-                          [[1], [3, 2], [2, 1]])
+        self.assertRaises(
+            ValueError, build_commit_graph, self.store, [[1], [3, 2], [2, 1]]
+        )
 
 
     def test_trees(self):
     def test_trees(self):
-        a1 = make_object(Blob, data=b'aaa1')
-        a2 = make_object(Blob, data=b'aaa2')
-        c1, c2 = build_commit_graph(self.store, [[1], [2, 1]],
-                                    trees={1: [(b'a', a1)],
-                                           2: [(b'a', a2, 0o100644)]})
-        self.assertEqual((0o100644, a1.id), self.store[c1.tree][b'a'])
-        self.assertEqual((0o100644, a2.id), self.store[c2.tree][b'a'])
+        a1 = make_object(Blob, data=b"aaa1")
+        a2 = make_object(Blob, data=b"aaa2")
+        c1, c2 = build_commit_graph(
+            self.store,
+            [[1], [2, 1]],
+            trees={1: [(b"a", a1)], 2: [(b"a", a2, 0o100644)]},
+        )
+        self.assertEqual((0o100644, a1.id), self.store[c1.tree][b"a"])
+        self.assertEqual((0o100644, a2.id), self.store[c2.tree][b"a"])
 
 
     def test_attrs(self):
     def test_attrs(self):
-        c1, c2 = build_commit_graph(self.store, [[1], [2, 1]],
-                                    attrs={1: {'message': b'Hooray!'}})
-        self.assertEqual(b'Hooray!', c1.message)
-        self.assertEqual(b'Commit 2', c2.message)
+        c1, c2 = build_commit_graph(
+            self.store, [[1], [2, 1]], attrs={1: {"message": b"Hooray!"}}
+        )
+        self.assertEqual(b"Hooray!", c1.message)
+        self.assertEqual(b"Commit 2", c2.message)
 
 
     def test_commit_time(self):
     def test_commit_time(self):
-        c1, c2, c3 = build_commit_graph(self.store, [[1], [2, 1], [3, 2]],
-                                        attrs={1: {'commit_time': 124},
-                                               2: {'commit_time': 123}})
+        c1, c2, c3 = build_commit_graph(
+            self.store,
+            [[1], [2, 1], [3, 2]],
+            attrs={1: {"commit_time": 124}, 2: {"commit_time": 123}},
+        )
         self.assertEqual(124, c1.commit_time)
         self.assertEqual(124, c1.commit_time)
         self.assertEqual(123, c2.commit_time)
         self.assertEqual(123, c2.commit_time)
         self.assertTrue(c2.commit_time < c1.commit_time < c3.commit_time)
         self.assertTrue(c2.commit_time < c1.commit_time < c3.commit_time)

+ 220 - 171
dulwich/tests/test_walk.py

@@ -22,7 +22,7 @@
 
 
 from itertools import (
 from itertools import (
     permutations,
     permutations,
-    )
+)
 from unittest import expectedFailure
 from unittest import expectedFailure
 
 
 from dulwich.diff_tree import (
 from dulwich.diff_tree import (
@@ -30,41 +30,37 @@ from dulwich.diff_tree import (
     CHANGE_RENAME,
     CHANGE_RENAME,
     TreeChange,
     TreeChange,
     RenameDetector,
     RenameDetector,
-    )
+)
 from dulwich.errors import (
 from dulwich.errors import (
     MissingCommitError,
     MissingCommitError,
-    )
+)
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
-    )
+)
 from dulwich.objects import (
 from dulwich.objects import (
     Commit,
     Commit,
     Blob,
     Blob,
-    )
-from dulwich.walk import (
-    ORDER_TOPO,
-    WalkEntry,
-    Walker,
-    _topo_reorder
-    )
+)
+from dulwich.walk import ORDER_TOPO, WalkEntry, Walker, _topo_reorder
 from dulwich.tests import TestCase
 from dulwich.tests import TestCase
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     F,
     F,
     make_object,
     make_object,
     make_tag,
     make_tag,
     build_commit_graph,
     build_commit_graph,
-    )
+)
 
 
 
 
 class TestWalkEntry(object):
 class TestWalkEntry(object):
-
     def __init__(self, commit, changes):
     def __init__(self, commit, changes):
         self.commit = commit
         self.commit = commit
         self.changes = changes
         self.changes = changes
 
 
     def __repr__(self):
     def __repr__(self):
-        return '<TestWalkEntry commit=%s, changes=%r>' % (
-          self.commit.id, self.changes)
+        return "<TestWalkEntry commit=%s, changes=%r>" % (
+            self.commit.id,
+            self.changes,
+        )
 
 
     def __eq__(self, other):
     def __eq__(self, other):
         if not isinstance(other, WalkEntry) or self.commit != other.commit:
         if not isinstance(other, WalkEntry) or self.commit != other.commit:
@@ -75,18 +71,16 @@ class TestWalkEntry(object):
 
 
 
 
 class WalkerTest(TestCase):
 class WalkerTest(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(WalkerTest, self).setUp()
         super(WalkerTest, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
 
 
     def make_commits(self, commit_spec, **kwargs):
     def make_commits(self, commit_spec, **kwargs):
-        times = kwargs.pop('times', [])
-        attrs = kwargs.pop('attrs', {})
+        times = kwargs.pop("times", [])
+        attrs = kwargs.pop("attrs", {})
         for i, t in enumerate(times):
         for i, t in enumerate(times):
-            attrs.setdefault(i + 1, {})['commit_time'] = t
-        return build_commit_graph(self.store, commit_spec, attrs=attrs,
-                                  **kwargs)
+            attrs.setdefault(i + 1, {})["commit_time"] = t
+        return build_commit_graph(self.store, commit_spec, attrs=attrs, **kwargs)
 
 
     def make_linear_commits(self, num_commits, **kwargs):
     def make_linear_commits(self, num_commits, **kwargs):
         commit_spec = []
         commit_spec = []
@@ -192,164 +186,210 @@ class WalkerTest(TestCase):
 
 
     def test_reverse_after_max_entries(self):
     def test_reverse_after_max_entries(self):
         c1, c2, c3 = self.make_linear_commits(3)
         c1, c2, c3 = self.make_linear_commits(3)
-        self.assertWalkYields([c1, c2, c3], [c3.id], max_entries=3,
-                              reverse=True)
+        self.assertWalkYields([c1, c2, c3], [c3.id], max_entries=3, reverse=True)
         self.assertWalkYields([c2, c3], [c3.id], max_entries=2, reverse=True)
         self.assertWalkYields([c2, c3], [c3.id], max_entries=2, reverse=True)
         self.assertWalkYields([c3], [c3.id], max_entries=1, reverse=True)
         self.assertWalkYields([c3], [c3.id], max_entries=1, reverse=True)
 
 
     def test_changes_one_parent(self):
     def test_changes_one_parent(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_b2 = make_object(Blob, data=b'b2')
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_b2 = make_object(Blob, data=b"b2")
         c1, c2 = self.make_linear_commits(
         c1, c2 = self.make_linear_commits(
-            2, trees={1: [(b'a', blob_a1)],
-                      2: [(b'a', blob_a2), (b'b', blob_b2)]})
-        e1 = TestWalkEntry(c1, [TreeChange.add((b'a', F, blob_a1.id))])
+            2,
+            trees={
+                1: [(b"a", blob_a1)],
+                2: [(b"a", blob_a2), (b"b", blob_b2)],
+            },
+        )
+        e1 = TestWalkEntry(c1, [TreeChange.add((b"a", F, blob_a1.id))])
         e2 = TestWalkEntry(
         e2 = TestWalkEntry(
-                c2,
-                [TreeChange(CHANGE_MODIFY, (b'a', F, blob_a1.id),
-                                           (b'a', F, blob_a2.id)),
-                 TreeChange.add((b'b', F, blob_b2.id))])
+            c2,
+            [
+                TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a2.id)),
+                TreeChange.add((b"b", F, blob_b2.id)),
+            ],
+        )
         self.assertWalkYields([e2, e1], [c2.id])
         self.assertWalkYields([e2, e1], [c2.id])
 
 
     def test_changes_multiple_parents(self):
     def test_changes_multiple_parents(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_b2 = make_object(Blob, data=b'b2')
-        blob_a3 = make_object(Blob, data=b'a3')
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_b2 = make_object(Blob, data=b"b2")
+        blob_a3 = make_object(Blob, data=b"a3")
         c1, c2, c3 = self.make_commits(
         c1, c2, c3 = self.make_commits(
             [[1], [2], [3, 1, 2]],
             [[1], [2], [3, 1, 2]],
-            trees={1: [(b'a', blob_a1)], 2: [(b'b', blob_b2)],
-                   3: [(b'a', blob_a3), (b'b', blob_b2)]})
+            trees={
+                1: [(b"a", blob_a1)],
+                2: [(b"b", blob_b2)],
+                3: [(b"a", blob_a3), (b"b", blob_b2)],
+            },
+        )
         # a is a modify/add conflict and b is not conflicted.
         # a is a modify/add conflict and b is not conflicted.
-        changes = [[
-                TreeChange(CHANGE_MODIFY,
-                           (b'a', F, blob_a1.id), (b'a', F, blob_a3.id)),
-                TreeChange.add((b'a', F, blob_a3.id)),
-        ]]
-        self.assertWalkYields([TestWalkEntry(c3, changes)], [c3.id],
-                              exclude=[c1.id, c2.id])
+        changes = [
+            [
+                TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a3.id)),
+                TreeChange.add((b"a", F, blob_a3.id)),
+            ]
+        ]
+        self.assertWalkYields(
+            [TestWalkEntry(c3, changes)], [c3.id], exclude=[c1.id, c2.id]
+        )
 
 
     def test_path_matches(self):
     def test_path_matches(self):
-        walker = Walker(None, [], paths=[b'foo', b'bar', b'baz/quux'])
-        self.assertTrue(walker._path_matches(b'foo'))
-        self.assertTrue(walker._path_matches(b'foo/a'))
-        self.assertTrue(walker._path_matches(b'foo/a/b'))
-        self.assertTrue(walker._path_matches(b'bar'))
-        self.assertTrue(walker._path_matches(b'baz/quux'))
-        self.assertTrue(walker._path_matches(b'baz/quux/a'))
+        walker = Walker(None, [], paths=[b"foo", b"bar", b"baz/quux"])
+        self.assertTrue(walker._path_matches(b"foo"))
+        self.assertTrue(walker._path_matches(b"foo/a"))
+        self.assertTrue(walker._path_matches(b"foo/a/b"))
+        self.assertTrue(walker._path_matches(b"bar"))
+        self.assertTrue(walker._path_matches(b"baz/quux"))
+        self.assertTrue(walker._path_matches(b"baz/quux/a"))
 
 
         self.assertFalse(walker._path_matches(None))
         self.assertFalse(walker._path_matches(None))
-        self.assertFalse(walker._path_matches(b'oops'))
-        self.assertFalse(walker._path_matches(b'fool'))
-        self.assertFalse(walker._path_matches(b'baz'))
-        self.assertFalse(walker._path_matches(b'baz/quu'))
+        self.assertFalse(walker._path_matches(b"oops"))
+        self.assertFalse(walker._path_matches(b"fool"))
+        self.assertFalse(walker._path_matches(b"baz"))
+        self.assertFalse(walker._path_matches(b"baz/quu"))
 
 
     def test_paths(self):
     def test_paths(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_b2 = make_object(Blob, data=b'b2')
-        blob_a3 = make_object(Blob, data=b'a3')
-        blob_b3 = make_object(Blob, data=b'b3')
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_b2 = make_object(Blob, data=b"b2")
+        blob_a3 = make_object(Blob, data=b"a3")
+        blob_b3 = make_object(Blob, data=b"b3")
         c1, c2, c3 = self.make_linear_commits(
         c1, c2, c3 = self.make_linear_commits(
-            3, trees={1: [(b'a', blob_a1)],
-                      2: [(b'a', blob_a1), (b'x/b', blob_b2)],
-                      3: [(b'a', blob_a3), (b'x/b', blob_b3)]})
+            3,
+            trees={
+                1: [(b"a", blob_a1)],
+                2: [(b"a", blob_a1), (b"x/b", blob_b2)],
+                3: [(b"a", blob_a3), (b"x/b", blob_b3)],
+            },
+        )
 
 
         self.assertWalkYields([c3, c2, c1], [c3.id])
         self.assertWalkYields([c3, c2, c1], [c3.id])
-        self.assertWalkYields([c3, c1], [c3.id], paths=[b'a'])
-        self.assertWalkYields([c3, c2], [c3.id], paths=[b'x/b'])
+        self.assertWalkYields([c3, c1], [c3.id], paths=[b"a"])
+        self.assertWalkYields([c3, c2], [c3.id], paths=[b"x/b"])
 
 
         # All changes are included, not just for requested paths.
         # All changes are included, not just for requested paths.
         changes = [
         changes = [
-            TreeChange(CHANGE_MODIFY, (b'a', F, blob_a1.id),
-                       (b'a', F, blob_a3.id)),
-            TreeChange(CHANGE_MODIFY, (b'x/b', F, blob_b2.id),
-                       (b'x/b', F, blob_b3.id)),
+            TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a3.id)),
+            TreeChange(CHANGE_MODIFY, (b"x/b", F, blob_b2.id), (b"x/b", F, blob_b3.id)),
         ]
         ]
-        self.assertWalkYields([TestWalkEntry(c3, changes)], [c3.id],
-                              max_entries=1, paths=[b'a'])
+        self.assertWalkYields(
+            [TestWalkEntry(c3, changes)], [c3.id], max_entries=1, paths=[b"a"]
+        )
 
 
     def test_paths_subtree(self):
     def test_paths_subtree(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
         c1, c2, c3 = self.make_linear_commits(
         c1, c2, c3 = self.make_linear_commits(
-            3, trees={1: [(b'x/a', blob_a)],
-                      2: [(b'b', blob_b), (b'x/a', blob_a)],
-                      3: [(b'b', blob_b), (b'x/a', blob_a), (b'x/b', blob_b)]})
-        self.assertWalkYields([c2], [c3.id], paths=[b'b'])
-        self.assertWalkYields([c3, c1], [c3.id], paths=[b'x'])
+            3,
+            trees={
+                1: [(b"x/a", blob_a)],
+                2: [(b"b", blob_b), (b"x/a", blob_a)],
+                3: [(b"b", blob_b), (b"x/a", blob_a), (b"x/b", blob_b)],
+            },
+        )
+        self.assertWalkYields([c2], [c3.id], paths=[b"b"])
+        self.assertWalkYields([c3, c1], [c3.id], paths=[b"x"])
 
 
     def test_paths_max_entries(self):
     def test_paths_max_entries(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
         c1, c2 = self.make_linear_commits(
         c1, c2 = self.make_linear_commits(
-            2, trees={1: [(b'a', blob_a)],
-                      2: [(b'a', blob_a), (b'b', blob_b)]})
-        self.assertWalkYields([c2], [c2.id], paths=[b'b'], max_entries=1)
-        self.assertWalkYields([c1], [c1.id], paths=[b'a'], max_entries=1)
+            2, trees={1: [(b"a", blob_a)], 2: [(b"a", blob_a), (b"b", blob_b)]}
+        )
+        self.assertWalkYields([c2], [c2.id], paths=[b"b"], max_entries=1)
+        self.assertWalkYields([c1], [c1.id], paths=[b"a"], max_entries=1)
 
 
     def test_paths_merge(self):
     def test_paths_merge(self):
-        blob_a1 = make_object(Blob, data=b'a1')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_a3 = make_object(Blob, data=b'a3')
+        blob_a1 = make_object(Blob, data=b"a1")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_a3 = make_object(Blob, data=b"a3")
         x1, y2, m3, m4 = self.make_commits(
         x1, y2, m3, m4 = self.make_commits(
             [[1], [2], [3, 1, 2], [4, 1, 2]],
             [[1], [2], [3, 1, 2], [4, 1, 2]],
-            trees={1: [(b'a', blob_a1)],
-                   2: [(b'a', blob_a2)],
-                   3: [(b'a', blob_a3)],
-                   4: [(b'a', blob_a1)]})  # Non-conflicting
-        self.assertWalkYields([m3, y2, x1], [m3.id], paths=[b'a'])
-        self.assertWalkYields([y2, x1], [m4.id], paths=[b'a'])
+            trees={
+                1: [(b"a", blob_a1)],
+                2: [(b"a", blob_a2)],
+                3: [(b"a", blob_a3)],
+                4: [(b"a", blob_a1)],
+            },
+        )  # Non-conflicting
+        self.assertWalkYields([m3, y2, x1], [m3.id], paths=[b"a"])
+        self.assertWalkYields([y2, x1], [m4.id], paths=[b"a"])
 
 
     def test_changes_with_renames(self):
     def test_changes_with_renames(self):
-        blob = make_object(Blob, data=b'blob')
+        blob = make_object(Blob, data=b"blob")
         c1, c2 = self.make_linear_commits(
         c1, c2 = self.make_linear_commits(
-            2, trees={1: [(b'a', blob)], 2: [(b'b', blob)]})
-        entry_a = (b'a', F, blob.id)
-        entry_b = (b'b', F, blob.id)
-        changes_without_renames = [TreeChange.delete(entry_a),
-                                   TreeChange.add(entry_b)]
+            2, trees={1: [(b"a", blob)], 2: [(b"b", blob)]}
+        )
+        entry_a = (b"a", F, blob.id)
+        entry_b = (b"b", F, blob.id)
+        changes_without_renames = [
+            TreeChange.delete(entry_a),
+            TreeChange.add(entry_b),
+        ]
         changes_with_renames = [TreeChange(CHANGE_RENAME, entry_a, entry_b)]
         changes_with_renames = [TreeChange(CHANGE_RENAME, entry_a, entry_b)]
         self.assertWalkYields(
         self.assertWalkYields(
-          [TestWalkEntry(c2, changes_without_renames)], [c2.id], max_entries=1)
+            [TestWalkEntry(c2, changes_without_renames)],
+            [c2.id],
+            max_entries=1,
+        )
         detector = RenameDetector(self.store)
         detector = RenameDetector(self.store)
         self.assertWalkYields(
         self.assertWalkYields(
-          [TestWalkEntry(c2, changes_with_renames)], [c2.id], max_entries=1,
-          rename_detector=detector)
+            [TestWalkEntry(c2, changes_with_renames)],
+            [c2.id],
+            max_entries=1,
+            rename_detector=detector,
+        )
 
 
     def test_follow_rename(self):
     def test_follow_rename(self):
-        blob = make_object(Blob, data=b'blob')
-        names = [b'a', b'a', b'b', b'b', b'c', b'c']
+        blob = make_object(Blob, data=b"blob")
+        names = [b"a", b"a", b"b", b"b", b"c", b"c"]
 
 
-        trees = dict((i + 1, [(n, blob, F)]) for i, n in enumerate(names))
+        trees = {i + 1: [(n, blob, F)] for i, n in enumerate(names)}
         c1, c2, c3, c4, c5, c6 = self.make_linear_commits(6, trees=trees)
         c1, c2, c3, c4, c5, c6 = self.make_linear_commits(6, trees=trees)
-        self.assertWalkYields([c5], [c6.id], paths=[b'c'])
+        self.assertWalkYields([c5], [c6.id], paths=[b"c"])
 
 
         def e(n):
         def e(n):
             return (n, F, blob.id)
             return (n, F, blob.id)
+
         self.assertWalkYields(
         self.assertWalkYields(
-            [TestWalkEntry(c5, [TreeChange(CHANGE_RENAME, e(b'b'), e(b'c'))]),
-             TestWalkEntry(c3, [TreeChange(CHANGE_RENAME, e(b'a'), e(b'b'))]),
-             TestWalkEntry(c1, [TreeChange.add(e(b'a'))])],
-            [c6.id], paths=[b'c'], follow=True)
+            [
+                TestWalkEntry(c5, [TreeChange(CHANGE_RENAME, e(b"b"), e(b"c"))]),
+                TestWalkEntry(c3, [TreeChange(CHANGE_RENAME, e(b"a"), e(b"b"))]),
+                TestWalkEntry(c1, [TreeChange.add(e(b"a"))]),
+            ],
+            [c6.id],
+            paths=[b"c"],
+            follow=True,
+        )
 
 
     def test_follow_rename_remove_path(self):
     def test_follow_rename_remove_path(self):
-        blob = make_object(Blob, data=b'blob')
+        blob = make_object(Blob, data=b"blob")
         _, _, _, c4, c5, c6 = self.make_linear_commits(
         _, _, _, c4, c5, c6 = self.make_linear_commits(
-            6, trees={1: [(b'a', blob), (b'c', blob)],
-                      2: [],
-                      3: [],
-                      4: [(b'b', blob)],
-                      5: [(b'a', blob)],
-                      6: [(b'c', blob)]})
+            6,
+            trees={
+                1: [(b"a", blob), (b"c", blob)],
+                2: [],
+                3: [],
+                4: [(b"b", blob)],
+                5: [(b"a", blob)],
+                6: [(b"c", blob)],
+            },
+        )
 
 
         def e(n):
         def e(n):
             return (n, F, blob.id)
             return (n, F, blob.id)
+
         # Once the path changes to b, we aren't interested in a or c anymore.
         # Once the path changes to b, we aren't interested in a or c anymore.
         self.assertWalkYields(
         self.assertWalkYields(
-            [TestWalkEntry(c6, [TreeChange(CHANGE_RENAME, e(b'a'), e(b'c'))]),
-             TestWalkEntry(c5, [TreeChange(CHANGE_RENAME, e(b'b'), e(b'a'))]),
-             TestWalkEntry(c4, [TreeChange.add(e(b'b'))])],
-            [c6.id], paths=[b'c'], follow=True)
+            [
+                TestWalkEntry(c6, [TreeChange(CHANGE_RENAME, e(b"a"), e(b"c"))]),
+                TestWalkEntry(c5, [TreeChange(CHANGE_RENAME, e(b"b"), e(b"a"))]),
+                TestWalkEntry(c4, [TreeChange.add(e(b"b"))]),
+            ],
+            [c6.id],
+            paths=[b"c"],
+            follow=True,
+        )
 
 
     def test_since(self):
     def test_since(self):
         c1, c2, c3 = self.make_linear_commits(3)
         c1, c2, c3 = self.make_linear_commits(3)
@@ -385,8 +425,7 @@ class WalkerTest(TestCase):
         self.assertWalkYields([c2], [c3.id], since=50, until=150)
         self.assertWalkYields([c2], [c3.id], since=50, until=150)
 
 
     def test_since_over_scan(self):
     def test_since_over_scan(self):
-        commits = self.make_linear_commits(
-          11, times=[9, 0, 1, 2, 3, 4, 5, 8, 6, 7, 9])
+        commits = self.make_linear_commits(11, times=[9, 0, 1, 2, 3, 4, 5, 8, 6, 7, 9])
         c8, _, c10, c11 = commits[-4:]
         c8, _, c10, c11 = commits[-4:]
         del self.store[commits[0].id]
         del self.store[commits[0].id]
         # c9 is older than we want to walk, but is out of order with its
         # c9 is older than we want to walk, but is out of order with its
@@ -434,8 +473,8 @@ class WalkerTest(TestCase):
 
 
     def test_out_of_order_children(self):
     def test_out_of_order_children(self):
         c1, c2, c3, c4, c5 = self.make_commits(
         c1, c2, c3, c4, c5 = self.make_commits(
-          [[1], [2, 1], [3, 2], [4, 1], [5, 3, 4]],
-          times=[2, 1, 3, 4, 5])
+            [[1], [2, 1], [3, 2], [4, 1], [5, 3, 4]], times=[2, 1, 3, 4, 5]
+        )
         self.assertWalkYields([c5, c4, c3, c1, c2], [c5.id])
         self.assertWalkYields([c5, c4, c3, c1, c2], [c5.id])
         self.assertWalkYields([c5, c4, c3, c2, c1], [c5.id], order=ORDER_TOPO)
         self.assertWalkYields([c5, c4, c3, c2, c1], [c5.id], order=ORDER_TOPO)
 
 
@@ -446,8 +485,9 @@ class WalkerTest(TestCase):
         #    \-y3--y4-/--y5
         #    \-y3--y4-/--y5
         # Due to skew, y5 is the oldest commit.
         # Due to skew, y5 is the oldest commit.
         c1, x2, y3, y4, y5, m6 = self.make_commits(
         c1, x2, y3, y4, y5, m6 = self.make_commits(
-          [[1], [2, 1], [3, 1], [4, 3], [5, 4], [6, 2, 4]],
-          times=[2, 3, 4, 5, 1, 6])
+            [[1], [2, 1], [3, 1], [4, 3], [5, 4], [6, 2, 4]],
+            times=[2, 3, 4, 5, 1, 6],
+        )
         self.assertWalkYields([m6, y4, y3, x2, c1], [m6.id])
         self.assertWalkYields([m6, y4, y3, x2, c1], [m6.id])
         # Ensure that c1..y4 get excluded even though they're popped from the
         # Ensure that c1..y4 get excluded even though they're popped from the
         # priority queue long before y5.
         # priority queue long before y5.
@@ -459,18 +499,16 @@ class WalkerTest(TestCase):
 
 
 
 
 class WalkEntryTest(TestCase):
 class WalkEntryTest(TestCase):
-
     def setUp(self):
     def setUp(self):
         super(WalkEntryTest, self).setUp()
         super(WalkEntryTest, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
 
 
     def make_commits(self, commit_spec, **kwargs):
     def make_commits(self, commit_spec, **kwargs):
-        times = kwargs.pop('times', [])
-        attrs = kwargs.pop('attrs', {})
+        times = kwargs.pop("times", [])
+        attrs = kwargs.pop("attrs", {})
         for i, t in enumerate(times):
         for i, t in enumerate(times):
-            attrs.setdefault(i + 1, {})['commit_time'] = t
-        return build_commit_graph(self.store, commit_spec, attrs=attrs,
-                                  **kwargs)
+            attrs.setdefault(i + 1, {})["commit_time"] = t
+        return build_commit_graph(self.store, commit_spec, attrs=attrs, **kwargs)
 
 
     def make_linear_commits(self, num_commits, **kwargs):
     def make_linear_commits(self, num_commits, **kwargs):
         commit_spec = []
         commit_spec = []
@@ -483,11 +521,11 @@ class WalkEntryTest(TestCase):
 
 
     def test_all_changes(self):
     def test_all_changes(self):
         # Construct a commit with 2 files in different subdirectories.
         # Construct a commit with 2 files in different subdirectories.
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
         c1 = self.make_linear_commits(
         c1 = self.make_linear_commits(
             1,
             1,
-            trees={1: [(b'x/a', blob_a), (b'y/b', blob_b)]},
+            trees={1: [(b"x/a", blob_a), (b"y/b", blob_b)]},
         )[0]
         )[0]
 
 
         # Get the WalkEntry for the commit.
         # Get the WalkEntry for the commit.
@@ -496,24 +534,26 @@ class WalkEntryTest(TestCase):
         changes = walker_entry.changes()
         changes = walker_entry.changes()
 
 
         # Compare the changes with the expected values.
         # Compare the changes with the expected values.
-        entry_a = (b'x/a', F, blob_a.id)
-        entry_b = (b'y/b', F, blob_b.id)
+        entry_a = (b"x/a", F, blob_a.id)
+        entry_b = (b"y/b", F, blob_b.id)
         self.assertEqual(
         self.assertEqual(
-            [TreeChange.add(entry_a),
-             TreeChange.add(entry_b)],
+            [TreeChange.add(entry_a), TreeChange.add(entry_b)],
             changes,
             changes,
         )
         )
 
 
     def test_all_with_merge(self):
     def test_all_with_merge(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_b = make_object(Blob, data=b'b')
-        blob_b2 = make_object(Blob, data=b'b2')
+        blob_a = make_object(Blob, data=b"a")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_b = make_object(Blob, data=b"b")
+        blob_b2 = make_object(Blob, data=b"b2")
         x1, y2, m3 = self.make_commits(
         x1, y2, m3 = self.make_commits(
             [[1], [2], [3, 1, 2]],
             [[1], [2], [3, 1, 2]],
-            trees={1: [(b'x/a', blob_a)],
-                   2: [(b'y/b', blob_b)],
-                   3: [(b'x/a', blob_a2), (b'y/b', blob_b2)]})
+            trees={
+                1: [(b"x/a", blob_a)],
+                2: [(b"y/b", blob_b)],
+                3: [(b"x/a", blob_a2), (b"y/b", blob_b2)],
+            },
+        )
 
 
         # Get the WalkEntry for the merge commit.
         # Get the WalkEntry for the merge commit.
         walker = Walker(self.store, m3.id)
         walker = Walker(self.store, m3.id)
@@ -523,60 +563,69 @@ class WalkEntryTest(TestCase):
         changes = walker_entry.changes()
         changes = walker_entry.changes()
         self.assertEqual(2, len(changes))
         self.assertEqual(2, len(changes))
 
 
-        entry_a = (b'x/a', F, blob_a.id)
-        entry_a2 = (b'x/a', F, blob_a2.id)
-        entry_b = (b'y/b', F, blob_b.id)
-        entry_b2 = (b'y/b', F, blob_b2.id)
+        entry_a = (b"x/a", F, blob_a.id)
+        entry_a2 = (b"x/a", F, blob_a2.id)
+        entry_b = (b"y/b", F, blob_b.id)
+        entry_b2 = (b"y/b", F, blob_b2.id)
         self.assertEqual(
         self.assertEqual(
-                [[TreeChange(CHANGE_MODIFY, entry_a, entry_a2),
-                  TreeChange.add(entry_a2)],
-                 [TreeChange.add(entry_b2),
-                  TreeChange(CHANGE_MODIFY, entry_b, entry_b2)]],
-                changes,
+            [
+                [
+                    TreeChange(CHANGE_MODIFY, entry_a, entry_a2),
+                    TreeChange.add(entry_a2),
+                ],
+                [
+                    TreeChange.add(entry_b2),
+                    TreeChange(CHANGE_MODIFY, entry_b, entry_b2),
+                ],
+            ],
+            changes,
         )
         )
 
 
     def test_filter_changes(self):
     def test_filter_changes(self):
         # Construct a commit with 2 files in different subdirectories.
         # Construct a commit with 2 files in different subdirectories.
-        blob_a = make_object(Blob, data=b'a')
-        blob_b = make_object(Blob, data=b'b')
+        blob_a = make_object(Blob, data=b"a")
+        blob_b = make_object(Blob, data=b"b")
         c1 = self.make_linear_commits(
         c1 = self.make_linear_commits(
             1,
             1,
-            trees={1: [(b'x/a', blob_a), (b'y/b', blob_b)]},
+            trees={1: [(b"x/a", blob_a), (b"y/b", blob_b)]},
         )[0]
         )[0]
 
 
         # Get the WalkEntry for the commit.
         # Get the WalkEntry for the commit.
         walker = Walker(self.store, c1.id)
         walker = Walker(self.store, c1.id)
         walker_entry = list(walker)[0]
         walker_entry = list(walker)[0]
-        changes = walker_entry.changes(path_prefix=b'x')
+        changes = walker_entry.changes(path_prefix=b"x")
 
 
         # Compare the changes with the expected values.
         # Compare the changes with the expected values.
-        entry_a = (b'a', F, blob_a.id)
+        entry_a = (b"a", F, blob_a.id)
         self.assertEqual(
         self.assertEqual(
             [TreeChange.add(entry_a)],
             [TreeChange.add(entry_a)],
             changes,
             changes,
         )
         )
 
 
     def test_filter_with_merge(self):
     def test_filter_with_merge(self):
-        blob_a = make_object(Blob, data=b'a')
-        blob_a2 = make_object(Blob, data=b'a2')
-        blob_b = make_object(Blob, data=b'b')
-        blob_b2 = make_object(Blob, data=b'b2')
+        blob_a = make_object(Blob, data=b"a")
+        blob_a2 = make_object(Blob, data=b"a2")
+        blob_b = make_object(Blob, data=b"b")
+        blob_b2 = make_object(Blob, data=b"b2")
         x1, y2, m3 = self.make_commits(
         x1, y2, m3 = self.make_commits(
             [[1], [2], [3, 1, 2]],
             [[1], [2], [3, 1, 2]],
-            trees={1: [(b'x/a', blob_a)],
-                   2: [(b'y/b', blob_b)],
-                   3: [(b'x/a', blob_a2), (b'y/b', blob_b2)]})
+            trees={
+                1: [(b"x/a", blob_a)],
+                2: [(b"y/b", blob_b)],
+                3: [(b"x/a", blob_a2), (b"y/b", blob_b2)],
+            },
+        )
 
 
         # Get the WalkEntry for the merge commit.
         # Get the WalkEntry for the merge commit.
         walker = Walker(self.store, m3.id)
         walker = Walker(self.store, m3.id)
         entries = list(walker)
         entries = list(walker)
         walker_entry = entries[0]
         walker_entry = entries[0]
         self.assertEqual(walker_entry.commit.id, m3.id)
         self.assertEqual(walker_entry.commit.id, m3.id)
-        changes = walker_entry.changes(b'x')
+        changes = walker_entry.changes(b"x")
         self.assertEqual(1, len(changes))
         self.assertEqual(1, len(changes))
 
 
-        entry_a = (b'a', F, blob_a.id)
-        entry_a2 = (b'a', F, blob_a2.id)
+        entry_a = (b"a", F, blob_a.id)
+        entry_a2 = (b"a", F, blob_a2.id)
         self.assertEqual(
         self.assertEqual(
             [[TreeChange(CHANGE_MODIFY, entry_a, entry_a2)]],
             [[TreeChange(CHANGE_MODIFY, entry_a, entry_a2)]],
             changes,
             changes,

Энэ ялгаанд хэт олон файл өөрчлөгдсөн тул зарим файлыг харуулаагүй болно