瀏覽代碼

Allow passing in unicode strings to config objects.

Jelmer Vernooij 7 年之前
父節點
當前提交
5a8ba757f4
共有 2 個文件被更改,包括 31 次插入27 次删除
  1. 15 14
      dulwich/config.py
  2. 16 13
      dulwich/repo.py

+ 15 - 14
dulwich/config.py

@@ -28,6 +28,7 @@ TODO:
 
 import errno
 import os
+import sys
 
 from collections import (
     Iterable,
@@ -39,9 +40,6 @@ from collections import (
 from dulwich.file import GitFile
 
 
-# TODO(jelmer): Allow passing in unicode strings; default encoding to
-# sys.getdefaultencoding()
-
 SENTINAL = object()
 
 
@@ -177,8 +175,11 @@ class Config(object):
 class ConfigDict(Config, MutableMapping):
     """Git configuration stored in a dictionary."""
 
-    def __init__(self, values=None):
+    def __init__(self, values=None, encoding=None):
         """Create a new ConfigDict."""
+        if encoding is None:
+            encoding = sys.getdefaultencoding()
+        self.encoding = encoding
         self._values = CaseInsensitiveDict.make(values)
 
     def __repr__(self):
@@ -212,14 +213,18 @@ class ConfigDict(Config, MutableMapping):
         else:
             return (parts[0], None, parts[1])
 
-    @staticmethod
-    def _check_section_and_name(section, name):
+    def _check_section_and_name(self, section, name):
         if not isinstance(section, tuple):
             section = (section, )
-        if not all([isinstance(subsection, bytes) for subsection in section]):
-            raise TypeError(section)
+
+        section = tuple([
+            subsection.encode(self.encoding)
+            if not isinstance(subsection, bytes) else subsection
+            for subsection in section
+            ])
+
         if not isinstance(name, bytes):
-            raise TypeError(name)
+            name = name.encode(self.encoding)
 
         return section, name
 
@@ -238,7 +243,7 @@ class ConfigDict(Config, MutableMapping):
         section, name = self._check_section_and_name(section, name)
 
         if type(value) not in (bool, bytes):
-            raise TypeError(value)
+            value = value.encode(self.encoding)
 
         self._values.setdefault(section)[name] = value
 
@@ -510,10 +515,6 @@ class StackedConfig(Config):
     def get(self, section, name):
         if not isinstance(section, tuple):
             section = (section, )
-        if not all([isinstance(subsection, bytes) for subsection in section]):
-            raise TypeError(section)
-        if not isinstance(name, bytes):
-            raise TypeError(name)
         for backend in self.backends:
             try:
                 return backend.get(section, name)

+ 16 - 13
dulwich/repo.py

@@ -195,14 +195,14 @@ class BaseRepo(object):
         self._put_named_file('description', b"Unnamed repository")
         f = BytesIO()
         cf = ConfigFile()
-        cf.set(b"core", b"repositoryformatversion", b"0")
+        cf.set("core", "repositoryformatversion", "0")
         if self._determine_file_mode():
-            cf.set(b"core", b"filemode", True)
+            cf.set("core", "filemode", True)
         else:
-            cf.set(b"core", b"filemode", False)
+            cf.set("core", "filemode", False)
 
-        cf.set(b"core", b"bare", bare)
-        cf.set(b"core", b"logallrefupdates", True)
+        cf.set("core", "bare", bare)
+        cf.set("core", "logallrefupdates", True)
         cf.write_to_file(f)
         self._put_named_file('config', f.getvalue())
         self._put_named_file(os.path.join('info', 'exclude'), b'')
@@ -522,21 +522,22 @@ class BaseRepo(object):
         config = self.get_config_stack()
         if user is None:
             try:
-                user = config.get((b"user", ), b"name")
+                user = config.get(("user", ), "name")
             except KeyError:
                 user = None
         if email is None:
             try:
-                email = config.get((b"user", ), b"email")
+                email = config.get(("user", ), "email")
             except KeyError:
                 email = None
         if user is None:
             import getpass
-            user = getpass.getuser()
+            user = getpass.getuser().encode(sys.getdefaultencoding())
         if email is None:
             import getpass
             import socket
-            email = b"%s@%s" % (getpass.getuser(), socket.gethostname())
+            email = ("%s@%s" % (getpass.getuser(), socket.gethostname())
+                    ).encode(sys.getdefaultencoding())
         return (user + b" <" + email + b">")
 
     def _add_graftpoints(self, updated_graftpoints):
@@ -753,7 +754,9 @@ class Repo(BaseRepo):
     def _write_reflog(self, ref, old_sha, new_sha, committer, timestamp,
                       timezone, message):
         from .reflog import format_reflog_line
-        path = os.path.join(self.controldir(), 'logs', ref)
+        path = os.path.join(
+                self.controldir(), 'logs',
+                ref.decode(sys.getfilesystemencoding()))
         try:
             os.makedirs(os.path.dirname(path))
         except OSError as e:
@@ -956,9 +959,9 @@ class Repo(BaseRepo):
         except KeyError:
             pass
         target_config = target.get_config()
-        target_config.set((b'remote', b'origin'), b'url', encoded_path)
-        target_config.set((b'remote', b'origin'), b'fetch',
-                          b'+refs/heads/*:refs/remotes/origin/*')
+        target_config.set(('remote', 'origin'), 'url', encoded_path)
+        target_config.set(('remote', 'origin'), 'fetch',
+                          '+refs/heads/*:refs/remotes/origin/*')
         target_config.write_to_path()
 
         # Update target head