Przeglądaj źródła

Convert repo to python3.

Jelmer Vernooij 10 lat temu
rodzic
commit
17d644de8e

+ 7 - 1
dulwich/config.py

@@ -347,7 +347,13 @@ class ConfigFile(ConfigDict):
             else:
                 f.write(b"[" + section_name + b" \"" + subsection_name + b"\"]\n")
             for key, value in values.items():
-                f.write(b"\t" + key + b" = " + _escape_value(value) + b"\n")
+                if value is True:
+                    value = b"true"
+                elif value is False:
+                    value = b"false"
+                else:
+                    value = _escape_value(value)
+                f.write(b"\t" + key + b" = " + value + b"\n")
 
 
 class StackedConfig(Config):

+ 2 - 2
dulwich/index.py

@@ -455,7 +455,7 @@ def validate_path_element_ntfs(element):
 
 def validate_path(path, element_validator=validate_path_element_default):
     """Default path validator that just checks for .git/."""
-    parts = path.split("/")
+    parts = path.split(b"/")
     for p in parts:
         if not element_validator(p):
             return False
@@ -486,7 +486,7 @@ def build_index_from_tree(prefix, index_path, object_store, tree_id,
     for entry in object_store.iter_tree_contents(tree_id):
         if not validate_path(entry.path):
             continue
-        full_path = os.path.join(prefix, entry.path)
+        full_path = os.path.join(prefix, entry.path.decode(sys.getfilesystemencoding()))
 
         if not os.path.exists(os.path.dirname(full_path)):
             os.makedirs(os.path.dirname(full_path))

+ 12 - 7
dulwich/object_store.py

@@ -76,7 +76,7 @@ class BaseObjectStore(object):
 
     def determine_wants_all(self, refs):
         return [sha for (ref, sha) in iteritems(refs)
-                if not sha in self and not ref.endswith("^{}") and
+                if not sha in self and not ref.endswith(b"^{}") and
                    not sha == ZERO_SHA]
 
     def iter_shas(self, shas):
@@ -483,6 +483,7 @@ class DiskObjectStore(PackBasedObjectStore):
         self._pack_cache_time = os.stat(self.pack_dir).st_mtime
         pack_files = set()
         for name in pack_dir_contents:
+            assert type(name) is str
             # TODO: verify that idx exists first
             if name.startswith("pack-") and name.endswith(".pack"):
                 pack_files.add(name[:-len(".pack")])
@@ -526,6 +527,12 @@ class DiskObjectStore(PackBasedObjectStore):
     def _remove_loose_object(self, sha):
         os.remove(self._get_shafile_path(sha))
 
+    def _get_pack_basepath(self, entries):
+        suffix = iter_sha1(entry[0] for entry in entries)
+        # TODO: Handle self.pack_dir being bytes
+        suffix = suffix.decode('ascii')
+        return os.path.join(self.pack_dir, "pack-" + suffix)
+
     def _complete_thin_pack(self, f, path, copier, indexer):
         """Move a specific file containing a pack into the pack directory.
 
@@ -565,8 +572,7 @@ class DiskObjectStore(PackBasedObjectStore):
 
         # Move the pack in.
         entries.sort()
-        pack_base_name = os.path.join(
-          self.pack_dir, 'pack-' + iter_sha1(e[0] for e in entries))
+        pack_base_name = self._get_pack_basepath(entries)
         os.rename(path, pack_base_name + '.pack')
 
         # Write the index.
@@ -615,8 +621,7 @@ class DiskObjectStore(PackBasedObjectStore):
         """
         with PackData(path) as p:
             entries = p.sorted_entries()
-            basename = os.path.join(self.pack_dir,
-                "pack-%s" % iter_sha1(entry[0] for entry in entries))
+            basename = self._get_pack_basepath(entries)
             with GitFile(basename+".idx", "wb") as f:
                 write_pack_index_v2(f, entries, p.get_stored_checksum())
         os.rename(path, basename + ".pack")
@@ -651,13 +656,13 @@ class DiskObjectStore(PackBasedObjectStore):
 
         :param obj: Object to add
         """
-        dir = os.path.join(self.path, obj.id[:2])
+        path = self._get_shafile_path(obj.id)
+        dir = os.path.dirname(path)
         try:
             os.mkdir(dir)
         except OSError as e:
             if e.errno != errno.EEXIST:
                 raise
-        path = os.path.join(dir, obj.id[2:])
         if os.path.exists(path):
             return # Already there, no need to write again
         with GitFile(path, 'wb') as f:

+ 26 - 29
dulwich/repo.py

@@ -175,16 +175,16 @@ class BaseRepo(object):
     def _init_files(self, bare):
         """Initialize a default set of named files."""
         from dulwich.config import ConfigFile
-        self._put_named_file('description', "Unnamed repository")
+        self._put_named_file('description', b"Unnamed repository")
         f = BytesIO()
         cf = ConfigFile()
-        cf.set("core", "repositoryformatversion", "0")
-        cf.set("core", "filemode", "true")
-        cf.set("core", "bare", str(bare).lower())
-        cf.set("core", "logallrefupdates", "true")
+        cf.set(b"core", b"repositoryformatversion", b"0")
+        cf.set(b"core", b"filemode", b"true")
+        cf.set(b"core", b"bare", bare)
+        cf.set(b"core", b"logallrefupdates", True)
         cf.write_to_file(f)
         self._put_named_file('config', f.getvalue())
-        self._put_named_file(os.path.join('info', 'exclude'), '')
+        self._put_named_file(os.path.join('info', 'exclude'), b'')
 
     def get_named_file(self, path):
         """Get a file from the control dir with a specific name.
@@ -291,7 +291,7 @@ class BaseRepo(object):
         :return: A graph walker object
         """
         if heads is None:
-            heads = self.refs.as_dict('refs/heads').values()
+            heads = self.refs.as_dict(b'refs/heads').values()
         return ObjectStoreGraphWalker(heads, self.get_parents)
 
     def get_refs(self):
@@ -303,7 +303,7 @@ class BaseRepo(object):
 
     def head(self):
         """Return the SHA1 pointed at by HEAD."""
-        return self.refs['HEAD']
+        return self.refs[b'HEAD']
 
     def _get_object(self, sha, cls):
         assert len(sha) in (20, 40)
@@ -439,7 +439,7 @@ class BaseRepo(object):
         :return: A `ShaFile` object, such as a Commit or Blob
         :raise KeyError: when the specified ref or object does not exist
         """
-        if not isinstance(name, str):
+        if not isinstance(name, bytes):
             raise TypeError("'name' must be bytestring, not %.80s" %
                     type(name).__name__)
         if len(name) in (20, 40):
@@ -468,10 +468,10 @@ class BaseRepo(object):
         :param name: ref name
         :param value: Ref value - either a ShaFile object, or a hex sha
         """
-        if name.startswith("refs/") or name == "HEAD":
+        if name.startswith(b"refs/") or name == b'HEAD':
             if isinstance(value, ShaFile):
                 self.refs[name] = value.id
-            elif isinstance(value, str):
+            elif isinstance(value, bytes):
                 self.refs[name] = value
             else:
                 raise TypeError(value)
@@ -483,7 +483,7 @@ class BaseRepo(object):
 
         :param name: Name of the ref to remove
         """
-        if name.startswith("refs/") or name == "HEAD":
+        if name.startswith(b"refs/") or name == b"HEAD":
             del self.refs[name]
         else:
             raise ValueError(name)
@@ -492,9 +492,8 @@ class BaseRepo(object):
         """Determine the identity to use for new commits.
         """
         config = self.get_config_stack()
-        return "%s <%s>" % (
-            config.get(("user", ), "name"),
-            config.get(("user", ), "email"))
+        return (config.get((b"user", ), b"name") + b" <" +
+                config.get((b"user", ), b"email") + b">")
 
     def _add_graftpoints(self, updated_graftpoints):
         """Add or modify graftpoints
@@ -521,7 +520,7 @@ class BaseRepo(object):
                   author=None, commit_timestamp=None,
                   commit_timezone=None, author_timestamp=None,
                   author_timezone=None, tree=None, encoding=None,
-                  ref='HEAD', merge_heads=None):
+                  ref=b'HEAD', merge_heads=None):
         """Create a new commit.
 
         :param message: Commit message
@@ -759,7 +758,7 @@ class Repo(BaseRepo):
         index.write()
 
     def clone(self, target_path, mkdir=True, bare=False,
-            origin="origin"):
+            origin=b"origin"):
         """Clone this repository.
 
         :param target_path: Target path
@@ -775,21 +774,21 @@ class Repo(BaseRepo):
             target = self.init_bare(target_path)
         self.fetch(target)
         target.refs.import_refs(
-            'refs/remotes/' + origin, self.refs.as_dict('refs/heads'))
+            b'refs/remotes/' + origin, self.refs.as_dict(b'refs/heads'))
         target.refs.import_refs(
-            'refs/tags', self.refs.as_dict('refs/tags'))
+            b'refs/tags', self.refs.as_dict(b'refs/tags'))
         try:
             target.refs.add_if_new(
-                'refs/heads/master',
-                self.refs['refs/heads/master'])
+                b'refs/heads/master',
+                self.refs[b'refs/heads/master'])
         except KeyError:
             pass
 
         # Update target head
-        head, head_sha = self.refs._follow('HEAD')
+        head, head_sha = self.refs._follow(b'HEAD')
         if head is not None and head_sha is not None:
-            target.refs.set_symbolic_ref('HEAD', head)
-            target['HEAD'] = head_sha
+            target.refs.set_symbolic_ref(b'HEAD', head)
+            target[b'HEAD'] = head_sha
 
             if not bare:
                 # Checkout HEAD to target dir
@@ -808,7 +807,7 @@ class Repo(BaseRepo):
             validate_path_element_ntfs,
             )
         if tree is None:
-            tree = self['HEAD'].tree
+            tree = self[b'HEAD'].tree
         config = self.get_config()
         honor_filemode = config.get_boolean('core', 'filemode', os.name != "nt")
         if config.get_boolean('core', 'core.protectNTFS', os.name == "nt"):
@@ -858,9 +857,7 @@ class Repo(BaseRepo):
         :param description: Text to set as description for this repository.
         """
 
-        path = os.path.join(self._controldir, 'description')
-        with open(path, 'w') as f:
-            f.write(description)
+        self._put_named_file('description', description)
 
     @classmethod
     def _init_maybe_bare(cls, path, bare):
@@ -868,7 +865,7 @@ class Repo(BaseRepo):
             os.mkdir(os.path.join(path, *d))
         DiskObjectStore.init(os.path.join(path, OBJECTDIR))
         ret = cls(path)
-        ret.refs.set_symbolic_ref("HEAD", "refs/heads/master")
+        ret.refs.set_symbolic_ref(b'HEAD', b"refs/heads/master")
         ret._init_files(bare)
         return ret
 

+ 2 - 2
dulwich/tests/test_object_store.py

@@ -64,11 +64,11 @@ class ObjectStoreTests(object):
 
     def test_determine_wants_all(self):
         self.assertEqual([b"1" * 40],
-            self.store.determine_wants_all({"refs/heads/foo": b"1" * 40}))
+            self.store.determine_wants_all({b"refs/heads/foo": b"1" * 40}))
 
     def test_determine_wants_all_zero(self):
         self.assertEqual([],
-            self.store.determine_wants_all({"refs/heads/foo": b"0" * 40}))
+            self.store.determine_wants_all({b"refs/heads/foo": b"0" * 40}))
 
     def test_iter(self):
         self.assertEqual([], list(self.store))

+ 96 - 96
dulwich/tests/test_repository.py

@@ -22,6 +22,7 @@
 import os
 import stat
 import shutil
+import sys
 import tempfile
 import warnings
 
@@ -45,7 +46,7 @@ from dulwich.tests.utils import (
     skipIfPY3,
     )
 
-missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
+missing_sha = b'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
 
 
 @skipIfPY3
@@ -87,7 +88,6 @@ class CreateRepositoryTests(TestCase):
         self._check_repo_contents(repo, True)
 
 
-@skipIfPY3
 class RepositoryTests(TestCase):
 
     def setUp(self):
@@ -105,17 +105,17 @@ class RepositoryTests(TestCase):
 
     def test_setitem(self):
         r = self._repo = open_repo('a.git')
-        r["refs/tags/foo"] = 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
-        self.assertEqual('a90fa2d900a17e99b433217e988c4eb4a2e9a097',
-                          r["refs/tags/foo"].id)
+        r[b"refs/tags/foo"] = b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
+        self.assertEqual(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+                          r[b"refs/tags/foo"].id)
 
     def test_getitem_unicode(self):
         r = self._repo = open_repo('a.git')
 
         test_keys = [
-            ('refs/heads/master', True),
-            ('a90fa2d900a17e99b433217e988c4eb4a2e9a097', True),
-            ('11' * 19 + '--', False),
+            (b'refs/heads/master', True),
+            (b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', True),
+            (b'11' * 19 + b'--', False),
         ]
 
         for k, contained in test_keys:
@@ -123,38 +123,38 @@ class RepositoryTests(TestCase):
 
         for k, _ in test_keys:
             self.assertRaisesRegexp(
-                TypeError, "'name' must be bytestring, not unicode",
-                r.__getitem__, unicode(k)
+                TypeError, "'name' must be bytestring, not int",
+                r.__getitem__, 12
             )
 
     def test_delitem(self):
         r = self._repo = open_repo('a.git')
 
-        del r['refs/heads/master']
-        self.assertRaises(KeyError, lambda: r['refs/heads/master'])
+        del r[b'refs/heads/master']
+        self.assertRaises(KeyError, lambda: r[b'refs/heads/master'])
 
-        del r['HEAD']
-        self.assertRaises(KeyError, lambda: r['HEAD'])
+        del r[b'HEAD']
+        self.assertRaises(KeyError, lambda: r[b'HEAD'])
 
-        self.assertRaises(ValueError, r.__delitem__, 'notrefs/foo')
+        self.assertRaises(ValueError, r.__delitem__, b'notrefs/foo')
 
     def test_get_refs(self):
         r = self._repo = open_repo('a.git')
         self.assertEqual({
-            'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
-            'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
-            'refs/tags/mytag': '28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
-            'refs/tags/mytag-packed': 'b0931cadc54336e78a1d980420e3268903b57a50',
+            b'HEAD': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            b'refs/heads/master': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
+            b'refs/tags/mytag-packed': b'b0931cadc54336e78a1d980420e3268903b57a50',
             }, r.get_refs())
 
     def test_head(self):
         r = self._repo = open_repo('a.git')
-        self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
+        self.assertEqual(r.head(), b'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
 
     def test_get_object(self):
         r = self._repo = open_repo('a.git')
         obj = r.get_object(r.head())
-        self.assertEqual(obj.type_name, 'commit')
+        self.assertEqual(obj.type_name, b'commit')
 
     def test_get_object_non_existant(self):
         r = self._repo = open_repo('a.git')
@@ -166,7 +166,7 @@ class RepositoryTests(TestCase):
 
     def test_contains_ref(self):
         r = self._repo = open_repo('a.git')
-        self.assertTrue("HEAD" in r)
+        self.assertTrue(b"HEAD" in r)
 
     def test_get_no_description(self):
         r = self._repo = open_repo('a.git')
@@ -174,50 +174,50 @@ class RepositoryTests(TestCase):
 
     def test_get_description(self):
         r = self._repo = open_repo('a.git')
-        with open(os.path.join(r.path, 'description'), 'w') as f:
-            f.write("Some description")
-        self.assertEqual("Some description", r.get_description())
+        with open(os.path.join(r.path, 'description'), 'wb') as f:
+            f.write(b"Some description")
+        self.assertEqual(b"Some description", r.get_description())
 
     def test_set_description(self):
         r = self._repo = open_repo('a.git')
-        description = "Some description"
+        description = b"Some description"
         r.set_description(description)
         self.assertEqual(description, r.get_description())
 
     def test_contains_missing(self):
         r = self._repo = open_repo('a.git')
-        self.assertFalse("bar" in r)
+        self.assertFalse(b"bar" in r)
 
     def test_get_peeled(self):
         # unpacked ref
         r = self._repo = open_repo('a.git')
-        tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
+        tag_sha = b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
         self.assertNotEqual(r[tag_sha].sha().hexdigest(), r.head())
-        self.assertEqual(r.get_peeled('refs/tags/mytag'), r.head())
+        self.assertEqual(r.get_peeled(b'refs/tags/mytag'), r.head())
 
         # packed ref with cached peeled value
-        packed_tag_sha = 'b0931cadc54336e78a1d980420e3268903b57a50'
+        packed_tag_sha = b'b0931cadc54336e78a1d980420e3268903b57a50'
         parent_sha = r[r.head()].parents[0]
         self.assertNotEqual(r[packed_tag_sha].sha().hexdigest(), parent_sha)
-        self.assertEqual(r.get_peeled('refs/tags/mytag-packed'), parent_sha)
+        self.assertEqual(r.get_peeled(b'refs/tags/mytag-packed'), parent_sha)
 
         # TODO: add more corner cases to test repo
 
     def test_get_peeled_not_tag(self):
         r = self._repo = open_repo('a.git')
-        self.assertEqual(r.get_peeled('HEAD'), r.head())
+        self.assertEqual(r.get_peeled(b'HEAD'), r.head())
 
     def test_get_walker(self):
         r = self._repo = open_repo('a.git')
         # include defaults to [r.head()]
         self.assertEqual([e.commit.id for e in r.get_walker()],
-                         [r.head(), '2a72d929692c41d8554c07f6301757ba18a65d91'])
+                         [r.head(), b'2a72d929692c41d8554c07f6301757ba18a65d91'])
         self.assertEqual(
-            [e.commit.id for e in r.get_walker(['2a72d929692c41d8554c07f6301757ba18a65d91'])],
-            ['2a72d929692c41d8554c07f6301757ba18a65d91'])
+            [e.commit.id for e in r.get_walker([b'2a72d929692c41d8554c07f6301757ba18a65d91'])],
+            [b'2a72d929692c41d8554c07f6301757ba18a65d91'])
         self.assertEqual(
-            [e.commit.id for e in r.get_walker('2a72d929692c41d8554c07f6301757ba18a65d91')],
-            ['2a72d929692c41d8554c07f6301757ba18a65d91'])
+            [e.commit.id for e in r.get_walker(b'2a72d929692c41d8554c07f6301757ba18a65d91')],
+            [b'2a72d929692c41d8554c07f6301757ba18a65d91'])
 
     def test_clone(self):
         r = self._repo = open_repo('a.git')
@@ -225,17 +225,17 @@ class RepositoryTests(TestCase):
         self.addCleanup(shutil.rmtree, tmp_dir)
         t = r.clone(tmp_dir, mkdir=False)
         self.assertEqual({
-            'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
-            'refs/remotes/origin/master':
-                'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
-            'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
-            'refs/tags/mytag': '28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
-            'refs/tags/mytag-packed':
-                'b0931cadc54336e78a1d980420e3268903b57a50',
+            b'HEAD': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            b'refs/remotes/origin/master':
+                b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            b'refs/heads/master': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
+            b'refs/tags/mytag-packed':
+                b'b0931cadc54336e78a1d980420e3268903b57a50',
             }, t.refs.as_dict())
         shas = [e.commit.id for e in r.get_walker()]
         self.assertEqual(shas, [t.head(),
-                         '2a72d929692c41d8554c07f6301757ba18a65d91'])
+                         b'2a72d929692c41d8554c07f6301757ba18a65d91'])
 
     def test_clone_no_head(self):
         temp_dir = tempfile.mkdtemp()
@@ -245,13 +245,13 @@ class RepositoryTests(TestCase):
         shutil.copytree(os.path.join(repo_dir, 'a.git'),
                         dest_dir, symlinks=True)
         r = Repo(dest_dir)
-        del r.refs["refs/heads/master"]
-        del r.refs["HEAD"]
+        del r.refs[b"refs/heads/master"]
+        del r.refs[b"HEAD"]
         t = r.clone(os.path.join(temp_dir, 'b.git'), mkdir=True)
         self.assertEqual({
-            'refs/tags/mytag': '28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
-            'refs/tags/mytag-packed':
-                'b0931cadc54336e78a1d980420e3268903b57a50',
+            b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
+            b'refs/tags/mytag-packed':
+                b'b0931cadc54336e78a1d980420e3268903b57a50',
             }, t.refs.as_dict())
 
     def test_clone_empty(self):
@@ -270,24 +270,24 @@ class RepositoryTests(TestCase):
     def test_merge_history(self):
         r = self._repo = open_repo('simple_merge.git')
         shas = [e.commit.id for e in r.get_walker()]
-        self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
-                                'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
-                                '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
-                                '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
-                                '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
+        self.assertEqual(shas, [b'5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
+                                b'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
+                                b'4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
+                                b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
+                                b'0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
 
     def test_out_of_order_merge(self):
         """Test that revision history is ordered by date, not parent order."""
         r = self._repo = open_repo('ooo_merge.git')
         shas = [e.commit.id for e in r.get_walker()]
-        self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
-                                'f507291b64138b875c28e03469025b1ea20bc614',
-                                'fb5b0425c7ce46959bec94d54b9a157645e114f5',
-                                'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
+        self.assertEqual(shas, [b'7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
+                                b'f507291b64138b875c28e03469025b1ea20bc614',
+                                b'fb5b0425c7ce46959bec94d54b9a157645e114f5',
+                                b'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
 
     def test_get_tags_empty(self):
         r = self._repo = open_repo('ooo_merge.git')
-        self.assertEqual({}, r.refs.as_dict('refs/tags'))
+        self.assertEqual({}, r.refs.as_dict(b'refs/tags'))
 
     def test_get_config(self):
         r = self._repo = open_repo('ooo_merge.git')
@@ -305,7 +305,7 @@ class RepositoryTests(TestCase):
         rel = os.path.relpath(os.path.join(repo_dir, 'submodule'), temp_dir)
         os.symlink(os.path.join(rel, 'dotgit'), os.path.join(temp_dir, '.git'))
         r = Repo(temp_dir)
-        self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
+        self.assertEqual(r.head(), b'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
 
     def test_common_revisions(self):
         """
@@ -315,7 +315,7 @@ class RepositoryTests(TestCase):
         ``Repo.fetch_objects()``).
         """
 
-        expected_shas = set(['60dacdc733de308bb77bb76ce0fb0f9b44c9769e'])
+        expected_shas = set([b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e'])
 
         # Source for objects.
         r_base = open_repo('simple_merge.git')
@@ -326,25 +326,25 @@ class RepositoryTests(TestCase):
         # corrupted, but we're only checking for commits for the purpose of this
         # test, so it's immaterial.
         r1_dir = tempfile.mkdtemp()
-        r1_commits = ['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', # HEAD
-                      '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
-                      '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
+        r1_commits = [b'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', # HEAD
+                      b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
+                      b'0d89f20333fbb1d2f3a94da77f4981373d8f4310']
 
         r2_dir = tempfile.mkdtemp()
-        r2_commits = ['4cffe90e0a41ad3f5190079d7c8f036bde29cbe6', # HEAD
-                      '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
-                      '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
+        r2_commits = [b'4cffe90e0a41ad3f5190079d7c8f036bde29cbe6', # HEAD
+                      b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
+                      b'0d89f20333fbb1d2f3a94da77f4981373d8f4310']
 
         try:
             r1 = Repo.init_bare(r1_dir)
             for c in r1_commits:
                 r1.object_store.add_object(r_base.get_object(c))
-            r1.refs['HEAD'] = r1_commits[0]
+            r1.refs[b'HEAD'] = r1_commits[0]
 
             r2 = Repo.init_bare(r2_dir)
             for c in r2_commits:
                 r2.object_store.add_object(r_base.get_object(c))
-            r2.refs['HEAD'] = r2_commits[0]
+            r2.refs[b'HEAD'] = r2_commits[0]
 
             # Finally, the 'real' testing!
             shas = r2.object_store.find_common_revisions(r1.get_graph_walker())
@@ -360,11 +360,11 @@ class RepositoryTests(TestCase):
         if os.name != 'posix':
             self.skipTest('shell hook tests requires POSIX shell')
 
-        pre_commit_fail = """#!/bin/sh
+        pre_commit_fail = b"""#!/bin/sh
 exit 1
 """
 
-        pre_commit_success = """#!/bin/sh
+        pre_commit_success = b"""#!/bin/sh
 exit 0
 """
 
@@ -389,9 +389,9 @@ exit 0
         os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
         commit_sha = r.do_commit(
-            'empty commit',
-            committer='Test Committer <test@nodomain.com>',
-            author='Test Author <test@nodomain.com>',
+            b'empty commit',
+            committer=b'Test Committer <test@nodomain.com>',
+            author=b'Test Author <test@nodomain.com>',
             commit_timestamp=12395, commit_timezone=0,
             author_timestamp=12395, author_timezone=0)
         self.assertEqual([], r[commit_sha].parents)
@@ -400,11 +400,11 @@ exit 0
         if os.name != 'posix':
             self.skipTest('shell hook tests requires POSIX shell')
 
-        commit_msg_fail = """#!/bin/sh
+        commit_msg_fail = b"""#!/bin/sh
 exit 1
 """
 
-        commit_msg_success = """#!/bin/sh
+        commit_msg_success = b"""#!/bin/sh
 exit 0
 """
 
@@ -418,9 +418,9 @@ exit 0
             f.write(commit_msg_fail)
         os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
-        self.assertRaises(errors.CommitError, r.do_commit, 'failed commit',
-                          committer='Test Committer <test@nodomain.com>',
-                          author='Test Author <test@nodomain.com>',
+        self.assertRaises(errors.CommitError, r.do_commit, b'failed commit',
+                          committer=b'Test Committer <test@nodomain.com>',
+                          author=b'Test Author <test@nodomain.com>',
                           commit_timestamp=12345, commit_timezone=0,
                           author_timestamp=12345, author_timezone=0)
 
@@ -429,9 +429,9 @@ exit 0
         os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
         commit_sha = r.do_commit(
-            'empty commit',
-            committer='Test Committer <test@nodomain.com>',
-            author='Test Author <test@nodomain.com>',
+            b'empty commit',
+            committer=b'Test Committer <test@nodomain.com>',
+            author=b'Test Author <test@nodomain.com>',
             commit_timestamp=12395, commit_timezone=0,
             author_timestamp=12395, author_timezone=0)
         self.assertEqual([], r[commit_sha].parents)
@@ -445,14 +445,14 @@ exit 0
         self.addCleanup(shutil.rmtree, repo_dir)
 
         (fd, path) = tempfile.mkstemp(dir=repo_dir)
-        post_commit_msg = """#!/bin/sh
-rm %(file)s
-""" % {'file': path}
+        post_commit_msg = b"""#!/bin/sh
+rm """ + path.encode(sys.getfilesystemencoding()) + b"""
+"""
 
         root_sha = r.do_commit(
-            'empty commit',
-            committer='Test Committer <test@nodomain.com>',
-            author='Test Author <test@nodomain.com>',
+            b'empty commit',
+            committer=b'Test Committer <test@nodomain.com>',
+            author=b'Test Author <test@nodomain.com>',
             commit_timestamp=12345, commit_timezone=0,
             author_timestamp=12345, author_timezone=0)
         self.assertEqual([], r[root_sha].parents)
@@ -464,16 +464,16 @@ rm %(file)s
         os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
 
         commit_sha = r.do_commit(
-            'empty commit',
-            committer='Test Committer <test@nodomain.com>',
-            author='Test Author <test@nodomain.com>',
+            b'empty commit',
+            committer=b'Test Committer <test@nodomain.com>',
+            author=b'Test Author <test@nodomain.com>',
             commit_timestamp=12345, commit_timezone=0,
             author_timestamp=12345, author_timezone=0)
         self.assertEqual([root_sha], r[commit_sha].parents)
 
         self.assertFalse(os.path.exists(path))
 
-        post_commit_msg_fail = """#!/bin/sh
+        post_commit_msg_fail = b"""#!/bin/sh
 exit 1
 """
         with open(post_commit, 'wb') as f:
@@ -486,9 +486,9 @@ exit 1
         self.addCleanup(restore_warnings)
 
         commit_sha2 = r.do_commit(
-            'empty commit',
-            committer='Test Committer <test@nodomain.com>',
-            author='Test Author <test@nodomain.com>',
+            b'empty commit',
+            committer=b'Test Committer <test@nodomain.com>',
+            author=b'Test Author <test@nodomain.com>',
             commit_timestamp=12345, commit_timezone=0,
             author_timestamp=12345, author_timezone=0)
         self.assertEqual(len(warnings_list), 1)

+ 2 - 0
dulwich/walk.py

@@ -226,6 +226,8 @@ class Walker(object):
         if order not in ALL_ORDERS:
             raise ValueError('Unknown walk order %s' % order)
         self.store = store
+        if not isinstance(include, list):
+            include = [include]
         self.include = include
         self.excluded = set(exclude or [])
         self.order = order