lfs_server.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. # lfs_server.py -- Simple Git LFS server implementation
  2. # Copyright (C) 2024 Jelmer Vernooij <jelmer@jelmer.uk>
  3. #
  4. # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
  5. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  6. # General Public License as published by the Free Software Foundation; version 2.0
  7. # or (at your option) any later version. You can redistribute it and/or
  8. # modify it under the terms of either of these two licenses.
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # You should have received a copy of the licenses; if not, see
  17. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  18. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  19. # License, Version 2.0.
  20. #
  21. """Simple Git LFS server implementation for testing."""
  22. import hashlib
  23. import json
  24. import tempfile
  25. from http.server import BaseHTTPRequestHandler, HTTPServer
  26. from typing import Optional
  27. from .lfs import LFSStore
  28. class LFSRequestHandler(BaseHTTPRequestHandler):
  29. """HTTP request handler for LFS operations."""
  30. server: "LFSServer" # Type annotation for the server attribute
  31. def send_json_response(self, status_code: int, data: dict) -> None:
  32. """Send a JSON response."""
  33. response = json.dumps(data).encode("utf-8")
  34. self.send_response(status_code)
  35. self.send_header("Content-Type", "application/vnd.git-lfs+json")
  36. self.send_header("Content-Length", str(len(response)))
  37. self.end_headers()
  38. self.wfile.write(response)
  39. def do_POST(self) -> None:
  40. """Handle POST requests."""
  41. if self.path == "/objects/batch":
  42. self.handle_batch()
  43. elif self.path.startswith("/objects/") and self.path.endswith("/verify"):
  44. self.handle_verify()
  45. else:
  46. self.send_error(404, "Not Found")
  47. def do_PUT(self) -> None:
  48. """Handle PUT requests (uploads)."""
  49. if self.path.startswith("/objects/"):
  50. self.handle_upload()
  51. else:
  52. self.send_error(404, "Not Found")
  53. def do_GET(self) -> None:
  54. """Handle GET requests (downloads)."""
  55. if self.path.startswith("/objects/"):
  56. self.handle_download()
  57. else:
  58. self.send_error(404, "Not Found")
  59. def handle_batch(self) -> None:
  60. """Handle batch API requests."""
  61. content_length = int(self.headers["Content-Length"])
  62. request_data = self.rfile.read(content_length)
  63. try:
  64. batch_request = json.loads(request_data)
  65. except json.JSONDecodeError:
  66. self.send_error(400, "Invalid JSON")
  67. return
  68. operation = batch_request.get("operation")
  69. objects = batch_request.get("objects", [])
  70. if operation not in ["download", "upload"]:
  71. self.send_error(400, "Invalid operation")
  72. return
  73. response_objects = []
  74. for obj in objects:
  75. oid = obj.get("oid")
  76. size = obj.get("size")
  77. if not oid or size is None:
  78. response_objects.append(
  79. {
  80. "oid": oid,
  81. "size": size,
  82. "error": {"code": 400, "message": "Missing oid or size"},
  83. }
  84. )
  85. continue
  86. response_obj = {
  87. "oid": oid,
  88. "size": size,
  89. }
  90. if operation == "download":
  91. # Check if object exists
  92. if self._object_exists(oid):
  93. response_obj["actions"] = {
  94. "download": {
  95. "href": f"http://{self.headers['Host']}/objects/{oid}",
  96. "header": {"Accept": "application/octet-stream"},
  97. }
  98. }
  99. else:
  100. response_obj["error"] = {"code": 404, "message": "Object not found"}
  101. else: # upload
  102. response_obj["actions"] = {
  103. "upload": {
  104. "href": f"http://{self.headers['Host']}/objects/{oid}",
  105. "header": {"Content-Type": "application/octet-stream"},
  106. },
  107. "verify": {
  108. "href": f"http://{self.headers['Host']}/objects/{oid}/verify"
  109. },
  110. }
  111. response_objects.append(response_obj)
  112. self.send_json_response(200, {"objects": response_objects})
  113. def handle_download(self) -> None:
  114. """Handle object download requests."""
  115. # Extract OID from path
  116. path_parts = self.path.strip("/").split("/")
  117. if len(path_parts) != 2:
  118. self.send_error(404, "Not Found")
  119. return
  120. oid = path_parts[1]
  121. try:
  122. with self.server.lfs_store.open_object(oid) as f:
  123. content = f.read()
  124. self.send_response(200)
  125. self.send_header("Content-Type", "application/octet-stream")
  126. self.send_header("Content-Length", str(len(content)))
  127. self.end_headers()
  128. self.wfile.write(content)
  129. except KeyError:
  130. self.send_error(404, "Object not found")
  131. def handle_upload(self) -> None:
  132. """Handle object upload requests."""
  133. # Extract OID from path
  134. path_parts = self.path.strip("/").split("/")
  135. if len(path_parts) != 2:
  136. self.send_error(404, "Not Found")
  137. return
  138. oid = path_parts[1]
  139. content_length = int(self.headers["Content-Length"])
  140. # Read content in chunks
  141. chunks = []
  142. remaining = content_length
  143. while remaining > 0:
  144. chunk_size = min(8192, remaining)
  145. chunk = self.rfile.read(chunk_size)
  146. if not chunk:
  147. break
  148. chunks.append(chunk)
  149. remaining -= len(chunk)
  150. # Calculate SHA256
  151. content = b"".join(chunks)
  152. calculated_oid = hashlib.sha256(content).hexdigest()
  153. # Verify OID matches
  154. if calculated_oid != oid:
  155. self.send_error(400, f"OID mismatch: expected {oid}, got {calculated_oid}")
  156. return
  157. # Check if object already exists
  158. if not self._object_exists(oid):
  159. # Store the object only if it doesn't exist
  160. self.server.lfs_store.write_object(chunks)
  161. self.send_response(200)
  162. self.end_headers()
  163. def handle_verify(self) -> None:
  164. """Handle object verification requests."""
  165. # Extract OID from path
  166. path_parts = self.path.strip("/").split("/")
  167. if len(path_parts) != 3 or path_parts[2] != "verify":
  168. self.send_error(404, "Not Found")
  169. return
  170. oid = path_parts[1]
  171. content_length = int(self.headers.get("Content-Length", 0))
  172. if content_length > 0:
  173. request_data = self.rfile.read(content_length)
  174. try:
  175. verify_request = json.loads(request_data)
  176. # Optionally validate size
  177. if "size" in verify_request:
  178. # Could verify size matches stored object
  179. pass
  180. except json.JSONDecodeError:
  181. pass
  182. # Check if object exists
  183. if self._object_exists(oid):
  184. self.send_response(200)
  185. self.end_headers()
  186. else:
  187. self.send_error(404, "Object not found")
  188. def _object_exists(self, oid: str) -> bool:
  189. """Check if an object exists in the store."""
  190. try:
  191. # Try to open the object - if it exists, close it immediately
  192. with self.server.lfs_store.open_object(oid):
  193. return True
  194. except KeyError:
  195. return False
  196. def log_message(self, format, *args):
  197. """Override to suppress request logging during tests."""
  198. if self.server.log_requests:
  199. super().log_message(format, *args)
  200. class LFSServer(HTTPServer):
  201. """Simple LFS server for testing."""
  202. def __init__(self, server_address, lfs_store: LFSStore, log_requests: bool = False):
  203. super().__init__(server_address, LFSRequestHandler)
  204. self.lfs_store = lfs_store
  205. self.log_requests = log_requests
  206. def run_lfs_server(
  207. host: str = "localhost",
  208. port: int = 0,
  209. lfs_dir: Optional[str] = None,
  210. log_requests: bool = False,
  211. ) -> tuple[LFSServer, str]:
  212. """Run an LFS server.
  213. Args:
  214. host: Host to bind to
  215. port: Port to bind to (0 for random)
  216. lfs_dir: Directory for LFS storage (temp dir if None)
  217. log_requests: Whether to log HTTP requests
  218. Returns:
  219. Tuple of (server, url) where url is the base URL for the server
  220. """
  221. if lfs_dir is None:
  222. lfs_dir = tempfile.mkdtemp()
  223. lfs_store = LFSStore.create(lfs_dir)
  224. server = LFSServer((host, port), lfs_store, log_requests)
  225. # Get the actual port if we used 0
  226. actual_port = server.server_address[1]
  227. url = f"http://{host}:{actual_port}"
  228. return server, url