Browse Source

Start writing reflog entries.

Jelmer Vernooij 7 năm trước cách đây
mục cha
commit
79540cdbfb
6 tập tin đã thay đổi với 164 bổ sung41 xóa
  1. 2 0
      NEWS
  2. 11 5
      dulwich/porcelain.py
  3. 1 1
      dulwich/reflog.py
  4. 79 17
      dulwich/refs.py
  5. 69 16
      dulwich/repo.py
  6. 2 2
      dulwich/tests/test_repository.py

+ 2 - 0
NEWS

@@ -13,6 +13,8 @@
 
 
   * Add a fastimport ``extra``. (Jelmer Vernooij)
   * Add a fastimport ``extra``. (Jelmer Vernooij)
 
 
+  * Start writing reflog entries. (Jelmer Vernooij)
+
  API CHANGES
  API CHANGES
 
 
   * ``GitClient.send_pack`` now accepts a ``generate_pack_data``
   * ``GitClient.send_pack`` now accepts a ``generate_pack_data``

+ 11 - 5
dulwich/porcelain.py

@@ -308,15 +308,18 @@ def clone(source, target=None, bare=False, checkout=None,
         fetch_result = client.fetch(
         fetch_result = client.fetch(
             host_path, r, determine_wants=r.object_store.determine_wants_all,
             host_path, r, determine_wants=r.object_store.determine_wants_all,
             progress=errstream.write)
             progress=errstream.write)
+        ref_message = b"clone: from " + source.encode('utf-8')
         r.refs.import_refs(
         r.refs.import_refs(
             b'refs/remotes/' + origin,
             b'refs/remotes/' + origin,
             {n[len(b'refs/heads/'):]: v for (n, v) in fetch_result.refs.items()
             {n[len(b'refs/heads/'):]: v for (n, v) in fetch_result.refs.items()
-                if n.startswith(b'refs/heads/')})
+                if n.startswith(b'refs/heads/')},
+            message=ref_message)
         r.refs.import_refs(
         r.refs.import_refs(
             b'refs/tags',
             b'refs/tags',
             {n[len(b'refs/tags/'):]: v for (n, v) in fetch_result.refs.items()
             {n[len(b'refs/tags/'):]: v for (n, v) in fetch_result.refs.items()
                 if n.startswith(b'refs/tags/') and
                 if n.startswith(b'refs/tags/') and
-                not n.endswith(ANNOTATED_TAG_SUFFIX)})
+                not n.endswith(ANNOTATED_TAG_SUFFIX)},
+            message=ref_message)
         target_config = r.get_config()
         target_config = r.get_config()
         if not isinstance(source, bytes):
         if not isinstance(source, bytes):
             source = source.encode(DEFAULT_ENCODING)
             source = source.encode(DEFAULT_ENCODING)
@@ -1025,9 +1028,12 @@ def branch_create(repo, name, objectish=None, force=False):
             objectish = "HEAD"
             objectish = "HEAD"
         object = parse_object(r, objectish)
         object = parse_object(r, objectish)
         refname = b"refs/heads/" + name
         refname = b"refs/heads/" + name
-        if refname in r.refs and not force:
-            raise KeyError("Branch with name %s already exists." % name)
-        r.refs[refname] = object.id
+        ref_message = b"branch: Created from " + objectish.encode('utf-8')
+        if force:
+            r.refs.set_if_equals(refname, None, object.id, message=ref_message)
+        else:
+            if not r.refs.add_if_new(refname, object.id, message=ref_message):
+                raise KeyError("Branch with name %s already exists." % name)
 
 
 
 
 def branch_list(repo):
 def branch_list(repo):

+ 1 - 1
dulwich/reflog.py

@@ -48,7 +48,7 @@ def format_reflog_line(old_sha, new_sha, committer, timestamp, timezone,
     if old_sha is None:
     if old_sha is None:
         old_sha = ZERO_SHA
         old_sha = ZERO_SHA
     return (old_sha + b' ' + new_sha + b' ' + committer + b' ' +
     return (old_sha + b' ' + new_sha + b' ' + committer + b' ' +
-            str(timestamp).encode('ascii') + b' ' +
+            str(int(timestamp)).encode('ascii') + b' ' +
             format_timezone(timezone) + b'\t' + message)
             format_timezone(timezone) + b'\t' + message)
 
 
 
 

+ 79 - 17
dulwich/refs.py

@@ -94,11 +94,25 @@ def check_ref_format(refname):
 class RefsContainer(object):
 class RefsContainer(object):
     """A container for refs."""
     """A container for refs."""
 
 
-    def set_symbolic_ref(self, name, other):
+    def __init__(self, logger=None):
+        self._logger = logger
+
+    def _log(self, ref, old_sha, new_sha, committer=None, timestamp=None,
+             timezone=None, message=None):
+        if self._logger is None:
+            return
+        if message is None:
+            return
+        self._logger(ref, old_sha, new_sha, committer, timestamp,
+                     timezone, message)
+
+    def set_symbolic_ref(self, name, other, committer=None, timestamp=None,
+                         timezone=None, message=None):
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
         :param name: Name of the ref to set
         :param name: Name of the ref to set
         :param other: Name of the ref to point at
         :param other: Name of the ref to point at
+        :param message: Optional message
         """
         """
         raise NotImplementedError(self.set_symbolic_ref)
         raise NotImplementedError(self.set_symbolic_ref)
 
 
@@ -122,9 +136,10 @@ class RefsContainer(object):
         """
         """
         return None
         return None
 
 
-    def import_refs(self, base, other):
+    def import_refs(self, base, other, committer=None, timestamp=None,
+                    timezone=None, message=None):
         for name, value in other.items():
         for name, value in other.items():
-            self[b'/'.join((base, name))] = value
+            self.set_if_equals(b'/'.join((base, name)), None, value, message=message)
 
 
     def allkeys(self):
     def allkeys(self):
         """All refs present in this container."""
         """All refs present in this container."""
@@ -256,7 +271,8 @@ class RefsContainer(object):
             raise KeyError(name)
             raise KeyError(name)
         return sha
         return sha
 
 
-    def set_if_equals(self, name, old_ref, new_ref):
+    def set_if_equals(self, name, old_ref, new_ref, committer=None,
+                      timestamp=None, timezone=None, message=None):
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
         This method follows all symbolic references if applicable for the
         This method follows all symbolic references if applicable for the
@@ -267,12 +283,18 @@ class RefsContainer(object):
         :param old_ref: The old sha the refname must refer to, or None to set
         :param old_ref: The old sha the refname must refer to, or None to set
             unconditionally.
             unconditionally.
         :param new_ref: The new sha the refname will refer to.
         :param new_ref: The new sha the refname will refer to.
+        :param message: Message for reflog
         :return: True if the set was successful, False otherwise.
         :return: True if the set was successful, False otherwise.
         """
         """
         raise NotImplementedError(self.set_if_equals)
         raise NotImplementedError(self.set_if_equals)
 
 
     def add_if_new(self, name, ref):
     def add_if_new(self, name, ref):
-        """Add a new reference only if it does not already exist."""
+        """Add a new reference only if it does not already exist.
+
+        :param name: Ref name
+        :param ref: Ref value
+        :param message: Message for reflog
+        """
         raise NotImplementedError(self.add_if_new)
         raise NotImplementedError(self.add_if_new)
 
 
     def __setitem__(self, name, ref):
     def __setitem__(self, name, ref):
@@ -289,7 +311,8 @@ class RefsContainer(object):
         """
         """
         self.set_if_equals(name, None, ref)
         self.set_if_equals(name, None, ref)
 
 
-    def remove_if_equals(self, name, old_ref):
+    def remove_if_equals(self, name, old_ref, committer=None,
+                         timestamp=None, timezone=None, message=None):
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
         This method does not follow symbolic references, even if applicable for
         This method does not follow symbolic references, even if applicable for
@@ -299,6 +322,7 @@ class RefsContainer(object):
         :param name: The refname to delete.
         :param name: The refname to delete.
         :param old_ref: The old sha the refname must refer to, or None to
         :param old_ref: The old sha the refname must refer to, or None to
             delete unconditionally.
             delete unconditionally.
+        :param message: Message for reflog
         :return: True if the delete was successful, False otherwise.
         :return: True if the delete was successful, False otherwise.
         """
         """
         raise NotImplementedError(self.remove_if_equals)
         raise NotImplementedError(self.remove_if_equals)
@@ -340,7 +364,8 @@ class DictRefsContainer(RefsContainer):
     threadsafe.
     threadsafe.
     """
     """
 
 
-    def __init__(self, refs):
+    def __init__(self, refs, logger=None):
+        super(DictRefsContainer, self).__init__(logger=logger)
         self._refs = refs
         self._refs = refs
         self._peeled = {}
         self._peeled = {}
 
 
@@ -353,31 +378,46 @@ class DictRefsContainer(RefsContainer):
     def get_packed_refs(self):
     def get_packed_refs(self):
         return {}
         return {}
 
 
-    def set_symbolic_ref(self, name, other):
+    def set_symbolic_ref(self, name, other, committer=None,
+                         timestamp=None, timezone=None, message=None):
+        old = self.follow(name)[-1]
         self._refs[name] = SYMREF + other
         self._refs[name] = SYMREF + other
+        self._log(name, old, old, committer=committer, timestamp=timestamp,
+                  timezone=timezone, message=message)
 
 
-    def set_if_equals(self, name, old_ref, new_ref):
+    def set_if_equals(self, name, old_ref, new_ref, committer=None,
+                      timestamp=None, timezone=None, message=None):
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
             return False
             return False
         realnames, _ = self.follow(name)
         realnames, _ = self.follow(name)
         for realname in realnames:
         for realname in realnames:
             self._check_refname(realname)
             self._check_refname(realname)
+            old = self._refs.get(realname)
             self._refs[realname] = new_ref
             self._refs[realname] = new_ref
+            self._log(realname, old, new_ref, committer=committer,
+                      timestamp=timestamp, timezone=timezone, message=message)
         return True
         return True
 
 
-    def add_if_new(self, name, ref):
+    def add_if_new(self, name, ref, committer=None, timestamp=None,
+                   timezone=None, message=None):
         if name in self._refs:
         if name in self._refs:
             return False
             return False
         self._refs[name] = ref
         self._refs[name] = ref
+        self._log(name, None, ref, committer=committer, timestamp=timestamp,
+                  timezone=timezone, message=message)
         return True
         return True
 
 
-    def remove_if_equals(self, name, old_ref):
+    def remove_if_equals(self, name, old_ref, committer=None, timestamp=None,
+                         timezone=None, message=None):
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
             return False
             return False
         try:
         try:
-            del self._refs[name]
+            old = self._refs.pop(name)
         except KeyError:
         except KeyError:
             pass
             pass
+        else:
+            self._log(name, old, None, committer=committer,
+                      timestamp=timestamp, timezone=timezone, message=message)
         return True
         return True
 
 
     def get_peeled(self, name):
     def get_peeled(self, name):
@@ -431,7 +471,8 @@ class InfoRefsContainer(RefsContainer):
 class DiskRefsContainer(RefsContainer):
 class DiskRefsContainer(RefsContainer):
     """Refs container that reads refs from disk."""
     """Refs container that reads refs from disk."""
 
 
-    def __init__(self, path, worktree_path=None):
+    def __init__(self, path, worktree_path=None, logger=None):
+        super(DiskRefsContainer, self).__init__(logger=logger)
         self.path = path
         self.path = path
         self.worktree_path = worktree_path or path
         self.worktree_path = worktree_path or path
         self._packed_refs = None
         self._packed_refs = None
@@ -589,11 +630,13 @@ class DiskRefsContainer(RefsContainer):
         finally:
         finally:
             f.abort()
             f.abort()
 
 
-    def set_symbolic_ref(self, name, other):
+    def set_symbolic_ref(self, name, other, committer=None, timestamp=None,
+                         timezone=None, message=None):
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
         :param name: Name of the ref to set
         :param name: Name of the ref to set
         :param other: Name of the ref to point at
         :param other: Name of the ref to point at
+        :param message: Optional message to describe the change
         """
         """
         self._check_refname(name)
         self._check_refname(name)
         self._check_refname(other)
         self._check_refname(other)
@@ -605,10 +648,16 @@ class DiskRefsContainer(RefsContainer):
             except (IOError, OSError):
             except (IOError, OSError):
                 f.abort()
                 f.abort()
                 raise
                 raise
+            else:
+                sha = self.follow(name)[-1]
+                self._log(name, sha, sha, committer=committer,
+                          timestamp=timestamp, timezone=timezone,
+                          message=message)
         finally:
         finally:
             f.close()
             f.close()
 
 
-    def set_if_equals(self, name, old_ref, new_ref):
+    def set_if_equals(self, name, old_ref, new_ref, committer=None,
+                      timestamp=None, timezone=None, message=None):
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
         This method follows all symbolic references, and can be used to perform
         This method follows all symbolic references, and can be used to perform
@@ -618,6 +667,7 @@ class DiskRefsContainer(RefsContainer):
         :param old_ref: The old sha the refname must refer to, or None to set
         :param old_ref: The old sha the refname must refer to, or None to set
             unconditionally.
             unconditionally.
         :param new_ref: The new sha the refname will refer to.
         :param new_ref: The new sha the refname will refer to.
+        :param message: Set message for reflog
         :return: True if the set was successful, False otherwise.
         :return: True if the set was successful, False otherwise.
         """
         """
         self._check_refname(name)
         self._check_refname(name)
@@ -647,9 +697,12 @@ class DiskRefsContainer(RefsContainer):
             except (OSError, IOError):
             except (OSError, IOError):
                 f.abort()
                 f.abort()
                 raise
                 raise
+            self._log(realname, old_ref, new_ref, committer=committer,
+                      timestamp=timestamp, timezone=timezone, message=message)
         return True
         return True
 
 
-    def add_if_new(self, name, ref):
+    def add_if_new(self, name, ref, committer=None, timestamp=None,
+                   timezone=None, message=None):
         """Add a new reference only if it does not already exist.
         """Add a new reference only if it does not already exist.
 
 
         This method follows symrefs, and only ensures that the last ref in the
         This method follows symrefs, and only ensures that the last ref in the
@@ -657,6 +710,7 @@ class DiskRefsContainer(RefsContainer):
 
 
         :param name: The refname to set.
         :param name: The refname to set.
         :param ref: The new sha the refname will refer to.
         :param ref: The new sha the refname will refer to.
+        :param message: Optional message for reflog
         :return: True if the add was successful, False otherwise.
         :return: True if the add was successful, False otherwise.
         """
         """
         try:
         try:
@@ -678,9 +732,14 @@ class DiskRefsContainer(RefsContainer):
             except (OSError, IOError):
             except (OSError, IOError):
                 f.abort()
                 f.abort()
                 raise
                 raise
+            else:
+                self._log(name, None, ref, committer=committer,
+                          timestamp=timestamp, timezone=timezone,
+                          message=message)
         return True
         return True
 
 
-    def remove_if_equals(self, name, old_ref):
+    def remove_if_equals(self, name, old_ref, committer=None, timestamp=None,
+                         timezone=None, message=None):
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
         This method does not follow symbolic references. It can be used to
         This method does not follow symbolic references. It can be used to
@@ -689,6 +748,7 @@ class DiskRefsContainer(RefsContainer):
         :param name: The refname to delete.
         :param name: The refname to delete.
         :param old_ref: The old sha the refname must refer to, or None to
         :param old_ref: The old sha the refname must refer to, or None to
             delete unconditionally.
             delete unconditionally.
+        :param message: Optional message
         :return: True if the delete was successful, False otherwise.
         :return: True if the delete was successful, False otherwise.
         """
         """
         self._check_refname(name)
         self._check_refname(name)
@@ -709,6 +769,8 @@ class DiskRefsContainer(RefsContainer):
                 if e.errno != errno.ENOENT:
                 if e.errno != errno.ENOENT:
                     raise
                     raise
             self._remove_packed_ref(name)
             self._remove_packed_ref(name)
+            self._log(name, old_ref, None, committer=committer,
+                      timestamp=timestamp, timezone=timezone, message=message)
         finally:
         finally:
             # never write, we just wanted the lock
             # never write, we just wanted the lock
             f.abort()
             f.abort()

+ 69 - 16
dulwich/repo.py

@@ -33,6 +33,7 @@ import errno
 import os
 import os
 import sys
 import sys
 import stat
 import stat
+import time
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     NoIndexPresent,
     NoIndexPresent,
@@ -516,9 +517,26 @@ class BaseRepo(object):
     def _get_user_identity(self):
     def _get_user_identity(self):
         """Determine the identity to use for new commits.
         """Determine the identity to use for new commits.
         """
         """
+        user = os.environ.get("GIT_COMMITTER_NAME")
+        email = os.environ.get("GIT_COMMITTER_EMAIL")
         config = self.get_config_stack()
         config = self.get_config_stack()
-        return (config.get((b"user", ), b"name") + b" <" +
-                config.get((b"user", ), b"email") + b">")
+        if user is None:
+            try:
+                user = config.get((b"user", ), b"name")
+            except KeyError:
+                user = None
+        if email is None:
+            try:
+                email = config.get((b"user", ), b"email")
+            except KeyError:
+                email = None
+        if user is None:
+            import getpass
+            user = getpass.getuser()
+        if email is None:
+            import getpass, socket
+            email = b"%s@%s" % (getpass.getuser(), socket.gethostname())
+        return (user + b" <" + email + b">")
 
 
     def _add_graftpoints(self, updated_graftpoints):
     def _add_graftpoints(self, updated_graftpoints):
         """Add or modify graftpoints
         """Add or modify graftpoints
@@ -585,8 +603,6 @@ class BaseRepo(object):
             # FIXME: Read merge heads from .git/MERGE_HEADS
             # FIXME: Read merge heads from .git/MERGE_HEADS
             merge_heads = []
             merge_heads = []
         if committer is None:
         if committer is None:
-            # FIXME: Support GIT_COMMITTER_NAME/GIT_COMMITTER_EMAIL environment
-            # variables
             committer = self._get_user_identity()
             committer = self._get_user_identity()
         c.committer = committer
         c.committer = committer
         if commit_timestamp is None:
         if commit_timestamp is None:
@@ -633,11 +649,17 @@ class BaseRepo(object):
                 old_head = self.refs[ref]
                 old_head = self.refs[ref]
                 c.parents = [old_head] + merge_heads
                 c.parents = [old_head] + merge_heads
                 self.object_store.add_object(c)
                 self.object_store.add_object(c)
-                ok = self.refs.set_if_equals(ref, old_head, c.id)
+                ok = self.refs.set_if_equals(
+                    ref, old_head, c.id, message=b"commit: " + message,
+                    committer=committer, timestamp=commit_timestamp,
+                    timezone=commit_timezone)
             except KeyError:
             except KeyError:
                 c.parents = merge_heads
                 c.parents = merge_heads
                 self.object_store.add_object(c)
                 self.object_store.add_object(c)
-                ok = self.refs.add_if_new(ref, c.id)
+                ok = self.refs.add_if_new(ref, c.id,
+                        message=b"commit: " + message,
+                        committer=committer, timestamp=commit_timestamp,
+                        timezone=commit_timezone)
             if not ok:
             if not ok:
                 # Fail if the atomic compare-and-swap failed, leaving the
                 # Fail if the atomic compare-and-swap failed, leaving the
                 # commit and all its objects as garbage.
                 # commit and all its objects as garbage.
@@ -707,7 +729,8 @@ class Repo(BaseRepo):
         self.path = root
         self.path = root
         object_store = DiskObjectStore(
         object_store = DiskObjectStore(
             os.path.join(self.commondir(), OBJECTDIR))
             os.path.join(self.commondir(), OBJECTDIR))
-        refs = DiskRefsContainer(self.commondir(), self._controldir)
+        refs = DiskRefsContainer(self.commondir(), self._controldir,
+                                 logger=self._write_reflog)
         BaseRepo.__init__(self, object_store, refs)
         BaseRepo.__init__(self, object_store, refs)
 
 
         self._graftpoints = {}
         self._graftpoints = {}
@@ -726,6 +749,25 @@ class Repo(BaseRepo):
         self.hooks['commit-msg'] = CommitMsgShellHook(self.controldir())
         self.hooks['commit-msg'] = CommitMsgShellHook(self.controldir())
         self.hooks['post-commit'] = PostCommitShellHook(self.controldir())
         self.hooks['post-commit'] = PostCommitShellHook(self.controldir())
 
 
+    def _write_reflog(self, ref, old_sha, new_sha, committer, timestamp,
+                      timezone, message):
+        from .reflog import format_reflog_line
+        path = os.path.join(self.controldir(), 'logs', ref)
+        try:
+            os.makedirs(os.path.dirname(path))
+        except OSError, e:
+            if e.errno != errno.EEXIST:
+                raise
+        if committer is None:
+            committer = self._get_user_identity()
+        if timestamp is None:
+            timestamp = int(time.time())
+        if timezone is None:
+            timezone = 0  # FIXME
+        with open(path, 'ab') as f:
+            f.write(format_reflog_line(old_sha, new_sha, committer,
+                    timestamp, timezone, message) + b'\n')
+
     @classmethod
     @classmethod
     def discover(cls, start='.'):
     def discover(cls, start='.'):
         """Iterate parent directories to discover a repository
         """Iterate parent directories to discover a repository
@@ -896,18 +938,23 @@ class Repo(BaseRepo):
         else:
         else:
             target = self.init_bare(target_path, mkdir=mkdir)
             target = self.init_bare(target_path, mkdir=mkdir)
         self.fetch(target)
         self.fetch(target)
+        encoded_path = self.path
+        if not isinstance(encoded_path, bytes):
+            encoded_path = encoded_path.encode(sys.getfilesystemencoding())
+        ref_message = b"clone: from " + encoded_path
         target.refs.import_refs(
         target.refs.import_refs(
-            b'refs/remotes/' + origin, self.refs.as_dict(b'refs/heads'))
+            b'refs/remotes/' + origin, self.refs.as_dict(b'refs/heads'),
+            message=ref_message)
         target.refs.import_refs(
         target.refs.import_refs(
-            b'refs/tags', self.refs.as_dict(b'refs/tags'))
+            b'refs/tags', self.refs.as_dict(b'refs/tags'),
+            message=ref_message)
         try:
         try:
-            target.refs.add_if_new(DEFAULT_REF, self.refs[DEFAULT_REF])
+            target.refs.add_if_new(
+                    DEFAULT_REF, self.refs[DEFAULT_REF],
+                    message=ref_message)
         except KeyError:
         except KeyError:
             pass
             pass
         target_config = target.get_config()
         target_config = target.get_config()
-        encoded_path = self.path
-        if not isinstance(encoded_path, bytes):
-            encoded_path = encoded_path.encode(sys.getfilesystemencoding())
         target_config.set((b'remote', b'origin'), b'url', encoded_path)
         target_config.set((b'remote', b'origin'), b'url', encoded_path)
         target_config.set((b'remote', b'origin'), b'fetch',
         target_config.set((b'remote', b'origin'), b'fetch',
                           b'+refs/heads/*:refs/remotes/origin/*')
                           b'+refs/heads/*:refs/remotes/origin/*')
@@ -916,7 +963,8 @@ class Repo(BaseRepo):
         # Update target head
         # Update target head
         head_chain, head_sha = self.refs.follow(b'HEAD')
         head_chain, head_sha = self.refs.follow(b'HEAD')
         if head_chain and head_sha is not None:
         if head_chain and head_sha is not None:
-            target.refs.set_symbolic_ref(b'HEAD', head_chain[-1])
+            target.refs.set_symbolic_ref(b'HEAD', head_chain[-1],
+                                         message=ref_message)
             target[b'HEAD'] = head_sha
             target[b'HEAD'] = head_sha
 
 
             if not bare:
             if not bare:
@@ -1092,12 +1140,17 @@ class MemoryRepo(BaseRepo):
 
 
     def __init__(self):
     def __init__(self):
         from dulwich.config import ConfigFile
         from dulwich.config import ConfigFile
-        BaseRepo.__init__(self, MemoryObjectStore(), DictRefsContainer({}))
+        self._reflog = []
+        refs_container = DictRefsContainer({}, logger=self._append_reflog)
+        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)
         self._named_files = {}
         self._named_files = {}
         self.bare = True
         self.bare = True
         self._config = ConfigFile()
         self._config = ConfigFile()
         self._description = None
         self._description = None
 
 
+    def _append_reflog(self, *args):
+        self._reflog.append(args)
+
     def set_description(self, description):
     def set_description(self, description):
         self._description = description
         self._description = description
 
 
@@ -1161,6 +1214,6 @@ class MemoryRepo(BaseRepo):
         for obj in objects:
         for obj in objects:
             ret.object_store.add_object(obj)
             ret.object_store.add_object(obj)
         for refname, sha in refs.items():
         for refname, sha in refs.items():
-            ret.refs[refname] = sha
+            ret.refs.add_if_new(refname, sha)
         ret._init_files(bare=True)
         ret._init_files(bare=True)
         return ret
         return ret

+ 2 - 2
dulwich/tests/test_repository.py

@@ -758,11 +758,11 @@ class BuildRepoRootTests(TestCase):
     def test_commit_fail_ref(self):
     def test_commit_fail_ref(self):
         r = self._repo
         r = self._repo
 
 
-        def set_if_equals(name, old_ref, new_ref):
+        def set_if_equals(name, old_ref, new_ref, **kwargs):
             return False
             return False
         r.refs.set_if_equals = set_if_equals
         r.refs.set_if_equals = set_if_equals
 
 
-        def add_if_new(name, new_ref):
+        def add_if_new(name, new_ref, **kwargs):
             self.fail('Unexpected call to add_if_new')
             self.fail('Unexpected call to add_if_new')
         r.refs.add_if_new = add_if_new
         r.refs.add_if_new = add_if_new