Преглед изворни кода

Split out a separate ConflictedIndexEntry class

This should make the changes transparent to existing API users (so long
as they don't work in trees with conflicts). It also prevents repeated
dictionary access.
Jelmer Vernooij пре 1 година
родитељ
комит
52328a607b
2 измењених фајлова са 90 додато и 151 уклоњено
  1. 90 117
      dulwich/index.py
  2. 0 34
      dulwich/tests/test_index.py

+ 90 - 117
dulwich/index.py

@@ -71,6 +71,27 @@ IndexEntry = collections.namedtuple(
 )
 
 
+class ConflictedIndexEntry:
+    """Index entry that represents a conflict."""
+
+    ancestor: Optional[IndexEntry]
+    this: Optional[IndexEntry]
+    other: Optional[IndexEntry]
+
+    def __init__(self):
+        self.ancestor = None
+        self.this = None
+        self.other = None
+
+    def entries(self) -> Iterable[IndexEntry]:
+        if self.ancestor:
+            yield self.ancestor
+        if self.this:
+            yield self.this
+        if self.other:
+            yield self.other
+
+
 # 2-bit stage (during merge)
 FLAG_STAGEMASK = 0x3000
 FLAG_STAGESHIFT = 12
@@ -90,7 +111,6 @@ EXTENDED_FLAG_INTEND_TO_ADD = 0x2000
 
 DEFAULT_VERSION = 2
 
-
 class Stage(Enum):
     NORMAL = 0
     MERGE_CONFLICT_ANCESTOR = 1
@@ -98,9 +118,8 @@ class Stage(Enum):
     MERGE_CONFLICT_OTHER = 3
 
 
-class UnmergedEntriesInIndexEx(Exception):
-    def __init__(self, message):
-        super().__init__(message)
+class UnmergedEntries(Exception):
+    """Unmerged entries exist in the index"""
 
 
 def read_stage(entry: IndexEntry) -> Stage:
@@ -250,7 +269,7 @@ class UnsupportedIndexFormat(Exception):
         self.index_format_version = version
 
 
-def read_index(f: BinaryIO):
+def read_index(f: BinaryIO) -> Iterator[Tuple[bytes, IndexEntry]]:
     """Read an index file, yielding the individual entries."""
     header = f.read(4)
     if header != b"DIRC":
@@ -262,20 +281,6 @@ def read_index(f: BinaryIO):
         yield read_cache_entry(f, version)
 
 
-def read_index_dict(f) -> Dict[Tuple[bytes, Stage], IndexEntry]:
-    """Read an index file and return it as a dictionary.
-       Dict Key is tuple of path and stage number, as
-            path alone is not unique
-    Args:
-      f: File object to read fromls.
-    """
-    ret = {}
-    for name, entry in read_index(f):
-        stage = read_stage(entry)
-        ret[(name, stage)] = entry
-    return ret
-
-
 def write_index(f: BinaryIO, entries: List[Tuple[bytes, IndexEntry]], version: Optional[int] = None):
     """Write an index file.
 
@@ -294,7 +299,7 @@ def write_index(f: BinaryIO, entries: List[Tuple[bytes, IndexEntry]], version: O
 
 def write_index_dict(
     f: BinaryIO,
-    entries: Dict[Tuple[bytes, Stage], IndexEntry],
+    entries: Dict[bytes, IndexEntry | ConflictedIndexEntry],
     version: Optional[int] = None,
 ) -> None:
     """Write an index file based on the contents of a dictionary.
@@ -302,12 +307,16 @@ def write_index_dict(
     """
     entries_list = []
     for key in sorted(entries):
-        if isinstance(key, tuple):
-            name, stage = key
+        value = entries[key]
+        if isinstance(value, ConflictedIndexEntry):
+            if value.ancestor is not None:
+                entries_list.append((key, value.ancestor))
+            if value.this is not None:
+                entries_list.append((key, value.this))
+            if value.other is not None:
+                entries_list.append((key, value.other))
         else:
-            name = key
-            stage = Stage.NORMAL
-        entries_list.append((name, entries[(name, stage)]))
+            entries_list.append((key, value))
     write_index(f, entries_list, version=version)
 
 
@@ -337,7 +346,7 @@ def cleanup_mode(mode: int) -> int:
 class Index:
     """A Git Index file."""
 
-    _bynamestage: Dict[Tuple[bytes, Stage], IndexEntry]
+    _byname: Dict[bytes, IndexEntry | ConflictedIndexEntry]
 
     def __init__(self, filename: Union[bytes, str], read=True) -> None:
         """Create an index object associated with the given filename.
@@ -365,7 +374,7 @@ class Index:
         f = GitFile(self._filename, "wb")
         try:
             f = SHA1Writer(f)
-            write_index_dict(f, self._bynamestage, version=self._version)
+            write_index_dict(f, self._byname, version=self._version)
         finally:
             f.close()
 
@@ -378,7 +387,19 @@ class Index:
             f = SHA1Reader(f)
             for name, entry in read_index(f):
                 stage = read_stage(entry)
-                self[(name, stage)] = entry
+                if stage == Stage.NORMAL:
+                    self[name] = entry
+                else:
+                    import pdb; pdb.set_trace()
+                    existing = self._byname.setdefault(name, ConflictedIndexEntry())
+                    if isinstance(existing, IndexEntry):
+                        raise AssertionError("Non-conflicted entry for %r exists" % name)
+                    if stage == Stage.MERGE_CONFLICT_ANCESTOR:
+                        existing.ancestor = entry
+                    elif stage == Stage.MERGE_CONFLICT_THIS:
+                        existing.this = entry
+                    elif stage == Stage.MERGE_CONFLICT_OTHER:
+                        existing.other = entry
             # FIXME: Additional data?
             f.read(os.path.getsize(self._filename) - f.tell() - 20)
             f.check_sha()
@@ -387,60 +408,42 @@ class Index:
 
     def __len__(self) -> int:
         """Number of entries in this index file."""
-        return len(self._bynamestage)
+        return len(self._byname)
 
-    def __getitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> IndexEntry:
+    def __getitem__(self, key: bytes) -> IndexEntry | ConflictedIndexEntry:
         """Retrieve entry by relative path and stage.
 
         Returns: tuple with (ctime, mtime, dev, ino, mode, uid, gid, size, sha,
             flags)
         """
-        if isinstance(key, tuple):
-            return self._bynamestage[key]
-        if (key, Stage.NORMAL) in self._bynamestage:
-            return self._bynamestage[(key, Stage.NORMAL)]
-        # there is a conflict return 'this' entry
-        return self._bynamestage[(key, Stage.MERGE_CONFLICT_THIS)]
+        return self._byname[key]
 
     def __iter__(self) -> Iterator[bytes]:
         """Iterate over the paths and stages in this index."""
-        for (name, stage) in self._bynamestage:
-            if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
-                continue
-            yield name
+        return iter(self._byname)
 
     def __contains__(self, key):
-        if isinstance(key, tuple):
-            return key in self._bynamestage
-        if (key, Stage.NORMAL) in self._bynamestage:
-            return True
-        if (key, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
-            return True
-        return False
+        return key in self._byname
 
-    def get_sha1(self, path: bytes, stage: Stage = Stage.NORMAL) -> bytes:
+    def get_sha1(self, path: bytes) -> bytes:
         """Return the (git object) SHA1 for the object at a path."""
-        return self[(path, stage)].sha
+        return self[path].sha
 
-    def get_mode(self, path: bytes, stage: Stage = Stage.NORMAL) -> int:
+    def get_mode(self, path: bytes) -> int:
         """Return the POSIX file mode for the object at a path."""
-        return self[(path, stage)].mode
+        return self[path].mode
 
     def iterobjects(self) -> Iterable[Tuple[bytes, bytes, int]]:
         """Iterate over path, sha, mode tuples for use with commit_tree."""
         for path in self:
             entry = self[path]
+            if isinstance(entry, ConflictedIndexEntry):
+                raise UnmergedEntries()
             yield path, entry.sha, cleanup_mode(entry.mode)
 
-    def iterconflicts(self) -> Iterable[Tuple[int, bytes, Stage, bytes]]:
-        """Iterate over path, sha, mode tuples for use with commit_tree."""
-        for (name, stage), entry in self._bynamestage.items():
-            if stage != Stage.NORMAL:
-                yield cleanup_mode(entry.mode), entry.sha, stage, name
-
-    def has_conflicts(self):
-        for (name, stage) in self._bynamestage.keys():
-            if stage != Stage.NORMAL:
+    def has_conflicts(self) -> bool:
+        for value in self._byname.values():
+            if isinstance(value, ConflictedIndexEntry):
                 return True
         return False
 
@@ -456,65 +459,36 @@ class Index:
                            sha,
                            stage << FLAG_STAGESHIFT,
                            0)
-        if (apath, Stage.NORMAL) in self._bynamestage:
-            del self._bynamestage[(apath, Stage.NORMAL)]
-        self._bynamestage[(apath, stage)] = entry
+        if (apath, Stage.NORMAL) in self._byname:
+            del self._byname[(apath, Stage.NORMAL)]
+        self._byname[(apath, stage)] = entry
 
     def clear(self):
         """Remove all contents from this index."""
-        self._bynamestage = {}
+        self._byname = {}
 
-    def __setitem__(self, key: Union[Tuple[bytes, Stage], bytes], x: IndexEntry) -> None:
-        assert len(x) == len(IndexEntry._fields)
-        if isinstance(key, tuple):
-            name, stage = key
-        else:
-            name = key
-            stage = Stage.NORMAL  # default when stage not explicitly specified
-        assert isinstance(name, bytes)
-        # Remove merge conflict entries if new entry is stage 0
-        # Remove normal stage entry if new entry has conflicts (stage > 0)
-        if stage == Stage.NORMAL:
-            if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
-                del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
-            if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
-                del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
-            if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
-                del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
-        if stage != Stage.NORMAL and (name, Stage.NORMAL) in self._bynamestage:
-            del self._bynamestage[(name, Stage.NORMAL)]
-        self._bynamestage[(name, stage)] = IndexEntry(*x)
-
-    def __delitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> None:
-        if isinstance(key, tuple):
-            del self._bynamestage[key]
-            return
-        name = key
+    def __setitem__(self, name: bytes, value: IndexEntry | ConflictedIndexEntry) -> None:
         assert isinstance(name, bytes)
-        if (name, Stage.NORMAL) in self._bynamestage:
-            del self._bynamestage[(name, Stage.NORMAL)]
-        if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
-            del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
-        if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
-            del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
-        if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
-            del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
-
-    def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
-        for (name, stage), entry in self._bynamestage.items():
-            yield name, entry
-
-    def items(self) -> Iterator[Tuple[Tuple[bytes, Stage], IndexEntry]]:
-        return iter(self._bynamestage.items())
-
-    def update(self, entries: Dict[Tuple[bytes, Stage], IndexEntry]):
+        if not isinstance(value, (IndexEntry, ConflictedIndexEntry)):
+            value = IndexEntry(*value)
+        self._byname[name] = value
+
+    def __delitem__(self, name: bytes) -> None:
+        del self._byname[name]
+
+    def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry | ConflictedIndexEntry]]:
+        return iter(self._byname.items())
+
+    def items(self) -> Iterator[Tuple[bytes, IndexEntry | ConflictedIndexEntry]]:
+        return iter(self._byname.items())
+
+    def update(self, entries: Dict[bytes, IndexEntry]):
         for key, value in entries.items():
             self[key] = value
 
     def paths(self):
-        for (name, stage) in self._bynamestage.keys():
-            if stage == Stage.NORMAL or stage == Stage.MERGE_CONFLICT_THIS:
-                yield name
+        for name in self._byname.keys():
+            yield name
 
     def changes_from_tree(
             self, object_store, tree: ObjectID, want_unchanged: bool = False):
@@ -548,9 +522,6 @@ class Index:
         Returns:
           Root tree SHA
         """
-        # as with git check for unmerged entries in the index and fail if found
-        if self.has_conflicts():
-            raise UnmergedEntriesInIndexEx('Unmerged entries exist in index these need to be handled first')
         return commit_tree(object_store, self.iterobjects())
 
 
@@ -858,7 +829,7 @@ def build_index_from_tree(
             st = st.__class__(st_tuple)
             # default to a stage 0 index entry (normal)
             # when reading from the filesystem
-        index[(entry.path, Stage.NORMAL)] = index_entry_from_stat(st, entry.sha, 0)
+        index[entry.path] = index_entry_from_stat(st, entry.sha, 0)
 
     index.write()
 
@@ -962,9 +933,11 @@ def get_unstaged_changes(
 
     for tree_path, entry in index.iteritems():
         full_path = _tree_to_fs_path(root_path, tree_path)
-        stage = read_stage(entry)
-        if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
+        if isinstance(entry, ConflictedIndexEntry):
+            # Conflicted files are always unstaged
+            yield tree_path
             continue
+
         try:
             st = os.lstat(full_path)
             if stat.S_ISDIR(st.st_mode):
@@ -1143,7 +1116,7 @@ class locked_index:
             return
         try:
             f = SHA1Writer(self._file)
-            write_index_dict(f, self._index._bynamestage)
+            write_index_dict(f, self._index._byname)
         except BaseException:
             self._file.abort()
         else:

+ 0 - 34
dulwich/tests/test_index.py

@@ -43,7 +43,6 @@ from ..index import (
     get_unstaged_changes,
     index_entry_from_stat,
     read_index,
-    read_index_dict,
     validate_path_element_default,
     validate_path_element_ntfs,
     write_cache_time,
@@ -158,39 +157,6 @@ class SimpleIndexWriterTestCase(IndexTestCase):
             self.assertEqual(entries, list(read_index(x)))
 
 
-class ReadIndexDictTests(IndexTestCase):
-    def setUp(self):
-        IndexTestCase.setUp(self)
-        self.tempdir = tempfile.mkdtemp()
-
-    def tearDown(self):
-        IndexTestCase.tearDown(self)
-        shutil.rmtree(self.tempdir)
-
-    def test_simple_write(self):
-        entries = {
-            (b"barbla", Stage.NORMAL): IndexEntry(
-                (1230680220, 0),
-                (1230680220, 0),
-                2050,
-                3761020,
-                33188,
-                1000,
-                1000,
-                0,
-                b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
-                0,
-                0,
-            )
-        }
-        filename = os.path.join(self.tempdir, "test-simple-write-index")
-        with open(filename, "wb+") as x:
-            write_index_dict(x, entries)
-
-        with open(filename, "rb") as x:
-            self.assertEqual(entries, read_index_dict(x))
-
-
 class CommitTreeTests(TestCase):
     def setUp(self):
         super().setUp()