diff --git a/cytetype/__init__.py b/cytetype/__init__.py index 37dfad9..72ad186 100644 --- a/cytetype/__init__.py +++ b/cytetype/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.19.2" +__version__ = "0.19.3" import requests diff --git a/cytetype/core/artifacts.py b/cytetype/core/artifacts.py index c589fca..d508a99 100644 --- a/cytetype/core/artifacts.py +++ b/cytetype/core/artifacts.py @@ -40,6 +40,7 @@ def _write_var_metadata( n_cols: int, var_df: pd.DataFrame, var_names: pd.Index | Sequence[Any] | None, + gene_symbols_column: str | None = None, ) -> None: if len(var_df) != n_cols: raise ValueError( @@ -68,6 +69,8 @@ def _write_var_metadata( data=_as_string_values(var_df.index), dtype=text_dtype, ) + if gene_symbols_column is not None: + var_group.attrs["gene_symbols_column"] = gene_symbols_column columns_group = var_group.create_group("columns") for i, col_name in enumerate(var_df.columns): @@ -414,6 +417,7 @@ def save_features_matrix( mat: Any, var_df: pd.DataFrame | None = None, var_names: pd.Index | Sequence[Any] | None = None, + gene_symbols_column: str | None = None, raw_mat: Any | None = None, raw_col_indices: "np.ndarray | None" = None, raw_cell_batch: int = 2000, @@ -454,6 +458,7 @@ def save_features_matrix( n_cols=n_cols, var_df=var_df, var_names=var_names, + gene_symbols_column=gene_symbols_column, ) if raw_mat is not None: diff --git a/cytetype/main.py b/cytetype/main.py index be2ea27..55b577a 100644 --- a/cytetype/main.py +++ b/cytetype/main.py @@ -16,12 +16,12 @@ from .preprocessing import ( validate_adata, resolve_gene_symbols_column, - clean_gene_names, aggregate_expression_percentages, extract_marker_genes, aggregate_cluster_metadata, extract_visualization_coordinates, ) +from .preprocessing.validation import materialize_canonical_gene_symbols_column from .core.payload import build_annotation_payload, save_query_to_file from .core.artifacts import ( _is_integer_valued, @@ -140,136 +140,163 @@ def __init__( self.api_url = api_url self.auth_token = auth_token self._artifact_build_errors: list[tuple[str, Exception]] = [] + self._vars_h5_path: str | None = None + self._obs_duckdb_path: str | None = None + self._original_gene_symbols_column: str | None = None + self._temporary_gene_symbols_column: str | None = None - self.gene_symbols_column = resolve_gene_symbols_column( - adata, gene_symbols_column - ) - - self.coordinates_key = validate_adata( - adata, group_key, rank_key, self.gene_symbols_column, coordinates_key - ) + try: + self.gene_symbols_column = resolve_gene_symbols_column( + adata, gene_symbols_column + ) + self._original_gene_symbols_column = self.gene_symbols_column - # Use original labels as IDs if all are short (<=3 chars), otherwise enumerate - _unique_group_categories: list[str | int] = natsorted( - adata.obs[group_key].unique().tolist() - ) - _short_ids = all(len(str(x)) <= 3 for x in _unique_group_categories) - self.cluster_map = { - str(x): str(x) if _short_ids else str(n) - for n, x in enumerate(_unique_group_categories) - } - self.clusters = [ - self.cluster_map[str(x)] for x in adata.obs[group_key].values.tolist() - ] - - gene_names = ( - adata.var[self.gene_symbols_column].tolist() - if self.gene_symbols_column is not None - else adata.var_names.tolist() - ) - gene_names = clean_gene_names(gene_names) - self.expression_percentages = aggregate_expression_percentages( - adata=adata, - clusters=self.clusters, - gene_names=gene_names, - cell_batch_size=pcent_batch_size, - ) + self.coordinates_key = validate_adata( + adata, group_key, rank_key, self.gene_symbols_column, coordinates_key + ) + ( + self.gene_symbols_column, + self._original_gene_symbols_column, + ) = materialize_canonical_gene_symbols_column( + adata, self.gene_symbols_column + ) + self._temporary_gene_symbols_column = self.gene_symbols_column - logger.info("Extracting marker genes...") - self.marker_genes = extract_marker_genes( - adata=self.adata, - cell_group_key=self.group_key, - rank_genes_key=self.rank_key, - cluster_map=self.cluster_map, - n_top_genes=n_top_genes, - gene_symbols_col=self.gene_symbols_column, - ) + # Use original labels as IDs if all are short (<=3 chars), otherwise enumerate + _unique_group_categories: list[str | int] = natsorted( + adata.obs[group_key].unique().tolist() + ) + _short_ids = all(len(str(x)) <= 3 for x in _unique_group_categories) + self.cluster_map = { + str(x): str(x) if _short_ids else str(n) + for n, x in enumerate(_unique_group_categories) + } + self.clusters = [ + self.cluster_map[str(x)] for x in adata.obs[group_key].values.tolist() + ] + + gene_names = adata.var[self.gene_symbols_column].tolist() + self.expression_percentages = aggregate_expression_percentages( + adata=adata, + clusters=self.clusters, + gene_names=gene_names, + cell_batch_size=pcent_batch_size, + ) - if aggregate_metadata: - logger.info("Aggregating cluster metadata...") - self.group_metadata = aggregate_cluster_metadata( + logger.info("Extracting marker genes...") + self.marker_genes = extract_marker_genes( adata=self.adata, - group_key=self.group_key, - min_percentage=min_percentage, - max_categories=max_metadata_categories, + cell_group_key=self.group_key, + rank_genes_key=self.rank_key, + cluster_map=self.cluster_map, + n_top_genes=n_top_genes, + gene_symbols_col=self.gene_symbols_column, ) - # Replace keys in group_metadata using cluster_map - self.group_metadata = { - self.cluster_map.get(str(key), str(key)): value - for key, value in self.group_metadata.items() - } - self.group_metadata = { - k: self.group_metadata[k] for k in sorted(self.group_metadata.keys()) + + if aggregate_metadata: + logger.info("Aggregating cluster metadata...") + self.group_metadata = aggregate_cluster_metadata( + adata=self.adata, + group_key=self.group_key, + min_percentage=min_percentage, + max_categories=max_metadata_categories, + ) + # Replace keys in group_metadata using cluster_map + self.group_metadata = { + self.cluster_map.get(str(key), str(key)): value + for key, value in self.group_metadata.items() + } + self.group_metadata = { + k: self.group_metadata[k] + for k in sorted(self.group_metadata.keys()) + } + else: + self.group_metadata = {} + + # Prepare visualization data with sampling + sampled_coordinates, sampled_cluster_labels = ( + extract_visualization_coordinates( + adata=adata, + coordinates_key=self.coordinates_key, + group_key=self.group_key, + cluster_map=self.cluster_map, + max_cells_per_group=self.max_cells_per_group, + ) + ) + + self.visualization_data = { + "coordinates": sampled_coordinates, + "clusters": sampled_cluster_labels, } - else: - self.group_metadata = {} - - # Prepare visualization data with sampling - sampled_coordinates, sampled_cluster_labels = extract_visualization_coordinates( - adata=adata, - coordinates_key=self.coordinates_key, - group_key=self.group_key, - cluster_map=self.cluster_map, - max_cells_per_group=self.max_cells_per_group, - ) - self.visualization_data = { - "coordinates": sampled_coordinates, - "clusters": sampled_cluster_labels, - } + # Resolve raw counts once and cache + self._raw_counts_result = self._resolve_raw_counts() + if self._raw_counts_result is None: + logger.warning( + "No integer raw counts found in adata.layers['counts'], " + "adata.raw.X, or adata.X. Skipping raw counts in vars.h5." + ) - # Resolve raw counts once and cache - self._raw_counts_result = self._resolve_raw_counts() - if self._raw_counts_result is None: - logger.warning( - "No integer raw counts found in adata.layers['counts'], " - "adata.raw.X, or adata.X. Skipping raw counts in vars.h5." - ) + # Build vars.h5 + try: + raw_mat, raw_col_indices = ( + self._raw_counts_result + if self._raw_counts_result is not None + else (None, None) + ) + save_features_matrix( + out_file=vars_h5_path, + mat=self.adata.X, + var_df=self.adata.var, + var_names=self.adata.var_names, + raw_mat=raw_mat, + raw_col_indices=raw_col_indices, + gene_symbols_column=self.gene_symbols_column, + ) + sys.stderr.flush() + self._vars_h5_path = vars_h5_path + except Exception as exc: + logger.warning(f"vars.h5 artifact failed during build: {exc}") + self._artifact_build_errors.append(("vars_h5", exc)) - # Build vars.h5 - try: - raw_mat, raw_col_indices = ( - self._raw_counts_result - if self._raw_counts_result is not None - else (None, None) - ) - save_features_matrix( - out_file=vars_h5_path, - mat=self.adata.X, - var_df=self.adata.var, - var_names=self.adata.var_names, - raw_mat=raw_mat, - raw_col_indices=raw_col_indices, - ) - sys.stderr.flush() - self._vars_h5_path: str | None = vars_h5_path - except Exception as exc: - logger.warning(f"vars.h5 artifact failed during build: {exc}") - self._vars_h5_path = None - self._artifact_build_errors.append(("vars_h5", exc)) - - # Build obs.duckdb - try: - logger.info("Writing obs data to duckdb artifact...") - obsm_coordinates = ( - self.adata.obsm[self.coordinates_key] - if self.coordinates_key and self.coordinates_key in self.adata.obsm - else None - ) - save_obs_duckdb_file( - out_file=obs_duckdb_path, - obs_df=self.adata.obs, - obsm_coordinates=obsm_coordinates, - coordinates_key=self.coordinates_key, + # Build obs.duckdb + try: + logger.info("Writing obs data to duckdb artifact...") + obsm_coordinates = ( + self.adata.obsm[self.coordinates_key] + if self.coordinates_key and self.coordinates_key in self.adata.obsm + else None + ) + save_obs_duckdb_file( + out_file=obs_duckdb_path, + obs_df=self.adata.obs, + obsm_coordinates=obsm_coordinates, + coordinates_key=self.coordinates_key, + ) + sys.stderr.flush() + self._obs_duckdb_path = obs_duckdb_path + except Exception as exc: + logger.warning(f"obs.duckdb artifact failed during build: {exc}") + self._artifact_build_errors.append(("obs_duckdb", exc)) + + logger.info("Data preparation completed. Ready for submitting jobs.") + except Exception: + self._cleanup_temporary_gene_symbols_column() + raise + + def _cleanup_temporary_gene_symbols_column(self) -> None: + temp_column = self._temporary_gene_symbols_column + if temp_column is None: + return + + if temp_column in self.adata.var.columns: + del self.adata.var[temp_column] + logger.info( + f"Deleted temporary canonical gene-symbol column '{temp_column}'." ) - sys.stderr.flush() - self._obs_duckdb_path: str | None = obs_duckdb_path - except Exception as exc: - logger.warning(f"obs.duckdb artifact failed during build: {exc}") - self._obs_duckdb_path = None - self._artifact_build_errors.append(("obs_duckdb", exc)) - logger.info("Data preparation completed. Ready for submitting jobs.") + self.gene_symbols_column = self._original_gene_symbols_column + self._temporary_gene_symbols_column = None def _resolve_raw_counts( self, @@ -356,7 +383,8 @@ def cleanup(self) -> None: """Delete the artifact files built during initialization. Call this after run() completes to remove the vars.h5 and obs.duckdb - files from disk. Paths are cleared so repeated calls are safe. + files from disk and drop the temporary canonical gene-symbol column. + Paths are cleared so repeated calls are safe. """ for attr, path in [ ("_vars_h5_path", self._vars_h5_path), @@ -370,6 +398,8 @@ def cleanup(self) -> None: logger.warning(f"Failed to delete artifact {path}: {exc}") setattr(self, attr, None) + self._cleanup_temporary_gene_symbols_column() + def run( self, study_context: str, diff --git a/cytetype/preprocessing/validation.py b/cytetype/preprocessing/validation.py index 075689f..bbede86 100644 --- a/cytetype/preprocessing/validation.py +++ b/cytetype/preprocessing/validation.py @@ -13,6 +13,7 @@ "gene_name", "symbol", ] +_CANONICAL_GENE_SYMBOLS_COLUMN = "__cytetype_gene_symbols" def _is_gene_id_like(value: str) -> bool: @@ -77,6 +78,38 @@ def clean_gene_names(names: list[str]) -> list[str]: return cleaned +def _temporary_gene_symbols_column_name(adata: anndata.AnnData) -> str: + candidate = _CANONICAL_GENE_SYMBOLS_COLUMN + suffix = 1 + while candidate in adata.var.columns: + candidate = f"{_CANONICAL_GENE_SYMBOLS_COLUMN}_{suffix}" + suffix += 1 + return candidate + + +def materialize_canonical_gene_symbols_column( + adata: anndata.AnnData, gene_symbols_column: str | None +) -> tuple[str, str | None]: + """Create a temporary canonical gene-symbol column in ``adata.var``.""" + if gene_symbols_column is None: + source_values = adata.var_names.astype(str).tolist() + source_name = "adata.var_names" + else: + source_values = [ + str(value) + for value in adata.var[gene_symbols_column].astype("string").fillna("") + ] + source_name = f"column '{gene_symbols_column}'" + + canonical_column = _temporary_gene_symbols_column_name(adata) + adata.var[canonical_column] = clean_gene_names(source_values) + logger.info( + f"Materialized canonical gene symbols in temporary column '{canonical_column}' " + f"from {source_name}." + ) + return canonical_column, gene_symbols_column + + def _id_like_percentage(values: list[str], seed: int = 42) -> float: if not values: return 100.0 diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py index 5cdf666..6bde248 100644 --- a/tests/test_artifacts.py +++ b/tests/test_artifacts.py @@ -19,6 +19,7 @@ def test_save_features_matrix_writes_var_metadata( mat=mock_adata.X, var_df=mock_adata.var, var_names=mock_adata.var_names, + gene_symbols_column="gene_symbols", col_batch=10, ) @@ -29,6 +30,7 @@ def test_save_features_matrix_writes_var_metadata( assert "columns" in f["info/var"] assert len(f["info/var/var_names"]) == mock_adata.n_vars assert len(f["info/var/index"]) == mock_adata.n_vars + assert f["info/var"].attrs["gene_symbols_column"] == "gene_symbols" columns_group = f["info/var/columns"] assert len(columns_group.keys()) == mock_adata.var.shape[1] @@ -37,6 +39,43 @@ def test_save_features_matrix_writes_var_metadata( assert "source_dtype" in dataset.attrs +def test_save_features_matrix_omits_gene_symbols_attr_when_not_provided( + tmp_path: Path, + mock_adata: anndata.AnnData, +) -> None: + out_path = tmp_path / "vars.h5" + + save_features_matrix( + out_file=str(out_path), + mat=mock_adata.X, + var_df=mock_adata.var, + var_names=mock_adata.var_names, + gene_symbols_column=None, + col_batch=10, + ) + + with h5py.File(out_path, "r") as f: + assert "gene_symbols_column" not in f["info/var"].attrs + + +def test_save_features_matrix_omits_gene_symbols_attr_when_omitted( + tmp_path: Path, + mock_adata: anndata.AnnData, +) -> None: + out_path = tmp_path / "vars.h5" + + save_features_matrix( + out_file=str(out_path), + mat=mock_adata.X, + var_df=mock_adata.var, + var_names=mock_adata.var_names, + col_batch=10, + ) + + with h5py.File(out_path, "r") as f: + assert "gene_symbols_column" not in f["info/var"].attrs + + def test_save_features_matrix_writes_raw_group( tmp_path: Path, mock_adata: anndata.AnnData, diff --git a/tests/test_cytetype_integration.py b/tests/test_cytetype_integration.py index 0b6993c..d7dfa12 100644 --- a/tests/test_cytetype_integration.py +++ b/tests/test_cytetype_integration.py @@ -6,6 +6,7 @@ from pydantic import ValidationError from typing import Any import anndata +import scanpy as sc from cytetype import CyteType from cytetype.api.exceptions import RateLimitError, AuthenticationError @@ -31,7 +32,13 @@ def test_cytetype_initialization(mock_adata: anndata.AnnData) -> None: assert ct.adata is mock_adata assert ct.group_key == "leiden" assert ct.rank_key == "rank_genes_groups" - assert ct.gene_symbols_column == "gene_symbols" + assert ct.gene_symbols_column is not None + assert ct.gene_symbols_column.startswith("__cytetype_gene_symbols") + assert ct.gene_symbols_column in mock_adata.var.columns + assert ( + mock_adata.var[ct.gene_symbols_column].tolist() + == mock_adata.var["gene_symbols"].tolist() + ) # Verify data preparation completed assert len(ct.clusters) == len(mock_adata) @@ -77,6 +84,57 @@ def _upload_vars(*args: Any, **kwargs: Any) -> UploadResponse: monkeypatch.setattr("cytetype.main.upload_vars_h5_file", _upload_vars) +def test_cytetype_materializes_canonical_column_from_composite_source( + mock_adata: anndata.AnnData, monkeypatch: pytest.MonkeyPatch +) -> None: + captured: dict[str, Any] = {} + original_var_names = mock_adata.var_names.tolist() + mock_adata.var["gene_symbols"] = [ + f"GENE{i}|{var_name}" for i, var_name in enumerate(original_var_names) + ] + + def _save_vars(*args: Any, **kwargs: Any) -> None: + captured.update(kwargs) + + monkeypatch.setattr("cytetype.main.save_features_matrix", _save_vars) + + ct = CyteType(mock_adata, group_key="leiden") + + assert ct.gene_symbols_column is not None + assert ct.gene_symbols_column.startswith("__cytetype_gene_symbols") + assert mock_adata.var[ct.gene_symbols_column].tolist() == [ + f"GENE{i}" for i in range(mock_adata.n_vars) + ] + assert captured["gene_symbols_column"] == ct.gene_symbols_column + + +def test_cytetype_materializes_canonical_column_from_composite_var_names( + mock_adata: anndata.AnnData, monkeypatch: pytest.MonkeyPatch +) -> None: + captured: dict[str, Any] = {} + original_var_names = mock_adata.var_names.tolist() + expected_gene_names = [f"GENE{i}" for i in range(mock_adata.n_vars)] + + mock_adata.var = mock_adata.var.drop(columns=["gene_symbols"]) + mock_adata.var_names = [ + f"{var_name}|{gene_name}" + for var_name, gene_name in zip(original_var_names, expected_gene_names) + ] + sc.tl.rank_genes_groups(mock_adata, "leiden", method="t-test") + + def _save_vars(*args: Any, **kwargs: Any) -> None: + captured.update(kwargs) + + monkeypatch.setattr("cytetype.main.save_features_matrix", _save_vars) + + ct = CyteType(mock_adata, group_key="leiden") + + assert ct.gene_symbols_column is not None + assert ct.gene_symbols_column.startswith("__cytetype_gene_symbols") + assert mock_adata.var[ct.gene_symbols_column].tolist() == expected_gene_names + assert captured["gene_symbols_column"] == ct.gene_symbols_column + + @patch("cytetype.main.wait_for_completion") @patch("cytetype.main.submit_annotation_job") def test_cytetype_run_success( @@ -241,6 +299,9 @@ def _save_obs(*args: Any, **kwargs: Any) -> None: assert vars_path.exists() assert obs_path.exists() + assert ct.gene_symbols_column is not None + assert ct.gene_symbols_column.startswith("__cytetype_gene_symbols") + assert ct.gene_symbols_column in ct.adata.var.columns ct.cleanup() @@ -248,6 +309,29 @@ def _save_obs(*args: Any, **kwargs: Any) -> None: assert not obs_path.exists() assert ct._vars_h5_path is None assert ct._obs_duckdb_path is None + assert "__cytetype_gene_symbols" not in ct.adata.var.columns + assert not any( + col.startswith("__cytetype_gene_symbols") for col in ct.adata.var.columns + ) + assert ct.gene_symbols_column == "gene_symbols" + + +def test_cytetype_init_failure_rolls_back_temporary_gene_symbols_column( + mock_adata: anndata.AnnData, monkeypatch: pytest.MonkeyPatch +) -> None: + original_columns = mock_adata.var.columns.tolist() + monkeypatch.setattr( + "cytetype.main.aggregate_expression_percentages", + MagicMock(side_effect=RuntimeError("aggregation failed")), + ) + + with pytest.raises(RuntimeError, match="aggregation failed"): + CyteType(mock_adata, group_key="leiden") + + assert mock_adata.var.columns.tolist() == original_columns + assert not any( + col.startswith("__cytetype_gene_symbols") for col in mock_adata.var.columns + ) @patch("cytetype.main.wait_for_completion")