From bec9e00eb93101d36ba155229012321e345563e0 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Sat, 28 Feb 2026 22:53:24 +0100 Subject: [PATCH 01/19] Update version to 0.18.0 and enhance raw counts handling in save_features_matrix - Bump package version to 0.18.0. - Introduce _resolve_raw_counts method in CyteType to improve raw counts extraction from AnnData. - Add _is_integer_valued utility function to check if matrices contain integer values. - Update save_features_matrix to handle raw counts and include them in the output HDF5 file. - Enhance tests to cover new raw counts functionality and integer value checks. --- cytetype/__init__.py | 2 +- cytetype/core/artifacts.py | 102 +++++++++++++++++++++++++++++++++++++ cytetype/main.py | 36 +++++++++++++ tests/test_artifacts.py | 90 +++++++++++++++++++++++++++++++- 4 files changed, 228 insertions(+), 2 deletions(-) diff --git a/cytetype/__init__.py b/cytetype/__init__.py index 3cb2877..dd4464e 100644 --- a/cytetype/__init__.py +++ b/cytetype/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.17.0" +__version__ = "0.18.0" import requests diff --git a/cytetype/core/artifacts.py b/cytetype/core/artifacts.py index c916e1f..893c6d2 100644 --- a/cytetype/core/artifacts.py +++ b/cytetype/core/artifacts.py @@ -113,11 +113,95 @@ def _write_var_metadata( dataset.attrs["source_dtype"] = str(series.dtype) +def _is_integer_valued(mat: Any, sample_n_rows: int = 200) -> bool: + if hasattr(mat, "dtype") and np.issubdtype(mat.dtype, np.integer): + return True + + n_rows = mat.shape[0] + row_end = min(sample_n_rows, n_rows) + chunk = mat[:row_end] + + if sp.issparse(chunk): + sample = chunk.data + elif hasattr(chunk, "toarray"): + sample = chunk.toarray().ravel() + else: + sample = np.asarray(chunk).ravel() + + if sample.size == 0: + return True + + sample = sample.astype(np.float64, copy=False) + return bool(np.all(np.isfinite(sample)) and np.all(sample == np.floor(sample))) + + +def _write_raw_group( + f: h5py.File, + mat: Any, + n_obs: int, + n_vars: int, + col_indices: "np.ndarray | None", + cell_batch: int, + min_chunk_size: int, +) -> None: + group = f.create_group("raw") + group.attrs["n_obs"] = n_obs + group.attrs["n_vars"] = n_vars + + chunk_size = max(1, min(n_vars * 10, min_chunk_size)) + + d_indices = group.create_dataset( + "indices", + shape=(0,), + maxshape=(n_obs * n_vars,), + chunks=(chunk_size,), + dtype=np.int32, + compression=hdf5plugin.LZ4(), + ) + d_data = group.create_dataset( + "data", + shape=(0,), + maxshape=(n_obs * n_vars,), + chunks=(chunk_size,), + dtype=np.int32, + compression=hdf5plugin.LZ4(), + ) + + indptr: list[int] = [0] + for start in range(0, n_obs, cell_batch): + end = min(start + cell_batch, n_obs) + raw_chunk = mat[start:end] + + if col_indices is not None: + raw_chunk = raw_chunk[:, col_indices] + + chunk = ( + raw_chunk.tocsr() if sp.issparse(raw_chunk) else sp.csr_matrix(raw_chunk) + ) + + old_size = d_indices.shape[0] + chunk_col_indices = chunk.indices.astype(np.int32, copy=False) + chunk_data = chunk.data.astype(np.int32, copy=False) + + d_indices.resize(old_size + len(chunk_col_indices), axis=0) + d_indices[old_size : old_size + len(chunk_col_indices)] = chunk_col_indices + + d_data.resize(old_size + len(chunk_data), axis=0) + d_data[old_size : old_size + len(chunk_data)] = chunk_data + + indptr.extend((chunk.indptr[1:] + indptr[-1]).tolist()) + + group.create_dataset("indptr", data=np.asarray(indptr, dtype=np.int64)) + + def save_features_matrix( out_file: str, mat: Any, var_df: pd.DataFrame | None = None, var_names: pd.Index | Sequence[Any] | None = None, + raw_mat: Any | None = None, + raw_col_indices: "np.ndarray | None" = None, + raw_cell_batch: int = 2000, min_chunk_size: int = 10_000_000, col_batch: int | None = None, ) -> None: @@ -193,6 +277,24 @@ def save_features_matrix( var_names=var_names, ) + if raw_mat is not None: + try: + _write_raw_group( + f, + raw_mat, + n_rows, + n_cols, + raw_col_indices, + raw_cell_batch, + min_chunk_size, + ) + except Exception: + logger.warning( + "Failed to write raw counts group to %s, skipping.", out_file + ) + if "raw" in f: + del f["raw"] + def save_obs_duckdb( out_file: str, diff --git a/cytetype/main.py b/cytetype/main.py index 993a0f2..eaa98c8 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -3,6 +3,7 @@ from importlib.metadata import PackageNotFoundError, version import anndata +import numpy as np from natsort import natsorted from .config import logger @@ -20,6 +21,7 @@ ) from .core.payload import build_annotation_payload, save_query_to_file from .core.artifacts import ( + _is_integer_valued, save_features_matrix, save_obs_duckdb as save_obs_duckdb_file, ) @@ -198,6 +200,29 @@ def __init__( logger.info("Data preparation completed. Ready for submitting jobs.") + def _resolve_raw_counts( + self, + ) -> "tuple[Any, np.ndarray | None] | None": + if "counts" in self.adata.layers: + mat = self.adata.layers["counts"] + if _is_integer_valued(mat): + return mat, None + + if self.adata.raw is not None: + raw_mat = self.adata.raw.X + col_indices = self.adata.raw.var_names.get_indexer(self.adata.var_names) + if (col_indices == -1).any(): + logger.warning( + "Some var_names not found in adata.raw — skipping adata.raw.X as raw counts source." + ) + elif _is_integer_valued(raw_mat): + return raw_mat, col_indices.astype(np.intp) + + if _is_integer_valued(self.adata.X): + return self.adata.X, None + + return None + def _build_and_upload_artifacts( self, vars_h5_path: str, @@ -218,11 +243,22 @@ def _build_and_upload_artifacts( # --- vars.h5 (save then upload) --- try: logger.info("Saving vars.h5 artifact from normalized counts...") + raw_result = self._resolve_raw_counts() + if raw_result is None: + logger.warning( + "No integer raw counts found in adata.layers['counts'], " + "adata.raw.X, or adata.X. Skipping raw counts in vars.h5." + ) + raw_mat, raw_col_indices = None, None + else: + raw_mat, raw_col_indices = raw_result save_features_matrix( out_file=vars_h5_path, mat=self.adata.X, var_df=self.adata.var, var_names=self.adata.var_names, + raw_mat=raw_mat, + raw_col_indices=raw_col_indices, ) logger.info("Uploading vars.h5 artifact...") vars_upload = upload_vars_h5_file( diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py index a041a62..968c781 100644 --- a/tests/test_artifacts.py +++ b/tests/test_artifacts.py @@ -1,8 +1,10 @@ import h5py import anndata +import numpy as np +import scipy.sparse as sp from pathlib import Path -from cytetype.core.artifacts import save_features_matrix +from cytetype.core.artifacts import save_features_matrix, _is_integer_valued def test_save_features_matrix_writes_var_metadata( @@ -32,3 +34,89 @@ def test_save_features_matrix_writes_var_metadata( for dataset in columns_group.values(): assert "source_name" in dataset.attrs assert "source_dtype" in dataset.attrs + + +def test_save_features_matrix_writes_raw_group( + tmp_path: Path, + mock_adata: anndata.AnnData, +) -> None: + n_obs, n_vars = mock_adata.n_obs, mock_adata.n_vars + rng = np.random.default_rng(0) + raw_mat = sp.random(n_obs, n_vars, density=0.1, format="csr", random_state=rng) + raw_mat.data = rng.integers(1, 20, size=raw_mat.nnz).astype(np.int32) + + out_path = tmp_path / "vars.h5" + save_features_matrix( + out_file=str(out_path), + mat=mock_adata.X, + var_df=mock_adata.var, + var_names=mock_adata.var_names, + raw_mat=raw_mat, + raw_cell_batch=30, + col_batch=10, + ) + + with h5py.File(out_path, "r") as f: + assert "raw" in f + raw = f["raw"] + assert raw.attrs["n_obs"] == n_obs + assert raw.attrs["n_vars"] == n_vars + assert raw["data"].dtype == np.int32 + assert raw["indices"].dtype == np.int32 + assert len(raw["indptr"]) == n_obs + 1 + assert raw["indptr"][0] == 0 + assert raw["indptr"][-1] == raw_mat.nnz + + +def test_save_features_matrix_raw_skipped_when_none( + tmp_path: Path, + mock_adata: anndata.AnnData, +) -> None: + out_path = tmp_path / "vars.h5" + save_features_matrix( + out_file=str(out_path), + mat=mock_adata.X, + col_batch=10, + ) + + with h5py.File(out_path, "r") as f: + assert "raw" not in f + assert "vars" in f + + +def test_save_features_matrix_raw_with_float_integers( + tmp_path: Path, + mock_adata: anndata.AnnData, +) -> None: + n_obs, n_vars = mock_adata.n_obs, mock_adata.n_vars + rng = np.random.default_rng(1) + int_vals = rng.integers(1, 10, size=(n_obs, n_vars)) + raw_mat = sp.csr_matrix(int_vals.astype(np.float32)) + + out_path = tmp_path / "vars.h5" + save_features_matrix( + out_file=str(out_path), + mat=mock_adata.X, + raw_mat=raw_mat, + raw_cell_batch=30, + col_batch=10, + ) + + with h5py.File(out_path, "r") as f: + assert "raw" in f + assert f["raw"]["data"].dtype == np.int32 + + +def test_is_integer_valued_with_true_integers() -> None: + mat = sp.csr_matrix(np.array([[1, 0, 3], [0, 2, 0]], dtype=np.int32)) + assert _is_integer_valued(mat) is True + + +def test_is_integer_valued_with_float_integers() -> None: + mat = sp.csr_matrix(np.array([[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]], dtype=np.float32)) + assert _is_integer_valued(mat) is True + + +def test_is_integer_valued_with_floats() -> None: + mat = sp.csr_matrix(np.array([[1.5, 0.0, 3.2], [0.0, 2.7, 0.0]], dtype=np.float32)) + assert _is_integer_valued(mat) is False From 098df32c74911581318d402fd8a8691264420714 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Mon, 2 Mar 2026 12:12:21 +0100 Subject: [PATCH 02/19] Add artifact paths for vars.h5 and obs.duckdb, enhance artifact building and uploading - Introduced vars_h5_path and obs_duckdb_path parameters in CyteType for customizable artifact paths. - Implemented caching of raw counts and improved error handling during artifact creation. - Updated _upload_artifacts method to handle pre-built artifacts and log errors appropriately. - Modified integration tests to accommodate new parameters and ensure proper artifact cleanup. --- cytetype/main.py | 175 +++++++++++++++-------------- tests/test_cytetype_integration.py | 9 +- 2 files changed, 99 insertions(+), 85 deletions(-) diff --git a/cytetype/main.py b/cytetype/main.py index eaa98c8..cceb38b 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -79,6 +79,8 @@ def __init__( pcent_batch_size: int = 2000, coordinates_key: str = "X_umap", max_cells_per_group: int = 1000, + vars_h5_path: str = "vars.h5", + obs_duckdb_path: str = "obs.duckdb", api_url: str = "https://prod.cytetype.nygen.io", auth_token: str | None = None, ) -> None: @@ -128,6 +130,7 @@ def __init__( self.max_cells_per_group = max_cells_per_group self.api_url = api_url self.auth_token = auth_token + self._artifact_build_errors: list[tuple[str, Exception]] = [] # Validate data and get the best available coordinates key self.coordinates_key = validate_adata( @@ -198,6 +201,54 @@ def __init__( "clusters": sampled_cluster_labels, } + # Resolve raw counts once and cache + self._raw_counts_result = self._resolve_raw_counts() + if self._raw_counts_result is None: + logger.warning( + "No integer raw counts found in adata.layers['counts'], " + "adata.raw.X, or adata.X. Skipping raw counts in vars.h5." + ) + + # Build vars.h5 + try: + logger.info("Saving vars.h5 artifact from normalized counts...") + raw_mat, raw_col_indices = ( + self._raw_counts_result if self._raw_counts_result is not None else (None, None) + ) + save_features_matrix( + out_file=vars_h5_path, + mat=self.adata.X, + var_df=self.adata.var, + var_names=self.adata.var_names, + raw_mat=raw_mat, + raw_col_indices=raw_col_indices, + ) + self._vars_h5_path: str | None = vars_h5_path + except Exception as exc: + logger.warning(f"vars.h5 artifact failed during build: {exc}") + self._vars_h5_path = None + self._artifact_build_errors.append(("vars_h5", exc)) + + # Build obs.duckdb + try: + logger.info("Saving obs.duckdb artifact from observation metadata...") + obsm_coordinates = ( + self.adata.obsm[self.coordinates_key] + if self.coordinates_key and self.coordinates_key in self.adata.obsm + else None + ) + save_obs_duckdb_file( + out_file=obs_duckdb_path, + obs_df=self.adata.obs, + obsm_coordinates=obsm_coordinates, + coordinates_key=self.coordinates_key, + ) + self._obs_duckdb_path: str | None = obs_duckdb_path + except Exception as exc: + logger.warning(f"obs.duckdb artifact failed during build: {exc}") + self._obs_duckdb_path = None + self._artifact_build_errors.append(("obs_duckdb", exc)) + logger.info("Data preparation completed. Ready for submitting jobs.") def _resolve_raw_counts( @@ -223,92 +274,61 @@ def _resolve_raw_counts( return None - def _build_and_upload_artifacts( + def _upload_artifacts( self, - vars_h5_path: str, - obs_duckdb_path: str, upload_timeout_seconds: int, upload_max_workers: int = 4, - coordinates_key: str | None = None, ) -> tuple[dict[str, str], list[tuple[str, Exception]]]: - """Build and upload each artifact as an independent unit. + """Upload pre-built artifact files to the server. Returns (uploaded_ids, errors) so the caller can decide whether partial success is acceptable. """ uploaded: dict[str, str] = {} - errors: list[tuple[str, Exception]] = [] + errors: list[tuple[str, Exception]] = list(self._artifact_build_errors) timeout = (30.0, float(upload_timeout_seconds)) - # --- vars.h5 (save then upload) --- - try: - logger.info("Saving vars.h5 artifact from normalized counts...") - raw_result = self._resolve_raw_counts() - if raw_result is None: - logger.warning( - "No integer raw counts found in adata.layers['counts'], " - "adata.raw.X, or adata.X. Skipping raw counts in vars.h5." - ) - raw_mat, raw_col_indices = None, None - else: - raw_mat, raw_col_indices = raw_result - save_features_matrix( - out_file=vars_h5_path, - mat=self.adata.X, - var_df=self.adata.var, - var_names=self.adata.var_names, - raw_mat=raw_mat, - raw_col_indices=raw_col_indices, - ) - logger.info("Uploading vars.h5 artifact...") - vars_upload = upload_vars_h5_file( - self.api_url, - self.auth_token, - vars_h5_path, - timeout=timeout, - max_workers=upload_max_workers, - ) - if vars_upload.file_kind != "vars_h5": - raise ValueError( - f"Unexpected upload file_kind for vars artifact: {vars_upload.file_kind}" + # --- vars.h5 upload --- + if self._vars_h5_path is not None: + try: + logger.info("Uploading vars.h5 artifact...") + vars_upload = upload_vars_h5_file( + self.api_url, + self.auth_token, + self._vars_h5_path, + timeout=timeout, + max_workers=upload_max_workers, ) - uploaded["vars_h5"] = vars_upload.upload_id - except Exception as exc: - logger.warning(f"vars.h5 artifact failed: {exc}") - errors.append(("vars_h5", exc)) + if vars_upload.file_kind != "vars_h5": + raise ValueError( + f"Unexpected upload file_kind for vars artifact: {vars_upload.file_kind}" + ) + uploaded["vars_h5"] = vars_upload.upload_id + except Exception as exc: + logger.warning(f"vars.h5 upload failed: {exc}") + errors.append(("vars_h5", exc)) print() - # --- obs.duckdb (save then upload) --- - try: - logger.info("Saving obs.duckdb artifact from observation metadata...") - obsm_coordinates = ( - self.adata.obsm[coordinates_key] - if coordinates_key and coordinates_key in self.adata.obsm - else None - ) - save_obs_duckdb_file( - out_file=obs_duckdb_path, - obs_df=self.adata.obs, - obsm_coordinates=obsm_coordinates, - coordinates_key=coordinates_key, - ) - logger.info("Uploading obs.duckdb artifact...") - obs_upload = upload_obs_duckdb_file( - self.api_url, - self.auth_token, - obs_duckdb_path, - timeout=timeout, - max_workers=upload_max_workers, - ) - if obs_upload.file_kind != "obs_duckdb": - raise ValueError( - f"Unexpected upload file_kind for obs artifact: {obs_upload.file_kind}" + # --- obs.duckdb upload --- + if self._obs_duckdb_path is not None: + try: + logger.info("Uploading obs.duckdb artifact...") + obs_upload = upload_obs_duckdb_file( + self.api_url, + self.auth_token, + self._obs_duckdb_path, + timeout=timeout, + max_workers=upload_max_workers, ) - uploaded["obs_duckdb"] = obs_upload.upload_id - except Exception as exc: - logger.warning(f"obs.duckdb artifact failed: {exc}") - errors.append(("obs_duckdb", exc)) + if obs_upload.file_kind != "obs_duckdb": + raise ValueError( + f"Unexpected upload file_kind for obs artifact: {obs_upload.file_kind}" + ) + uploaded["obs_duckdb"] = obs_upload.upload_id + except Exception as exc: + logger.warning(f"obs.duckdb upload failed: {exc}") + errors.append(("obs_duckdb", exc)) return uploaded, errors @@ -333,8 +353,6 @@ def run( auth_token: str | None = None, save_query: bool = True, query_filename: str = "query.json", - vars_h5_path: str = "vars.h5", - obs_duckdb_path: str = "obs.duckdb", upload_timeout_seconds: int = 3600, upload_max_workers: int = 4, cleanup_artifacts: bool = False, @@ -373,10 +391,6 @@ def run( save_query (bool, optional): Whether to save the query to a file. Defaults to True. query_filename (str, optional): Filename for saving the query when save_query is True. Defaults to "query.json". - vars_h5_path (str, optional): Local output path for generated vars.h5 artifact. - Defaults to "vars.h5". - obs_duckdb_path (str, optional): Local output path for generated obs.duckdb artifact. - Defaults to "obs.duckdb". upload_timeout_seconds (int, optional): Socket read timeout used for each artifact upload. Defaults to 3600. upload_max_workers (int, optional): Number of parallel threads used to upload file @@ -423,7 +437,7 @@ def run( if upload_timeout_seconds <= 0: raise ValueError("upload_timeout_seconds must be greater than 0") - # Validate the base payload before doing potentially expensive uploads. + # Validate the base payload before uploading. payload = build_annotation_payload( study_context, metadata, @@ -436,14 +450,11 @@ def run( llm_configs, ) - artifact_paths = [vars_h5_path, obs_duckdb_path] + artifact_paths = [p for p in [self._vars_h5_path, self._obs_duckdb_path] if p is not None] try: - uploaded_file_refs, artifact_errors = self._build_and_upload_artifacts( - vars_h5_path=vars_h5_path, - obs_duckdb_path=obs_duckdb_path, + uploaded_file_refs, artifact_errors = self._upload_artifacts( upload_timeout_seconds=upload_timeout_seconds, upload_max_workers=upload_max_workers, - coordinates_key=self.coordinates_key, ) if uploaded_file_refs: payload["uploaded_files"] = uploaded_file_refs diff --git a/tests/test_cytetype_integration.py b/tests/test_cytetype_integration.py index dddcf24..86105ab 100644 --- a/tests/test_cytetype_integration.py +++ b/tests/test_cytetype_integration.py @@ -231,11 +231,14 @@ def _save_obs(*args: Any, **kwargs: Any) -> None: monkeypatch.setattr("cytetype.main.save_features_matrix", _save_vars) monkeypatch.setattr("cytetype.main.save_obs_duckdb_file", _save_obs) - ct = CyteType(mock_adata, group_key="leiden") - ct.run( - study_context="Test", + ct = CyteType( + mock_adata, + group_key="leiden", vars_h5_path=str(vars_path), obs_duckdb_path=str(obs_path), + ) + ct.run( + study_context="Test", cleanup_artifacts=True, ) From 6348c71e6d01f5d5ff3ebddc61e88a985baf2770 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Mon, 2 Mar 2026 12:14:18 +0100 Subject: [PATCH 03/19] Refactor artifact cleanup in CyteType and update tests - Replaced the static method _cleanup_artifact_files with an instance method cleanup to manage artifact file deletion after run completion. - Removed the cleanup_artifacts parameter from run method, simplifying the interface. - Updated integration tests to verify that cleanup correctly deletes artifact files and clears associated paths. --- cytetype/main.py | 129 ++++++++++++++--------------- tests/test_cytetype_integration.py | 16 ++-- 2 files changed, 74 insertions(+), 71 deletions(-) diff --git a/cytetype/main.py b/cytetype/main.py index cceb38b..3226866 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -332,13 +332,20 @@ def _upload_artifacts( return uploaded, errors - @staticmethod - def _cleanup_artifact_files(paths: list[str]) -> None: - for artifact_path in paths: - try: - Path(artifact_path).unlink(missing_ok=True) - except OSError as exc: - logger.warning(f"Failed to cleanup artifact {artifact_path}: {exc}") + def cleanup(self) -> None: + """Delete the artifact files built during initialization. + + Call this after run() completes to remove the vars.h5 and obs.duckdb + files from disk. Paths are cleared so repeated calls are safe. + """ + for attr, path in [("_vars_h5_path", self._vars_h5_path), ("_obs_duckdb_path", self._obs_duckdb_path)]: + if path is not None: + try: + Path(path).unlink(missing_ok=True) + logger.info(f"Deleted artifact file: {path}") + except OSError as exc: + logger.warning(f"Failed to delete artifact {path}: {exc}") + setattr(self, attr, None) def run( self, @@ -355,7 +362,6 @@ def run( query_filename: str = "query.json", upload_timeout_seconds: int = 3600, upload_max_workers: int = 4, - cleanup_artifacts: bool = False, require_artifacts: bool = True, show_progress: bool = True, override_existing_results: bool = False, @@ -395,8 +401,6 @@ def run( Defaults to 3600. upload_max_workers (int, optional): Number of parallel threads used to upload file chunks. Each worker holds one chunk in memory (~100 MB). Defaults to 4. - cleanup_artifacts (bool, optional): Whether to delete generated artifact files after run - completes or fails. Defaults to False. require_artifacts (bool, optional): Whether to raise an error if artifact building or uploading fails. When True (default), any artifact failure stops the run. Set to False to skip artifacts and continue with annotation only. Defaults to True. @@ -450,66 +454,61 @@ def run( llm_configs, ) - artifact_paths = [p for p in [self._vars_h5_path, self._obs_duckdb_path] if p is not None] - try: - uploaded_file_refs, artifact_errors = self._upload_artifacts( - upload_timeout_seconds=upload_timeout_seconds, - upload_max_workers=upload_max_workers, - ) - if uploaded_file_refs: - payload["uploaded_files"] = uploaded_file_refs - - if artifact_errors: - failed_names = ", ".join(name for name, _ in artifact_errors) - if require_artifacts: - logger.error( - f"Artifact build/upload failed for: {failed_names}. " - "Rerun with `require_artifacts=False` to skip this error.\n" - "Please report the error below in a new issue at " - "https://github.com/NygenAnalytics/CyteType\n" - f"({type(artifact_errors[0][1]).__name__}: {str(artifact_errors[0][1]).strip()})" - ) - raise artifact_errors[0][1] - logger.warning( + uploaded_file_refs, artifact_errors = self._upload_artifacts( + upload_timeout_seconds=upload_timeout_seconds, + upload_max_workers=upload_max_workers, + ) + if uploaded_file_refs: + payload["uploaded_files"] = uploaded_file_refs + + if artifact_errors: + failed_names = ", ".join(name for name, _ in artifact_errors) + if require_artifacts: + logger.error( f"Artifact build/upload failed for: {failed_names}. " - "Continuing without those artifacts. " - "Set `require_artifacts=True` to see the full traceback." + "Rerun with `require_artifacts=False` to skip this error.\n" + "Please report the error below in a new issue at " + "https://github.com/NygenAnalytics/CyteType\n" + f"({type(artifact_errors[0][1]).__name__}: {str(artifact_errors[0][1]).strip()})" ) - - # Save query if requested - if save_query: - save_query_to_file(payload["input_data"], query_filename) - - # Submit job and store details - print() - job_id = submit_annotation_job(self.api_url, self.auth_token, payload) - store_job_details(self.adata, job_id, self.api_url, results_prefix) - - # Wait for completion - result = wait_for_completion( - self.api_url, - self.auth_token, - job_id, - poll_interval_seconds, - timeout_seconds, - show_progress, + raise artifact_errors[0][1] + logger.warning( + f"Artifact build/upload failed for: {failed_names}. " + "Continuing without those artifacts. " + "Set `require_artifacts=True` to see the full traceback." ) - # Store results - store_annotations( - self.adata, - result, - job_id, - results_prefix, - self.group_key, - self.clusters, - check_unannotated=True, - ) + # Save query if requested + if save_query: + save_query_to_file(payload["input_data"], query_filename) + + # Submit job and store details + print() + job_id = submit_annotation_job(self.api_url, self.auth_token, payload) + store_job_details(self.adata, job_id, self.api_url, results_prefix) + + # Wait for completion + result = wait_for_completion( + self.api_url, + self.auth_token, + job_id, + poll_interval_seconds, + timeout_seconds, + show_progress, + ) + + # Store results + store_annotations( + self.adata, + result, + job_id, + results_prefix, + self.group_key, + self.clusters, + check_unannotated=True, + ) - return self.adata - finally: - if cleanup_artifacts: - self._cleanup_artifact_files(artifact_paths) + return self.adata def get_results( self, diff --git a/tests/test_cytetype_integration.py b/tests/test_cytetype_integration.py index 86105ab..0b6993c 100644 --- a/tests/test_cytetype_integration.py +++ b/tests/test_cytetype_integration.py @@ -207,7 +207,7 @@ def test_cytetype_run_artifact_failure_continues_when_not_required( @patch("cytetype.main.wait_for_completion") @patch("cytetype.main.submit_annotation_job") -def test_cytetype_run_cleanup_artifacts( +def test_cytetype_cleanup_deletes_artifact_files( mock_submit: MagicMock, mock_wait: MagicMock, mock_adata: anndata.AnnData, @@ -215,7 +215,7 @@ def test_cytetype_run_cleanup_artifacts( tmp_path: Path, monkeypatch: pytest.MonkeyPatch, ) -> None: - """Test run() can cleanup generated artifact files when requested.""" + """Test cleanup() deletes the artifact files built during initialization.""" mock_submit.return_value = "job_cleanup" mock_wait.return_value = mock_api_response @@ -237,13 +237,17 @@ def _save_obs(*args: Any, **kwargs: Any) -> None: vars_h5_path=str(vars_path), obs_duckdb_path=str(obs_path), ) - ct.run( - study_context="Test", - cleanup_artifacts=True, - ) + ct.run(study_context="Test") + + assert vars_path.exists() + assert obs_path.exists() + + ct.cleanup() assert not vars_path.exists() assert not obs_path.exists() + assert ct._vars_h5_path is None + assert ct._obs_duckdb_path is None @patch("cytetype.main.wait_for_completion") From d1e3f2203b1ecef2f8e27f6017dd48c4d38885b4 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Mon, 2 Mar 2026 12:54:55 +0100 Subject: [PATCH 04/19] Add rank_genes_groups_backed function and update exports - Introduced rank_genes_groups_backed in marker_detection.py for memory-efficient gene ranking on backed AnnData objects. - Updated __init__.py files to include rank_genes_groups_backed in the public API of cytetype and preprocessing modules. - Refactored code for improved readability in main.py, enhancing the formatting of artifact cleanup logic. --- cytetype/__init__.py | 3 +- cytetype/main.py | 9 +- cytetype/preprocessing/__init__.py | 2 + cytetype/preprocessing/marker_detection.py | 243 +++++++++++++++++++++ 4 files changed, 254 insertions(+), 3 deletions(-) create mode 100644 cytetype/preprocessing/marker_detection.py diff --git a/cytetype/__init__.py b/cytetype/__init__.py index dd4464e..3a771d4 100644 --- a/cytetype/__init__.py +++ b/cytetype/__init__.py @@ -4,8 +4,9 @@ from .config import logger from .main import CyteType +from .preprocessing.marker_detection import rank_genes_groups_backed -__all__ = ["CyteType"] +__all__ = ["CyteType", "rank_genes_groups_backed"] _PYPI_JSON_URL = "https://pypi.org/pypi/cytetype/json" diff --git a/cytetype/main.py b/cytetype/main.py index 3226866..bd41528 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -213,7 +213,9 @@ def __init__( try: logger.info("Saving vars.h5 artifact from normalized counts...") raw_mat, raw_col_indices = ( - self._raw_counts_result if self._raw_counts_result is not None else (None, None) + self._raw_counts_result + if self._raw_counts_result is not None + else (None, None) ) save_features_matrix( out_file=vars_h5_path, @@ -338,7 +340,10 @@ def cleanup(self) -> None: Call this after run() completes to remove the vars.h5 and obs.duckdb files from disk. Paths are cleared so repeated calls are safe. """ - for attr, path in [("_vars_h5_path", self._vars_h5_path), ("_obs_duckdb_path", self._obs_duckdb_path)]: + for attr, path in [ + ("_vars_h5_path", self._vars_h5_path), + ("_obs_duckdb_path", self._obs_duckdb_path), + ]: if path is not None: try: Path(path).unlink(missing_ok=True) diff --git a/cytetype/preprocessing/__init__.py b/cytetype/preprocessing/__init__.py index c33ce9d..0af3818 100644 --- a/cytetype/preprocessing/__init__.py +++ b/cytetype/preprocessing/__init__.py @@ -1,6 +1,7 @@ from .validation import validate_adata from .aggregation import aggregate_expression_percentages, aggregate_cluster_metadata from .extraction import extract_marker_genes, extract_visualization_coordinates +from .marker_detection import rank_genes_groups_backed __all__ = [ "validate_adata", @@ -8,4 +9,5 @@ "aggregate_cluster_metadata", "extract_marker_genes", "extract_visualization_coordinates", + "rank_genes_groups_backed", ] diff --git a/cytetype/preprocessing/marker_detection.py b/cytetype/preprocessing/marker_detection.py new file mode 100644 index 0000000..9058b1f --- /dev/null +++ b/cytetype/preprocessing/marker_detection.py @@ -0,0 +1,243 @@ +from __future__ import annotations + +from typing import Literal + +import anndata +import numpy as np +import pandas as pd +from natsort import natsorted +from scipy.stats import ttest_ind_from_stats + +from ..config import logger + + +def _benjamini_hochberg(pvals: np.ndarray) -> np.ndarray: + n = len(pvals) + order = np.argsort(pvals) + pvals_sorted = pvals[order] + adjusted = pvals_sorted * n / np.arange(1, n + 1) + adjusted = np.minimum(adjusted, 1.0) + # enforce monotonicity from the right (largest p-value first) + adjusted = np.minimum.accumulate(adjusted[::-1])[::-1] + result = np.empty_like(pvals) + result[order] = adjusted + return result + + +def rank_genes_groups_backed( + adata: anndata.AnnData, + groupby: str, + *, + use_raw: bool = False, + layer: str | None = None, + n_genes: int | None = None, + key_added: str = "rank_genes_groups", + cell_batch_size: int = 5000, + corr_method: Literal["benjamini-hochberg", "bonferroni"] = "benjamini-hochberg", + pts: bool = False, + copy: bool = False, +) -> anndata.AnnData | None: + """Memory-efficient rank_genes_groups for backed/on-disk AnnData objects. + + Drop-in replacement for ``sc.tl.rank_genes_groups`` that works on backed + ``_CSRDataset`` matrices by streaming cell chunks instead of loading the + full matrix. Uses Welch's t-test with one-vs-rest comparison. + + Writes scanpy-compatible output to ``adata.uns[key_added]``. + """ + adata = adata.copy() if copy else adata + + # --- resolve data source --- + if layer is not None: + if use_raw: + raise ValueError("Cannot specify `layer` and have `use_raw=True`.") + X = adata.layers[layer] + var_names = adata.var_names + elif use_raw: + if adata.raw is None: + raise ValueError("Received `use_raw=True`, but `adata.raw` is empty.") + X = adata.raw.X + var_names = adata.raw.var_names + else: + X = adata.X + var_names = adata.var_names + + n_cells, n_genes_total = X.shape + + # --- resolve groups --- + groupby_series = adata.obs[groupby] + if not hasattr(groupby_series, "cat"): + groupby_series = groupby_series.astype("category") + groups_order = np.array( + natsorted(groupby_series.cat.categories.tolist()), dtype=object + ) + group_labels = groupby_series.values + + # reject singlet groups + group_counts = groupby_series.value_counts() + singlets = group_counts[group_counts < 2].index.tolist() + if singlets: + raise ValueError( + f"Could not calculate statistics for groups {', '.join(str(s) for s in singlets)} " + "since they only contain one sample." + ) + + n_groups = len(groups_order) + group_to_idx = {g: i for i, g in enumerate(groups_order)} + cell_group_indices = np.array([group_to_idx[g] for g in group_labels]) + + # --- log1p base handling (matches scanpy) --- + log1p_base = adata.uns.get("log1p", {}).get("base") + + def expm1_func(x: np.ndarray) -> np.ndarray: + if log1p_base is not None: + result: np.ndarray = np.expm1(x * np.log(log1p_base)) + return result + out: np.ndarray = np.expm1(x) + return out + + # --- accumulate sufficient statistics in one pass --- + logger.info( + "Accumulating statistics over {} cells in chunks of {}...", + n_cells, + cell_batch_size, + ) + sum_ = np.zeros((n_groups, n_genes_total), dtype=np.float64) + sum_sq_ = np.zeros((n_groups, n_genes_total), dtype=np.float64) + n_ = np.zeros(n_groups, dtype=np.int64) + nnz_ = np.zeros((n_groups, n_genes_total), dtype=np.int64) if pts else None + + chunk_starts = range(0, n_cells, cell_batch_size) + try: + from tqdm.auto import tqdm + + chunk_iter = tqdm(chunk_starts, desc="rank_genes_groups_backed", unit="chunk") + except ImportError: + chunk_iter = chunk_starts + + for start in chunk_iter: + end = min(start + cell_batch_size, n_cells) + chunk = X[start:end] + if hasattr(chunk, "toarray"): + chunk = chunk.toarray() + chunk = np.asarray(chunk, dtype=np.float64) + chunk_labels = cell_group_indices[start:end] + + for g_idx in range(n_groups): + mask = chunk_labels == g_idx + if not mask.any(): + continue + g_data = chunk[mask] + sum_[g_idx] += g_data.sum(axis=0) + sum_sq_[g_idx] += (g_data**2).sum(axis=0) + n_[g_idx] += mask.sum() + if nnz_ is not None: + nnz_[g_idx] += (g_data != 0).sum(axis=0) + + total_sum = sum_.sum(axis=0) + total_sum_sq = sum_sq_.sum(axis=0) + + # --- compute per-group statistics and t-test --- + logger.info("Computing t-test statistics for {} groups...", n_groups) + + n_out = n_genes_total if n_genes is None else min(n_genes, n_genes_total) + result_names = np.empty((n_out, n_groups), dtype=object) + result_scores = np.empty((n_out, n_groups), dtype=np.float32) + result_logfc = np.empty((n_out, n_groups), dtype=np.float32) + result_pvals = np.empty((n_out, n_groups), dtype=np.float64) + result_pvals_adj = np.empty((n_out, n_groups), dtype=np.float64) + + if pts and nnz_ is not None: + pts_arr = np.zeros((n_groups, n_genes_total), dtype=np.float64) + pts_rest_arr = np.zeros((n_groups, n_genes_total), dtype=np.float64) + total_nnz = nnz_.sum(axis=0) + + for g_idx in range(n_groups): + ng = n_[g_idx] + nr = n_cells - ng + + mean_g = sum_[g_idx] / ng + var_g = (sum_sq_[g_idx] - sum_[g_idx] ** 2 / ng) / (ng - 1) + + sum_rest = total_sum - sum_[g_idx] + sum_sq_rest = total_sum_sq - sum_sq_[g_idx] + mean_r = sum_rest / nr + var_r = (sum_sq_rest - sum_rest**2 / nr) / (nr - 1) + + with np.errstate(invalid="ignore"): + scores, pvals = ttest_ind_from_stats( + mean1=mean_g, + std1=np.sqrt(var_g), + nobs1=ng, + mean2=mean_r, + std2=np.sqrt(var_r), + nobs2=nr, + equal_var=False, + ) + + scores[np.isnan(scores)] = 0 + pvals[np.isnan(pvals)] = 1 + + if corr_method == "benjamini-hochberg": + pvals_adj = _benjamini_hochberg(pvals) + else: + pvals_adj = np.minimum(pvals * n_genes_total, 1.0) + + with np.errstate(divide="ignore", invalid="ignore"): + foldchanges = (expm1_func(mean_g) + 1e-9) / (expm1_func(mean_r) + 1e-9) + logfc = np.log2(foldchanges) + + top_indices = np.argsort(-scores)[:n_out] + + result_names[:, g_idx] = var_names[top_indices] + result_scores[:, g_idx] = scores[top_indices].astype(np.float32) + result_logfc[:, g_idx] = logfc[top_indices].astype(np.float32) + result_pvals[:, g_idx] = pvals[top_indices] + result_pvals_adj[:, g_idx] = pvals_adj[top_indices] + + if pts and nnz_ is not None: + pts_arr[g_idx] = nnz_[g_idx] / ng + rest_nnz = total_nnz - nnz_[g_idx] + pts_rest_arr[g_idx] = rest_nnz / nr + + # --- build scanpy-compatible recarrays --- + group_names = [str(g) for g in groups_order] + + def _to_recarray(data: np.ndarray, dtype_str: str) -> np.recarray: + df = pd.DataFrame(data, columns=group_names) + rec: np.recarray = df.to_records(index=False, column_dtypes=dtype_str) + return rec + + adata.uns[key_added] = { + "params": { + "groupby": groupby, + "reference": "rest", + "method": "t-test", + "use_raw": use_raw, + "layer": layer, + "corr_method": corr_method, + }, + "names": _to_recarray(result_names, "O"), + "scores": _to_recarray(result_scores, "float32"), + "logfoldchanges": _to_recarray(result_logfc, "float32"), + "pvals": _to_recarray(result_pvals, "float64"), + "pvals_adj": _to_recarray(result_pvals_adj, "float64"), + } + + if pts and nnz_ is not None: + adata.uns[key_added]["pts"] = pd.DataFrame( + pts_arr.T, index=var_names, columns=group_names + ) + adata.uns[key_added]["pts_rest"] = pd.DataFrame( + pts_rest_arr.T, index=var_names, columns=group_names + ) + + logger.info( + "rank_genes_groups_backed complete — {} genes per group written to adata.uns['{}']", + n_out, + key_added, + ) + + if copy: + return adata + return None From fbb2c4dc87308fb2ad95228d5e91d286e0a81856 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Mon, 2 Mar 2026 14:37:29 +0100 Subject: [PATCH 05/19] Enhance gene symbol handling in CyteType - Introduced resolve_gene_symbols_column function to auto-detect gene symbols in AnnData, improving flexibility in gene symbol management. - Updated gene_symbols_column type to accept None, allowing for better handling of cases where gene symbols are not explicitly provided. - Refactored aggregate_expression_percentages and extract_marker_genes functions to accommodate the new gene symbol resolution logic. - Enhanced validation in _validate_gene_symbols_column to provide clearer warnings about potential gene ID misclassifications. --- cytetype/main.py | 26 ++-- cytetype/preprocessing/__init__.py | 3 +- cytetype/preprocessing/aggregation.py | 15 +- cytetype/preprocessing/extraction.py | 10 +- cytetype/preprocessing/validation.py | 200 ++++++++++++++++++-------- 5 files changed, 177 insertions(+), 77 deletions(-) diff --git a/cytetype/main.py b/cytetype/main.py index bd41528..19f152b 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -14,6 +14,7 @@ ) from .preprocessing import ( validate_adata, + resolve_gene_symbols_column, aggregate_expression_percentages, extract_marker_genes, aggregate_cluster_metadata, @@ -55,7 +56,7 @@ class CyteType: adata: anndata.AnnData group_key: str rank_key: str - gene_symbols_column: str + gene_symbols_column: str | None n_top_genes: int pcent_batch_size: int coordinates_key: str | None @@ -72,7 +73,7 @@ def __init__( adata: anndata.AnnData, group_key: str, rank_key: str = "rank_genes_groups", - gene_symbols_column: str = "gene_symbols", + gene_symbols_column: str | None = None, n_top_genes: int = 50, aggregate_metadata: bool = True, min_percentage: int = 10, @@ -94,8 +95,10 @@ def __init__( rank_key (str, optional): The key in `adata.uns` containing differential expression results from `sc.tl.rank_genes_groups`. Must use the same `groupby` as `group_key`. Defaults to "rank_genes_groups". - gene_symbols_column (str, optional): Name of the column in `adata.var` that contains - the gene symbols. Defaults to "gene_symbols". + gene_symbols_column (str | None, optional): Name of the column in `adata.var` that + contains gene symbols. When None (default), auto-detects by checking well-known + column names (feature_name, gene_symbols, etc.), then adata.var_names, then + heuristic scan of all var columns. Defaults to None. n_top_genes (int, optional): Number of top marker genes per cluster to extract during initialization. Higher values may improve annotation quality but increase memory usage. Defaults to 50. @@ -123,7 +126,6 @@ def __init__( self.adata = adata self.group_key = group_key self.rank_key = rank_key - self.gene_symbols_column = gene_symbols_column self.n_top_genes = n_top_genes self.pcent_batch_size = pcent_batch_size self.coordinates_key = coordinates_key @@ -132,9 +134,12 @@ def __init__( self.auth_token = auth_token self._artifact_build_errors: list[tuple[str, Exception]] = [] - # Validate data and get the best available coordinates key + self.gene_symbols_column = resolve_gene_symbols_column( + adata, gene_symbols_column + ) + self.coordinates_key = validate_adata( - adata, group_key, rank_key, gene_symbols_column, coordinates_key + adata, group_key, rank_key, self.gene_symbols_column, coordinates_key ) # Use original labels as IDs if all are short (<=3 chars), otherwise enumerate @@ -151,11 +156,16 @@ def __init__( ] logger.info("Calculating expression percentages...") + gene_names = ( + adata.var[self.gene_symbols_column].tolist() + if self.gene_symbols_column is not None + else adata.var_names.tolist() + ) self.expression_percentages = aggregate_expression_percentages( adata=adata, clusters=self.clusters, batch_size=pcent_batch_size, - gene_names=adata.var[self.gene_symbols_column].tolist(), + gene_names=gene_names, ) logger.info("Extracting marker genes...") diff --git a/cytetype/preprocessing/__init__.py b/cytetype/preprocessing/__init__.py index 0af3818..e1391f6 100644 --- a/cytetype/preprocessing/__init__.py +++ b/cytetype/preprocessing/__init__.py @@ -1,10 +1,11 @@ -from .validation import validate_adata +from .validation import validate_adata, resolve_gene_symbols_column from .aggregation import aggregate_expression_percentages, aggregate_cluster_metadata from .extraction import extract_marker_genes, extract_visualization_coordinates from .marker_detection import rank_genes_groups_backed __all__ = [ "validate_adata", + "resolve_gene_symbols_column", "aggregate_expression_percentages", "aggregate_cluster_metadata", "extract_marker_genes", diff --git a/cytetype/preprocessing/aggregation.py b/cytetype/preprocessing/aggregation.py index edd2dd3..6a8fed5 100644 --- a/cytetype/preprocessing/aggregation.py +++ b/cytetype/preprocessing/aggregation.py @@ -17,7 +17,7 @@ def aggregate_expression_percentages( Returns: Dictionary mapping gene names to cluster-level expression percentages """ - pcent = {} + pcent: dict[str, dict[str, float]] = {} n_genes = adata.shape[1] for s in range(0, n_genes, batch_size): @@ -32,10 +32,17 @@ def aggregate_expression_percentages( f"Unexpected data type in `adata.raw.X` slice: {type(batch_data)}" ) - df = pd.DataFrame(batch_data > 0, columns=gene_names[s:e]) * 100 + # Use integer columns to avoid duplicate-name warnings, then + # map back to gene names (last duplicate wins, matching dict semantics). + df = pd.DataFrame(batch_data > 0) * 100 df["clusters"] = clusters - pcent.update(df.groupby("clusters").mean().round(2).to_dict()) - del df, batch_data + means = df.groupby("clusters").mean().round(2) + + batch_names = gene_names[s:e] + for col_idx, name in enumerate(batch_names): + pcent[name] = means[col_idx].to_dict() + + del df, batch_data, means return pcent diff --git a/cytetype/preprocessing/extraction.py b/cytetype/preprocessing/extraction.py index f5c9252..fbcaeb2 100644 --- a/cytetype/preprocessing/extraction.py +++ b/cytetype/preprocessing/extraction.py @@ -10,7 +10,7 @@ def extract_marker_genes( rank_genes_key: str, cluster_map: dict[str, str], n_top_genes: int, - gene_symbols_col: str, + gene_symbols_col: str | None, ) -> dict[str, list[str]]: """Extract top marker genes from rank_genes_groups results. @@ -20,7 +20,8 @@ def extract_marker_genes( rank_genes_key: Key in adata.uns containing rank_genes_groups results cluster_map: Dictionary mapping original labels to cluster IDs n_top_genes: Number of top genes to extract per cluster - gene_symbols_col: Column in adata.var containing gene symbols + gene_symbols_col: Column in adata.var containing gene symbols, + or None to use var_names directly (identity mapping). Returns: Dictionary mapping cluster IDs to lists of marker gene symbols @@ -44,7 +45,10 @@ def extract_marker_genes( f"Failed to extract marker gene names from `rank_genes_groups`. Error: {e}" ) - gene_ids_to_name = adata.var[gene_symbols_col].to_dict() + if gene_symbols_col is not None: + gene_ids_to_name = adata.var[gene_symbols_col].to_dict() + else: + gene_ids_to_name = dict(zip(adata.var_names, adata.var_names)) markers = {} any_genes_found = False diff --git a/cytetype/preprocessing/validation.py b/cytetype/preprocessing/validation.py index 0d6fdbc..2eff60c 100644 --- a/cytetype/preprocessing/validation.py +++ b/cytetype/preprocessing/validation.py @@ -3,58 +3,60 @@ from ..config import logger +_KNOWN_GENE_SYMBOL_COLUMNS = [ + "feature_name", + "gene_symbols", + "gene_symbol", + "gene_short_name", + "gene_name", + "symbol", +] -def _is_gene_id_like(value: str) -> bool: - """Check if a value looks like a gene ID rather than a gene symbol. - - Common gene ID patterns: - - Ensembl: ENSG00000000003, ENSMUSG00000000001, etc. - - RefSeq: NM_000001, XM_000001, etc. - - Numeric IDs: just numbers - - Other database IDs with similar patterns - - Args: - value: String value to check - Returns: - bool: True if the value looks like a gene ID, False if it looks like a gene symbol - """ +def _is_gene_id_like(value: str) -> bool: if not isinstance(value, str) or not value.strip(): return False value = value.strip() - # Ensembl IDs (human, mouse, etc.) if re.match(r"^ENS[A-Z]*G\d{11}$", value, re.IGNORECASE): return True - # RefSeq IDs if re.match(r"^[NX][MR]_\d+$", value): return True - # Purely numeric IDs if re.match(r"^\d+$", value): return True - # Other common ID patterns (long alphanumeric with underscores/dots) if re.match(r"^[A-Z0-9]+[._][A-Z0-9._]+$", value) and len(value) > 10: return True return False +def _has_composite_gene_values(values: list[str]) -> bool: + """Detect values like 'TSPAN6_ENSG00000000003' (symbol_id or id_symbol).""" + composite_count = 0 + for v in values[:200]: + parts = re.split(r"[_|]", v, maxsplit=1) + if len(parts) == 2: + id_flags = [_is_gene_id_like(p) for p in parts] + if id_flags[0] != id_flags[1]: + composite_count += 1 + return len(values) > 0 and (composite_count / min(200, len(values))) > 0.5 + + +def _id_like_percentage(values: list[str]) -> float: + if not values: + return 100.0 + n = min(500, len(values)) + sample = values[:n] + return sum(1 for v in sample if _is_gene_id_like(v)) / n * 100 + + def _validate_gene_symbols_column( adata: anndata.AnnData, gene_symbols_col: str ) -> None: - """Validate that the gene_symbols_col contains gene symbols rather than gene IDs. - - Args: - adata: AnnData object - gene_symbols_col: Column name in adata.var that should contain gene symbols - - Raises: - ValueError: If the column appears to contain gene IDs instead of gene symbols - """ gene_values = adata.var[gene_symbols_col].dropna().astype(str) if len(gene_values) == 0: @@ -63,54 +65,133 @@ def _validate_gene_symbols_column( ) return - # Sample a subset for efficiency (check up to 1000 non-null values) - sample_size = min(1000, len(gene_values)) - sample_values = gene_values.sample(n=sample_size) - - # Count how many look like gene IDs vs gene symbols - id_like_count = sum(1 for value in sample_values if _is_gene_id_like(value)) - id_like_percentage = (id_like_count / len(sample_values)) * 100 + values_list = gene_values.tolist() + pct = _id_like_percentage(values_list) - if id_like_percentage > 50: - example_ids = [ - value for value in sample_values.iloc[:5] if _is_gene_id_like(value) - ] + if pct > 50: + example_ids = [v for v in values_list[:20] if _is_gene_id_like(v)][:3] logger.warning( f"Column '{gene_symbols_col}' appears to contain gene IDs rather than gene symbols. " - f"{id_like_percentage:.1f}% of values look like gene IDs (e.g., {example_ids[:3]}). " + f"{pct:.1f}% of values look like gene IDs (e.g., {example_ids}). " f"The annotation might not be accurate. Consider using a column that contains " f"human-readable gene symbols (e.g., 'TSPAN6', 'DPM1', 'SCYL3') instead of database identifiers." ) - elif id_like_percentage > 20: + elif pct > 20: logger.warning( - f"Column '{gene_symbols_col}' contains {id_like_percentage:.1f}% values that look like gene IDs. " + f"Column '{gene_symbols_col}' contains {pct:.1f}% values that look like gene IDs. " f"Please verify this column contains gene symbols rather than gene identifiers." ) +def resolve_gene_symbols_column( + adata: anndata.AnnData, gene_symbols_column: str | None +) -> str | None: + """Resolve which source contains gene symbols. + + Returns the column name in adata.var, or None if var_names should be used directly. + """ + if gene_symbols_column is not None: + if gene_symbols_column not in adata.var.columns: + raise KeyError( + f"Column '{gene_symbols_column}' not found in `adata.var`. " + f"Available columns: {list(adata.var.columns)}. " + f"Set gene_symbols_column=None for auto-detection." + ) + _validate_gene_symbols_column(adata, gene_symbols_column) + logger.info(f"Using gene symbols from column '{gene_symbols_column}'.") + return gene_symbols_column + + # --- Auto-detection: score all candidates, then pick the best --- + # Each candidate: (column_name | None, id_like_pct, unique_ratio, priority) + # column_name=None → use var_names. + # priority: 0 = known column, 1 = var_names, 2 = other column. + # Sorted by (id_like_pct ↑, priority ↑, unique_ratio ↓ with 1.0 penalized) + # so the lowest ID-like % wins; ties broken by known columns first, then + # by higher unique ratio (gene names have high cardinality, unlike + # categorical metadata — but exactly 1.0 is slightly penalized because + # IDs are always unique while gene symbols occasionally have duplicates). + _KNOWN_SET = set(_KNOWN_GENE_SYMBOL_COLUMNS) + candidates: list[tuple[str | None, float, float, int]] = [] + + for col in _KNOWN_GENE_SYMBOL_COLUMNS: + if col not in adata.var.columns: + continue + values = adata.var[col].dropna().astype(str).tolist() + if not values: + continue + if _has_composite_gene_values(values): + logger.warning( + f"Column '{col}' appears to contain composite gene name/ID values " + f"(e.g., '{values[0]}'). Skipping." + ) + continue + pct = _id_like_percentage(values) + unique_ratio = len(set(values)) / len(values) + candidates.append((col, pct, unique_ratio, 0)) + + var_names_list = adata.var_names.astype(str).tolist() + if var_names_list: + var_id_pct = _id_like_percentage(var_names_list) + var_unique_ratio = len(set(var_names_list)) / len(var_names_list) + candidates.append((None, var_id_pct, var_unique_ratio, 1)) + + for col in adata.var.columns: + if col in _KNOWN_SET: + continue + try: + values = adata.var[col].dropna().astype(str).tolist() + except (TypeError, ValueError): + continue + if not values: + continue + if _has_composite_gene_values(values): + continue + n_unique = len(set(values)) + if n_unique < max(10, len(values) * 0.05): + continue + pct = _id_like_percentage(values) + unique_ratio = n_unique / len(values) + candidates.append((col, pct, unique_ratio, 2)) + + viable = [c for c in candidates if c[1] < 50] + + def _ur_sort_key(ur: float) -> float: + return ur if ur < 1.0 else ur - 0.02 + + if viable: + viable.sort(key=lambda c: (c[1], c[3], -_ur_sort_key(c[2]))) + best_col, best_pct, best_ur, _ = viable[0] + + if best_col is not None: + source = f"column '{best_col}'" + _validate_gene_symbols_column(adata, best_col) + else: + source = "adata.var_names (index)" + + logger.info( + f"Auto-detected gene symbols in {source} " + f"({best_pct:.0f}% ID-like, {best_ur:.0%} unique)." + ) + return best_col + + # No viable candidate: fall back to var_names with warning + fallback_pct = var_id_pct if var_names_list else 100.0 + logger.warning( + "Could not find a column containing gene symbols in adata.var. " + "Falling back to adata.var_names, but they appear to contain gene IDs " + f"({fallback_pct:.0f}% ID-like). Annotation quality may be affected. " + "Consider providing gene_symbols_column explicitly." + ) + return None + + def validate_adata( adata: anndata.AnnData, cell_group_key: str, rank_genes_key: str, - gene_symbols_col: str, + gene_symbols_col: str | None, coordinates_key: str, ) -> str | None: - """Validate the AnnData object structure and return the best available coordinates key. - - Args: - adata: AnnData object to validate - cell_group_key: Key in adata.obs containing cluster labels - rank_genes_key: Key in adata.uns containing rank_genes_groups results - gene_symbols_col: Column in adata.var containing gene symbols - coordinates_key: Preferred key in adata.obsm for coordinates - - Returns: - str | None: The coordinates key that was found and validated, or None if no suitable coordinates found. - - Raises: - KeyError: If required keys are missing - ValueError: If data format is incorrect - """ if cell_group_key not in adata.obs: raise KeyError(f"Cell group key '{cell_group_key}' not found in `adata.obs`.") if adata.X is None: @@ -123,9 +204,6 @@ def validate_adata( raise KeyError( f"'{rank_genes_key}' not found in `adata.uns`. Run `sc.tl.rank_genes_groups` first." ) - if hasattr(adata.var, gene_symbols_col) is False: - raise KeyError(f"Column '{gene_symbols_col}' not found in `adata.var`.") - _validate_gene_symbols_column(adata, gene_symbols_col) if adata.uns[rank_genes_key]["params"]["groupby"] != cell_group_key: raise ValueError( From 06084552483e0f48997127d44b9d9a0bf082a15f Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Mon, 2 Mar 2026 14:41:10 +0100 Subject: [PATCH 06/19] Update batch size for expression percentage calculations and refactor aggregation logic - Increased the default batch size for calculating expression percentages from 2000 to 5000 to optimize memory usage. - Refactored the aggregate_expression_percentages function to utilize a single-pass row-batched accumulation method for improved performance. - Introduced a new _accumulate_group_stats function to streamline the computation of per-group statistics, enhancing efficiency for large datasets. - Updated related documentation to reflect changes in parameters and processing logic. --- cytetype/main.py | 8 +- cytetype/preprocessing/aggregation.py | 60 ++++++----- cytetype/preprocessing/marker_detection.py | 110 +++++++++++++++------ 3 files changed, 116 insertions(+), 62 deletions(-) diff --git a/cytetype/main.py b/cytetype/main.py index 19f152b..920c15c 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -77,7 +77,7 @@ def __init__( n_top_genes: int = 50, aggregate_metadata: bool = True, min_percentage: int = 10, - pcent_batch_size: int = 2000, + pcent_batch_size: int = 5000, coordinates_key: str = "X_umap", max_cells_per_group: int = 1000, vars_h5_path: str = "vars.h5", @@ -106,8 +106,8 @@ def __init__( Defaults to True. min_percentage (int, optional): Minimum percentage of cells in a group to include in the cluster context. Defaults to 10. - pcent_batch_size (int, optional): Batch size for calculating expression percentages to - optimize memory usage. Defaults to 2000. + pcent_batch_size (int, optional): Number of cells to process per chunk when + calculating expression percentages. Defaults to 5000. coordinates_key (str, optional): Key in adata.obsm containing 2D coordinates for visualization. Must be a 2D array with same number of elements as clusters. Defaults to "X_umap". @@ -164,8 +164,8 @@ def __init__( self.expression_percentages = aggregate_expression_percentages( adata=adata, clusters=self.clusters, - batch_size=pcent_batch_size, gene_names=gene_names, + cell_batch_size=pcent_batch_size, ) logger.info("Extracting marker genes...") diff --git a/cytetype/preprocessing/aggregation.py b/cytetype/preprocessing/aggregation.py index 6a8fed5..c3a2cc7 100644 --- a/cytetype/preprocessing/aggregation.py +++ b/cytetype/preprocessing/aggregation.py @@ -2,47 +2,53 @@ import numpy as np import pandas as pd +from .marker_detection import _accumulate_group_stats + def aggregate_expression_percentages( - adata: anndata.AnnData, clusters: list[str], batch_size: int, gene_names: list[str] + adata: anndata.AnnData, + clusters: list[str], + gene_names: list[str], + cell_batch_size: int = 5000, ) -> dict[str, dict[str, float]]: """Aggregate gene expression percentages per cluster. + Uses a single-pass row-batched accumulation (fast for CSR / backed data). + Args: adata: AnnData object containing expression data clusters: List of cluster assignments for each cell - batch_size: Number of genes to process per batch (for memory efficiency) gene_names: List of gene names corresponding to columns in adata.X + cell_batch_size: Number of cells to process per chunk Returns: Dictionary mapping gene names to cluster-level expression percentages """ - pcent: dict[str, dict[str, float]] = {} - n_genes = adata.shape[1] - - for s in range(0, n_genes, batch_size): - e = min(s + batch_size, n_genes) - batch_data = adata.X[:, s:e] - if hasattr(batch_data, "toarray"): - batch_data = batch_data.toarray() - elif isinstance(batch_data, np.ndarray): - pass - else: - raise TypeError( - f"Unexpected data type in `adata.raw.X` slice: {type(batch_data)}" - ) - - # Use integer columns to avoid duplicate-name warnings, then - # map back to gene names (last duplicate wins, matching dict semantics). - df = pd.DataFrame(batch_data > 0) * 100 - df["clusters"] = clusters - means = df.groupby("clusters").mean().round(2) + unique_clusters = sorted(set(clusters)) + n_groups = len(unique_clusters) + cluster_to_idx = {c: i for i, c in enumerate(unique_clusters)} + cell_group_indices = np.array([cluster_to_idx[c] for c in clusters]) + + stats = _accumulate_group_stats( + adata.X, + cell_group_indices, + n_groups, + adata.shape[1], + cell_batch_size=cell_batch_size, + compute_nnz=True, + progress_desc="expression_percentages", + ) + + with np.errstate(divide="ignore", invalid="ignore"): + pct_matrix = np.round(stats.nnz / stats.n[:, None] * 100, 2) + pct_matrix = np.nan_to_num(pct_matrix, nan=0.0) - batch_names = gene_names[s:e] - for col_idx, name in enumerate(batch_names): - pcent[name] = means[col_idx].to_dict() - - del df, batch_data, means + pcent: dict[str, dict[str, float]] = {} + for gene_idx, name in enumerate(gene_names): + pcent[name] = { + unique_clusters[g_idx]: float(pct_matrix[g_idx, gene_idx]) + for g_idx in range(n_groups) + } return pcent diff --git a/cytetype/preprocessing/marker_detection.py b/cytetype/preprocessing/marker_detection.py index 9058b1f..9757d25 100644 --- a/cytetype/preprocessing/marker_detection.py +++ b/cytetype/preprocessing/marker_detection.py @@ -1,5 +1,6 @@ from __future__ import annotations +from dataclasses import dataclass from typing import Literal import anndata @@ -11,6 +12,70 @@ from ..config import logger +@dataclass +class GroupStats: + n: np.ndarray + nnz: np.ndarray | None + sum_: np.ndarray | None + sum_sq: np.ndarray | None + + +def _accumulate_group_stats( + X, + cell_group_indices: np.ndarray, + n_groups: int, + n_genes: int, + cell_batch_size: int = 5000, + compute_nnz: bool = False, + compute_moments: bool = False, + progress_desc: str = "streaming", +) -> GroupStats: + """Single-pass row-batched accumulation of per-group statistics. + + Streams row chunks from *X* (works with dense, sparse, and backed + matrices) and accumulates the requested statistics per group. + """ + n_cells = X.shape[0] + + n_ = np.zeros(n_groups, dtype=np.int64) + nnz_ = np.zeros((n_groups, n_genes), dtype=np.int64) if compute_nnz else None + sum_ = np.zeros((n_groups, n_genes), dtype=np.float64) if compute_moments else None + sum_sq_ = ( + np.zeros((n_groups, n_genes), dtype=np.float64) if compute_moments else None + ) + + chunk_starts = range(0, n_cells, cell_batch_size) + try: + from tqdm.auto import tqdm + + chunk_iter = tqdm(chunk_starts, desc=progress_desc, unit="chunk") + except ImportError: + chunk_iter = chunk_starts + + for start in chunk_iter: + end = min(start + cell_batch_size, n_cells) + chunk = X[start:end] + if hasattr(chunk, "toarray"): + chunk = chunk.toarray() + chunk = np.asarray(chunk, dtype=np.float64) + chunk_labels = cell_group_indices[start:end] + + for g_idx in range(n_groups): + mask = chunk_labels == g_idx + if not mask.any(): + continue + g_data = chunk[mask] + n_[g_idx] += mask.sum() + if sum_ is not None: + sum_[g_idx] += g_data.sum(axis=0) + if sum_sq_ is not None: + sum_sq_[g_idx] += (g_data**2).sum(axis=0) + if nnz_ is not None: + nnz_[g_idx] += (g_data != 0).sum(axis=0) + + return GroupStats(n=n_, nnz=nnz_, sum_=sum_, sum_sq=sum_sq_) + + def _benjamini_hochberg(pvals: np.ndarray) -> np.ndarray: n = len(pvals) order = np.argsort(pvals) @@ -102,37 +167,20 @@ def expm1_func(x: np.ndarray) -> np.ndarray: n_cells, cell_batch_size, ) - sum_ = np.zeros((n_groups, n_genes_total), dtype=np.float64) - sum_sq_ = np.zeros((n_groups, n_genes_total), dtype=np.float64) - n_ = np.zeros(n_groups, dtype=np.int64) - nnz_ = np.zeros((n_groups, n_genes_total), dtype=np.int64) if pts else None - - chunk_starts = range(0, n_cells, cell_batch_size) - try: - from tqdm.auto import tqdm - - chunk_iter = tqdm(chunk_starts, desc="rank_genes_groups_backed", unit="chunk") - except ImportError: - chunk_iter = chunk_starts - - for start in chunk_iter: - end = min(start + cell_batch_size, n_cells) - chunk = X[start:end] - if hasattr(chunk, "toarray"): - chunk = chunk.toarray() - chunk = np.asarray(chunk, dtype=np.float64) - chunk_labels = cell_group_indices[start:end] - - for g_idx in range(n_groups): - mask = chunk_labels == g_idx - if not mask.any(): - continue - g_data = chunk[mask] - sum_[g_idx] += g_data.sum(axis=0) - sum_sq_[g_idx] += (g_data**2).sum(axis=0) - n_[g_idx] += mask.sum() - if nnz_ is not None: - nnz_[g_idx] += (g_data != 0).sum(axis=0) + stats = _accumulate_group_stats( + X, + cell_group_indices, + n_groups, + n_genes_total, + cell_batch_size=cell_batch_size, + compute_nnz=pts, + compute_moments=True, + progress_desc="rank_genes_groups_backed", + ) + sum_ = stats.sum_ + sum_sq_ = stats.sum_sq + n_ = stats.n + nnz_ = stats.nnz total_sum = sum_.sum(axis=0) total_sum_sq = sum_sq_.sum(axis=0) From 10d69a2f9cd1ba8de436a24e1e7df43c8b8af6d1 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Mon, 2 Mar 2026 16:02:58 +0100 Subject: [PATCH 07/19] Refactor logging and enhance progress reporting in CyteType - Removed unnecessary logging statements for calculating expression percentages and extracting visualization coordinates to streamline output. - Updated logging message for saving obs.duckdb artifact for clarity. - Integrated progress reporting using tqdm for batch processing in save_features_matrix and extract_visualization_coordinates functions. - Improved handling of warnings during batch processing to suppress FutureWarnings from tqdm. - Adjusted progress descriptions for better user feedback during long-running operations. --- cytetype/core/artifacts.py | 32 ++++++++++++++++++++-- cytetype/main.py | 7 +++-- cytetype/preprocessing/aggregation.py | 3 +- cytetype/preprocessing/extraction.py | 32 ++++++++++++++-------- cytetype/preprocessing/marker_detection.py | 16 +++++++---- 5 files changed, 65 insertions(+), 25 deletions(-) diff --git a/cytetype/core/artifacts.py b/cytetype/core/artifacts.py index 893c6d2..eb00155 100644 --- a/cytetype/core/artifacts.py +++ b/cytetype/core/artifacts.py @@ -168,7 +168,19 @@ def _write_raw_group( ) indptr: list[int] = [0] - for start in range(0, n_obs, cell_batch): + batch_starts = range(0, n_obs, cell_batch) + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + from tqdm.auto import tqdm + + batch_iter = tqdm(batch_starts, desc="Writing raw counts", unit="batch") + except ImportError: + batch_iter = batch_starts + + for start in batch_iter: end = min(start + cell_batch, n_obs) raw_chunk = mat[start:end] @@ -220,7 +232,7 @@ def save_features_matrix( ) if col_batch is None: - col_batch = max(1, int(100_000_000 / max(n_rows, 1))) + col_batch = max(1, int(1_000_000_000 / max(n_rows, 1))) chunk_size = max(1, min(n_rows * 10, min_chunk_size)) with h5py.File(out_file, "w") as f: @@ -246,7 +258,21 @@ def save_features_matrix( ) indptr: list[int] = [0] - for start in range(0, n_cols, col_batch): + col_starts = range(0, n_cols, col_batch) + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + from tqdm.auto import tqdm + + col_iter = tqdm( + col_starts, desc="Writing pivoted normalized counts", unit="batch" + ) + except ImportError: + col_iter = col_starts + + for start in col_iter: end = min(start + col_batch, n_cols) raw_chunk = mat[:, start:end] chunk = ( diff --git a/cytetype/main.py b/cytetype/main.py index 920c15c..7e17d42 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -1,3 +1,4 @@ +import sys from pathlib import Path from typing import Any from importlib.metadata import PackageNotFoundError, version @@ -155,7 +156,6 @@ def __init__( self.cluster_map[str(x)] for x in adata.obs[group_key].values.tolist() ] - logger.info("Calculating expression percentages...") gene_names = ( adata.var[self.gene_symbols_column].tolist() if self.gene_symbols_column is not None @@ -197,7 +197,6 @@ def __init__( self.group_metadata = {} # Prepare visualization data with sampling - logger.info("Extracting sampled visualization coordinates...") sampled_coordinates, sampled_cluster_labels = extract_visualization_coordinates( adata=adata, coordinates_key=self.coordinates_key, @@ -235,6 +234,7 @@ def __init__( raw_mat=raw_mat, raw_col_indices=raw_col_indices, ) + sys.stderr.flush() self._vars_h5_path: str | None = vars_h5_path except Exception as exc: logger.warning(f"vars.h5 artifact failed during build: {exc}") @@ -243,7 +243,7 @@ def __init__( # Build obs.duckdb try: - logger.info("Saving obs.duckdb artifact from observation metadata...") + logger.info("Writing obs.duckdb artifact from observation metadata...") obsm_coordinates = ( self.adata.obsm[self.coordinates_key] if self.coordinates_key and self.coordinates_key in self.adata.obsm @@ -255,6 +255,7 @@ def __init__( obsm_coordinates=obsm_coordinates, coordinates_key=self.coordinates_key, ) + sys.stderr.flush() self._obs_duckdb_path: str | None = obs_duckdb_path except Exception as exc: logger.warning(f"obs.duckdb artifact failed during build: {exc}") diff --git a/cytetype/preprocessing/aggregation.py b/cytetype/preprocessing/aggregation.py index c3a2cc7..c279725 100644 --- a/cytetype/preprocessing/aggregation.py +++ b/cytetype/preprocessing/aggregation.py @@ -1,6 +1,5 @@ import anndata import numpy as np -import pandas as pd from .marker_detection import _accumulate_group_stats @@ -36,7 +35,7 @@ def aggregate_expression_percentages( adata.shape[1], cell_batch_size=cell_batch_size, compute_nnz=True, - progress_desc="expression_percentages", + progress_desc="Calculating expression percentages", ) with np.errstate(divide="ignore", invalid="ignore"): diff --git a/cytetype/preprocessing/extraction.py b/cytetype/preprocessing/extraction.py index fbcaeb2..79d96f0 100644 --- a/cytetype/preprocessing/extraction.py +++ b/cytetype/preprocessing/extraction.py @@ -127,8 +127,24 @@ def extract_visualization_coordinates( ) # Sample cells from each group using pandas + unique_groups = coord_df["group"].unique() + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + from tqdm.auto import tqdm + + group_iter = tqdm( + unique_groups, + desc=f"Sampling coordinates from {coordinates_key}", + unit="group", + ) + except ImportError: + group_iter = unique_groups + sampled_coords = [] - for group_label in coord_df["group"].unique(): + for group_label in group_iter: group_mask = coord_df["group"] == group_label group_size = group_mask.sum() sample_size = min(max_cells_per_group, group_size) @@ -138,12 +154,6 @@ def extract_visualization_coordinates( ) sampled_coords.append(sampled_group) - if group_size > max_cells_per_group: - logger.info( - f"Sampled {sample_size} cells from group '{group_label}' " - f"(originally {group_size} cells)" - ) - # Concatenate all sampled groups sampled_coord_df: pd.DataFrame = pd.concat(sampled_coords, ignore_index=True) @@ -156,9 +166,9 @@ def extract_visualization_coordinates( for label in sampled_coord_df["group"].values ] - logger.info( - f"Extracted {len(sampled_coordinates)} coordinate points " - f"(sampled from {len(coordinates)} total cells)" - ) + # logger.info( + # f"Extracted {len(sampled_coordinates)} coordinate points " + # f"(sampled from {len(coordinates)} total cells)" + # ) return sampled_coordinates, sampled_cluster_labels diff --git a/cytetype/preprocessing/marker_detection.py b/cytetype/preprocessing/marker_detection.py index 9757d25..abcb125 100644 --- a/cytetype/preprocessing/marker_detection.py +++ b/cytetype/preprocessing/marker_detection.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Literal +from typing import Any, Literal import anndata import numpy as np @@ -21,7 +21,7 @@ class GroupStats: def _accumulate_group_stats( - X, + X: Any, cell_group_indices: np.ndarray, n_groups: int, n_genes: int, @@ -46,7 +46,11 @@ def _accumulate_group_stats( chunk_starts = range(0, n_cells, cell_batch_size) try: - from tqdm.auto import tqdm + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=FutureWarning) + from tqdm.auto import tqdm chunk_iter = tqdm(chunk_starts, desc=progress_desc, unit="chunk") except ImportError: @@ -177,11 +181,11 @@ def expm1_func(x: np.ndarray) -> np.ndarray: compute_moments=True, progress_desc="rank_genes_groups_backed", ) - sum_ = stats.sum_ - sum_sq_ = stats.sum_sq n_ = stats.n nnz_ = stats.nnz - + sum_ = stats.sum_ + sum_sq_ = stats.sum_sq + assert sum_ is not None and sum_sq_ is not None total_sum = sum_.sum(axis=0) total_sum_sq = sum_sq_.sum(axis=0) From fd133a2e2a47cd841520e2a6382f7e80effa3e9b Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Mon, 2 Mar 2026 16:39:05 +0100 Subject: [PATCH 08/19] Add WRITE_MEM_BUDGET constant and enhance logging in CyteType - Introduced WRITE_MEM_BUDGET constant in config.py to define memory budget for writing artifacts. - Updated logging messages in main.py for clarity during artifact saving processes. - Enhanced progress reporting in artifact writing functions to improve user feedback. - Refactored warning handling to suppress FutureWarnings from tqdm during batch processing. - Added new functions in artifacts.py for improved handling of sparse matrix writing and progress tracking. --- cytetype/config.py | 2 + cytetype/core/artifacts.py | 266 ++++++++++++++++----- cytetype/main.py | 3 +- cytetype/preprocessing/marker_detection.py | 7 +- tests/test_artifacts.py | 72 ++++++ 5 files changed, 283 insertions(+), 67 deletions(-) diff --git a/cytetype/config.py b/cytetype/config.py index 52bd61c..8979525 100644 --- a/cytetype/config.py +++ b/cytetype/config.py @@ -24,3 +24,5 @@ def _log_format(record: Record) -> str: level="INFO", format=_log_format, ) + +WRITE_MEM_BUDGET: int = 4 * 1024 * 1024 * 1024 # 4 GB diff --git a/cytetype/core/artifacts.py b/cytetype/core/artifacts.py index eb00155..ee4386c 100644 --- a/cytetype/core/artifacts.py +++ b/cytetype/core/artifacts.py @@ -11,7 +11,7 @@ import scipy.sparse as sp from anndata.abc import CSCDataset, CSRDataset -from ..config import logger +from ..config import logger, WRITE_MEM_BUDGET def _safe_column_dataset_name( @@ -176,7 +176,9 @@ def _write_raw_group( warnings.simplefilter("ignore", category=FutureWarning) from tqdm.auto import tqdm - batch_iter = tqdm(batch_starts, desc="Writing raw counts", unit="batch") + batch_iter = tqdm( + batch_starts, desc="Writing raw counts to H5 artifact", unit="batch" + ) except ImportError: batch_iter = batch_starts @@ -206,6 +208,202 @@ def _write_raw_group( group.create_dataset("indptr", data=np.asarray(indptr, dtype=np.int64)) +def _try_import_tqdm() -> type | None: + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from tqdm.auto import tqdm + + return tqdm # type: ignore[no-any-return] + except ImportError: + return None + + +def _write_csc_via_row_batches( + group: h5py.Group, + mat: Any, + n_rows: int, + n_cols: int, + min_chunk_size: int, +) -> None: + row_batch = max(1, int(100_000_000 / max(n_cols, 1))) + chunk_size = max(1, min(n_rows * 10, min_chunk_size)) + tqdm = _try_import_tqdm() + + # --- Pass 1: count nnz per column --- + col_counts = np.zeros(n_cols, dtype=np.int64) + row_starts = range(0, n_rows, row_batch) + pass1_iter = ( + tqdm(row_starts, desc="Counting non-zeros in normalized matrix", unit="batch") + if tqdm is not None + else row_starts + ) + + for start in pass1_iter: + end = min(start + row_batch, n_rows) + chunk = mat[start:end] + csr = chunk.tocsr() if sp.issparse(chunk) else sp.csr_matrix(chunk) + col_counts += np.bincount(csr.indices, minlength=n_cols) + + indptr = np.empty(n_cols + 1, dtype=np.int64) + indptr[0] = 0 + np.cumsum(col_counts, out=indptr[1:]) + total_nnz = int(indptr[-1]) + + d_indices = group.create_dataset( + "indices", + shape=(total_nnz,), + dtype=np.int32, + chunks=(chunk_size,) if total_nnz > 0 else None, + compression=hdf5plugin.LZ4() if total_nnz > 0 else None, + ) + d_data = group.create_dataset( + "data", + shape=(total_nnz,), + dtype=np.float32, + chunks=(chunk_size,) if total_nnz > 0 else None, + compression=hdf5plugin.LZ4() if total_nnz > 0 else None, + ) + + # --- Pass 2+: column-group scatter --- + max_nnz_per_group = max(1, WRITE_MEM_BUDGET // 8) + + c_start = 0 + group_idx = 0 + while c_start < n_cols: + cumulative = np.cumsum(col_counts[c_start:]) + over_budget = np.searchsorted(cumulative, max_nnz_per_group, side="right") + c_end = c_start + max(1, int(over_budget)) + c_end = min(c_end, n_cols) + + group_nnz = int(indptr[c_end] - indptr[c_start]) + if group_nnz == 0: + c_start = c_end + group_idx += 1 + continue + + grp_indices = np.empty(group_nnz, dtype=np.int32) + grp_data = np.empty(group_nnz, dtype=np.float32) + + n_group_cols = c_end - c_start + cursors = (indptr[c_start:c_end] - indptr[c_start]).copy() + + desc = ( + f"Writing normalized counts (group {group_idx + 1})" + if c_end < n_cols or c_start > 0 + else "Writing normalized counts to H5 artifact" + ) + pass2_iter = ( + tqdm(row_starts, desc=desc, unit="batch") + if tqdm is not None + else row_starts + ) + + for start in pass2_iter: + end = min(start + row_batch, n_rows) + chunk = mat[start:end] + csr = chunk.tocsr() if sp.issparse(chunk) else sp.csr_matrix(chunk) + csc = csr.tocsc() + + local_start_ptr = csc.indptr[c_start] + local_end_ptr = csc.indptr[c_end] + if local_start_ptr == local_end_ptr: + continue + + slice_indices = csc.indices[local_start_ptr:local_end_ptr] + start + slice_data = csc.data[local_start_ptr:local_end_ptr].astype( + np.float32, copy=False + ) + slice_indptr = csc.indptr[c_start : c_end + 1] - local_start_ptr + + chunk_col_counts = np.diff(slice_indptr) + col_ids = np.repeat( + np.arange(n_group_cols, dtype=np.int64), chunk_col_counts + ) + bases = cursors[col_ids] + local_offsets = np.arange(len(slice_indices), dtype=np.int64) - np.repeat( + slice_indptr[:-1].astype(np.int64), chunk_col_counts + ) + targets = bases + local_offsets + + grp_indices[targets] = slice_indices + grp_data[targets] = slice_data + cursors += chunk_col_counts + + hdf5_offset = int(indptr[c_start]) + d_indices[hdf5_offset : hdf5_offset + group_nnz] = grp_indices + d_data[hdf5_offset : hdf5_offset + group_nnz] = grp_data + + del grp_indices, grp_data + c_start = c_end + group_idx += 1 + + group.create_dataset("indptr", data=indptr) + + +def _write_csc_via_col_batches( + group: h5py.Group, + mat: Any, + n_rows: int, + n_cols: int, + min_chunk_size: int, + col_batch: int | None, +) -> None: + if col_batch is None: + col_batch = max(1, int(1_000_000_000 / max(n_rows, 1))) + + chunk_size = max(1, min(n_rows * 10, min_chunk_size)) + + d_indices = group.create_dataset( + "indices", + shape=(0,), + maxshape=(n_rows * n_cols,), + chunks=(chunk_size,), + dtype=np.int32, + compression=hdf5plugin.LZ4(), + ) + d_data = group.create_dataset( + "data", + shape=(0,), + maxshape=(n_rows * n_cols,), + chunks=(chunk_size,), + dtype=np.float32, + compression=hdf5plugin.LZ4(), + ) + + indptr: list[int] = [0] + col_starts = range(0, n_cols, col_batch) + tqdm = _try_import_tqdm() + col_iter = ( + tqdm(col_starts, desc="Writing normalized counts", unit="batch") + if tqdm is not None + else col_starts + ) + + for start in col_iter: + end = min(start + col_batch, n_cols) + raw_chunk = mat[:, start:end] + chunk = ( + raw_chunk.tocsc() if sp.issparse(raw_chunk) else sp.csc_matrix(raw_chunk) + ) + + old_size = d_indices.shape[0] + chunk_indices = chunk.indices.astype(np.int32, copy=False) + chunk_data = chunk.data.astype(np.float32, copy=False) + + d_indices.resize(old_size + len(chunk_indices), axis=0) + d_indices[old_size : old_size + len(chunk_indices)] = chunk_indices + + d_data.resize(old_size + len(chunk_data), axis=0) + d_data[old_size : old_size + len(chunk_data)] = chunk_data + + indptr.extend((chunk.indptr[1:] + indptr[-1]).tolist()) + + group.create_dataset("indptr", data=np.asarray(indptr, dtype=np.int64)) + + def save_features_matrix( out_file: str, mat: Any, @@ -225,76 +423,26 @@ def save_features_matrix( n_rows, n_cols = mat.shape + use_row_batch_path = isinstance(mat, CSRDataset) + if not isinstance(mat, (CSRDataset, CSCDataset)): logger.warning( "For large datasets, use AnnData backed mode (e.g. sc.read_h5ad(..., backed='r')) " "so `adata.X` is a backed sparse dataset and avoids loading the full matrix in memory.", ) - if col_batch is None: - col_batch = max(1, int(1_000_000_000 / max(n_rows, 1))) - - chunk_size = max(1, min(n_rows * 10, min_chunk_size)) with h5py.File(out_file, "w") as f: group = f.create_group("vars") group.attrs["n_obs"] = n_rows group.attrs["n_vars"] = n_cols - d_indices = group.create_dataset( - "indices", - shape=(0,), - maxshape=(n_rows * n_cols,), - chunks=(chunk_size,), - dtype=np.int32, - compression=hdf5plugin.LZ4(), - ) - d_data = group.create_dataset( - "data", - shape=(0,), - maxshape=(n_rows * n_cols,), - chunks=(chunk_size,), - dtype=np.float32, - compression=hdf5plugin.LZ4(), - ) - - indptr: list[int] = [0] - col_starts = range(0, n_cols, col_batch) - try: - import warnings - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=FutureWarning) - from tqdm.auto import tqdm - - col_iter = tqdm( - col_starts, desc="Writing pivoted normalized counts", unit="batch" - ) - except ImportError: - col_iter = col_starts - - for start in col_iter: - end = min(start + col_batch, n_cols) - raw_chunk = mat[:, start:end] - chunk = ( - raw_chunk.tocsc() - if sp.issparse(raw_chunk) - else sp.csc_matrix(raw_chunk) + if use_row_batch_path: + _write_csc_via_row_batches(group, mat, n_rows, n_cols, min_chunk_size) + else: + _write_csc_via_col_batches( + group, mat, n_rows, n_cols, min_chunk_size, col_batch ) - old_size = d_indices.shape[0] - chunk_indices = chunk.indices.astype(np.int32, copy=False) - chunk_data = chunk.data.astype(np.float32, copy=False) - - d_indices.resize(old_size + len(chunk_indices), axis=0) - d_indices[old_size : old_size + len(chunk_indices)] = chunk_indices - - d_data.resize(old_size + len(chunk_data), axis=0) - d_data[old_size : old_size + len(chunk_data)] = chunk_data - - indptr.extend((chunk.indptr[1:] + indptr[-1]).tolist()) - - group.create_dataset("indptr", data=np.asarray(indptr, dtype=np.int64)) - if var_df is not None: _write_var_metadata( out_file_group=f, diff --git a/cytetype/main.py b/cytetype/main.py index 7e17d42..9862367 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -220,7 +220,6 @@ def __init__( # Build vars.h5 try: - logger.info("Saving vars.h5 artifact from normalized counts...") raw_mat, raw_col_indices = ( self._raw_counts_result if self._raw_counts_result is not None @@ -243,7 +242,7 @@ def __init__( # Build obs.duckdb try: - logger.info("Writing obs.duckdb artifact from observation metadata...") + logger.info("Writing obs data to duckdb artifact...") obsm_coordinates = ( self.adata.obsm[self.coordinates_key] if self.coordinates_key and self.coordinates_key in self.adata.obsm diff --git a/cytetype/preprocessing/marker_detection.py b/cytetype/preprocessing/marker_detection.py index abcb125..e5ac7f0 100644 --- a/cytetype/preprocessing/marker_detection.py +++ b/cytetype/preprocessing/marker_detection.py @@ -49,7 +49,7 @@ def _accumulate_group_stats( import warnings with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=FutureWarning) + warnings.simplefilter("ignore") from tqdm.auto import tqdm chunk_iter = tqdm(chunk_starts, desc=progress_desc, unit="chunk") @@ -166,11 +166,6 @@ def expm1_func(x: np.ndarray) -> np.ndarray: return out # --- accumulate sufficient statistics in one pass --- - logger.info( - "Accumulating statistics over {} cells in chunks of {}...", - n_cells, - cell_batch_size, - ) stats = _accumulate_group_stats( X, cell_group_indices, diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py index 968c781..5cdf666 100644 --- a/tests/test_artifacts.py +++ b/tests/test_artifacts.py @@ -1,3 +1,4 @@ +import pytest import h5py import anndata import numpy as np @@ -107,6 +108,77 @@ def test_save_features_matrix_raw_with_float_integers( assert f["raw"]["data"].dtype == np.int32 +def test_save_features_matrix_backed_csr(tmp_path: Path) -> None: + n_obs, n_vars = 200, 80 + rng = np.random.default_rng(42) + mat = sp.random(n_obs, n_vars, density=0.3, format="csr", random_state=rng) + mat.data = rng.standard_normal(mat.nnz).astype(np.float32) + reference_csc = mat.tocsc() + + h5ad_path = tmp_path / "backed.h5ad" + adata = anndata.AnnData(X=mat) + adata.write_h5ad(h5ad_path) + del adata + + backed = anndata.read_h5ad(h5ad_path, backed="r") + out_path = tmp_path / "vars.h5" + save_features_matrix(out_file=str(out_path), mat=backed.X) + backed.file.close() + + with h5py.File(out_path, "r") as f: + grp = f["vars"] + assert grp.attrs["n_obs"] == n_obs + assert grp.attrs["n_vars"] == n_vars + + h5_indices = grp["indices"][:] + h5_data = grp["data"][:] + h5_indptr = grp["indptr"][:] + + assert h5_indices.dtype == np.int32 + assert h5_data.dtype == np.float32 + assert len(h5_indptr) == n_vars + 1 + assert h5_indptr[0] == 0 + assert h5_indptr[-1] == reference_csc.nnz + + np.testing.assert_array_equal(h5_indptr, reference_csc.indptr) + np.testing.assert_array_equal(h5_indices, reference_csc.indices) + np.testing.assert_allclose(h5_data, reference_csc.data, rtol=1e-6) + + +def test_save_features_matrix_backed_csr_multiple_groups( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + import cytetype.config + + monkeypatch.setattr(cytetype.config, "WRITE_MEM_BUDGET", 256) + + n_obs, n_vars = 50, 30 + rng = np.random.default_rng(7) + mat = sp.random(n_obs, n_vars, density=0.4, format="csr", random_state=rng) + mat.data = rng.standard_normal(mat.nnz).astype(np.float32) + reference_csc = mat.tocsc() + + h5ad_path = tmp_path / "backed.h5ad" + adata = anndata.AnnData(X=mat) + adata.write_h5ad(h5ad_path) + del adata + + backed = anndata.read_h5ad(h5ad_path, backed="r") + out_path = tmp_path / "vars.h5" + save_features_matrix(out_file=str(out_path), mat=backed.X) + backed.file.close() + + with h5py.File(out_path, "r") as f: + grp = f["vars"] + h5_indptr = grp["indptr"][:] + h5_indices = grp["indices"][:] + h5_data = grp["data"][:] + + np.testing.assert_array_equal(h5_indptr, reference_csc.indptr) + np.testing.assert_array_equal(h5_indices, reference_csc.indices) + np.testing.assert_allclose(h5_data, reference_csc.data, rtol=1e-6) + + def test_is_integer_valued_with_true_integers() -> None: mat = sp.csr_matrix(np.array([[1, 0, 3], [0, 2, 0]], dtype=np.int32)) assert _is_integer_valued(mat) is True From f953a2863de04b01e081264d2920c6fba086cf2c Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 12:40:54 +0100 Subject: [PATCH 09/19] Enhance file upload functionality and error handling in CyteType - Increased maximum upload size for vars_h5 from 10GB to 50GB to accommodate larger datasets. - Introduced a new ClientDisconnectedError exception to handle client disconnection scenarios. - Improved progress reporting during file uploads by integrating tqdm for better user feedback. - Refactored upload logic to ensure consistent progress updates and error handling across different upload scenarios. --- cytetype/api/client.py | 66 +++++++++++++++------- cytetype/api/exceptions.py | 7 +++ cytetype/preprocessing/marker_detection.py | 23 ++++---- 3 files changed, 64 insertions(+), 32 deletions(-) diff --git a/cytetype/api/client.py b/cytetype/api/client.py index 52fd4d1..000db3d 100644 --- a/cytetype/api/client.py +++ b/cytetype/api/client.py @@ -14,19 +14,32 @@ MAX_UPLOAD_BYTES: dict[UploadFileKind, int] = { "obs_duckdb": 100 * 1024 * 1024, # 100MB - "vars_h5": 10 * 1024 * 1024 * 1024, # 10GB + "vars_h5": 50 * 1024 * 1024 * 1024, # 10GB } _CHUNK_RETRY_DELAYS = (1, 5, 20) _RETRYABLE_API_ERROR_CODES = frozenset({"INTERNAL_ERROR", "HTTP_ERROR"}) +def _try_import_tqdm() -> type | None: + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from tqdm.auto import tqdm + + return tqdm # type: ignore[no-any-return] + except ImportError: + return None + + def _upload_file( base_url: str, auth_token: str | None, file_kind: UploadFileKind, file_path: str, - timeout: float | tuple[float, float] = (30.0, 3600.0), + timeout: float | tuple[float, float] = (60.0, 3600.0), max_workers: int = 4, ) -> UploadResponse: path_obj = Path(file_path) @@ -62,6 +75,12 @@ def _upload_file( # Memory is bounded to ~max_workers × chunk_size because each thread # reads its chunk on demand via seek+read. _tls = threading.local() + tqdm_cls = _try_import_tqdm() + pbar = ( + tqdm_cls(total=n_chunks, desc="Uploading", unit="chunk") + if tqdm_cls is not None and n_chunks > 0 + else None + ) _progress_lock = threading.Lock() _chunks_done = [0] @@ -82,15 +101,18 @@ def _upload_chunk(chunk_idx: int) -> None: data=chunk_data, timeout=timeout, ) - with _progress_lock: - _chunks_done[0] += 1 - done = _chunks_done[0] - pct = 100 * done / n_chunks - print( - f"\r Uploading: {done}/{n_chunks} chunks ({pct:.0f}%)", - end="", - flush=True, - ) + if pbar is not None: + pbar.update(1) + else: + with _progress_lock: + _chunks_done[0] += 1 + done = _chunks_done[0] + pct = 100 * done / n_chunks + print( + f"\r Uploading: {done}/{n_chunks} chunks ({pct:.0f}%)", + end="", + flush=True, + ) return except (NetworkError, TimeoutError) as exc: last_exc = exc @@ -103,13 +125,9 @@ def _upload_chunk(chunk_idx: int) -> None: if attempt < len(_CHUNK_RETRY_DELAYS): delay = _CHUNK_RETRY_DELAYS[attempt] logger.warning( - "Chunk %d/%d upload failed (attempt %d/%d), retrying in %ds: %s", - chunk_idx + 1, - n_chunks, - attempt + 1, - 1 + len(_CHUNK_RETRY_DELAYS), - delay, - last_exc, + f"Chunk {chunk_idx + 1}/{n_chunks} upload failed " + f"(attempt {attempt + 1}/{1 + len(_CHUNK_RETRY_DELAYS)}), " + f"retrying in {delay}s: {last_exc}" ) time.sleep(delay) @@ -120,9 +138,17 @@ def _upload_chunk(chunk_idx: int) -> None: try: with ThreadPoolExecutor(max_workers=effective_workers) as pool: list(pool.map(_upload_chunk, range(n_chunks))) - print(f"\r \033[92m✓\033[0m Uploaded {n_chunks}/{n_chunks} chunks (100%)") + if pbar is not None: + pbar.close() + else: + print( + f"\r \033[92m✓\033[0m Uploaded {n_chunks}/{n_chunks} chunks (100%)" + ) except BaseException: - print() # ensure newline on failure + if pbar is not None: + pbar.close() + else: + print() raise # Step 3 – Complete upload (returns same UploadResponse shape as before) diff --git a/cytetype/api/exceptions.py b/cytetype/api/exceptions.py index 3ed2366..fe4c529 100644 --- a/cytetype/api/exceptions.py +++ b/cytetype/api/exceptions.py @@ -56,6 +56,12 @@ class LLMValidationError(APIError): pass +class ClientDisconnectedError(APIError): + """Server detected client disconnection mid-request - CLIENT_DISCONNECTED (HTTP 499).""" + + pass + + # Client-side errors with default messages class TimeoutError(CyteTypeError): """Client-side timeout waiting for results.""" @@ -87,6 +93,7 @@ def __init__( "JOB_NOT_FOUND": JobNotFoundError, "JOB_FAILED": JobFailedError, "LLM_VALIDATION_FAILED": LLMValidationError, + "CLIENT_DISCONNECTED": ClientDisconnectedError, "JOB_PROCESSING": APIError, # Generic - expected during polling "JOB_NOT_COMPLETED": APIError, # Generic "HTTP_ERROR": APIError, # Generic diff --git a/cytetype/preprocessing/marker_detection.py b/cytetype/preprocessing/marker_detection.py index e5ac7f0..727ae1d 100644 --- a/cytetype/preprocessing/marker_detection.py +++ b/cytetype/preprocessing/marker_detection.py @@ -64,18 +64,17 @@ def _accumulate_group_stats( chunk = np.asarray(chunk, dtype=np.float64) chunk_labels = cell_group_indices[start:end] - for g_idx in range(n_groups): - mask = chunk_labels == g_idx - if not mask.any(): - continue - g_data = chunk[mask] - n_[g_idx] += mask.sum() - if sum_ is not None: - sum_[g_idx] += g_data.sum(axis=0) - if sum_sq_ is not None: - sum_sq_[g_idx] += (g_data**2).sum(axis=0) - if nnz_ is not None: - nnz_[g_idx] += (g_data != 0).sum(axis=0) + batch_len = end - start + indicator = np.zeros((n_groups, batch_len), dtype=np.float64) + indicator[chunk_labels, np.arange(batch_len)] = 1.0 + + n_ += indicator.sum(axis=1).astype(np.int64) + if sum_ is not None: + sum_ += indicator @ chunk + if sum_sq_ is not None: + sum_sq_ += indicator @ (chunk**2) + if nnz_ is not None: + nnz_ += (indicator @ (chunk != 0).astype(np.float64)).astype(np.int64) return GroupStats(n=n_, nnz=nnz_, sum_=sum_, sum_sq=sum_sq_) From bcdbd08a7be4a9728f8f0cc836a4901ac97438ab Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 13:35:59 +0100 Subject: [PATCH 10/19] Add subsampling functionality to preprocessing module - Introduced a new `subsample_by_group` function in `subsampling.py` to limit the number of cells per group in an AnnData object. - Updated `__init__.py` to include `subsample_by_group` in the public API of the preprocessing module. - Enhanced error handling to check for the existence of the specified group key in the AnnData object. - Added logging to report the results of the subsampling process. --- cytetype/preprocessing/__init__.py | 2 + cytetype/preprocessing/subsampling.py | 77 +++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 cytetype/preprocessing/subsampling.py diff --git a/cytetype/preprocessing/__init__.py b/cytetype/preprocessing/__init__.py index e1391f6..abbf71f 100644 --- a/cytetype/preprocessing/__init__.py +++ b/cytetype/preprocessing/__init__.py @@ -2,6 +2,7 @@ from .aggregation import aggregate_expression_percentages, aggregate_cluster_metadata from .extraction import extract_marker_genes, extract_visualization_coordinates from .marker_detection import rank_genes_groups_backed +from .subsampling import subsample_by_group __all__ = [ "validate_adata", @@ -11,4 +12,5 @@ "extract_marker_genes", "extract_visualization_coordinates", "rank_genes_groups_backed", + "subsample_by_group", ] diff --git a/cytetype/preprocessing/subsampling.py b/cytetype/preprocessing/subsampling.py new file mode 100644 index 0000000..8c9ff7c --- /dev/null +++ b/cytetype/preprocessing/subsampling.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import anndata +import pandas as pd +from natsort import natsorted + +from ..config import logger + + +def subsample_by_group( + adata: anndata.AnnData, + group_key: str, + max_cells_per_group: int = 1000, + random_state: int = 0, +) -> anndata.AnnData: + """Subsample cells from an AnnData object, capping each group to a maximum count. + + Groups smaller than *max_cells_per_group* are kept intact. + + Parameters + ---------- + adata + The AnnData object to subsample. + group_key + Column in ``adata.obs`` that defines the groups (e.g. cluster labels). + max_cells_per_group + Maximum number of cells to retain per group. Groups with fewer cells + are included in full. + random_state + Seed for reproducible sampling. + + Returns + ------- + anndata.AnnData + A new in-memory AnnData object containing at most *max_cells_per_group* + cells per group. + """ + if group_key not in adata.obs.columns: + raise KeyError( + f"Group key '{group_key}' not found in adata.obs. " + f"Available columns: {list(adata.obs.columns)}" + ) + + is_backed = getattr(adata, "isbacked", False) + groups = natsorted(adata.obs[group_key].unique()) + subsampled: list[anndata.AnnData] = [] + + try: + import warnings + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from tqdm.auto import tqdm + + group_iter = tqdm(groups, desc="Subsampling groups", unit="group") + except ImportError: + group_iter = groups + + for group in group_iter: + mask = adata.obs[group_key] == group + n_cells = mask.sum() + + if n_cells > max_cells_per_group: + keep = pd.Series(False, index=adata.obs.index) + sampled = mask[mask].sample(n=max_cells_per_group, random_state=random_state) + keep[sampled.index] = True + subset = adata[keep] + else: + subset = adata[mask] + + subsampled.append(subset.to_memory() if is_backed else subset.copy()) + + result = anndata.concat(subsampled) + + logger.info(f"Subsampling complete: {adata.n_obs} -> {result.n_obs} cells") + + return result From 7112e164554dbc8604bd3bddf40a7d3cd16fcd9d Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 13:48:15 +0100 Subject: [PATCH 11/19] Refactor subsampling functionality and improve logging in preprocessing module - Enhanced the `subsample_by_group` function to optimize performance and memory usage during subsampling. - Improved logging to provide clearer insights into the subsampling process and results. - Updated error handling to ensure robustness when dealing with edge cases in AnnData objects. - Refactored related tests to validate the new subsampling logic and logging enhancements. --- cytetype/plotting/__init__.py | 5 ++ cytetype/plotting/dotplot.py | 109 ++++++++++++++++++++++++++++++++++ 2 files changed, 114 insertions(+) create mode 100644 cytetype/plotting/__init__.py create mode 100644 cytetype/plotting/dotplot.py diff --git a/cytetype/plotting/__init__.py b/cytetype/plotting/__init__.py new file mode 100644 index 0000000..14e1513 --- /dev/null +++ b/cytetype/plotting/__init__.py @@ -0,0 +1,5 @@ +from .dotplot import marker_dotplot + +__all__ = [ + "marker_dotplot", +] diff --git a/cytetype/plotting/dotplot.py b/cytetype/plotting/dotplot.py new file mode 100644 index 0000000..e20a5eb --- /dev/null +++ b/cytetype/plotting/dotplot.py @@ -0,0 +1,109 @@ +from typing import Any + +import anndata + +from ..config import logger +from ..core.results import load_local_results + + +def marker_dotplot( + adata: anndata.AnnData, + group_key: str, + results_prefix: str = "cytetype", + n_top_markers: int = 3, + gene_symbols: str | None = None, + **kwargs: Any, +) -> Any: + """Dotplot of top marker genes grouped by CyteType cluster categories. + + Reads stored CyteType results from ``adata.uns`` and builds a + category-grouped marker dict suitable for ``sc.pl.dotplot``. + + Parameters + ---------- + adata + AnnData object that has been annotated by CyteType + (results stored in ``adata.uns``). + group_key + The original cluster key used during annotation (e.g. ``"leiden"``). + results_prefix + Prefix used when the results were stored. Must match the + ``results_prefix`` passed to :meth:`CyteType.run`. + n_top_markers + Number of top supporting genes to display per cluster. + gene_symbols + Column in ``adata.var`` containing gene symbols. Forwarded to + ``sc.pl.dotplot`` via the ``gene_symbols`` parameter. + **kwargs + Additional keyword arguments forwarded to ``sc.pl.dotplot`` + (e.g. ``cmap``, ``use_raw``). + + Returns + ------- + The return value of ``sc.pl.dotplot``. + """ + results = load_local_results(adata, results_prefix) + if results is None: + raise KeyError( + f"No CyteType results found in adata.uns with prefix '{results_prefix}'. " + "Run CyteType annotation first or check the results_prefix." + ) + + cluster_categories = results.get("clusterCategories", []) + if not cluster_categories: + raise ValueError( + "No cluster categories found in CyteType results. " + "The API response may not include category groupings for this run." + ) + + raw_annotations = results.get("raw_annotations", {}) + if not raw_annotations: + raise ValueError( + "No raw_annotations found in CyteType results." + ) + + markers: dict[str, list[str]] = {} + categories_order: list[str] = [] + + for category in cluster_categories: + category_name = category["categoryName"] + markers[category_name] = [] + + for cluster_id in category["clusterIds"]: + cluster_data = raw_annotations.get(cluster_id) + if cluster_data is None: + logger.warning( + "Cluster '{}' listed in clusterCategories but missing from raw_annotations, skipping.", + cluster_id, + ) + continue + + full_output = ( + cluster_data["latest"]["annotation"]["fullOutput"] + ) + cell_type = full_output["cellType"] + + categories_order.append(cell_type["label"]) + markers[category_name].extend( + cell_type["keySupportingGenes"][:n_top_markers] + ) + + markers[category_name] = sorted(set(markers[category_name])) + + try: + import scanpy as sc + except ImportError: + raise ImportError( + "scanpy is required for plotting. Install it with: pip install scanpy" + ) from None + + groupby = f"{results_prefix}_annotation_{group_key}" + + return sc.pl.dotplot( + adata, + markers, + groupby=groupby, + gene_symbols=gene_symbols, + categories_order=categories_order, + **kwargs, + ) From aa57457454c1162ac54a145fbcd674239d402297 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 13:48:33 +0100 Subject: [PATCH 12/19] formatted --- cytetype/plotting/dotplot.py | 8 ++------ cytetype/preprocessing/subsampling.py | 4 +++- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/cytetype/plotting/dotplot.py b/cytetype/plotting/dotplot.py index e20a5eb..e24cbb5 100644 --- a/cytetype/plotting/dotplot.py +++ b/cytetype/plotting/dotplot.py @@ -58,9 +58,7 @@ def marker_dotplot( raw_annotations = results.get("raw_annotations", {}) if not raw_annotations: - raise ValueError( - "No raw_annotations found in CyteType results." - ) + raise ValueError("No raw_annotations found in CyteType results.") markers: dict[str, list[str]] = {} categories_order: list[str] = [] @@ -78,9 +76,7 @@ def marker_dotplot( ) continue - full_output = ( - cluster_data["latest"]["annotation"]["fullOutput"] - ) + full_output = cluster_data["latest"]["annotation"]["fullOutput"] cell_type = full_output["cellType"] categories_order.append(cell_type["label"]) diff --git a/cytetype/preprocessing/subsampling.py b/cytetype/preprocessing/subsampling.py index 8c9ff7c..1ebea4b 100644 --- a/cytetype/preprocessing/subsampling.py +++ b/cytetype/preprocessing/subsampling.py @@ -62,7 +62,9 @@ def subsample_by_group( if n_cells > max_cells_per_group: keep = pd.Series(False, index=adata.obs.index) - sampled = mask[mask].sample(n=max_cells_per_group, random_state=random_state) + sampled = mask[mask].sample( + n=max_cells_per_group, random_state=random_state + ) keep[sampled.index] = True subset = adata[keep] else: From bd92caee96be90311e7ff3dc375d2efbfda08c50 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 15:14:28 +0100 Subject: [PATCH 13/19] Update subsampling logic to merge subsets by taking the first occurrence in the preprocessing module - Modified the `subsample_by_group` function to use `merge="first"` when concatenating subsampled subsets, ensuring that the first occurrence of each observation is retained. - This change enhances the subsampling process by providing a more consistent output when merging groups. --- cytetype/preprocessing/subsampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cytetype/preprocessing/subsampling.py b/cytetype/preprocessing/subsampling.py index 1ebea4b..69a919e 100644 --- a/cytetype/preprocessing/subsampling.py +++ b/cytetype/preprocessing/subsampling.py @@ -72,7 +72,7 @@ def subsample_by_group( subsampled.append(subset.to_memory() if is_backed else subset.copy()) - result = anndata.concat(subsampled) + result = anndata.concat(subsampled, merge="first") logger.info(f"Subsampling complete: {adata.n_obs} -> {result.n_obs} cells") From bed8784c4f6c0b6b13aea3bc1e7e2499754ae57b Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 16:01:24 +0100 Subject: [PATCH 14/19] Enhance gene name processing in preprocessing module - Added `clean_gene_names` function to extract gene symbols from composite gene names, improving the handling of gene identifiers. - Updated `extract_marker_genes` to utilize `clean_gene_names` for better gene name management. - Integrated `clean_gene_names` into the `CyteType` class for consistent gene name processing across the module. - Enhanced logging to provide insights when composite gene values are cleaned. --- cytetype/main.py | 2 + cytetype/preprocessing/__init__.py | 3 +- cytetype/preprocessing/extraction.py | 10 +++- cytetype/preprocessing/validation.py | 74 ++++++++++++++++++++++------ 4 files changed, 70 insertions(+), 19 deletions(-) diff --git a/cytetype/main.py b/cytetype/main.py index 9862367..2600ed9 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -16,6 +16,7 @@ from .preprocessing import ( validate_adata, resolve_gene_symbols_column, + clean_gene_names, aggregate_expression_percentages, extract_marker_genes, aggregate_cluster_metadata, @@ -161,6 +162,7 @@ def __init__( if self.gene_symbols_column is not None else adata.var_names.tolist() ) + gene_names = clean_gene_names(gene_names) self.expression_percentages = aggregate_expression_percentages( adata=adata, clusters=self.clusters, diff --git a/cytetype/preprocessing/__init__.py b/cytetype/preprocessing/__init__.py index abbf71f..d25b339 100644 --- a/cytetype/preprocessing/__init__.py +++ b/cytetype/preprocessing/__init__.py @@ -1,4 +1,4 @@ -from .validation import validate_adata, resolve_gene_symbols_column +from .validation import validate_adata, resolve_gene_symbols_column, clean_gene_names from .aggregation import aggregate_expression_percentages, aggregate_cluster_metadata from .extraction import extract_marker_genes, extract_visualization_coordinates from .marker_detection import rank_genes_groups_backed @@ -7,6 +7,7 @@ __all__ = [ "validate_adata", "resolve_gene_symbols_column", + "clean_gene_names", "aggregate_expression_percentages", "aggregate_cluster_metadata", "extract_marker_genes", diff --git a/cytetype/preprocessing/extraction.py b/cytetype/preprocessing/extraction.py index 79d96f0..44f9c9c 100644 --- a/cytetype/preprocessing/extraction.py +++ b/cytetype/preprocessing/extraction.py @@ -2,6 +2,7 @@ import pandas as pd from ..config import logger +from .validation import _extract_symbol_from_composite, clean_gene_names def extract_marker_genes( @@ -46,9 +47,14 @@ def extract_marker_genes( ) if gene_symbols_col is not None: - gene_ids_to_name = adata.var[gene_symbols_col].to_dict() + raw_map = adata.var[gene_symbols_col].to_dict() + gene_ids_to_name = { + k: _extract_symbol_from_composite(str(v)) for k, v in raw_map.items() + } else: - gene_ids_to_name = dict(zip(adata.var_names, adata.var_names)) + raw_names = adata.var_names.tolist() + cleaned = clean_gene_names(raw_names) + gene_ids_to_name = dict(zip(adata.var_names, cleaned)) markers = {} any_genes_found = False diff --git a/cytetype/preprocessing/validation.py b/cytetype/preprocessing/validation.py index 2eff60c..fce5d5e 100644 --- a/cytetype/preprocessing/validation.py +++ b/cytetype/preprocessing/validation.py @@ -46,6 +46,35 @@ def _has_composite_gene_values(values: list[str]) -> bool: return len(values) > 0 and (composite_count / min(200, len(values))) > 0.5 +def _extract_symbol_from_composite(value: str) -> str: + parts = re.split(r"[_|]", value, maxsplit=1) + if len(parts) != 2: + return value + id_flags = [_is_gene_id_like(p) for p in parts] + if id_flags[0] and not id_flags[1]: + return parts[1] + if not id_flags[0] and id_flags[1]: + return parts[0] + return value + + +def clean_gene_names(names: list[str]) -> list[str]: + """Extract gene symbols from composite gene name/ID values. + + If >50% of values are composite (e.g. ``TSPAN6_ENSG00000000003``), + splits each value and returns the gene-symbol part. Non-composite + lists are returned unchanged. + """ + if not _has_composite_gene_values(names): + return names + cleaned = [_extract_symbol_from_composite(n) for n in names] + logger.info( + f"Cleaned {len(cleaned)} composite gene values " + f"(e.g., '{names[0]}' -> '{cleaned[0]}')." + ) + return cleaned + + def _id_like_percentage(values: list[str]) -> float: if not values: return 100.0 @@ -97,6 +126,12 @@ def resolve_gene_symbols_column( f"Available columns: {list(adata.var.columns)}. " f"Set gene_symbols_column=None for auto-detection." ) + values = adata.var[gene_symbols_column].dropna().astype(str).tolist() + if _has_composite_gene_values(values): + logger.info( + f"Column '{gene_symbols_column}' contains composite gene name/ID values " + f"(e.g., '{values[0]}'). Gene symbols will be extracted automatically." + ) _validate_gene_symbols_column(adata, gene_symbols_column) logger.info(f"Using gene symbols from column '{gene_symbols_column}'.") return gene_symbols_column @@ -119,20 +154,24 @@ def resolve_gene_symbols_column( values = adata.var[col].dropna().astype(str).tolist() if not values: continue - if _has_composite_gene_values(values): - logger.warning( - f"Column '{col}' appears to contain composite gene name/ID values " - f"(e.g., '{values[0]}'). Skipping." - ) - continue - pct = _id_like_percentage(values) - unique_ratio = len(set(values)) / len(values) + score_values = ( + [_extract_symbol_from_composite(v) for v in values] + if _has_composite_gene_values(values) + else values + ) + pct = _id_like_percentage(score_values) + unique_ratio = len(set(score_values)) / len(score_values) candidates.append((col, pct, unique_ratio, 0)) var_names_list = adata.var_names.astype(str).tolist() if var_names_list: - var_id_pct = _id_like_percentage(var_names_list) - var_unique_ratio = len(set(var_names_list)) / len(var_names_list) + var_score_values = ( + [_extract_symbol_from_composite(v) for v in var_names_list] + if _has_composite_gene_values(var_names_list) + else var_names_list + ) + var_id_pct = _id_like_percentage(var_score_values) + var_unique_ratio = len(set(var_score_values)) / len(var_score_values) candidates.append((None, var_id_pct, var_unique_ratio, 1)) for col in adata.var.columns: @@ -144,13 +183,16 @@ def resolve_gene_symbols_column( continue if not values: continue - if _has_composite_gene_values(values): - continue - n_unique = len(set(values)) - if n_unique < max(10, len(values) * 0.05): + score_values = ( + [_extract_symbol_from_composite(v) for v in values] + if _has_composite_gene_values(values) + else values + ) + n_unique = len(set(score_values)) + if n_unique < max(10, len(score_values) * 0.05): continue - pct = _id_like_percentage(values) - unique_ratio = n_unique / len(values) + pct = _id_like_percentage(score_values) + unique_ratio = n_unique / len(score_values) candidates.append((col, pct, unique_ratio, 2)) viable = [c for c in candidates if c[1] < 50] From 2bc6628097b9bbae1401b5bcc501c2ff4ee0e701 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 16:06:42 +0100 Subject: [PATCH 15/19] Optimize group statistics accumulation for sparse matrices in marker detection - Enhanced the `_accumulate_group_stats` function to handle both sparse and dense matrix inputs efficiently. - Implemented conditional logic to process sparse matrices using CSR format, improving memory usage and performance. - Maintained existing functionality for dense matrices, ensuring compatibility with previous implementations. --- cytetype/preprocessing/marker_detection.py | 52 ++++++++++++++++------ 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/cytetype/preprocessing/marker_detection.py b/cytetype/preprocessing/marker_detection.py index 727ae1d..6aa8d1a 100644 --- a/cytetype/preprocessing/marker_detection.py +++ b/cytetype/preprocessing/marker_detection.py @@ -7,6 +7,7 @@ import numpy as np import pandas as pd from natsort import natsorted +import scipy.sparse as sp from scipy.stats import ttest_ind_from_stats from ..config import logger @@ -59,22 +60,45 @@ def _accumulate_group_stats( for start in chunk_iter: end = min(start + cell_batch_size, n_cells) chunk = X[start:end] - if hasattr(chunk, "toarray"): - chunk = chunk.toarray() - chunk = np.asarray(chunk, dtype=np.float64) chunk_labels = cell_group_indices[start:end] - batch_len = end - start - indicator = np.zeros((n_groups, batch_len), dtype=np.float64) - indicator[chunk_labels, np.arange(batch_len)] = 1.0 - - n_ += indicator.sum(axis=1).astype(np.int64) - if sum_ is not None: - sum_ += indicator @ chunk - if sum_sq_ is not None: - sum_sq_ += indicator @ (chunk**2) - if nnz_ is not None: - nnz_ += (indicator @ (chunk != 0).astype(np.float64)).astype(np.int64) + + if sp.issparse(chunk): + chunk = chunk.tocsr() + + ind = sp.csr_matrix( + (np.ones(batch_len, dtype=np.float64), + (chunk_labels, np.arange(batch_len))), + shape=(n_groups, batch_len), + ) + + n_ += np.asarray(ind.sum(axis=1), dtype=np.int64).ravel() + if sum_ is not None: + sum_ += np.asarray((ind @ chunk).toarray(), dtype=np.float64) + if sum_sq_ is not None: + sum_sq_ += np.asarray( + (ind @ chunk.multiply(chunk)).toarray(), dtype=np.float64 + ) + if nnz_ is not None: + binary = chunk.copy() + binary.eliminate_zeros() + binary.data = np.ones_like(binary.data, dtype=np.float64) + nnz_ += np.asarray((ind @ binary).toarray(), dtype=np.int64) + else: + if hasattr(chunk, "toarray"): + chunk = chunk.toarray() + chunk = np.asarray(chunk, dtype=np.float64) + + indicator = np.zeros((n_groups, batch_len), dtype=np.float64) + indicator[chunk_labels, np.arange(batch_len)] = 1.0 + + n_ += indicator.sum(axis=1).astype(np.int64) + if sum_ is not None: + sum_ += indicator @ chunk + if sum_sq_ is not None: + sum_sq_ += indicator @ (chunk ** 2) + if nnz_ is not None: + nnz_ += (indicator @ (chunk != 0).astype(np.float64)).astype(np.int64) return GroupStats(n=n_, nnz=nnz_, sum_=sum_, sum_sq=sum_sq_) From 9a5b89dfedc3005033dc01e9b2b8081cf3c27231 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 16:07:13 +0100 Subject: [PATCH 16/19] Increase default timeout for file uploads in CyteType - Updated the timeout settings in both `main.py` and `client.py` from 30 seconds to 60 seconds to allow for longer upload durations, improving reliability for larger files. --- cytetype/api/client.py | 4 ++-- cytetype/main.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cytetype/api/client.py b/cytetype/api/client.py index 000db3d..a6280b3 100644 --- a/cytetype/api/client.py +++ b/cytetype/api/client.py @@ -162,7 +162,7 @@ def upload_obs_duckdb( base_url: str, auth_token: str | None, file_path: str, - timeout: float | tuple[float, float] = (30.0, 3600.0), + timeout: float | tuple[float, float] = (60.0, 3600.0), max_workers: int = 4, ) -> UploadResponse: return _upload_file( @@ -179,7 +179,7 @@ def upload_vars_h5( base_url: str, auth_token: str | None, file_path: str, - timeout: float | tuple[float, float] = (30.0, 3600.0), + timeout: float | tuple[float, float] = (60.0, 3600.0), max_workers: int = 4, ) -> UploadResponse: return _upload_file( diff --git a/cytetype/main.py b/cytetype/main.py index 2600ed9..8c69ec8 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -300,7 +300,7 @@ def _upload_artifacts( """ uploaded: dict[str, str] = {} errors: list[tuple[str, Exception]] = list(self._artifact_build_errors) - timeout = (30.0, float(upload_timeout_seconds)) + timeout = (60.0, float(upload_timeout_seconds)) # --- vars.h5 upload --- if self._vars_h5_path is not None: From 1a6a3477ca80340011b9c34490ffcf282a6394ba Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 16:09:52 +0100 Subject: [PATCH 17/19] fomatted --- cytetype/preprocessing/marker_detection.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/cytetype/preprocessing/marker_detection.py b/cytetype/preprocessing/marker_detection.py index 6aa8d1a..264d607 100644 --- a/cytetype/preprocessing/marker_detection.py +++ b/cytetype/preprocessing/marker_detection.py @@ -67,8 +67,10 @@ def _accumulate_group_stats( chunk = chunk.tocsr() ind = sp.csr_matrix( - (np.ones(batch_len, dtype=np.float64), - (chunk_labels, np.arange(batch_len))), + ( + np.ones(batch_len, dtype=np.float64), + (chunk_labels, np.arange(batch_len)), + ), shape=(n_groups, batch_len), ) @@ -96,7 +98,7 @@ def _accumulate_group_stats( if sum_ is not None: sum_ += indicator @ chunk if sum_sq_ is not None: - sum_sq_ += indicator @ (chunk ** 2) + sum_sq_ += indicator @ (chunk**2) if nnz_ is not None: nnz_ += (indicator @ (chunk != 0).astype(np.float64)).astype(np.int64) From c24c9a48bc5905ae8c101bcc4ada1ec3bfd1a72c Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 16:10:34 +0100 Subject: [PATCH 18/19] Refactor subsampling logic in `_is_integer_valued` function to improve row selection - Updated the logic to select rows for sampling based on the number of rows in the input matrix. - Implemented random sampling when the number of rows exceeds the specified sample size, ensuring a more representative subset. - Maintained functionality for cases where the number of rows is less than or equal to the sample size. --- cytetype/core/artifacts.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/cytetype/core/artifacts.py b/cytetype/core/artifacts.py index ee4386c..fd08aab 100644 --- a/cytetype/core/artifacts.py +++ b/cytetype/core/artifacts.py @@ -118,8 +118,13 @@ def _is_integer_valued(mat: Any, sample_n_rows: int = 200) -> bool: return True n_rows = mat.shape[0] - row_end = min(sample_n_rows, n_rows) - chunk = mat[:row_end] + if n_rows <= sample_n_rows: + row_indices = np.arange(n_rows) + else: + rng = np.random.default_rng(42) + row_indices = np.sort(rng.choice(n_rows, size=sample_n_rows, replace=False)) + + chunk = mat[row_indices] if sp.issparse(chunk): sample = chunk.data From ef4e7bfb958a5083b3c15a098e5aa04dfb5396d7 Mon Sep 17 00:00:00 2001 From: parashardhapola Date: Tue, 3 Mar 2026 16:14:38 +0100 Subject: [PATCH 19/19] Update public API in `__init__.py` to include new plotting and subsampling functions - Added `marker_dotplot` and `subsample_by_group` to the `__all__` list, making them accessible for import. - This change enhances the module's functionality by exposing additional features for users. --- cytetype/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cytetype/__init__.py b/cytetype/__init__.py index 3a771d4..bbb4588 100644 --- a/cytetype/__init__.py +++ b/cytetype/__init__.py @@ -4,9 +4,11 @@ from .config import logger from .main import CyteType +from .plotting import marker_dotplot from .preprocessing.marker_detection import rank_genes_groups_backed +from .preprocessing.subsampling import subsample_by_group -__all__ = ["CyteType", "rank_genes_groups_backed"] +__all__ = ["CyteType", "marker_dotplot", "rank_genes_groups_backed", "subsample_by_group"] _PYPI_JSON_URL = "https://pypi.org/pypi/cytetype/json"