Sfoglia il codice sorgente

Merge pull request #1195 from jelmer/enum-stage

Define enum for stage values
Jelmer Vernooij 1 anno fa
parent
commit
9853838097
2 ha cambiato i file con 63 aggiunte e 59 eliminazioni
  1. 61 58
      dulwich/index.py
  2. 2 1
      dulwich/tests/test_index.py

+ 61 - 58
dulwich/index.py

@@ -25,6 +25,7 @@ import os
 import stat
 import stat
 import struct
 import struct
 import sys
 import sys
+from enum import Enum
 from typing import (
 from typing import (
     Any,
     Any,
     BinaryIO,
     BinaryIO,
@@ -90,19 +91,20 @@ EXTENDED_FLAG_INTEND_TO_ADD = 0x2000
 DEFAULT_VERSION = 2
 DEFAULT_VERSION = 2
 
 
 
 
+class Stage(Enum):
+    NORMAL = 0
+    MERGE_CONFLICT_ANCESTOR = 1
+    MERGE_CONFLICT_THIS = 2
+    MERGE_CONFLICT_OTHER = 3
+
+
 class UnmergedEntriesInIndexEx(Exception):
 class UnmergedEntriesInIndexEx(Exception):
     def __init__(self, message):
     def __init__(self, message):
         super().__init__(message)
         super().__init__(message)
 
 
 
 
-def read_stage(entry: IndexEntry) -> int:
-    """Stage of an Entry
-       0 - normal
-       1 - merge conflict 'ancestor' entry
-       2 - merge conflict 'this' entry
-       3 - merge conflict 'other' entry
-     """
-    return (entry.flags & FLAG_STAGEMASK) >> FLAG_STAGESHIFT
+def read_stage(entry: IndexEntry) -> Stage:
+    return Stage((entry.flags & FLAG_STAGEMASK) >> FLAG_STAGESHIFT)
 
 
 
 
 def pathsplit(path: bytes) -> Tuple[bytes, bytes]:
 def pathsplit(path: bytes) -> Tuple[bytes, bytes]:
@@ -155,7 +157,7 @@ def write_cache_time(f, t):
     f.write(struct.pack(">LL", *t))
     f.write(struct.pack(">LL", *t))
 
 
 
 
-def read_cache_entry(f, version: int) -> Tuple[str, IndexEntry]:
+def read_cache_entry(f, version: int) -> Tuple[bytes, IndexEntry]:
     """Read an entry from a cache file.
     """Read an entry from a cache file.
 
 
     Args:
     Args:
@@ -260,12 +262,12 @@ def read_index(f: BinaryIO):
         yield read_cache_entry(f, version)
         yield read_cache_entry(f, version)
 
 
 
 
-def read_index_dict(f) -> Dict[Tuple[bytes, int], IndexEntry]:
+def read_index_dict(f) -> Dict[Tuple[bytes, Stage], IndexEntry]:
     """Read an index file and return it as a dictionary.
     """Read an index file and return it as a dictionary.
        Dict Key is tuple of path and stage number, as
        Dict Key is tuple of path and stage number, as
             path alone is not unique
             path alone is not unique
     Args:
     Args:
-      f: File object to read fromls
+      f: File object to read fromls.
     """
     """
     ret = {}
     ret = {}
     for name, entry in read_index(f):
     for name, entry in read_index(f):
@@ -292,11 +294,11 @@ def write_index(f: BinaryIO, entries: List[Tuple[bytes, IndexEntry]], version: O
 
 
 def write_index_dict(
 def write_index_dict(
     f: BinaryIO,
     f: BinaryIO,
-    entries: Dict[Tuple[bytes, int], IndexEntry],
+    entries: Dict[Tuple[bytes, Stage], IndexEntry],
     version: Optional[int] = None,
     version: Optional[int] = None,
 ) -> None:
 ) -> None:
     """Write an index file based on the contents of a dictionary.
     """Write an index file based on the contents of a dictionary.
-       being careful to sort by path and then by stage
+    being careful to sort by path and then by stage.
     """
     """
     entries_list = []
     entries_list = []
     for key in sorted(entries):
     for key in sorted(entries):
@@ -304,7 +306,7 @@ def write_index_dict(
             name, stage = key
             name, stage = key
         else:
         else:
             name = key
             name = key
-            stage = 0
+            stage = Stage.NORMAL
         entries_list.append((name, entries[(name, stage)]))
         entries_list.append((name, entries[(name, stage)]))
     write_index(f, entries_list, version=version)
     write_index(f, entries_list, version=version)
 
 
@@ -335,6 +337,8 @@ def cleanup_mode(mode: int) -> int:
 class Index:
 class Index:
     """A Git Index file."""
     """A Git Index file."""
 
 
+    _bynamestage: Dict[Tuple[bytes, Stage], IndexEntry]
+
     def __init__(self, filename: Union[bytes, str], read=True) -> None:
     def __init__(self, filename: Union[bytes, str], read=True) -> None:
         """Create an index object associated with the given filename.
         """Create an index object associated with the given filename.
 
 
@@ -385,7 +389,7 @@ class Index:
         """Number of entries in this index file."""
         """Number of entries in this index file."""
         return len(self._bynamestage)
         return len(self._bynamestage)
 
 
-    def __getitem__(self, key: Union[Tuple[bytes, int], bytes]) -> IndexEntry:
+    def __getitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> IndexEntry:
         """Retrieve entry by relative path and stage.
         """Retrieve entry by relative path and stage.
 
 
         Returns: tuple with (ctime, mtime, dev, ino, mode, uid, gid, size, sha,
         Returns: tuple with (ctime, mtime, dev, ino, mode, uid, gid, size, sha,
@@ -393,32 +397,32 @@ class Index:
         """
         """
         if isinstance(key, tuple):
         if isinstance(key, tuple):
             return self._bynamestage[key]
             return self._bynamestage[key]
-        if (key, 0) in self._bynamestage:
-            return self._bynamestage[(key, 0)]
+        if (key, Stage.NORMAL) in self._bynamestage:
+            return self._bynamestage[(key, Stage.NORMAL)]
         # there is a conflict return 'this' entry
         # there is a conflict return 'this' entry
-        return self._bynamestage[(key, 2)]
+        return self._bynamestage[(key, Stage.MERGE_CONFLICT_THIS)]
 
 
     def __iter__(self) -> Iterator[bytes]:
     def __iter__(self) -> Iterator[bytes]:
         """Iterate over the paths and stages in this index."""
         """Iterate over the paths and stages in this index."""
         for (name, stage) in self._bynamestage:
         for (name, stage) in self._bynamestage:
-            if stage == 1 or stage == 3:
+            if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
                 continue
                 continue
             yield name
             yield name
 
 
     def __contains__(self, key):
     def __contains__(self, key):
         if isinstance(key, tuple):
         if isinstance(key, tuple):
             return key in self._bynamestage
             return key in self._bynamestage
-        if (key, 0) in self._bynamestage:
+        if (key, Stage.NORMAL) in self._bynamestage:
             return True
             return True
-        if (key, 2) in self._bynamestage:
+        if (key, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
             return True
             return True
         return False
         return False
-    
-    def get_sha1(self, path: bytes, stage: int = 0) -> bytes:
+
+    def get_sha1(self, path: bytes, stage: Stage = Stage.NORMAL) -> bytes:
         """Return the (git object) SHA1 for the object at a path."""
         """Return the (git object) SHA1 for the object at a path."""
         return self[(path, stage)].sha
         return self[(path, stage)].sha
 
 
-    def get_mode(self, path: bytes, stage: int = 0) -> int:
+    def get_mode(self, path: bytes, stage: Stage = Stage.NORMAL) -> int:
         """Return the POSIX file mode for the object at a path."""
         """Return the POSIX file mode for the object at a path."""
         return self[(path, stage)].mode
         return self[(path, stage)].mode
 
 
@@ -428,15 +432,15 @@ class Index:
             entry = self[path]
             entry = self[path]
             yield path, entry.sha, cleanup_mode(entry.mode)
             yield path, entry.sha, cleanup_mode(entry.mode)
 
 
-    def iterconflicts(self) -> Iterable[Tuple[int, bytes, int, bytes]]:
+    def iterconflicts(self) -> Iterable[Tuple[int, bytes, Stage, bytes]]:
         """Iterate over path, sha, mode tuples for use with commit_tree."""
         """Iterate over path, sha, mode tuples for use with commit_tree."""
         for (name, stage), entry in self._bynamestage.items():
         for (name, stage), entry in self._bynamestage.items():
-            if stage > 0:
+            if stage != Stage.NORMAL:
                 yield cleanup_mode(entry.mode), entry.sha, stage, name
                 yield cleanup_mode(entry.mode), entry.sha, stage, name
 
 
     def has_conflicts(self):
     def has_conflicts(self):
         for (name, stage) in self._bynamestage.keys():
         for (name, stage) in self._bynamestage.keys():
-            if stage > 0:
+            if stage != Stage.NORMAL:
                 return True
                 return True
         return False
         return False
 
 
@@ -452,66 +456,66 @@ class Index:
                            sha,
                            sha,
                            stage << FLAG_STAGESHIFT,
                            stage << FLAG_STAGESHIFT,
                            0)
                            0)
-        if (apath, 0) in self._bynamestage:
-            del self._bynamestage[(apath, 0)]
+        if (apath, Stage.NORMAL) in self._bynamestage:
+            del self._bynamestage[(apath, Stage.NORMAL)]
         self._bynamestage[(apath, stage)] = entry
         self._bynamestage[(apath, stage)] = entry
 
 
     def clear(self):
     def clear(self):
         """Remove all contents from this index."""
         """Remove all contents from this index."""
         self._bynamestage = {}
         self._bynamestage = {}
 
 
-    def __setitem__(self, key: Union[Tuple[bytes, int], bytes], x: IndexEntry) -> None:
+    def __setitem__(self, key: Union[Tuple[bytes, Stage], bytes], x: IndexEntry) -> None:
         assert len(x) == len(IndexEntry._fields)
         assert len(x) == len(IndexEntry._fields)
         if isinstance(key, tuple):
         if isinstance(key, tuple):
             name, stage = key
             name, stage = key
         else:
         else:
             name = key
             name = key
-            stage = 0  # default when stage not explicitly specified
+            stage = Stage.NORMAL  # default when stage not explicitly specified
         assert isinstance(name, bytes)
         assert isinstance(name, bytes)
         # Remove merge conflict entries if new entry is stage 0
         # Remove merge conflict entries if new entry is stage 0
-        # Remove stage 0 entry if new entry has conflicts (stage > 0)
-        if stage == 0:
-            if (name, 1) in self._bynamestage:
-                del self._bynamestage[(name, 1)]
-            if (name, 2) in self._bynamestage:
-                del self._bynamestage[(name, 2)]
-            if (name, 3) in self._bynamestage:
-                del self._bynamestage[(name, 3)]
-        if stage > 0 and (name, 0) in self._bynamestage:
-            del self._bynamestage[(name, 0)]
+        # Remove normal stage entry if new entry has conflicts (stage > 0)
+        if stage == Stage.NORMAL:
+            if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
+                del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
+            if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
+                del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
+            if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
+                del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
+        if stage != Stage.NORMAL and (name, Stage.NORMAL) in self._bynamestage:
+            del self._bynamestage[(name, Stage.NORMAL)]
         self._bynamestage[(name, stage)] = IndexEntry(*x)
         self._bynamestage[(name, stage)] = IndexEntry(*x)
 
 
-    def __delitem__(self, key: Union[Tuple[bytes, int], bytes]) -> None:
+    def __delitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> None:
         if isinstance(key, tuple):
         if isinstance(key, tuple):
             del self._bynamestage[key]
             del self._bynamestage[key]
             return
             return
         name = key
         name = key
         assert isinstance(name, bytes)
         assert isinstance(name, bytes)
-        if (name, 0) in self._bynamestage:
-            del self._bynamestage[(name, 0)]
-        if (name, 1) in self._bynamestage:
-            del self._bynamestage[(name, 1)]
-        if (name, 2) in self._bynamestage:
-            del self._bynamestage[(name, 2)]
-        if (name, 3) in self._bynamestage:
-            del self._bynamestage[(name, 3)]
+        if (name, Stage.NORMAL) in self._bynamestage:
+            del self._bynamestage[(name, Stage.NORMAL)]
+        if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
+            del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
+        if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
+            del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
+        if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
+            del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
 
 
     def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
     def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
         for (name, stage), entry in self._bynamestage.items():
         for (name, stage), entry in self._bynamestage.items():
             yield name, entry
             yield name, entry
 
 
-    def items(self) -> Iterator[Tuple[Tuple[bytes, int], IndexEntry]]:
-        return self._bynamestage.items()
+    def items(self) -> Iterator[Tuple[Tuple[bytes, Stage], IndexEntry]]:
+        return iter(self._bynamestage.items())
 
 
-    def update(self, entries: Dict[Tuple[bytes, int], IndexEntry]):
+    def update(self, entries: Dict[Tuple[bytes, Stage], IndexEntry]):
         for key, value in entries.items():
         for key, value in entries.items():
             self[key] = value
             self[key] = value
 
 
     def paths(self):
     def paths(self):
         for (name, stage) in self._bynamestage.keys():
         for (name, stage) in self._bynamestage.keys():
-            if stage == 0 or stage == 2:  # normal or conflict 'this'
+            if stage == Stage.NORMAL or stage == Stage.MERGE_CONFLICT_THIS:
                 yield name
                 yield name
-    
+
     def changes_from_tree(
     def changes_from_tree(
             self, object_store, tree: ObjectID, want_unchanged: bool = False):
             self, object_store, tree: ObjectID, want_unchanged: bool = False):
         """Find the differences between the contents of this index and a tree.
         """Find the differences between the contents of this index and a tree.
@@ -561,7 +565,6 @@ def commit_tree(
     Returns:
     Returns:
       SHA1 of the created tree.
       SHA1 of the created tree.
     """
     """
-
     trees: Dict[bytes, Any] = {b"": {}}
     trees: Dict[bytes, Any] = {b"": {}}
 
 
     def add_tree(path):
     def add_tree(path):
@@ -855,7 +858,7 @@ def build_index_from_tree(
             st = st.__class__(st_tuple)
             st = st.__class__(st_tuple)
             # default to a stage 0 index entry (normal)
             # default to a stage 0 index entry (normal)
             # when reading from the filesystem
             # when reading from the filesystem
-        index[(entry.path, 0)] = index_entry_from_stat(st, entry.sha, 0)
+        index[(entry.path, Stage.NORMAL)] = index_entry_from_stat(st, entry.sha, 0)
 
 
     index.write()
     index.write()
 
 
@@ -960,7 +963,7 @@ def get_unstaged_changes(
     for tree_path, entry in index.iteritems():
     for tree_path, entry in index.iteritems():
         full_path = _tree_to_fs_path(root_path, tree_path)
         full_path = _tree_to_fs_path(root_path, tree_path)
         stage = read_stage(entry)
         stage = read_stage(entry)
-        if stage == 1 or stage == 3:
+        if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
             continue
             continue
         try:
         try:
             st = os.lstat(full_path)
             st = os.lstat(full_path)

+ 2 - 1
dulwich/tests/test_index.py

@@ -34,6 +34,7 @@ from dulwich.tests import TestCase, skipIf
 from ..index import (
 from ..index import (
     Index,
     Index,
     IndexEntry,
     IndexEntry,
+    Stage,
     _fs_to_tree_path,
     _fs_to_tree_path,
     _tree_to_fs_path,
     _tree_to_fs_path,
     build_index_from_tree,
     build_index_from_tree,
@@ -168,7 +169,7 @@ class ReadIndexDictTests(IndexTestCase):
 
 
     def test_simple_write(self):
     def test_simple_write(self):
         entries = {
         entries = {
-            (b"barbla", 0): IndexEntry(
+            (b"barbla", Stage.NORMAL): IndexEntry(
                 (1230680220, 0),
                 (1230680220, 0),
                 (1230680220, 0),
                 (1230680220, 0),
                 2050,
                 2050,