浏览代码

Allow non-os file-like objects passed to ShaFile.from_file.

Jelmer Vernooij 15 年之前
父节点
当前提交
87fe92ea91
共有 2 个文件被更改,包括 24 次插入18 次删除
  1. 24 16
      dulwich/objects.py
  2. 0 2
      dulwich/tests/test_objects.py

+ 24 - 16
dulwich/objects.py

@@ -184,9 +184,7 @@ class ShaFile(object):
         obj_class = object_class(type_name)
         if not obj_class:
             raise ObjectFormatException("Not a known type: %s" % type_name)
-        obj = obj_class()
-        obj._filename = f.name
-        return obj
+        return obj_class()
 
     def _parse_legacy_object(self, f):
         """Parse a legacy object, setting the raw string."""
@@ -233,8 +231,13 @@ class ShaFile(object):
     def _ensure_parsed(self):
         if self._needs_parsing:
             if not self._chunked_text:
-                assert self._filename, "ShaFile needs either text or filename"
-                self._parse_file()
+                if self._file is not None:
+                    self._parse_file(self._file)
+                elif self._path is not None:
+                    self._parse_path()
+                else:
+                    raise AssertionError(
+                        "ShaFile needs either text or filename")
             self._deserialize(self._chunked_text)
             self._needs_parsing = False
 
@@ -257,9 +260,7 @@ class ShaFile(object):
         obj_class = object_class(num_type)
         if not obj_class:
             raise ObjectFormatException("Not a known type: %d" % num_type)
-        obj = obj_class()
-        obj._filename = f.name
-        return obj
+        return obj_class()
 
     def _parse_object(self, f):
         """Parse a new style object, setting self._text."""
@@ -295,7 +296,8 @@ class ShaFile(object):
     def __init__(self):
         """Don't call this directly"""
         self._sha = None
-        self._filename = None
+        self._path = None
+        self._file = None
         self._chunked_text = []
         self._needs_parsing = False
         self._needs_serialization = True
@@ -306,23 +308,28 @@ class ShaFile(object):
     def _serialize(self):
         raise NotImplementedError(self._serialize)
 
-    def _parse_file(self):
-        f = GitFile(self._filename, 'rb')
+    def _parse_path(self):
+        f = GitFile(self._path, 'rb')
         try:
-            magic = f.read(2)
-            if self._is_legacy_object(magic):
-                self._parse_legacy_object(f)
-            else:
-                self._parse_object(f)
+            self._parse_file(f)
         finally:
             f.close()
 
+    def _parse_file(self, f):
+        magic = f.read(2)
+        if self._is_legacy_object(magic):
+            self._parse_legacy_object(f)
+        else:
+            self._parse_object(f)
+
     @classmethod
     def from_path(cls, path):
         f = GitFile(path, 'rb')
         try:
             obj = cls.from_file(f)
+            obj._path = path
             obj._sha = FixedSha(filename_to_hex(path))
+            obj._file = None
             return obj
         finally:
             f.close()
@@ -335,6 +342,7 @@ class ShaFile(object):
             obj._sha = None
             obj._needs_parsing = True
             obj._needs_serialization = True
+            obj._file = f
             return obj
         except (IndexError, ValueError), e:
             raise ObjectFormatException("invalid object header")

+ 0 - 2
dulwich/tests/test_objects.py

@@ -177,7 +177,6 @@ class BlobReadTests(unittest.TestCase):
         self.assertEqual(t.tag_time, 1231203091)
         self.assertEqual(t.message, 'This is a signed tag\n-----BEGIN PGP SIGNATURE-----\nVersion: GnuPG v1.4.9 (GNU/Linux)\n\niEYEABECAAYFAkliqx8ACgkQqSMmLy9u/kcx5ACfakZ9NnPl02tOyYP6pkBoEkU1\n5EcAn0UFgokaSvS371Ym/4W9iJj6vh3h\n=ql7y\n-----END PGP SIGNATURE-----\n')
   
-  
     def test_read_commit_from_file(self):
         sha = '60dacdc733de308bb77bb76ce0fb0f9b44c9769e'
         c = self.commit(sha)
@@ -425,7 +424,6 @@ class TreeTests(ShaFileCheckTests):
     def _do_test_parse_tree(self, parse_tree):
         dir = os.path.join(os.path.dirname(__file__), 'data', 'trees')
         o = Tree.from_path(hex_to_filename(dir, tree_sha))
-        o._parse_file()
         self.assertEquals([('a', 0100644, a_sha), ('b', 0100644, b_sha)],
                           list(parse_tree(o.as_raw_string())))