Explorar o código

Import upstream version 0.20.50

Jelmer Vernooij %!s(int64=2) %!d(string=hai) anos
pai
achega
b6d8d6d904
Modificáronse 53 ficheiros con 1285 adicións e 719 borrados
  1. 24 0
      .github/workflows/disperse.yml
  2. 23 0
      .github/workflows/docs.yml
  3. 2 2
      .github/workflows/pythontest.yml
  4. 19 58
      .github/workflows/pythonwheels.yml
  5. 2 2
      CONTRIBUTING.rst
  6. 1 1
      Makefile
  7. 26 0
      NEWS
  8. 4 8
      PKG-INFO
  9. 3 7
      disperse.conf
  10. 4 8
      dulwich.egg-info/PKG-INFO
  11. 6 2
      dulwich.egg-info/SOURCES.txt
  12. 0 1
      dulwich.egg-info/entry_points.txt
  13. 1 1
      dulwich/__init__.py
  14. 1 1
      dulwich/archive.py
  15. 1 1
      dulwich/bundle.py
  16. 34 12
      dulwich/cli.py
  17. 17 16
      dulwich/client.py
  18. 15 4
      dulwich/config.py
  19. 3 2
      dulwich/contrib/swift.py
  20. 89 0
      dulwich/credentials.py
  21. 2 2
      dulwich/file.py
  22. 1 1
      dulwich/hooks.py
  23. 101 47
      dulwich/index.py
  24. 2 2
      dulwich/lfs.py
  25. 14 13
      dulwich/object_store.py
  26. 75 65
      dulwich/objects.py
  27. 257 161
      dulwich/pack.py
  28. 59 27
      dulwich/porcelain.py
  29. 36 40
      dulwich/protocol.py
  30. 10 8
      dulwich/refs.py
  31. 87 44
      dulwich/repo.py
  32. 33 38
      dulwich/server.py
  33. 1 0
      dulwich/tests/__init__.py
  34. 1 1
      dulwich/tests/compat/test_utils.py
  35. 14 7
      dulwich/tests/compat/utils.py
  36. 12 0
      dulwich/tests/test_archive.py
  37. 2 2
      dulwich/tests/test_client.py
  38. 13 0
      dulwich/tests/test_config.py
  39. 75 0
      dulwich/tests/test_credentials.py
  40. 4 4
      dulwich/tests/test_fastexport.py
  41. 0 4
      dulwich/tests/test_index.py
  42. 2 2
      dulwich/tests/test_object_store.py
  43. 34 1
      dulwich/tests/test_objects.py
  44. 13 12
      dulwich/tests/test_pack.py
  45. 4 0
      dulwich/tests/test_porcelain.py
  46. 12 2
      dulwich/tests/test_refs.py
  47. 37 1
      dulwich/tests/test_repository.py
  48. 2 1
      dulwich/tests/test_server.py
  49. 7 7
      dulwich/tests/utils.py
  50. 11 8
      dulwich/walk.py
  51. 21 13
      dulwich/web.py
  52. 56 0
      setup.cfg
  53. 12 80
      setup.py

+ 24 - 0
.github/workflows/disperse.yml

@@ -0,0 +1,24 @@
+---
+name: Disperse configuration
+
+"on":
+  - push
+
+jobs:
+  build:
+
+    runs-on: ubuntu-latest
+
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+      - name: Install dependencies
+        run: |
+          sudo apt install protobuf-compiler
+      - name: Install disperse
+        run: |
+          pip install git+https://github.com/jelmer/disperse
+      - name: Validate disperse.conf
+        run: |
+          PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python disperse validate .

+ 23 - 0
.github/workflows/docs.yml

@@ -0,0 +1,23 @@
+
+name: API Docs
+
+on:
+  push:
+  pull_request:
+  schedule:
+    - cron: "0 6 * * *" # Daily 6AM UTC build
+
+jobs:
+  test:
+    runs-on: ubuntu-latest
+
+    steps:
+      - uses: actions/checkout@v2
+      - name: Set up Python
+        uses: actions/setup-python@v2
+      - name: Install pydoctor
+        run: |
+          sudo apt-get update && sudo apt -y install -y pydoctor python3-pip
+          pip3 install pydoctor
+      - name: Generate docs
+        run: make apidocs

+ 2 - 2
.github/workflows/pythonpackage.yml → .github/workflows/pythontest.yml

@@ -1,4 +1,4 @@
-name: Python package
+name: Python tests
 
 
 on:
 on:
   push:
   push:
@@ -7,7 +7,7 @@ on:
     - cron: "0 6 * * *" # Daily 6AM UTC build
     - cron: "0 6 * * *" # Daily 6AM UTC build
 
 
 jobs:
 jobs:
-  build:
+  test:
     runs-on: ${{ matrix.os }}
     runs-on: ${{ matrix.os }}
     strategy:
     strategy:
       matrix:
       matrix:

+ 19 - 58
.github/workflows/pythonwheels.yml

@@ -2,7 +2,6 @@ name: Build Python Wheels
 
 
 on:
 on:
   push:
   push:
-  pull_request:
   schedule:
   schedule:
     - cron: "0 6 * * *" # Daily 6AM UTC build
     - cron: "0 6 * * *" # Daily 6AM UTC build
 
 
@@ -11,25 +10,12 @@ jobs:
     runs-on: ${{ matrix.os }}
     runs-on: ${{ matrix.os }}
     strategy:
     strategy:
       matrix:
       matrix:
-        os: [macos-latest, windows-latest]
-        python-version: ["3.6", "3.7", "3.8", "3.9", "3.10", "3.11.0-rc - 3.11"]
-        architecture: ["x64", "x86"]
-        include:
-          - os: ubuntu-latest
-            python-version: "3.x"
-          # path encoding
-        exclude:
-          - os: macos-latest
-            architecture: "x86"
+        os: [ubuntu-latest, macos-latest, windows-latest]
       fail-fast: true
       fail-fast: true
 
 
     steps:
     steps:
-      - uses: actions/checkout@v2
-      - name: Set up Python ${{ matrix.python-version }}
-        uses: actions/setup-python@v2
-        with:
-          python-version: ${{ matrix.python-version }}
-          architecture: ${{ matrix.architecture }}
+      - uses: actions/checkout@v3
+      - uses: actions/setup-python@v3
       - name: Install native dependencies (Ubuntu)
       - name: Install native dependencies (Ubuntu)
         run: sudo apt-get update && sudo apt-get install -y libgpgme-dev libgpg-error-dev
         run: sudo apt-get update && sudo apt-get install -y libgpgme-dev libgpg-error-dev
         if: "matrix.os == 'ubuntu-latest'"
         if: "matrix.os == 'ubuntu-latest'"
@@ -39,48 +25,25 @@ jobs:
       - name: Install dependencies
       - name: Install dependencies
         run: |
         run: |
           python -m pip install --upgrade pip
           python -m pip install --upgrade pip
-          pip install setuptools wheel fastimport paramiko urllib3
+          pip install setuptools wheel fastimport paramiko urllib3 cibuildwheel==2.9.0
       - name: Install gpg on supported platforms
       - name: Install gpg on supported platforms
         run: pip install -U gpg
         run: pip install -U gpg
-        if: "matrix.os != 'windows-latest' && matrix.python-version != 'pypy3'"
+        if: "matrix.os != 'windows-latest'"
       - name: Run test suite
       - name: Run test suite
-        run: |
-          python -m unittest dulwich.tests.test_suite
-      - name: Build
-        run: |
-          python setup.py sdist bdist_wheel
-        if: "matrix.os != 'ubuntu-latest'"
-      - uses: docker/setup-qemu-action@v1
-        name: Set up QEMU
+        run: python -m unittest dulwich.tests.test_suite
+      - name: Set up QEMU
+        uses: docker/setup-qemu-action@v1
         if: "matrix.os == 'ubuntu-latest'"
         if: "matrix.os == 'ubuntu-latest'"
-      - name: Build (Linux aarch64)
-        uses: RalfG/python-wheels-manylinux-build@v0.5.0-manylinux2014_aarch64
-        with:
-          python-versions: "cp36-cp36m cp37-cp37m cp38-cp38 cp39-cp39 cp310-cp310 cp311-cp311"
-        if: "matrix.os == 'ubuntu-latest'"
-      - name: Build (Linux)
-        uses: RalfG/python-wheels-manylinux-build@v0.5.0
-        with:
-          python-versions: "cp36-cp36m cp37-cp37m cp38-cp38 cp39-cp39 cp310-cp310 cp311-cp311"
+      - name: Build wheels
+        run: python -m cibuildwheel --output-dir wheelhouse
         env:
         env:
-          # Temporary fix for LD_LIBRARY_PATH issue. See
-          # https://github.com/RalfG/python-wheels-manylinux-build/issues/26
-          LD_LIBRARY_PATH: /usr/local/lib:${{ env.LD_LIBRARY_PATH }}
-        if: "matrix.os == 'ubuntu-latest'"
-      - name: Upload wheels (Linux)
-        uses: actions/upload-artifact@v2
-        # Only include *manylinux* wheels; the other wheels files are built but
-        # rejected by pip.
-        if: "matrix.os == 'ubuntu-latest'"
-        with:
-          name: dist
-          path: dist/*manylinux*.whl
-      - name: Upload wheels (non-Linux)
-        uses: actions/upload-artifact@v2
+          CIBW_ARCHS_LINUX: x86_64 aarch64
+          CIBW_ARCHS_MACOS: x86_64 arm64 universal2
+          CIBW_ARCHS_WINDOWS: AMD64 x86
+      - name: Upload wheels
+        uses: actions/upload-artifact@v3
         with:
         with:
-          name: dist
-          path: dist/*.whl
-        if: "matrix.os != 'ubuntu-latest'"
+          path: ./wheelhouse/*.whl
 
 
   publish:
   publish:
     runs-on: ubuntu-latest
     runs-on: ubuntu-latest
@@ -88,10 +51,8 @@ jobs:
     needs: build
     needs: build
     if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/dulwich-')
     if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/dulwich-')
     steps:
     steps:
-      - name: Set up Python
-        uses: actions/setup-python@v2
-        with:
-          python-version: "3.x"
+      - uses: actions/setup-python@v3
+
       - name: Install twine
       - name: Install twine
         run: |
         run: |
           python -m pip install --upgrade pip
           python -m pip install --upgrade pip
@@ -102,4 +63,4 @@ jobs:
         env:
         env:
           TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
           TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
           TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
           TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
-        run: twine upload dist/*.whl
+        run: twine upload artifact/*.whl

+ 2 - 2
CONTRIBUTING.rst

@@ -9,7 +9,7 @@ New functionality and bug fixes should be accompanied by matching unit tests.
 
 
 Coding style
 Coding style
 ------------
 ------------
-Where possible, please follow PEP8 with regard to coding style.
+Where possible, please follow PEP8 with regard to coding style. Run flake8.
 
 
 Furthermore, triple-quotes should always be """, single quotes are ' unless
 Furthermore, triple-quotes should always be """, single quotes are ' unless
 using " would result in less escaping within the string.
 using " would result in less escaping within the string.
@@ -26,7 +26,7 @@ will run the tests using unittest.
 ::
 ::
    $ make check
    $ make check
 
 
-Tox configuration is also present as well as a Travis configuration file.
+Tox configuration is also present.
 
 
 String Types
 String Types
 ------------
 ------------

+ 1 - 1
Makefile

@@ -74,4 +74,4 @@ coverage-html: coverage
 .PHONY: apidocs
 .PHONY: apidocs
 
 
 apidocs:
 apidocs:
-	pydoctor --intersphinx http://urllib3.readthedocs.org/en/latest/objects.inv --intersphinx http://docs.python.org/3/objects.inv --docformat=google dulwich --project-url=https://www.dulwich.io/
+	pydoctor --intersphinx http://urllib3.readthedocs.org/en/latest/objects.inv --intersphinx http://docs.python.org/3/objects.inv --docformat=google dulwich --project-url=https://www.dulwich.io/ --project-name=dulwich

+ 26 - 0
NEWS

@@ -1,3 +1,29 @@
+0.20.50	2022-10-30
+
+ * Fix Repo.reset_index.
+   Previously, it instead took the union with the given tree.
+   (Christian Sattler, #1072)
+
+ * Add -b argument to ``dulwich clone``.
+   (Jelmer Vernooij)
+
+ * On Windows, provide a hint about developer mode
+   when creating symlinks fails due to a permission
+   error. (Jelmer Vernooij, #1005)
+
+ * Add new ``ObjectID`` type in ``dulwich.objects``,
+   currently just an alias for ``bytes``.
+   (Jelmer Vernooij)
+
+ * Support repository format version 1.
+   (Jelmer Vernooij, #1056)
+
+ * Support \r\n line endings with continuations when parsing
+   configuration files.  (Jelmer Vernooij)
+
+ * Fix handling of SymrefLoop in RefsContainer.__setitem__.
+   (Dominic Davis-Foster, Jelmer Vernooij)
+
 0.20.46	2022-09-06
 0.20.46	2022-09-06
 
 
  * Apply insteadOf to rsync-style location strings
  * Apply insteadOf to rsync-style location strings

+ 4 - 8
PKG-INFO

@@ -1,16 +1,15 @@
 Metadata-Version: 2.1
 Metadata-Version: 2.1
 Name: dulwich
 Name: dulwich
-Version: 0.20.46
+Version: 0.20.50
 Summary: Python Git Library
 Summary: Python Git Library
 Home-page: https://www.dulwich.io/
 Home-page: https://www.dulwich.io/
 Author: Jelmer Vernooij
 Author: Jelmer Vernooij
 Author-email: jelmer@jelmer.uk
 Author-email: jelmer@jelmer.uk
 License: Apachev2 or later or GPLv2
 License: Apachev2 or later or GPLv2
-Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
 Project-URL: Repository, https://www.dulwich.io/code/
 Project-URL: Repository, https://www.dulwich.io/code/
 Project-URL: GitHub, https://github.com/dulwich/dulwich
 Project-URL: GitHub, https://github.com/dulwich/dulwich
-Keywords: git vcs
-Platform: UNKNOWN
+Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
+Keywords: vcs,git
 Classifier: Development Status :: 4 - Beta
 Classifier: Development Status :: 4 - Beta
 Classifier: License :: OSI Approved :: Apache Software License
 Classifier: License :: OSI Approved :: Apache Software License
 Classifier: Programming Language :: Python :: 3.6
 Classifier: Programming Language :: Python :: 3.6
@@ -27,10 +26,9 @@ Classifier: Topic :: Software Development :: Version Control
 Requires-Python: >=3.6
 Requires-Python: >=3.6
 Provides-Extra: fastimport
 Provides-Extra: fastimport
 Provides-Extra: https
 Provides-Extra: https
-Provides-Extra: paramiko
 Provides-Extra: pgp
 Provides-Extra: pgp
+Provides-Extra: paramiko
 License-File: COPYING
 License-File: COPYING
-License-File: AUTHORS
 
 
 Dulwich
 Dulwich
 =======
 =======
@@ -123,5 +121,3 @@ Supported versions of Python
 
 
 At the moment, Dulwich supports (and is tested on) CPython 3.6 and later and
 At the moment, Dulwich supports (and is tested on) CPython 3.6 and later and
 Pypy.
 Pypy.
-
-

+ 3 - 7
releaser.conf → disperse.conf

@@ -1,16 +1,12 @@
-# See https://github.com/jelmer/releaser
+# See https://github.com/jelmer/disperse
 news_file: "NEWS"
 news_file: "NEWS"
 timeout_days: 5
 timeout_days: 5
 tag_name: "dulwich-$VERSION"
 tag_name: "dulwich-$VERSION"
 verify_command: "flake8 && make check"
 verify_command: "flake8 && make check"
-github_url: "https://github.com/dulwich/dulwich"
-update_version {
-  path: "setup.py"
-  match: "^dulwich_version_string = '(.*)'$"
-  new_line: "dulwich_version_string = '$VERSION'"
-}
 update_version {
 update_version {
   path: "dulwich/__init__.py"
   path: "dulwich/__init__.py"
   match: "^__version__ = \((.*)\)$"
   match: "^__version__ = \((.*)\)$"
   new_line: "__version__ = $TUPLED_VERSION"
   new_line: "__version__ = $TUPLED_VERSION"
 }
 }
+# Dulwich' CI builds wheels, which is really slow
+ci_timeout: 7200

+ 4 - 8
dulwich.egg-info/PKG-INFO

@@ -1,16 +1,15 @@
 Metadata-Version: 2.1
 Metadata-Version: 2.1
 Name: dulwich
 Name: dulwich
-Version: 0.20.46
+Version: 0.20.50
 Summary: Python Git Library
 Summary: Python Git Library
 Home-page: https://www.dulwich.io/
 Home-page: https://www.dulwich.io/
 Author: Jelmer Vernooij
 Author: Jelmer Vernooij
 Author-email: jelmer@jelmer.uk
 Author-email: jelmer@jelmer.uk
 License: Apachev2 or later or GPLv2
 License: Apachev2 or later or GPLv2
-Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
 Project-URL: Repository, https://www.dulwich.io/code/
 Project-URL: Repository, https://www.dulwich.io/code/
 Project-URL: GitHub, https://github.com/dulwich/dulwich
 Project-URL: GitHub, https://github.com/dulwich/dulwich
-Keywords: git vcs
-Platform: UNKNOWN
+Project-URL: Bug Tracker, https://github.com/dulwich/dulwich/issues
+Keywords: vcs,git
 Classifier: Development Status :: 4 - Beta
 Classifier: Development Status :: 4 - Beta
 Classifier: License :: OSI Approved :: Apache Software License
 Classifier: License :: OSI Approved :: Apache Software License
 Classifier: Programming Language :: Python :: 3.6
 Classifier: Programming Language :: Python :: 3.6
@@ -27,10 +26,9 @@ Classifier: Topic :: Software Development :: Version Control
 Requires-Python: >=3.6
 Requires-Python: >=3.6
 Provides-Extra: fastimport
 Provides-Extra: fastimport
 Provides-Extra: https
 Provides-Extra: https
-Provides-Extra: paramiko
 Provides-Extra: pgp
 Provides-Extra: pgp
+Provides-Extra: paramiko
 License-File: COPYING
 License-File: COPYING
-License-File: AUTHORS
 
 
 Dulwich
 Dulwich
 =======
 =======
@@ -123,5 +121,3 @@ Supported versions of Python
 
 
 At the moment, Dulwich supports (and is tested on) CPython 3.6 and later and
 At the moment, Dulwich supports (and is tested on) CPython 3.6 and later and
 Pypy.
 Pypy.
-
-

+ 6 - 2
dulwich.egg-info/SOURCES.txt

@@ -16,15 +16,17 @@ README.rst
 README.swift.rst
 README.swift.rst
 SECURITY.md
 SECURITY.md
 TODO
 TODO
+disperse.conf
 dulwich.cfg
 dulwich.cfg
-releaser.conf
 requirements.txt
 requirements.txt
 setup.cfg
 setup.cfg
 setup.py
 setup.py
 status.yaml
 status.yaml
 tox.ini
 tox.ini
 .github/FUNDING.yml
 .github/FUNDING.yml
-.github/workflows/pythonpackage.yml
+.github/workflows/disperse.yml
+.github/workflows/docs.yml
+.github/workflows/pythontest.yml
 .github/workflows/pythonwheels.yml
 .github/workflows/pythonwheels.yml
 bin/dul-receive-pack
 bin/dul-receive-pack
 bin/dul-upload-pack
 bin/dul-upload-pack
@@ -61,6 +63,7 @@ dulwich/bundle.py
 dulwich/cli.py
 dulwich/cli.py
 dulwich/client.py
 dulwich/client.py
 dulwich/config.py
 dulwich/config.py
+dulwich/credentials.py
 dulwich/diff_tree.py
 dulwich/diff_tree.py
 dulwich/errors.py
 dulwich/errors.py
 dulwich/fastexport.py
 dulwich/fastexport.py
@@ -128,6 +131,7 @@ dulwich/tests/test_blackbox.py
 dulwich/tests/test_bundle.py
 dulwich/tests/test_bundle.py
 dulwich/tests/test_client.py
 dulwich/tests/test_client.py
 dulwich/tests/test_config.py
 dulwich/tests/test_config.py
+dulwich/tests/test_credentials.py
 dulwich/tests/test_diff_tree.py
 dulwich/tests/test_diff_tree.py
 dulwich/tests/test_fastexport.py
 dulwich/tests/test_fastexport.py
 dulwich/tests/test_file.py
 dulwich/tests/test_file.py

+ 0 - 1
dulwich.egg-info/entry_points.txt

@@ -1,3 +1,2 @@
 [console_scripts]
 [console_scripts]
 dulwich = dulwich.cli:main
 dulwich = dulwich.cli:main
-

+ 1 - 1
dulwich/__init__.py

@@ -22,4 +22,4 @@
 
 
 """Python implementation of the Git file formats and protocols."""
 """Python implementation of the Git file formats and protocols."""
 
 
-__version__ = (0, 20, 46)
+__version__ = (0, 20, 50)

+ 1 - 1
dulwich/archive.py

@@ -110,7 +110,7 @@ def tar_stream(store, tree, mtime, prefix=b"", format=""):
 
 
             info = tarfile.TarInfo()
             info = tarfile.TarInfo()
             # tarfile only works with ascii.
             # tarfile only works with ascii.
-            info.name = entry_abspath.decode("ascii")
+            info.name = entry_abspath.decode('utf-8', 'surrogateescape')
             info.size = blob.raw_length()
             info.size = blob.raw_length()
             info.mode = entry.mode
             info.mode = entry.mode
             info.mtime = mtime
             info.mtime = mtime

+ 1 - 1
dulwich/bundle.py

@@ -119,4 +119,4 @@ def write_bundle(f, bundle):
     for ref, obj_id in bundle.references.items():
     for ref, obj_id in bundle.references.items():
         f.write(b"%s %s\n" % (obj_id, ref))
         f.write(b"%s %s\n" % (obj_id, ref))
     f.write(b"\n")
     f.write(b"\n")
-    write_pack_data(f, records=bundle.pack_data)
+    write_pack_data(f.write, records=bundle.pack_data)

+ 34 - 12
dulwich/cli.py

@@ -34,7 +34,7 @@ from getopt import getopt
 import argparse
 import argparse
 import optparse
 import optparse
 import signal
 import signal
-from typing import Dict, Type
+from typing import Dict, Type, Optional
 
 
 from dulwich import porcelain
 from dulwich import porcelain
 from dulwich.client import get_transport_and_path
 from dulwich.client import get_transport_and_path
@@ -247,6 +247,10 @@ class cmd_clone(Command):
         parser.add_option(
         parser.add_option(
             "--depth", dest="depth", type=int, help="Depth at which to fetch"
             "--depth", dest="depth", type=int, help="Depth at which to fetch"
         )
         )
+        parser.add_option(
+            "-b", "--branch", dest="branch", type=str,
+            help=("Check out branch instead of branch pointed to by remote "
+                  "HEAD"))
         options, args = parser.parse_args(args)
         options, args = parser.parse_args(args)
 
 
         if args == []:
         if args == []:
@@ -259,7 +263,8 @@ class cmd_clone(Command):
         else:
         else:
             target = None
             target = None
 
 
-        porcelain.clone(source, target, bare=options.bare, depth=options.depth)
+        porcelain.clone(source, target, bare=options.bare, depth=options.depth,
+                        branch=options.branch)
 
 
 
 
 class cmd_commit(Command):
 class cmd_commit(Command):
@@ -321,14 +326,6 @@ class cmd_rev_list(Command):
         porcelain.rev_list(".", args)
         porcelain.rev_list(".", args)
 
 
 
 
-class cmd_submodule(Command):
-    def run(self, args):
-        parser = optparse.OptionParser()
-        options, args = parser.parse_args(args)
-        for path, sha in porcelain.submodule_list("."):
-            sys.stdout.write(' %s %s\n' % (sha, path))
-
-
 class cmd_tag(Command):
 class cmd_tag(Command):
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()
@@ -581,10 +578,11 @@ class cmd_remote_add(Command):
 
 
 class SuperCommand(Command):
 class SuperCommand(Command):
 
 
-    subcommands = {}  # type: Dict[str, Type[Command]]
+    subcommands: Dict[str, Type[Command]] = {}
+    default_command: Optional[Type[Command]] = None
 
 
     def run(self, args):
     def run(self, args):
-        if not args:
+        if not args and not self.default_command:
             print("Supported subcommands: %s" % ", ".join(self.subcommands.keys()))
             print("Supported subcommands: %s" % ", ".join(self.subcommands.keys()))
             return False
             return False
         cmd = args[0]
         cmd = args[0]
@@ -603,6 +601,30 @@ class cmd_remote(SuperCommand):
     }
     }
 
 
 
 
+class cmd_submodule_list(Command):
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        parser.parse_args(argv)
+        for path, sha in porcelain.submodule_list("."):
+            sys.stdout.write(' %s %s\n' % (sha, path))
+
+
+class cmd_submodule_init(Command):
+    def run(self, argv):
+        parser = argparse.ArgumentParser()
+        parser.parse_args(argv)
+        porcelain.submodule_init(".")
+
+
+class cmd_submodule(SuperCommand):
+
+    subcommands = {
+        "init": cmd_submodule_init,
+    }
+
+    default_command = cmd_submodule_init
+
+
 class cmd_check_ignore(Command):
 class cmd_check_ignore(Command):
     def run(self, args):
     def run(self, args):
         parser = optparse.OptionParser()
         parser = optparse.OptionParser()

+ 17 - 16
dulwich/client.py

@@ -111,7 +111,6 @@ from dulwich.protocol import (
     SIDE_BAND_CHANNEL_FATAL,
     SIDE_BAND_CHANNEL_FATAL,
     PktLineParser,
     PktLineParser,
     Protocol,
     Protocol,
-    ProtocolFile,
     TCP_GIT_PORT,
     TCP_GIT_PORT,
     ZERO_SHA,
     ZERO_SHA,
     extract_capabilities,
     extract_capabilities,
@@ -511,8 +510,9 @@ def _read_side_band64k_data(pkt_seq, channel_callbacks):
         pkt = pkt[1:]
         pkt = pkt[1:]
         try:
         try:
             cb = channel_callbacks[channel]
             cb = channel_callbacks[channel]
-        except KeyError:
-            raise AssertionError("Invalid sideband channel %d" % channel)
+        except KeyError as exc:
+            raise AssertionError(
+                "Invalid sideband channel %d" % channel) from exc
         else:
         else:
             if cb is not None:
             if cb is not None:
                 cb(pkt)
                 cb(pkt)
@@ -1053,8 +1053,8 @@ class TraditionalGitClient(GitClient):
         with proto:
         with proto:
             try:
             try:
                 old_refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
                 old_refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
-            except HangupException:
-                raise _remote_error_from_stderr(stderr)
+            except HangupException as exc:
+                raise _remote_error_from_stderr(stderr) from exc
             (
             (
                 negotiated_capabilities,
                 negotiated_capabilities,
                 agent,
                 agent,
@@ -1147,8 +1147,8 @@ class TraditionalGitClient(GitClient):
         with proto:
         with proto:
             try:
             try:
                 refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
                 refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
-            except HangupException:
-                raise _remote_error_from_stderr(stderr)
+            except HangupException as exc:
+                raise _remote_error_from_stderr(stderr) from exc
             (
             (
                 negotiated_capabilities,
                 negotiated_capabilities,
                 symrefs,
                 symrefs,
@@ -1196,8 +1196,8 @@ class TraditionalGitClient(GitClient):
         with proto:
         with proto:
             try:
             try:
                 refs, _ = read_pkt_refs(proto.read_pkt_seq())
                 refs, _ = read_pkt_refs(proto.read_pkt_seq())
-            except HangupException:
-                raise _remote_error_from_stderr(stderr)
+            except HangupException as exc:
+                raise _remote_error_from_stderr(stderr) from exc
             proto.write_pkt_line(None)
             proto.write_pkt_line(None)
             return refs
             return refs
 
 
@@ -1225,8 +1225,8 @@ class TraditionalGitClient(GitClient):
             proto.write_pkt_line(None)
             proto.write_pkt_line(None)
             try:
             try:
                 pkt = proto.read_pkt_line()
                 pkt = proto.read_pkt_line()
-            except HangupException:
-                raise _remote_error_from_stderr(stderr)
+            except HangupException as exc:
+                raise _remote_error_from_stderr(stderr) from exc
             if pkt == b"NACK\n" or pkt == b"NACK":
             if pkt == b"NACK\n" or pkt == b"NACK":
                 return
                 return
             elif pkt == b"ACK\n" or pkt == b"ACK":
             elif pkt == b"ACK\n" or pkt == b"ACK":
@@ -1397,7 +1397,8 @@ class SubprocessGitClient(TraditionalGitClient):
 class LocalGitClient(GitClient):
 class LocalGitClient(GitClient):
     """Git Client that just uses a local Repo."""
     """Git Client that just uses a local Repo."""
 
 
-    def __init__(self, thin_packs=True, report_activity=None, config=None):
+    def __init__(self, thin_packs=True, report_activity=None,
+                 config: Optional[Config] = None):
         """Create a new LocalGitClient instance.
         """Create a new LocalGitClient instance.
 
 
         Args:
         Args:
@@ -1543,8 +1544,7 @@ class LocalGitClient(GitClient):
             # Note that the client still expects a 0-object pack in most cases.
             # Note that the client still expects a 0-object pack in most cases.
             if objects_iter is None:
             if objects_iter is None:
                 return FetchPackResult(None, symrefs, agent)
                 return FetchPackResult(None, symrefs, agent)
-            protocol = ProtocolFile(None, pack_data)
-            write_pack_objects(protocol, objects_iter)
+            write_pack_objects(pack_data, objects_iter)
             return FetchPackResult(r.get_refs(), symrefs, agent)
             return FetchPackResult(r.get_refs(), symrefs, agent)
 
 
     def get_refs(self, path):
     def get_refs(self, path):
@@ -1949,8 +1949,9 @@ class AbstractHttpGitClient(GitClient):
                 # The first line should mention the service
                 # The first line should mention the service
                 try:
                 try:
                     [pkt] = list(proto.read_pkt_seq())
                     [pkt] = list(proto.read_pkt_seq())
-                except ValueError:
-                    raise GitProtocolError("unexpected number of packets received")
+                except ValueError as exc:
+                    raise GitProtocolError(
+                        "unexpected number of packets received") from exc
                 if pkt.rstrip(b"\n") != (b"# service=" + service):
                 if pkt.rstrip(b"\n") != (b"# service=" + service):
                     raise GitProtocolError(
                     raise GitProtocolError(
                         "unexpected first line %r from smart server" % pkt
                         "unexpected first line %r from smart server" % pkt

+ 15 - 4
dulwich/config.py

@@ -411,15 +411,15 @@ def _parse_string(value: bytes) -> bytes:
             i += 1
             i += 1
             try:
             try:
                 v = _ESCAPE_TABLE[value[i]]
                 v = _ESCAPE_TABLE[value[i]]
-            except IndexError:
+            except IndexError as exc:
                 raise ValueError(
                 raise ValueError(
                     "escape character in %r at %d before end of string" % (value, i)
                     "escape character in %r at %d before end of string" % (value, i)
-                )
-            except KeyError:
+                ) from exc
+            except KeyError as exc:
                 raise ValueError(
                 raise ValueError(
                     "escape character followed by unknown character "
                     "escape character followed by unknown character "
                     "%s at %d in %r" % (value[i], i, value)
                     "%s at %d in %r" % (value[i], i, value)
-                )
+                ) from exc
             if whitespace:
             if whitespace:
                 ret.extend(whitespace)
                 ret.extend(whitespace)
                 whitespace = bytearray()
                 whitespace = bytearray()
@@ -447,6 +447,7 @@ def _parse_string(value: bytes) -> bytes:
 def _escape_value(value: bytes) -> bytes:
 def _escape_value(value: bytes) -> bytes:
     """Escape a value."""
     """Escape a value."""
     value = value.replace(b"\\", b"\\\\")
     value = value.replace(b"\\", b"\\\\")
+    value = value.replace(b"\r", b"\\r")
     value = value.replace(b"\n", b"\\n")
     value = value.replace(b"\n", b"\\n")
     value = value.replace(b"\t", b"\\t")
     value = value.replace(b"\t", b"\\t")
     value = value.replace(b'"', b'\\"')
     value = value.replace(b'"', b'\\"')
@@ -565,6 +566,8 @@ class ConfigFile(ConfigDict):
                     raise ValueError("invalid variable name %r" % setting)
                     raise ValueError("invalid variable name %r" % setting)
                 if value.endswith(b"\\\n"):
                 if value.endswith(b"\\\n"):
                     continuation = value[:-2]
                     continuation = value[:-2]
+                elif value.endswith(b"\\\r\n"):
+                    continuation = value[:-3]
                 else:
                 else:
                     continuation = None
                     continuation = None
                     value = _parse_string(value)
                     value = _parse_string(value)
@@ -573,6 +576,8 @@ class ConfigFile(ConfigDict):
             else:  # continuation line
             else:  # continuation line
                 if line.endswith(b"\\\n"):
                 if line.endswith(b"\\\n"):
                     continuation += line[:-2]
                     continuation += line[:-2]
+                elif line.endswith(b"\\\r\n"):
+                    continuation += line[:-3]
                 else:
                 else:
                     continuation += line
                     continuation += line
                     value = _parse_string(continuation)
                     value = _parse_string(continuation)
@@ -750,6 +755,12 @@ class StackedConfig(Config):
                     yield section
                     yield section
 
 
 
 
+def read_submodules(path: str) -> Iterator[Tuple[bytes, bytes, bytes]]:
+    """read a .gitmodules file."""
+    cfg = ConfigFile.from_path(path)
+    return parse_submodules(cfg)
+
+
 def parse_submodules(config: ConfigFile) -> Iterator[Tuple[bytes, bytes, bytes]]:
 def parse_submodules(config: ConfigFile) -> Iterator[Tuple[bytes, bytes, bytes]]:
     """Parse a gitmodules GitConfig file, returning submodules.
     """Parse a gitmodules GitConfig file, returning submodules.
 
 

+ 3 - 2
dulwich/contrib/swift.py

@@ -170,8 +170,9 @@ def load_conf(path=None, file=None):
     if not path:
     if not path:
         try:
         try:
             confpath = os.environ["DULWICH_SWIFT_CFG"]
             confpath = os.environ["DULWICH_SWIFT_CFG"]
-        except KeyError:
-            raise Exception("You need to specify a configuration file")
+        except KeyError as exc:
+            raise Exception(
+                "You need to specify a configuration file") from exc
     else:
     else:
         confpath = path
         confpath = path
     if not os.path.isfile(confpath):
     if not os.path.isfile(confpath):

+ 89 - 0
dulwich/credentials.py

@@ -0,0 +1,89 @@
+# credentials.py -- support for git credential helpers
+
+# Copyright (C) 2022 Daniele Trifirò <daniele@iterative.ai>
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Support for git credential helpers
+
+https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage
+
+"""
+import sys
+from typing import Iterator, Optional
+from urllib.parse import ParseResult, urlparse
+
+from dulwich.config import ConfigDict, SectionLike
+
+
+def match_urls(url: ParseResult, url_prefix: ParseResult) -> bool:
+    base_match = (
+        url.scheme == url_prefix.scheme
+        and url.hostname == url_prefix.hostname
+        and url.port == url_prefix.port
+    )
+    user_match = url.username == url_prefix.username if url_prefix.username else True
+    path_match = url.path.rstrip("/").startswith(url_prefix.path.rstrip())
+    return base_match and user_match and path_match
+
+
+def match_partial_url(valid_url: ParseResult, partial_url: str) -> bool:
+    """matches a parsed url with a partial url (no scheme/netloc)"""
+    if "://" not in partial_url:
+        parsed = urlparse("scheme://" + partial_url)
+    else:
+        parsed = urlparse(partial_url)
+        if valid_url.scheme != parsed.scheme:
+            return False
+
+    if any(
+        (
+            (parsed.hostname and valid_url.hostname != parsed.hostname),
+            (parsed.username and valid_url.username != parsed.username),
+            (parsed.port and valid_url.port != parsed.port),
+            (parsed.path and parsed.path.rstrip("/") != valid_url.path.rstrip("/")),
+        ),
+    ):
+        return False
+
+    return True
+
+
+def urlmatch_credential_sections(
+    config: ConfigDict, url: Optional[str]
+) -> Iterator[SectionLike]:
+    """Returns credential sections from the config which match the given URL"""
+    encoding = config.encoding or sys.getdefaultencoding()
+    parsed_url = urlparse(url or "")
+    for config_section in config.sections():
+        if config_section[0] != b"credential":
+            continue
+
+        if len(config_section) < 2:
+            yield config_section
+            continue
+
+        config_url = config_section[1].decode(encoding)
+        parsed_config_url = urlparse(config_url)
+        if parsed_config_url.scheme and parsed_config_url.netloc:
+            is_match = match_urls(parsed_url, parsed_config_url)
+        else:
+            is_match = match_partial_url(parsed_url, config_url)
+
+        if is_match:
+            yield config_section

+ 2 - 2
dulwich/file.py

@@ -152,8 +152,8 @@ class _GitFile(object):
                 os.O_RDWR | os.O_CREAT | os.O_EXCL | getattr(os, "O_BINARY", 0),
                 os.O_RDWR | os.O_CREAT | os.O_EXCL | getattr(os, "O_BINARY", 0),
                 mask,
                 mask,
             )
             )
-        except FileExistsError:
-            raise FileLocked(filename, self._lockfilename)
+        except FileExistsError as exc:
+            raise FileLocked(filename, self._lockfilename) from exc
         self._file = os.fdopen(fd, mode, bufsize)
         self._file = os.fdopen(fd, mode, bufsize)
         self._closed = False
         self._closed = False
 
 

+ 1 - 1
dulwich/hooks.py

@@ -200,4 +200,4 @@ class PostReceiveShellHook(ShellHook):
                 raise HookError(err_msg.decode('utf-8', 'backslashreplace'))
                 raise HookError(err_msg.decode('utf-8', 'backslashreplace'))
             return out_data
             return out_data
         except OSError as err:
         except OSError as err:
-            raise HookError(repr(err))
+            raise HookError(repr(err)) from err

+ 101 - 47
dulwich/index.py

@@ -36,6 +36,7 @@ from typing import (
     Iterable,
     Iterable,
     Iterator,
     Iterator,
     Tuple,
     Tuple,
+    Union,
 )
 )
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -49,6 +50,7 @@ from dulwich.objects import (
     Tree,
     Tree,
     hex_to_sha,
     hex_to_sha,
     sha_to_hex,
     sha_to_hex,
+    ObjectID,
 )
 )
 from dulwich.pack import (
 from dulwich.pack import (
     SHA1Reader,
     SHA1Reader,
@@ -95,7 +97,7 @@ EXTENDED_FLAG_INTEND_TO_ADD = 0x2000
 DEFAULT_VERSION = 2
 DEFAULT_VERSION = 2
 
 
 
 
-def pathsplit(path):
+def pathsplit(path: bytes) -> Tuple[bytes, bytes]:
     """Split a /-delimited path into a directory part and a basename.
     """Split a /-delimited path into a directory part and a basename.
 
 
     Args:
     Args:
@@ -194,7 +196,7 @@ def read_cache_entry(f, version: int) -> Tuple[str, IndexEntry]:
         ))
         ))
 
 
 
 
-def write_cache_entry(f, name, entry, version):
+def write_cache_entry(f, name: bytes, entry: IndexEntry, version: int) -> None:
     """Write an index entry to a file.
     """Write an index entry to a file.
 
 
     Args:
     Args:
@@ -230,18 +232,26 @@ def write_cache_entry(f, name, entry, version):
         f.write(b"\0" * ((beginoffset + real_size) - f.tell()))
         f.write(b"\0" * ((beginoffset + real_size) - f.tell()))
 
 
 
 
+class UnsupportedIndexFormat(Exception):
+    """An unsupported index format was encountered."""
+
+    def __init__(self, version):
+        self.index_format_version = version
+
+
 def read_index(f: BinaryIO):
 def read_index(f: BinaryIO):
     """Read an index file, yielding the individual entries."""
     """Read an index file, yielding the individual entries."""
     header = f.read(4)
     header = f.read(4)
     if header != b"DIRC":
     if header != b"DIRC":
         raise AssertionError("Invalid index file header: %r" % header)
         raise AssertionError("Invalid index file header: %r" % header)
     (version, num_entries) = struct.unpack(b">LL", f.read(4 * 2))
     (version, num_entries) = struct.unpack(b">LL", f.read(4 * 2))
-    assert version in (1, 2, 3), "index version is %r" % version
+    if version not in (1, 2, 3):
+        raise UnsupportedIndexFormat(version)
     for i in range(num_entries):
     for i in range(num_entries):
         yield read_cache_entry(f, version)
         yield read_cache_entry(f, version)
 
 
 
 
-def read_index_dict(f):
+def read_index_dict(f) -> Dict[bytes, IndexEntry]:
     """Read an index file and return it as a dictionary.
     """Read an index file and return it as a dictionary.
 
 
     Args:
     Args:
@@ -306,17 +316,19 @@ def cleanup_mode(mode: int) -> int:
 class Index(object):
 class Index(object):
     """A Git Index file."""
     """A Git Index file."""
 
 
-    def __init__(self, filename):
-        """Open an index file.
+    def __init__(self, filename: Union[bytes, str], read=True):
+        """Create an index object associated with the given filename.
 
 
         Args:
         Args:
           filename: Path to the index file
           filename: Path to the index file
+          read: Whether to initialize the index from the given file, should it exist.
         """
         """
         self._filename = filename
         self._filename = filename
         # TODO(jelmer): Store the version returned by read_index
         # TODO(jelmer): Store the version returned by read_index
         self._version = None
         self._version = None
         self.clear()
         self.clear()
-        self.read()
+        if read:
+            self.read()
 
 
     @property
     @property
     def path(self):
     def path(self):
@@ -383,27 +395,28 @@ class Index(object):
         """Remove all contents from this index."""
         """Remove all contents from this index."""
         self._byname = {}
         self._byname = {}
 
 
-    def __setitem__(self, name, x):
+    def __setitem__(self, name: bytes, x: IndexEntry):
         assert isinstance(name, bytes)
         assert isinstance(name, bytes)
         assert len(x) == len(IndexEntry._fields)
         assert len(x) == len(IndexEntry._fields)
         # Remove the old entry if any
         # Remove the old entry if any
         self._byname[name] = IndexEntry(*x)
         self._byname[name] = IndexEntry(*x)
 
 
-    def __delitem__(self, name):
+    def __delitem__(self, name: bytes):
         assert isinstance(name, bytes)
         assert isinstance(name, bytes)
         del self._byname[name]
         del self._byname[name]
 
 
-    def iteritems(self):
+    def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
         return self._byname.items()
         return self._byname.items()
 
 
-    def items(self):
+    def items(self) -> Iterator[Tuple[bytes, IndexEntry]]:
         return self._byname.items()
         return self._byname.items()
 
 
-    def update(self, entries):
+    def update(self, entries: Dict[bytes, IndexEntry]):
         for name, value in entries.items():
         for name, value in entries.items():
             self[name] = value
             self[name] = value
 
 
-    def changes_from_tree(self, object_store, tree, want_unchanged=False):
+    def changes_from_tree(
+            self, object_store, tree: ObjectID, want_unchanged: bool = False):
         """Find the differences between the contents of this index and a tree.
         """Find the differences between the contents of this index and a tree.
 
 
         Args:
         Args:
@@ -573,8 +586,36 @@ def index_entry_from_stat(
     )
     )
 
 
 
 
+if sys.platform == 'win32':
+    # On Windows, creating symlinks either requires administrator privileges
+    # or developer mode. Raise a more helpful error when we're unable to
+    # create symlinks
+
+    # https://github.com/jelmer/dulwich/issues/1005
+
+    class WindowsSymlinkPermissionError(PermissionError):
+
+        def __init__(self, errno, msg, filename):
+            super(PermissionError, self).__init__(
+                errno, "Unable to create symlink; "
+                "do you have developer mode enabled? %s" % msg,
+                filename)
+
+    def symlink(src, dst, target_is_directory=False, *, dir_fd=None):
+        try:
+            return os.symlink(
+                src, dst, target_is_directory=target_is_directory,
+                dir_fd=dir_fd)
+        except PermissionError as e:
+            raise WindowsSymlinkPermissionError(
+                e.errno, e.strerror, e.filename) from e
+else:
+    symlink = os.symlink
+
+
 def build_file_from_blob(
 def build_file_from_blob(
-    blob, mode, target_path, honor_filemode=True, tree_encoding="utf-8"
+        blob: Blob, mode: int, target_path: bytes, *, honor_filemode=True,
+        tree_encoding="utf-8", symlink_fn=None
 ):
 ):
     """Build a file or symlink on disk based on a Git object.
     """Build a file or symlink on disk based on a Git object.
 
 
@@ -584,6 +625,7 @@ def build_file_from_blob(
       target_path: Path to write to
       target_path: Path to write to
       honor_filemode: An optional flag to honor core.filemode setting in
       honor_filemode: An optional flag to honor core.filemode setting in
         config file, default is core.filemode=True, change executable bit
         config file, default is core.filemode=True, change executable bit
+      symlink: Function to use for creating symlinks
     Returns: stat object for the file
     Returns: stat object for the file
     """
     """
     try:
     try:
@@ -592,14 +634,13 @@ def build_file_from_blob(
         oldstat = None
         oldstat = None
     contents = blob.as_raw_string()
     contents = blob.as_raw_string()
     if stat.S_ISLNK(mode):
     if stat.S_ISLNK(mode):
-        # FIXME: This will fail on Windows. What should we do instead?
         if oldstat:
         if oldstat:
             os.unlink(target_path)
             os.unlink(target_path)
         if sys.platform == "win32":
         if sys.platform == "win32":
             # os.readlink on Python3 on Windows requires a unicode string.
             # os.readlink on Python3 on Windows requires a unicode string.
-            contents = contents.decode(tree_encoding)
-            target_path = target_path.decode(tree_encoding)
-        os.symlink(contents, target_path)
+            contents = contents.decode(tree_encoding)  # type: ignore
+            target_path = target_path.decode(tree_encoding)  # type: ignore
+        (symlink_fn or symlink)(contents, target_path)
     else:
     else:
         if oldstat is not None and oldstat.st_size == len(contents):
         if oldstat is not None and oldstat.st_size == len(contents):
             with open(target_path, "rb") as f:
             with open(target_path, "rb") as f:
@@ -619,11 +660,11 @@ def build_file_from_blob(
 INVALID_DOTNAMES = (b".git", b".", b"..", b"")
 INVALID_DOTNAMES = (b".git", b".", b"..", b"")
 
 
 
 
-def validate_path_element_default(element):
+def validate_path_element_default(element: bytes) -> bool:
     return element.lower() not in INVALID_DOTNAMES
     return element.lower() not in INVALID_DOTNAMES
 
 
 
 
-def validate_path_element_ntfs(element):
+def validate_path_element_ntfs(element: bytes) -> bool:
     stripped = element.rstrip(b". ").lower()
     stripped = element.rstrip(b". ").lower()
     if stripped in INVALID_DOTNAMES:
     if stripped in INVALID_DOTNAMES:
         return False
         return False
@@ -632,7 +673,8 @@ def validate_path_element_ntfs(element):
     return True
     return True
 
 
 
 
-def validate_path(path, element_validator=validate_path_element_default):
+def validate_path(path: bytes,
+                  element_validator=validate_path_element_default) -> bool:
     """Default path validator that just checks for .git/."""
     """Default path validator that just checks for .git/."""
     parts = path.split(b"/")
     parts = path.split(b"/")
     for p in parts:
     for p in parts:
@@ -643,12 +685,13 @@ def validate_path(path, element_validator=validate_path_element_default):
 
 
 
 
 def build_index_from_tree(
 def build_index_from_tree(
-    root_path,
-    index_path,
-    object_store,
-    tree_id,
-    honor_filemode=True,
+    root_path: Union[str, bytes],
+    index_path: Union[str, bytes],
+    object_store: "BaseObjectStore",
+    tree_id: bytes,
+    honor_filemode: bool = True,
     validate_path_element=validate_path_element_default,
     validate_path_element=validate_path_element_default,
+    symlink_fn=None
 ):
 ):
     """Generate and materialize index from a tree
     """Generate and materialize index from a tree
 
 
@@ -665,8 +708,7 @@ def build_index_from_tree(
     Note: existing index is wiped and contents are not merged
     Note: existing index is wiped and contents are not merged
         in a working dir. Suitable only for fresh clones.
         in a working dir. Suitable only for fresh clones.
     """
     """
-
-    index = Index(index_path)
+    index = Index(index_path, read=False)
     if not isinstance(root_path, bytes):
     if not isinstance(root_path, bytes):
         root_path = os.fsencode(root_path)
         root_path = os.fsencode(root_path)
 
 
@@ -687,7 +729,9 @@ def build_index_from_tree(
         else:
         else:
             obj = object_store[entry.sha]
             obj = object_store[entry.sha]
             st = build_file_from_blob(
             st = build_file_from_blob(
-                obj, entry.mode, full_path, honor_filemode=honor_filemode
+                obj, entry.mode, full_path,
+                honor_filemode=honor_filemode,
+                symlink_fn=symlink_fn,
             )
             )
 
 
         # Add file to index
         # Add file to index
@@ -713,7 +757,8 @@ def build_index_from_tree(
     index.write()
     index.write()
 
 
 
 
-def blob_from_path_and_mode(fs_path, mode, tree_encoding="utf-8"):
+def blob_from_path_and_mode(fs_path: bytes, mode: int,
+                            tree_encoding="utf-8"):
     """Create a blob from a path and a stat object.
     """Create a blob from a path and a stat object.
 
 
     Args:
     Args:
@@ -726,8 +771,7 @@ def blob_from_path_and_mode(fs_path, mode, tree_encoding="utf-8"):
     if stat.S_ISLNK(mode):
     if stat.S_ISLNK(mode):
         if sys.platform == "win32":
         if sys.platform == "win32":
             # os.readlink on Python3 on Windows requires a unicode string.
             # os.readlink on Python3 on Windows requires a unicode string.
-            fs_path = os.fsdecode(fs_path)
-            blob.data = os.readlink(fs_path).encode(tree_encoding)
+            blob.data = os.readlink(os.fsdecode(fs_path)).encode(tree_encoding)
         else:
         else:
             blob.data = os.readlink(fs_path)
             blob.data = os.readlink(fs_path)
     else:
     else:
@@ -736,7 +780,7 @@ def blob_from_path_and_mode(fs_path, mode, tree_encoding="utf-8"):
     return blob
     return blob
 
 
 
 
-def blob_from_path_and_stat(fs_path, st, tree_encoding="utf-8"):
+def blob_from_path_and_stat(fs_path: bytes, st, tree_encoding="utf-8"):
     """Create a blob from a path and a stat object.
     """Create a blob from a path and a stat object.
 
 
     Args:
     Args:
@@ -747,7 +791,7 @@ def blob_from_path_and_stat(fs_path, st, tree_encoding="utf-8"):
     return blob_from_path_and_mode(fs_path, st.st_mode, tree_encoding)
     return blob_from_path_and_mode(fs_path, st.st_mode, tree_encoding)
 
 
 
 
-def read_submodule_head(path):
+def read_submodule_head(path: Union[str, bytes]) -> Optional[bytes]:
     """Read the head commit of a submodule.
     """Read the head commit of a submodule.
 
 
     Args:
     Args:
@@ -771,7 +815,7 @@ def read_submodule_head(path):
         return None
         return None
 
 
 
 
-def _has_directory_changed(tree_path, entry):
+def _has_directory_changed(tree_path: bytes, entry):
     """Check if a directory has changed after getting an error.
     """Check if a directory has changed after getting an error.
 
 
     When handling an error trying to create a blob from a path, call this
     When handling an error trying to create a blob from a path, call this
@@ -796,7 +840,9 @@ def _has_directory_changed(tree_path, entry):
     return False
     return False
 
 
 
 
-def get_unstaged_changes(index: Index, root_path, filter_blob_callback=None):
+def get_unstaged_changes(
+        index: Index, root_path: Union[str, bytes],
+        filter_blob_callback=None):
     """Walk through an index and check for differences against working tree.
     """Walk through an index and check for differences against working tree.
 
 
     Args:
     Args:
@@ -836,7 +882,7 @@ def get_unstaged_changes(index: Index, root_path, filter_blob_callback=None):
 os_sep_bytes = os.sep.encode("ascii")
 os_sep_bytes = os.sep.encode("ascii")
 
 
 
 
-def _tree_to_fs_path(root_path, tree_path: bytes):
+def _tree_to_fs_path(root_path: bytes, tree_path: bytes):
     """Convert a git tree path to a file system path.
     """Convert a git tree path to a file system path.
 
 
     Args:
     Args:
@@ -853,7 +899,7 @@ def _tree_to_fs_path(root_path, tree_path: bytes):
     return os.path.join(root_path, sep_corrected_path)
     return os.path.join(root_path, sep_corrected_path)
 
 
 
 
-def _fs_to_tree_path(fs_path):
+def _fs_to_tree_path(fs_path: Union[str, bytes]) -> bytes:
     """Convert a file system path to a git tree path.
     """Convert a file system path to a git tree path.
 
 
     Args:
     Args:
@@ -872,7 +918,7 @@ def _fs_to_tree_path(fs_path):
     return tree_path
     return tree_path
 
 
 
 
-def index_entry_from_directory(st, path):
+def index_entry_from_directory(st, path: bytes) -> Optional[IndexEntry]:
     if os.path.exists(os.path.join(path, b".git")):
     if os.path.exists(os.path.join(path, b".git")):
         head = read_submodule_head(path)
         head = read_submodule_head(path)
         if head is None:
         if head is None:
@@ -881,7 +927,9 @@ def index_entry_from_directory(st, path):
     return None
     return None
 
 
 
 
-def index_entry_from_path(path, object_store=None):
+def index_entry_from_path(
+        path: bytes, object_store: Optional["BaseObjectStore"] = None
+) -> Optional[IndexEntry]:
     """Create an index from a filesystem path.
     """Create an index from a filesystem path.
 
 
     This returns an index value for files, symlinks
     This returns an index value for files, symlinks
@@ -909,8 +957,9 @@ def index_entry_from_path(path, object_store=None):
 
 
 
 
 def iter_fresh_entries(
 def iter_fresh_entries(
-    paths, root_path, object_store: Optional["BaseObjectStore"] = None
-):
+    paths: Iterable[bytes], root_path: bytes,
+    object_store: Optional["BaseObjectStore"] = None
+) -> Iterator[Tuple[bytes, Optional[IndexEntry]]]:
     """Iterate over current versions of index entries on disk.
     """Iterate over current versions of index entries on disk.
 
 
     Args:
     Args:
@@ -928,7 +977,10 @@ def iter_fresh_entries(
         yield path, entry
         yield path, entry
 
 
 
 
-def iter_fresh_objects(paths, root_path, include_deleted=False, object_store=None):
+def iter_fresh_objects(
+        paths: Iterable[bytes], root_path: bytes, include_deleted=False,
+        object_store=None) -> Iterator[
+            Tuple[bytes, Optional[bytes], Optional[int]]]:
     """Iterate over versions of objects on disk referenced by index.
     """Iterate over versions of objects on disk referenced by index.
 
 
     Args:
     Args:
@@ -938,7 +990,8 @@ def iter_fresh_objects(paths, root_path, include_deleted=False, object_store=Non
       object_store: Optional object store to report new items to
       object_store: Optional object store to report new items to
     Returns: Iterator over path, sha, mode
     Returns: Iterator over path, sha, mode
     """
     """
-    for path, entry in iter_fresh_entries(paths, root_path, object_store=object_store):
+    for path, entry in iter_fresh_entries(
+            paths, root_path, object_store=object_store):
         if entry is None:
         if entry is None:
             if include_deleted:
             if include_deleted:
                 yield path, None, None
                 yield path, None, None
@@ -947,7 +1000,7 @@ def iter_fresh_objects(paths, root_path, include_deleted=False, object_store=Non
             yield path, entry.sha, cleanup_mode(entry.mode)
             yield path, entry.sha, cleanup_mode(entry.mode)
 
 
 
 
-def refresh_index(index, root_path):
+def refresh_index(index: Index, root_path: bytes):
     """Refresh the contents of an index.
     """Refresh the contents of an index.
 
 
     This is the equivalent to running 'git commit -a'.
     This is the equivalent to running 'git commit -a'.
@@ -957,7 +1010,8 @@ def refresh_index(index, root_path):
       root_path: Root filesystem path
       root_path: Root filesystem path
     """
     """
     for path, entry in iter_fresh_entries(index, root_path):
     for path, entry in iter_fresh_entries(index, root_path):
-        index[path] = path
+        if entry:
+            index[path] = entry
 
 
 
 
 class locked_index(object):
 class locked_index(object):
@@ -965,7 +1019,7 @@ class locked_index(object):
 
 
     Works as a context manager.
     Works as a context manager.
     """
     """
-    def __init__(self, path):
+    def __init__(self, path: Union[bytes, str]):
         self._path = path
         self._path = path
 
 
     def __enter__(self):
     def __enter__(self):

+ 2 - 2
dulwich/lfs.py

@@ -51,8 +51,8 @@ class LFSStore(object):
         """Open an object by sha."""
         """Open an object by sha."""
         try:
         try:
             return open(self._sha_path(sha), "rb")
             return open(self._sha_path(sha), "rb")
-        except FileNotFoundError:
-            raise KeyError(sha)
+        except FileNotFoundError as exc:
+            raise KeyError(sha) from exc
 
 
     def write_object(self, chunks):
     def write_object(self, chunks):
         """Write an object.
         """Write an object.

+ 14 - 13
dulwich/object_store.py

@@ -38,6 +38,7 @@ from dulwich.errors import (
 )
 )
 from dulwich.file import GitFile
 from dulwich.file import GitFile
 from dulwich.objects import (
 from dulwich.objects import (
+    ObjectID,
     Commit,
     Commit,
     ShaFile,
     ShaFile,
     Tag,
     Tag,
@@ -67,7 +68,7 @@ from dulwich.pack import (
     PackStreamCopier,
     PackStreamCopier,
 )
 )
 from dulwich.protocol import DEPTH_INFINITE
 from dulwich.protocol import DEPTH_INFINITE
-from dulwich.refs import ANNOTATED_TAG_SUFFIX
+from dulwich.refs import ANNOTATED_TAG_SUFFIX, Ref
 
 
 INFODIR = "info"
 INFODIR = "info"
 PACKDIR = "pack"
 PACKDIR = "pack"
@@ -83,9 +84,9 @@ class BaseObjectStore(object):
 
 
     def determine_wants_all(
     def determine_wants_all(
         self,
         self,
-        refs: Dict[bytes, bytes],
+        refs: Dict[Ref, ObjectID],
         depth: Optional[int] = None
         depth: Optional[int] = None
-    ) -> List[bytes]:
+    ) -> List[ObjectID]:
         def _want_deepen(sha):
         def _want_deepen(sha):
             if not depth:
             if not depth:
                 return False
                 return False
@@ -139,7 +140,7 @@ class BaseObjectStore(object):
         """
         """
         raise NotImplementedError(self.get_raw)
         raise NotImplementedError(self.get_raw)
 
 
-    def __getitem__(self, sha):
+    def __getitem__(self, sha: ObjectID):
         """Obtain an object by SHA1."""
         """Obtain an object by SHA1."""
         type_num, uncomp = self.get_raw(sha)
         type_num, uncomp = self.get_raw(sha)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha)
@@ -179,7 +180,7 @@ class BaseObjectStore(object):
         f, commit, abort = self.add_pack()
         f, commit, abort = self.add_pack()
         try:
         try:
             write_pack_data(
             write_pack_data(
-                f,
+                f.write,
                 count,
                 count,
                 pack_data,
                 pack_data,
                 progress,
                 progress,
@@ -780,7 +781,7 @@ class DiskObjectStore(PackBasedObjectStore):
 
 
         # Update the header with the new number of objects.
         # Update the header with the new number of objects.
         f.seek(0)
         f.seek(0)
-        write_pack_header(f, len(entries) + len(indexer.ext_refs()))
+        write_pack_header(f.write, len(entries) + len(indexer.ext_refs()))
 
 
         # Must flush before reading (http://bugs.python.org/issue3207)
         # Must flush before reading (http://bugs.python.org/issue3207)
         f.flush()
         f.flush()
@@ -797,7 +798,7 @@ class DiskObjectStore(PackBasedObjectStore):
             type_num, data = self.get_raw(ext_sha)
             type_num, data = self.get_raw(ext_sha)
             offset = f.tell()
             offset = f.tell()
             crc32 = write_pack_object(
             crc32 = write_pack_object(
-                f,
+                f.write,
                 type_num,
                 type_num,
                 data,
                 data,
                 sha=new_sha,
                 sha=new_sha,
@@ -984,7 +985,7 @@ class MemoryObjectStore(BaseObjectStore):
         """List with pack objects."""
         """List with pack objects."""
         return []
         return []
 
 
-    def get_raw(self, name):
+    def get_raw(self, name: ObjectID):
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
 
 
         Args:
         Args:
@@ -994,10 +995,10 @@ class MemoryObjectStore(BaseObjectStore):
         obj = self[self._to_hexsha(name)]
         obj = self[self._to_hexsha(name)]
         return obj.type_num, obj.as_raw_string()
         return obj.type_num, obj.as_raw_string()
 
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: ObjectID):
         return self._data[self._to_hexsha(name)].copy()
         return self._data[self._to_hexsha(name)].copy()
 
 
-    def __delitem__(self, name):
+    def __delitem__(self, name: ObjectID):
         """Delete an object from this store, for testing only."""
         """Delete an object from this store, for testing only."""
         del self._data[self._to_hexsha(name)]
         del self._data[self._to_hexsha(name)]
 
 
@@ -1047,7 +1048,7 @@ class MemoryObjectStore(BaseObjectStore):
 
 
         # Update the header with the new number of objects.
         # Update the header with the new number of objects.
         f.seek(0)
         f.seek(0)
-        write_pack_header(f, len(entries) + len(indexer.ext_refs()))
+        write_pack_header(f.write, len(entries) + len(indexer.ext_refs()))
 
 
         # Rescan the rest of the pack, computing the SHA with the new header.
         # Rescan the rest of the pack, computing the SHA with the new header.
         new_sha = compute_file_sha(f, end_ofs=-20)
         new_sha = compute_file_sha(f, end_ofs=-20)
@@ -1056,7 +1057,7 @@ class MemoryObjectStore(BaseObjectStore):
         for ext_sha in indexer.ext_refs():
         for ext_sha in indexer.ext_refs():
             assert len(ext_sha) == 20
             assert len(ext_sha) == 20
             type_num, data = self.get_raw(ext_sha)
             type_num, data = self.get_raw(ext_sha)
-            write_pack_object(f, type_num, data, sha=new_sha)
+            write_pack_object(f.write, type_num, data, sha=new_sha)
         pack_sha = new_sha.digest()
         pack_sha = new_sha.digest()
         f.write(pack_sha)
         f.write(pack_sha)
 
 
@@ -1333,7 +1334,7 @@ class MissingObjectFinder(object):
         if not leaf:
         if not leaf:
             o = self.object_store[sha]
             o = self.object_store[sha]
             if isinstance(o, Commit):
             if isinstance(o, Commit):
-                self.add_todo([(o.tree, "", False)])
+                self.add_todo([(o.tree, b"", False)])
             elif isinstance(o, Tree):
             elif isinstance(o, Tree):
                 self.add_todo(
                 self.add_todo(
                     [
                     [

+ 75 - 65
dulwich/objects.py

@@ -33,6 +33,8 @@ from typing import (
     Iterable,
     Iterable,
     Union,
     Union,
     Type,
     Type,
+    Iterator,
+    List,
 )
 )
 import zlib
 import zlib
 from hashlib import sha1
 from hashlib import sha1
@@ -75,6 +77,9 @@ MAX_TIME = 9223372036854775807  # (2**63) - 1 - signed long int max
 BEGIN_PGP_SIGNATURE = b"-----BEGIN PGP SIGNATURE-----"
 BEGIN_PGP_SIGNATURE = b"-----BEGIN PGP SIGNATURE-----"
 
 
 
 
+ObjectID = bytes
+
+
 class EmptyFileException(FileFormatException):
 class EmptyFileException(FileFormatException):
     """An unexpectedly empty file was encountered."""
     """An unexpectedly empty file was encountered."""
 
 
@@ -111,7 +116,7 @@ def hex_to_sha(hex):
     except TypeError as exc:
     except TypeError as exc:
         if not isinstance(hex, bytes):
         if not isinstance(hex, bytes):
             raise
             raise
-        raise ValueError(exc.args[0])
+        raise ValueError(exc.args[0]) from exc
 
 
 
 
 def valid_hexsha(hex):
 def valid_hexsha(hex):
@@ -153,7 +158,10 @@ def filename_to_hex(filename):
 
 
 def object_header(num_type: int, length: int) -> bytes:
 def object_header(num_type: int, length: int) -> bytes:
     """Return an object header for the given numeric type and text length."""
     """Return an object header for the given numeric type and text length."""
-    return object_class(num_type).type_name + b" " + str(length).encode("ascii") + b"\0"
+    cls = object_class(num_type)
+    if cls is None:
+        raise AssertionError("unsupported class type num: %d" % num_type)
+    return cls.type_name + b" " + str(length).encode("ascii") + b"\0"
 
 
 
 
 def serializable_property(name: str, docstring: Optional[str] = None):
 def serializable_property(name: str, docstring: Optional[str] = None):
@@ -169,7 +177,7 @@ def serializable_property(name: str, docstring: Optional[str] = None):
     return property(get, set, doc=docstring)
     return property(get, set, doc=docstring)
 
 
 
 
-def object_class(type):
+def object_class(type: Union[bytes, int]) -> Optional[Type["ShaFile"]]:
     """Get the object class corresponding to the given type.
     """Get the object class corresponding to the given type.
 
 
     Args:
     Args:
@@ -193,7 +201,7 @@ def check_hexsha(hex, error_msg):
         raise ObjectFormatException("%s %s" % (error_msg, hex))
         raise ObjectFormatException("%s %s" % (error_msg, hex))
 
 
 
 
-def check_identity(identity, error_msg):
+def check_identity(identity: bytes, error_msg: str) -> None:
     """Check if the specified identity is valid.
     """Check if the specified identity is valid.
 
 
     This will raise an exception if the identity is not valid.
     This will raise an exception if the identity is not valid.
@@ -202,16 +210,16 @@ def check_identity(identity, error_msg):
       identity: Identity string
       identity: Identity string
       error_msg: Error message to use in exception
       error_msg: Error message to use in exception
     """
     """
-    email_start = identity.find(b"<")
-    email_end = identity.find(b">")
-    if (
-        email_start < 0
-        or email_end < 0
-        or email_end <= email_start
-        or identity.find(b"<", email_start + 1) >= 0
-        or identity.find(b">", email_end + 1) >= 0
-        or not identity.endswith(b">")
-    ):
+    email_start = identity.find(b'<')
+    email_end = identity.find(b'>')
+    if not all([
+        email_start >= 1,
+        identity[email_start - 1] == b' '[0],
+        identity.find(b'<', email_start + 1) == -1,
+        email_end == len(identity) - 1,
+        b'\0' not in identity,
+        b'\n' not in identity,
+    ]):
         raise ObjectFormatException(error_msg)
         raise ObjectFormatException(error_msg)
 
 
 
 
@@ -261,11 +269,13 @@ class ShaFile(object):
 
 
     __slots__ = ("_chunked_text", "_sha", "_needs_serialization")
     __slots__ = ("_chunked_text", "_sha", "_needs_serialization")
 
 
-    type_name = None  # type: bytes
-    type_num = None  # type: int
+    _needs_serialization: bool
+    type_name: bytes
+    type_num: int
+    _chunked_text: Optional[List[bytes]]
 
 
     @staticmethod
     @staticmethod
-    def _parse_legacy_object_header(magic, f):
+    def _parse_legacy_object_header(magic, f) -> "ShaFile":
         """Parse a legacy object, creating it but not reading the file."""
         """Parse a legacy object, creating it but not reading the file."""
         bufsize = 1024
         bufsize = 1024
         decomp = zlib.decompressobj()
         decomp = zlib.decompressobj()
@@ -282,14 +292,15 @@ class ShaFile(object):
         type_name, size = header.split(b" ", 1)
         type_name, size = header.split(b" ", 1)
         try:
         try:
             int(size)  # sanity check
             int(size)  # sanity check
-        except ValueError as e:
-            raise ObjectFormatException("Object size not an integer: %s" % e)
+        except ValueError as exc:
+            raise ObjectFormatException(
+                "Object size not an integer: %s" % exc) from exc
         obj_class = object_class(type_name)
         obj_class = object_class(type_name)
         if not obj_class:
         if not obj_class:
-            raise ObjectFormatException("Not a known type: %s" % type_name)
+            raise ObjectFormatException("Not a known type: %s" % type_name.decode('ascii'))
         return obj_class()
         return obj_class()
 
 
-    def _parse_legacy_object(self, map):
+    def _parse_legacy_object(self, map) -> None:
         """Parse a legacy object, setting the raw string."""
         """Parse a legacy object, setting the raw string."""
         text = _decompress(map)
         text = _decompress(map)
         header_end = text.find(b"\0")
         header_end = text.find(b"\0")
@@ -297,7 +308,8 @@ class ShaFile(object):
             raise ObjectFormatException("Invalid object header, no \\0")
             raise ObjectFormatException("Invalid object header, no \\0")
         self.set_raw_string(text[header_end + 1 :])
         self.set_raw_string(text[header_end + 1 :])
 
 
-    def as_legacy_object_chunks(self, compression_level=-1):
+    def as_legacy_object_chunks(
+            self, compression_level: int = -1) -> Iterator[bytes]:
         """Return chunks representing the object in the experimental format.
         """Return chunks representing the object in the experimental format.
 
 
         Returns: List of strings
         Returns: List of strings
@@ -308,13 +320,13 @@ class ShaFile(object):
             yield compobj.compress(chunk)
             yield compobj.compress(chunk)
         yield compobj.flush()
         yield compobj.flush()
 
 
-    def as_legacy_object(self, compression_level=-1):
+    def as_legacy_object(self, compression_level: int = -1) -> bytes:
         """Return string representing the object in the experimental format."""
         """Return string representing the object in the experimental format."""
         return b"".join(
         return b"".join(
             self.as_legacy_object_chunks(compression_level=compression_level)
             self.as_legacy_object_chunks(compression_level=compression_level)
         )
         )
 
 
-    def as_raw_chunks(self):
+    def as_raw_chunks(self) -> List[bytes]:
         """Return chunks with serialization of the object.
         """Return chunks with serialization of the object.
 
 
         Returns: List of strings, not necessarily one per line
         Returns: List of strings, not necessarily one per line
@@ -323,16 +335,16 @@ class ShaFile(object):
             self._sha = None
             self._sha = None
             self._chunked_text = self._serialize()
             self._chunked_text = self._serialize()
             self._needs_serialization = False
             self._needs_serialization = False
-        return self._chunked_text
+        return self._chunked_text  # type: ignore
 
 
-    def as_raw_string(self):
+    def as_raw_string(self) -> bytes:
         """Return raw string with serialization of the object.
         """Return raw string with serialization of the object.
 
 
         Returns: String object
         Returns: String object
         """
         """
         return b"".join(self.as_raw_chunks())
         return b"".join(self.as_raw_chunks())
 
 
-    def __bytes__(self):
+    def __bytes__(self) -> bytes:
         """Return raw string serialization of this object."""
         """Return raw string serialization of this object."""
         return self.as_raw_string()
         return self.as_raw_string()
 
 
@@ -340,24 +352,27 @@ class ShaFile(object):
         """Return unique hash for this object."""
         """Return unique hash for this object."""
         return hash(self.id)
         return hash(self.id)
 
 
-    def as_pretty_string(self):
+    def as_pretty_string(self) -> bytes:
         """Return a string representing this object, fit for display."""
         """Return a string representing this object, fit for display."""
         return self.as_raw_string()
         return self.as_raw_string()
 
 
-    def set_raw_string(self, text, sha=None):
+    def set_raw_string(
+            self, text: bytes, sha: Optional[ObjectID] = None) -> None:
         """Set the contents of this object from a serialized string."""
         """Set the contents of this object from a serialized string."""
         if not isinstance(text, bytes):
         if not isinstance(text, bytes):
             raise TypeError("Expected bytes for text, got %r" % text)
             raise TypeError("Expected bytes for text, got %r" % text)
         self.set_raw_chunks([text], sha)
         self.set_raw_chunks([text], sha)
 
 
-    def set_raw_chunks(self, chunks, sha=None):
+    def set_raw_chunks(
+            self, chunks: List[bytes],
+            sha: Optional[ObjectID] = None) -> None:
         """Set the contents of this object from a list of chunks."""
         """Set the contents of this object from a list of chunks."""
         self._chunked_text = chunks
         self._chunked_text = chunks
         self._deserialize(chunks)
         self._deserialize(chunks)
         if sha is None:
         if sha is None:
             self._sha = None
             self._sha = None
         else:
         else:
-            self._sha = FixedSha(sha)
+            self._sha = FixedSha(sha)  # type: ignore
         self._needs_serialization = False
         self._needs_serialization = False
 
 
     @staticmethod
     @staticmethod
@@ -369,7 +384,7 @@ class ShaFile(object):
             raise ObjectFormatException("Not a known type %d" % num_type)
             raise ObjectFormatException("Not a known type %d" % num_type)
         return obj_class()
         return obj_class()
 
 
-    def _parse_object(self, map):
+    def _parse_object(self, map) -> None:
         """Parse a new style object, setting self._text."""
         """Parse a new style object, setting self._text."""
         # skip type and size; type must have already been determined, and
         # skip type and size; type must have already been determined, and
         # we trust zlib to fail if it's otherwise corrupted
         # we trust zlib to fail if it's otherwise corrupted
@@ -382,7 +397,7 @@ class ShaFile(object):
         self.set_raw_string(_decompress(raw))
         self.set_raw_string(_decompress(raw))
 
 
     @classmethod
     @classmethod
-    def _is_legacy_object(cls, magic):
+    def _is_legacy_object(cls, magic: bytes) -> bool:
         b0 = ord(magic[0:1])
         b0 = ord(magic[0:1])
         b1 = ord(magic[1:2])
         b1 = ord(magic[1:2])
         word = (b0 << 8) + b1
         word = (b0 << 8) + b1
@@ -427,8 +442,8 @@ class ShaFile(object):
             obj = cls._parse_file(f)
             obj = cls._parse_file(f)
             obj._sha = None
             obj._sha = None
             return obj
             return obj
-        except (IndexError, ValueError):
-            raise ObjectFormatException("invalid object header")
+        except (IndexError, ValueError) as exc:
+            raise ObjectFormatException("invalid object header") from exc
 
 
     @staticmethod
     @staticmethod
     def from_raw_string(type_num, string, sha=None):
     def from_raw_string(type_num, string, sha=None):
@@ -444,7 +459,9 @@ class ShaFile(object):
         return obj
         return obj
 
 
     @staticmethod
     @staticmethod
-    def from_raw_chunks(type_num, chunks, sha=None):
+    def from_raw_chunks(
+            type_num: int, chunks: List[bytes],
+            sha: Optional[ObjectID] = None):
         """Creates an object of the indicated type from the raw chunks given.
         """Creates an object of the indicated type from the raw chunks given.
 
 
         Args:
         Args:
@@ -452,7 +469,10 @@ class ShaFile(object):
           chunks: An iterable of the raw uncompressed contents.
           chunks: An iterable of the raw uncompressed contents.
           sha: Optional known sha for the object
           sha: Optional known sha for the object
         """
         """
-        obj = object_class(type_num)()
+        cls = object_class(type_num)
+        if cls is None:
+            raise AssertionError("unsupported class type num: %d" % type_num)
+        obj = cls()
         obj.set_raw_chunks(chunks, sha)
         obj.set_raw_chunks(chunks, sha)
         return obj
         return obj
 
 
@@ -476,7 +496,7 @@ class ShaFile(object):
         if getattr(self, member, None) is None:
         if getattr(self, member, None) is None:
             raise ObjectFormatException(error_msg)
             raise ObjectFormatException(error_msg)
 
 
-    def check(self):
+    def check(self) -> None:
         """Check this object for internal consistency.
         """Check this object for internal consistency.
 
 
         Raises:
         Raises:
@@ -493,15 +513,15 @@ class ShaFile(object):
             self._deserialize(self.as_raw_chunks())
             self._deserialize(self.as_raw_chunks())
             self._sha = None
             self._sha = None
             new_sha = self.id
             new_sha = self.id
-        except Exception as e:
-            raise ObjectFormatException(e)
+        except Exception as exc:
+            raise ObjectFormatException(exc) from exc
         if old_sha != new_sha:
         if old_sha != new_sha:
             raise ChecksumMismatch(new_sha, old_sha)
             raise ChecksumMismatch(new_sha, old_sha)
 
 
     def _header(self):
     def _header(self):
-        return object_header(self.type, self.raw_length())
+        return object_header(self.type_num, self.raw_length())
 
 
-    def raw_length(self):
+    def raw_length(self) -> int:
         """Returns the length of the raw string of this object."""
         """Returns the length of the raw string of this object."""
         ret = 0
         ret = 0
         for chunk in self.as_raw_chunks():
         for chunk in self.as_raw_chunks():
@@ -521,25 +541,14 @@ class ShaFile(object):
 
 
     def copy(self):
     def copy(self):
         """Create a new copy of this SHA1 object from its raw string"""
         """Create a new copy of this SHA1 object from its raw string"""
-        obj_class = object_class(self.get_type())
-        return obj_class.from_raw_string(self.get_type(), self.as_raw_string(), self.id)
+        obj_class = object_class(self.type_num)
+        return obj_class.from_raw_string(self.type_num, self.as_raw_string(), self.id)
 
 
     @property
     @property
     def id(self):
     def id(self):
         """The hex SHA of this object."""
         """The hex SHA of this object."""
         return self.sha().hexdigest().encode("ascii")
         return self.sha().hexdigest().encode("ascii")
 
 
-    def get_type(self):
-        """Return the type number for this object class."""
-        return self.type_num
-
-    def set_type(self, type):
-        """Set the type number for this object class."""
-        self.type_num = type
-
-    # DEPRECATED: use type_num or type_name as needed.
-    type = property(get_type, set_type)
-
     def __repr__(self):
     def __repr__(self):
         return "<%s %s>" % (self.__class__.__name__, self.id)
         return "<%s %s>" % (self.__class__.__name__, self.id)
 
 
@@ -620,7 +629,7 @@ class Blob(ShaFile):
         """
         """
         super(Blob, self).check()
         super(Blob, self).check()
 
 
-    def splitlines(self):
+    def splitlines(self) -> List[bytes]:
         """Return list of lines in this blob.
         """Return list of lines in this blob.
 
 
         This preserves the original line endings.
         This preserves the original line endings.
@@ -648,7 +657,7 @@ class Blob(ShaFile):
         return ret
         return ret
 
 
 
 
-def _parse_message(chunks):
+def _parse_message(chunks: Iterable[bytes]):
     """Parse a message with a list of fields and a body.
     """Parse a message with a list of fields and a body.
 
 
     Args:
     Args:
@@ -659,7 +668,7 @@ def _parse_message(chunks):
     """
     """
     f = BytesIO(b"".join(chunks))
     f = BytesIO(b"".join(chunks))
     k = None
     k = None
-    v = ""
+    v = b""
     eof = False
     eof = False
 
 
     def _strip_last_newline(value):
     def _strip_last_newline(value):
@@ -942,8 +951,9 @@ def parse_tree(text, strict=False):
             raise ObjectFormatException("Invalid mode '%s'" % mode_text)
             raise ObjectFormatException("Invalid mode '%s'" % mode_text)
         try:
         try:
             mode = int(mode_text, 8)
             mode = int(mode_text, 8)
-        except ValueError:
-            raise ObjectFormatException("Invalid mode '%s'" % mode_text)
+        except ValueError as exc:
+            raise ObjectFormatException(
+                "Invalid mode '%s'" % mode_text) from exc
         name_end = text.index(b"\0", mode_end)
         name_end = text.index(b"\0", mode_end)
         name = text[mode_end + 1 : name_end]
         name = text[mode_end + 1 : name_end]
         count = name_end + 21
         count = name_end + 21
@@ -1114,8 +1124,8 @@ class Tree(ShaFile):
         """Grab the entries in the tree"""
         """Grab the entries in the tree"""
         try:
         try:
             parsed_entries = parse_tree(b"".join(chunks))
             parsed_entries = parse_tree(b"".join(chunks))
-        except ValueError as e:
-            raise ObjectFormatException(e)
+        except ValueError as exc:
+            raise ObjectFormatException(exc) from exc
         # TODO: list comprehension is for efficiency in the common (small)
         # TODO: list comprehension is for efficiency in the common (small)
         # case; if memory efficiency in the large case is a concern, use a
         # case; if memory efficiency in the large case is a concern, use a
         # genexp.
         # genexp.
@@ -1255,8 +1265,8 @@ def parse_time_entry(value):
         timetext, timezonetext = rest.rsplit(b" ", 1)
         timetext, timezonetext = rest.rsplit(b" ", 1)
         time = int(timetext)
         time = int(timetext)
         timezone, timezone_neg_utc = parse_timezone(timezonetext)
         timezone, timezone_neg_utc = parse_timezone(timezonetext)
-    except ValueError as e:
-        raise ObjectFormatException(e)
+    except ValueError as exc:
+        raise ObjectFormatException(exc) from exc
     return person, time, (timezone, timezone_neg_utc)
     return person, time, (timezone, timezone_neg_utc)
 
 
 
 
@@ -1594,7 +1604,7 @@ OBJECT_CLASSES = (
     Tag,
     Tag,
 )
 )
 
 
-_TYPE_MAP = {}  # type: Dict[Union[bytes, int], Type[ShaFile]]
+_TYPE_MAP: Dict[Union[bytes, int], Type[ShaFile]] = {}
 
 
 for cls in OBJECT_CLASSES:
 for cls in OBJECT_CLASSES:
     _TYPE_MAP[cls.type_name] = cls
     _TYPE_MAP[cls.type_name] = cls

+ 257 - 161
dulwich/pack.py

@@ -39,13 +39,18 @@ from io import BytesIO, UnsupportedOperation
 from collections import (
 from collections import (
     deque,
     deque,
 )
 )
-import difflib
+try:
+    from cdifflib import CSequenceMatcher as SequenceMatcher
+except ModuleNotFoundError:
+    from difflib import SequenceMatcher
 import struct
 import struct
 
 
 from itertools import chain
 from itertools import chain
 
 
 import os
 import os
 import sys
 import sys
+from typing import Optional, Callable, Tuple, List
+import warnings
 
 
 from hashlib import sha1
 from hashlib import sha1
 from os import (
 from os import (
@@ -396,15 +401,7 @@ class PackIndex(object):
         lives at within the corresponding pack file. If the pack file doesn't
         lives at within the corresponding pack file. If the pack file doesn't
         have the object then None will be returned.
         have the object then None will be returned.
         """
         """
-        if len(sha) == 40:
-            sha = hex_to_sha(sha)
-        try:
-            return self._object_index(sha)
-        except ValueError:
-            closed = getattr(self._contents, "closed", None)
-            if closed in (None, True):
-                raise PackFileDisappeared(self)
-            raise
+        raise NotImplementedError(self.object_index)
 
 
     def object_sha1(self, index):
     def object_sha1(self, index):
         """Return the SHA1 corresponding to the index in the pack file."""
         """Return the SHA1 corresponding to the index in the pack file."""
@@ -459,7 +456,9 @@ class MemoryPackIndex(PackIndex):
     def __len__(self):
     def __len__(self):
         return len(self._entries)
         return len(self._entries)
 
 
-    def _object_index(self, sha):
+    def object_index(self, sha):
+        if len(sha) == 40:
+            sha = hex_to_sha(sha)
         return self._by_sha[sha][0]
         return self._by_sha[sha][0]
 
 
     def object_sha1(self, index):
     def object_sha1(self, index):
@@ -484,6 +483,8 @@ class FilePackIndex(PackIndex):
     present.
     present.
     """
     """
 
 
+    _fan_out_table: List[int]
+
     def __init__(self, filename, file=None, contents=None, size=None):
     def __init__(self, filename, file=None, contents=None, size=None):
         """Create a pack index object.
         """Create a pack index object.
 
 
@@ -595,6 +596,23 @@ class FilePackIndex(PackIndex):
         """
         """
         return bytes(self._contents[-20:])
         return bytes(self._contents[-20:])
 
 
+    def object_index(self, sha):
+        """Return the index in to the corresponding packfile for the object.
+
+        Given the name of an object it will return the offset that object
+        lives at within the corresponding pack file. If the pack file doesn't
+        have the object then None will be returned.
+        """
+        if len(sha) == 40:
+            sha = hex_to_sha(sha)
+        try:
+            return self._object_index(sha)
+        except ValueError as exc:
+            closed = getattr(self._contents, "closed", None)
+            if closed in (None, True):
+                raise PackFileDisappeared(self) from exc
+            raise
+
     def _object_index(self, sha):
     def _object_index(self, sha):
         """See object_index.
         """See object_index.
 
 
@@ -1059,7 +1077,6 @@ class PackData(object):
         self._offset_cache = LRUSizeCache(
         self._offset_cache = LRUSizeCache(
             1024 * 1024 * 20, compute_size=_compute_object_size
             1024 * 1024 * 20, compute_size=_compute_object_size
         )
         )
-        self.pack = None
 
 
     @property
     @property
     def filename(self):
     def filename(self):
@@ -1122,65 +1139,6 @@ class PackData(object):
         """
         """
         return compute_file_sha(self._file, end_ofs=-20).digest()
         return compute_file_sha(self._file, end_ofs=-20).digest()
 
 
-    def get_ref(self, sha):
-        """Get the object for a ref SHA, only looking in this pack."""
-        # TODO: cache these results
-        if self.pack is None:
-            raise KeyError(sha)
-        try:
-            offset = self.pack.index.object_index(sha)
-        except KeyError:
-            offset = None
-        if offset:
-            type, obj = self.get_object_at(offset)
-        elif self.pack is not None and self.pack.resolve_ext_ref:
-            type, obj = self.pack.resolve_ext_ref(sha)
-        else:
-            raise KeyError(sha)
-        return offset, type, obj
-
-    def resolve_object(self, offset, type, obj, get_ref=None):
-        """Resolve an object, possibly resolving deltas when necessary.
-
-        Returns: Tuple with object type and contents.
-        """
-        # Walk down the delta chain, building a stack of deltas to reach
-        # the requested object.
-        base_offset = offset
-        base_type = type
-        base_obj = obj
-        delta_stack = []
-        while base_type in DELTA_TYPES:
-            prev_offset = base_offset
-            if get_ref is None:
-                get_ref = self.get_ref
-            if base_type == OFS_DELTA:
-                (delta_offset, delta) = base_obj
-                # TODO: clean up asserts and replace with nicer error messages
-                base_offset = base_offset - delta_offset
-                base_type, base_obj = self.get_object_at(base_offset)
-                assert isinstance(base_type, int)
-            elif base_type == REF_DELTA:
-                (basename, delta) = base_obj
-                assert isinstance(basename, bytes) and len(basename) == 20
-                base_offset, base_type, base_obj = get_ref(basename)
-                assert isinstance(base_type, int)
-            delta_stack.append((prev_offset, base_type, delta))
-
-        # Now grab the base object (mustn't be a delta) and apply the
-        # deltas all the way up the stack.
-        chunks = base_obj
-        for prev_offset, delta_type, delta in reversed(delta_stack):
-            chunks = apply_delta(chunks, delta)
-            # TODO(dborowitz): This can result in poor performance if
-            # large base objects are separated from deltas in the pack.
-            # We should reorganize so that we apply deltas to all
-            # objects in a chain one after the other to optimize cache
-            # performance.
-            if prev_offset is not None:
-                self._offset_cache[prev_offset] = base_type, chunks
-        return base_type, chunks
-
     def iterobjects(self, progress=None, compute_crc32=True):
     def iterobjects(self, progress=None, compute_crc32=True):
         self._file.seek(self._header_size)
         self._file.seek(self._header_size)
         for i in range(1, self._num_objects + 1):
         for i in range(1, self._num_objects + 1):
@@ -1215,7 +1173,7 @@ class PackData(object):
             # Back up over unused data.
             # Back up over unused data.
             self._file.seek(-len(unused), SEEK_CUR)
             self._file.seek(-len(unused), SEEK_CUR)
 
 
-    def iterentries(self, progress=None):
+    def iterentries(self, progress=None, resolve_ext_ref=None):
         """Yield entries summarizing the contents of this pack.
         """Yield entries summarizing the contents of this pack.
 
 
         Args:
         Args:
@@ -1224,25 +1182,24 @@ class PackData(object):
         Returns: iterator of tuples with (sha, offset, crc32)
         Returns: iterator of tuples with (sha, offset, crc32)
         """
         """
         num_objects = self._num_objects
         num_objects = self._num_objects
-        resolve_ext_ref = self.pack.resolve_ext_ref if self.pack is not None else None
         indexer = PackIndexer.for_pack_data(self, resolve_ext_ref=resolve_ext_ref)
         indexer = PackIndexer.for_pack_data(self, resolve_ext_ref=resolve_ext_ref)
         for i, result in enumerate(indexer):
         for i, result in enumerate(indexer):
             if progress is not None:
             if progress is not None:
                 progress(i, num_objects)
                 progress(i, num_objects)
             yield result
             yield result
 
 
-    def sorted_entries(self, progress=None):
+    def sorted_entries(self, progress=None, resolve_ext_ref=None):
         """Return entries in this pack, sorted by SHA.
         """Return entries in this pack, sorted by SHA.
 
 
         Args:
         Args:
           progress: Progress function, called with current and total
           progress: Progress function, called with current and total
             object count
             object count
-        Returns: List of tuples with (sha, offset, crc32)
+        Returns: Iterator of tuples with (sha, offset, crc32)
         """
         """
-        ret = sorted(self.iterentries(progress=progress))
-        return ret
+        return sorted(self.iterentries(
+            progress=progress, resolve_ext_ref=resolve_ext_ref))
 
 
-    def create_index_v1(self, filename, progress=None):
+    def create_index_v1(self, filename, progress=None, resolve_ext_ref=None):
         """Create a version 1 file for this data file.
         """Create a version 1 file for this data file.
 
 
         Args:
         Args:
@@ -1250,11 +1207,12 @@ class PackData(object):
           progress: Progress report function
           progress: Progress report function
         Returns: Checksum of index file
         Returns: Checksum of index file
         """
         """
-        entries = self.sorted_entries(progress=progress)
+        entries = self.sorted_entries(
+            progress=progress, resolve_ext_ref=resolve_ext_ref)
         with GitFile(filename, "wb") as f:
         with GitFile(filename, "wb") as f:
             return write_pack_index_v1(f, entries, self.calculate_checksum())
             return write_pack_index_v1(f, entries, self.calculate_checksum())
 
 
-    def create_index_v2(self, filename, progress=None):
+    def create_index_v2(self, filename, progress=None, resolve_ext_ref=None):
         """Create a version 2 index file for this data file.
         """Create a version 2 index file for this data file.
 
 
         Args:
         Args:
@@ -1262,11 +1220,12 @@ class PackData(object):
           progress: Progress report function
           progress: Progress report function
         Returns: Checksum of index file
         Returns: Checksum of index file
         """
         """
-        entries = self.sorted_entries(progress=progress)
+        entries = self.sorted_entries(
+            progress=progress, resolve_ext_ref=resolve_ext_ref)
         with GitFile(filename, "wb") as f:
         with GitFile(filename, "wb") as f:
             return write_pack_index_v2(f, entries, self.calculate_checksum())
             return write_pack_index_v2(f, entries, self.calculate_checksum())
 
 
-    def create_index(self, filename, progress=None, version=2):
+    def create_index(self, filename, progress=None, version=2, resolve_ext_ref=None):
         """Create an  index file for this data file.
         """Create an  index file for this data file.
 
 
         Args:
         Args:
@@ -1275,9 +1234,11 @@ class PackData(object):
         Returns: Checksum of index file
         Returns: Checksum of index file
         """
         """
         if version == 1:
         if version == 1:
-            return self.create_index_v1(filename, progress)
+            return self.create_index_v1(
+                filename, progress, resolve_ext_ref=resolve_ext_ref)
         elif version == 2:
         elif version == 2:
-            return self.create_index_v2(filename, progress)
+            return self.create_index_v2(
+                filename, progress, resolve_ext_ref=resolve_ext_ref)
         else:
         else:
             raise ValueError("unknown index format %d" % version)
             raise ValueError("unknown index format %d" % version)
 
 
@@ -1383,10 +1344,8 @@ class DeltaChainIterator(object):
 
 
     def _walk_all_chains(self):
     def _walk_all_chains(self):
         for offset, type_num in self._full_ofs:
         for offset, type_num in self._full_ofs:
-            for result in self._follow_chain(offset, type_num, None):
-                yield result
-        for result in self._walk_ref_chains():
-            yield result
+            yield from self._follow_chain(offset, type_num, None)
+        yield from self._walk_ref_chains()
         assert not self._pending_ofs
         assert not self._pending_ofs
 
 
     def _ensure_no_pending(self):
     def _ensure_no_pending(self):
@@ -1411,8 +1370,7 @@ class DeltaChainIterator(object):
             self._ext_refs.append(base_sha)
             self._ext_refs.append(base_sha)
             self._pending_ref.pop(base_sha)
             self._pending_ref.pop(base_sha)
             for new_offset in pending:
             for new_offset in pending:
-                for result in self._follow_chain(new_offset, type_num, chunks):
-                    yield result
+                yield from self._follow_chain(new_offset, type_num, chunks)
 
 
         self._ensure_no_pending()
         self._ensure_no_pending()
 
 
@@ -1563,28 +1521,50 @@ def pack_object_header(type_num, delta_base, size):
     return bytearray(header)
     return bytearray(header)
 
 
 
 
-def write_pack_object(f, type, object, sha=None, compression_level=-1):
-    """Write pack object to a file.
+def pack_object_chunks(type, object, compression_level=-1):
+    """Generate chunks for a pack object.
 
 
     Args:
     Args:
-      f: File to write to
       type: Numeric type of the object
       type: Numeric type of the object
       object: Object to write
       object: Object to write
       compression_level: the zlib compression level
       compression_level: the zlib compression level
-    Returns: Tuple with offset at which the object was written, and crc32
+    Returns: Chunks
     """
     """
     if type in DELTA_TYPES:
     if type in DELTA_TYPES:
         delta_base, object = object
         delta_base, object = object
     else:
     else:
         delta_base = None
         delta_base = None
-    header = bytes(pack_object_header(type, delta_base, len(object)))
-    comp_data = zlib.compress(object, compression_level)
+    if isinstance(object, bytes):
+        object = [object]
+    yield bytes(pack_object_header(type, delta_base, sum(map(len, object))))
+    compressor = zlib.compressobj(level=compression_level)
+    for data in object:
+        yield compressor.compress(data)
+    yield compressor.flush()
+
+
+def write_pack_object(write, type, object, sha=None, compression_level=-1):
+    """Write pack object to a file.
+
+    Args:
+      write: Write function to use
+      type: Numeric type of the object
+      object: Object to write
+      compression_level: the zlib compression level
+    Returns: Tuple with offset at which the object was written, and crc32
+    """
+    if hasattr(write, 'write'):
+        warnings.warn(
+            'write_pack_object() now takes a write rather than file argument',
+            DeprecationWarning, stacklevel=2)
+        write = write.write
     crc32 = 0
     crc32 = 0
-    for data in (header, comp_data):
-        f.write(data)
+    for chunk in pack_object_chunks(
+            type, object, compression_level=compression_level):
+        write(chunk)
         if sha is not None:
         if sha is not None:
-            sha.update(data)
-        crc32 = binascii.crc32(data, crc32)
+            sha.update(chunk)
+        crc32 = binascii.crc32(chunk, crc32)
     return crc32 & 0xFFFFFFFF
     return crc32 & 0xFFFFFFFF
 
 
 
 
@@ -1607,7 +1587,7 @@ def write_pack(
     """
     """
     with GitFile(filename + ".pack", "wb") as f:
     with GitFile(filename + ".pack", "wb") as f:
         entries, data_sum = write_pack_objects(
         entries, data_sum = write_pack_objects(
-            f,
+            f.write,
             objects,
             objects,
             delta_window_size=delta_window_size,
             delta_window_size=delta_window_size,
             deltify=deltify,
             deltify=deltify,
@@ -1618,11 +1598,22 @@ def write_pack(
         return data_sum, write_pack_index_v2(f, entries, data_sum)
         return data_sum, write_pack_index_v2(f, entries, data_sum)
 
 
 
 
-def write_pack_header(f, num_objects):
+def pack_header_chunks(num_objects):
+    """Yield chunks for a pack header."""
+    yield b"PACK"  # Pack header
+    yield struct.pack(b">L", 2)  # Pack version
+    yield struct.pack(b">L", num_objects)  # Number of objects in pack
+
+
+def write_pack_header(write, num_objects):
     """Write a pack header for the given number of objects."""
     """Write a pack header for the given number of objects."""
-    f.write(b"PACK")  # Pack header
-    f.write(struct.pack(b">L", 2))  # Pack version
-    f.write(struct.pack(b">L", num_objects))  # Number of objects in pack
+    if hasattr(write, 'write'):
+        write = write.write
+        warnings.warn(
+            'write_pack_header() now takes a write rather than file argument',
+            DeprecationWarning, stacklevel=2)
+    for chunk in pack_header_chunks(num_objects):
+        write(chunk)
 
 
 
 
 def deltify_pack_objects(objects, window_size=None):
 def deltify_pack_objects(objects, window_size=None):
@@ -1647,18 +1638,26 @@ def deltify_pack_objects(objects, window_size=None):
     possible_bases = deque()
     possible_bases = deque()
 
 
     for type_num, path, neg_length, o in magic:
     for type_num, path, neg_length, o in magic:
-        raw = o.as_raw_string()
+        raw = o.as_raw_chunks()
         winner = raw
         winner = raw
+        winner_len = sum(map(len, winner))
         winner_base = None
         winner_base = None
-        for base in possible_bases:
-            if base.type_num != type_num:
+        for base_id, base_type_num, base in possible_bases:
+            if base_type_num != type_num:
                 continue
                 continue
-            delta = create_delta(base.as_raw_string(), raw)
-            if len(delta) < len(winner):
-                winner_base = base.sha().digest()
+            delta_len = 0
+            delta = []
+            for chunk in create_delta(base, raw):
+                delta_len += len(chunk)
+                if delta_len >= winner_len:
+                    break
+                delta.append(chunk)
+            else:
+                winner_base = base_id
                 winner = delta
                 winner = delta
+                winner_len = sum(map(len, winner))
         yield type_num, o.sha().digest(), winner_base, winner
         yield type_num, o.sha().digest(), winner_base, winner
-        possible_bases.appendleft(o)
+        possible_bases.appendleft((o.sha().digest(), type_num, raw))
         while len(possible_bases) > window_size:
         while len(possible_bases) > window_size:
             possible_bases.pop()
             possible_bases.pop()
 
 
@@ -1674,19 +1673,19 @@ def pack_objects_to_data(objects):
     return (
     return (
         count,
         count,
         (
         (
-            (o.type_num, o.sha().digest(), None, o.as_raw_string())
+            (o.type_num, o.sha().digest(), None, o.as_raw_chunks())
             for (o, path) in objects
             for (o, path) in objects
         ),
         ),
     )
     )
 
 
 
 
 def write_pack_objects(
 def write_pack_objects(
-    f, objects, delta_window_size=None, deltify=None, compression_level=-1
+    write, objects, delta_window_size=None, deltify=None, compression_level=-1
 ):
 ):
     """Write a new pack data file.
     """Write a new pack data file.
 
 
     Args:
     Args:
-      f: File to write to
+      write: write function to use
       objects: Iterable of (object, path) tuples to write. Should provide
       objects: Iterable of (object, path) tuples to write. Should provide
          __len__
          __len__
       delta_window_size: Sliding window size for searching for deltas;
       delta_window_size: Sliding window size for searching for deltas;
@@ -1695,6 +1694,12 @@ def write_pack_objects(
       compression_level: the zlib compression level to use
       compression_level: the zlib compression level to use
     Returns: Dict mapping id -> (offset, crc32 checksum), pack checksum
     Returns: Dict mapping id -> (offset, crc32 checksum), pack checksum
     """
     """
+    if hasattr(write, 'write'):
+        warnings.warn(
+            'write_pack_objects() now takes a write rather than file argument',
+            DeprecationWarning, stacklevel=2)
+        write = write.write
+
     if deltify is None:
     if deltify is None:
         # PERFORMANCE/TODO(jelmer): This should be enabled but is *much* too
         # PERFORMANCE/TODO(jelmer): This should be enabled but is *much* too
         # slow at the moment.
         # slow at the moment.
@@ -1706,7 +1711,7 @@ def write_pack_objects(
         pack_contents_count, pack_contents = pack_objects_to_data(objects)
         pack_contents_count, pack_contents = pack_objects_to_data(objects)
 
 
     return write_pack_data(
     return write_pack_data(
-        f,
+        write,
         pack_contents_count,
         pack_contents_count,
         pack_contents,
         pack_contents,
         compression_level=compression_level,
         compression_level=compression_level,
@@ -1740,11 +1745,11 @@ class PackChunkGenerator(object):
         # Write the pack
         # Write the pack
         if num_records is None:
         if num_records is None:
             num_records = len(records)
             num_records = len(records)
-        f = BytesIO()
-        write_pack_header(f, num_records)
-        self.cs.update(f.getvalue())
-        yield f.getvalue()
-        offset = f.tell()
+        offset = 0
+        for chunk in pack_header_chunks(num_records):
+            yield chunk
+            self.cs.update(chunk)
+            offset += len(chunk)
         actual_num_records = 0
         actual_num_records = 0
         for i, (type_num, object_id, delta_base, raw) in enumerate(records):
         for i, (type_num, object_id, delta_base, raw) in enumerate(records):
             if progress is not None:
             if progress is not None:
@@ -1758,13 +1763,16 @@ class PackChunkGenerator(object):
                 else:
                 else:
                     type_num = OFS_DELTA
                     type_num = OFS_DELTA
                     raw = (offset - base_offset, raw)
                     raw = (offset - base_offset, raw)
-            f = BytesIO()
-            crc32 = write_pack_object(f, type_num, raw, compression_level=compression_level)
-            self.cs.update(f.getvalue())
-            yield f.getvalue()
+            crc32 = 0
+            object_size = 0
+            for chunk in pack_object_chunks(type_num, raw, compression_level=compression_level):
+                yield chunk
+                crc32 = binascii.crc32(chunk, crc32)
+                self.cs.update(chunk)
+                object_size += len(chunk)
             actual_num_records += 1
             actual_num_records += 1
             self.entries[object_id] = (offset, crc32)
             self.entries[object_id] = (offset, crc32)
-            offset += f.tell()
+            offset += object_size
         if actual_num_records != num_records:
         if actual_num_records != num_records:
             raise AssertionError(
             raise AssertionError(
                 'actual records written differs: %d != %d' % (
                 'actual records written differs: %d != %d' % (
@@ -1773,22 +1781,27 @@ class PackChunkGenerator(object):
         yield self.cs.digest()
         yield self.cs.digest()
 
 
 
 
-def write_pack_data(f, num_records=None, records=None, progress=None, compression_level=-1):
+def write_pack_data(write, num_records=None, records=None, progress=None, compression_level=-1):
     """Write a new pack data file.
     """Write a new pack data file.
 
 
     Args:
     Args:
-      f: File to write to
+      write: Write function to use
       num_records: Number of records (defaults to len(records) if None)
       num_records: Number of records (defaults to len(records) if None)
       records: Iterator over type_num, object_id, delta_base, raw
       records: Iterator over type_num, object_id, delta_base, raw
       progress: Function to report progress to
       progress: Function to report progress to
       compression_level: the zlib compression level
       compression_level: the zlib compression level
     Returns: Dict mapping id -> (offset, crc32 checksum), pack checksum
     Returns: Dict mapping id -> (offset, crc32 checksum), pack checksum
     """
     """
+    if hasattr(write, 'write'):
+        warnings.warn(
+            'write_pack_data() now takes a write rather than file argument',
+            DeprecationWarning, stacklevel=2)
+        write = write.write
     chunk_generator = PackChunkGenerator(
     chunk_generator = PackChunkGenerator(
         num_records=num_records, records=records, progress=progress,
         num_records=num_records, records=records, progress=progress,
         compression_level=compression_level)
         compression_level=compression_level)
     for chunk in chunk_generator:
     for chunk in chunk_generator:
-        f.write(chunk)
+        write(chunk)
     return chunk_generator.entries, chunk_generator.sha1digest()
     return chunk_generator.entries, chunk_generator.sha1digest()
 
 
 
 
@@ -1819,7 +1832,7 @@ def write_pack_index_v1(f, entries, pack_checksum):
     return f.write_sha()
     return f.write_sha()
 
 
 
 
-def _delta_encode_size(size):
+def _delta_encode_size(size) -> bytes:
     ret = bytearray()
     ret = bytearray()
     c = size & 0x7F
     c = size & 0x7F
     size >>= 7
     size >>= 7
@@ -1828,7 +1841,7 @@ def _delta_encode_size(size):
         c = size & 0x7F
         c = size & 0x7F
         size >>= 7
         size >>= 7
     ret.append(c)
     ret.append(c)
-    return ret
+    return bytes(ret)
 
 
 
 
 # The length of delta compression copy operations in version 2 packs is limited
 # The length of delta compression copy operations in version 2 packs is limited
@@ -1838,17 +1851,16 @@ _MAX_COPY_LEN = 0xFFFF
 
 
 
 
 def _encode_copy_operation(start, length):
 def _encode_copy_operation(start, length):
-    scratch = []
-    op = 0x80
+    scratch = bytearray([0x80])
     for i in range(4):
     for i in range(4):
         if start & 0xFF << i * 8:
         if start & 0xFF << i * 8:
             scratch.append((start >> i * 8) & 0xFF)
             scratch.append((start >> i * 8) & 0xFF)
-            op |= 1 << i
+            scratch[0] |= 1 << i
     for i in range(2):
     for i in range(2):
         if length & 0xFF << i * 8:
         if length & 0xFF << i * 8:
             scratch.append((length >> i * 8) & 0xFF)
             scratch.append((length >> i * 8) & 0xFF)
-            op |= 1 << (4 + i)
-    return bytearray([op] + scratch)
+            scratch[0] |= 1 << (4 + i)
+    return bytes(scratch)
 
 
 
 
 def create_delta(base_buf, target_buf):
 def create_delta(base_buf, target_buf):
@@ -1858,14 +1870,17 @@ def create_delta(base_buf, target_buf):
       base_buf: Base buffer
       base_buf: Base buffer
       target_buf: Target buffer
       target_buf: Target buffer
     """
     """
+    if isinstance(base_buf, list):
+        base_buf = b''.join(base_buf)
+    if isinstance(target_buf, list):
+        target_buf = b''.join(target_buf)
     assert isinstance(base_buf, bytes)
     assert isinstance(base_buf, bytes)
     assert isinstance(target_buf, bytes)
     assert isinstance(target_buf, bytes)
-    out_buf = bytearray()
     # write delta header
     # write delta header
-    out_buf += _delta_encode_size(len(base_buf))
-    out_buf += _delta_encode_size(len(target_buf))
+    yield _delta_encode_size(len(base_buf))
+    yield _delta_encode_size(len(target_buf))
     # write out delta opcodes
     # write out delta opcodes
-    seq = difflib.SequenceMatcher(a=base_buf, b=target_buf)
+    seq = SequenceMatcher(isjunk=None, a=base_buf, b=target_buf)
     for opcode, i1, i2, j1, j2 in seq.get_opcodes():
     for opcode, i1, i2, j1, j2 in seq.get_opcodes():
         # Git patch opcodes don't care about deletes!
         # Git patch opcodes don't care about deletes!
         # if opcode == 'replace' or opcode == 'delete':
         # if opcode == 'replace' or opcode == 'delete':
@@ -1877,7 +1892,7 @@ def create_delta(base_buf, target_buf):
             copy_len = i2 - i1
             copy_len = i2 - i1
             while copy_len > 0:
             while copy_len > 0:
                 to_copy = min(copy_len, _MAX_COPY_LEN)
                 to_copy = min(copy_len, _MAX_COPY_LEN)
-                out_buf += _encode_copy_operation(copy_start, to_copy)
+                yield _encode_copy_operation(copy_start, to_copy)
                 copy_start += to_copy
                 copy_start += to_copy
                 copy_len -= to_copy
                 copy_len -= to_copy
         if opcode == "replace" or opcode == "insert":
         if opcode == "replace" or opcode == "insert":
@@ -1886,13 +1901,12 @@ def create_delta(base_buf, target_buf):
             s = j2 - j1
             s = j2 - j1
             o = j1
             o = j1
             while s > 127:
             while s > 127:
-                out_buf.append(127)
-                out_buf += bytearray(target_buf[o : o + 127])
+                yield bytes([127])
+                yield memoryview(target_buf)[o:o + 127]
                 s -= 127
                 s -= 127
                 o += 127
                 o += 127
-            out_buf.append(s)
-            out_buf += bytearray(target_buf[o : o + s])
-    return bytes(out_buf)
+            yield bytes([s])
+            yield memoryview(target_buf)[o:o + s]
 
 
 
 
 def apply_delta(src_buf, delta):
 def apply_delta(src_buf, delta):
@@ -2007,10 +2021,25 @@ def write_pack_index_v2(f, entries, pack_checksum):
 write_pack_index = write_pack_index_v2
 write_pack_index = write_pack_index_v2
 
 
 
 
+class _PackTupleIterable(object):
+    """Helper for Pack.pack_tuples."""
+
+    def __init__(self, iterobjects, length):
+        self._iterobjects = iterobjects
+        self._length = length
+
+    def __len__(self):
+        return self._length
+
+    def __iter__(self):
+        return ((o, None) for o in self._iterobjects())
+
+
 class Pack(object):
 class Pack(object):
     """A Git pack object."""
     """A Git pack object."""
 
 
-    def __init__(self, basename, resolve_ext_ref=None):
+    def __init__(self, basename, resolve_ext_ref: Optional[
+            Callable[[bytes], Tuple[int, UnpackedObject]]] = None):
         self._basename = basename
         self._basename = basename
         self._data = None
         self._data = None
         self._idx = None
         self._idx = None
@@ -2034,7 +2063,6 @@ class Pack(object):
         """Create a new pack object from pack data and index objects."""
         """Create a new pack object from pack data and index objects."""
         ret = cls("")
         ret = cls("")
         ret._data = data
         ret._data = data
-        ret._data.pack = ret
         ret._data_load = None
         ret._data_load = None
         ret._idx = idx
         ret._idx = idx
         ret._idx_load = None
         ret._idx_load = None
@@ -2050,7 +2078,6 @@ class Pack(object):
         """The pack data object being used."""
         """The pack data object being used."""
         if self._data is None:
         if self._data is None:
             self._data = self._data_load()
             self._data = self._data_load()
-            self._data.pack = self
             self.check_length_and_checksum()
             self.check_length_and_checksum()
         return self._data
         return self._data
 
 
@@ -2142,7 +2169,7 @@ class Pack(object):
     def get_raw(self, sha1):
     def get_raw(self, sha1):
         offset = self.index.object_index(sha1)
         offset = self.index.object_index(sha1)
         obj_type, obj = self.data.get_object_at(offset)
         obj_type, obj = self.data.get_object_at(offset)
-        type_num, chunks = self.data.resolve_object(offset, obj_type, obj)
+        type_num, chunks = self.resolve_object(offset, obj_type, obj)
         return type_num, b"".join(chunks)
         return type_num, b"".join(chunks)
 
 
     def __getitem__(self, sha1):
     def __getitem__(self, sha1):
@@ -2163,17 +2190,7 @@ class Pack(object):
             and provides __len__
             and provides __len__
         """
         """
 
 
-        class PackTupleIterable(object):
-            def __init__(self, pack):
-                self.pack = pack
-
-            def __len__(self):
-                return len(self.pack)
-
-            def __iter__(self):
-                return ((o, None) for o in self.pack.iterobjects())
-
-        return PackTupleIterable(self)
+        return _PackTupleIterable(self.iterobjects, len(self))
 
 
     def keep(self, msg=None):
     def keep(self, msg=None):
         """Add a .keep file for the pack, preventing git from garbage collecting it.
         """Add a .keep file for the pack, preventing git from garbage collecting it.
@@ -2190,6 +2207,85 @@ class Pack(object):
                 keepfile.write(b"\n")
                 keepfile.write(b"\n")
         return keepfile_name
         return keepfile_name
 
 
+    def get_ref(self, sha) -> Tuple[int, int, UnpackedObject]:
+        """Get the object for a ref SHA, only looking in this pack."""
+        # TODO: cache these results
+        try:
+            offset = self.index.object_index(sha)
+        except KeyError:
+            offset = None
+        if offset:
+            type, obj = self.data.get_object_at(offset)
+        elif self.resolve_ext_ref:
+            type, obj = self.resolve_ext_ref(sha)
+        else:
+            raise KeyError(sha)
+        return offset, type, obj
+
+    def resolve_object(self, offset, type, obj, get_ref=None):
+        """Resolve an object, possibly resolving deltas when necessary.
+
+        Returns: Tuple with object type and contents.
+        """
+        # Walk down the delta chain, building a stack of deltas to reach
+        # the requested object.
+        base_offset = offset
+        base_type = type
+        base_obj = obj
+        delta_stack = []
+        while base_type in DELTA_TYPES:
+            prev_offset = base_offset
+            if get_ref is None:
+                get_ref = self.get_ref
+            if base_type == OFS_DELTA:
+                (delta_offset, delta) = base_obj
+                # TODO: clean up asserts and replace with nicer error messages
+                base_offset = base_offset - delta_offset
+                base_type, base_obj = self.data.get_object_at(base_offset)
+                assert isinstance(base_type, int)
+            elif base_type == REF_DELTA:
+                (basename, delta) = base_obj
+                assert isinstance(basename, bytes) and len(basename) == 20
+                base_offset, base_type, base_obj = get_ref(basename)
+                assert isinstance(base_type, int)
+            delta_stack.append((prev_offset, base_type, delta))
+
+        # Now grab the base object (mustn't be a delta) and apply the
+        # deltas all the way up the stack.
+        chunks = base_obj
+        for prev_offset, delta_type, delta in reversed(delta_stack):
+            chunks = apply_delta(chunks, delta)
+            # TODO(dborowitz): This can result in poor performance if
+            # large base objects are separated from deltas in the pack.
+            # We should reorganize so that we apply deltas to all
+            # objects in a chain one after the other to optimize cache
+            # performance.
+            if prev_offset is not None:
+                self.data._offset_cache[prev_offset] = base_type, chunks
+        return base_type, chunks
+
+    def entries(self, progress=None):
+        """Yield entries summarizing the contents of this pack.
+
+        Args:
+          progress: Progress function, called with current and total
+            object count.
+        Returns: iterator of tuples with (sha, offset, crc32)
+        """
+        return self.data.iterentries(
+            progress=progress, resolve_ext_ref=self.resolve_ext_ref)
+
+    def sorted_entries(self, progress=None):
+        """Return entries in this pack, sorted by SHA.
+
+        Args:
+          progress: Progress function, called with current and total
+            object count
+        Returns: Iterator of tuples with (sha, offset, crc32)
+        """
+        return self.data.sorted_entries(
+            progress=progress, resolve_ext_ref=self.resolve_ext_ref)
+
 
 
 try:
 try:
     from dulwich._pack import (  # type: ignore # noqa: F811
     from dulwich._pack import (  # type: ignore # noqa: F811

+ 59 - 27
dulwich/porcelain.py

@@ -43,6 +43,8 @@ Currently implemented:
  * remote{_add}
  * remote{_add}
  * receive-pack
  * receive-pack
  * reset
  * reset
+ * submodule_add
+ * submodule_init
  * submodule_list
  * submodule_list
  * rev-list
  * rev-list
  * tag{_create,_delete,_list}
  * tag{_create,_delete,_list}
@@ -87,8 +89,10 @@ from dulwich.client import (
     get_transport_and_path,
     get_transport_and_path,
 )
 )
 from dulwich.config import (
 from dulwich.config import (
+    Config,
     ConfigFile,
     ConfigFile,
     StackedConfig,
     StackedConfig,
+    read_submodules,
 )
 )
 from dulwich.diff_tree import (
 from dulwich.diff_tree import (
     CHANGE_ADD,
     CHANGE_ADD,
@@ -182,9 +186,8 @@ DEFAULT_ENCODING = "utf-8"
 class Error(Exception):
 class Error(Exception):
     """Porcelain-based error. """
     """Porcelain-based error. """
 
 
-    def __init__(self, msg, inner=None):
+    def __init__(self, msg):
         super(Error, self).__init__(msg)
         super(Error, self).__init__(msg)
-        self.inner = inner
 
 
 
 
 class RemoteExists(Error):
 class RemoteExists(Error):
@@ -197,10 +200,13 @@ class TimezoneFormatError(Error):
 
 
 def parse_timezone_format(tz_str):
 def parse_timezone_format(tz_str):
     """Parse given string and attempt to return a timezone offset.
     """Parse given string and attempt to return a timezone offset.
+
     Different formats are considered in the following order:
     Different formats are considered in the following order:
-        - Git internal format: <unix timestamp> <timezone offset>
-        - RFC 2822: e.g. Mon, 20 Nov 1995 19:12:08 -0500
-        - ISO 8601: e.g. 1995-11-20T19:12:08-0500
+
+     - Git internal format: <unix timestamp> <timezone offset>
+     - RFC 2822: e.g. Mon, 20 Nov 1995 19:12:08 -0500
+     - ISO 8601: e.g. 1995-11-20T19:12:08-0500
+
     Args:
     Args:
       tz_str: datetime string
       tz_str: datetime string
     Returns: Timezone offset as integer
     Returns: Timezone offset as integer
@@ -330,6 +336,10 @@ def path_to_tree_path(repopath, path, tree_encoding=DEFAULT_ENCODING):
 class DivergedBranches(Error):
 class DivergedBranches(Error):
     """Branches have diverged and fast-forward is not possible."""
     """Branches have diverged and fast-forward is not possible."""
 
 
+    def __init__(self, current_sha, new_sha):
+        self.current_sha = current_sha
+        self.new_sha = new_sha
+
 
 
 def check_diverged(repo, current_sha, new_sha):
 def check_diverged(repo, current_sha, new_sha):
     """Check if updating to a sha can be done with fast forwarding.
     """Check if updating to a sha can be done with fast forwarding.
@@ -487,10 +497,10 @@ def clone(
     checkout=None,
     checkout=None,
     errstream=default_bytes_err_stream,
     errstream=default_bytes_err_stream,
     outstream=None,
     outstream=None,
-    origin="origin",
-    depth=None,
-    branch=None,
-    config=None,
+    origin: Optional[str] = "origin",
+    depth: Optional[int] = None,
+    branch: Optional[Union[str, bytes]] = None,
+    config: Optional[Config] = None,
     **kwargs
     **kwargs
 ):
 ):
     """Clone a local or remote git repository.
     """Clone a local or remote git repository.
@@ -530,6 +540,9 @@ def clone(
     if target is None:
     if target is None:
         target = source.split("/")[-1]
         target = source.split("/")[-1]
 
 
+    if isinstance(branch, str):
+        branch = branch.encode(DEFAULT_ENCODING)
+
     mkdir = not os.path.exists(target)
     mkdir = not os.path.exists(target)
 
 
     (client, path) = get_transport_and_path(
     (client, path) = get_transport_and_path(
@@ -663,8 +676,8 @@ def remove(repo=".", paths=None, cached=False):
             tree_path = path_to_tree_path(r.path, p)
             tree_path = path_to_tree_path(r.path, p)
             try:
             try:
                 index_sha = index[tree_path].sha
                 index_sha = index[tree_path].sha
-            except KeyError:
-                raise Error("%s did not match any files" % p)
+            except KeyError as exc:
+                raise Error("%s did not match any files" % p) from exc
 
 
             if not cached:
             if not cached:
                 try:
                 try:
@@ -986,6 +999,21 @@ def submodule_add(repo, url, path=None, name=None):
         config.write_to_path()
         config.write_to_path()
 
 
 
 
+def submodule_init(repo):
+    """Initialize submodules.
+
+    Args:
+      repo: Path to repository
+    """
+    with open_repo_closing(repo) as r:
+        config = r.get_config()
+        gitmodules_path = os.path.join(r.path, '.gitmodules')
+        for path, url, name in read_submodules(gitmodules_path):
+            config.set((b'submodule', name), b'active', True)
+            config.set((b'submodule', name), b'url', url)
+        config.write_to_path()
+
+
 def submodule_list(repo):
 def submodule_list(repo):
     """List submodules.
     """List submodules.
 
 
@@ -1008,6 +1036,7 @@ def tag_create(
     tag_time=None,
     tag_time=None,
     tag_timezone=None,
     tag_timezone=None,
     sign=False,
     sign=False,
+    encoding=DEFAULT_ENCODING
 ):
 ):
     """Creates a tag in git via dulwich calls:
     """Creates a tag in git via dulwich calls:
 
 
@@ -1035,7 +1064,7 @@ def tag_create(
                 # TODO(jelmer): Don't use repo private method.
                 # TODO(jelmer): Don't use repo private method.
                 author = r._get_user_identity(r.get_config_stack())
                 author = r._get_user_identity(r.get_config_stack())
             tag_obj.tagger = author
             tag_obj.tagger = author
-            tag_obj.message = message + "\n".encode()
+            tag_obj.message = message + "\n".encode(encoding)
             tag_obj.name = tag
             tag_obj.name = tag
             tag_obj.object = (type(object), object.id)
             tag_obj.object = (type(object), object.id)
             if tag_time is None:
             if tag_time is None:
@@ -1173,8 +1202,10 @@ def push(
                 else:
                 else:
                     try:
                     try:
                         localsha = r.refs[lh]
                         localsha = r.refs[lh]
-                    except KeyError:
-                        raise Error("No valid ref %s in local repository" % lh)
+                    except KeyError as exc:
+                        raise Error(
+                            "No valid ref %s in local repository" % lh
+                        ) from exc
                     if not force_ref and rh in refs:
                     if not force_ref and rh in refs:
                         check_diverged(r, refs[rh], localsha)
                         check_diverged(r, refs[rh], localsha)
                     new_refs[rh] = localsha
                     new_refs[rh] = localsha
@@ -1190,11 +1221,10 @@ def push(
                 generate_pack_data=r.generate_pack_data,
                 generate_pack_data=r.generate_pack_data,
                 progress=errstream.write,
                 progress=errstream.write,
             )
             )
-        except SendPackError as e:
+        except SendPackError as exc:
             raise Error(
             raise Error(
-                "Push to " + remote_location + " failed -> " + e.args[0].decode(),
-                inner=e,
-            )
+                "Push to " + remote_location + " failed -> " + exc.args[0].decode(),
+            ) from exc
         else:
         else:
             errstream.write(
             errstream.write(
                 b"Push to " + remote_location.encode(err_encoding) + b" successful.\n"
                 b"Push to " + remote_location.encode(err_encoding) + b" successful.\n"
@@ -1259,11 +1289,12 @@ def pull(
             if not force_ref and rh in r.refs:
             if not force_ref and rh in r.refs:
                 try:
                 try:
                     check_diverged(r, r.refs.follow(rh)[1], fetch_result.refs[lh])
                     check_diverged(r, r.refs.follow(rh)[1], fetch_result.refs[lh])
-                except DivergedBranches:
+                except DivergedBranches as exc:
                     if fast_forward:
                     if fast_forward:
                         raise
                         raise
                     else:
                     else:
-                        raise NotImplementedError("merge is not yet supported")
+                        raise NotImplementedError(
+                            "merge is not yet supported") from exc
             r.refs[rh] = fetch_result.refs[lh]
             r.refs[rh] = fetch_result.refs[lh]
         if selected_refs:
         if selected_refs:
             r[b"HEAD"] = fetch_result.refs[selected_refs[0][1]]
             r[b"HEAD"] = fetch_result.refs[selected_refs[0][1]]
@@ -1683,7 +1714,7 @@ def fetch(
     return fetch_result
     return fetch_result
 
 
 
 
-def ls_remote(remote, config=None, **kwargs):
+def ls_remote(remote, config: Optional[Config] = None, **kwargs):
     """List the refs in a remote.
     """List the refs in a remote.
 
 
     Args:
     Args:
@@ -1721,7 +1752,7 @@ def pack_objects(repo, object_ids, packf, idxf, delta_window_size=None):
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
         entries, data_sum = write_pack_objects(
         entries, data_sum = write_pack_objects(
-            packf,
+            packf.write,
             r.object_store.iter_shas((oid, None) for oid in object_ids),
             r.object_store.iter_shas((oid, None) for oid in object_ids),
             delta_window_size=delta_window_size,
             delta_window_size=delta_window_size,
         )
         )
@@ -1849,7 +1880,8 @@ def update_head(repo, target, detached=False, new_branch=None):
             r.refs.set_symbolic_ref(b"HEAD", to_set)
             r.refs.set_symbolic_ref(b"HEAD", to_set)
 
 
 
 
-def reset_file(repo, file_path: str, target: bytes = b'HEAD'):
+def reset_file(repo, file_path: str, target: bytes = b'HEAD',
+               symlink_fn=None):
     """Reset the file to specific commit or branch.
     """Reset the file to specific commit or branch.
 
 
     Args:
     Args:
@@ -1858,13 +1890,13 @@ def reset_file(repo, file_path: str, target: bytes = b'HEAD'):
       target: branch or commit or b'HEAD' to reset
       target: branch or commit or b'HEAD' to reset
     """
     """
     tree = parse_tree(repo, treeish=target)
     tree = parse_tree(repo, treeish=target)
-    file_path = _fs_to_tree_path(file_path)
+    tree_path = _fs_to_tree_path(file_path)
 
 
-    file_entry = tree.lookup_path(repo.object_store.__getitem__, file_path)
-    full_path = os.path.join(repo.path.encode(), file_path)
+    file_entry = tree.lookup_path(repo.object_store.__getitem__, tree_path)
+    full_path = os.path.join(os.fsencode(repo.path), tree_path)
     blob = repo.object_store[file_entry[1]]
     blob = repo.object_store[file_entry[1]]
     mode = file_entry[0]
     mode = file_entry[0]
-    build_file_from_blob(blob, mode, full_path)
+    build_file_from_blob(blob, mode, full_path, symlink_fn=symlink_fn)
 
 
 
 
 def check_mailmap(repo, contact):
 def check_mailmap(repo, contact):

+ 36 - 40
dulwich/protocol.py

@@ -109,6 +109,8 @@ KNOWN_RECEIVE_CAPABILITIES = set(
 
 
 DEPTH_INFINITE = 0x7FFFFFFF
 DEPTH_INFINITE = 0x7FFFFFFF
 
 
+NAK_LINE = b"NAK\n"
+
 
 
 def agent_string():
 def agent_string():
     return ("dulwich/%d.%d.%d" % dulwich.__version__).encode("ascii")
     return ("dulwich/%d.%d.%d" % dulwich.__version__).encode("ascii")
@@ -145,20 +147,6 @@ COMMAND_WANT = b"want"
 COMMAND_HAVE = b"have"
 COMMAND_HAVE = b"have"
 
 
 
 
-class ProtocolFile(object):
-    """A dummy file for network ops that expect file-like objects."""
-
-    def __init__(self, read, write):
-        self.read = read
-        self.write = write
-
-    def tell(self):
-        pass
-
-    def close(self):
-        pass
-
-
 def format_cmd_pkt(cmd, *args):
 def format_cmd_pkt(cmd, *args):
     return cmd + b" " + b"".join([(a + b"\0") for a in args])
     return cmd + b" " + b"".join([(a + b"\0") for a in args])
 
 
@@ -238,10 +226,10 @@ class Protocol(object):
             if self.report_activity:
             if self.report_activity:
                 self.report_activity(size, "read")
                 self.report_activity(size, "read")
             pkt_contents = read(size - 4)
             pkt_contents = read(size - 4)
-        except ConnectionResetError:
-            raise HangupException()
-        except socket.error as e:
-            raise GitProtocolError(e)
+        except ConnectionResetError as exc:
+            raise HangupException() from exc
+        except socket.error as exc:
+            raise GitProtocolError(exc) from exc
         else:
         else:
             if len(pkt_contents) + 4 != size:
             if len(pkt_contents) + 4 != size:
                 raise GitProtocolError(
                 raise GitProtocolError(
@@ -303,28 +291,8 @@ class Protocol(object):
             self.write(line)
             self.write(line)
             if self.report_activity:
             if self.report_activity:
                 self.report_activity(len(line), "write")
                 self.report_activity(len(line), "write")
-        except socket.error as e:
-            raise GitProtocolError(e)
-
-    def write_file(self):
-        """Return a writable file-like object for this protocol."""
-
-        class ProtocolFile(object):
-            def __init__(self, proto):
-                self._proto = proto
-                self._offset = 0
-
-            def write(self, data):
-                self._proto.write(data)
-                self._offset += len(data)
-
-            def tell(self):
-                return self._offset
-
-            def close(self):
-                pass
-
-        return ProtocolFile(self)
+        except socket.error as exc:
+            raise GitProtocolError(exc) from exc
 
 
     def write_sideband(self, channel, blob):
     def write_sideband(self, channel, blob):
         """Write multiplexed data to the sideband.
         """Write multiplexed data to the sideband.
@@ -585,3 +553,31 @@ class PktLineParser(object):
     def get_tail(self):
     def get_tail(self):
         """Read back any unused data."""
         """Read back any unused data."""
         return self._readahead.getvalue()
         return self._readahead.getvalue()
+
+
+def format_capability_line(capabilities):
+    return b"".join([b" " + c for c in capabilities])
+
+
+def format_ref_line(ref, sha, capabilities=None):
+    if capabilities is None:
+        return sha + b" " + ref + b"\n"
+    else:
+        return (
+            sha + b" " + ref + b"\0"
+            + format_capability_line(capabilities)
+            + b"\n")
+
+
+def format_shallow_line(sha):
+    return COMMAND_SHALLOW + b" " + sha
+
+
+def format_unshallow_line(sha):
+    return COMMAND_UNSHALLOW + b" " + sha
+
+
+def format_ack_line(sha, ack_type=b""):
+    if ack_type:
+        ack_type = b" " + ack_type
+    return b"ACK " + sha + ack_type + b"\n"

+ 10 - 8
dulwich/refs.py

@@ -34,12 +34,14 @@ from dulwich.objects import (
     valid_hexsha,
     valid_hexsha,
     ZERO_SHA,
     ZERO_SHA,
     Tag,
     Tag,
+    ObjectID,
 )
 )
 from dulwich.file import (
 from dulwich.file import (
     GitFile,
     GitFile,
     ensure_dir_exists,
     ensure_dir_exists,
 )
 )
 
 
+Ref = bytes
 
 
 HEADREF = b"HEAD"
 HEADREF = b"HEAD"
 SYMREF = b"ref: "
 SYMREF = b"ref: "
@@ -69,7 +71,7 @@ def parse_symref_value(contents):
     raise ValueError(contents)
     raise ValueError(contents)
 
 
 
 
-def check_ref_format(refname):
+def check_ref_format(refname: Ref):
     """Check if a refname is correctly formatted.
     """Check if a refname is correctly formatted.
 
 
     Implements all the same rules as git-check-ref-format[1].
     Implements all the same rules as git-check-ref-format[1].
@@ -166,8 +168,8 @@ class RefsContainer(object):
 
 
     def import_refs(
     def import_refs(
         self,
         self,
-        base: bytes,
-        other: Dict[bytes, bytes],
+        base: Ref,
+        other: Dict[Ref, ObjectID],
         committer: Optional[bytes] = None,
         committer: Optional[bytes] = None,
         timestamp: Optional[bytes] = None,
         timestamp: Optional[bytes] = None,
         timezone: Optional[bytes] = None,
         timezone: Optional[bytes] = None,
@@ -455,8 +457,8 @@ class DictRefsContainer(RefsContainer):
 
 
     def set_symbolic_ref(
     def set_symbolic_ref(
         self,
         self,
-        name,
-        other,
+        name: Ref,
+        other: Ref,
         committer=None,
         committer=None,
         timestamp=None,
         timestamp=None,
         timezone=None,
         timezone=None,
@@ -507,8 +509,8 @@ class DictRefsContainer(RefsContainer):
 
 
     def add_if_new(
     def add_if_new(
         self,
         self,
-        name: bytes,
-        ref: bytes,
+        name: Ref,
+        ref: ObjectID,
         committer=None,
         committer=None,
         timestamp=None,
         timestamp=None,
         timezone=None,
         timezone=None,
@@ -835,7 +837,7 @@ class DiskRefsContainer(RefsContainer):
         try:
         try:
             realnames, _ = self.follow(name)
             realnames, _ = self.follow(name)
             realname = realnames[-1]
             realname = realnames[-1]
-        except (KeyError, IndexError):
+        except (KeyError, IndexError, SymrefLoop):
             realname = name
             realname = name
         filename = self.refpath(realname)
         filename = self.refpath(realname)
 
 

+ 87 - 44
dulwich/repo.py

@@ -33,7 +33,18 @@ import os
 import sys
 import sys
 import stat
 import stat
 import time
 import time
-from typing import Optional, Tuple, TYPE_CHECKING, List, Dict, Union, Iterable
+from typing import (
+    Optional,
+    BinaryIO,
+    Callable,
+    Tuple,
+    TYPE_CHECKING,
+    List,
+    Dict,
+    Union,
+    Iterable,
+    Set
+)
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     # There are no circular imports here, but we try to defer imports as long
     # There are no circular imports here, but we try to defer imports as long
@@ -70,6 +81,7 @@ from dulwich.objects import (
     ShaFile,
     ShaFile,
     Tag,
     Tag,
     Tree,
     Tree,
+    ObjectID,
 )
 )
 from dulwich.pack import (
 from dulwich.pack import (
     pack_objects_to_data,
     pack_objects_to_data,
@@ -86,6 +98,7 @@ from dulwich.hooks import (
 from dulwich.line_ending import BlobNormalizer, TreeBlobNormalizer
 from dulwich.line_ending import BlobNormalizer, TreeBlobNormalizer
 
 
 from dulwich.refs import (  # noqa: F401
 from dulwich.refs import (  # noqa: F401
+    Ref,
     ANNOTATED_TAG_SUFFIX,
     ANNOTATED_TAG_SUFFIX,
     LOCAL_BRANCH_PREFIX,
     LOCAL_BRANCH_PREFIX,
     LOCAL_TAG_PREFIX,
     LOCAL_TAG_PREFIX,
@@ -136,31 +149,39 @@ class InvalidUserIdentity(Exception):
         self.identity = identity
         self.identity = identity
 
 
 
 
+# TODO(jelmer): Cache?
 def _get_default_identity() -> Tuple[str, str]:
 def _get_default_identity() -> Tuple[str, str]:
-    import getpass
     import socket
     import socket
 
 
-    username = getpass.getuser()
+    for name in ('LOGNAME', 'USER', 'LNAME', 'USERNAME'):
+        username = os.environ.get(name)
+        if username:
+            break
+    else:
+        username = None
+
     try:
     try:
         import pwd
         import pwd
     except ImportError:
     except ImportError:
         fullname = None
         fullname = None
     else:
     else:
         try:
         try:
-            gecos = pwd.getpwnam(username).pw_gecos  # type: ignore
-        except (KeyError, AttributeError):
+            entry = pwd.getpwuid(os.getuid())  # type: ignore
+        except KeyError:
             fullname = None
             fullname = None
         else:
         else:
-            if gecos:
-                fullname = gecos.split(",")[0]
+            if getattr(entry, 'gecos', None):
+                fullname = entry.pw_gecos.split(",")[0]
             else:
             else:
                 fullname = None
                 fullname = None
+            if username is None:
+                username = entry.pw_name
     if not fullname:
     if not fullname:
         fullname = username
         fullname = username
     email = os.environ.get("EMAIL")
     email = os.environ.get("EMAIL")
     if email is None:
     if email is None:
         email = "{}@{}".format(username, socket.gethostname())
         email = "{}@{}".format(username, socket.gethostname())
-    return (fullname, email)
+    return (fullname, email)  # type: ignore
 
 
 
 
 def get_user_identity(config: "StackedConfig", kind: Optional[str] = None) -> bytes:
 def get_user_identity(config: "StackedConfig", kind: Optional[str] = None) -> bytes:
@@ -223,10 +244,12 @@ def check_user_identity(identity):
     """
     """
     try:
     try:
         fst, snd = identity.split(b" <", 1)
         fst, snd = identity.split(b" <", 1)
-    except ValueError:
-        raise InvalidUserIdentity(identity)
+    except ValueError as exc:
+        raise InvalidUserIdentity(identity) from exc
     if b">" not in snd:
     if b">" not in snd:
         raise InvalidUserIdentity(identity)
         raise InvalidUserIdentity(identity)
+    if b'\0' in identity or b'\n' in identity:
+        raise InvalidUserIdentity(identity)
 
 
 
 
 def parse_graftpoints(
 def parse_graftpoints(
@@ -327,6 +350,9 @@ class ParentsProvider(object):
 class BaseRepo(object):
 class BaseRepo(object):
     """Base class for a git repository.
     """Base class for a git repository.
 
 
+    This base class is meant to be used for Repository implementations that e.g.
+    work on top of a different transport than a standard filesystem path.
+
     Attributes:
     Attributes:
       object_store: Dictionary-like object for accessing
       object_store: Dictionary-like object for accessing
         the objects
         the objects
@@ -376,7 +402,7 @@ class BaseRepo(object):
         self._put_named_file("config", f.getvalue())
         self._put_named_file("config", f.getvalue())
         self._put_named_file(os.path.join("info", "exclude"), b"")
         self._put_named_file(os.path.join("info", "exclude"), b"")
 
 
-    def get_named_file(self, path):
+    def get_named_file(self, path: str) -> Optional[BinaryIO]:
         """Get a file from the control dir with a specific name.
         """Get a file from the control dir with a specific name.
 
 
         Although the filename should be interpreted as a filename relative to
         Although the filename should be interpreted as a filename relative to
@@ -389,7 +415,7 @@ class BaseRepo(object):
         """
         """
         raise NotImplementedError(self.get_named_file)
         raise NotImplementedError(self.get_named_file)
 
 
-    def _put_named_file(self, path, contents):
+    def _put_named_file(self, path: str, contents: bytes):
         """Write a file to the control dir with the given name and contents.
         """Write a file to the control dir with the given name and contents.
 
 
         Args:
         Args:
@@ -398,11 +424,11 @@ class BaseRepo(object):
         """
         """
         raise NotImplementedError(self._put_named_file)
         raise NotImplementedError(self._put_named_file)
 
 
-    def _del_named_file(self, path):
+    def _del_named_file(self, path: str):
         """Delete a file in the control directory with the given name."""
         """Delete a file in the control directory with the given name."""
         raise NotImplementedError(self._del_named_file)
         raise NotImplementedError(self._del_named_file)
 
 
-    def open_index(self):
+    def open_index(self) -> "Index":
         """Open the index for this repository.
         """Open the index for this repository.
 
 
         Raises:
         Raises:
@@ -549,7 +575,9 @@ class BaseRepo(object):
             )
             )
         )
         )
 
 
-    def generate_pack_data(self, have, want, progress=None, ofs_delta=None):
+    def generate_pack_data(self, have: List[ObjectID], want: List[ObjectID],
+                           progress: Optional[Callable[[str], None]] = None,
+                           ofs_delta: Optional[bool] = None):
         """Generate pack data objects for a set of wants/haves.
         """Generate pack data objects for a set of wants/haves.
 
 
         Args:
         Args:
@@ -566,7 +594,8 @@ class BaseRepo(object):
             ofs_delta=ofs_delta,
             ofs_delta=ofs_delta,
         )
         )
 
 
-    def get_graph_walker(self, heads=None):
+    def get_graph_walker(
+            self, heads: List[ObjectID] = None) -> ObjectStoreGraphWalker:
         """Retrieve a graph walker.
         """Retrieve a graph walker.
 
 
         A graph walker is used by a remote repository (or proxy)
         A graph walker is used by a remote repository (or proxy)
@@ -627,7 +656,7 @@ class BaseRepo(object):
         """
         """
         return self.object_store[sha]
         return self.object_store[sha]
 
 
-    def parents_provider(self):
+    def parents_provider(self) -> ParentsProvider:
         return ParentsProvider(
         return ParentsProvider(
             self.object_store,
             self.object_store,
             grafts=self._graftpoints,
             grafts=self._graftpoints,
@@ -647,7 +676,7 @@ class BaseRepo(object):
         """
         """
         return self.parents_provider().get_parents(sha, commit)
         return self.parents_provider().get_parents(sha, commit)
 
 
-    def get_config(self):
+    def get_config(self) -> "ConfigFile":
         """Retrieve the config object.
         """Retrieve the config object.
 
 
         Returns: `ConfigFile` object for the ``.git/config`` file.
         Returns: `ConfigFile` object for the ``.git/config`` file.
@@ -684,7 +713,7 @@ class BaseRepo(object):
         backends = [self.get_config()] + StackedConfig.default_backends()
         backends = [self.get_config()] + StackedConfig.default_backends()
         return StackedConfig(backends, writable=backends[0])
         return StackedConfig(backends, writable=backends[0])
 
 
-    def get_shallow(self):
+    def get_shallow(self) -> Set[ObjectID]:
         """Get the set of shallow commits.
         """Get the set of shallow commits.
 
 
         Returns: Set of shallow commits.
         Returns: Set of shallow commits.
@@ -714,7 +743,7 @@ class BaseRepo(object):
         else:
         else:
             self._del_named_file("shallow")
             self._del_named_file("shallow")
 
 
-    def get_peeled(self, ref):
+    def get_peeled(self, ref: Ref) -> ObjectID:
         """Get the peeled value of a ref.
         """Get the peeled value of a ref.
 
 
         Args:
         Args:
@@ -728,7 +757,8 @@ class BaseRepo(object):
             return cached
             return cached
         return self.object_store.peel_sha(self.refs[ref]).id
         return self.object_store.peel_sha(self.refs[ref]).id
 
 
-    def get_walker(self, include=None, *args, **kwargs):
+    def get_walker(self, include: Optional[List[bytes]] = None,
+                   *args, **kwargs):
         """Obtain a walker for this repository.
         """Obtain a walker for this repository.
 
 
         Args:
         Args:
@@ -758,14 +788,12 @@ class BaseRepo(object):
 
 
         if include is None:
         if include is None:
             include = [self.head()]
             include = [self.head()]
-        if isinstance(include, str):
-            include = [include]
 
 
         kwargs["get_parents"] = lambda commit: self.get_parents(commit.id, commit)
         kwargs["get_parents"] = lambda commit: self.get_parents(commit.id, commit)
 
 
         return Walker(self.object_store, include, *args, **kwargs)
         return Walker(self.object_store, include, *args, **kwargs)
 
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: Union[ObjectID, Ref]):
         """Retrieve a Git object by SHA1 or ref.
         """Retrieve a Git object by SHA1 or ref.
 
 
         Args:
         Args:
@@ -785,8 +813,8 @@ class BaseRepo(object):
                 pass
                 pass
         try:
         try:
             return self.object_store[self.refs[name]]
             return self.object_store[self.refs[name]]
-        except RefFormatError:
-            raise KeyError(name)
+        except RefFormatError as exc:
+            raise KeyError(name) from exc
 
 
     def __contains__(self, name: bytes) -> bool:
     def __contains__(self, name: bytes) -> bool:
         """Check if a specific Git object or ref is present.
         """Check if a specific Git object or ref is present.
@@ -864,19 +892,19 @@ class BaseRepo(object):
 
 
     def do_commit(  # noqa: C901
     def do_commit(  # noqa: C901
         self,
         self,
-        message=None,
-        committer=None,
-        author=None,
+        message: Optional[bytes] = None,
+        committer: Optional[bytes] = None,
+        author: Optional[bytes] = None,
         commit_timestamp=None,
         commit_timestamp=None,
         commit_timezone=None,
         commit_timezone=None,
         author_timestamp=None,
         author_timestamp=None,
         author_timezone=None,
         author_timezone=None,
-        tree=None,
-        encoding=None,
-        ref=b"HEAD",
-        merge_heads=None,
-        no_verify=False,
-        sign=False,
+        tree: Optional[ObjectID] = None,
+        encoding: Optional[bytes] = None,
+        ref: Ref = b"HEAD",
+        merge_heads: Optional[List[ObjectID]] = None,
+        no_verify: bool = False,
+        sign: bool = False,
     ):
     ):
         """Create a new commit.
         """Create a new commit.
 
 
@@ -911,8 +939,8 @@ class BaseRepo(object):
         try:
         try:
             if not no_verify:
             if not no_verify:
                 self.hooks["pre-commit"].execute()
                 self.hooks["pre-commit"].execute()
-        except HookError as e:
-            raise CommitError(e)
+        except HookError as exc:
+            raise CommitError(exc) from exc
         except KeyError:  # no hook defined, silent fallthrough
         except KeyError:  # no hook defined, silent fallthrough
             pass
             pass
 
 
@@ -969,8 +997,8 @@ class BaseRepo(object):
                 c.message = self.hooks["commit-msg"].execute(message)
                 c.message = self.hooks["commit-msg"].execute(message)
                 if c.message is None:
                 if c.message is None:
                     c.message = message
                     c.message = message
-        except HookError as e:
-            raise CommitError(e)
+        except HookError as exc:
+            raise CommitError(exc) from exc
         except KeyError:  # no hook defined, message not modified
         except KeyError:  # no hook defined, message not modified
             c.message = message
             c.message = message
 
 
@@ -1014,7 +1042,7 @@ class BaseRepo(object):
             if not ok:
             if not ok:
                 # Fail if the atomic compare-and-swap failed, leaving the
                 # Fail if the atomic compare-and-swap failed, leaving the
                 # commit and all its objects as garbage.
                 # commit and all its objects as garbage.
-                raise CommitError("%s changed during commit" % (ref,))
+                raise CommitError(f"{ref!r} changed during commit")
 
 
         self._del_named_file("MERGE_HEAD")
         self._del_named_file("MERGE_HEAD")
 
 
@@ -1050,6 +1078,13 @@ class UnsupportedVersion(Exception):
         self.version = version
         self.version = version
 
 
 
 
+class UnsupportedExtension(Exception):
+    """Unsupported repository extension."""
+
+    def __init__(self, extension):
+        self.extension = extension
+
+
 class Repo(BaseRepo):
 class Repo(BaseRepo):
     """A git repository backed by local disk.
     """A git repository backed by local disk.
 
 
@@ -1078,6 +1113,7 @@ class Repo(BaseRepo):
         object_store: Optional[BaseObjectStore] = None,
         object_store: Optional[BaseObjectStore] = None,
         bare: Optional[bool] = None
         bare: Optional[bool] = None
     ) -> None:
     ) -> None:
+        self.symlink_fn = None
         hidden_path = os.path.join(root, CONTROLDIR)
         hidden_path = os.path.join(root, CONTROLDIR)
         if bare is None:
         if bare is None:
             if (os.path.isfile(hidden_path)
             if (os.path.isfile(hidden_path)
@@ -1125,8 +1161,12 @@ class Repo(BaseRepo):
         except KeyError:
         except KeyError:
             format_version = 0
             format_version = 0
 
 
-        if format_version != 0:
+        if format_version not in (0, 1):
             raise UnsupportedVersion(format_version)
             raise UnsupportedVersion(format_version)
+
+        for extension in config.items((b"extensions", )):
+            raise UnsupportedExtension(extension)
+
         if object_store is None:
         if object_store is None:
             object_store = DiskObjectStore.from_config(
             object_store = DiskObjectStore.from_config(
                 os.path.join(self.commondir(), OBJECTDIR), config
                 os.path.join(self.commondir(), OBJECTDIR), config
@@ -1395,8 +1435,10 @@ class Repo(BaseRepo):
                 try:
                 try:
                     del index[tree_path]
                     del index[tree_path]
                     continue
                     continue
-                except KeyError:
-                    raise KeyError("file '%s' not in index" % (tree_path.decode()))
+                except KeyError as exc:
+                    raise KeyError(
+                        "file '%s' not in index" % (tree_path.decode())
+                    ) from exc
 
 
             st = None
             st = None
             try:
             try:
@@ -1514,7 +1556,7 @@ class Repo(BaseRepo):
             raise
             raise
         return target
         return target
 
 
-    def reset_index(self, tree: Optional[Tree] = None):
+    def reset_index(self, tree: Optional[bytes] = None):
         """Reset the index back to a specific tree.
         """Reset the index back to a specific tree.
 
 
         Args:
         Args:
@@ -1545,6 +1587,7 @@ class Repo(BaseRepo):
             tree,
             tree,
             honor_filemode=honor_filemode,
             honor_filemode=honor_filemode,
             validate_path_element=validate_path_element,
             validate_path_element=validate_path_element,
+            symlink_fn=self.symlink_fn,
         )
         )
 
 
     def get_config(self) -> "ConfigFile":
     def get_config(self) -> "ConfigFile":

+ 33 - 38
dulwich/server.py

@@ -96,7 +96,6 @@ from dulwich.protocol import (
     MULTI_ACK,
     MULTI_ACK,
     MULTI_ACK_DETAILED,
     MULTI_ACK_DETAILED,
     Protocol,
     Protocol,
-    ProtocolFile,
     ReceivableProtocol,
     ReceivableProtocol,
     SIDE_BAND_CHANNEL_DATA,
     SIDE_BAND_CHANNEL_DATA,
     SIDE_BAND_CHANNEL_PROGRESS,
     SIDE_BAND_CHANNEL_PROGRESS,
@@ -108,6 +107,11 @@ from dulwich.protocol import (
     extract_capabilities,
     extract_capabilities,
     extract_want_line_capabilities,
     extract_want_line_capabilities,
     symref_capabilities,
     symref_capabilities,
+    format_ref_line,
+    format_shallow_line,
+    format_unshallow_line,
+    format_ack_line,
+    NAK_LINE,
 )
 )
 from dulwich.refs import (
 from dulwich.refs import (
     ANNOTATED_TAG_SUFFIX,
     ANNOTATED_TAG_SUFFIX,
@@ -189,10 +193,10 @@ class DictBackend(Backend):
         logger.debug("Opening repository at %s", path)
         logger.debug("Opening repository at %s", path)
         try:
         try:
             return self.repos[path]
             return self.repos[path]
-        except KeyError:
+        except KeyError as exc:
             raise NotGitRepository(
             raise NotGitRepository(
                 "No git repository was found at %(path)s" % dict(path=path)
                 "No git repository was found at %(path)s" % dict(path=path)
-            )
+            ) from exc
 
 
 
 
 class FileSystemBackend(Backend):
 class FileSystemBackend(Backend):
@@ -215,7 +219,7 @@ class FileSystemBackend(Backend):
 class Handler(object):
 class Handler(object):
     """Smart protocol command handler base class."""
     """Smart protocol command handler base class."""
 
 
-    def __init__(self, backend, proto, stateless_rpc=None):
+    def __init__(self, backend, proto, stateless_rpc=False):
         self.backend = backend
         self.backend = backend
         self.proto = proto
         self.proto = proto
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
@@ -227,17 +231,12 @@ class Handler(object):
 class PackHandler(Handler):
 class PackHandler(Handler):
     """Protocol handler for packs."""
     """Protocol handler for packs."""
 
 
-    def __init__(self, backend, proto, stateless_rpc=None):
+    def __init__(self, backend, proto, stateless_rpc=False):
         super(PackHandler, self).__init__(backend, proto, stateless_rpc)
         super(PackHandler, self).__init__(backend, proto, stateless_rpc)
         self._client_capabilities = None
         self._client_capabilities = None
         # Flags needed for the no-done capability
         # Flags needed for the no-done capability
         self._done_received = False
         self._done_received = False
 
 
-    @classmethod
-    def capability_line(cls, capabilities):
-        logger.info("Sending capabilities: %s", capabilities)
-        return b"".join([b" " + c for c in capabilities])
-
     @classmethod
     @classmethod
     def capabilities(cls) -> Iterable[bytes]:
     def capabilities(cls) -> Iterable[bytes]:
         raise NotImplementedError(cls.capabilities)
         raise NotImplementedError(cls.capabilities)
@@ -289,7 +288,7 @@ class PackHandler(Handler):
 class UploadPackHandler(PackHandler):
 class UploadPackHandler(PackHandler):
     """Protocol handler for uploading a pack to the client."""
     """Protocol handler for uploading a pack to the client."""
 
 
-    def __init__(self, backend, args, proto, stateless_rpc=None, advertise_refs=False):
+    def __init__(self, backend, args, proto, stateless_rpc=False, advertise_refs=False):
         super(UploadPackHandler, self).__init__(
         super(UploadPackHandler, self).__init__(
             backend, proto, stateless_rpc=stateless_rpc
             backend, proto, stateless_rpc=stateless_rpc
         )
         )
@@ -409,7 +408,7 @@ class UploadPackHandler(PackHandler):
         self.progress(
         self.progress(
             ("counting objects: %d, done.\n" % len(objects_iter)).encode("ascii")
             ("counting objects: %d, done.\n" % len(objects_iter)).encode("ascii")
         )
         )
-        write_pack_objects(ProtocolFile(None, write), objects_iter)
+        write_pack_objects(write, objects_iter)
         # we are done
         # we are done
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
 
 
@@ -602,17 +601,19 @@ class _ProtocolGraphWalker(object):
                     # TODO(jelmer): Integrate with Repo.fetch_objects refs
                     # TODO(jelmer): Integrate with Repo.fetch_objects refs
                     # logic.
                     # logic.
                     continue
                     continue
-                line = sha + b" " + ref
-                if not i:
-                    line += b"\x00" + self.handler.capability_line(
+                if i == 0:
+                    logger.info(
+                        "Sending capabilities: %s", self.handler.capabilities())
+                    line = format_ref_line(
+                        ref, sha,
                         self.handler.capabilities()
                         self.handler.capabilities()
-                        + symref_capabilities(symrefs.items())
-                    )
-                self.proto.write_pkt_line(line + b"\n")
+                        + symref_capabilities(symrefs.items()))
+                else:
+                    line = format_ref_line(ref, sha)
+                self.proto.write_pkt_line(line)
                 if peeled_sha != sha:
                 if peeled_sha != sha:
                     self.proto.write_pkt_line(
                     self.proto.write_pkt_line(
-                        peeled_sha + b" " + ref + ANNOTATED_TAG_SUFFIX + b"\n"
-                    )
+                        format_ref_line(ref + ANNOTATED_TAG_SUFFIX, peeled_sha))
 
 
             # i'm done..
             # i'm done..
             self.proto.write_pkt_line(None)
             self.proto.write_pkt_line(None)
@@ -706,9 +707,9 @@ class _ProtocolGraphWalker(object):
         unshallow = self.unshallow = not_shallow & self.client_shallow
         unshallow = self.unshallow = not_shallow & self.client_shallow
 
 
         for sha in sorted(new_shallow):
         for sha in sorted(new_shallow):
-            self.proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha)
+            self.proto.write_pkt_line(format_shallow_line(sha))
         for sha in sorted(unshallow):
         for sha in sorted(unshallow):
-            self.proto.write_pkt_line(COMMAND_UNSHALLOW + b" " + sha)
+            self.proto.write_pkt_line(format_unshallow_line(sha))
 
 
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
 
 
@@ -717,12 +718,10 @@ class _ProtocolGraphWalker(object):
         self.handler.notify_done()
         self.handler.notify_done()
 
 
     def send_ack(self, sha, ack_type=b""):
     def send_ack(self, sha, ack_type=b""):
-        if ack_type:
-            ack_type = b" " + ack_type
-        self.proto.write_pkt_line(b"ACK " + sha + ack_type + b"\n")
+        self.proto.write_pkt_line(format_ack_line(sha, ack_type))
 
 
     def send_nak(self):
     def send_nak(self):
-        self.proto.write_pkt_line(b"NAK\n")
+        self.proto.write_pkt_line(NAK_LINE)
 
 
     def handle_done(self, done_required, done_received):
     def handle_done(self, done_required, done_received):
         # Delegate this to the implementation.
         # Delegate this to the implementation.
@@ -924,7 +923,7 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(PackHandler):
 class ReceivePackHandler(PackHandler):
     """Protocol handler for downloading a pack from the client."""
     """Protocol handler for downloading a pack from the client."""
 
 
-    def __init__(self, backend, args, proto, stateless_rpc=None, advertise_refs=False):
+    def __init__(self, backend, args, proto, stateless_rpc=False, advertise_refs=False):
         super(ReceivePackHandler, self).__init__(
         super(ReceivePackHandler, self).__init__(
             backend, proto, stateless_rpc=stateless_rpc
             backend, proto, stateless_rpc=stateless_rpc
         )
         )
@@ -1047,19 +1046,15 @@ class ReceivePackHandler(PackHandler):
 
 
             if not refs:
             if not refs:
                 refs = [(CAPABILITIES_REF, ZERO_SHA)]
                 refs = [(CAPABILITIES_REF, ZERO_SHA)]
+            logger.info(
+                "Sending capabilities: %s", self.capabilities())
             self.proto.write_pkt_line(
             self.proto.write_pkt_line(
-                refs[0][1]
-                + b" "
-                + refs[0][0]
-                + b"\0"
-                + self.capability_line(
-                    self.capabilities() + symref_capabilities(symrefs)
-                )
-                + b"\n"
-            )
+                format_ref_line(
+                    refs[0][0], refs[0][1],
+                    self.capabilities() + symref_capabilities(symrefs)))
             for i in range(1, len(refs)):
             for i in range(1, len(refs)):
                 ref = refs[i]
                 ref = refs[i]
-                self.proto.write_pkt_line(ref[1] + b" " + ref[0] + b"\n")
+                self.proto.write_pkt_line(format_ref_line(ref[0], ref[1]))
 
 
             self.proto.write_pkt_line(None)
             self.proto.write_pkt_line(None)
             if self.advertise_refs:
             if self.advertise_refs:
@@ -1092,7 +1087,7 @@ class ReceivePackHandler(PackHandler):
 
 
 
 
 class UploadArchiveHandler(Handler):
 class UploadArchiveHandler(Handler):
-    def __init__(self, backend, args, proto, stateless_rpc=None):
+    def __init__(self, backend, args, proto, stateless_rpc=False):
         super(UploadArchiveHandler, self).__init__(backend, proto, stateless_rpc)
         super(UploadArchiveHandler, self).__init__(backend, proto, stateless_rpc)
         self.repo = backend.open_repository(args[0])
         self.repo = backend.open_repository(args[0])
 
 

+ 1 - 0
dulwich/tests/__init__.py

@@ -117,6 +117,7 @@ def self_test_suite():
         "bundle",
         "bundle",
         "client",
         "client",
         "config",
         "config",
+        "credentials",
         "diff_tree",
         "diff_tree",
         "fastexport",
         "fastexport",
         "file",
         "file",

+ 1 - 1
dulwich/tests/compat/test_utils.py

@@ -35,7 +35,7 @@ class GitVersionTests(TestCase):
 
 
         def run_git(args, **unused_kwargs):
         def run_git(args, **unused_kwargs):
             self.assertEqual(["--version"], args)
             self.assertEqual(["--version"], args)
-            return 0, self._version_str
+            return 0, self._version_str, ''
 
 
         utils.run_git = run_git
         utils.run_git = run_git
 
 

+ 14 - 7
dulwich/tests/compat/utils.py

@@ -117,7 +117,8 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
 
 
 
 
 def run_git(
 def run_git(
-    args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False, **popen_kwargs
+    args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
+    capture_stderr=False, **popen_kwargs
 ):
 ):
     """Run a git command.
     """Run a git command.
 
 
@@ -131,8 +132,9 @@ def run_git(
       capture_stdout: Whether to capture and return stdout.
       capture_stdout: Whether to capture and return stdout.
       popen_kwargs: Additional kwargs for subprocess.Popen;
       popen_kwargs: Additional kwargs for subprocess.Popen;
         stdin/stdout args are ignored.
         stdin/stdout args are ignored.
-    Returns: A tuple of (returncode, stdout contents). If capture_stdout is
-        False, None will be returned as stdout contents.
+    Returns: A tuple of (returncode, stdout contents, stderr contents).
+        If capture_stdout is False, None will be returned as stdout contents.
+        If capture_stderr is False, None will be returned as stderr contents.
     Raises:
     Raises:
       OSError: if the git executable was not found.
       OSError: if the git executable was not found.
     """
     """
@@ -147,21 +149,26 @@ def run_git(
         popen_kwargs["stdout"] = subprocess.PIPE
         popen_kwargs["stdout"] = subprocess.PIPE
     else:
     else:
         popen_kwargs.pop("stdout", None)
         popen_kwargs.pop("stdout", None)
+    if capture_stderr:
+        popen_kwargs["stderr"] = subprocess.PIPE
+    else:
+        popen_kwargs.pop("stderr", None)
     p = subprocess.Popen(args, env=env, **popen_kwargs)
     p = subprocess.Popen(args, env=env, **popen_kwargs)
     stdout, stderr = p.communicate(input=input)
     stdout, stderr = p.communicate(input=input)
-    return (p.returncode, stdout)
+    return (p.returncode, stdout, stderr)
 
 
 
 
 def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
 def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
     """Run a git command, capture stdout/stderr, and fail if git fails."""
     """Run a git command, capture stdout/stderr, and fail if git fails."""
     if "stderr" not in popen_kwargs:
     if "stderr" not in popen_kwargs:
         popen_kwargs["stderr"] = subprocess.STDOUT
         popen_kwargs["stderr"] = subprocess.STDOUT
-    returncode, stdout = run_git(
-        args, git_path=git_path, input=input, capture_stdout=True, **popen_kwargs
+    returncode, stdout, stderr = run_git(
+        args, git_path=git_path, input=input, capture_stdout=True,
+        capture_stderr=True, **popen_kwargs
     )
     )
     if returncode != 0:
     if returncode != 0:
         raise AssertionError(
         raise AssertionError(
-            "git with args %r failed with %d: %r" % (args, returncode, stdout)
+            "git with args %r failed with %d: stdout=%r stderr=%r" % (args, returncode, stdout, stderr)
         )
         )
     return stdout
     return stdout
 
 

+ 12 - 0
dulwich/tests/test_archive.py

@@ -73,6 +73,18 @@ class ArchiveTests(TestCase):
         self.addCleanup(tf.close)
         self.addCleanup(tf.close)
         self.assertEqual(["somename"], tf.getnames())
         self.assertEqual(["somename"], tf.getnames())
 
 
+    def test_unicode(self):
+        store = MemoryObjectStore()
+        b1 = Blob.from_string(b"somedata")
+        store.add_object(b1)
+        t1 = Tree()
+        t1.add("ő".encode('utf-8'), 0o100644, b1.id)
+        store.add_object(t1)
+        stream = b"".join(tar_stream(store, t1, mtime=0))
+        tf = tarfile.TarFile(fileobj=BytesIO(stream))
+        self.addCleanup(tf.close)
+        self.assertEqual(["ő"], tf.getnames())
+
     def test_prefix(self):
     def test_prefix(self):
         stream = self._get_example_tar_stream(mtime=0, prefix=b"blah")
         stream = self._get_example_tar_stream(mtime=0, prefix=b"blah")
         tf = tarfile.TarFile(fileobj=stream)
         tf = tarfile.TarFile(fileobj=stream)

+ 2 - 2
dulwich/tests/test_client.py

@@ -339,7 +339,7 @@ class GitClientTests(TestCase):
             return 0, []
             return 0, []
 
 
         f = BytesIO()
         f = BytesIO()
-        write_pack_objects(f, {})
+        write_pack_objects(f.write, {})
         self.client.send_pack("/", update_refs, generate_pack_data)
         self.client.send_pack("/", update_refs, generate_pack_data)
         self.assertEqual(
         self.assertEqual(
             self.rout.getvalue(),
             self.rout.getvalue(),
@@ -384,7 +384,7 @@ class GitClientTests(TestCase):
             )
             )
 
 
         f = BytesIO()
         f = BytesIO()
-        write_pack_data(f, *generate_pack_data(None, None))
+        write_pack_data(f.write, *generate_pack_data(None, None))
         self.client.send_pack(b"/", update_refs, generate_pack_data)
         self.client.send_pack(b"/", update_refs, generate_pack_data)
         self.assertEqual(
         self.assertEqual(
             self.rout.getvalue(),
             self.rout.getvalue(),

+ 13 - 0
dulwich/tests/test_config.py

@@ -196,6 +196,19 @@ class ConfigFileTests(TestCase):
         cf = self.from_file(b"[branch.foo] foo = bar\n")
         cf = self.from_file(b"[branch.foo] foo = bar\n")
         self.assertEqual(b"bar", cf.get((b"branch", b"foo"), b"foo"))
         self.assertEqual(b"bar", cf.get((b"branch", b"foo"), b"foo"))
 
 
+    def test_quoted_newlines_windows(self):
+        cf = self.from_file(
+            b"[alias]\r\n"
+            b"c = '!f() { \\\r\n"
+            b" printf '[git commit -m \\\"%s\\\"]\\n' \\\"$*\\\" && \\\r\n"
+            b" git commit -m \\\"$*\\\"; \\\r\n"
+            b" }; f'\r\n")
+        self.assertEqual(list(cf.sections()), [(b'alias', )])
+        self.assertEqual(
+            b'\'!f() { printf \'[git commit -m "%s"]\n\' '
+            b'"$*" && git commit -m "$*"',
+            cf.get((b"alias", ), b"c"))
+
     def test_quoted(self):
     def test_quoted(self):
         cf = self.from_file(
         cf = self.from_file(
             b"""[gui]
             b"""[gui]

+ 75 - 0
dulwich/tests/test_credentials.py

@@ -0,0 +1,75 @@
+# test_credentials.py -- tests for credentials.py
+
+# Copyright (C) 2022 Daniele Trifirò <daniele@iterative.ai>
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+from urllib.parse import urlparse
+
+from dulwich.config import ConfigDict
+from dulwich.credentials import (
+    match_partial_url,
+    match_urls,
+    urlmatch_credential_sections,
+)
+from dulwich.tests import TestCase
+
+
+class TestCredentialHelpersUtils(TestCase):
+
+    def test_match_urls(self):
+        url = urlparse("https://github.com/jelmer/dulwich/")
+        url_1 = urlparse("https://github.com/jelmer/dulwich")
+        url_2 = urlparse("https://github.com/jelmer")
+        url_3 = urlparse("https://github.com")
+        self.assertTrue(match_urls(url, url_1))
+        self.assertTrue(match_urls(url, url_2))
+        self.assertTrue(match_urls(url, url_3))
+
+        non_matching = urlparse("https://git.sr.ht/")
+        self.assertFalse(match_urls(url, non_matching))
+
+    def test_match_partial_url(self):
+        url = urlparse("https://github.com/jelmer/dulwich/")
+        self.assertTrue(match_partial_url(url, "github.com"))
+        self.assertFalse(match_partial_url(url, "github.com/jelmer/"))
+        self.assertTrue(match_partial_url(url, "github.com/jelmer/dulwich"))
+        self.assertFalse(match_partial_url(url, "github.com/jel"))
+        self.assertFalse(match_partial_url(url, "github.com/jel/"))
+
+    def test_urlmatch_credential_sections(self):
+        config = ConfigDict()
+        config.set((b"credential", "https://github.com"), b"helper", "foo")
+        config.set((b"credential", "git.sr.ht"), b"helper", "foo")
+        config.set(b"credential", b"helper", "bar")
+
+        self.assertEqual(
+            list(urlmatch_credential_sections(config, "https://github.com")), [
+                (b"credential", b"https://github.com"),
+                (b"credential",),
+            ])
+
+        self.assertEqual(
+            list(urlmatch_credential_sections(config, "https://git.sr.ht")), [
+                (b"credential", b"git.sr.ht"),
+                (b"credential",),
+            ])
+
+        self.assertEqual(
+            list(urlmatch_credential_sections(config, "missing_url")), [
+                (b"credential",)])

+ 4 - 4
dulwich/tests/test_fastexport.py

@@ -52,8 +52,8 @@ class GitFastExporterTests(TestCase):
         self.stream = BytesIO()
         self.stream = BytesIO()
         try:
         try:
             from dulwich.fastexport import GitFastExporter
             from dulwich.fastexport import GitFastExporter
-        except ImportError:
-            raise SkipTest("python-fastimport not available")
+        except ImportError as exc:
+            raise SkipTest("python-fastimport not available") from exc
         self.fastexporter = GitFastExporter(self.stream, self.store)
         self.fastexporter = GitFastExporter(self.stream, self.store)
 
 
     def test_emit_blob(self):
     def test_emit_blob(self):
@@ -100,8 +100,8 @@ class GitImportProcessorTests(TestCase):
         self.repo = MemoryRepo()
         self.repo = MemoryRepo()
         try:
         try:
             from dulwich.fastexport import GitImportProcessor
             from dulwich.fastexport import GitImportProcessor
-        except ImportError:
-            raise SkipTest("python-fastimport not available")
+        except ImportError as exc:
+            raise SkipTest("python-fastimport not available") from exc
         self.processor = GitImportProcessor(self.repo)
         self.processor = GitImportProcessor(self.repo)
 
 
     def test_reset_handler(self):
     def test_reset_handler(self):

+ 0 - 4
dulwich/tests/test_index.py

@@ -71,10 +71,6 @@ def can_symlink():
         # Platforms other than Windows should allow symlinks without issues.
         # Platforms other than Windows should allow symlinks without issues.
         return True
         return True
 
 
-    if not hasattr(os, "symlink"):
-        # Older Python versions do not have `os.symlink` on Windows.
-        return False
-
     test_source = tempfile.mkdtemp()
     test_source = tempfile.mkdtemp()
     test_target = test_source + "can_symlink"
     test_target = test_source + "can_symlink"
     try:
     try:

+ 2 - 2
dulwich/tests/test_object_store.py

@@ -292,7 +292,7 @@ class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
         f, commit, abort = o.add_pack()
         f, commit, abort = o.add_pack()
         try:
         try:
             b = make_object(Blob, data=b"more yummy data")
             b = make_object(Blob, data=b"more yummy data")
-            write_pack_objects(f, [(b, None)])
+            write_pack_objects(f.write, [(b, None)])
         except BaseException:
         except BaseException:
             abort()
             abort()
             raise
             raise
@@ -525,7 +525,7 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
         f, commit, abort = o.add_pack()
         f, commit, abort = o.add_pack()
         try:
         try:
             b = make_object(Blob, data=b"more yummy data")
             b = make_object(Blob, data=b"more yummy data")
-            write_pack_objects(f, [(b, None)])
+            write_pack_objects(f.write, [(b, None)])
         except BaseException:
         except BaseException:
             abort()
             abort()
             raise
             raise

+ 34 - 1
dulwich/tests/test_objects.py

@@ -1203,7 +1203,10 @@ class CheckTests(TestCase):
             b"Dave Borowitz <dborowitz@google.com>",
             b"Dave Borowitz <dborowitz@google.com>",
             "failed to check good identity",
             "failed to check good identity",
         )
         )
-        check_identity(b"<dborowitz@google.com>", "failed to check good identity")
+        check_identity(b" <dborowitz@google.com>", "failed to check good identity")
+        self.assertRaises(
+            ObjectFormatException, check_identity, b'<dborowitz@google.com>', 'no space before email'
+        )
         self.assertRaises(
         self.assertRaises(
             ObjectFormatException, check_identity, b"Dave Borowitz", "no email"
             ObjectFormatException, check_identity, b"Dave Borowitz", "no email"
         )
         )
@@ -1237,6 +1240,36 @@ class CheckTests(TestCase):
             b"Dave Borowitz <dborowitz@google.com>xxx",
             b"Dave Borowitz <dborowitz@google.com>xxx",
             "trailing characters",
             "trailing characters",
         )
         )
+        self.assertRaises(
+            ObjectFormatException,
+            check_identity,
+            b"Dave Borowitz <dborowitz@google.com>xxx",
+            "trailing characters",
+        )
+        self.assertRaises(
+            ObjectFormatException,
+            check_identity,
+            b'Dave<Borowitz <dborowitz@google.com>',
+            'reserved byte in name',
+        )
+        self.assertRaises(
+            ObjectFormatException,
+            check_identity,
+            b'Dave>Borowitz <dborowitz@google.com>',
+            'reserved byte in name',
+        )
+        self.assertRaises(
+            ObjectFormatException,
+            check_identity,
+            b'Dave\0Borowitz <dborowitz@google.com>',
+            'null byte',
+        )
+        self.assertRaises(
+            ObjectFormatException,
+            check_identity,
+            b'Dave\nBorowitz <dborowitz@google.com>',
+            'newline byte',
+        )
 
 
 
 
 class TimezoneTests(TestCase):
 class TimezoneTests(TestCase):

+ 13 - 12
dulwich/tests/test_pack.py

@@ -186,7 +186,8 @@ class TestPackDeltas(TestCase):
 
 
     def _test_roundtrip(self, base, target):
     def _test_roundtrip(self, base, target):
         self.assertEqual(
         self.assertEqual(
-            target, b"".join(apply_delta(base, create_delta(base, target)))
+            target,
+            b"".join(apply_delta(base, list(create_delta(base, target))))
         )
         )
 
 
     def test_nochange(self):
     def test_nochange(self):
@@ -498,7 +499,7 @@ class TestPack(PackTests):
 
 
             data._file.seek(12)
             data._file.seek(12)
             bad_file = BytesIO()
             bad_file = BytesIO()
-            write_pack_header(bad_file, 9999)
+            write_pack_header(bad_file.write, 9999)
             bad_file.write(data._file.read())
             bad_file.write(data._file.read())
             bad_file = BytesIO(bad_file.getvalue())
             bad_file = BytesIO(bad_file.getvalue())
             bad_data = PackData("", file=bad_file)
             bad_data = PackData("", file=bad_file)
@@ -562,8 +563,8 @@ class TestThinPack(PackTests):
         # Index the new pack.
         # Index the new pack.
         with self.make_pack(True) as pack:
         with self.make_pack(True) as pack:
             with PackData(pack._data_path) as data:
             with PackData(pack._data_path) as data:
-                data.pack = pack
-                data.create_index(self.pack_prefix + ".idx")
+                data.create_index(
+                    self.pack_prefix + ".idx", resolve_ext_ref=pack.resolve_ext_ref)
 
 
         del self.store[self.blobs[b"bar"].id]
         del self.store[self.blobs[b"bar"].id]
 
 
@@ -618,14 +619,14 @@ class TestThinPack(PackTests):
 class WritePackTests(TestCase):
 class WritePackTests(TestCase):
     def test_write_pack_header(self):
     def test_write_pack_header(self):
         f = BytesIO()
         f = BytesIO()
-        write_pack_header(f, 42)
+        write_pack_header(f.write, 42)
         self.assertEqual(b"PACK\x00\x00\x00\x02\x00\x00\x00*", f.getvalue())
         self.assertEqual(b"PACK\x00\x00\x00\x02\x00\x00\x00*", f.getvalue())
 
 
     def test_write_pack_object(self):
     def test_write_pack_object(self):
         f = BytesIO()
         f = BytesIO()
         f.write(b"header")
         f.write(b"header")
         offset = f.tell()
         offset = f.tell()
-        crc32 = write_pack_object(f, Blob.type_num, b"blob")
+        crc32 = write_pack_object(f.write, Blob.type_num, b"blob")
         self.assertEqual(crc32, zlib.crc32(f.getvalue()[6:]) & 0xFFFFFFFF)
         self.assertEqual(crc32, zlib.crc32(f.getvalue()[6:]) & 0xFFFFFFFF)
 
 
         f.write(b"x")  # unpack_object needs extra trailing data.
         f.write(b"x")  # unpack_object needs extra trailing data.
@@ -643,7 +644,7 @@ class WritePackTests(TestCase):
         offset = f.tell()
         offset = f.tell()
         sha_a = sha1(b"foo")
         sha_a = sha1(b"foo")
         sha_b = sha_a.copy()
         sha_b = sha_a.copy()
-        write_pack_object(f, Blob.type_num, b"blob", sha=sha_a)
+        write_pack_object(f.write, Blob.type_num, b"blob", sha=sha_a)
         self.assertNotEqual(sha_a.digest(), sha_b.digest())
         self.assertNotEqual(sha_a.digest(), sha_b.digest())
         sha_b.update(f.getvalue()[offset:])
         sha_b.update(f.getvalue()[offset:])
         self.assertEqual(sha_a.digest(), sha_b.digest())
         self.assertEqual(sha_a.digest(), sha_b.digest())
@@ -654,7 +655,7 @@ class WritePackTests(TestCase):
         offset = f.tell()
         offset = f.tell()
         sha_a = sha1(b"foo")
         sha_a = sha1(b"foo")
         sha_b = sha_a.copy()
         sha_b = sha_a.copy()
-        write_pack_object(f, Blob.type_num, b"blob", sha=sha_a, compression_level=6)
+        write_pack_object(f.write, Blob.type_num, b"blob", sha=sha_a, compression_level=6)
         self.assertNotEqual(sha_a.digest(), sha_b.digest())
         self.assertNotEqual(sha_a.digest(), sha_b.digest())
         sha_b.update(f.getvalue()[offset:])
         sha_b.update(f.getvalue()[offset:])
         self.assertEqual(sha_a.digest(), sha_b.digest())
         self.assertEqual(sha_a.digest(), sha_b.digest())
@@ -873,17 +874,17 @@ class DeltifyTests(TestCase):
     def test_single(self):
     def test_single(self):
         b = Blob.from_string(b"foo")
         b = Blob.from_string(b"foo")
         self.assertEqual(
         self.assertEqual(
-            [(b.type_num, b.sha().digest(), None, b.as_raw_string())],
+            [(b.type_num, b.sha().digest(), None, b.as_raw_chunks())],
             list(deltify_pack_objects([(b, b"")])),
             list(deltify_pack_objects([(b, b"")])),
         )
         )
 
 
     def test_simple_delta(self):
     def test_simple_delta(self):
         b1 = Blob.from_string(b"a" * 101)
         b1 = Blob.from_string(b"a" * 101)
         b2 = Blob.from_string(b"a" * 100)
         b2 = Blob.from_string(b"a" * 100)
-        delta = create_delta(b1.as_raw_string(), b2.as_raw_string())
+        delta = list(create_delta(b1.as_raw_chunks(), b2.as_raw_chunks()))
         self.assertEqual(
         self.assertEqual(
             [
             [
-                (b1.type_num, b1.sha().digest(), None, b1.as_raw_string()),
+                (b1.type_num, b1.sha().digest(), None, b1.as_raw_chunks()),
                 (b2.type_num, b2.sha().digest(), b1.sha().digest(), delta),
                 (b2.type_num, b2.sha().digest(), b1.sha().digest(), delta),
             ],
             ],
             list(deltify_pack_objects([(b1, b""), (b2, b"")])),
             list(deltify_pack_objects([(b1, b""), (b2, b"")])),
@@ -927,7 +928,7 @@ class TestPackStreamReader(TestCase):
             unpacked_delta.delta_base,
             unpacked_delta.delta_base,
         )
         )
         delta = create_delta(b"blob", b"blob1")
         delta = create_delta(b"blob", b"blob1")
-        self.assertEqual(delta, b"".join(unpacked_delta.decomp_chunks))
+        self.assertEqual(b''.join(delta), b"".join(unpacked_delta.decomp_chunks))
         self.assertEqual(entries[1][4], unpacked_delta.crc32)
         self.assertEqual(entries[1][4], unpacked_delta.crc32)
 
 
     def test_read_objects_buffered(self):
     def test_read_objects_buffered(self):

+ 4 - 0
dulwich/tests/test_porcelain.py

@@ -1625,6 +1625,10 @@ class SubmoduleTests(PorcelainTestCase):
 \tpath = bar
 \tpath = bar
 """, f.read())
 """, f.read())
 
 
+    def test_init(self):
+        porcelain.submodule_add(self.repo, "../bar.git", "bar")
+        porcelain.submodule_init(self.repo)
+
 
 
 class PushTests(PorcelainTestCase):
 class PushTests(PorcelainTestCase):
     def test_simple(self):
     def test_simple(self):

+ 12 - 2
dulwich/tests/test_refs.py

@@ -521,6 +521,14 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
         )
         )
         self.assertRaises(SymrefLoop, self._refs.follow, b"refs/heads/loop")
         self.assertRaises(SymrefLoop, self._refs.follow, b"refs/heads/loop")
 
 
+    def test_set_overwrite_loop(self):
+        self.assertRaises(SymrefLoop, self._refs.follow, b"refs/heads/loop")
+        self._refs[b'refs/heads/loop'] = (
+            b"42d06bd4b77fed026b154d16493e5deab78f02ec")
+        self.assertEqual(
+            ([b'refs/heads/loop'], b'42d06bd4b77fed026b154d16493e5deab78f02ec'),
+            self._refs.follow(b"refs/heads/loop"))
+
     def test_delitem(self):
     def test_delitem(self):
         RefsContainerTests.test_delitem(self)
         RefsContainerTests.test_delitem(self)
         ref_file = os.path.join(self._refs.path, b"refs", b"heads", b"master")
         ref_file = os.path.join(self._refs.path, b"refs", b"heads", b"master")
@@ -633,8 +641,10 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
     def test_non_ascii(self):
     def test_non_ascii(self):
         try:
         try:
             encoded_ref = os.fsencode(u"refs/tags/schön")
             encoded_ref = os.fsencode(u"refs/tags/schön")
-        except UnicodeEncodeError:
-            raise SkipTest("filesystem encoding doesn't support special character")
+        except UnicodeEncodeError as exc:
+            raise SkipTest(
+                "filesystem encoding doesn't support special character"
+            ) from exc
         p = os.path.join(os.fsencode(self._repo.path), encoded_ref)
         p = os.path.join(os.fsencode(self._repo.path), encoded_ref)
         with open(p, "w") as f:
         with open(p, "w") as f:
             f.write("00" * 20)
             f.write("00" * 20)

+ 37 - 1
dulwich/tests/test_repository.py

@@ -44,6 +44,7 @@ from dulwich.repo import (
     MemoryRepo,
     MemoryRepo,
     check_user_identity,
     check_user_identity,
     UnsupportedVersion,
     UnsupportedVersion,
+    UnsupportedExtension,
 )
 )
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
@@ -1081,13 +1082,28 @@ class BuildRepoRootTests(TestCase):
         r = Repo(self._repo_dir)
         r = Repo(self._repo_dir)
         self.assertEqual(r.object_store.loose_compression_level, 4)
         self.assertEqual(r.object_store.loose_compression_level, 4)
 
 
-    def test_repositoryformatversion(self):
+    def test_repositoryformatversion_unsupported(self):
         r = self._repo
         r = self._repo
         c = r.get_config()
         c = r.get_config()
         c.set(("core",), "repositoryformatversion", "2")
         c.set(("core",), "repositoryformatversion", "2")
         c.write_to_path()
         c.write_to_path()
         self.assertRaises(UnsupportedVersion, Repo, self._repo_dir)
         self.assertRaises(UnsupportedVersion, Repo, self._repo_dir)
 
 
+    def test_repositoryformatversion_1(self):
+        r = self._repo
+        c = r.get_config()
+        c.set(("core",), "repositoryformatversion", "1")
+        c.write_to_path()
+        Repo(self._repo_dir)
+
+    def test_repositoryformatversion_1_extension(self):
+        r = self._repo
+        c = r.get_config()
+        c.set(("core",), "repositoryformatversion", "1")
+        c.set(("extensions", ), "worktreeconfig", True)
+        c.write_to_path()
+        self.assertRaises(UnsupportedExtension, Repo, self._repo_dir)
+
     def test_commit_encoding_from_config(self):
     def test_commit_encoding_from_config(self):
         r = self._repo
         r = self._repo
         c = r.get_config()
         c = r.get_config()
@@ -1388,6 +1404,7 @@ class BuildRepoRootTests(TestCase):
         porcelain.add(self._repo, paths=[full_path])
         porcelain.add(self._repo, paths=[full_path])
         self._repo.unstage([file])
         self._repo.unstage([file])
         status = list(porcelain.status(self._repo))
         status = list(porcelain.status(self._repo))
+
         self.assertEqual([{'add': [], 'delete': [], 'modify': []}, [b'foo'], []], status)
         self.assertEqual([{'add': [], 'delete': [], 'modify': []}, [b'foo'], []], status)
 
 
     def test_unstage_remove_file(self):
     def test_unstage_remove_file(self):
@@ -1407,6 +1424,19 @@ class BuildRepoRootTests(TestCase):
         status = list(porcelain.status(self._repo))
         status = list(porcelain.status(self._repo))
         self.assertEqual([{'add': [], 'delete': [], 'modify': []}, [b'foo'], []], status)
         self.assertEqual([{'add': [], 'delete': [], 'modify': []}, [b'foo'], []], status)
 
 
+    def test_reset_index(self):
+        r = self._repo
+        with open(os.path.join(r.path, 'a'), 'wb') as f:
+            f.write(b'changed')
+        with open(os.path.join(r.path, 'b'), 'wb') as f:
+            f.write(b'added')
+        r.stage(['a', 'b'])
+        status = list(porcelain.status(self._repo))
+        self.assertEqual([{'add': [b'b'], 'delete': [], 'modify': [b'a']}, [], []], status)
+        r.reset_index()
+        status = list(porcelain.status(self._repo))
+        self.assertEqual([{'add': [], 'delete': [], 'modify': []}, [], ['b']], status)
+
     @skipIf(
     @skipIf(
         sys.platform in ("win32", "darwin"),
         sys.platform in ("win32", "darwin"),
         "tries to implicitly decode as utf8",
         "tries to implicitly decode as utf8",
@@ -1471,3 +1501,9 @@ class CheckUserIdentityTests(TestCase):
         self.assertRaises(
         self.assertRaises(
             InvalidUserIdentity, check_user_identity, b"Fullname >order<>"
             InvalidUserIdentity, check_user_identity, b"Fullname >order<>"
         )
         )
+        self.assertRaises(
+            InvalidUserIdentity, check_user_identity, b'Contains\0null byte <>'
+        )
+        self.assertRaises(
+            InvalidUserIdentity, check_user_identity, b'Contains\nnewline byte <>'
+        )

+ 2 - 1
dulwich/tests/test_server.py

@@ -64,6 +64,7 @@ from dulwich.tests.utils import (
 )
 )
 from dulwich.protocol import (
 from dulwich.protocol import (
     ZERO_SHA,
     ZERO_SHA,
+    format_capability_line,
 )
 )
 
 
 ONE = b"1" * 40
 ONE = b"1" * 40
@@ -131,7 +132,7 @@ class HandlerTestCase(TestCase):
     def test_capability_line(self):
     def test_capability_line(self):
         self.assertEqual(
         self.assertEqual(
             b" cap1 cap2 cap3",
             b" cap1 cap2 cap3",
-            self._handler.capability_line([b"cap1", b"cap2", b"cap3"]),
+            format_capability_line([b"cap1", b"cap2", b"cap3"]),
         )
         )
 
 
     def test_set_client_capabilities(self):
     def test_set_client_capabilities(self):

+ 7 - 7
dulwich/tests/utils.py

@@ -230,7 +230,7 @@ def build_pack(f, objects_spec, store=None):
     """
     """
     sf = SHA1Writer(f)
     sf = SHA1Writer(f)
     num_objects = len(objects_spec)
     num_objects = len(objects_spec)
-    write_pack_header(sf, num_objects)
+    write_pack_header(sf.write, num_objects)
 
 
     full_objects = {}
     full_objects = {}
     offsets = {}
     offsets = {}
@@ -260,7 +260,7 @@ def build_pack(f, objects_spec, store=None):
             base_index, data = obj
             base_index, data = obj
             base = offset - offsets[base_index]
             base = offset - offsets[base_index]
             _, base_data, _ = full_objects[base_index]
             _, base_data, _ = full_objects[base_index]
-            obj = (base, create_delta(base_data, data))
+            obj = (base, list(create_delta(base_data, data)))
         elif type_num == REF_DELTA:
         elif type_num == REF_DELTA:
             base_ref, data = obj
             base_ref, data = obj
             if isinstance(base_ref, int):
             if isinstance(base_ref, int):
@@ -268,9 +268,9 @@ def build_pack(f, objects_spec, store=None):
             else:
             else:
                 base_type_num, base_data = store.get_raw(base_ref)
                 base_type_num, base_data = store.get_raw(base_ref)
                 base = obj_sha(base_type_num, base_data)
                 base = obj_sha(base_type_num, base_data)
-            obj = (base, create_delta(base_data, data))
+            obj = (base, list(create_delta(base_data, data)))
 
 
-        crc32 = write_pack_object(sf, type_num, obj)
+        crc32 = write_pack_object(sf.write, type_num, obj)
         offsets[i] = offset
         offsets[i] = offset
         crc32s[i] = crc32
         crc32s[i] = crc32
 
 
@@ -328,9 +328,9 @@ def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
         commit_num = commit[0]
         commit_num = commit[0]
         try:
         try:
             parent_ids = [nums[pn] for pn in commit[1:]]
             parent_ids = [nums[pn] for pn in commit[1:]]
-        except KeyError as e:
-            (missing_parent,) = e.args
-            raise ValueError("Unknown parent %i" % missing_parent)
+        except KeyError as exc:
+            (missing_parent,) = exc.args
+            raise ValueError("Unknown parent %i" % missing_parent) from exc
 
 
         blobs = []
         blobs = []
         for entry in trees.get(commit_num, []):
         for entry in trees.get(commit_num, []):

+ 11 - 8
dulwich/walk.py

@@ -24,6 +24,7 @@
 import collections
 import collections
 import heapq
 import heapq
 from itertools import chain
 from itertools import chain
+from typing import List, Tuple, Set
 
 
 from dulwich.diff_tree import (
 from dulwich.diff_tree import (
     RENAME_CHANGE_TYPES,
     RENAME_CHANGE_TYPES,
@@ -35,7 +36,9 @@ from dulwich.errors import (
     MissingCommitError,
     MissingCommitError,
 )
 )
 from dulwich.objects import (
 from dulwich.objects import (
+    Commit,
     Tag,
     Tag,
+    ObjectID,
 )
 )
 
 
 ORDER_DATE = "date"
 ORDER_DATE = "date"
@@ -128,15 +131,15 @@ class WalkEntry(object):
 class _CommitTimeQueue(object):
 class _CommitTimeQueue(object):
     """Priority queue of WalkEntry objects by commit time."""
     """Priority queue of WalkEntry objects by commit time."""
 
 
-    def __init__(self, walker):
+    def __init__(self, walker: "Walker"):
         self._walker = walker
         self._walker = walker
         self._store = walker.store
         self._store = walker.store
         self._get_parents = walker.get_parents
         self._get_parents = walker.get_parents
         self._excluded = walker.excluded
         self._excluded = walker.excluded
-        self._pq = []
-        self._pq_set = set()
-        self._seen = set()
-        self._done = set()
+        self._pq: List[Tuple[int, Commit]] = []
+        self._pq_set: Set[ObjectID] = set()
+        self._seen: Set[ObjectID] = set()
+        self._done: Set[ObjectID] = set()
         self._min_time = walker.since
         self._min_time = walker.since
         self._last = None
         self._last = None
         self._extra_commits_left = _MAX_EXTRA_COMMITS
         self._extra_commits_left = _MAX_EXTRA_COMMITS
@@ -145,11 +148,11 @@ class _CommitTimeQueue(object):
         for commit_id in chain(walker.include, walker.excluded):
         for commit_id in chain(walker.include, walker.excluded):
             self._push(commit_id)
             self._push(commit_id)
 
 
-    def _push(self, object_id):
+    def _push(self, object_id: bytes):
         try:
         try:
             obj = self._store[object_id]
             obj = self._store[object_id]
-        except KeyError:
-            raise MissingCommitError(object_id)
+        except KeyError as exc:
+            raise MissingCommitError(object_id) from exc
         if isinstance(obj, Tag):
         if isinstance(obj, Tag):
             self._push(obj.object[1])
             self._push(obj.object[1])
             return
             return

+ 21 - 13
dulwich/web.py

@@ -67,6 +67,23 @@ HTTP_FORBIDDEN = "403 Forbidden"
 HTTP_ERROR = "500 Internal Server Error"
 HTTP_ERROR = "500 Internal Server Error"
 
 
 
 
+NO_CACHE_HEADERS = [
+    ("Expires", "Fri, 01 Jan 1980 00:00:00 GMT"),
+    ("Pragma", "no-cache"),
+    ("Cache-Control", "no-cache, max-age=0, must-revalidate"),
+]
+
+
+def cache_forever_headers(now=None):
+    if now is None:
+        now = time.time()
+    return [
+        ("Date", date_time_string(now)),
+        ("Expires", date_time_string(now + 31536000)),
+        ("Cache-Control", "public, max-age=31536000"),
+    ]
+
+
 def date_time_string(timestamp: Optional[float] = None) -> str:
 def date_time_string(timestamp: Optional[float] = None) -> str:
     # From BaseHTTPRequestHandler.date_time_string in BaseHTTPServer.py in the
     # From BaseHTTPRequestHandler.date_time_string in BaseHTTPServer.py in the
     # Python 2.6.5 standard library, following modifications:
     # Python 2.6.5 standard library, following modifications:
@@ -216,7 +233,7 @@ def get_info_refs(req, backend, mat):
             backend,
             backend,
             [url_prefix(mat)],
             [url_prefix(mat)],
             proto,
             proto,
-            stateless_rpc=req,
+            stateless_rpc=True,
             advertise_refs=True,
             advertise_refs=True,
         )
         )
         handler.proto.write_pkt_line(b"# service=" + service.encode("ascii") + b"\n")
         handler.proto.write_pkt_line(b"# service=" + service.encode("ascii") + b"\n")
@@ -311,7 +328,7 @@ def handle_service_request(req, backend, mat):
     proto = ReceivableProtocol(read, write)
     proto = ReceivableProtocol(read, write)
     # TODO(jelmer): Find a way to pass in repo, rather than having handler_cls
     # TODO(jelmer): Find a way to pass in repo, rather than having handler_cls
     # reopen.
     # reopen.
-    handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=req)
+    handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
     handler.handle()
     handler.handle()
 
 
 
 
@@ -372,20 +389,11 @@ class HTTPGitRequest(object):
 
 
     def nocache(self) -> None:
     def nocache(self) -> None:
         """Set the response to never be cached by the client."""
         """Set the response to never be cached by the client."""
-        self._cache_headers = [
-            ("Expires", "Fri, 01 Jan 1980 00:00:00 GMT"),
-            ("Pragma", "no-cache"),
-            ("Cache-Control", "no-cache, max-age=0, must-revalidate"),
-        ]
+        self._cache_headers = NO_CACHE_HEADERS
 
 
     def cache_forever(self) -> None:
     def cache_forever(self) -> None:
         """Set the response to be cached forever by the client."""
         """Set the response to be cached forever by the client."""
-        now = time.time()
-        self._cache_headers = [
-            ("Date", date_time_string(now)),
-            ("Expires", date_time_string(now + 31536000)),
-            ("Cache-Control", "public, max-age=31536000"),
-        ]
+        self._cache_headers = cache_forever_headers()
 
 
 
 
 class HTTPGitApplication(object):
 class HTTPGitApplication(object):

+ 56 - 0
setup.cfg

@@ -1,6 +1,62 @@
 [mypy]
 [mypy]
 ignore_missing_imports = True
 ignore_missing_imports = True
 
 
+[metadata]
+name = dulwich
+version = attr:dulwich.__version__
+description = Python Git Library
+long_description = file:README.rst
+url = https://www.dulwich.io/
+author = Jelmer Vernooij
+author_email = jelmer@jelmer.uk
+license = Apachev2 or later or GPLv2
+keywords = vcs, git
+classifiers = 
+	Development Status :: 4 - Beta
+	License :: OSI Approved :: Apache Software License
+	Programming Language :: Python :: 3.6
+	Programming Language :: Python :: 3.7
+	Programming Language :: Python :: 3.8
+	Programming Language :: Python :: 3.9
+	Programming Language :: Python :: 3.10
+	Programming Language :: Python :: 3.11
+	Programming Language :: Python :: Implementation :: CPython
+	Programming Language :: Python :: Implementation :: PyPy
+	Operating System :: POSIX
+	Operating System :: Microsoft :: Windows
+	Topic :: Software Development :: Version Control
+project_urls = 
+	Repository=https://www.dulwich.io/code/
+	GitHub=https://github.com/dulwich/dulwich
+	Bug Tracker=https://github.com/dulwich/dulwich/issues
+license_files = COPYING
+
+[options.extras_require]
+fastimport = fastimport
+https = urllib3>=1.24.1
+pgp = gpg
+paramiko = paramiko
+
+[options.entry_points]
+console_scripts = 
+	dulwich = dulwich.cli:main
+
+[options]
+python_requires = >=3.6
+packages = 
+	dulwich
+	dulwich.cloud
+	dulwich.tests
+	dulwich.tests.compat
+	dulwich.contrib
+include_package_data = True
+install_requires = 
+	urllib3>=1.25
+zip_safe = False
+scripts = 
+	bin/dul-receive-pack
+	bin/dul-upload-pack
+
 [egg_info]
 [egg_info]
 tag_build = 
 tag_build = 
 tag_date = 0
 tag_date = 0

+ 12 - 80
setup.py

@@ -3,8 +3,7 @@
 # Setup file for dulwich
 # Setup file for dulwich
 # Copyright (C) 2008-2022 Jelmer Vernooij <jelmer@jelmer.uk>
 # Copyright (C) 2008-2022 Jelmer Vernooij <jelmer@jelmer.uk>
 
 
-from setuptools import setup, Extension, Distribution
-import io
+from setuptools import setup, Extension
 import os
 import os
 import sys
 import sys
 
 
@@ -15,25 +14,6 @@ if sys.version_info < (3, 6):
         'For 2.7 support, please install a version prior to 0.20')
         'For 2.7 support, please install a version prior to 0.20')
 
 
 
 
-dulwich_version_string = '0.20.46'
-
-
-class DulwichDistribution(Distribution):
-
-    def is_pure(self):
-        if self.pure:
-            return True
-
-    def has_ext_modules(self):
-        return not self.pure
-
-    global_options = Distribution.global_options + [
-        ('pure', None, "use pure Python code instead of C "
-                       "extensions (slower on CPython)")]
-
-    pure = False
-
-
 if sys.platform == 'darwin' and os.path.exists('/usr/bin/xcodebuild'):
 if sys.platform == 'darwin' and os.path.exists('/usr/bin/xcodebuild'):
     # XCode 4.0 dropped support for ppc architecture, which is hardcoded in
     # XCode 4.0 dropped support for ppc architecture, which is hardcoded in
     # distutils.sysconfig
     # distutils.sysconfig
@@ -57,66 +37,18 @@ if '__pypy__' not in sys.modules and sys.platform != 'win32':
         'gevent', 'geventhttpclient', 'setuptools>=17.1'])
         'gevent', 'geventhttpclient', 'setuptools>=17.1'])
 
 
 
 
-ext_modules = [
-    Extension('dulwich._objects', ['dulwich/_objects.c']),
-    Extension('dulwich._pack', ['dulwich/_pack.c']),
-    Extension('dulwich._diff_tree', ['dulwich/_diff_tree.c']),
-]
-
-scripts = ['bin/dul-receive-pack', 'bin/dul-upload-pack']
+optional = os.environ.get('CIBUILDWHEEL', '0') != '1'
 
 
 
 
-with io.open(os.path.join(os.path.dirname(__file__), "README.rst"),
-             encoding="utf-8") as f:
-    description = f.read()
+ext_modules = [
+    Extension('dulwich._objects', ['dulwich/_objects.c'],
+              optional=optional),
+    Extension('dulwich._pack', ['dulwich/_pack.c'],
+              optional=optional),
+    Extension('dulwich._diff_tree', ['dulwich/_diff_tree.c'],
+              optional=optional),
+]
 
 
-setup(name='dulwich',
-      author="Jelmer Vernooij",
-      author_email="jelmer@jelmer.uk",
-      url="https://www.dulwich.io/",
-      long_description=description,
-      description="Python Git Library",
-      version=dulwich_version_string,
-      license='Apachev2 or later or GPLv2',
-      project_urls={
-          "Bug Tracker": "https://github.com/dulwich/dulwich/issues",
-          "Repository": "https://www.dulwich.io/code/",
-          "GitHub": "https://github.com/dulwich/dulwich",
-      },
-      keywords="git vcs",
-      packages=['dulwich', 'dulwich.cloud', 'dulwich.tests',
-                'dulwich.tests.compat', 'dulwich.contrib'],
-      package_data={'': ['../docs/tutorial/*.txt', 'py.typed']},
-      scripts=scripts,
+setup(package_data={'': ['../docs/tutorial/*.txt', 'py.typed']},
       ext_modules=ext_modules,
       ext_modules=ext_modules,
-      zip_safe=False,
-      distclass=DulwichDistribution,  # type: ignore
-      install_requires=['urllib3>=1.25'],
-      include_package_data=True,
-      test_suite='dulwich.tests.test_suite',
-      tests_require=tests_require,
-      entry_points={
-          "console_scripts": ["dulwich=dulwich.cli:main"]
-      },
-      python_requires='>=3.6',
-      classifiers=[
-          'Development Status :: 4 - Beta',
-          'License :: OSI Approved :: Apache Software License',
-          'Programming Language :: Python :: 3.6',
-          'Programming Language :: Python :: 3.7',
-          'Programming Language :: Python :: 3.8',
-          'Programming Language :: Python :: 3.9',
-          'Programming Language :: Python :: 3.10',
-          'Programming Language :: Python :: 3.11',
-          'Programming Language :: Python :: Implementation :: CPython',
-          'Programming Language :: Python :: Implementation :: PyPy',
-          'Operating System :: POSIX',
-          'Operating System :: Microsoft :: Windows',
-          'Topic :: Software Development :: Version Control',
-      ],
-      extras_require={
-          'fastimport': ['fastimport'],
-          'https': ['urllib3>=1.24.1'],
-          'pgp': ['gpg'],
-          'paramiko': ['paramiko'],
-      })
+      tests_require=tests_require)