Преглед на файлове

parse and serialize graftpoints

milki преди 11 години
родител
ревизия
c024478d90
променени са 1 файла, в които са добавени 91 реда и са изтрити 0 реда
  1. 91 0
      dulwich/repo.py

+ 91 - 0
dulwich/repo.py

@@ -51,6 +51,7 @@ from dulwich.object_store import (
     ObjectStoreGraphWalker,
     )
 from dulwich.objects import (
+    check_hexsha,
     Blob,
     Commit,
     ShaFile,
@@ -96,6 +97,57 @@ BASE_DIRECTORIES = [
     ]
 
 
+def parse_graftpoints(graftpoints=iter([])):
+    """Convert a list of graftpoints into a dict
+
+    :param graftpoints: Iterator of graftpoint lines
+
+    Each line is formatted as:
+        <commit sha1> <parent sha1> [<parent sha1>]*
+
+    Resulting dictionary is:
+        <commit sha1>: [<parent sha1>*]
+
+    https://git.wiki.kernel.org/index.php/GraftPoint
+    """
+    grafts = {}
+    for l in graftpoints:
+        raw_graft = l.split(None, 1)
+
+        commit = raw_graft[0]
+        if len(raw_graft) == 2:
+            parents = raw_graft[1].split()
+        else:
+            parents = []
+
+        for sha in [commit] + parents:
+            check_hexsha(sha, 'Invalid graftpoint')
+
+        grafts[commit] = parents
+    return grafts
+
+
+def serialize_graftpoints(graftpoints={}):
+    """Convert a dictionary of grafts into string
+
+    The graft dictionary is:
+        <commit sha1>: [<parent sha1>*]
+
+    Each line is formatted as:
+        <commit sha1> <parent sha1> [<parent sha1>]*
+
+    https://git.wiki.kernel.org/index.php/GraftPoint
+
+    """
+    graft_lines = []
+    for commit, parents in graftpoints.iteritems():
+        if parents:
+            graft_lines.append('%s %s' % (commit, ' '.join(parents)))
+        else:
+            graft_lines.append(commit)
+    return '\n'.join(graft_lines)
+
+
 class BaseRepo(object):
     """Base class for a git repository.
 
@@ -117,6 +169,7 @@ class BaseRepo(object):
         self.object_store = object_store
         self.refs = refs
 
+        self.graftpoints = {}
         self.hooks = {}
 
     def _init_files(self, bare):
@@ -478,6 +531,34 @@ class BaseRepo(object):
             config.get(("user", ), "name"),
             config.get(("user", ), "email"))
 
+    def add_graftpoints(self, updated_graftpoints):
+        """Add or modify graftpoints
+
+        :param updated_graftpoints: Dict of commit shas to list of parent shas
+        """
+
+        # Simple validation
+        for commit, parents in updated_graftpoints.iteritems():
+            for sha in [commit] + parents:
+                check_hexsha(sha, 'Invalid graftpoint')
+
+        self.graftpoints.update(updated_graftpoints)
+
+    def remove_graftpoints(self, to_remove=[]):
+        """Remove graftpoints
+
+        :param to_remove: List of commit shas
+        """
+        for sha in to_remove:
+            del self.graftpoints[sha]
+
+    def serialize_graftpoints(self):
+        """Get the string representation of the graftpoints
+
+        This format is writable to a graftpoint file.
+        """
+        return serialize_graftpoints(self.graftpoints)
+
     def do_commit(self, message=None, committer=None,
                   author=None, commit_timestamp=None,
                   commit_timezone=None, author_timestamp=None,
@@ -620,6 +701,10 @@ class Repo(BaseRepo):
         refs = DiskRefsContainer(self.controldir())
         BaseRepo.__init__(self, object_store, refs)
 
+        graft_file = self.get_named_file(os.path.join("info", "grafts"))
+        if graft_file:
+            self.graftpoints = parse_graftpoints(graft_file)
+
         self.hooks['pre-commit'] = PreCommitShellHook(self.controldir())
         self.hooks['commit-msg'] = CommitMsgShellHook(self.controldir())
         self.hooks['post-commit'] = PostCommitShellHook(self.controldir())
@@ -641,6 +726,9 @@ class Repo(BaseRepo):
         finally:
             f.close()
 
+        if path == os.path.join("info", "grafts"):
+            self.graftpoints = parse_graftpoints(iter(contents.splitlines()))
+
     def get_named_file(self, path):
         """Get a file from the control dir with a specific name.
 
@@ -867,6 +955,9 @@ class MemoryRepo(BaseRepo):
         """
         self._named_files[path] = contents
 
+        if path == os.path.join("info", "grafts"):
+            self.graftpoints = parse_graftpoints(contents.splitlines())
+
     def get_named_file(self, path):
         """Get a file from the control dir with a specific name.