diff --git a/cytetype/__init__.py b/cytetype/__init__.py index db09ae3..08c2bb0 100644 --- a/cytetype/__init__.py +++ b/cytetype/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.18.1" +__version__ = "0.19.0" import requests diff --git a/cytetype/api/client.py b/cytetype/api/client.py index a6280b3..6819b92 100644 --- a/cytetype/api/client.py +++ b/cytetype/api/client.py @@ -40,7 +40,7 @@ def _upload_file( file_kind: UploadFileKind, file_path: str, timeout: float | tuple[float, float] = (60.0, 3600.0), - max_workers: int = 4, + max_workers: int = 6, ) -> UploadResponse: path_obj = Path(file_path) if not path_obj.is_file(): @@ -69,6 +69,16 @@ def _upload_file( n_chunks = math.ceil(size_bytes / chunk_size) if size_bytes > 0 else 0 + presigned_urls: list[str] | None = init_data.get("presigned_urls") + r2_upload_id: str | None = init_data.get("r2_upload_id") + use_r2 = presigned_urls is not None and r2_upload_id is not None + + if use_r2 and len(presigned_urls) != n_chunks: # type: ignore[arg-type] + raise ValueError( + f"Server returned {len(presigned_urls)} presigned URLs " # type: ignore[arg-type] + f"but expected {n_chunks} (one per chunk)." + ) + # Step 2 – Upload chunks in parallel. # Each worker thread gets its own HTTPTransport (and thus its own # requests.Session / connection pool) for thread safety. @@ -83,8 +93,61 @@ def _upload_file( ) _progress_lock = threading.Lock() _chunks_done = [0] + _etags: dict[int, str] = {} + _etags_lock = threading.Lock() + + def _update_progress() -> None: + if pbar is not None: + pbar.update(1) + else: + with _progress_lock: + _chunks_done[0] += 1 + done = _chunks_done[0] + pct = 100 * done / n_chunks + print( + f"\r Uploading: {done}/{n_chunks} chunks ({pct:.0f}%)", + end="", + flush=True, + ) + + def _upload_chunk_r2(chunk_idx: int) -> None: + if not hasattr(_tls, "transport"): + _tls.transport = HTTPTransport(base_url, auth_token) + offset = chunk_idx * chunk_size + read_size = min(chunk_size, size_bytes - offset) + with path_obj.open("rb") as f: + f.seek(offset) + chunk_data = f.read(read_size) + + url = presigned_urls[chunk_idx] # type: ignore[index] + last_exc: Exception | None = None + for attempt in range(1 + len(_CHUNK_RETRY_DELAYS)): + try: + etag = _tls.transport.put_to_presigned_url( + url, chunk_data, timeout=timeout + ) + with _etags_lock: + _etags[chunk_idx] = etag + _update_progress() + return + except (NetworkError, TimeoutError) as exc: + last_exc = exc + except APIError as exc: + if exc.error_code in _RETRYABLE_API_ERROR_CODES: + last_exc = exc + else: + raise + if attempt < len(_CHUNK_RETRY_DELAYS): + delay = _CHUNK_RETRY_DELAYS[attempt] + logger.warning( + f"Chunk {chunk_idx + 1}/{n_chunks} upload failed " + f"(attempt {attempt + 1}/{1 + len(_CHUNK_RETRY_DELAYS)}), " + f"retrying in {delay}s: {last_exc}" + ) + time.sleep(delay) + raise last_exc # type: ignore[misc] - def _upload_chunk(chunk_idx: int) -> None: + def _upload_chunk_server(chunk_idx: int) -> None: if not hasattr(_tls, "transport"): _tls.transport = HTTPTransport(base_url, auth_token) offset = chunk_idx * chunk_size @@ -101,18 +164,7 @@ def _upload_chunk(chunk_idx: int) -> None: data=chunk_data, timeout=timeout, ) - if pbar is not None: - pbar.update(1) - else: - with _progress_lock: - _chunks_done[0] += 1 - done = _chunks_done[0] - pct = 100 * done / n_chunks - print( - f"\r Uploading: {done}/{n_chunks} chunks ({pct:.0f}%)", - end="", - flush=True, - ) + _update_progress() return except (NetworkError, TimeoutError) as exc: last_exc = exc @@ -121,7 +173,6 @@ def _upload_chunk(chunk_idx: int) -> None: last_exc = exc else: raise - if attempt < len(_CHUNK_RETRY_DELAYS): delay = _CHUNK_RETRY_DELAYS[attempt] logger.warning( @@ -130,14 +181,15 @@ def _upload_chunk(chunk_idx: int) -> None: f"retrying in {delay}s: {last_exc}" ) time.sleep(delay) - raise last_exc # type: ignore[misc] + upload_fn = _upload_chunk_r2 if use_r2 else _upload_chunk_server + if n_chunks > 0: effective_workers = min(max_workers, n_chunks) try: with ThreadPoolExecutor(max_workers=effective_workers) as pool: - list(pool.map(_upload_chunk, range(n_chunks))) + list(pool.map(upload_fn, range(n_chunks))) if pbar is not None: pbar.close() else: @@ -151,10 +203,19 @@ def _upload_chunk(chunk_idx: int) -> None: print() raise - # Step 3 – Complete upload (returns same UploadResponse shape as before) - _, complete_data = transport.post_empty( - f"upload/{upload_id}/complete", timeout=timeout - ) + # Step 3 – Complete upload + if use_r2: + parts = [{"ETag": _etags[i], "PartNumber": i + 1} for i in range(n_chunks)] + _, complete_data = transport.post_json( + f"upload/{upload_id}/complete", + data={"parts": parts}, + timeout=timeout, + ) + else: + _, complete_data = transport.post_empty( + f"upload/{upload_id}/complete", timeout=timeout + ) + return UploadResponse(**complete_data) @@ -163,7 +224,7 @@ def upload_obs_duckdb( auth_token: str | None, file_path: str, timeout: float | tuple[float, float] = (60.0, 3600.0), - max_workers: int = 4, + max_workers: int = 6, ) -> UploadResponse: return _upload_file( base_url, @@ -180,7 +241,7 @@ def upload_vars_h5( auth_token: str | None, file_path: str, timeout: float | tuple[float, float] = (60.0, 3600.0), - max_workers: int = 4, + max_workers: int = 6, ) -> UploadResponse: return _upload_file( base_url, diff --git a/cytetype/api/transport.py b/cytetype/api/transport.py index 98bce28..c5b5a72 100644 --- a/cytetype/api/transport.py +++ b/cytetype/api/transport.py @@ -1,7 +1,7 @@ import requests from typing import Any, BinaryIO -from .exceptions import create_api_exception, NetworkError, TimeoutError +from .exceptions import create_api_exception, APIError, NetworkError, TimeoutError from .schemas import ErrorResponse @@ -124,6 +124,60 @@ def put_binary( self._handle_request_error(e) raise # For type checker + def put_to_presigned_url( + self, + url: str, + data: bytes, + timeout: float | tuple[float, float] = (30.0, 3600.0), + ) -> str: + """PUT raw bytes to a presigned URL. Returns the ETag header.""" + try: + response = self.session.put( + url, + data=data, + headers={"Content-Type": "application/octet-stream"}, + timeout=timeout, + ) + if 400 <= response.status_code < 500: + raise APIError( + f"Presigned URL upload rejected (HTTP {response.status_code}): " + f"{response.text[:200]}", + error_code="PRESIGNED_URL_REJECTED", + ) + response.raise_for_status() + etag = response.headers.get("ETag") + if not etag: + raise NetworkError( + "Presigned URL PUT succeeded but response is missing the ETag header", + error_code="MISSING_ETAG", + ) + return etag + except requests.RequestException as e: + self._handle_request_error(e) + raise + + def post_json( + self, + endpoint: str, + data: dict[str, Any], + timeout: float | tuple[float, float] = 30.0, + ) -> tuple[int, dict[str, Any]]: + """Make POST request with JSON body.""" + url = f"{self.base_url}/{endpoint.lstrip('/')}" + try: + response = self.session.post( + url, + json=data, + headers=self._build_headers(content_type="application/json"), + timeout=timeout, + ) + if not response.ok: + self._parse_error(response) + return response.status_code, response.json() + except requests.RequestException as e: + self._handle_request_error(e) + raise + def get(self, endpoint: str, timeout: int = 30) -> tuple[int, dict[str, Any]]: """Make GET request and return (status_code, data).""" url = f"{self.base_url}/{endpoint.lstrip('/')}"