Procházet zdrojové kódy

Use urllib3 rather than urllib in dumb repo

Jelmer Vernooij před 1 měsícem
rodič
revize
9893ac33bf
4 změnil soubory, kde provedl 123 přidání a 82 odebrání
  1. 24 7
      dulwich/client.py
  2. 39 44
      dulwich/dumb.py
  3. 7 4
      tests/test_client.py
  4. 53 27
      tests/test_dumb.py

+ 24 - 7
dulwich/client.py

@@ -2547,7 +2547,14 @@ class AbstractHttpGitClient(GitClient):
                     return refs, server_capabilities, base_url, symrefs, peeled
             else:
                 self.protocol_version = 0  # dumb servers only support protocol v0
-                (refs, peeled) = split_peeled_refs(read_info_refs(resp))
+                # Read all the response data
+                data = b""
+                while True:
+                    chunk = read(4096)
+                    if not chunk:
+                        break
+                    data += chunk
+                (refs, peeled) = split_peeled_refs(read_info_refs(BytesIO(data)))
                 if ref_prefix is not None:
                     refs = filter_ref_prefix(refs, ref_prefix)
                 return refs, set(), base_url, {}, peeled
@@ -2701,9 +2708,10 @@ class AbstractHttpGitClient(GitClient):
             return FetchPackResult(refs, symrefs, agent)
         if self.dumb:
             # Use dumb HTTP protocol
-            from .dumb import DumbRemoteRepo
+            from .dumb import DumbRemoteHTTPRepo
 
-            dumb_repo = DumbRemoteRepo(url, self._http_request)
+            # Pass http_request function
+            dumb_repo = DumbRemoteHTTPRepo(url, self._http_request)
 
             # Fetch pack data from dumb remote
             pack_data_list = list(
@@ -2714,16 +2722,25 @@ class AbstractHttpGitClient(GitClient):
 
             # Write pack data
             if pack_data:
-                from .pack import pack_objects_to_data
+                from .pack import pack_objects_to_data, write_pack_data
 
                 # Convert unpacked objects to ShaFile objects for packing
                 objects = []
                 for unpacked in pack_data_list:
                     objects.append(unpacked.sha_file())
 
-                # Generate pack data
-                pack_bytes = pack_objects_to_data(objects)
-                pack_data(pack_bytes)
+                # Generate pack data and write it to a buffer
+                pack_buffer = BytesIO()
+                count, unpacked_iter = pack_objects_to_data(objects)
+                write_pack_data(
+                    pack_buffer.write,
+                    unpacked_iter,
+                    num_records=count,
+                    progress=progress,
+                )
+
+                # Pass the raw pack data to pack_data callback
+                pack_data(pack_buffer.getvalue())
 
             return FetchPackResult(refs, symrefs, agent)
         req_data = BytesIO()

+ 39 - 44
dulwich/dumb.py

@@ -23,24 +23,24 @@
 
 import os
 import tempfile
+import zlib
 from collections.abc import Iterator
 from io import BytesIO
 from typing import Optional
 from urllib.parse import urljoin
-import zlib
 
 from .errors import NotGitRepository, ObjectFormatException
 from .object_store import BaseObjectStore
 from .objects import (
     ZERO_SHA,
-    ObjectID,
-    ShaFile,
-    hex_to_sha,
-    sha_to_hex,
     Blob,
     Commit,
+    ObjectID,
+    ShaFile,
     Tag,
     Tree,
+    hex_to_sha,
+    sha_to_hex,
 )
 from .pack import Pack, PackIndex, UnpackedObject, load_pack_index_file
 from .refs import Ref, read_info_refs, split_peeled_refs
@@ -56,7 +56,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
         Args:
           base_url: Base URL of the remote repository (e.g. "https://example.com/repo.git/")
           http_request_func: Function to make HTTP requests, should accept (url, headers)
-                           and return (response, read_func)
+                           and return (response, read_func).
         """
         self.base_url = base_url.rstrip("/") + "/"
         self._http_request = http_request_func
@@ -102,14 +102,15 @@ class DumbHTTPObjectStore(BaseObjectStore):
         """Fetch a loose object by SHA.
 
         Args:
-          sha: SHA1 of the object
+          sha: SHA1 of the object (hex string as bytes)
+
         Returns:
           Tuple of (type_num, content)
 
         Raises:
           KeyError: If object not found
         """
-        hex_sha = sha_to_hex(sha).decode("ascii")
+        hex_sha = sha.decode("ascii")
         path = f"objects/{hex_sha[:2]}/{hex_sha[2:]}"
 
         try:
@@ -199,7 +200,8 @@ class DumbHTTPObjectStore(BaseObjectStore):
         """Try to fetch an object from pack files.
 
         Args:
-          sha: SHA1 of the object
+          sha: SHA1 of the object (hex string as bytes)
+
         Returns:
           Tuple of (type_num, content)
 
@@ -207,6 +209,8 @@ class DumbHTTPObjectStore(BaseObjectStore):
           KeyError: If object not found in any pack
         """
         self._load_packs()
+        # Convert hex to binary for pack operations
+        binsha = hex_to_sha(sha)
 
         for pack_name, idx in self._packs or []:
             if idx is None:
@@ -214,7 +218,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
 
             try:
                 # Check if object is in this pack
-                idx.object_offset(sha)
+                idx.object_offset(binsha)
             except KeyError:
                 continue
 
@@ -235,7 +239,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
             # Open the pack and get the object
             pack = Pack(pack_path[:-5])  # Remove .pack extension
             try:
-                return pack.get_raw(sha)
+                return pack.get_raw(binsha)
             finally:
                 pack.close()
 
@@ -312,7 +316,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
             for sha in idx:
                 if sha not in seen:
                     seen.add(sha)
-                    yield sha
+                    yield sha_to_hex(sha)
 
     @property
     def packs(self):
@@ -338,15 +342,15 @@ class DumbHTTPObjectStore(BaseObjectStore):
             shutil.rmtree(self._temp_pack_dir, ignore_errors=True)
 
 
-class DumbRemoteRepo(BaseRepo):
+class DumbRemoteHTTPRepo(BaseRepo):
     """Repository implementation for dumb HTTP remotes."""
 
     def __init__(self, base_url: str, http_request_func):
-        """Initialize a DumbRemoteRepo.
+        """Initialize a DumbRemoteHTTPRepo.
 
         Args:
           base_url: Base URL of the remote repository
-          http_request_func: Function to make HTTP requests
+          http_request_func: Function to make HTTP requests.
         """
         self.base_url = base_url.rstrip("/") + "/"
         self._http_request = http_request_func
@@ -389,11 +393,8 @@ class DumbRemoteRepo(BaseRepo):
                 raise NotGitRepository(f"Cannot read refs from {self.base_url}")
 
             refs_hex = read_info_refs(BytesIO(refs_data))
-            # Convert hex SHAs to binary
-            refs = {}
-            for ref, hex_sha in refs_hex.items():
-                refs[ref] = hex_to_sha(hex_sha)
-            self._refs, self._peeled = split_peeled_refs(refs)
+            # Keep SHAs as hex
+            self._refs, self._peeled = split_peeled_refs(refs_hex)
 
         return dict(self._refs)
 
@@ -420,7 +421,6 @@ class DumbRemoteRepo(BaseRepo):
         Returns:
           Iterator of UnpackedObject instances
         """
-
         refs = self.get_refs()
         wants = determine_wants(refs)
 
@@ -444,29 +444,24 @@ class DumbRemoteRepo(BaseRepo):
                 continue
 
             # Fetch the object
-            try:
-                type_num, content = self._object_store.get_raw(sha)
-                unpacked = UnpackedObject(type_num, sha=sha)
-                unpacked.obj_type_num = type_num
-                unpacked.obj_chunks = [content]
-                yield unpacked
-
-                # If it's a commit or tag, we need to fetch its references
-                obj = ShaFile.from_raw_string(type_num, content)
-
-                if hasattr(obj, "tree"):  # Commit
-                    to_fetch.add(obj.tree)
-                    for parent in obj.parents:
-                        to_fetch.add(parent)
-                elif hasattr(obj, "object"):  # Tag
-                    to_fetch.add(obj.object[1])
-                elif hasattr(obj, "items"):  # Tree
-                    for _, _, item_sha in obj.items():
-                        to_fetch.add(item_sha)
-
-            except KeyError:
-                # Object not found, skip it
-                pass
+            type_num, content = self._object_store.get_raw(sha)
+            unpacked = UnpackedObject(type_num, sha=sha)
+            unpacked.obj_type_num = type_num
+            unpacked.obj_chunks = [content]
+            yield unpacked
+
+            # If it's a commit or tag, we need to fetch its references
+            obj = ShaFile.from_raw_string(type_num, content)
+
+            if isinstance(obj, Commit):  # Commit
+                to_fetch.add(obj.tree)
+                for parent in obj.parents:
+                    to_fetch.add(parent)
+            elif isinstance(obj, Tag):  # Tag
+                to_fetch.add(obj.object[1])
+            elif isinstance(obj, Tree):  # Tree
+                for _, _, item_sha in obj.items():
+                    to_fetch.add(item_sha)
 
             if progress:
                 progress(f"Fetching objects: {len(seen)} done")

+ 7 - 4
tests/test_client.py

@@ -59,7 +59,7 @@ from dulwich.client import (
     parse_rsync_url,
 )
 from dulwich.config import ConfigDict
-from dulwich.objects import Commit, Tree, hex_to_sha
+from dulwich.objects import Commit, Tree
 from dulwich.pack import pack_objects_to_data, write_pack_data, write_pack_objects
 from dulwich.protocol import DEFAULT_GIT_PROTOCOL_VERSION_FETCH, TCP_GIT_PORT, Protocol
 from dulwich.repo import MemoryRepo, Repo
@@ -1395,7 +1395,7 @@ class HttpGitClientTests(TestCase):
         )
 
         # Verify we got the refs
-        expected_sha = hex_to_sha(blob_hex.encode("ascii"))
+        expected_sha = blob_hex.encode("ascii")
         self.assertEqual({b"refs/heads/master": expected_sha}, result.refs)
 
         # Verify we received pack data
@@ -1403,8 +1403,11 @@ class HttpGitClientTests(TestCase):
         pack_data = b"".join(received_data)
         self.assertTrue(len(pack_data) > 0)
 
-        # The pack should contain our blob
-        self.assertIn(blob_content, pack_data)
+        # The pack should be valid pack format
+        self.assertTrue(pack_data.startswith(b"PACK"))
+        # Pack header: PACK + version (4 bytes) + num objects (4 bytes)
+        self.assertEqual(pack_data[4:8], b"\x00\x00\x00\x02")  # version 2
+        self.assertEqual(pack_data[8:12], b"\x00\x00\x00\x01")  # 1 object
 
 
 class TCPGitClientTests(TestCase):

+ 53 - 27
tests/test_dumb.py

@@ -25,9 +25,9 @@ import zlib
 from unittest import TestCase
 from unittest.mock import Mock
 
-from dulwich.dumb import DumbHTTPObjectStore, DumbRemoteRepo
+from dulwich.dumb import DumbHTTPObjectStore, DumbRemoteHTTPRepo
 from dulwich.errors import NotGitRepository
-from dulwich.objects import Blob, Commit, Tag, Tree, hex_to_sha, sha_to_hex
+from dulwich.objects import Blob, Commit, Tag, Tree, sha_to_hex
 
 
 class MockResponse:
@@ -56,7 +56,22 @@ class DumbHTTPObjectStoreTests(TestCase):
             resp = MockResponse(
                 resp_data.get("status", 200), resp_data.get("content", b"")
             )
-            return resp, lambda size: resp.content
+            # Create a mock read function that behaves like urllib3's read
+            content = resp.content
+            offset = [0]  # Use list to make it mutable in closure
+
+            def read_func(size=None):
+                if offset[0] >= len(content):
+                    return b""
+                if size is None:
+                    result = content[offset[0] :]
+                    offset[0] = len(content)
+                else:
+                    result = content[offset[0] : offset[0] + size]
+                    offset[0] += size
+                return result
+
+            return resp, read_func
         else:
             resp = MockResponse(404)
             return resp, lambda size: b""
@@ -83,21 +98,20 @@ class DumbHTTPObjectStoreTests(TestCase):
         # Create a blob object
         blob = Blob()
         blob.data = b"Hello, world!"
-        sha = blob.sha().digest()
-        hex_sha = sha_to_hex(sha)
+        hex_sha = blob.id
 
         # Add mock response
-        path = f"objects/{hex_sha[:2]}/{hex_sha[2:]}"
+        path = f"objects/{hex_sha[:2].decode('ascii')}/{hex_sha[2:].decode('ascii')}"
         self._add_response(path, self._make_object(blob))
 
         # Fetch the object
-        type_num, content = self.store._fetch_loose_object(sha)
+        type_num, content = self.store._fetch_loose_object(blob.id)
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(b"Hello, world!", content)
 
     def test_fetch_loose_object_not_found(self):
-        sha = b"1" * 20
-        self.assertRaises(KeyError, self.store._fetch_loose_object, sha)
+        hex_sha = b"1" * 40
+        self.assertRaises(KeyError, self.store._fetch_loose_object, hex_sha)
 
     def test_fetch_loose_object_invalid_format(self):
         sha = b"1" * 20
@@ -130,7 +144,7 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         )
 
     def test_get_raw_from_cache(self):
-        sha = b"1" * 20
+        sha = b"1" * 40
         self.store._cached_objects[sha] = (Blob.type_num, b"cached content")
 
         type_num, content = self.store.get_raw(sha)
@@ -141,15 +155,14 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         # Create a blob object
         blob = Blob()
         blob.data = b"Test blob"
-        sha = blob.sha().digest()
-        hex_sha = sha_to_hex(sha)
+        hex_sha = blob.id
 
         # Add mock response
-        path = f"objects/{hex_sha[:2]}/{hex_sha[2:]}"
+        path = f"objects/{hex_sha[:2].decode('ascii')}/{hex_sha[2:].decode('ascii')}"
         self._add_response(path, self._make_object(blob))
 
-        self.assertTrue(self.store.contains_loose(sha))
-        self.assertFalse(self.store.contains_loose(b"0" * 20))
+        self.assertTrue(self.store.contains_loose(hex_sha))
+        self.assertFalse(self.store.contains_loose(b"0" * 40))
 
     def test_add_object_not_implemented(self):
         blob = Blob()
@@ -160,13 +173,13 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         self.assertRaises(NotImplementedError, self.store.add_objects, [])
 
 
-class DumbRemoteRepoTests(TestCase):
-    """Tests for DumbRemoteRepo."""
+class DumbRemoteHTTPRepoTests(TestCase):
+    """Tests for DumbRemoteHTTPRepo."""
 
     def setUp(self):
         self.base_url = "https://example.com/repo.git/"
         self.responses = {}
-        self.repo = DumbRemoteRepo(self.base_url, self._mock_http_request)
+        self.repo = DumbRemoteHTTPRepo(self.base_url, self._mock_http_request)
 
     def _mock_http_request(self, url, headers):
         """Mock HTTP request function."""
@@ -175,7 +188,22 @@ class DumbRemoteRepoTests(TestCase):
             resp = MockResponse(
                 resp_data.get("status", 200), resp_data.get("content", b"")
             )
-            return resp, lambda size: resp.content[:size] if size else resp.content
+            # Create a mock read function that behaves like urllib3's read
+            content = resp.content
+            offset = [0]  # Use list to make it mutable in closure
+
+            def read_func(size=None):
+                if offset[0] >= len(content):
+                    return b""
+                if size is None:
+                    result = content[offset[0] :]
+                    offset[0] = len(content)
+                else:
+                    result = content[offset[0] : offset[0] + size]
+                    offset[0] += size
+                return result
+
+            return resp, read_func
         else:
             resp = MockResponse(404)
             return resp, lambda size: b""
@@ -195,15 +223,15 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         refs = self.repo.get_refs()
         self.assertEqual(3, len(refs))
         self.assertEqual(
-            hex_to_sha(b"0123456789abcdef0123456789abcdef01234567"),
+            b"0123456789abcdef0123456789abcdef01234567",
             refs[b"refs/heads/master"],
         )
         self.assertEqual(
-            hex_to_sha(b"abcdef0123456789abcdef0123456789abcdef01"),
+            b"abcdef0123456789abcdef0123456789abcdef01",
             refs[b"refs/heads/develop"],
         )
         self.assertEqual(
-            hex_to_sha(b"fedcba9876543210fedcba9876543210fedcba98"),
+            b"fedcba9876543210fedcba9876543210fedcba98",
             refs[b"refs/tags/v1.0"],
         )
 
@@ -216,9 +244,7 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
 
         # For dumb HTTP, peeled just returns the ref value
         peeled = self.repo.get_peeled(b"refs/heads/master")
-        self.assertEqual(
-            hex_to_sha(b"0123456789abcdef0123456789abcdef01234567"), peeled
-        )
+        self.assertEqual(b"0123456789abcdef0123456789abcdef01234567", peeled)
 
     def test_fetch_pack_data_no_wants(self):
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
@@ -240,7 +266,7 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         # Create a simple blob object
         blob = Blob()
         blob.data = b"Test content"
-        blob_sha = blob.sha().digest()
+        blob_sha = blob.id
         # Add blob response
         self.repo._object_store._cached_objects[blob_sha] = (
             Blob.type_num,
@@ -257,7 +283,7 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         result = list(self.repo.fetch_pack_data(graph_walker, determine_wants))
         self.assertEqual(1, len(result))
         self.assertEqual(Blob.type_num, result[0].pack_type_num)
-        self.assertEqual(blob.as_raw_string(), result[0].obj)
+        self.assertEqual([blob.as_raw_string()], result[0].obj_chunks)
 
     def test_object_store_property(self):
         self.assertIsInstance(self.repo.object_store, DumbHTTPObjectStore)