Browse Source

Close files for Pack objects.

Gary van der Merwe 11 years ago
parent
commit
e00ce69f19
3 changed files with 162 additions and 143 deletions
  1. 12 0
      dulwich/pack.py
  2. 15 15
      dulwich/tests/compat/test_pack.py
  3. 135 128
      dulwich/tests/test_pack.py

+ 12 - 0
dulwich/pack.py

@@ -989,6 +989,12 @@ class PackData(object):
     def close(self):
         self._file.close()
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
     def _get_size(self):
         if self._size is not None:
             return self._size
@@ -1801,6 +1807,12 @@ class Pack(object):
         if self._idx is not None:
             self._idx.close()
 
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
     def __eq__(self, other):
         return isinstance(self, type(other)) and self.index == other.index
 

+ 15 - 15
dulwich/tests/compat/test_pack.py

@@ -48,19 +48,19 @@ class TestPack(PackTests):
         self.addCleanup(shutil.rmtree, self._tempdir)
 
     def test_copy(self):
-        origpack = self.get_pack(pack1_sha)
-        self.assertSucceeds(origpack.index.check)
-        pack_path = os.path.join(self._tempdir, "Elch")
-        write_pack(pack_path, origpack.pack_tuples())
-        output = run_git_or_fail(['verify-pack', '-v', pack_path])
+        with self.get_pack(pack1_sha) as origpack:
+            self.assertSucceeds(origpack.index.check)
+            pack_path = os.path.join(self._tempdir, "Elch")
+            write_pack(pack_path, origpack.pack_tuples())
+            output = run_git_or_fail(['verify-pack', '-v', pack_path])
 
-        pack_shas = set()
-        for line in output.splitlines():
-            sha = line[:40]
-            try:
-                binascii.unhexlify(sha)
-            except (TypeError, binascii.Error):
-                continue  # non-sha line
-            pack_shas.add(sha)
-        orig_shas = set(o.id for o in origpack.iterobjects())
-        self.assertEqual(orig_shas, pack_shas)
+            pack_shas = set()
+            for line in output.splitlines():
+                sha = line[:40]
+                try:
+                    binascii.unhexlify(sha)
+                except (TypeError, binascii.Error):
+                    continue  # non-sha line
+                pack_shas.add(sha)
+            orig_shas = set(o.id for o in origpack.iterobjects())
+            self.assertEqual(orig_shas, pack_shas)

+ 135 - 128
dulwich/tests/test_pack.py

@@ -184,64 +184,64 @@ class TestPackData(PackTests):
     """Tests getting the data from the packfile."""
 
     def test_create_pack(self):
-        self.get_pack_data(pack1_sha)
+        self.get_pack_data(pack1_sha).close()
 
     def test_from_file(self):
         path = os.path.join(self.datadir, 'pack-%s.pack' % pack1_sha)
         PackData.from_file(open(path), os.path.getsize(path))
 
     def test_pack_len(self):
-        p = self.get_pack_data(pack1_sha)
-        self.assertEqual(3, len(p))
+        with self.get_pack_data(pack1_sha) as p:
+            self.assertEqual(3, len(p))
 
     def test_index_check(self):
-        p = self.get_pack_data(pack1_sha)
-        self.assertSucceeds(p.check)
+        with self.get_pack_data(pack1_sha) as p:
+            self.assertSucceeds(p.check)
 
     def test_iterobjects(self):
-        p = self.get_pack_data(pack1_sha)
-        commit_data = ('tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\n'
-                       'author James Westby <jw+debian@jameswestby.net> '
-                       '1174945067 +0100\n'
-                       'committer James Westby <jw+debian@jameswestby.net> '
-                       '1174945067 +0100\n'
-                       '\n'
-                       'Test commit\n')
-        blob_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
-        tree_data = '100644 a\0%s' % hex_to_sha(blob_sha)
-        actual = []
-        for offset, type_num, chunks, crc32 in p.iterobjects():
-            actual.append((offset, type_num, ''.join(chunks), crc32))
-        self.assertEqual([
-          (12, 1, commit_data, 3775879613),
-          (138, 2, tree_data, 912998690),
-          (178, 3, 'test 1\n', 1373561701)
-          ], actual)
+        with self.get_pack_data(pack1_sha) as p:
+            commit_data = ('tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\n'
+                           'author James Westby <jw+debian@jameswestby.net> '
+                           '1174945067 +0100\n'
+                           'committer James Westby <jw+debian@jameswestby.net> '
+                           '1174945067 +0100\n'
+                           '\n'
+                           'Test commit\n')
+            blob_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
+            tree_data = '100644 a\0%s' % hex_to_sha(blob_sha)
+            actual = []
+            for offset, type_num, chunks, crc32 in p.iterobjects():
+                actual.append((offset, type_num, ''.join(chunks), crc32))
+            self.assertEqual([
+              (12, 1, commit_data, 3775879613),
+              (138, 2, tree_data, 912998690),
+              (178, 3, 'test 1\n', 1373561701)
+              ], actual)
 
     def test_iterentries(self):
-        p = self.get_pack_data(pack1_sha)
-        entries = set((sha_to_hex(s), o, c) for s, o, c in p.iterentries())
-        self.assertEqual(set([
-          ('6f670c0fb53f9463760b7295fbb814e965fb20c8', 178, 1373561701),
-          ('b2a2766a2879c209ab1176e7e778b81ae422eeaa', 138, 912998690),
-          ('f18faa16531ac570a3fdc8c7ca16682548dafd12', 12, 3775879613),
-          ]), entries)
+        with self.get_pack_data(pack1_sha) as p:
+            entries = set((sha_to_hex(s), o, c) for s, o, c in p.iterentries())
+            self.assertEqual(set([
+              ('6f670c0fb53f9463760b7295fbb814e965fb20c8', 178, 1373561701),
+              ('b2a2766a2879c209ab1176e7e778b81ae422eeaa', 138, 912998690),
+              ('f18faa16531ac570a3fdc8c7ca16682548dafd12', 12, 3775879613),
+              ]), entries)
 
     def test_create_index_v1(self):
-        p = self.get_pack_data(pack1_sha)
-        filename = os.path.join(self.tempdir, 'v1test.idx')
-        p.create_index_v1(filename)
-        idx1 = load_pack_index(filename)
-        idx2 = self.get_pack_index(pack1_sha)
-        self.assertEqual(idx1, idx2)
+        with self.get_pack_data(pack1_sha) as p:
+            filename = os.path.join(self.tempdir, 'v1test.idx')
+            p.create_index_v1(filename)
+            idx1 = load_pack_index(filename)
+            idx2 = self.get_pack_index(pack1_sha)
+            self.assertEqual(idx1, idx2)
 
     def test_create_index_v2(self):
-        p = self.get_pack_data(pack1_sha)
-        filename = os.path.join(self.tempdir, 'v2test.idx')
-        p.create_index_v2(filename)
-        idx1 = load_pack_index(filename)
-        idx2 = self.get_pack_index(pack1_sha)
-        self.assertEqual(idx1, idx2)
+        with self.get_pack_data(pack1_sha) as p:
+            filename = os.path.join(self.tempdir, 'v2test.idx')
+            p.create_index_v2(filename)
+            idx1 = load_pack_index(filename)
+            idx2 = self.get_pack_index(pack1_sha)
+            self.assertEqual(idx1, idx2)
 
     def test_compute_file_sha(self):
         f = BytesIO('abcd1234wxyz')
@@ -261,46 +261,46 @@ class TestPackData(PackTests):
 class TestPack(PackTests):
 
     def test_len(self):
-        p = self.get_pack(pack1_sha)
-        self.assertEqual(3, len(p))
+        with self.get_pack(pack1_sha) as p:
+            self.assertEqual(3, len(p))
 
     def test_contains(self):
-        p = self.get_pack(pack1_sha)
-        self.assertTrue(tree_sha in p)
+        with self.get_pack(pack1_sha) as p:
+            self.assertTrue(tree_sha in p)
 
     def test_get(self):
-        p = self.get_pack(pack1_sha)
-        self.assertEqual(type(p[tree_sha]), Tree)
+        with self.get_pack(pack1_sha) as p:
+            self.assertEqual(type(p[tree_sha]), Tree)
 
     def test_iter(self):
-        p = self.get_pack(pack1_sha)
-        self.assertEqual(set([tree_sha, commit_sha, a_sha]), set(p))
+        with self.get_pack(pack1_sha) as p:
+            self.assertEqual(set([tree_sha, commit_sha, a_sha]), set(p))
 
     def test_iterobjects(self):
-        p = self.get_pack(pack1_sha)
-        expected = set([p[s] for s in [commit_sha, tree_sha, a_sha]])
-        self.assertEqual(expected, set(list(p.iterobjects())))
+        with self.get_pack(pack1_sha) as p:
+            expected = set([p[s] for s in [commit_sha, tree_sha, a_sha]])
+            self.assertEqual(expected, set(list(p.iterobjects())))
 
     def test_pack_tuples(self):
-        p = self.get_pack(pack1_sha)
-        tuples = p.pack_tuples()
-        expected = set([(p[s], None) for s in [commit_sha, tree_sha, a_sha]])
-        self.assertEqual(expected, set(list(tuples)))
-        self.assertEqual(expected, set(list(tuples)))
-        self.assertEqual(3, len(tuples))
+        with self.get_pack(pack1_sha) as p:
+            tuples = p.pack_tuples()
+            expected = set([(p[s], None) for s in [commit_sha, tree_sha, a_sha]])
+            self.assertEqual(expected, set(list(tuples)))
+            self.assertEqual(expected, set(list(tuples)))
+            self.assertEqual(3, len(tuples))
 
     def test_get_object_at(self):
         """Tests random access for non-delta objects"""
-        p = self.get_pack(pack1_sha)
-        obj = p[a_sha]
-        self.assertEqual(obj.type_name, 'blob')
-        self.assertEqual(obj.sha().hexdigest(), a_sha)
-        obj = p[tree_sha]
-        self.assertEqual(obj.type_name, 'tree')
-        self.assertEqual(obj.sha().hexdigest(), tree_sha)
-        obj = p[commit_sha]
-        self.assertEqual(obj.type_name, 'commit')
-        self.assertEqual(obj.sha().hexdigest(), commit_sha)
+        with self.get_pack(pack1_sha) as p:
+            obj = p[a_sha]
+            self.assertEqual(obj.type_name, 'blob')
+            self.assertEqual(obj.sha().hexdigest(), a_sha)
+            obj = p[tree_sha]
+            self.assertEqual(obj.type_name, 'tree')
+            self.assertEqual(obj.sha().hexdigest(), tree_sha)
+            obj = p[commit_sha]
+            self.assertEqual(obj.type_name, 'commit')
+            self.assertEqual(obj.sha().hexdigest(), commit_sha)
 
     def test_copy(self):
         origpack = self.get_pack(pack1_sha)
@@ -328,11 +328,11 @@ class TestPack(PackTests):
             origpack.close()
 
     def test_commit_obj(self):
-        p = self.get_pack(pack1_sha)
-        commit = p[commit_sha]
-        self.assertEqual('James Westby <jw+debian@jameswestby.net>',
-                          commit.author)
-        self.assertEqual([], commit.parents)
+        with self.get_pack(pack1_sha) as p:
+            commit = p[commit_sha]
+            self.assertEqual('James Westby <jw+debian@jameswestby.net>',
+                             commit.author)
+            self.assertEqual([], commit.parents)
 
     def _copy_pack(self, origpack):
         basename = os.path.join(self.tempdir, 'somepack')
@@ -340,10 +340,12 @@ class TestPack(PackTests):
         return Pack(basename)
 
     def test_keep_no_message(self):
-        p = self.get_pack(pack1_sha)
-        p = self._copy_pack(p)
+        with self.get_pack(pack1_sha) as p:
+            p = self._copy_pack(p)
+
+        with p:
+            keepfile_name = p.keep()
 
-        keepfile_name = p.keep()
         # file should exist
         self.assertTrue(os.path.exists(keepfile_name))
 
@@ -355,11 +357,12 @@ class TestPack(PackTests):
             f.close()
 
     def test_keep_message(self):
-        p = self.get_pack(pack1_sha)
-        p = self._copy_pack(p)
+        with self.get_pack(pack1_sha) as p:
+            p = self._copy_pack(p)
 
         msg = 'some message'
-        keepfile_name = p.keep(msg)
+        with p:
+            keepfile_name = p.keep(msg)
 
         # file should exist
         self.assertTrue(os.path.exists(keepfile_name))
@@ -373,46 +376,46 @@ class TestPack(PackTests):
             f.close()
 
     def test_name(self):
-        p = self.get_pack(pack1_sha)
-        self.assertEqual(pack1_sha, p.name())
+        with self.get_pack(pack1_sha) as p:
+            self.assertEqual(pack1_sha, p.name())
 
     def test_length_mismatch(self):
-        data = self.get_pack_data(pack1_sha)
-        index = self.get_pack_index(pack1_sha)
-        Pack.from_objects(data, index).check_length_and_checksum()
-
-        data._file.seek(12)
-        bad_file = BytesIO()
-        write_pack_header(bad_file, 9999)
-        bad_file.write(data._file.read())
-        bad_file = BytesIO(bad_file.getvalue())
-        bad_data = PackData('', file=bad_file)
-        bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index)
-        self.assertRaises(AssertionError, lambda: bad_pack.data)
-        self.assertRaises(AssertionError,
-                          lambda: bad_pack.check_length_and_checksum())
+        with self.get_pack_data(pack1_sha) as data:
+            index = self.get_pack_index(pack1_sha)
+            Pack.from_objects(data, index).check_length_and_checksum()
+
+            data._file.seek(12)
+            bad_file = BytesIO()
+            write_pack_header(bad_file, 9999)
+            bad_file.write(data._file.read())
+            bad_file = BytesIO(bad_file.getvalue())
+            bad_data = PackData('', file=bad_file)
+            bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index)
+            self.assertRaises(AssertionError, lambda: bad_pack.data)
+            self.assertRaises(AssertionError,
+                              lambda: bad_pack.check_length_and_checksum())
 
     def test_checksum_mismatch(self):
-        data = self.get_pack_data(pack1_sha)
-        index = self.get_pack_index(pack1_sha)
-        Pack.from_objects(data, index).check_length_and_checksum()
-
-        data._file.seek(0)
-        bad_file = BytesIO(data._file.read()[:-20] + ('\xff' * 20))
-        bad_data = PackData('', file=bad_file)
-        bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index)
-        self.assertRaises(ChecksumMismatch, lambda: bad_pack.data)
-        self.assertRaises(ChecksumMismatch, lambda:
-                          bad_pack.check_length_and_checksum())
+        with self.get_pack_data(pack1_sha) as data:
+            index = self.get_pack_index(pack1_sha)
+            Pack.from_objects(data, index).check_length_and_checksum()
+
+            data._file.seek(0)
+            bad_file = BytesIO(data._file.read()[:-20] + ('\xff' * 20))
+            bad_data = PackData('', file=bad_file)
+            bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index)
+            self.assertRaises(ChecksumMismatch, lambda: bad_pack.data)
+            self.assertRaises(ChecksumMismatch, lambda:
+                              bad_pack.check_length_and_checksum())
 
     def test_iterobjects_2(self):
-        p = self.get_pack(pack1_sha)
-        objs = dict((o.id, o) for o in p.iterobjects())
-        self.assertEqual(3, len(objs))
-        self.assertEqual(sorted(objs), sorted(p.index))
-        self.assertTrue(isinstance(objs[a_sha], Blob))
-        self.assertTrue(isinstance(objs[tree_sha], Tree))
-        self.assertTrue(isinstance(objs[commit_sha], Commit))
+        with self.get_pack(pack1_sha) as p:
+            objs = dict((o.id, o) for o in p.iterobjects())
+            self.assertEqual(3, len(objs))
+            self.assertEqual(sorted(objs), sorted(p.index))
+            self.assertTrue(isinstance(objs[a_sha], Blob))
+            self.assertTrue(isinstance(objs[tree_sha], Tree))
+            self.assertTrue(isinstance(objs[commit_sha], Commit))
 
 
 class TestThinPack(PackTests):
@@ -443,10 +446,10 @@ class TestThinPack(PackTests):
             f.close()
 
         # Index the new pack.
-        pack = self.make_pack(True)
-        data = PackData(pack._data_path)
-        data.pack = pack
-        data.create_index(self.pack_prefix + '.idx')
+        with self.make_pack(True) as pack:
+            with PackData(pack._data_path) as data:
+                data.pack = pack
+                data.create_index(self.pack_prefix + '.idx')
 
         del self.store[self.blobs['bar'].id]
 
@@ -456,18 +459,22 @@ class TestThinPack(PackTests):
             resolve_ext_ref=self.store.get_raw if resolve_ext_ref else None)
 
     def test_get_raw(self):
-        self.assertRaises(
-            KeyError, self.make_pack(False).get_raw, self.blobs['foo1234'].id)
-        self.assertEqual(
-            (3, 'foo1234'),
-            self.make_pack(True).get_raw(self.blobs['foo1234'].id))
+        with self.make_pack(False) as p:
+            self.assertRaises(
+                KeyError, p.get_raw, self.blobs['foo1234'].id)
+        with self.make_pack(True) as p:
+            self.assertEqual(
+                (3, 'foo1234'),
+                p.get_raw(self.blobs['foo1234'].id))
 
     def test_iterobjects(self):
-        self.assertRaises(KeyError, list, self.make_pack(False).iterobjects())
-        self.assertEqual(
-            sorted([self.blobs['foo1234'].id, self.blobs['bar'].id,
-                    self.blobs['bar2468'].id]),
-            sorted(o.id for o in self.make_pack(True).iterobjects()))
+        with self.make_pack(False) as p:
+            self.assertRaises(KeyError, list, p.iterobjects())
+        with self.make_pack(True) as p:
+            self.assertEqual(
+                sorted([self.blobs['foo1234'].id, self.blobs[b'bar'].id,
+                        self.blobs['bar2468'].id]),
+                sorted(o.id for o in p.iterobjects()))
 
 
 class WritePackTests(TestCase):