瀏覽代碼

Various minor rust improvements (#1390)

* Avoid copying names in sorted_tree_items
* Add some more tests for sorted_tree_items
* Make sure we're actually testing the rust implementations and not
accidentally the Python ones
Jelmer Vernooij 5 月之前
父節點
當前提交
15d6c817bc
共有 7 個文件被更改,包括 149 次插入83 次删除
  1. 2 0
      .github/workflows/pythontest.yml
  2. 5 5
      Cargo.lock
  3. 6 1
      crates/diff-tree/src/lib.rs
  4. 35 45
      crates/objects/src/lib.rs
  5. 23 11
      crates/pack/src/lib.rs
  6. 13 5
      dulwich/objects.py
  7. 65 16
      tests/test_objects.py

+ 2 - 0
.github/workflows/pythontest.yml

@@ -51,6 +51,8 @@ jobs:
       - name: Build
         run: |
           python setup.py build_ext -i
+        env:
+          RUSTFLAGS: "-D warnings"
       - name: codespell
         run: |
           pip install --upgrade codespell

+ 5 - 5
Cargo.lock

@@ -16,7 +16,7 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
 
 [[package]]
 name = "diff-tree-py"
-version = "0.22.2"
+version = "0.22.3"
 dependencies = [
  "pyo3",
 ]
@@ -56,7 +56,7 @@ dependencies = [
 
 [[package]]
 name = "objects-py"
-version = "0.22.2"
+version = "0.22.3"
 dependencies = [
  "memchr",
  "pyo3",
@@ -70,7 +70,7 @@ checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
 
 [[package]]
 name = "pack-py"
-version = "0.22.2"
+version = "0.22.3"
 dependencies = [
  "memchr",
  "pyo3",
@@ -84,9 +84,9 @@ checksum = "cc9c68a3f6da06753e9335d63e27f6b9754dd1920d941135b7ea8224f141adb2"
 
 [[package]]
 name = "proc-macro2"
-version = "1.0.87"
+version = "1.0.88"
 source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b3e4daa0dcf6feba26f985457cdf104d4b4256fc5a09547140f3631bb076b19a"
+checksum = "7c3a7fc5db1e57d5a779a352c8cdb57b29aa4c40cc69c3a68a7fedc815fbf2f9"
 dependencies = [
  "unicode-ident",
 ]

+ 6 - 1
crates/diff-tree/src/lib.rs

@@ -133,7 +133,12 @@ fn entry_path_cmp(entry1: &Bound<PyAny>, entry2: &Bound<PyAny>) -> PyResult<Orde
 }
 
 #[pyfunction]
-fn _merge_entries(py: Python, path: &[u8], tree1: &Bound<PyAny>, tree2: &Bound<PyAny>) -> PyResult<PyObject> {
+fn _merge_entries(
+    py: Python,
+    path: &[u8],
+    tree1: &Bound<PyAny>,
+    tree2: &Bound<PyAny>,
+) -> PyResult<PyObject> {
     let entries1 = tree_entries(path, tree1, py)?;
     let entries2 = tree_entries(path, tree2, py)?;
 

+ 35 - 45
crates/objects/src/lib.rs

@@ -19,7 +19,6 @@
  */
 
 use memchr::memchr;
-use std::borrow::Cow;
 
 use pyo3::exceptions::PyTypeError;
 use pyo3::import_exception;
@@ -30,6 +29,7 @@ import_exception!(dulwich.errors, ObjectFormatException);
 
 const S_IFDIR: u32 = 0o40000;
 
+#[inline]
 fn bytehex(byte: u8) -> u8 {
     match byte {
         0..=9 => byte + b'0',
@@ -58,36 +58,19 @@ fn parse_tree(
     let mut entries = Vec::new();
     let strict = strict.unwrap_or(false);
     while !text.is_empty() {
-        let mode_end = match memchr(b' ', text) {
-            Some(e) => e,
-            None => {
-                return Err(ObjectFormatException::new_err((
-                    "Missing terminator for mode",
-                )));
-            }
-        };
+        let mode_end = memchr(b' ', text)
+            .ok_or_else(|| ObjectFormatException::new_err(("Missing terminator for mode",)))?;
         let text_str = String::from_utf8_lossy(&text[..mode_end]).to_string();
-        let mode = match u32::from_str_radix(text_str.as_str(), 8) {
-            Ok(m) => m,
-            Err(e) => {
-                return Err(ObjectFormatException::new_err((format!(
-                    "invalid mode: {}",
-                    e
-                ),)));
-            }
-        };
+        let mode = u32::from_str_radix(text_str.as_str(), 8)
+            .map_err(|e| ObjectFormatException::new_err((format!("invalid mode: {}", e),)))?;
         if strict && text[0] == b'0' {
             return Err(ObjectFormatException::new_err((
                 "Illegal leading zero on mode",
             )));
         }
         text = &text[mode_end + 1..];
-        let namelen = match memchr(b'\0', text) {
-            Some(nl) => nl,
-            None => {
-                return Err(ObjectFormatException::new_err(("Missing trailing \\0",)));
-            }
-        };
+        let namelen = memchr(b'\0', text)
+            .ok_or_else(|| ObjectFormatException::new_err(("Missing trailing \\0",)))?;
         let name = &text[..namelen];
         if namelen + 20 >= text.len() {
             return Err(ObjectFormatException::new_err(("SHA truncated",)));
@@ -104,14 +87,20 @@ fn parse_tree(
     Ok(entries)
 }
 
-fn name_with_suffix(mode: u32, name: &[u8]) -> Cow<[u8]> {
-    if mode & S_IFDIR != 0 {
-        let mut v = name.to_vec();
-        v.push(b'/');
-        v.into()
-    } else {
-        name.into()
+fn cmp_with_suffix(a: (u32, &[u8]), b: (u32, &[u8])) -> std::cmp::Ordering {
+    let len = std::cmp::min(a.1.len(), b.1.len());
+    let cmp = a.1[..len].cmp(&b.1[..len]);
+    if cmp != std::cmp::Ordering::Equal {
+        return cmp;
     }
+
+    let c1 =
+        a.1.get(len)
+            .map_or_else(|| if a.0 & S_IFDIR != 0 { b'/' } else { 0 }, |&c| c);
+    let c2 =
+        b.1.get(len)
+            .map_or_else(|| if b.0 & S_IFDIR != 0 { b'/' } else { 0 }, |&c| c);
+    c1.cmp(&c2)
 }
 
 /// Iterate over a tree entries dictionary.
@@ -125,23 +114,24 @@ fn name_with_suffix(mode: u32, name: &[u8]) -> Cow<[u8]> {
 ///
 /// # Returns: Iterator over (name, mode, hexsha)
 #[pyfunction]
-fn sorted_tree_items(py: Python, entries: &Bound<PyDict>, name_order: bool) -> PyResult<Vec<PyObject>> {
-    let mut qsort_entries = Vec::new();
-    for (name, e) in entries.iter() {
-        let (mode, sha): (u32, Vec<u8>) = match e.extract() {
-            Ok(o) => o,
-            Err(e) => {
-                return Err(PyTypeError::new_err((format!("invalid type: {}", e),)));
-            }
-        };
-        qsort_entries.push((name.extract::<Vec<u8>>().unwrap(), mode, sha));
-    }
+fn sorted_tree_items(
+    py: Python,
+    entries: &Bound<PyDict>,
+    name_order: bool,
+) -> PyResult<Vec<PyObject>> {
+    let mut qsort_entries = entries
+        .iter()
+        .map(|(name, value)| -> PyResult<(Vec<u8>, u32, Vec<u8>)> {
+            let value = value
+                .extract::<(u32, Vec<u8>)>()
+                .map_err(|e| PyTypeError::new_err((format!("invalid type: {}", e),)))?;
+            Ok((name.extract::<Vec<u8>>().unwrap(), value.0, value.1))
+        })
+        .collect::<PyResult<Vec<(Vec<u8>, u32, Vec<u8>)>>>()?;
     if name_order {
         qsort_entries.sort_by(|a, b| a.0.cmp(&b.0));
     } else {
-        qsort_entries.sort_by(|a, b| {
-            name_with_suffix(a.1, a.0.as_slice()).cmp(&name_with_suffix(b.1, b.0.as_slice()))
-        });
+        qsort_entries.sort_by(|a, b| cmp_with_suffix((a.1, a.0.as_slice()), (b.1, b.0.as_slice())));
     }
     let objectsm = py.import_bound("dulwich.objects")?;
     let tree_entry_cls = objectsm.getattr("TreeEntry")?;

+ 23 - 11
crates/pack/src/lib.rs

@@ -18,9 +18,9 @@
  * License, Version 2.0.
  */
 
+use pyo3::exceptions::{PyTypeError, PyValueError};
 use pyo3::prelude::*;
-use pyo3::types::{PyList,PyBytes};
-use pyo3::exceptions::{PyValueError, PyTypeError};
+use pyo3::types::{PyBytes, PyList};
 
 pyo3::import_exception!(dulwich.errors, ApplyDeltaError);
 
@@ -39,7 +39,13 @@ fn py_is_sha(sha: &PyObject, py: Python) -> PyResult<bool> {
 }
 
 #[pyfunction]
-fn bisect_find_sha(py: Python, start: i32, end: i32, sha: Py<PyBytes>, unpack_name: PyObject) -> PyResult<Option<i32>> {
+fn bisect_find_sha(
+    py: Python,
+    start: i32,
+    end: i32,
+    sha: Py<PyBytes>,
+    unpack_name: PyObject,
+) -> PyResult<Option<i32>> {
     // Convert sha_obj to a byte slice
     let sha = sha.as_bytes(py);
     let sha_len = sha.len();
@@ -99,7 +105,10 @@ fn get_delta_header_size(delta: &[u8], index: &mut usize, length: usize) -> usiz
     size
 }
 
-fn py_chunked_as_string<'a>(py: Python<'a>, py_buf: &'a PyObject) -> PyResult<std::borrow::Cow<'a, [u8]>> {
+fn py_chunked_as_string<'a>(
+    py: Python<'a>,
+    py_buf: &'a PyObject,
+) -> PyResult<std::borrow::Cow<'a, [u8]>> {
     if let Ok(py_list) = py_buf.extract::<Bound<PyList>>(py) {
         let mut buf = Vec::new();
         for chunk in py_list.iter() {
@@ -108,14 +117,19 @@ fn py_chunked_as_string<'a>(py: Python<'a>, py_buf: &'a PyObject) -> PyResult<st
             } else if let Ok(chunk) = chunk.extract::<Vec<u8>>() {
                 buf.extend(chunk);
             } else {
-                return Err(PyTypeError::new_err(format!("chunk is not a byte string, but a {:?}", chunk.get_type().name())));
+                return Err(PyTypeError::new_err(format!(
+                    "chunk is not a byte string, but a {:?}",
+                    chunk.get_type().name()
+                )));
             }
         }
         Ok(buf.into())
     } else if py_buf.extract::<Bound<PyBytes>>(py).is_ok() {
         Ok(std::borrow::Cow::Borrowed(py_buf.extract::<&[u8]>(py)?))
     } else {
-        Err(PyTypeError::new_err("buf is not a string or a list of chunks"))
+        Err(PyTypeError::new_err(
+            "buf is not a string or a list of chunks",
+        ))
     }
 }
 
@@ -168,10 +182,7 @@ fn apply_delta(py: Python, py_src_buf: PyObject, py_delta: PyObject) -> PyResult
                 cp_size = 0x10000;
             }
 
-            if cp_off + cp_size < cp_size
-                || cp_off + cp_size > src_size
-                || cp_size > dest_size
-            {
+            if cp_off + cp_size < cp_size || cp_off + cp_size > src_size || cp_size > dest_size {
                 break;
             }
 
@@ -187,7 +198,8 @@ fn apply_delta(py: Python, py_src_buf: PyObject, py_delta: PyObject) -> PyResult
                 return Err(ApplyDeltaError::new_err("Not enough space to copy"));
             }
 
-            out[outindex..outindex + cmd as usize].copy_from_slice(&delta[index..index + cmd as usize]);
+            out[outindex..outindex + cmd as usize]
+                .copy_from_slice(&delta[index..index + cmd as usize]);
             outindex += cmd as usize;
             index += cmd as usize;
         } else {

+ 13 - 5
dulwich/objects.py

@@ -1027,19 +1027,19 @@ def sorted_tree_items(entries, name_order: bool):
         yield TreeEntry(name, mode, hexsha)
 
 
-def key_entry(entry) -> bytes:
+def key_entry(entry: Tuple[bytes, Tuple[int, ObjectID]]) -> bytes:
     """Sort key for tree entry.
 
     Args:
       entry: (name, value) tuple
     """
-    (name, value) = entry
-    if stat.S_ISDIR(value[0]):
+    (name, (mode, _sha)) = entry
+    if stat.S_ISDIR(mode):
         name += b"/"
     return name
 
 
-def key_entry_name_order(entry):
+def key_entry_name_order(entry: Tuple[bytes, Tuple[int, ObjectID]]) -> bytes:
     """Sort key for tree entry in name order."""
     return entry[0]
 
@@ -1667,6 +1667,14 @@ _parse_tree_py = parse_tree
 _sorted_tree_items_py = sorted_tree_items
 try:
     # Try to import Rust versions
-    from dulwich._objects import parse_tree, sorted_tree_items  # type: ignore
+    from dulwich._objects import (
+        parse_tree as _parse_tree_rs,
+    )
+    from dulwich._objects import (
+        sorted_tree_items as _sorted_tree_items_rs,
+    )
 except ImportError:
     pass
+else:
+    parse_tree = _parse_tree_rs
+    sorted_tree_items = _sorted_tree_items_rs

+ 65 - 16
tests/test_objects.py

@@ -45,13 +45,17 @@ from dulwich.objects import (
     format_timezone,
     hex_to_filename,
     hex_to_sha,
+    key_entry,
     object_class,
     parse_timezone,
-    parse_tree,
     pretty_format_tree_entry,
     sha_to_hex,
-    sorted_tree_items,
 )
+
+try:
+    from dulwich.objects import _parse_tree_rs, _sorted_tree_items_rs
+except ImportError:
+    _sorted_tree_items_rs = _parse_tree_rs = None
 from dulwich.tests.utils import (
     ext_functest_builder,
     functest_builder,
@@ -813,15 +817,37 @@ nHxksHfeNln9RKseIDcy4b2ATjhDNIJZARHNfr6oy4u3XPW4svRqtBsLoMiIeuI=
 
 
 _TREE_ITEMS = {
+    b"a-c": (0o100755, b"d80c186a03f423a81b39df39dc87fd269736ca86"),
     b"a.c": (0o100755, b"d80c186a03f423a81b39df39dc87fd269736ca86"),
+    b"aoc": (0o100755, b"d80c186a03f423a81b39df39dc87fd269736ca86"),
     b"a": (stat.S_IFDIR, b"d80c186a03f423a81b39df39dc87fd269736ca86"),
     b"a/c": (stat.S_IFDIR, b"d80c186a03f423a81b39df39dc87fd269736ca86"),
 }
 
 _SORTED_TREE_ITEMS = [
+    TreeEntry(b"a-c", 0o100755, b"d80c186a03f423a81b39df39dc87fd269736ca86"),
     TreeEntry(b"a.c", 0o100755, b"d80c186a03f423a81b39df39dc87fd269736ca86"),
     TreeEntry(b"a", stat.S_IFDIR, b"d80c186a03f423a81b39df39dc87fd269736ca86"),
     TreeEntry(b"a/c", stat.S_IFDIR, b"d80c186a03f423a81b39df39dc87fd269736ca86"),
+    TreeEntry(b"aoc", 0o100755, b"d80c186a03f423a81b39df39dc87fd269736ca86"),
+]
+
+
+_TREE_ITEMS_BUG_1325 = {
+    b"dir": (stat.S_IFDIR | 0o644, b"5944b31ff85b415573d1a43eb942e2dea30ab8be"),
+    b"dira": (0o100644, b"cf7a729ca69bfabd0995fc9b083e86a18215bd91"),
+}
+
+
+_SORTED_TREE_ITEMS_BUG_1325 = [
+    TreeEntry(
+        path=b"dir",
+        mode=stat.S_IFDIR | 0o644,
+        sha=b"5944b31ff85b415573d1a43eb942e2dea30ab8be",
+    ),
+    TreeEntry(
+        path=b"dira", mode=0o100644, sha=b"cf7a729ca69bfabd0995fc9b083e86a18215bd91"
+    ),
 ]
 
 
@@ -878,35 +904,47 @@ class TreeTests(ShaFileCheckTests):
         )
 
     test_parse_tree = functest_builder(_do_test_parse_tree, _parse_tree_py)
-    test_parse_tree_extension = ext_functest_builder(_do_test_parse_tree, parse_tree)
+    test_parse_tree_extension = ext_functest_builder(
+        _do_test_parse_tree, _parse_tree_rs
+    )
 
     def _do_test_sorted_tree_items(self, sorted_tree_items):
-        def do_sort(entries):
-            return list(sorted_tree_items(entries, False))
+        def do_sort(entries, name_order):
+            return list(sorted_tree_items(entries, name_order))
 
-        actual = do_sort(_TREE_ITEMS)
+        actual = do_sort(_TREE_ITEMS, False)
         self.assertEqual(_SORTED_TREE_ITEMS, actual)
         self.assertIsInstance(actual[0], TreeEntry)
 
+        actual = do_sort(_TREE_ITEMS_BUG_1325, False)
+        self.assertEqual(
+            key_entry((b"a", (0o40644, b"cf7a729ca69bfabd0995fc9b083e86a18215bd91"))),
+            b"a/",
+        )
+        self.assertEqual(_SORTED_TREE_ITEMS_BUG_1325, actual)
+        self.assertIsInstance(actual[0], TreeEntry)
+
         # C/Python implementations may differ in specific error types, but
         # should all error on invalid inputs.
         # For example, the Rust implementation has stricter type checks, so may
         # raise TypeError where the Python implementation raises
         # AttributeError.
         errors = (TypeError, ValueError, AttributeError)
-        self.assertRaises(errors, do_sort, b"foo")
-        self.assertRaises(errors, do_sort, {b"foo": (1, 2, 3)})
+        self.assertRaises(errors, do_sort, b"foo", False)
+        self.assertRaises(errors, do_sort, {b"foo": (1, 2, 3)}, False)
 
         myhexsha = b"d80c186a03f423a81b39df39dc87fd269736ca86"
-        self.assertRaises(errors, do_sort, {b"foo": (b"xxx", myhexsha)})
-        self.assertRaises(errors, do_sort, {b"foo": (0o100755, 12345)})
+        self.assertRaises(errors, do_sort, {b"foo": (b"xxx", myhexsha)}, False)
+        self.assertRaises(errors, do_sort, {b"foo": (0o100755, 12345)}, False)
 
     test_sorted_tree_items = functest_builder(
         _do_test_sorted_tree_items, _sorted_tree_items_py
     )
-    test_sorted_tree_items_extension = ext_functest_builder(
-        _do_test_sorted_tree_items, sorted_tree_items
-    )
+    if _sorted_tree_items_rs is not None:
+        assert _sorted_tree_items_rs != _sorted_tree_items_py
+        test_sorted_tree_items_extension = ext_functest_builder(
+            _do_test_sorted_tree_items, _sorted_tree_items_rs
+        )
 
     def _do_test_sorted_tree_items_name_order(self, sorted_tree_items):
         self.assertEqual(
@@ -916,6 +954,11 @@ class TreeTests(ShaFileCheckTests):
                     stat.S_IFDIR,
                     b"d80c186a03f423a81b39df39dc87fd269736ca86",
                 ),
+                TreeEntry(
+                    b"a-c",
+                    0o100755,
+                    b"d80c186a03f423a81b39df39dc87fd269736ca86",
+                ),
                 TreeEntry(
                     b"a.c",
                     0o100755,
@@ -926,6 +969,11 @@ class TreeTests(ShaFileCheckTests):
                     stat.S_IFDIR,
                     b"d80c186a03f423a81b39df39dc87fd269736ca86",
                 ),
+                TreeEntry(
+                    b"aoc",
+                    0o100755,
+                    b"d80c186a03f423a81b39df39dc87fd269736ca86",
+                ),
             ],
             list(sorted_tree_items(_TREE_ITEMS, True)),
         )
@@ -933,9 +981,10 @@ class TreeTests(ShaFileCheckTests):
     test_sorted_tree_items_name_order = functest_builder(
         _do_test_sorted_tree_items_name_order, _sorted_tree_items_py
     )
-    test_sorted_tree_items_name_order_extension = ext_functest_builder(
-        _do_test_sorted_tree_items_name_order, sorted_tree_items
-    )
+    if _sorted_tree_items_rs is not None:
+        test_sorted_tree_items_name_order_extension = ext_functest_builder(
+            _do_test_sorted_tree_items_name_order, _sorted_tree_items_rs
+        )
 
     def test_check(self):
         t = Tree