Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cytetype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.18.1"
__version__ = "0.19.0"

import requests

Expand Down
107 changes: 84 additions & 23 deletions cytetype/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _upload_file(
file_kind: UploadFileKind,
file_path: str,
timeout: float | tuple[float, float] = (60.0, 3600.0),
max_workers: int = 4,
max_workers: int = 6,
) -> UploadResponse:
path_obj = Path(file_path)
if not path_obj.is_file():
Expand Down Expand Up @@ -69,6 +69,16 @@ def _upload_file(

n_chunks = math.ceil(size_bytes / chunk_size) if size_bytes > 0 else 0

presigned_urls: list[str] | None = init_data.get("presigned_urls")
r2_upload_id: str | None = init_data.get("r2_upload_id")
use_r2 = presigned_urls is not None and r2_upload_id is not None

if use_r2 and len(presigned_urls) != n_chunks: # type: ignore[arg-type]
raise ValueError(
f"Server returned {len(presigned_urls)} presigned URLs " # type: ignore[arg-type]
f"but expected {n_chunks} (one per chunk)."
)

# Step 2 – Upload chunks in parallel.
# Each worker thread gets its own HTTPTransport (and thus its own
# requests.Session / connection pool) for thread safety.
Expand All @@ -83,8 +93,61 @@ def _upload_file(
)
_progress_lock = threading.Lock()
_chunks_done = [0]
_etags: dict[int, str] = {}
_etags_lock = threading.Lock()

def _update_progress() -> None:
if pbar is not None:
pbar.update(1)
else:
with _progress_lock:
_chunks_done[0] += 1
done = _chunks_done[0]
pct = 100 * done / n_chunks
print(
f"\r Uploading: {done}/{n_chunks} chunks ({pct:.0f}%)",
end="",
flush=True,
)

def _upload_chunk_r2(chunk_idx: int) -> None:
if not hasattr(_tls, "transport"):
_tls.transport = HTTPTransport(base_url, auth_token)
offset = chunk_idx * chunk_size
read_size = min(chunk_size, size_bytes - offset)
with path_obj.open("rb") as f:
f.seek(offset)
chunk_data = f.read(read_size)

url = presigned_urls[chunk_idx] # type: ignore[index]
last_exc: Exception | None = None
for attempt in range(1 + len(_CHUNK_RETRY_DELAYS)):
try:
etag = _tls.transport.put_to_presigned_url(
url, chunk_data, timeout=timeout
)
with _etags_lock:
_etags[chunk_idx] = etag
_update_progress()
return
except (NetworkError, TimeoutError) as exc:
last_exc = exc
except APIError as exc:
if exc.error_code in _RETRYABLE_API_ERROR_CODES:
last_exc = exc
else:
raise
if attempt < len(_CHUNK_RETRY_DELAYS):
delay = _CHUNK_RETRY_DELAYS[attempt]
logger.warning(
f"Chunk {chunk_idx + 1}/{n_chunks} upload failed "
f"(attempt {attempt + 1}/{1 + len(_CHUNK_RETRY_DELAYS)}), "
f"retrying in {delay}s: {last_exc}"
)
time.sleep(delay)
raise last_exc # type: ignore[misc]

def _upload_chunk(chunk_idx: int) -> None:
def _upload_chunk_server(chunk_idx: int) -> None:
if not hasattr(_tls, "transport"):
_tls.transport = HTTPTransport(base_url, auth_token)
offset = chunk_idx * chunk_size
Expand All @@ -101,18 +164,7 @@ def _upload_chunk(chunk_idx: int) -> None:
data=chunk_data,
timeout=timeout,
)
if pbar is not None:
pbar.update(1)
else:
with _progress_lock:
_chunks_done[0] += 1
done = _chunks_done[0]
pct = 100 * done / n_chunks
print(
f"\r Uploading: {done}/{n_chunks} chunks ({pct:.0f}%)",
end="",
flush=True,
)
_update_progress()
return
except (NetworkError, TimeoutError) as exc:
last_exc = exc
Expand All @@ -121,7 +173,6 @@ def _upload_chunk(chunk_idx: int) -> None:
last_exc = exc
else:
raise

if attempt < len(_CHUNK_RETRY_DELAYS):
delay = _CHUNK_RETRY_DELAYS[attempt]
logger.warning(
Expand All @@ -130,14 +181,15 @@ def _upload_chunk(chunk_idx: int) -> None:
f"retrying in {delay}s: {last_exc}"
)
time.sleep(delay)

raise last_exc # type: ignore[misc]

upload_fn = _upload_chunk_r2 if use_r2 else _upload_chunk_server

if n_chunks > 0:
effective_workers = min(max_workers, n_chunks)
try:
with ThreadPoolExecutor(max_workers=effective_workers) as pool:
list(pool.map(_upload_chunk, range(n_chunks)))
list(pool.map(upload_fn, range(n_chunks)))
if pbar is not None:
pbar.close()
else:
Expand All @@ -151,10 +203,19 @@ def _upload_chunk(chunk_idx: int) -> None:
print()
raise

# Step 3 – Complete upload (returns same UploadResponse shape as before)
_, complete_data = transport.post_empty(
f"upload/{upload_id}/complete", timeout=timeout
)
# Step 3 – Complete upload
if use_r2:
parts = [{"ETag": _etags[i], "PartNumber": i + 1} for i in range(n_chunks)]
_, complete_data = transport.post_json(
f"upload/{upload_id}/complete",
data={"parts": parts},
timeout=timeout,
)
else:
_, complete_data = transport.post_empty(
f"upload/{upload_id}/complete", timeout=timeout
)

return UploadResponse(**complete_data)


Expand All @@ -163,7 +224,7 @@ def upload_obs_duckdb(
auth_token: str | None,
file_path: str,
timeout: float | tuple[float, float] = (60.0, 3600.0),
max_workers: int = 4,
max_workers: int = 6,
) -> UploadResponse:
return _upload_file(
base_url,
Expand All @@ -180,7 +241,7 @@ def upload_vars_h5(
auth_token: str | None,
file_path: str,
timeout: float | tuple[float, float] = (60.0, 3600.0),
max_workers: int = 4,
max_workers: int = 6,
) -> UploadResponse:
return _upload_file(
base_url,
Expand Down
56 changes: 55 additions & 1 deletion cytetype/api/transport.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import requests
from typing import Any, BinaryIO

from .exceptions import create_api_exception, NetworkError, TimeoutError
from .exceptions import create_api_exception, APIError, NetworkError, TimeoutError
from .schemas import ErrorResponse


Expand Down Expand Up @@ -124,6 +124,60 @@ def put_binary(
self._handle_request_error(e)
raise # For type checker

def put_to_presigned_url(
self,
url: str,
data: bytes,
timeout: float | tuple[float, float] = (30.0, 3600.0),
) -> str:
"""PUT raw bytes to a presigned URL. Returns the ETag header."""
try:
response = self.session.put(
url,
data=data,
headers={"Content-Type": "application/octet-stream"},
timeout=timeout,
)
if 400 <= response.status_code < 500:
raise APIError(
f"Presigned URL upload rejected (HTTP {response.status_code}): "
f"{response.text[:200]}",
error_code="PRESIGNED_URL_REJECTED",
)
response.raise_for_status()
etag = response.headers.get("ETag")
if not etag:
raise NetworkError(
"Presigned URL PUT succeeded but response is missing the ETag header",
error_code="MISSING_ETAG",
)
return etag
except requests.RequestException as e:
self._handle_request_error(e)
raise

def post_json(
self,
endpoint: str,
data: dict[str, Any],
timeout: float | tuple[float, float] = 30.0,
) -> tuple[int, dict[str, Any]]:
"""Make POST request with JSON body."""
url = f"{self.base_url}/{endpoint.lstrip('/')}"
try:
response = self.session.post(
url,
json=data,
headers=self._build_headers(content_type="application/json"),
timeout=timeout,
)
if not response.ok:
self._parse_error(response)
return response.status_code, response.json()
except requests.RequestException as e:
self._handle_request_error(e)
raise

def get(self, endpoint: str, timeout: int = 30) -> tuple[int, dict[str, Any]]:
"""Make GET request and return (status_code, data)."""
url = f"{self.base_url}/{endpoint.lstrip('/')}"
Expand Down