Parcourir la source

Improve error handling

Jelmer Vernooij il y a 5 jours
Parent
commit
557b776195
2 fichiers modifiés avec 61 ajouts et 48 suppressions
  1. 37 39
      dulwich/cli.py
  2. 24 9
      dulwich/porcelain.py

+ 37 - 39
dulwich/cli.py

@@ -69,6 +69,13 @@ class Command:
         raise NotImplementedError(self.run)
 
 
+class CommandError(Exception):
+    """An error occurred while running a command."""
+
+    def __init__(self, message: str) -> None:
+        self.message = message
+
+
 class cmd_archive(Command):
     def run(self, args) -> None:
         parser = argparse.ArgumentParser()
@@ -143,9 +150,9 @@ class cmd_fetch(Command):
             sys.stdout.buffer.write(msg)
 
         refs = client.fetch(path, r, progress=progress)
-        print("Remote refs:")
+        sys.stdout.write("Remote refs:\n")
         for item in refs.items():
-            print("{} -> {}".format(*item))
+            sys.stdout.write("{} -> {}\n".format(*item))
 
 
 class cmd_for_each_ref(Command):
@@ -212,8 +219,7 @@ class cmd_dump_pack(Command):
         opts, args = getopt(args, "", [])
 
         if args == []:
-            print("Usage: dulwich dump-pack FILENAME")
-            sys.exit(1)
+            raise CommandError("Usage: dulwich dump-pack FILENAME")
 
         basename, _ = os.path.splitext(args[0])
         x = Pack(basename)
@@ -235,8 +241,7 @@ class cmd_dump_index(Command):
         opts, args = getopt(args, "", [])
 
         if args == []:
-            print("Usage: dulwich dump-index FILENAME")
-            sys.exit(1)
+            raise CommandError("Usage: dulwich dump-index FILENAME")
 
         filename = args[0]
         idx = Index(filename)
@@ -298,8 +303,7 @@ class cmd_clone(Command):
         options, args = parser.parse_args(args)
 
         if args == []:
-            print("usage: dulwich clone host:path [PATH]")
-            sys.exit(1)
+            raise CommandError("usage: dulwich clone host:path [PATH]")
 
         source = args.pop(0)
         if len(args) > 0:
@@ -319,7 +323,7 @@ class cmd_clone(Command):
                 protocol_version=options.protocol,
             )
         except GitProtocolError as e:
-            print(f"{e}")
+            raise CommandError(str(e)) from e
 
 
 class cmd_commit(Command):
@@ -333,8 +337,7 @@ class cmd_commit_tree(Command):
     def run(self, args) -> None:
         opts, args = getopt(args, "", ["message="])
         if args == []:
-            print("usage: dulwich commit-tree tree")
-            sys.exit(1)
+            raise CommandError("usage: dulwich commit-tree tree") from e
         kwopts = dict(opts)
         porcelain.commit_tree(".", tree=args[0], message=kwopts["--message"])
 
@@ -348,8 +351,7 @@ class cmd_symbolic_ref(Command):
     def run(self, args) -> None:
         opts, args = getopt(args, "", ["ref-name", "force"])
         if not args:
-            print("Usage: dulwich symbolic-ref REF_NAME [--force]")
-            sys.exit(1)
+            raise CommandError("Usage: dulwich symbolic-ref REF_NAME [--force]")
 
         ref_name = args.pop(0)
         porcelain.symbolic_ref(".", ref_name=ref_name, force="--force" in args)
@@ -379,8 +381,7 @@ class cmd_diff_tree(Command):
     def run(self, args) -> None:
         opts, args = getopt(args, "", [])
         if len(args) < 2:
-            print("Usage: dulwich diff-tree OLD-TREE NEW-TREE")
-            sys.exit(1)
+            raise CommandError("Usage: dulwich diff-tree OLD-TREE NEW-TREE")
         porcelain.diff_tree(".", args[0], args[1])
 
 
@@ -388,8 +389,7 @@ class cmd_rev_list(Command):
     def run(self, args) -> None:
         opts, args = getopt(args, "", [])
         if len(args) < 1:
-            print("Usage: dulwich rev-list COMMITID...")
-            sys.exit(1)
+            raise CommandError("Usage: dulwich rev-list COMMITID...")
         porcelain.rev_list(".", args)
 
 
@@ -562,8 +562,7 @@ class cmd_ls_remote(Command):
     def run(self, args) -> None:
         opts, args = getopt(args, "", [])
         if len(args) < 1:
-            print("Usage: dulwich ls-remote URL")
-            sys.exit(1)
+            raise CommandError("Usage: dulwich ls-remote URL")
         refs = porcelain.ls_remote(args[0])
         for ref in sorted(refs):
             sys.stdout.write(f"{ref}\t{refs[ref]}\n")
@@ -600,8 +599,7 @@ class cmd_pack_objects(Command):
         opts, args = getopt(args, "", ["stdout", "deltify", "no-reuse-deltas"])
         kwopts = dict(opts)
         if len(args) < 1 and "--stdout" not in kwopts.keys():
-            print("Usage: dulwich pack-objects basename")
-            sys.exit(1)
+            raise CommandError("Usage: dulwich pack-objects basename")
         object_ids = [line.strip() for line in sys.stdin.readlines()]
         if "--deltify" in kwopts.keys():
             deltify = True
@@ -636,7 +634,7 @@ class cmd_pull(Command):
             args.from_location or None,
             args.refspec or None,
             filter_spec=args.filter,
-            protocol_version=args.protocol or None,
+            protocol_version=args.protocol
         )
 
 
@@ -651,9 +649,8 @@ class cmd_push(Command):
             porcelain.push(
                 ".", args.to_location, args.refspec or None, force=args.force
             )
-        except porcelain.DivergedBranches:
-            sys.stderr.write("Diverged branches; specify --force to override")
-            return 1
+        except porcelain.DivergedBranches as e:
+            raise CommandError("Diverged branches; specify --force to override") from e
 
         return None
 
@@ -678,9 +675,8 @@ class SuperCommand(Command):
         cmd = args[0]
         try:
             cmd_kls = self.subcommands[cmd]
-        except KeyError:
-            print(f"No such subcommand: {args[0]}")
-            return False
+        except KeyError as e:
+            raise CommandError(f"No such subcommand: {args[0]}") from e
         return cmd_kls().run(args[1:])
 
 
@@ -749,8 +745,7 @@ class cmd_branch(Command):
         )
         args = parser.parse_args(args)
         if not args.branch:
-            print("Usage: dulwich branch [-d] BRANCH_NAME")
-            sys.exit(1)
+            raise CommandError("Usage: dulwich branch [-d] BRANCH_NAME")
 
         if args.delete:
             porcelain.branch_delete(".", name=args.branch)
@@ -758,8 +753,7 @@ class cmd_branch(Command):
             try:
                 porcelain.branch_create(".", name=args.branch)
             except porcelain.Error as e:
-                sys.stderr.write(f"{e}")
-                sys.exit(1)
+                raise CommandError(f"{e}") from e
 
 
 class cmd_checkout(Command):
@@ -778,14 +772,12 @@ class cmd_checkout(Command):
         )
         args = parser.parse_args(args)
         if not args.branch:
-            print("Usage: dulwich checkout BRANCH_NAME [--force]")
-            sys.exit(1)
+            raise CommandError("Usage: dulwich checkout BRANCH_NAME [--force]")
 
         try:
             porcelain.checkout_branch(".", target=args.branch, force=args.force)
         except porcelain.CheckoutError as e:
-            sys.stderr.write(f"{e}\n")
-            sys.exit(1)
+            raise CommandError(f"{e}") from e
 
 
 class cmd_stash_list(Command):
@@ -916,17 +908,23 @@ def main(argv=None):
         argv = sys.argv[1:]
 
     if len(argv) < 1:
-        print("Usage: dulwich <{}> [OPTIONS...]".format("|".join(commands.keys())))
+        raise CommandError("Usage: dulwich <{}> [OPTIONS...]".format("|".join(commands.keys())))
         return 1
 
     cmd = argv[0]
     try:
         cmd_kls = commands[cmd]
     except KeyError:
-        print(f"No such subcommand: {cmd}")
+        sys.stderr.write(f"No such subcommand: {cmd}\n")
         return 1
     # TODO(jelmer): Return non-0 on errors
-    return cmd_kls().run(argv[1:])
+    try:
+        return cmd_kls().run(argv[1:])
+    except CommandError as e:
+        sys.stderr.write(f"Error: {e.message}\n")
+        return 1
+    except KeyboardInterrupt:
+        return 1
 
 
 def _main() -> None:

+ 24 - 9
dulwich/porcelain.py

@@ -729,7 +729,9 @@ def remove(repo=".", paths=None, cached=False) -> None:
 rm = remove
 
 
-def commit_decode(commit, contents, default_encoding=DEFAULT_ENCODING):
+def commit_decode(commit: Commit, contents: Optional[bytes], default_encoding: str = DEFAULT_ENCODING) -> str:
+    if contents is None:
+        return ""
     if commit.encoding:
         encoding = commit.encoding.decode("ascii")
     else:
@@ -737,7 +739,7 @@ def commit_decode(commit, contents, default_encoding=DEFAULT_ENCODING):
     return contents.decode(encoding, "replace")
 
 
-def commit_encode(commit, contents, default_encoding=DEFAULT_ENCODING):
+def commit_encode(commit: Commit, contents: str, default_encoding: str = DEFAULT_ENCODING) -> bytes:
     if commit.encoding:
         encoding = commit.encoding.decode("ascii")
     else:
@@ -745,7 +747,7 @@ def commit_encode(commit, contents, default_encoding=DEFAULT_ENCODING):
     return contents.encode(encoding)
 
 
-def print_commit(commit, decode, outstream=sys.stdout) -> None:
+def print_commit(commit: Commit, decode, outstream=sys.stdout) -> None:
     """Write a human-readable commit log entry.
 
     Args:
@@ -860,7 +862,7 @@ def show_object(repo, obj, decode, outstream):
     }[obj.type_name](repo, obj, decode, outstream)
 
 
-def print_name_status(changes):
+def print_name_status(changes, encoding: str = DEFAULT_ENCODING) -> None:
     """Print a simple status summary, listing changed files."""
     for change in changes:
         if not change:
@@ -869,15 +871,15 @@ def print_name_status(changes):
             change = change[0]
         if change.type == CHANGE_ADD:
             path1 = change.new.path
-            path2 = ""
+            path2 = b""
             kind = "A"
         elif change.type == CHANGE_DELETE:
             path1 = change.old.path
-            path2 = ""
+            path2 = b""
             kind = "D"
         elif change.type == CHANGE_MODIFY:
             path1 = change.new.path
-            path2 = ""
+            path2 = b""
             kind = "M"
         elif change.type in RENAME_CHANGE_TYPES:
             path1 = change.old.path
@@ -886,7 +888,7 @@ def print_name_status(changes):
                 kind = "R"
             elif change.type == CHANGE_COPY:
                 kind = "C"
-        yield "%-8s%-20s%-20s" % (kind, path1, path2)  # noqa: UP031
+        yield "%-8s%-20s%-20s" % (kind, path1.decode(encoding), path2.decode(encoding))  # noqa: UP031
 
 
 def log(
@@ -912,8 +914,21 @@ def log(
             include = [r.head()]
         except KeyError:
             include = []
+        if paths is None:
+            encoded_paths = None
+        else:
+            encoded_paths = []
+            for path in paths:
+                if isinstance(path, str):
+                    encoded_path = path.encode(DEFAULT_ENCODING)
+                elif isinstance(path, bytes):
+                    encoded_path = path
+                else:
+                    raise ValueError(f"Invalid path type: {type(path)}")
+                encoded_paths.append(encoded_path)
         walker = r.get_walker(
-            include=include, max_entries=max_entries, paths=paths, reverse=reverse
+            include=include, max_entries=max_entries, paths=encoded_paths,
+            reverse=reverse
         )
         for entry in walker: