diff --git a/cytetype/__init__.py b/cytetype/__init__.py index 3cb2877..bbb4588 100644 --- a/cytetype/__init__.py +++ b/cytetype/__init__.py @@ -1,11 +1,14 @@ -__version__ = "0.17.0" +__version__ = "0.18.0" import requests from .config import logger from .main import CyteType +from .plotting import marker_dotplot +from .preprocessing.marker_detection import rank_genes_groups_backed +from .preprocessing.subsampling import subsample_by_group -__all__ = ["CyteType"] +__all__ = ["CyteType", "marker_dotplot", "rank_genes_groups_backed", "subsample_by_group"] _PYPI_JSON_URL = "https://pypi.org/pypi/cytetype/json" diff --git a/cytetype/api/client.py b/cytetype/api/client.py index 52fd4d1..a6280b3 100644 --- a/cytetype/api/client.py +++ b/cytetype/api/client.py @@ -14,19 +14,32 @@ MAX_UPLOAD_BYTES: dict[UploadFileKind, int] = { "obs_duckdb": 100 * 1024 * 1024, # 100MB - "vars_h5": 10 * 1024 * 1024 * 1024, # 10GB + "vars_h5": 50 * 1024 * 1024 * 1024, # 10GB } _CHUNK_RETRY_DELAYS = (1, 5, 20) _RETRYABLE_API_ERROR_CODES = frozenset({"INTERNAL_ERROR", "HTTP_ERROR"}) +def _try_import_tqdm() -> type | None: + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from tqdm.auto import tqdm + + return tqdm # type: ignore[no-any-return] + except ImportError: + return None + + def _upload_file( base_url: str, auth_token: str | None, file_kind: UploadFileKind, file_path: str, - timeout: float | tuple[float, float] = (30.0, 3600.0), + timeout: float | tuple[float, float] = (60.0, 3600.0), max_workers: int = 4, ) -> UploadResponse: path_obj = Path(file_path) @@ -62,6 +75,12 @@ def _upload_file( # Memory is bounded to ~max_workers × chunk_size because each thread # reads its chunk on demand via seek+read. _tls = threading.local() + tqdm_cls = _try_import_tqdm() + pbar = ( + tqdm_cls(total=n_chunks, desc="Uploading", unit="chunk") + if tqdm_cls is not None and n_chunks > 0 + else None + ) _progress_lock = threading.Lock() _chunks_done = [0] @@ -82,15 +101,18 @@ def _upload_chunk(chunk_idx: int) -> None: data=chunk_data, timeout=timeout, ) - with _progress_lock: - _chunks_done[0] += 1 - done = _chunks_done[0] - pct = 100 * done / n_chunks - print( - f"\r Uploading: {done}/{n_chunks} chunks ({pct:.0f}%)", - end="", - flush=True, - ) + if pbar is not None: + pbar.update(1) + else: + with _progress_lock: + _chunks_done[0] += 1 + done = _chunks_done[0] + pct = 100 * done / n_chunks + print( + f"\r Uploading: {done}/{n_chunks} chunks ({pct:.0f}%)", + end="", + flush=True, + ) return except (NetworkError, TimeoutError) as exc: last_exc = exc @@ -103,13 +125,9 @@ def _upload_chunk(chunk_idx: int) -> None: if attempt < len(_CHUNK_RETRY_DELAYS): delay = _CHUNK_RETRY_DELAYS[attempt] logger.warning( - "Chunk %d/%d upload failed (attempt %d/%d), retrying in %ds: %s", - chunk_idx + 1, - n_chunks, - attempt + 1, - 1 + len(_CHUNK_RETRY_DELAYS), - delay, - last_exc, + f"Chunk {chunk_idx + 1}/{n_chunks} upload failed " + f"(attempt {attempt + 1}/{1 + len(_CHUNK_RETRY_DELAYS)}), " + f"retrying in {delay}s: {last_exc}" ) time.sleep(delay) @@ -120,9 +138,17 @@ def _upload_chunk(chunk_idx: int) -> None: try: with ThreadPoolExecutor(max_workers=effective_workers) as pool: list(pool.map(_upload_chunk, range(n_chunks))) - print(f"\r \033[92m✓\033[0m Uploaded {n_chunks}/{n_chunks} chunks (100%)") + if pbar is not None: + pbar.close() + else: + print( + f"\r \033[92m✓\033[0m Uploaded {n_chunks}/{n_chunks} chunks (100%)" + ) except BaseException: - print() # ensure newline on failure + if pbar is not None: + pbar.close() + else: + print() raise # Step 3 – Complete upload (returns same UploadResponse shape as before) @@ -136,7 +162,7 @@ def upload_obs_duckdb( base_url: str, auth_token: str | None, file_path: str, - timeout: float | tuple[float, float] = (30.0, 3600.0), + timeout: float | tuple[float, float] = (60.0, 3600.0), max_workers: int = 4, ) -> UploadResponse: return _upload_file( @@ -153,7 +179,7 @@ def upload_vars_h5( base_url: str, auth_token: str | None, file_path: str, - timeout: float | tuple[float, float] = (30.0, 3600.0), + timeout: float | tuple[float, float] = (60.0, 3600.0), max_workers: int = 4, ) -> UploadResponse: return _upload_file( diff --git a/cytetype/api/exceptions.py b/cytetype/api/exceptions.py index 3ed2366..fe4c529 100644 --- a/cytetype/api/exceptions.py +++ b/cytetype/api/exceptions.py @@ -56,6 +56,12 @@ class LLMValidationError(APIError): pass +class ClientDisconnectedError(APIError): + """Server detected client disconnection mid-request - CLIENT_DISCONNECTED (HTTP 499).""" + + pass + + # Client-side errors with default messages class TimeoutError(CyteTypeError): """Client-side timeout waiting for results.""" @@ -87,6 +93,7 @@ def __init__( "JOB_NOT_FOUND": JobNotFoundError, "JOB_FAILED": JobFailedError, "LLM_VALIDATION_FAILED": LLMValidationError, + "CLIENT_DISCONNECTED": ClientDisconnectedError, "JOB_PROCESSING": APIError, # Generic - expected during polling "JOB_NOT_COMPLETED": APIError, # Generic "HTTP_ERROR": APIError, # Generic diff --git a/cytetype/config.py b/cytetype/config.py index 52bd61c..8979525 100644 --- a/cytetype/config.py +++ b/cytetype/config.py @@ -24,3 +24,5 @@ def _log_format(record: Record) -> str: level="INFO", format=_log_format, ) + +WRITE_MEM_BUDGET: int = 4 * 1024 * 1024 * 1024 # 4 GB diff --git a/cytetype/core/artifacts.py b/cytetype/core/artifacts.py index c916e1f..fd08aab 100644 --- a/cytetype/core/artifacts.py +++ b/cytetype/core/artifacts.py @@ -11,7 +11,7 @@ import scipy.sparse as sp from anndata.abc import CSCDataset, CSRDataset -from ..config import logger +from ..config import logger, WRITE_MEM_BUDGET def _safe_column_dataset_name( @@ -113,11 +113,310 @@ def _write_var_metadata( dataset.attrs["source_dtype"] = str(series.dtype) +def _is_integer_valued(mat: Any, sample_n_rows: int = 200) -> bool: + if hasattr(mat, "dtype") and np.issubdtype(mat.dtype, np.integer): + return True + + n_rows = mat.shape[0] + if n_rows <= sample_n_rows: + row_indices = np.arange(n_rows) + else: + rng = np.random.default_rng(42) + row_indices = np.sort(rng.choice(n_rows, size=sample_n_rows, replace=False)) + + chunk = mat[row_indices] + + if sp.issparse(chunk): + sample = chunk.data + elif hasattr(chunk, "toarray"): + sample = chunk.toarray().ravel() + else: + sample = np.asarray(chunk).ravel() + + if sample.size == 0: + return True + + sample = sample.astype(np.float64, copy=False) + return bool(np.all(np.isfinite(sample)) and np.all(sample == np.floor(sample))) + + +def _write_raw_group( + f: h5py.File, + mat: Any, + n_obs: int, + n_vars: int, + col_indices: "np.ndarray | None", + cell_batch: int, + min_chunk_size: int, +) -> None: + group = f.create_group("raw") + group.attrs["n_obs"] = n_obs + group.attrs["n_vars"] = n_vars + + chunk_size = max(1, min(n_vars * 10, min_chunk_size)) + + d_indices = group.create_dataset( + "indices", + shape=(0,), + maxshape=(n_obs * n_vars,), + chunks=(chunk_size,), + dtype=np.int32, + compression=hdf5plugin.LZ4(), + ) + d_data = group.create_dataset( + "data", + shape=(0,), + maxshape=(n_obs * n_vars,), + chunks=(chunk_size,), + dtype=np.int32, + compression=hdf5plugin.LZ4(), + ) + + indptr: list[int] = [0] + batch_starts = range(0, n_obs, cell_batch) + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + from tqdm.auto import tqdm + + batch_iter = tqdm( + batch_starts, desc="Writing raw counts to H5 artifact", unit="batch" + ) + except ImportError: + batch_iter = batch_starts + + for start in batch_iter: + end = min(start + cell_batch, n_obs) + raw_chunk = mat[start:end] + + if col_indices is not None: + raw_chunk = raw_chunk[:, col_indices] + + chunk = ( + raw_chunk.tocsr() if sp.issparse(raw_chunk) else sp.csr_matrix(raw_chunk) + ) + + old_size = d_indices.shape[0] + chunk_col_indices = chunk.indices.astype(np.int32, copy=False) + chunk_data = chunk.data.astype(np.int32, copy=False) + + d_indices.resize(old_size + len(chunk_col_indices), axis=0) + d_indices[old_size : old_size + len(chunk_col_indices)] = chunk_col_indices + + d_data.resize(old_size + len(chunk_data), axis=0) + d_data[old_size : old_size + len(chunk_data)] = chunk_data + + indptr.extend((chunk.indptr[1:] + indptr[-1]).tolist()) + + group.create_dataset("indptr", data=np.asarray(indptr, dtype=np.int64)) + + +def _try_import_tqdm() -> type | None: + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from tqdm.auto import tqdm + + return tqdm # type: ignore[no-any-return] + except ImportError: + return None + + +def _write_csc_via_row_batches( + group: h5py.Group, + mat: Any, + n_rows: int, + n_cols: int, + min_chunk_size: int, +) -> None: + row_batch = max(1, int(100_000_000 / max(n_cols, 1))) + chunk_size = max(1, min(n_rows * 10, min_chunk_size)) + tqdm = _try_import_tqdm() + + # --- Pass 1: count nnz per column --- + col_counts = np.zeros(n_cols, dtype=np.int64) + row_starts = range(0, n_rows, row_batch) + pass1_iter = ( + tqdm(row_starts, desc="Counting non-zeros in normalized matrix", unit="batch") + if tqdm is not None + else row_starts + ) + + for start in pass1_iter: + end = min(start + row_batch, n_rows) + chunk = mat[start:end] + csr = chunk.tocsr() if sp.issparse(chunk) else sp.csr_matrix(chunk) + col_counts += np.bincount(csr.indices, minlength=n_cols) + + indptr = np.empty(n_cols + 1, dtype=np.int64) + indptr[0] = 0 + np.cumsum(col_counts, out=indptr[1:]) + total_nnz = int(indptr[-1]) + + d_indices = group.create_dataset( + "indices", + shape=(total_nnz,), + dtype=np.int32, + chunks=(chunk_size,) if total_nnz > 0 else None, + compression=hdf5plugin.LZ4() if total_nnz > 0 else None, + ) + d_data = group.create_dataset( + "data", + shape=(total_nnz,), + dtype=np.float32, + chunks=(chunk_size,) if total_nnz > 0 else None, + compression=hdf5plugin.LZ4() if total_nnz > 0 else None, + ) + + # --- Pass 2+: column-group scatter --- + max_nnz_per_group = max(1, WRITE_MEM_BUDGET // 8) + + c_start = 0 + group_idx = 0 + while c_start < n_cols: + cumulative = np.cumsum(col_counts[c_start:]) + over_budget = np.searchsorted(cumulative, max_nnz_per_group, side="right") + c_end = c_start + max(1, int(over_budget)) + c_end = min(c_end, n_cols) + + group_nnz = int(indptr[c_end] - indptr[c_start]) + if group_nnz == 0: + c_start = c_end + group_idx += 1 + continue + + grp_indices = np.empty(group_nnz, dtype=np.int32) + grp_data = np.empty(group_nnz, dtype=np.float32) + + n_group_cols = c_end - c_start + cursors = (indptr[c_start:c_end] - indptr[c_start]).copy() + + desc = ( + f"Writing normalized counts (group {group_idx + 1})" + if c_end < n_cols or c_start > 0 + else "Writing normalized counts to H5 artifact" + ) + pass2_iter = ( + tqdm(row_starts, desc=desc, unit="batch") + if tqdm is not None + else row_starts + ) + + for start in pass2_iter: + end = min(start + row_batch, n_rows) + chunk = mat[start:end] + csr = chunk.tocsr() if sp.issparse(chunk) else sp.csr_matrix(chunk) + csc = csr.tocsc() + + local_start_ptr = csc.indptr[c_start] + local_end_ptr = csc.indptr[c_end] + if local_start_ptr == local_end_ptr: + continue + + slice_indices = csc.indices[local_start_ptr:local_end_ptr] + start + slice_data = csc.data[local_start_ptr:local_end_ptr].astype( + np.float32, copy=False + ) + slice_indptr = csc.indptr[c_start : c_end + 1] - local_start_ptr + + chunk_col_counts = np.diff(slice_indptr) + col_ids = np.repeat( + np.arange(n_group_cols, dtype=np.int64), chunk_col_counts + ) + bases = cursors[col_ids] + local_offsets = np.arange(len(slice_indices), dtype=np.int64) - np.repeat( + slice_indptr[:-1].astype(np.int64), chunk_col_counts + ) + targets = bases + local_offsets + + grp_indices[targets] = slice_indices + grp_data[targets] = slice_data + cursors += chunk_col_counts + + hdf5_offset = int(indptr[c_start]) + d_indices[hdf5_offset : hdf5_offset + group_nnz] = grp_indices + d_data[hdf5_offset : hdf5_offset + group_nnz] = grp_data + + del grp_indices, grp_data + c_start = c_end + group_idx += 1 + + group.create_dataset("indptr", data=indptr) + + +def _write_csc_via_col_batches( + group: h5py.Group, + mat: Any, + n_rows: int, + n_cols: int, + min_chunk_size: int, + col_batch: int | None, +) -> None: + if col_batch is None: + col_batch = max(1, int(1_000_000_000 / max(n_rows, 1))) + + chunk_size = max(1, min(n_rows * 10, min_chunk_size)) + + d_indices = group.create_dataset( + "indices", + shape=(0,), + maxshape=(n_rows * n_cols,), + chunks=(chunk_size,), + dtype=np.int32, + compression=hdf5plugin.LZ4(), + ) + d_data = group.create_dataset( + "data", + shape=(0,), + maxshape=(n_rows * n_cols,), + chunks=(chunk_size,), + dtype=np.float32, + compression=hdf5plugin.LZ4(), + ) + + indptr: list[int] = [0] + col_starts = range(0, n_cols, col_batch) + tqdm = _try_import_tqdm() + col_iter = ( + tqdm(col_starts, desc="Writing normalized counts", unit="batch") + if tqdm is not None + else col_starts + ) + + for start in col_iter: + end = min(start + col_batch, n_cols) + raw_chunk = mat[:, start:end] + chunk = ( + raw_chunk.tocsc() if sp.issparse(raw_chunk) else sp.csc_matrix(raw_chunk) + ) + + old_size = d_indices.shape[0] + chunk_indices = chunk.indices.astype(np.int32, copy=False) + chunk_data = chunk.data.astype(np.float32, copy=False) + + d_indices.resize(old_size + len(chunk_indices), axis=0) + d_indices[old_size : old_size + len(chunk_indices)] = chunk_indices + + d_data.resize(old_size + len(chunk_data), axis=0) + d_data[old_size : old_size + len(chunk_data)] = chunk_data + + indptr.extend((chunk.indptr[1:] + indptr[-1]).tolist()) + + group.create_dataset("indptr", data=np.asarray(indptr, dtype=np.int64)) + + def save_features_matrix( out_file: str, mat: Any, var_df: pd.DataFrame | None = None, var_names: pd.Index | Sequence[Any] | None = None, + raw_mat: Any | None = None, + raw_col_indices: "np.ndarray | None" = None, + raw_cell_batch: int = 2000, min_chunk_size: int = 10_000_000, col_batch: int | None = None, ) -> None: @@ -129,62 +428,26 @@ def save_features_matrix( n_rows, n_cols = mat.shape + use_row_batch_path = isinstance(mat, CSRDataset) + if not isinstance(mat, (CSRDataset, CSCDataset)): logger.warning( "For large datasets, use AnnData backed mode (e.g. sc.read_h5ad(..., backed='r')) " "so `adata.X` is a backed sparse dataset and avoids loading the full matrix in memory.", ) - if col_batch is None: - col_batch = max(1, int(100_000_000 / max(n_rows, 1))) - - chunk_size = max(1, min(n_rows * 10, min_chunk_size)) with h5py.File(out_file, "w") as f: group = f.create_group("vars") group.attrs["n_obs"] = n_rows group.attrs["n_vars"] = n_cols - d_indices = group.create_dataset( - "indices", - shape=(0,), - maxshape=(n_rows * n_cols,), - chunks=(chunk_size,), - dtype=np.int32, - compression=hdf5plugin.LZ4(), - ) - d_data = group.create_dataset( - "data", - shape=(0,), - maxshape=(n_rows * n_cols,), - chunks=(chunk_size,), - dtype=np.float32, - compression=hdf5plugin.LZ4(), - ) - - indptr: list[int] = [0] - for start in range(0, n_cols, col_batch): - end = min(start + col_batch, n_cols) - raw_chunk = mat[:, start:end] - chunk = ( - raw_chunk.tocsc() - if sp.issparse(raw_chunk) - else sp.csc_matrix(raw_chunk) + if use_row_batch_path: + _write_csc_via_row_batches(group, mat, n_rows, n_cols, min_chunk_size) + else: + _write_csc_via_col_batches( + group, mat, n_rows, n_cols, min_chunk_size, col_batch ) - old_size = d_indices.shape[0] - chunk_indices = chunk.indices.astype(np.int32, copy=False) - chunk_data = chunk.data.astype(np.float32, copy=False) - - d_indices.resize(old_size + len(chunk_indices), axis=0) - d_indices[old_size : old_size + len(chunk_indices)] = chunk_indices - - d_data.resize(old_size + len(chunk_data), axis=0) - d_data[old_size : old_size + len(chunk_data)] = chunk_data - - indptr.extend((chunk.indptr[1:] + indptr[-1]).tolist()) - - group.create_dataset("indptr", data=np.asarray(indptr, dtype=np.int64)) - if var_df is not None: _write_var_metadata( out_file_group=f, @@ -193,6 +456,24 @@ def save_features_matrix( var_names=var_names, ) + if raw_mat is not None: + try: + _write_raw_group( + f, + raw_mat, + n_rows, + n_cols, + raw_col_indices, + raw_cell_batch, + min_chunk_size, + ) + except Exception: + logger.warning( + "Failed to write raw counts group to %s, skipping.", out_file + ) + if "raw" in f: + del f["raw"] + def save_obs_duckdb( out_file: str, diff --git a/cytetype/main.py b/cytetype/main.py index 993a0f2..8c69ec8 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -1,8 +1,10 @@ +import sys from pathlib import Path from typing import Any from importlib.metadata import PackageNotFoundError, version import anndata +import numpy as np from natsort import natsorted from .config import logger @@ -13,6 +15,8 @@ ) from .preprocessing import ( validate_adata, + resolve_gene_symbols_column, + clean_gene_names, aggregate_expression_percentages, extract_marker_genes, aggregate_cluster_metadata, @@ -20,6 +24,7 @@ ) from .core.payload import build_annotation_payload, save_query_to_file from .core.artifacts import ( + _is_integer_valued, save_features_matrix, save_obs_duckdb as save_obs_duckdb_file, ) @@ -53,7 +58,7 @@ class CyteType: adata: anndata.AnnData group_key: str rank_key: str - gene_symbols_column: str + gene_symbols_column: str | None n_top_genes: int pcent_batch_size: int coordinates_key: str | None @@ -70,13 +75,15 @@ def __init__( adata: anndata.AnnData, group_key: str, rank_key: str = "rank_genes_groups", - gene_symbols_column: str = "gene_symbols", + gene_symbols_column: str | None = None, n_top_genes: int = 50, aggregate_metadata: bool = True, min_percentage: int = 10, - pcent_batch_size: int = 2000, + pcent_batch_size: int = 5000, coordinates_key: str = "X_umap", max_cells_per_group: int = 1000, + vars_h5_path: str = "vars.h5", + obs_duckdb_path: str = "obs.duckdb", api_url: str = "https://prod.cytetype.nygen.io", auth_token: str | None = None, ) -> None: @@ -90,8 +97,10 @@ def __init__( rank_key (str, optional): The key in `adata.uns` containing differential expression results from `sc.tl.rank_genes_groups`. Must use the same `groupby` as `group_key`. Defaults to "rank_genes_groups". - gene_symbols_column (str, optional): Name of the column in `adata.var` that contains - the gene symbols. Defaults to "gene_symbols". + gene_symbols_column (str | None, optional): Name of the column in `adata.var` that + contains gene symbols. When None (default), auto-detects by checking well-known + column names (feature_name, gene_symbols, etc.), then adata.var_names, then + heuristic scan of all var columns. Defaults to None. n_top_genes (int, optional): Number of top marker genes per cluster to extract during initialization. Higher values may improve annotation quality but increase memory usage. Defaults to 50. @@ -99,8 +108,8 @@ def __init__( Defaults to True. min_percentage (int, optional): Minimum percentage of cells in a group to include in the cluster context. Defaults to 10. - pcent_batch_size (int, optional): Batch size for calculating expression percentages to - optimize memory usage. Defaults to 2000. + pcent_batch_size (int, optional): Number of cells to process per chunk when + calculating expression percentages. Defaults to 5000. coordinates_key (str, optional): Key in adata.obsm containing 2D coordinates for visualization. Must be a 2D array with same number of elements as clusters. Defaults to "X_umap". @@ -119,17 +128,20 @@ def __init__( self.adata = adata self.group_key = group_key self.rank_key = rank_key - self.gene_symbols_column = gene_symbols_column self.n_top_genes = n_top_genes self.pcent_batch_size = pcent_batch_size self.coordinates_key = coordinates_key self.max_cells_per_group = max_cells_per_group self.api_url = api_url self.auth_token = auth_token + self._artifact_build_errors: list[tuple[str, Exception]] = [] + + self.gene_symbols_column = resolve_gene_symbols_column( + adata, gene_symbols_column + ) - # Validate data and get the best available coordinates key self.coordinates_key = validate_adata( - adata, group_key, rank_key, gene_symbols_column, coordinates_key + adata, group_key, rank_key, self.gene_symbols_column, coordinates_key ) # Use original labels as IDs if all are short (<=3 chars), otherwise enumerate @@ -145,12 +157,17 @@ def __init__( self.cluster_map[str(x)] for x in adata.obs[group_key].values.tolist() ] - logger.info("Calculating expression percentages...") + gene_names = ( + adata.var[self.gene_symbols_column].tolist() + if self.gene_symbols_column is not None + else adata.var_names.tolist() + ) + gene_names = clean_gene_names(gene_names) self.expression_percentages = aggregate_expression_percentages( adata=adata, clusters=self.clusters, - batch_size=pcent_batch_size, - gene_names=adata.var[self.gene_symbols_column].tolist(), + gene_names=gene_names, + cell_batch_size=pcent_batch_size, ) logger.info("Extracting marker genes...") @@ -182,7 +199,6 @@ def __init__( self.group_metadata = {} # Prepare visualization data with sampling - logger.info("Extracting sampled visualization coordinates...") sampled_coordinates, sampled_cluster_labels = extract_visualization_coordinates( adata=adata, coordinates_key=self.coordinates_key, @@ -196,93 +212,157 @@ def __init__( "clusters": sampled_cluster_labels, } - logger.info("Data preparation completed. Ready for submitting jobs.") - - def _build_and_upload_artifacts( - self, - vars_h5_path: str, - obs_duckdb_path: str, - upload_timeout_seconds: int, - upload_max_workers: int = 4, - coordinates_key: str | None = None, - ) -> tuple[dict[str, str], list[tuple[str, Exception]]]: - """Build and upload each artifact as an independent unit. - - Returns (uploaded_ids, errors) so the caller can decide whether - partial success is acceptable. - """ - uploaded: dict[str, str] = {} - errors: list[tuple[str, Exception]] = [] - timeout = (30.0, float(upload_timeout_seconds)) + # Resolve raw counts once and cache + self._raw_counts_result = self._resolve_raw_counts() + if self._raw_counts_result is None: + logger.warning( + "No integer raw counts found in adata.layers['counts'], " + "adata.raw.X, or adata.X. Skipping raw counts in vars.h5." + ) - # --- vars.h5 (save then upload) --- + # Build vars.h5 try: - logger.info("Saving vars.h5 artifact from normalized counts...") + raw_mat, raw_col_indices = ( + self._raw_counts_result + if self._raw_counts_result is not None + else (None, None) + ) save_features_matrix( out_file=vars_h5_path, mat=self.adata.X, var_df=self.adata.var, var_names=self.adata.var_names, + raw_mat=raw_mat, + raw_col_indices=raw_col_indices, ) - logger.info("Uploading vars.h5 artifact...") - vars_upload = upload_vars_h5_file( - self.api_url, - self.auth_token, - vars_h5_path, - timeout=timeout, - max_workers=upload_max_workers, - ) - if vars_upload.file_kind != "vars_h5": - raise ValueError( - f"Unexpected upload file_kind for vars artifact: {vars_upload.file_kind}" - ) - uploaded["vars_h5"] = vars_upload.upload_id + sys.stderr.flush() + self._vars_h5_path: str | None = vars_h5_path except Exception as exc: - logger.warning(f"vars.h5 artifact failed: {exc}") - errors.append(("vars_h5", exc)) + logger.warning(f"vars.h5 artifact failed during build: {exc}") + self._vars_h5_path = None + self._artifact_build_errors.append(("vars_h5", exc)) - print() - - # --- obs.duckdb (save then upload) --- + # Build obs.duckdb try: - logger.info("Saving obs.duckdb artifact from observation metadata...") + logger.info("Writing obs data to duckdb artifact...") obsm_coordinates = ( - self.adata.obsm[coordinates_key] - if coordinates_key and coordinates_key in self.adata.obsm + self.adata.obsm[self.coordinates_key] + if self.coordinates_key and self.coordinates_key in self.adata.obsm else None ) save_obs_duckdb_file( out_file=obs_duckdb_path, obs_df=self.adata.obs, obsm_coordinates=obsm_coordinates, - coordinates_key=coordinates_key, - ) - logger.info("Uploading obs.duckdb artifact...") - obs_upload = upload_obs_duckdb_file( - self.api_url, - self.auth_token, - obs_duckdb_path, - timeout=timeout, - max_workers=upload_max_workers, + coordinates_key=self.coordinates_key, ) - if obs_upload.file_kind != "obs_duckdb": - raise ValueError( - f"Unexpected upload file_kind for obs artifact: {obs_upload.file_kind}" - ) - uploaded["obs_duckdb"] = obs_upload.upload_id + sys.stderr.flush() + self._obs_duckdb_path: str | None = obs_duckdb_path except Exception as exc: - logger.warning(f"obs.duckdb artifact failed: {exc}") - errors.append(("obs_duckdb", exc)) + logger.warning(f"obs.duckdb artifact failed during build: {exc}") + self._obs_duckdb_path = None + self._artifact_build_errors.append(("obs_duckdb", exc)) - return uploaded, errors + logger.info("Data preparation completed. Ready for submitting jobs.") + + def _resolve_raw_counts( + self, + ) -> "tuple[Any, np.ndarray | None] | None": + if "counts" in self.adata.layers: + mat = self.adata.layers["counts"] + if _is_integer_valued(mat): + return mat, None + + if self.adata.raw is not None: + raw_mat = self.adata.raw.X + col_indices = self.adata.raw.var_names.get_indexer(self.adata.var_names) + if (col_indices == -1).any(): + logger.warning( + "Some var_names not found in adata.raw — skipping adata.raw.X as raw counts source." + ) + elif _is_integer_valued(raw_mat): + return raw_mat, col_indices.astype(np.intp) + + if _is_integer_valued(self.adata.X): + return self.adata.X, None + + return None + + def _upload_artifacts( + self, + upload_timeout_seconds: int, + upload_max_workers: int = 4, + ) -> tuple[dict[str, str], list[tuple[str, Exception]]]: + """Upload pre-built artifact files to the server. + + Returns (uploaded_ids, errors) so the caller can decide whether + partial success is acceptable. + """ + uploaded: dict[str, str] = {} + errors: list[tuple[str, Exception]] = list(self._artifact_build_errors) + timeout = (60.0, float(upload_timeout_seconds)) - @staticmethod - def _cleanup_artifact_files(paths: list[str]) -> None: - for artifact_path in paths: + # --- vars.h5 upload --- + if self._vars_h5_path is not None: try: - Path(artifact_path).unlink(missing_ok=True) - except OSError as exc: - logger.warning(f"Failed to cleanup artifact {artifact_path}: {exc}") + logger.info("Uploading vars.h5 artifact...") + vars_upload = upload_vars_h5_file( + self.api_url, + self.auth_token, + self._vars_h5_path, + timeout=timeout, + max_workers=upload_max_workers, + ) + if vars_upload.file_kind != "vars_h5": + raise ValueError( + f"Unexpected upload file_kind for vars artifact: {vars_upload.file_kind}" + ) + uploaded["vars_h5"] = vars_upload.upload_id + except Exception as exc: + logger.warning(f"vars.h5 upload failed: {exc}") + errors.append(("vars_h5", exc)) + + print() + + # --- obs.duckdb upload --- + if self._obs_duckdb_path is not None: + try: + logger.info("Uploading obs.duckdb artifact...") + obs_upload = upload_obs_duckdb_file( + self.api_url, + self.auth_token, + self._obs_duckdb_path, + timeout=timeout, + max_workers=upload_max_workers, + ) + if obs_upload.file_kind != "obs_duckdb": + raise ValueError( + f"Unexpected upload file_kind for obs artifact: {obs_upload.file_kind}" + ) + uploaded["obs_duckdb"] = obs_upload.upload_id + except Exception as exc: + logger.warning(f"obs.duckdb upload failed: {exc}") + errors.append(("obs_duckdb", exc)) + + return uploaded, errors + + def cleanup(self) -> None: + """Delete the artifact files built during initialization. + + Call this after run() completes to remove the vars.h5 and obs.duckdb + files from disk. Paths are cleared so repeated calls are safe. + """ + for attr, path in [ + ("_vars_h5_path", self._vars_h5_path), + ("_obs_duckdb_path", self._obs_duckdb_path), + ]: + if path is not None: + try: + Path(path).unlink(missing_ok=True) + logger.info(f"Deleted artifact file: {path}") + except OSError as exc: + logger.warning(f"Failed to delete artifact {path}: {exc}") + setattr(self, attr, None) def run( self, @@ -297,11 +377,8 @@ def run( auth_token: str | None = None, save_query: bool = True, query_filename: str = "query.json", - vars_h5_path: str = "vars.h5", - obs_duckdb_path: str = "obs.duckdb", upload_timeout_seconds: int = 3600, upload_max_workers: int = 4, - cleanup_artifacts: bool = False, require_artifacts: bool = True, show_progress: bool = True, override_existing_results: bool = False, @@ -337,16 +414,10 @@ def run( save_query (bool, optional): Whether to save the query to a file. Defaults to True. query_filename (str, optional): Filename for saving the query when save_query is True. Defaults to "query.json". - vars_h5_path (str, optional): Local output path for generated vars.h5 artifact. - Defaults to "vars.h5". - obs_duckdb_path (str, optional): Local output path for generated obs.duckdb artifact. - Defaults to "obs.duckdb". upload_timeout_seconds (int, optional): Socket read timeout used for each artifact upload. Defaults to 3600. upload_max_workers (int, optional): Number of parallel threads used to upload file chunks. Each worker holds one chunk in memory (~100 MB). Defaults to 4. - cleanup_artifacts (bool, optional): Whether to delete generated artifact files after run - completes or fails. Defaults to False. require_artifacts (bool, optional): Whether to raise an error if artifact building or uploading fails. When True (default), any artifact failure stops the run. Set to False to skip artifacts and continue with annotation only. Defaults to True. @@ -387,7 +458,7 @@ def run( if upload_timeout_seconds <= 0: raise ValueError("upload_timeout_seconds must be greater than 0") - # Validate the base payload before doing potentially expensive uploads. + # Validate the base payload before uploading. payload = build_annotation_payload( study_context, metadata, @@ -400,69 +471,61 @@ def run( llm_configs, ) - artifact_paths = [vars_h5_path, obs_duckdb_path] - try: - uploaded_file_refs, artifact_errors = self._build_and_upload_artifacts( - vars_h5_path=vars_h5_path, - obs_duckdb_path=obs_duckdb_path, - upload_timeout_seconds=upload_timeout_seconds, - upload_max_workers=upload_max_workers, - coordinates_key=self.coordinates_key, - ) - if uploaded_file_refs: - payload["uploaded_files"] = uploaded_file_refs - - if artifact_errors: - failed_names = ", ".join(name for name, _ in artifact_errors) - if require_artifacts: - logger.error( - f"Artifact build/upload failed for: {failed_names}. " - "Rerun with `require_artifacts=False` to skip this error.\n" - "Please report the error below in a new issue at " - "https://github.com/NygenAnalytics/CyteType\n" - f"({type(artifact_errors[0][1]).__name__}: {str(artifact_errors[0][1]).strip()})" - ) - raise artifact_errors[0][1] - logger.warning( + uploaded_file_refs, artifact_errors = self._upload_artifacts( + upload_timeout_seconds=upload_timeout_seconds, + upload_max_workers=upload_max_workers, + ) + if uploaded_file_refs: + payload["uploaded_files"] = uploaded_file_refs + + if artifact_errors: + failed_names = ", ".join(name for name, _ in artifact_errors) + if require_artifacts: + logger.error( f"Artifact build/upload failed for: {failed_names}. " - "Continuing without those artifacts. " - "Set `require_artifacts=True` to see the full traceback." + "Rerun with `require_artifacts=False` to skip this error.\n" + "Please report the error below in a new issue at " + "https://github.com/NygenAnalytics/CyteType\n" + f"({type(artifact_errors[0][1]).__name__}: {str(artifact_errors[0][1]).strip()})" ) - - # Save query if requested - if save_query: - save_query_to_file(payload["input_data"], query_filename) - - # Submit job and store details - print() - job_id = submit_annotation_job(self.api_url, self.auth_token, payload) - store_job_details(self.adata, job_id, self.api_url, results_prefix) - - # Wait for completion - result = wait_for_completion( - self.api_url, - self.auth_token, - job_id, - poll_interval_seconds, - timeout_seconds, - show_progress, + raise artifact_errors[0][1] + logger.warning( + f"Artifact build/upload failed for: {failed_names}. " + "Continuing without those artifacts. " + "Set `require_artifacts=True` to see the full traceback." ) - # Store results - store_annotations( - self.adata, - result, - job_id, - results_prefix, - self.group_key, - self.clusters, - check_unannotated=True, - ) + # Save query if requested + if save_query: + save_query_to_file(payload["input_data"], query_filename) + + # Submit job and store details + print() + job_id = submit_annotation_job(self.api_url, self.auth_token, payload) + store_job_details(self.adata, job_id, self.api_url, results_prefix) + + # Wait for completion + result = wait_for_completion( + self.api_url, + self.auth_token, + job_id, + poll_interval_seconds, + timeout_seconds, + show_progress, + ) + + # Store results + store_annotations( + self.adata, + result, + job_id, + results_prefix, + self.group_key, + self.clusters, + check_unannotated=True, + ) - return self.adata - finally: - if cleanup_artifacts: - self._cleanup_artifact_files(artifact_paths) + return self.adata def get_results( self, diff --git a/cytetype/plotting/__init__.py b/cytetype/plotting/__init__.py new file mode 100644 index 0000000..14e1513 --- /dev/null +++ b/cytetype/plotting/__init__.py @@ -0,0 +1,5 @@ +from .dotplot import marker_dotplot + +__all__ = [ + "marker_dotplot", +] diff --git a/cytetype/plotting/dotplot.py b/cytetype/plotting/dotplot.py new file mode 100644 index 0000000..e24cbb5 --- /dev/null +++ b/cytetype/plotting/dotplot.py @@ -0,0 +1,105 @@ +from typing import Any + +import anndata + +from ..config import logger +from ..core.results import load_local_results + + +def marker_dotplot( + adata: anndata.AnnData, + group_key: str, + results_prefix: str = "cytetype", + n_top_markers: int = 3, + gene_symbols: str | None = None, + **kwargs: Any, +) -> Any: + """Dotplot of top marker genes grouped by CyteType cluster categories. + + Reads stored CyteType results from ``adata.uns`` and builds a + category-grouped marker dict suitable for ``sc.pl.dotplot``. + + Parameters + ---------- + adata + AnnData object that has been annotated by CyteType + (results stored in ``adata.uns``). + group_key + The original cluster key used during annotation (e.g. ``"leiden"``). + results_prefix + Prefix used when the results were stored. Must match the + ``results_prefix`` passed to :meth:`CyteType.run`. + n_top_markers + Number of top supporting genes to display per cluster. + gene_symbols + Column in ``adata.var`` containing gene symbols. Forwarded to + ``sc.pl.dotplot`` via the ``gene_symbols`` parameter. + **kwargs + Additional keyword arguments forwarded to ``sc.pl.dotplot`` + (e.g. ``cmap``, ``use_raw``). + + Returns + ------- + The return value of ``sc.pl.dotplot``. + """ + results = load_local_results(adata, results_prefix) + if results is None: + raise KeyError( + f"No CyteType results found in adata.uns with prefix '{results_prefix}'. " + "Run CyteType annotation first or check the results_prefix." + ) + + cluster_categories = results.get("clusterCategories", []) + if not cluster_categories: + raise ValueError( + "No cluster categories found in CyteType results. " + "The API response may not include category groupings for this run." + ) + + raw_annotations = results.get("raw_annotations", {}) + if not raw_annotations: + raise ValueError("No raw_annotations found in CyteType results.") + + markers: dict[str, list[str]] = {} + categories_order: list[str] = [] + + for category in cluster_categories: + category_name = category["categoryName"] + markers[category_name] = [] + + for cluster_id in category["clusterIds"]: + cluster_data = raw_annotations.get(cluster_id) + if cluster_data is None: + logger.warning( + "Cluster '{}' listed in clusterCategories but missing from raw_annotations, skipping.", + cluster_id, + ) + continue + + full_output = cluster_data["latest"]["annotation"]["fullOutput"] + cell_type = full_output["cellType"] + + categories_order.append(cell_type["label"]) + markers[category_name].extend( + cell_type["keySupportingGenes"][:n_top_markers] + ) + + markers[category_name] = sorted(set(markers[category_name])) + + try: + import scanpy as sc + except ImportError: + raise ImportError( + "scanpy is required for plotting. Install it with: pip install scanpy" + ) from None + + groupby = f"{results_prefix}_annotation_{group_key}" + + return sc.pl.dotplot( + adata, + markers, + groupby=groupby, + gene_symbols=gene_symbols, + categories_order=categories_order, + **kwargs, + ) diff --git a/cytetype/preprocessing/__init__.py b/cytetype/preprocessing/__init__.py index c33ce9d..d25b339 100644 --- a/cytetype/preprocessing/__init__.py +++ b/cytetype/preprocessing/__init__.py @@ -1,11 +1,17 @@ -from .validation import validate_adata +from .validation import validate_adata, resolve_gene_symbols_column, clean_gene_names from .aggregation import aggregate_expression_percentages, aggregate_cluster_metadata from .extraction import extract_marker_genes, extract_visualization_coordinates +from .marker_detection import rank_genes_groups_backed +from .subsampling import subsample_by_group __all__ = [ "validate_adata", + "resolve_gene_symbols_column", + "clean_gene_names", "aggregate_expression_percentages", "aggregate_cluster_metadata", "extract_marker_genes", "extract_visualization_coordinates", + "rank_genes_groups_backed", + "subsample_by_group", ] diff --git a/cytetype/preprocessing/aggregation.py b/cytetype/preprocessing/aggregation.py index edd2dd3..c279725 100644 --- a/cytetype/preprocessing/aggregation.py +++ b/cytetype/preprocessing/aggregation.py @@ -1,41 +1,53 @@ import anndata import numpy as np -import pandas as pd + +from .marker_detection import _accumulate_group_stats def aggregate_expression_percentages( - adata: anndata.AnnData, clusters: list[str], batch_size: int, gene_names: list[str] + adata: anndata.AnnData, + clusters: list[str], + gene_names: list[str], + cell_batch_size: int = 5000, ) -> dict[str, dict[str, float]]: """Aggregate gene expression percentages per cluster. + Uses a single-pass row-batched accumulation (fast for CSR / backed data). + Args: adata: AnnData object containing expression data clusters: List of cluster assignments for each cell - batch_size: Number of genes to process per batch (for memory efficiency) gene_names: List of gene names corresponding to columns in adata.X + cell_batch_size: Number of cells to process per chunk Returns: Dictionary mapping gene names to cluster-level expression percentages """ - pcent = {} - n_genes = adata.shape[1] - - for s in range(0, n_genes, batch_size): - e = min(s + batch_size, n_genes) - batch_data = adata.X[:, s:e] - if hasattr(batch_data, "toarray"): - batch_data = batch_data.toarray() - elif isinstance(batch_data, np.ndarray): - pass - else: - raise TypeError( - f"Unexpected data type in `adata.raw.X` slice: {type(batch_data)}" - ) - - df = pd.DataFrame(batch_data > 0, columns=gene_names[s:e]) * 100 - df["clusters"] = clusters - pcent.update(df.groupby("clusters").mean().round(2).to_dict()) - del df, batch_data + unique_clusters = sorted(set(clusters)) + n_groups = len(unique_clusters) + cluster_to_idx = {c: i for i, c in enumerate(unique_clusters)} + cell_group_indices = np.array([cluster_to_idx[c] for c in clusters]) + + stats = _accumulate_group_stats( + adata.X, + cell_group_indices, + n_groups, + adata.shape[1], + cell_batch_size=cell_batch_size, + compute_nnz=True, + progress_desc="Calculating expression percentages", + ) + + with np.errstate(divide="ignore", invalid="ignore"): + pct_matrix = np.round(stats.nnz / stats.n[:, None] * 100, 2) + pct_matrix = np.nan_to_num(pct_matrix, nan=0.0) + + pcent: dict[str, dict[str, float]] = {} + for gene_idx, name in enumerate(gene_names): + pcent[name] = { + unique_clusters[g_idx]: float(pct_matrix[g_idx, gene_idx]) + for g_idx in range(n_groups) + } return pcent diff --git a/cytetype/preprocessing/extraction.py b/cytetype/preprocessing/extraction.py index f5c9252..44f9c9c 100644 --- a/cytetype/preprocessing/extraction.py +++ b/cytetype/preprocessing/extraction.py @@ -2,6 +2,7 @@ import pandas as pd from ..config import logger +from .validation import _extract_symbol_from_composite, clean_gene_names def extract_marker_genes( @@ -10,7 +11,7 @@ def extract_marker_genes( rank_genes_key: str, cluster_map: dict[str, str], n_top_genes: int, - gene_symbols_col: str, + gene_symbols_col: str | None, ) -> dict[str, list[str]]: """Extract top marker genes from rank_genes_groups results. @@ -20,7 +21,8 @@ def extract_marker_genes( rank_genes_key: Key in adata.uns containing rank_genes_groups results cluster_map: Dictionary mapping original labels to cluster IDs n_top_genes: Number of top genes to extract per cluster - gene_symbols_col: Column in adata.var containing gene symbols + gene_symbols_col: Column in adata.var containing gene symbols, + or None to use var_names directly (identity mapping). Returns: Dictionary mapping cluster IDs to lists of marker gene symbols @@ -44,7 +46,15 @@ def extract_marker_genes( f"Failed to extract marker gene names from `rank_genes_groups`. Error: {e}" ) - gene_ids_to_name = adata.var[gene_symbols_col].to_dict() + if gene_symbols_col is not None: + raw_map = adata.var[gene_symbols_col].to_dict() + gene_ids_to_name = { + k: _extract_symbol_from_composite(str(v)) for k, v in raw_map.items() + } + else: + raw_names = adata.var_names.tolist() + cleaned = clean_gene_names(raw_names) + gene_ids_to_name = dict(zip(adata.var_names, cleaned)) markers = {} any_genes_found = False @@ -123,8 +133,24 @@ def extract_visualization_coordinates( ) # Sample cells from each group using pandas + unique_groups = coord_df["group"].unique() + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + from tqdm.auto import tqdm + + group_iter = tqdm( + unique_groups, + desc=f"Sampling coordinates from {coordinates_key}", + unit="group", + ) + except ImportError: + group_iter = unique_groups + sampled_coords = [] - for group_label in coord_df["group"].unique(): + for group_label in group_iter: group_mask = coord_df["group"] == group_label group_size = group_mask.sum() sample_size = min(max_cells_per_group, group_size) @@ -134,12 +160,6 @@ def extract_visualization_coordinates( ) sampled_coords.append(sampled_group) - if group_size > max_cells_per_group: - logger.info( - f"Sampled {sample_size} cells from group '{group_label}' " - f"(originally {group_size} cells)" - ) - # Concatenate all sampled groups sampled_coord_df: pd.DataFrame = pd.concat(sampled_coords, ignore_index=True) @@ -152,9 +172,9 @@ def extract_visualization_coordinates( for label in sampled_coord_df["group"].values ] - logger.info( - f"Extracted {len(sampled_coordinates)} coordinate points " - f"(sampled from {len(coordinates)} total cells)" - ) + # logger.info( + # f"Extracted {len(sampled_coordinates)} coordinate points " + # f"(sampled from {len(coordinates)} total cells)" + # ) return sampled_coordinates, sampled_cluster_labels diff --git a/cytetype/preprocessing/marker_detection.py b/cytetype/preprocessing/marker_detection.py new file mode 100644 index 0000000..264d607 --- /dev/null +++ b/cytetype/preprocessing/marker_detection.py @@ -0,0 +1,315 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Literal + +import anndata +import numpy as np +import pandas as pd +from natsort import natsorted +import scipy.sparse as sp +from scipy.stats import ttest_ind_from_stats + +from ..config import logger + + +@dataclass +class GroupStats: + n: np.ndarray + nnz: np.ndarray | None + sum_: np.ndarray | None + sum_sq: np.ndarray | None + + +def _accumulate_group_stats( + X: Any, + cell_group_indices: np.ndarray, + n_groups: int, + n_genes: int, + cell_batch_size: int = 5000, + compute_nnz: bool = False, + compute_moments: bool = False, + progress_desc: str = "streaming", +) -> GroupStats: + """Single-pass row-batched accumulation of per-group statistics. + + Streams row chunks from *X* (works with dense, sparse, and backed + matrices) and accumulates the requested statistics per group. + """ + n_cells = X.shape[0] + + n_ = np.zeros(n_groups, dtype=np.int64) + nnz_ = np.zeros((n_groups, n_genes), dtype=np.int64) if compute_nnz else None + sum_ = np.zeros((n_groups, n_genes), dtype=np.float64) if compute_moments else None + sum_sq_ = ( + np.zeros((n_groups, n_genes), dtype=np.float64) if compute_moments else None + ) + + chunk_starts = range(0, n_cells, cell_batch_size) + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from tqdm.auto import tqdm + + chunk_iter = tqdm(chunk_starts, desc=progress_desc, unit="chunk") + except ImportError: + chunk_iter = chunk_starts + + for start in chunk_iter: + end = min(start + cell_batch_size, n_cells) + chunk = X[start:end] + chunk_labels = cell_group_indices[start:end] + batch_len = end - start + + if sp.issparse(chunk): + chunk = chunk.tocsr() + + ind = sp.csr_matrix( + ( + np.ones(batch_len, dtype=np.float64), + (chunk_labels, np.arange(batch_len)), + ), + shape=(n_groups, batch_len), + ) + + n_ += np.asarray(ind.sum(axis=1), dtype=np.int64).ravel() + if sum_ is not None: + sum_ += np.asarray((ind @ chunk).toarray(), dtype=np.float64) + if sum_sq_ is not None: + sum_sq_ += np.asarray( + (ind @ chunk.multiply(chunk)).toarray(), dtype=np.float64 + ) + if nnz_ is not None: + binary = chunk.copy() + binary.eliminate_zeros() + binary.data = np.ones_like(binary.data, dtype=np.float64) + nnz_ += np.asarray((ind @ binary).toarray(), dtype=np.int64) + else: + if hasattr(chunk, "toarray"): + chunk = chunk.toarray() + chunk = np.asarray(chunk, dtype=np.float64) + + indicator = np.zeros((n_groups, batch_len), dtype=np.float64) + indicator[chunk_labels, np.arange(batch_len)] = 1.0 + + n_ += indicator.sum(axis=1).astype(np.int64) + if sum_ is not None: + sum_ += indicator @ chunk + if sum_sq_ is not None: + sum_sq_ += indicator @ (chunk**2) + if nnz_ is not None: + nnz_ += (indicator @ (chunk != 0).astype(np.float64)).astype(np.int64) + + return GroupStats(n=n_, nnz=nnz_, sum_=sum_, sum_sq=sum_sq_) + + +def _benjamini_hochberg(pvals: np.ndarray) -> np.ndarray: + n = len(pvals) + order = np.argsort(pvals) + pvals_sorted = pvals[order] + adjusted = pvals_sorted * n / np.arange(1, n + 1) + adjusted = np.minimum(adjusted, 1.0) + # enforce monotonicity from the right (largest p-value first) + adjusted = np.minimum.accumulate(adjusted[::-1])[::-1] + result = np.empty_like(pvals) + result[order] = adjusted + return result + + +def rank_genes_groups_backed( + adata: anndata.AnnData, + groupby: str, + *, + use_raw: bool = False, + layer: str | None = None, + n_genes: int | None = None, + key_added: str = "rank_genes_groups", + cell_batch_size: int = 5000, + corr_method: Literal["benjamini-hochberg", "bonferroni"] = "benjamini-hochberg", + pts: bool = False, + copy: bool = False, +) -> anndata.AnnData | None: + """Memory-efficient rank_genes_groups for backed/on-disk AnnData objects. + + Drop-in replacement for ``sc.tl.rank_genes_groups`` that works on backed + ``_CSRDataset`` matrices by streaming cell chunks instead of loading the + full matrix. Uses Welch's t-test with one-vs-rest comparison. + + Writes scanpy-compatible output to ``adata.uns[key_added]``. + """ + adata = adata.copy() if copy else adata + + # --- resolve data source --- + if layer is not None: + if use_raw: + raise ValueError("Cannot specify `layer` and have `use_raw=True`.") + X = adata.layers[layer] + var_names = adata.var_names + elif use_raw: + if adata.raw is None: + raise ValueError("Received `use_raw=True`, but `adata.raw` is empty.") + X = adata.raw.X + var_names = adata.raw.var_names + else: + X = adata.X + var_names = adata.var_names + + n_cells, n_genes_total = X.shape + + # --- resolve groups --- + groupby_series = adata.obs[groupby] + if not hasattr(groupby_series, "cat"): + groupby_series = groupby_series.astype("category") + groups_order = np.array( + natsorted(groupby_series.cat.categories.tolist()), dtype=object + ) + group_labels = groupby_series.values + + # reject singlet groups + group_counts = groupby_series.value_counts() + singlets = group_counts[group_counts < 2].index.tolist() + if singlets: + raise ValueError( + f"Could not calculate statistics for groups {', '.join(str(s) for s in singlets)} " + "since they only contain one sample." + ) + + n_groups = len(groups_order) + group_to_idx = {g: i for i, g in enumerate(groups_order)} + cell_group_indices = np.array([group_to_idx[g] for g in group_labels]) + + # --- log1p base handling (matches scanpy) --- + log1p_base = adata.uns.get("log1p", {}).get("base") + + def expm1_func(x: np.ndarray) -> np.ndarray: + if log1p_base is not None: + result: np.ndarray = np.expm1(x * np.log(log1p_base)) + return result + out: np.ndarray = np.expm1(x) + return out + + # --- accumulate sufficient statistics in one pass --- + stats = _accumulate_group_stats( + X, + cell_group_indices, + n_groups, + n_genes_total, + cell_batch_size=cell_batch_size, + compute_nnz=pts, + compute_moments=True, + progress_desc="rank_genes_groups_backed", + ) + n_ = stats.n + nnz_ = stats.nnz + sum_ = stats.sum_ + sum_sq_ = stats.sum_sq + assert sum_ is not None and sum_sq_ is not None + total_sum = sum_.sum(axis=0) + total_sum_sq = sum_sq_.sum(axis=0) + + # --- compute per-group statistics and t-test --- + logger.info("Computing t-test statistics for {} groups...", n_groups) + + n_out = n_genes_total if n_genes is None else min(n_genes, n_genes_total) + result_names = np.empty((n_out, n_groups), dtype=object) + result_scores = np.empty((n_out, n_groups), dtype=np.float32) + result_logfc = np.empty((n_out, n_groups), dtype=np.float32) + result_pvals = np.empty((n_out, n_groups), dtype=np.float64) + result_pvals_adj = np.empty((n_out, n_groups), dtype=np.float64) + + if pts and nnz_ is not None: + pts_arr = np.zeros((n_groups, n_genes_total), dtype=np.float64) + pts_rest_arr = np.zeros((n_groups, n_genes_total), dtype=np.float64) + total_nnz = nnz_.sum(axis=0) + + for g_idx in range(n_groups): + ng = n_[g_idx] + nr = n_cells - ng + + mean_g = sum_[g_idx] / ng + var_g = (sum_sq_[g_idx] - sum_[g_idx] ** 2 / ng) / (ng - 1) + + sum_rest = total_sum - sum_[g_idx] + sum_sq_rest = total_sum_sq - sum_sq_[g_idx] + mean_r = sum_rest / nr + var_r = (sum_sq_rest - sum_rest**2 / nr) / (nr - 1) + + with np.errstate(invalid="ignore"): + scores, pvals = ttest_ind_from_stats( + mean1=mean_g, + std1=np.sqrt(var_g), + nobs1=ng, + mean2=mean_r, + std2=np.sqrt(var_r), + nobs2=nr, + equal_var=False, + ) + + scores[np.isnan(scores)] = 0 + pvals[np.isnan(pvals)] = 1 + + if corr_method == "benjamini-hochberg": + pvals_adj = _benjamini_hochberg(pvals) + else: + pvals_adj = np.minimum(pvals * n_genes_total, 1.0) + + with np.errstate(divide="ignore", invalid="ignore"): + foldchanges = (expm1_func(mean_g) + 1e-9) / (expm1_func(mean_r) + 1e-9) + logfc = np.log2(foldchanges) + + top_indices = np.argsort(-scores)[:n_out] + + result_names[:, g_idx] = var_names[top_indices] + result_scores[:, g_idx] = scores[top_indices].astype(np.float32) + result_logfc[:, g_idx] = logfc[top_indices].astype(np.float32) + result_pvals[:, g_idx] = pvals[top_indices] + result_pvals_adj[:, g_idx] = pvals_adj[top_indices] + + if pts and nnz_ is not None: + pts_arr[g_idx] = nnz_[g_idx] / ng + rest_nnz = total_nnz - nnz_[g_idx] + pts_rest_arr[g_idx] = rest_nnz / nr + + # --- build scanpy-compatible recarrays --- + group_names = [str(g) for g in groups_order] + + def _to_recarray(data: np.ndarray, dtype_str: str) -> np.recarray: + df = pd.DataFrame(data, columns=group_names) + rec: np.recarray = df.to_records(index=False, column_dtypes=dtype_str) + return rec + + adata.uns[key_added] = { + "params": { + "groupby": groupby, + "reference": "rest", + "method": "t-test", + "use_raw": use_raw, + "layer": layer, + "corr_method": corr_method, + }, + "names": _to_recarray(result_names, "O"), + "scores": _to_recarray(result_scores, "float32"), + "logfoldchanges": _to_recarray(result_logfc, "float32"), + "pvals": _to_recarray(result_pvals, "float64"), + "pvals_adj": _to_recarray(result_pvals_adj, "float64"), + } + + if pts and nnz_ is not None: + adata.uns[key_added]["pts"] = pd.DataFrame( + pts_arr.T, index=var_names, columns=group_names + ) + adata.uns[key_added]["pts_rest"] = pd.DataFrame( + pts_rest_arr.T, index=var_names, columns=group_names + ) + + logger.info( + "rank_genes_groups_backed complete — {} genes per group written to adata.uns['{}']", + n_out, + key_added, + ) + + if copy: + return adata + return None diff --git a/cytetype/preprocessing/subsampling.py b/cytetype/preprocessing/subsampling.py new file mode 100644 index 0000000..69a919e --- /dev/null +++ b/cytetype/preprocessing/subsampling.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import anndata +import pandas as pd +from natsort import natsorted + +from ..config import logger + + +def subsample_by_group( + adata: anndata.AnnData, + group_key: str, + max_cells_per_group: int = 1000, + random_state: int = 0, +) -> anndata.AnnData: + """Subsample cells from an AnnData object, capping each group to a maximum count. + + Groups smaller than *max_cells_per_group* are kept intact. + + Parameters + ---------- + adata + The AnnData object to subsample. + group_key + Column in ``adata.obs`` that defines the groups (e.g. cluster labels). + max_cells_per_group + Maximum number of cells to retain per group. Groups with fewer cells + are included in full. + random_state + Seed for reproducible sampling. + + Returns + ------- + anndata.AnnData + A new in-memory AnnData object containing at most *max_cells_per_group* + cells per group. + """ + if group_key not in adata.obs.columns: + raise KeyError( + f"Group key '{group_key}' not found in adata.obs. " + f"Available columns: {list(adata.obs.columns)}" + ) + + is_backed = getattr(adata, "isbacked", False) + groups = natsorted(adata.obs[group_key].unique()) + subsampled: list[anndata.AnnData] = [] + + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from tqdm.auto import tqdm + + group_iter = tqdm(groups, desc="Subsampling groups", unit="group") + except ImportError: + group_iter = groups + + for group in group_iter: + mask = adata.obs[group_key] == group + n_cells = mask.sum() + + if n_cells > max_cells_per_group: + keep = pd.Series(False, index=adata.obs.index) + sampled = mask[mask].sample( + n=max_cells_per_group, random_state=random_state + ) + keep[sampled.index] = True + subset = adata[keep] + else: + subset = adata[mask] + + subsampled.append(subset.to_memory() if is_backed else subset.copy()) + + result = anndata.concat(subsampled, merge="first") + + logger.info(f"Subsampling complete: {adata.n_obs} -> {result.n_obs} cells") + + return result diff --git a/cytetype/preprocessing/validation.py b/cytetype/preprocessing/validation.py index 0d6fdbc..fce5d5e 100644 --- a/cytetype/preprocessing/validation.py +++ b/cytetype/preprocessing/validation.py @@ -3,58 +3,89 @@ from ..config import logger +_KNOWN_GENE_SYMBOL_COLUMNS = [ + "feature_name", + "gene_symbols", + "gene_symbol", + "gene_short_name", + "gene_name", + "symbol", +] -def _is_gene_id_like(value: str) -> bool: - """Check if a value looks like a gene ID rather than a gene symbol. - - Common gene ID patterns: - - Ensembl: ENSG00000000003, ENSMUSG00000000001, etc. - - RefSeq: NM_000001, XM_000001, etc. - - Numeric IDs: just numbers - - Other database IDs with similar patterns - - Args: - value: String value to check - Returns: - bool: True if the value looks like a gene ID, False if it looks like a gene symbol - """ +def _is_gene_id_like(value: str) -> bool: if not isinstance(value, str) or not value.strip(): return False value = value.strip() - # Ensembl IDs (human, mouse, etc.) if re.match(r"^ENS[A-Z]*G\d{11}$", value, re.IGNORECASE): return True - # RefSeq IDs if re.match(r"^[NX][MR]_\d+$", value): return True - # Purely numeric IDs if re.match(r"^\d+$", value): return True - # Other common ID patterns (long alphanumeric with underscores/dots) if re.match(r"^[A-Z0-9]+[._][A-Z0-9._]+$", value) and len(value) > 10: return True return False -def _validate_gene_symbols_column( - adata: anndata.AnnData, gene_symbols_col: str -) -> None: - """Validate that the gene_symbols_col contains gene symbols rather than gene IDs. +def _has_composite_gene_values(values: list[str]) -> bool: + """Detect values like 'TSPAN6_ENSG00000000003' (symbol_id or id_symbol).""" + composite_count = 0 + for v in values[:200]: + parts = re.split(r"[_|]", v, maxsplit=1) + if len(parts) == 2: + id_flags = [_is_gene_id_like(p) for p in parts] + if id_flags[0] != id_flags[1]: + composite_count += 1 + return len(values) > 0 and (composite_count / min(200, len(values))) > 0.5 + + +def _extract_symbol_from_composite(value: str) -> str: + parts = re.split(r"[_|]", value, maxsplit=1) + if len(parts) != 2: + return value + id_flags = [_is_gene_id_like(p) for p in parts] + if id_flags[0] and not id_flags[1]: + return parts[1] + if not id_flags[0] and id_flags[1]: + return parts[0] + return value + - Args: - adata: AnnData object - gene_symbols_col: Column name in adata.var that should contain gene symbols +def clean_gene_names(names: list[str]) -> list[str]: + """Extract gene symbols from composite gene name/ID values. - Raises: - ValueError: If the column appears to contain gene IDs instead of gene symbols + If >50% of values are composite (e.g. ``TSPAN6_ENSG00000000003``), + splits each value and returns the gene-symbol part. Non-composite + lists are returned unchanged. """ + if not _has_composite_gene_values(names): + return names + cleaned = [_extract_symbol_from_composite(n) for n in names] + logger.info( + f"Cleaned {len(cleaned)} composite gene values " + f"(e.g., '{names[0]}' -> '{cleaned[0]}')." + ) + return cleaned + + +def _id_like_percentage(values: list[str]) -> float: + if not values: + return 100.0 + n = min(500, len(values)) + sample = values[:n] + return sum(1 for v in sample if _is_gene_id_like(v)) / n * 100 + + +def _validate_gene_symbols_column( + adata: anndata.AnnData, gene_symbols_col: str +) -> None: gene_values = adata.var[gene_symbols_col].dropna().astype(str) if len(gene_values) == 0: @@ -63,54 +94,146 @@ def _validate_gene_symbols_column( ) return - # Sample a subset for efficiency (check up to 1000 non-null values) - sample_size = min(1000, len(gene_values)) - sample_values = gene_values.sample(n=sample_size) - - # Count how many look like gene IDs vs gene symbols - id_like_count = sum(1 for value in sample_values if _is_gene_id_like(value)) - id_like_percentage = (id_like_count / len(sample_values)) * 100 + values_list = gene_values.tolist() + pct = _id_like_percentage(values_list) - if id_like_percentage > 50: - example_ids = [ - value for value in sample_values.iloc[:5] if _is_gene_id_like(value) - ] + if pct > 50: + example_ids = [v for v in values_list[:20] if _is_gene_id_like(v)][:3] logger.warning( f"Column '{gene_symbols_col}' appears to contain gene IDs rather than gene symbols. " - f"{id_like_percentage:.1f}% of values look like gene IDs (e.g., {example_ids[:3]}). " + f"{pct:.1f}% of values look like gene IDs (e.g., {example_ids}). " f"The annotation might not be accurate. Consider using a column that contains " f"human-readable gene symbols (e.g., 'TSPAN6', 'DPM1', 'SCYL3') instead of database identifiers." ) - elif id_like_percentage > 20: + elif pct > 20: logger.warning( - f"Column '{gene_symbols_col}' contains {id_like_percentage:.1f}% values that look like gene IDs. " + f"Column '{gene_symbols_col}' contains {pct:.1f}% values that look like gene IDs. " f"Please verify this column contains gene symbols rather than gene identifiers." ) +def resolve_gene_symbols_column( + adata: anndata.AnnData, gene_symbols_column: str | None +) -> str | None: + """Resolve which source contains gene symbols. + + Returns the column name in adata.var, or None if var_names should be used directly. + """ + if gene_symbols_column is not None: + if gene_symbols_column not in adata.var.columns: + raise KeyError( + f"Column '{gene_symbols_column}' not found in `adata.var`. " + f"Available columns: {list(adata.var.columns)}. " + f"Set gene_symbols_column=None for auto-detection." + ) + values = adata.var[gene_symbols_column].dropna().astype(str).tolist() + if _has_composite_gene_values(values): + logger.info( + f"Column '{gene_symbols_column}' contains composite gene name/ID values " + f"(e.g., '{values[0]}'). Gene symbols will be extracted automatically." + ) + _validate_gene_symbols_column(adata, gene_symbols_column) + logger.info(f"Using gene symbols from column '{gene_symbols_column}'.") + return gene_symbols_column + + # --- Auto-detection: score all candidates, then pick the best --- + # Each candidate: (column_name | None, id_like_pct, unique_ratio, priority) + # column_name=None → use var_names. + # priority: 0 = known column, 1 = var_names, 2 = other column. + # Sorted by (id_like_pct ↑, priority ↑, unique_ratio ↓ with 1.0 penalized) + # so the lowest ID-like % wins; ties broken by known columns first, then + # by higher unique ratio (gene names have high cardinality, unlike + # categorical metadata — but exactly 1.0 is slightly penalized because + # IDs are always unique while gene symbols occasionally have duplicates). + _KNOWN_SET = set(_KNOWN_GENE_SYMBOL_COLUMNS) + candidates: list[tuple[str | None, float, float, int]] = [] + + for col in _KNOWN_GENE_SYMBOL_COLUMNS: + if col not in adata.var.columns: + continue + values = adata.var[col].dropna().astype(str).tolist() + if not values: + continue + score_values = ( + [_extract_symbol_from_composite(v) for v in values] + if _has_composite_gene_values(values) + else values + ) + pct = _id_like_percentage(score_values) + unique_ratio = len(set(score_values)) / len(score_values) + candidates.append((col, pct, unique_ratio, 0)) + + var_names_list = adata.var_names.astype(str).tolist() + if var_names_list: + var_score_values = ( + [_extract_symbol_from_composite(v) for v in var_names_list] + if _has_composite_gene_values(var_names_list) + else var_names_list + ) + var_id_pct = _id_like_percentage(var_score_values) + var_unique_ratio = len(set(var_score_values)) / len(var_score_values) + candidates.append((None, var_id_pct, var_unique_ratio, 1)) + + for col in adata.var.columns: + if col in _KNOWN_SET: + continue + try: + values = adata.var[col].dropna().astype(str).tolist() + except (TypeError, ValueError): + continue + if not values: + continue + score_values = ( + [_extract_symbol_from_composite(v) for v in values] + if _has_composite_gene_values(values) + else values + ) + n_unique = len(set(score_values)) + if n_unique < max(10, len(score_values) * 0.05): + continue + pct = _id_like_percentage(score_values) + unique_ratio = n_unique / len(score_values) + candidates.append((col, pct, unique_ratio, 2)) + + viable = [c for c in candidates if c[1] < 50] + + def _ur_sort_key(ur: float) -> float: + return ur if ur < 1.0 else ur - 0.02 + + if viable: + viable.sort(key=lambda c: (c[1], c[3], -_ur_sort_key(c[2]))) + best_col, best_pct, best_ur, _ = viable[0] + + if best_col is not None: + source = f"column '{best_col}'" + _validate_gene_symbols_column(adata, best_col) + else: + source = "adata.var_names (index)" + + logger.info( + f"Auto-detected gene symbols in {source} " + f"({best_pct:.0f}% ID-like, {best_ur:.0%} unique)." + ) + return best_col + + # No viable candidate: fall back to var_names with warning + fallback_pct = var_id_pct if var_names_list else 100.0 + logger.warning( + "Could not find a column containing gene symbols in adata.var. " + "Falling back to adata.var_names, but they appear to contain gene IDs " + f"({fallback_pct:.0f}% ID-like). Annotation quality may be affected. " + "Consider providing gene_symbols_column explicitly." + ) + return None + + def validate_adata( adata: anndata.AnnData, cell_group_key: str, rank_genes_key: str, - gene_symbols_col: str, + gene_symbols_col: str | None, coordinates_key: str, ) -> str | None: - """Validate the AnnData object structure and return the best available coordinates key. - - Args: - adata: AnnData object to validate - cell_group_key: Key in adata.obs containing cluster labels - rank_genes_key: Key in adata.uns containing rank_genes_groups results - gene_symbols_col: Column in adata.var containing gene symbols - coordinates_key: Preferred key in adata.obsm for coordinates - - Returns: - str | None: The coordinates key that was found and validated, or None if no suitable coordinates found. - - Raises: - KeyError: If required keys are missing - ValueError: If data format is incorrect - """ if cell_group_key not in adata.obs: raise KeyError(f"Cell group key '{cell_group_key}' not found in `adata.obs`.") if adata.X is None: @@ -123,9 +246,6 @@ def validate_adata( raise KeyError( f"'{rank_genes_key}' not found in `adata.uns`. Run `sc.tl.rank_genes_groups` first." ) - if hasattr(adata.var, gene_symbols_col) is False: - raise KeyError(f"Column '{gene_symbols_col}' not found in `adata.var`.") - _validate_gene_symbols_column(adata, gene_symbols_col) if adata.uns[rank_genes_key]["params"]["groupby"] != cell_group_key: raise ValueError( diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py index a041a62..5cdf666 100644 --- a/tests/test_artifacts.py +++ b/tests/test_artifacts.py @@ -1,8 +1,11 @@ +import pytest import h5py import anndata +import numpy as np +import scipy.sparse as sp from pathlib import Path -from cytetype.core.artifacts import save_features_matrix +from cytetype.core.artifacts import save_features_matrix, _is_integer_valued def test_save_features_matrix_writes_var_metadata( @@ -32,3 +35,160 @@ def test_save_features_matrix_writes_var_metadata( for dataset in columns_group.values(): assert "source_name" in dataset.attrs assert "source_dtype" in dataset.attrs + + +def test_save_features_matrix_writes_raw_group( + tmp_path: Path, + mock_adata: anndata.AnnData, +) -> None: + n_obs, n_vars = mock_adata.n_obs, mock_adata.n_vars + rng = np.random.default_rng(0) + raw_mat = sp.random(n_obs, n_vars, density=0.1, format="csr", random_state=rng) + raw_mat.data = rng.integers(1, 20, size=raw_mat.nnz).astype(np.int32) + + out_path = tmp_path / "vars.h5" + save_features_matrix( + out_file=str(out_path), + mat=mock_adata.X, + var_df=mock_adata.var, + var_names=mock_adata.var_names, + raw_mat=raw_mat, + raw_cell_batch=30, + col_batch=10, + ) + + with h5py.File(out_path, "r") as f: + assert "raw" in f + raw = f["raw"] + assert raw.attrs["n_obs"] == n_obs + assert raw.attrs["n_vars"] == n_vars + assert raw["data"].dtype == np.int32 + assert raw["indices"].dtype == np.int32 + assert len(raw["indptr"]) == n_obs + 1 + assert raw["indptr"][0] == 0 + assert raw["indptr"][-1] == raw_mat.nnz + + +def test_save_features_matrix_raw_skipped_when_none( + tmp_path: Path, + mock_adata: anndata.AnnData, +) -> None: + out_path = tmp_path / "vars.h5" + save_features_matrix( + out_file=str(out_path), + mat=mock_adata.X, + col_batch=10, + ) + + with h5py.File(out_path, "r") as f: + assert "raw" not in f + assert "vars" in f + + +def test_save_features_matrix_raw_with_float_integers( + tmp_path: Path, + mock_adata: anndata.AnnData, +) -> None: + n_obs, n_vars = mock_adata.n_obs, mock_adata.n_vars + rng = np.random.default_rng(1) + int_vals = rng.integers(1, 10, size=(n_obs, n_vars)) + raw_mat = sp.csr_matrix(int_vals.astype(np.float32)) + + out_path = tmp_path / "vars.h5" + save_features_matrix( + out_file=str(out_path), + mat=mock_adata.X, + raw_mat=raw_mat, + raw_cell_batch=30, + col_batch=10, + ) + + with h5py.File(out_path, "r") as f: + assert "raw" in f + assert f["raw"]["data"].dtype == np.int32 + + +def test_save_features_matrix_backed_csr(tmp_path: Path) -> None: + n_obs, n_vars = 200, 80 + rng = np.random.default_rng(42) + mat = sp.random(n_obs, n_vars, density=0.3, format="csr", random_state=rng) + mat.data = rng.standard_normal(mat.nnz).astype(np.float32) + reference_csc = mat.tocsc() + + h5ad_path = tmp_path / "backed.h5ad" + adata = anndata.AnnData(X=mat) + adata.write_h5ad(h5ad_path) + del adata + + backed = anndata.read_h5ad(h5ad_path, backed="r") + out_path = tmp_path / "vars.h5" + save_features_matrix(out_file=str(out_path), mat=backed.X) + backed.file.close() + + with h5py.File(out_path, "r") as f: + grp = f["vars"] + assert grp.attrs["n_obs"] == n_obs + assert grp.attrs["n_vars"] == n_vars + + h5_indices = grp["indices"][:] + h5_data = grp["data"][:] + h5_indptr = grp["indptr"][:] + + assert h5_indices.dtype == np.int32 + assert h5_data.dtype == np.float32 + assert len(h5_indptr) == n_vars + 1 + assert h5_indptr[0] == 0 + assert h5_indptr[-1] == reference_csc.nnz + + np.testing.assert_array_equal(h5_indptr, reference_csc.indptr) + np.testing.assert_array_equal(h5_indices, reference_csc.indices) + np.testing.assert_allclose(h5_data, reference_csc.data, rtol=1e-6) + + +def test_save_features_matrix_backed_csr_multiple_groups( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + import cytetype.config + + monkeypatch.setattr(cytetype.config, "WRITE_MEM_BUDGET", 256) + + n_obs, n_vars = 50, 30 + rng = np.random.default_rng(7) + mat = sp.random(n_obs, n_vars, density=0.4, format="csr", random_state=rng) + mat.data = rng.standard_normal(mat.nnz).astype(np.float32) + reference_csc = mat.tocsc() + + h5ad_path = tmp_path / "backed.h5ad" + adata = anndata.AnnData(X=mat) + adata.write_h5ad(h5ad_path) + del adata + + backed = anndata.read_h5ad(h5ad_path, backed="r") + out_path = tmp_path / "vars.h5" + save_features_matrix(out_file=str(out_path), mat=backed.X) + backed.file.close() + + with h5py.File(out_path, "r") as f: + grp = f["vars"] + h5_indptr = grp["indptr"][:] + h5_indices = grp["indices"][:] + h5_data = grp["data"][:] + + np.testing.assert_array_equal(h5_indptr, reference_csc.indptr) + np.testing.assert_array_equal(h5_indices, reference_csc.indices) + np.testing.assert_allclose(h5_data, reference_csc.data, rtol=1e-6) + + +def test_is_integer_valued_with_true_integers() -> None: + mat = sp.csr_matrix(np.array([[1, 0, 3], [0, 2, 0]], dtype=np.int32)) + assert _is_integer_valued(mat) is True + + +def test_is_integer_valued_with_float_integers() -> None: + mat = sp.csr_matrix(np.array([[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]], dtype=np.float32)) + assert _is_integer_valued(mat) is True + + +def test_is_integer_valued_with_floats() -> None: + mat = sp.csr_matrix(np.array([[1.5, 0.0, 3.2], [0.0, 2.7, 0.0]], dtype=np.float32)) + assert _is_integer_valued(mat) is False diff --git a/tests/test_cytetype_integration.py b/tests/test_cytetype_integration.py index dddcf24..0b6993c 100644 --- a/tests/test_cytetype_integration.py +++ b/tests/test_cytetype_integration.py @@ -207,7 +207,7 @@ def test_cytetype_run_artifact_failure_continues_when_not_required( @patch("cytetype.main.wait_for_completion") @patch("cytetype.main.submit_annotation_job") -def test_cytetype_run_cleanup_artifacts( +def test_cytetype_cleanup_deletes_artifact_files( mock_submit: MagicMock, mock_wait: MagicMock, mock_adata: anndata.AnnData, @@ -215,7 +215,7 @@ def test_cytetype_run_cleanup_artifacts( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - """Test run() can cleanup generated artifact files when requested.""" + """Test cleanup() deletes the artifact files built during initialization.""" mock_submit.return_value = "job_cleanup" mock_wait.return_value = mock_api_response @@ -231,16 +231,23 @@ def _save_obs(*args: Any, **kwargs: Any) -> None: monkeypatch.setattr("cytetype.main.save_features_matrix", _save_vars) monkeypatch.setattr("cytetype.main.save_obs_duckdb_file", _save_obs) - ct = CyteType(mock_adata, group_key="leiden") - ct.run( - study_context="Test", + ct = CyteType( + mock_adata, + group_key="leiden", vars_h5_path=str(vars_path), obs_duckdb_path=str(obs_path), - cleanup_artifacts=True, ) + ct.run(study_context="Test") + + assert vars_path.exists() + assert obs_path.exists() + + ct.cleanup() assert not vars_path.exists() assert not obs_path.exists() + assert ct._vars_h5_path is None + assert ct._obs_duckdb_path is None @patch("cytetype.main.wait_for_completion")