Browse Source

Fix porcelain tests on all versions.

Jelmer Vernooij 9 years ago
parent
commit
cdb640f8f5
2 changed files with 53 additions and 35 deletions
  1. 39 25
      dulwich/porcelain.py
  2. 14 10
      dulwich/tests/test_porcelain.py

+ 39 - 25
dulwich/porcelain.py

@@ -67,6 +67,7 @@ from dulwich.errors import (
     )
 from dulwich.index import get_unstaged_changes
 from dulwich.objects import (
+    Commit,
     Tag,
     parse_timezone,
     )
@@ -290,13 +291,13 @@ def rm(repo=".", paths=None):
         index.write()
 
 
-def commit_decode(commit, contents):
+def commit_decode(commit, contents, default_encoding='utf-8'):
     if commit.encoding is not None:
         return contents.decode(commit.encoding, "replace")
-    return contents.decode("utf-8", "replace")
+    return contents.decode(default_encoding, "replace")
 
 
-def print_commit(commit, outstream=sys.stdout):
+def print_commit(commit, decode, outstream=sys.stdout):
     """Write a human-readable commit log entry.
 
     :param commit: A `Commit` object
@@ -307,77 +308,82 @@ def print_commit(commit, outstream=sys.stdout):
     if len(commit.parents) > 1:
         outstream.write("merge: " +
             "...".join([c.decode('ascii') for c in commit.parents[1:]]) + "\n")
-    outstream.write("author: " + commit_decode(commit, commit.author) + "\n")
-    outstream.write("committer: " + commit_decode(commit, commit.committer) + "\n")
+    outstream.write("author: " + decode(commit.author) + "\n")
+    outstream.write("committer: " + decode(commit.committer) + "\n")
     outstream.write("\n")
-    outstream.write(commit_decode(commit, commit.message) + "\n")
+    outstream.write(decode(commit.message) + "\n")
     outstream.write("\n")
 
 
-def print_tag(tag, outstream=sys.stdout):
+def print_tag(tag, decode, outstream=sys.stdout):
     """Write a human-readable tag.
 
     :param tag: A `Tag` object
+    :param decode: Function for decoding bytes to unicode string
     :param outstream: A stream to write to
     """
-    outstream.write(b"Tagger: " + tag.tagger + b"\n")
-    outstream.write(b"Date:   " + tag.tag_time + b"\n")
-    outstream.write(b"\n")
-    outstream.write(tag.message + b"\n")
-    outstream.write(b"\n")
+    outstream.write("Tagger: " + decode(tag.tagger) + "\n")
+    outstream.write("Date:   " + decode(tag.tag_time) + "\n")
+    outstream.write("\n")
+    outstream.write(decode(tag.message) + "\n")
+    outstream.write("\n")
 
 
-def show_blob(repo, blob, outstream=sys.stdout):
+def show_blob(repo, blob, decode, outstream=sys.stdout):
     """Write a blob to a stream.
 
     :param repo: A `Repo` object
     :param blob: A `Blob` object
+    :param decode: Function for decoding bytes to unicode string
     :param outstream: A stream file to write to
     """
-    outstream.write(blob.data)
+    outstream.write(decode(blob.data))
 
 
-def show_commit(repo, commit, outstream=sys.stdout):
+def show_commit(repo, commit, decode, outstream=sys.stdout):
     """Show a commit to a stream.
 
     :param repo: A `Repo` object
     :param commit: A `Commit` object
+    :param decode: Function for decoding bytes to unicode string
     :param outstream: Stream to write to
     """
-    print_commit(commit, outstream)
+    print_commit(commit, decode=decode, outstream=outstream)
     parent_commit = repo[commit.parents[0]]
     write_tree_diff(outstream, repo.object_store, parent_commit.tree, commit.tree)
 
 
-def show_tree(repo, tree, outstream=sys.stdout):
+def show_tree(repo, tree, decode, outstream=sys.stdout):
     """Print a tree to a stream.
 
     :param repo: A `Repo` object
     :param tree: A `Tree` object
+    :param decode: Function for decoding bytes to unicode string
     :param outstream: Stream to write to
     """
     for n in tree:
-        outstream.write("%s\n" % n)
+        outstream.write(decode(n) + "\n")
 
 
-def show_tag(repo, tag, outstream=sys.stdout):
+def show_tag(repo, tag, decode, outstream=sys.stdout):
     """Print a tag to a stream.
 
     :param repo: A `Repo` object
     :param tag: A `Tag` object
+    :param decode: Function for decoding bytes to unicode string
     :param outstream: Stream to write to
     """
-    print_tag(tag, outstream)
+    print_tag(tag, decode, outstream)
     show_object(repo, repo[tag.object[1]], outstream)
 
 
-def show_object(repo, obj, outstream):
+def show_object(repo, obj, decode, outstream):
     return {
         b"tree": show_tree,
         b"blob": show_blob,
         b"commit": show_commit,
         b"tag": show_tag,
-            }[obj.type_name](repo, obj, outstream)
+            }[obj.type_name](repo, obj, decode, outstream)
 
 
 def log(repo=".", outstream=sys.stdout, max_entries=None):
@@ -390,15 +396,18 @@ def log(repo=".", outstream=sys.stdout, max_entries=None):
     with open_repo_closing(repo) as r:
         walker = r.get_walker(max_entries=max_entries)
         for entry in walker:
-            print_commit(entry.commit, outstream)
+            decode = lambda x: commit_decode(entry.commit, x)
+            print_commit(entry.commit, decode, outstream)
 
 
-def show(repo=".", objects=None, outstream=sys.stdout):
+# TODO(jelmer): better default for encoding?
+def show(repo=".", objects=None, outstream=sys.stdout, default_encoding='utf-8'):
     """Print the changes in a commit.
 
     :param repo: Path to repository
     :param objects: Objects to show (defaults to [HEAD])
     :param outstream: Stream to write to
+    :param default_encoding: Default encoding to use if none is set in the commit
     """
     if objects is None:
         objects = ["HEAD"]
@@ -406,7 +415,12 @@ def show(repo=".", objects=None, outstream=sys.stdout):
         objects = [objects]
     with open_repo_closing(repo) as r:
         for objectish in objects:
-            show_object(r, parse_object(r, objectish), outstream)
+            o = parse_object(r, objectish)
+            if isinstance(o, Commit):
+                decode = lambda x: commit_decode(o, x, default_encoding)
+            else:
+                decode = lambda x: x.decode(default_encoding)
+            show_object(r, o, decode, outstream)
 
 
 def diff_tree(repo, old_tree, new_tree, outstream=sys.stdout):

+ 14 - 10
dulwich/tests/test_porcelain.py

@@ -20,6 +20,10 @@
 
 from contextlib import closing
 from io import BytesIO
+try:
+    from StringIO import StringIO
+except ImportError:
+    from io import StringIO
 import os
 import shutil
 import tarfile
@@ -232,17 +236,17 @@ class LogTests(PorcelainTestCase):
         c1, c2, c3 = build_commit_graph(self.repo.object_store, [[1], [2, 1],
             [3, 1, 2]])
         self.repo.refs[b"HEAD"] = c3.id
-        outstream = BytesIO()
+        outstream = StringIO()
         porcelain.log(self.repo.path, outstream=outstream)
-        self.assertEqual(3, outstream.getvalue().count(b"-" * 50))
+        self.assertEqual(3, outstream.getvalue().count("-" * 50))
 
     def test_max_entries(self):
         c1, c2, c3 = build_commit_graph(self.repo.object_store, [[1], [2, 1],
             [3, 1, 2]])
         self.repo.refs[b"HEAD"] = c3.id
-        outstream = BytesIO()
+        outstream = StringIO()
         porcelain.log(self.repo.path, outstream=outstream, max_entries=1)
-        self.assertEqual(1, outstream.getvalue().count(b"-" * 50))
+        self.assertEqual(1, outstream.getvalue().count("-" * 50))
 
 
 class ShowTests(PorcelainTestCase):
@@ -251,24 +255,24 @@ class ShowTests(PorcelainTestCase):
         c1, c2, c3 = build_commit_graph(self.repo.object_store, [[1], [2, 1],
             [3, 1, 2]])
         self.repo.refs[b"HEAD"] = c3.id
-        outstream = BytesIO()
+        outstream = StringIO()
         porcelain.show(self.repo.path, objects=c3.id, outstream=outstream)
-        self.assertTrue(outstream.getvalue().startswith(b"-" * 50))
+        self.assertTrue(outstream.getvalue().startswith("-" * 50))
 
     def test_simple(self):
         c1, c2, c3 = build_commit_graph(self.repo.object_store, [[1], [2, 1],
             [3, 1, 2]])
         self.repo.refs[b"HEAD"] = c3.id
-        outstream = BytesIO()
+        outstream = StringIO()
         porcelain.show(self.repo.path, objects=[c3.id], outstream=outstream)
-        self.assertTrue(outstream.getvalue().startswith(b"-" * 50))
+        self.assertTrue(outstream.getvalue().startswith("-" * 50))
 
     def test_blob(self):
         b = Blob.from_string(b"The Foo\n")
         self.repo.object_store.add_object(b)
-        outstream = BytesIO()
+        outstream = StringIO()
         porcelain.show(self.repo.path, objects=[b.id], outstream=outstream)
-        self.assertEqual(outstream.getvalue(), b"The Foo\n")
+        self.assertEqual(outstream.getvalue(), "The Foo\n")
 
 
 class SymbolicRefTests(PorcelainTestCase):