Browse Source

Add some typing, clarify behaviour of author argument to Repo.do_commit(). Fixes #786

Jelmer Vernooij 4 năm trước cách đây
mục cha
commit
3082274e39
3 tập tin đã thay đổi với 86 bổ sung48 xóa
  1. 13 6
      dulwich/config.py
  2. 0 1
      dulwich/porcelain.py
  3. 73 41
      dulwich/repo.py

+ 13 - 6
dulwich/config.py

@@ -29,6 +29,8 @@ TODO:
 import os
 import sys
 
+from typing import BinaryIO, Tuple, Optional
+
 from collections import (
     OrderedDict,
     )
@@ -380,12 +382,17 @@ class ConfigFile(ConfigDict):
     """A Git configuration file, like .git/config or ~/.gitconfig.
     """
 
+    def __init__(self, values=None, encoding=None):
+        super(ConfigFile, self).__init__(values=values, encoding=encoding)
+        self.path = None
+
     @classmethod
-    def from_file(cls, f):
+    def from_file(cls, f: BinaryIO) -> 'ConfigFile':
         """Read configuration from a file-like object."""
         ret = cls()
-        section = None
+        section = None  # type: Optional[Tuple[bytes, ...]]
         setting = None
+        continuation = None
         for lineno, line in enumerate(f.readlines()):
             line = line.lstrip()
             if setting is None:
@@ -429,7 +436,7 @@ class ConfigFile(ConfigDict):
                     value = b"true"
                 setting = setting.strip()
                 if not _check_variable_name(setting):
-                    raise ValueError("invalid variable name %s" % setting)
+                    raise ValueError("invalid variable name %r" % setting)
                 if value.endswith(b"\\\n"):
                     continuation = value[:-2]
                 else:
@@ -449,21 +456,21 @@ class ConfigFile(ConfigDict):
         return ret
 
     @classmethod
-    def from_path(cls, path):
+    def from_path(cls, path) -> 'ConfigFile':
         """Read configuration from a file on disk."""
         with GitFile(path, 'rb') as f:
             ret = cls.from_file(f)
             ret.path = path
             return ret
 
-    def write_to_path(self, path=None):
+    def write_to_path(self, path=None) -> None:
         """Write configuration to a file on disk."""
         if path is None:
             path = self.path
         with GitFile(path, 'wb') as f:
             self.write_to_file(f)
 
-    def write_to_file(self, f):
+    def write_to_file(self, f: BinaryIO) -> None:
         """Write configuration to a file-like object."""
         for section, values in self._values.items():
             try:

+ 0 - 1
dulwich/porcelain.py

@@ -928,7 +928,6 @@ def get_remote_repo(
         encoded_location = url
     else:
         remote_name = None
-        config = None
 
     return (remote_name, encoded_location.decode())
 

+ 73 - 41
dulwich/repo.py

@@ -33,6 +33,14 @@ import os
 import sys
 import stat
 import time
+from typing import Optional, Tuple, TYPE_CHECKING, List, Dict, Union, Iterable
+
+if TYPE_CHECKING:
+    # There are no circular imports here, but we try to defer imports as long
+    # as possible to reduce start-up time for anything that doesn't need
+    # these imports.
+    from dulwich.config import StackedConfig, ConfigFile
+    from dulwich.index import Index
 
 from dulwich.errors import (
     NoIndexPresent,
@@ -51,6 +59,7 @@ from dulwich.file import (
 from dulwich.object_store import (
     DiskObjectStore,
     MemoryObjectStore,
+    BaseObjectStore,
     ObjectStoreGraphWalker,
     )
 from dulwich.objects import (
@@ -66,6 +75,7 @@ from dulwich.pack import (
     )
 
 from dulwich.hooks import (
+    Hook,
     PreCommitShellHook,
     PostCommitShellHook,
     CommitMsgShellHook,
@@ -120,7 +130,7 @@ class InvalidUserIdentity(Exception):
         self.identity = identity
 
 
-def _get_default_identity():
+def _get_default_identity() -> Tuple[str, str]:
     import getpass
     import socket
     username = getpass.getuser()
@@ -143,19 +153,38 @@ def _get_default_identity():
     return (fullname, email)
 
 
-def get_user_identity(config, kind=None):
+def get_user_identity(
+        config: 'StackedConfig',
+        kind: Optional[str] = None) -> bytes:
     """Determine the identity to use for new commits.
+
+    If kind is set, this first checks
+    GIT_${KIND}_NAME and GIT_${KIND}_EMAIL.
+
+    If those variables are not set, then it will fall back
+    to reading the user.name and user.email settings from
+    the specified configuration.
+
+    If that also fails, then it will fall back to using
+    the current users' identity as obtained from the host
+    system (e.g. the gecos field, $EMAIL, $USER@$(hostname -f).
+
+    Args:
+      kind: Optional kind to return identity for,
+        usually either "AUTHOR" or "COMMITTER".
+
+    Returns:
+      A user identity
     """
+    user = None  # type: Optional[bytes]
+    email = None  # type: Optional[bytes]
     if kind:
-        user = os.environ.get("GIT_" + kind + "_NAME")
-        if user is not None:
-            user = user.encode('utf-8')
-        email = os.environ.get("GIT_" + kind + "_EMAIL")
-        if email is not None:
-            email = email.encode('utf-8')
-    else:
-        user = None
-        email = None
+        user_uc = os.environ.get("GIT_" + kind + "_NAME")
+        if user_uc is not None:
+            user = user_uc.encode('utf-8')
+        email_uc = os.environ.get("GIT_" + kind + "_EMAIL")
+        if email_uc is not None:
+            email = email_uc.encode('utf-8')
     if user is None:
         try:
             user = config.get(("user", ), "name")
@@ -168,16 +197,12 @@ def get_user_identity(config, kind=None):
             email = None
     default_user, default_email = _get_default_identity()
     if user is None:
-        user = default_user
-        if not isinstance(user, bytes):
-            user = user.encode('utf-8')
+        user = default_user.encode('utf-8')
     if email is None:
-        email = default_email
-        if not isinstance(email, bytes):
-            email = email.encode('utf-8')
+        email = default_email.encode('utf-8')
     if email.startswith(b'<') and email.endswith(b'>'):
         email = email[1:-1]
-    return (user + b" <" + email + b">")
+    return user + b" <" + email + b">"
 
 
 def check_user_identity(identity):
@@ -196,7 +221,8 @@ def check_user_identity(identity):
         raise InvalidUserIdentity(identity)
 
 
-def parse_graftpoints(graftpoints):
+def parse_graftpoints(
+        graftpoints: Iterable[bytes]) -> Dict[bytes, List[bytes]]:
     """Convert a list of graftpoints into a dict
 
     Args:
@@ -227,7 +253,7 @@ def parse_graftpoints(graftpoints):
     return grafts
 
 
-def serialize_graftpoints(graftpoints):
+def serialize_graftpoints(graftpoints: Dict[bytes, List[bytes]]) -> bytes:
     """Convert a dictionary of grafts into string
 
     The graft dictionary is:
@@ -279,7 +305,7 @@ class BaseRepo(object):
         repository
     """
 
-    def __init__(self, object_store, refs):
+    def __init__(self, object_store: BaseObjectStore, refs: RefsContainer):
         """Open a repository.
 
         This shouldn't be called directly, but rather through one of the
@@ -292,17 +318,17 @@ class BaseRepo(object):
         self.object_store = object_store
         self.refs = refs
 
-        self._graftpoints = {}
-        self.hooks = {}
+        self._graftpoints = {}  # type: Dict[bytes, List[bytes]]
+        self.hooks = {}  # type: Dict[str, Hook]
 
-    def _determine_file_mode(self):
+    def _determine_file_mode(self) -> bool:
         """Probe the file-system to determine whether permissions can be trusted.
 
         Returns: True if permissions can be trusted, False otherwise.
         """
         raise NotImplementedError(self._determine_file_mode)
 
-    def _init_files(self, bare):
+    def _init_files(self, bare: bool) -> None:
         """Initialize a default set of named files."""
         from dulwich.config import ConfigFile
         self._put_named_file('description', b"Unnamed repository")
@@ -501,14 +527,14 @@ class BaseRepo(object):
         return ObjectStoreGraphWalker(
             heads, self.get_parents, shallow=self.get_shallow())
 
-    def get_refs(self):
+    def get_refs(self) -> Dict[bytes, bytes]:
         """Get dictionary with all refs.
 
         Returns: A ``dict`` mapping ref names to SHA1s
         """
         return self.refs.as_dict()
 
-    def head(self):
+    def head(self) -> bytes:
         """Return the SHA1 pointed at by HEAD."""
         return self.refs[b'HEAD']
 
@@ -529,7 +555,7 @@ class BaseRepo(object):
                   ret.type_name, cls.type_name))
         return ret
 
-    def get_object(self, sha):
+    def get_object(self, sha: bytes) -> ShaFile:
         """Retrieve the object with the specified SHA.
 
         Args:
@@ -540,7 +566,7 @@ class BaseRepo(object):
         """
         return self.object_store[sha]
 
-    def get_parents(self, sha, commit=None):
+    def get_parents(self, sha: bytes, commit: Commit = None) -> List[bytes]:
         """Retrieve the parents of a specific commit.
 
         If the specific commit is a graftpoint, the graft parents
@@ -551,7 +577,6 @@ class BaseRepo(object):
           commit: Optional commit matching the sha
         Returns: List of parents
         """
-
         try:
             return self._graftpoints[sha]
         except KeyError:
@@ -582,7 +607,7 @@ class BaseRepo(object):
         """
         raise NotImplementedError(self.set_description)
 
-    def get_config_stack(self):
+    def get_config_stack(self) -> 'StackedConfig':
         """Return a config stack for this repository.
 
         This stack accesses the configuration for both this repository
@@ -695,7 +720,7 @@ class BaseRepo(object):
         except RefFormatError:
             raise KeyError(name)
 
-    def __contains__(self, name):
+    def __contains__(self, name: bytes) -> bool:
         """Check if a specific Git object or ref is present.
 
         Args:
@@ -706,7 +731,7 @@ class BaseRepo(object):
         else:
             return name in self.refs
 
-    def __setitem__(self, name, value):
+    def __setitem__(self, name: bytes, value: Union[ShaFile, bytes]):
         """Set a ref.
 
         Args:
@@ -723,7 +748,7 @@ class BaseRepo(object):
         else:
             raise ValueError(name)
 
-    def __delitem__(self, name):
+    def __delitem__(self, name: bytes):
         """Remove a ref.
 
         Args:
@@ -734,13 +759,14 @@ class BaseRepo(object):
         else:
             raise ValueError(name)
 
-    def _get_user_identity(self, config, kind=None):
+    def _get_user_identity(
+            self, config: 'StackedConfig', kind: str = None) -> bytes:
         """Determine the identity to use for new commits.
         """
         # TODO(jelmer): Deprecate this function in favor of get_user_identity
         return get_user_identity(config)
 
-    def _add_graftpoints(self, updated_graftpoints):
+    def _add_graftpoints(self, updated_graftpoints: Dict[bytes, List[bytes]]):
         """Add or modify graftpoints
 
         Args:
@@ -754,7 +780,7 @@ class BaseRepo(object):
 
         self._graftpoints.update(updated_graftpoints)
 
-    def _remove_graftpoints(self, to_remove=[]):
+    def _remove_graftpoints(self, to_remove: List[bytes] = []) -> None:
         """Remove graftpoints
 
         Args:
@@ -777,10 +803,14 @@ class BaseRepo(object):
                   ref=b'HEAD', merge_heads=None):
         """Create a new commit.
 
+        If not specified, `committer` and `author` default to
+        get_user_identity(..., 'COMMITTER')
+        and get_user_identity(..., 'AUTHOR') respectively.
+
         Args:
           message: Commit message
           committer: Committer fullname
-          author: Author fullname (defaults to committer)
+          author: Author fullname
           commit_timestamp: Commit timestamp (defaults to now)
           commit_timezone: Commit timestamp timezone (defaults to GMT)
           author_timestamp: Author timestamp (defaults to commit
@@ -792,7 +822,9 @@ class BaseRepo(object):
           encoding: Encoding
           ref: Optional ref to commit to (defaults to current branch)
           merge_heads: Merge heads (defaults to .git/MERGE_HEADS)
-        Returns: New commit SHA1
+
+        Returns:
+          New commit SHA1
         """
         import time
         c = Commit()
@@ -1093,7 +1125,7 @@ class Repo(BaseRepo):
         """Return path to the index file."""
         return os.path.join(self.controldir(), INDEX_FILENAME)
 
-    def open_index(self):
+    def open_index(self) -> 'Index':
         """Open the index for this repository.
 
         Raises:
@@ -1241,7 +1273,7 @@ class Repo(BaseRepo):
             honor_filemode=honor_filemode,
             validate_path_element=validate_path_element)
 
-    def get_config(self):
+    def get_config(self) -> 'ConfigFile':
         """Retrieve the config object.
 
         Returns: `ConfigFile` object for the ``.git/config`` file.