Browse Source

Merge pull request #1196 from jelmer/index-refactor

More refactoring of index
Jelmer Vernooij 1 year ago
parent
commit
3c84bc409d
4 changed files with 194 additions and 207 deletions
  1. 175 179
      dulwich/index.py
  2. 1 1
      dulwich/porcelain.py
  3. 1 3
      dulwich/repo.py
  4. 17 24
      dulwich/tests/test_index.py

+ 175 - 179
dulwich/index.py

@@ -20,11 +20,11 @@
 
 
 """Parser for the git index file format."""
 """Parser for the git index file format."""
 
 
-import collections
 import os
 import os
 import stat
 import stat
 import struct
 import struct
 import sys
 import sys
+from dataclasses import dataclass
 from enum import Enum
 from enum import Enum
 from typing import (
 from typing import (
     Any,
     Any,
@@ -52,25 +52,6 @@ from .objects import (
 )
 )
 from .pack import ObjectContainer, SHA1Reader, SHA1Writer
 from .pack import ObjectContainer, SHA1Reader, SHA1Writer
 
 
-# TODO(jelmer): Switch to dataclass?
-IndexEntry = collections.namedtuple(
-    "IndexEntry",
-    [
-        "ctime",
-        "mtime",
-        "dev",
-        "ino",
-        "mode",
-        "uid",
-        "gid",
-        "size",
-        "sha",
-        "flags",
-        "extended_flags",
-    ],
-)
-
-
 # 2-bit stage (during merge)
 # 2-bit stage (during merge)
 FLAG_STAGEMASK = 0x3000
 FLAG_STAGEMASK = 0x3000
 FLAG_STAGESHIFT = 12
 FLAG_STAGESHIFT = 12
@@ -98,13 +79,85 @@ class Stage(Enum):
     MERGE_CONFLICT_OTHER = 3
     MERGE_CONFLICT_OTHER = 3
 
 
 
 
-class UnmergedEntriesInIndexEx(Exception):
-    def __init__(self, message):
-        super().__init__(message)
+@dataclass
+class SerializedIndexEntry:
+    name: bytes
+    ctime: Union[int, float, Tuple[int, int]]
+    mtime: Union[int, float, Tuple[int, int]]
+    dev: int
+    ino: int
+    mode: int
+    uid: int
+    gid: int
+    size: int
+    sha: bytes
+    flags: int
+    extended_flags: int
+
+    def stage(self) -> Stage:
+        return Stage((self.flags & FLAG_STAGEMASK) >> FLAG_STAGESHIFT)
+
+
+@dataclass
+class IndexEntry:
+    ctime: Union[int, float, Tuple[int, int]]
+    mtime: Union[int, float, Tuple[int, int]]
+    dev: int
+    ino: int
+    mode: int
+    uid: int
+    gid: int
+    size: int
+    sha: bytes
+
+    @classmethod
+    def from_serialized(cls, serialized: SerializedIndexEntry) -> "IndexEntry":
+        return cls(
+            ctime=serialized.ctime,
+            mtime=serialized.mtime,
+            dev=serialized.dev,
+            ino=serialized.ino,
+            mode=serialized.mode,
+            uid=serialized.uid,
+            gid=serialized.gid,
+            size=serialized.size,
+            sha=serialized.sha,
+        )
+
+    def serialize(self, name: bytes, stage: Stage) -> SerializedIndexEntry:
+        return SerializedIndexEntry(
+            name=name,
+            ctime=self.ctime,
+            mtime=self.mtime,
+            dev=self.dev,
+            ino=self.ino,
+            mode=self.mode,
+            uid=self.uid,
+            gid=self.gid,
+            size=self.size,
+            sha=self.sha,
+            flags=stage.value << FLAG_STAGESHIFT,
+            extended_flags=0,
+        )
+
+
+class ConflictedIndexEntry:
+    """Index entry that represents a conflict."""
+
+    ancestor: Optional[IndexEntry]
+    this: Optional[IndexEntry]
+    other: Optional[IndexEntry]
+
+    def __init__(self, ancestor: Optional[IndexEntry] = None,
+                 this: Optional[IndexEntry] = None,
+                 other: Optional[IndexEntry] = None) -> None:
+        self.ancestor = ancestor
+        self.this = this
+        self.other = other
 
 
 
 
-def read_stage(entry: IndexEntry) -> Stage:
-    return Stage((entry.flags & FLAG_STAGEMASK) >> FLAG_STAGESHIFT)
+class UnmergedEntries(Exception):
+    """Unmerged entries exist in the index."""
 
 
 
 
 def pathsplit(path: bytes) -> Tuple[bytes, bytes]:
 def pathsplit(path: bytes) -> Tuple[bytes, bytes]:
@@ -157,13 +210,11 @@ def write_cache_time(f, t):
     f.write(struct.pack(">LL", *t))
     f.write(struct.pack(">LL", *t))
 
 
 
 
-def read_cache_entry(f, version: int) -> Tuple[bytes, IndexEntry]:
+def read_cache_entry(f, version: int) -> SerializedIndexEntry:
     """Read an entry from a cache file.
     """Read an entry from a cache file.
 
 
     Args:
     Args:
       f: File-like object to read from
       f: File-like object to read from
-    Returns:
-      tuple with: name, IndexEntry
     """
     """
     beginoffset = f.tell()
     beginoffset = f.tell()
     ctime = read_cache_time(f)
     ctime = read_cache_time(f)
@@ -190,24 +241,23 @@ def read_cache_entry(f, version: int) -> Tuple[bytes, IndexEntry]:
     if version < 4:
     if version < 4:
         real_size = (f.tell() - beginoffset + 8) & ~7
         real_size = (f.tell() - beginoffset + 8) & ~7
         f.read((beginoffset + real_size) - f.tell())
         f.read((beginoffset + real_size) - f.tell())
-    return (
+    return SerializedIndexEntry(
         name,
         name,
-        IndexEntry(
-            ctime,
-            mtime,
-            dev,
-            ino,
-            mode,
-            uid,
-            gid,
-            size,
-            sha_to_hex(sha),
-            flags & ~FLAG_NAMEMASK,
-            extended_flags,
-        ))
-
-
-def write_cache_entry(f, name: bytes, entry: IndexEntry, version: int) -> None:
+        ctime,
+        mtime,
+        dev,
+        ino,
+        mode,
+        uid,
+        gid,
+        size,
+        sha_to_hex(sha),
+        flags & ~FLAG_NAMEMASK,
+        extended_flags,
+    )
+
+
+def write_cache_entry(f, entry: SerializedIndexEntry, version: int) -> None:
     """Write an index entry to a file.
     """Write an index entry to a file.
 
 
     Args:
     Args:
@@ -217,7 +267,7 @@ def write_cache_entry(f, name: bytes, entry: IndexEntry, version: int) -> None:
     beginoffset = f.tell()
     beginoffset = f.tell()
     write_cache_time(f, entry.ctime)
     write_cache_time(f, entry.ctime)
     write_cache_time(f, entry.mtime)
     write_cache_time(f, entry.mtime)
-    flags = len(name) | (entry.flags & ~FLAG_NAMEMASK)
+    flags = len(entry.name) | (entry.flags & ~FLAG_NAMEMASK)
     if entry.extended_flags:
     if entry.extended_flags:
         flags |= FLAG_EXTENDED
         flags |= FLAG_EXTENDED
     if flags & FLAG_EXTENDED and version is not None and version < 3:
     if flags & FLAG_EXTENDED and version is not None and version < 3:
@@ -237,7 +287,7 @@ def write_cache_entry(f, name: bytes, entry: IndexEntry, version: int) -> None:
     )
     )
     if flags & FLAG_EXTENDED:
     if flags & FLAG_EXTENDED:
         f.write(struct.pack(b">H", entry.extended_flags))
         f.write(struct.pack(b">H", entry.extended_flags))
-    f.write(name)
+    f.write(entry.name)
     if version < 4:
     if version < 4:
         real_size = (f.tell() - beginoffset + 8) & ~7
         real_size = (f.tell() - beginoffset + 8) & ~7
         f.write(b"\0" * ((beginoffset + real_size) - f.tell()))
         f.write(b"\0" * ((beginoffset + real_size) - f.tell()))
@@ -250,7 +300,7 @@ class UnsupportedIndexFormat(Exception):
         self.index_format_version = version
         self.index_format_version = version
 
 
 
 
-def read_index(f: BinaryIO):
+def read_index(f: BinaryIO) -> Iterator[SerializedIndexEntry]:
     """Read an index file, yielding the individual entries."""
     """Read an index file, yielding the individual entries."""
     header = f.read(4)
     header = f.read(4)
     if header != b"DIRC":
     if header != b"DIRC":
@@ -262,21 +312,32 @@ def read_index(f: BinaryIO):
         yield read_cache_entry(f, version)
         yield read_cache_entry(f, version)
 
 
 
 
-def read_index_dict(f) -> Dict[Tuple[bytes, Stage], IndexEntry]:
+def read_index_dict(f) -> Dict[bytes, Union[IndexEntry, ConflictedIndexEntry]]:
     """Read an index file and return it as a dictionary.
     """Read an index file and return it as a dictionary.
        Dict Key is tuple of path and stage number, as
        Dict Key is tuple of path and stage number, as
             path alone is not unique
             path alone is not unique
     Args:
     Args:
       f: File object to read fromls.
       f: File object to read fromls.
     """
     """
-    ret = {}
-    for name, entry in read_index(f):
-        stage = read_stage(entry)
-        ret[(name, stage)] = entry
+    ret: Dict[bytes, Union[IndexEntry, ConflictedIndexEntry]] = {}
+    for entry in read_index(f):
+        stage = entry.stage()
+        if stage == Stage.NORMAL:
+            ret[entry.name] = IndexEntry.from_serialized(entry)
+        else:
+            existing = ret.setdefault(entry.name, ConflictedIndexEntry())
+            if isinstance(existing, IndexEntry):
+                raise AssertionError("Non-conflicted entry for %r exists" % entry.name)
+            if stage == Stage.MERGE_CONFLICT_ANCESTOR:
+                existing.ancestor = IndexEntry.from_serialized(entry)
+            elif stage == Stage.MERGE_CONFLICT_THIS:
+                existing.this = IndexEntry.from_serialized(entry)
+            elif stage == Stage.MERGE_CONFLICT_OTHER:
+                existing.other = IndexEntry.from_serialized(entry)
     return ret
     return ret
 
 
 
 
-def write_index(f: BinaryIO, entries: List[Tuple[bytes, IndexEntry]], version: Optional[int] = None):
+def write_index(f: BinaryIO, entries: List[SerializedIndexEntry], version: Optional[int] = None):
     """Write an index file.
     """Write an index file.
 
 
     Args:
     Args:
@@ -288,13 +349,13 @@ def write_index(f: BinaryIO, entries: List[Tuple[bytes, IndexEntry]], version: O
         version = DEFAULT_VERSION
         version = DEFAULT_VERSION
     f.write(b"DIRC")
     f.write(b"DIRC")
     f.write(struct.pack(b">LL", version, len(entries)))
     f.write(struct.pack(b">LL", version, len(entries)))
-    for name, entry in entries:
-        write_cache_entry(f, name, entry, version)
+    for entry in entries:
+        write_cache_entry(f, entry, version)
 
 
 
 
 def write_index_dict(
 def write_index_dict(
     f: BinaryIO,
     f: BinaryIO,
-    entries: Dict[Tuple[bytes, Stage], IndexEntry],
+    entries: Dict[bytes, Union[IndexEntry, ConflictedIndexEntry]],
     version: Optional[int] = None,
     version: Optional[int] = None,
 ) -> None:
 ) -> None:
     """Write an index file based on the contents of a dictionary.
     """Write an index file based on the contents of a dictionary.
@@ -302,12 +363,16 @@ def write_index_dict(
     """
     """
     entries_list = []
     entries_list = []
     for key in sorted(entries):
     for key in sorted(entries):
-        if isinstance(key, tuple):
-            name, stage = key
+        value = entries[key]
+        if isinstance(value, ConflictedIndexEntry):
+            if value.ancestor is not None:
+                entries_list.append(value.ancestor.serialize(key, Stage.MERGE_CONFLICT_ANCESTOR))
+            if value.this is not None:
+                entries_list.append(value.this.serialize(key, Stage.MERGE_CONFLICT_THIS))
+            if value.other is not None:
+                entries_list.append(value.other.serialize(key, Stage.MERGE_CONFLICT_OTHER))
         else:
         else:
-            name = key
-            stage = Stage.NORMAL
-        entries_list.append((name, entries[(name, stage)]))
+            entries_list.append(value.serialize(key, Stage.NORMAL))
     write_index(f, entries_list, version=version)
     write_index(f, entries_list, version=version)
 
 
 
 
@@ -337,7 +402,7 @@ def cleanup_mode(mode: int) -> int:
 class Index:
 class Index:
     """A Git Index file."""
     """A Git Index file."""
 
 
-    _bynamestage: Dict[Tuple[bytes, Stage], IndexEntry]
+    _byname: Dict[bytes, Union[IndexEntry, ConflictedIndexEntry]]
 
 
     def __init__(self, filename: Union[bytes, str], read=True) -> None:
     def __init__(self, filename: Union[bytes, str], read=True) -> None:
         """Create an index object associated with the given filename.
         """Create an index object associated with the given filename.
@@ -365,7 +430,7 @@ class Index:
         f = GitFile(self._filename, "wb")
         f = GitFile(self._filename, "wb")
         try:
         try:
             f = SHA1Writer(f)
             f = SHA1Writer(f)
-            write_index_dict(f, self._bynamestage, version=self._version)
+            write_index_dict(f, self._byname, version=self._version)
         finally:
         finally:
             f.close()
             f.close()
 
 
@@ -376,9 +441,7 @@ class Index:
         f = GitFile(self._filename, "rb")
         f = GitFile(self._filename, "rb")
         try:
         try:
             f = SHA1Reader(f)
             f = SHA1Reader(f)
-            for name, entry in read_index(f):
-                stage = read_stage(entry)
-                self[(name, stage)] = entry
+            self.update(read_index_dict(f))
             # FIXME: Additional data?
             # FIXME: Additional data?
             f.read(os.path.getsize(self._filename) - f.tell() - 20)
             f.read(os.path.getsize(self._filename) - f.tell() - 20)
             f.check_sha()
             f.check_sha()
@@ -387,134 +450,74 @@ class Index:
 
 
     def __len__(self) -> int:
     def __len__(self) -> int:
         """Number of entries in this index file."""
         """Number of entries in this index file."""
-        return len(self._bynamestage)
+        return len(self._byname)
 
 
-    def __getitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> IndexEntry:
+    def __getitem__(self, key: bytes) -> Union[IndexEntry, ConflictedIndexEntry]:
         """Retrieve entry by relative path and stage.
         """Retrieve entry by relative path and stage.
 
 
         Returns: tuple with (ctime, mtime, dev, ino, mode, uid, gid, size, sha,
         Returns: tuple with (ctime, mtime, dev, ino, mode, uid, gid, size, sha,
             flags)
             flags)
         """
         """
-        if isinstance(key, tuple):
-            return self._bynamestage[key]
-        if (key, Stage.NORMAL) in self._bynamestage:
-            return self._bynamestage[(key, Stage.NORMAL)]
-        # there is a conflict return 'this' entry
-        return self._bynamestage[(key, Stage.MERGE_CONFLICT_THIS)]
+        return self._byname[key]
 
 
     def __iter__(self) -> Iterator[bytes]:
     def __iter__(self) -> Iterator[bytes]:
         """Iterate over the paths and stages in this index."""
         """Iterate over the paths and stages in this index."""
-        for (name, stage) in self._bynamestage:
-            if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
-                continue
-            yield name
+        return iter(self._byname)
 
 
     def __contains__(self, key):
     def __contains__(self, key):
-        if isinstance(key, tuple):
-            return key in self._bynamestage
-        if (key, Stage.NORMAL) in self._bynamestage:
-            return True
-        if (key, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
-            return True
-        return False
+        return key in self._byname
 
 
-    def get_sha1(self, path: bytes, stage: Stage = Stage.NORMAL) -> bytes:
+    def get_sha1(self, path: bytes) -> bytes:
         """Return the (git object) SHA1 for the object at a path."""
         """Return the (git object) SHA1 for the object at a path."""
-        return self[(path, stage)].sha
+        value = self[path]
+        if isinstance(value, ConflictedIndexEntry):
+            raise UnmergedEntries()
+        return value.sha
 
 
-    def get_mode(self, path: bytes, stage: Stage = Stage.NORMAL) -> int:
+    def get_mode(self, path: bytes) -> int:
         """Return the POSIX file mode for the object at a path."""
         """Return the POSIX file mode for the object at a path."""
-        return self[(path, stage)].mode
+        value = self[path]
+        if isinstance(value, ConflictedIndexEntry):
+            raise UnmergedEntries()
+        return value.mode
 
 
     def iterobjects(self) -> Iterable[Tuple[bytes, bytes, int]]:
     def iterobjects(self) -> Iterable[Tuple[bytes, bytes, int]]:
         """Iterate over path, sha, mode tuples for use with commit_tree."""
         """Iterate over path, sha, mode tuples for use with commit_tree."""
         for path in self:
         for path in self:
             entry = self[path]
             entry = self[path]
+            if isinstance(entry, ConflictedIndexEntry):
+                raise UnmergedEntries()
             yield path, entry.sha, cleanup_mode(entry.mode)
             yield path, entry.sha, cleanup_mode(entry.mode)
 
 
-    def iterconflicts(self) -> Iterable[Tuple[int, bytes, Stage, bytes]]:
-        """Iterate over path, sha, mode tuples for use with commit_tree."""
-        for (name, stage), entry in self._bynamestage.items():
-            if stage != Stage.NORMAL:
-                yield cleanup_mode(entry.mode), entry.sha, stage, name
-
-    def has_conflicts(self):
-        for (name, stage) in self._bynamestage.keys():
-            if stage != Stage.NORMAL:
+    def has_conflicts(self) -> bool:
+        for value in self._byname.values():
+            if isinstance(value, ConflictedIndexEntry):
                 return True
                 return True
         return False
         return False
 
 
-    def set_merge_conflict(self, apath, stage, mode, sha, time):
-        entry = IndexEntry(time,
-                           time,
-                           0,
-                           0,
-                           mode,
-                           0,
-                           0,
-                           0,
-                           sha,
-                           stage << FLAG_STAGESHIFT,
-                           0)
-        if (apath, Stage.NORMAL) in self._bynamestage:
-            del self._bynamestage[(apath, Stage.NORMAL)]
-        self._bynamestage[(apath, stage)] = entry
-
     def clear(self):
     def clear(self):
         """Remove all contents from this index."""
         """Remove all contents from this index."""
-        self._bynamestage = {}
+        self._byname = {}
 
 
-    def __setitem__(self, key: Union[Tuple[bytes, Stage], bytes], x: IndexEntry) -> None:
-        assert len(x) == len(IndexEntry._fields)
-        if isinstance(key, tuple):
-            name, stage = key
-        else:
-            name = key
-            stage = Stage.NORMAL  # default when stage not explicitly specified
+    def __setitem__(self, name: bytes, value: Union[IndexEntry, ConflictedIndexEntry]) -> None:
         assert isinstance(name, bytes)
         assert isinstance(name, bytes)
-        # Remove merge conflict entries if new entry is stage 0
-        # Remove normal stage entry if new entry has conflicts (stage > 0)
-        if stage == Stage.NORMAL:
-            if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
-                del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
-            if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
-                del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
-            if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
-                del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
-        if stage != Stage.NORMAL and (name, Stage.NORMAL) in self._bynamestage:
-            del self._bynamestage[(name, Stage.NORMAL)]
-        self._bynamestage[(name, stage)] = IndexEntry(*x)
-
-    def __delitem__(self, key: Union[Tuple[bytes, Stage], bytes]) -> None:
-        if isinstance(key, tuple):
-            del self._bynamestage[key]
-            return
-        name = key
-        assert isinstance(name, bytes)
-        if (name, Stage.NORMAL) in self._bynamestage:
-            del self._bynamestage[(name, Stage.NORMAL)]
-        if (name, Stage.MERGE_CONFLICT_ANCESTOR) in self._bynamestage:
-            del self._bynamestage[(name, Stage.MERGE_CONFLICT_ANCESTOR)]
-        if (name, Stage.MERGE_CONFLICT_THIS) in self._bynamestage:
-            del self._bynamestage[(name, Stage.MERGE_CONFLICT_THIS)]
-        if (name, Stage.MERGE_CONFLICT_OTHER) in self._bynamestage:
-            del self._bynamestage[(name, Stage.MERGE_CONFLICT_OTHER)]
-
-    def iteritems(self) -> Iterator[Tuple[bytes, IndexEntry]]:
-        for (name, stage), entry in self._bynamestage.items():
-            yield name, entry
-
-    def items(self) -> Iterator[Tuple[Tuple[bytes, Stage], IndexEntry]]:
-        return iter(self._bynamestage.items())
-
-    def update(self, entries: Dict[Tuple[bytes, Stage], IndexEntry]):
+        self._byname[name] = value
+
+    def __delitem__(self, name: bytes) -> None:
+        del self._byname[name]
+
+    def iteritems(self) -> Iterator[Tuple[bytes, Union[IndexEntry, ConflictedIndexEntry]]]:
+        return iter(self._byname.items())
+
+    def items(self) -> Iterator[Tuple[bytes, Union[IndexEntry, ConflictedIndexEntry]]]:
+        return iter(self._byname.items())
+
+    def update(self, entries: Dict[bytes, Union[IndexEntry, ConflictedIndexEntry]]):
         for key, value in entries.items():
         for key, value in entries.items():
             self[key] = value
             self[key] = value
 
 
     def paths(self):
     def paths(self):
-        for (name, stage) in self._bynamestage.keys():
-            if stage == Stage.NORMAL or stage == Stage.MERGE_CONFLICT_THIS:
-                yield name
+        yield from self._byname.keys()
 
 
     def changes_from_tree(
     def changes_from_tree(
             self, object_store, tree: ObjectID, want_unchanged: bool = False):
             self, object_store, tree: ObjectID, want_unchanged: bool = False):
@@ -548,9 +551,6 @@ class Index:
         Returns:
         Returns:
           Root tree SHA
           Root tree SHA
         """
         """
-        # as with git check for unmerged entries in the index and fail if found
-        if self.has_conflicts():
-            raise UnmergedEntriesInIndexEx('Unmerged entries exist in index these need to be handled first')
         return commit_tree(object_store, self.iterobjects())
         return commit_tree(object_store, self.iterobjects())
 
 
 
 
@@ -661,15 +661,13 @@ def changes_from_tree(
 
 
 
 
 def index_entry_from_stat(
 def index_entry_from_stat(
-    stat_val, hex_sha: bytes, flags: int, mode: Optional[int] = None,
-    extended_flags: Optional[int] = None
+    stat_val, hex_sha: bytes, mode: Optional[int] = None,
 ):
 ):
     """Create a new index entry from a stat value.
     """Create a new index entry from a stat value.
 
 
     Args:
     Args:
       stat_val: POSIX stat_result instance
       stat_val: POSIX stat_result instance
       hex_sha: Hex sha of the object
       hex_sha: Hex sha of the object
-      flags: Index flags
     """
     """
     if mode is None:
     if mode is None:
         mode = cleanup_mode(stat_val.st_mode)
         mode = cleanup_mode(stat_val.st_mode)
@@ -684,8 +682,6 @@ def index_entry_from_stat(
         stat_val.st_gid,
         stat_val.st_gid,
         stat_val.st_size,
         stat_val.st_size,
         hex_sha,
         hex_sha,
-        flags,
-        extended_flags
     )
     )
 
 
 
 
@@ -858,7 +854,7 @@ def build_index_from_tree(
             st = st.__class__(st_tuple)
             st = st.__class__(st_tuple)
             # default to a stage 0 index entry (normal)
             # default to a stage 0 index entry (normal)
             # when reading from the filesystem
             # when reading from the filesystem
-        index[(entry.path, Stage.NORMAL)] = index_entry_from_stat(st, entry.sha, 0)
+        index[entry.path] = index_entry_from_stat(st, entry.sha)
 
 
     index.write()
     index.write()
 
 
@@ -962,9 +958,11 @@ def get_unstaged_changes(
 
 
     for tree_path, entry in index.iteritems():
     for tree_path, entry in index.iteritems():
         full_path = _tree_to_fs_path(root_path, tree_path)
         full_path = _tree_to_fs_path(root_path, tree_path)
-        stage = read_stage(entry)
-        if stage == Stage.MERGE_CONFLICT_ANCESTOR or stage == Stage.MERGE_CONFLICT_OTHER:
+        if isinstance(entry, ConflictedIndexEntry):
+            # Conflicted files are always unstaged
+            yield tree_path
             continue
             continue
+
         try:
         try:
             st = os.lstat(full_path)
             st = os.lstat(full_path)
             if stat.S_ISDIR(st.st_mode):
             if stat.S_ISDIR(st.st_mode):
@@ -1032,7 +1030,7 @@ def index_entry_from_directory(st, path: bytes) -> Optional[IndexEntry]:
         head = read_submodule_head(path)
         head = read_submodule_head(path)
         if head is None:
         if head is None:
             return None
             return None
-        return index_entry_from_stat(st, head, 0, mode=S_IFGITLINK)
+        return index_entry_from_stat(st, head, mode=S_IFGITLINK)
     return None
     return None
 
 
 
 
@@ -1060,7 +1058,7 @@ def index_entry_from_path(
         blob = blob_from_path_and_stat(path, st)
         blob = blob_from_path_and_stat(path, st)
         if object_store is not None:
         if object_store is not None:
             object_store.add_object(blob)
             object_store.add_object(blob)
-        return index_entry_from_stat(st, blob.id, 0)
+        return index_entry_from_stat(st, blob.id)
 
 
     return None
     return None
 
 
@@ -1105,7 +1103,6 @@ def iter_fresh_objects(
             if include_deleted:
             if include_deleted:
                 yield path, None, None
                 yield path, None, None
         else:
         else:
-            entry = IndexEntry(*entry)
             yield path, entry.sha, cleanup_mode(entry.mode)
             yield path, entry.sha, cleanup_mode(entry.mode)
 
 
 
 
@@ -1120,8 +1117,7 @@ def refresh_index(index: Index, root_path: bytes):
     """
     """
     for path, entry in iter_fresh_entries(index, root_path):
     for path, entry in iter_fresh_entries(index, root_path):
         if entry:
         if entry:
-            stage = read_stage(entry)
-            index[(path, stage)] = entry
+            index[path] = entry
 
 
 
 
 class locked_index:
 class locked_index:
@@ -1143,7 +1139,7 @@ class locked_index:
             return
             return
         try:
         try:
             f = SHA1Writer(self._file)
             f = SHA1Writer(self._file)
-            write_index_dict(f, self._index._bynamestage)
+            write_index_dict(f, self._index._byname)
         except BaseException:
         except BaseException:
             self._file.abort()
             self._file.abort()
         else:
         else:

+ 1 - 1
dulwich/porcelain.py

@@ -1970,7 +1970,7 @@ def checkout_branch(repo, target: Union[bytes, str], force: bool = False):
             blob = repo.object_store[entry.sha]
             blob = repo.object_store[entry.sha]
             ensure_dir_exists(os.path.dirname(full_path))
             ensure_dir_exists(os.path.dirname(full_path))
             st = build_file_from_blob(blob, entry.mode, full_path)
             st = build_file_from_blob(blob, entry.mode, full_path)
-            repo_index[entry.path] = index_entry_from_stat(st, entry.sha, 0)
+            repo_index[entry.path] = index_entry_from_stat(st, entry.sha)
 
 
         repo_index.write()
         repo_index.write()
 
 

+ 1 - 3
dulwich/repo.py

@@ -1422,7 +1422,7 @@ class Repo(BaseRepo):
                     blob = blob_from_path_and_stat(full_path, st)
                     blob = blob_from_path_and_stat(full_path, st)
                     blob = blob_normalizer.checkin_normalize(blob, fs_path)
                     blob = blob_normalizer.checkin_normalize(blob, fs_path)
                     self.object_store.add_object(blob)
                     self.object_store.add_object(blob)
-                    index[tree_path] = index_entry_from_stat(st, blob.id, 0)
+                    index[tree_path] = index_entry_from_stat(st, blob.id)
         index.write()
         index.write()
 
 
     def unstage(self, fs_paths: List[str]):
     def unstage(self, fs_paths: List[str]):
@@ -1478,8 +1478,6 @@ class Repo(BaseRepo):
                 gid=st.st_gid if st else 0,
                 gid=st.st_gid if st else 0,
                 size=len(self[tree_entry[1]].data),
                 size=len(self[tree_entry[1]].data),
                 sha=tree_entry[1],
                 sha=tree_entry[1],
-                flags=0,
-                extended_flags=0
             )
             )
 
 
             index[tree_path] = index_entry
             index[tree_path] = index_entry

+ 17 - 24
dulwich/tests/test_index.py

@@ -34,7 +34,7 @@ from dulwich.tests import TestCase, skipIf
 from ..index import (
 from ..index import (
     Index,
     Index,
     IndexEntry,
     IndexEntry,
-    Stage,
+    SerializedIndexEntry,
     _fs_to_tree_path,
     _fs_to_tree_path,
     _tree_to_fs_path,
     _tree_to_fs_path,
     build_index_from_tree,
     build_index_from_tree,
@@ -93,7 +93,7 @@ class SimpleIndexTestCase(IndexTestCase):
 
 
     def test_getitem(self):
     def test_getitem(self):
         self.assertEqual(
         self.assertEqual(
-            (
+            IndexEntry(
                 (1230680220, 0),
                 (1230680220, 0),
                 (1230680220, 0),
                 (1230680220, 0),
                 2050,
                 2050,
@@ -103,8 +103,6 @@ class SimpleIndexTestCase(IndexTestCase):
                 1000,
                 1000,
                 0,
                 0,
                 b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
                 b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
-                0,
-                0,
             ),
             ),
             self.get_simple_index("index")[b"bla"],
             self.get_simple_index("index")[b"bla"],
         )
         )
@@ -135,8 +133,8 @@ class SimpleIndexWriterTestCase(IndexTestCase):
     def test_simple_write(self):
     def test_simple_write(self):
         entries = [
         entries = [
             (
             (
-                b"barbla",
-                IndexEntry(
+                SerializedIndexEntry(
+                    b"barbla",
                     (1230680220, 0),
                     (1230680220, 0),
                     (1230680220, 0),
                     (1230680220, 0),
                     2050,
                     2050,
@@ -159,6 +157,7 @@ class SimpleIndexWriterTestCase(IndexTestCase):
 
 
 
 
 class ReadIndexDictTests(IndexTestCase):
 class ReadIndexDictTests(IndexTestCase):
+
     def setUp(self):
     def setUp(self):
         IndexTestCase.setUp(self)
         IndexTestCase.setUp(self)
         self.tempdir = tempfile.mkdtemp()
         self.tempdir = tempfile.mkdtemp()
@@ -169,7 +168,7 @@ class ReadIndexDictTests(IndexTestCase):
 
 
     def test_simple_write(self):
     def test_simple_write(self):
         entries = {
         entries = {
-            (b"barbla", Stage.NORMAL): IndexEntry(
+            b"barbla": IndexEntry(
                 (1230680220, 0),
                 (1230680220, 0),
                 (1230680220, 0),
                 (1230680220, 0),
                 2050,
                 2050,
@@ -179,8 +178,6 @@ class ReadIndexDictTests(IndexTestCase):
                 1000,
                 1000,
                 0,
                 0,
                 b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
                 b"e69de29bb2d1d6434b8b29ae775ad8c2e48c5391",
-                0,
-                0,
             )
             )
         }
         }
         filename = os.path.join(self.tempdir, "test-simple-write-index")
         filename = os.path.join(self.tempdir, "test-simple-write-index")
@@ -278,7 +275,7 @@ class IndexEntryFromStatTests(TestCase):
                 1324180496,
                 1324180496,
             )
             )
         )
         )
-        entry = index_entry_from_stat(st, "22" * 20, 0)
+        entry = index_entry_from_stat(st, b"22" * 20)
         self.assertEqual(
         self.assertEqual(
             entry,
             entry,
             IndexEntry(
             IndexEntry(
@@ -290,9 +287,7 @@ class IndexEntryFromStatTests(TestCase):
                 1000,
                 1000,
                 1000,
                 1000,
                 12288,
                 12288,
-                "2222222222222222222222222222222222222222",
-                0,
-                None,
+                b"2222222222222222222222222222222222222222",
             ),
             ),
         )
         )
 
 
@@ -311,7 +306,7 @@ class IndexEntryFromStatTests(TestCase):
                 1324180496,
                 1324180496,
             )
             )
         )
         )
-        entry = index_entry_from_stat(st, "22" * 20, 0, mode=stat.S_IFREG + 0o755)
+        entry = index_entry_from_stat(st, b"22" * 20, mode=stat.S_IFREG + 0o755)
         self.assertEqual(
         self.assertEqual(
             entry,
             entry,
             IndexEntry(
             IndexEntry(
@@ -323,18 +318,16 @@ class IndexEntryFromStatTests(TestCase):
                 1000,
                 1000,
                 1000,
                 1000,
                 12288,
                 12288,
-                "2222222222222222222222222222222222222222",
-                0,
-                None,
+                b"2222222222222222222222222222222222222222",
             ),
             ),
         )
         )
 
 
 
 
 class BuildIndexTests(TestCase):
 class BuildIndexTests(TestCase):
     def assertReasonableIndexEntry(self, index_entry, mode, filesize, sha):
     def assertReasonableIndexEntry(self, index_entry, mode, filesize, sha):
-        self.assertEqual(index_entry[4], mode)  # mode
-        self.assertEqual(index_entry[7], filesize)  # filesize
-        self.assertEqual(index_entry[8], sha)  # sha
+        self.assertEqual(index_entry.mode, mode)  # mode
+        self.assertEqual(index_entry.size, filesize)  # filesize
+        self.assertEqual(index_entry.sha, sha)  # sha
 
 
     def assertFileContents(self, path, contents, symlink=False):
     def assertFileContents(self, path, contents, symlink=False):
         if symlink:
         if symlink:
@@ -607,8 +600,8 @@ class BuildIndexTests(TestCase):
             # dir c
             # dir c
             cpath = os.path.join(repo.path, "c")
             cpath = os.path.join(repo.path, "c")
             self.assertTrue(os.path.isdir(cpath))
             self.assertTrue(os.path.isdir(cpath))
-            self.assertEqual(index[b"c"][4], S_IFGITLINK)  # mode
-            self.assertEqual(index[b"c"][8], c.id)  # sha
+            self.assertEqual(index[b"c"].mode, S_IFGITLINK)  # mode
+            self.assertEqual(index[b"c"].sha, c.id)  # sha
 
 
     def test_git_submodule_exists(self):
     def test_git_submodule_exists(self):
         repo_dir = tempfile.mkdtemp()
         repo_dir = tempfile.mkdtemp()
@@ -648,8 +641,8 @@ class BuildIndexTests(TestCase):
             # dir c
             # dir c
             cpath = os.path.join(repo.path, "c")
             cpath = os.path.join(repo.path, "c")
             self.assertTrue(os.path.isdir(cpath))
             self.assertTrue(os.path.isdir(cpath))
-            self.assertEqual(index[b"c"][4], S_IFGITLINK)  # mode
-            self.assertEqual(index[b"c"][8], c.id)  # sha
+            self.assertEqual(index[b"c"].mode, S_IFGITLINK)  # mode
+            self.assertEqual(index[b"c"].sha, c.id)  # sha
 
 
 
 
 class GetUnstagedChangesTests(TestCase):
 class GetUnstagedChangesTests(TestCase):