Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cytetype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.19.2"
__version__ = "0.19.3"

import requests

Expand Down
5 changes: 5 additions & 0 deletions cytetype/core/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _write_var_metadata(
n_cols: int,
var_df: pd.DataFrame,
var_names: pd.Index | Sequence[Any] | None,
gene_symbols_column: str | None = None,
) -> None:
if len(var_df) != n_cols:
raise ValueError(
Expand Down Expand Up @@ -68,6 +69,8 @@ def _write_var_metadata(
data=_as_string_values(var_df.index),
dtype=text_dtype,
)
if gene_symbols_column is not None:
var_group.attrs["gene_symbols_column"] = gene_symbols_column

columns_group = var_group.create_group("columns")
for i, col_name in enumerate(var_df.columns):
Expand Down Expand Up @@ -414,6 +417,7 @@ def save_features_matrix(
mat: Any,
var_df: pd.DataFrame | None = None,
var_names: pd.Index | Sequence[Any] | None = None,
gene_symbols_column: str | None = None,
raw_mat: Any | None = None,
raw_col_indices: "np.ndarray | None" = None,
raw_cell_batch: int = 2000,
Expand Down Expand Up @@ -454,6 +458,7 @@ def save_features_matrix(
n_cols=n_cols,
var_df=var_df,
var_names=var_names,
gene_symbols_column=gene_symbols_column,
)

if raw_mat is not None:
Expand Down
270 changes: 150 additions & 120 deletions cytetype/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from .preprocessing import (
validate_adata,
resolve_gene_symbols_column,
clean_gene_names,
aggregate_expression_percentages,
extract_marker_genes,
aggregate_cluster_metadata,
extract_visualization_coordinates,
)
from .preprocessing.validation import materialize_canonical_gene_symbols_column
from .core.payload import build_annotation_payload, save_query_to_file
from .core.artifacts import (
_is_integer_valued,
Expand Down Expand Up @@ -140,136 +140,163 @@ def __init__(
self.api_url = api_url
self.auth_token = auth_token
self._artifact_build_errors: list[tuple[str, Exception]] = []
self._vars_h5_path: str | None = None
self._obs_duckdb_path: str | None = None
self._original_gene_symbols_column: str | None = None
self._temporary_gene_symbols_column: str | None = None

self.gene_symbols_column = resolve_gene_symbols_column(
adata, gene_symbols_column
)

self.coordinates_key = validate_adata(
adata, group_key, rank_key, self.gene_symbols_column, coordinates_key
)
try:
self.gene_symbols_column = resolve_gene_symbols_column(
adata, gene_symbols_column
)
self._original_gene_symbols_column = self.gene_symbols_column

# Use original labels as IDs if all are short (<=3 chars), otherwise enumerate
_unique_group_categories: list[str | int] = natsorted(
adata.obs[group_key].unique().tolist()
)
_short_ids = all(len(str(x)) <= 3 for x in _unique_group_categories)
self.cluster_map = {
str(x): str(x) if _short_ids else str(n)
for n, x in enumerate(_unique_group_categories)
}
self.clusters = [
self.cluster_map[str(x)] for x in adata.obs[group_key].values.tolist()
]

gene_names = (
adata.var[self.gene_symbols_column].tolist()
if self.gene_symbols_column is not None
else adata.var_names.tolist()
)
gene_names = clean_gene_names(gene_names)
self.expression_percentages = aggregate_expression_percentages(
adata=adata,
clusters=self.clusters,
gene_names=gene_names,
cell_batch_size=pcent_batch_size,
)
self.coordinates_key = validate_adata(
adata, group_key, rank_key, self.gene_symbols_column, coordinates_key
)
(
self.gene_symbols_column,
self._original_gene_symbols_column,
) = materialize_canonical_gene_symbols_column(
adata, self.gene_symbols_column
)
self._temporary_gene_symbols_column = self.gene_symbols_column

logger.info("Extracting marker genes...")
self.marker_genes = extract_marker_genes(
adata=self.adata,
cell_group_key=self.group_key,
rank_genes_key=self.rank_key,
cluster_map=self.cluster_map,
n_top_genes=n_top_genes,
gene_symbols_col=self.gene_symbols_column,
)
# Use original labels as IDs if all are short (<=3 chars), otherwise enumerate
_unique_group_categories: list[str | int] = natsorted(
adata.obs[group_key].unique().tolist()
)
_short_ids = all(len(str(x)) <= 3 for x in _unique_group_categories)
self.cluster_map = {
str(x): str(x) if _short_ids else str(n)
for n, x in enumerate(_unique_group_categories)
}
self.clusters = [
self.cluster_map[str(x)] for x in adata.obs[group_key].values.tolist()
]

gene_names = adata.var[self.gene_symbols_column].tolist()
self.expression_percentages = aggregate_expression_percentages(
adata=adata,
clusters=self.clusters,
gene_names=gene_names,
cell_batch_size=pcent_batch_size,
)

if aggregate_metadata:
logger.info("Aggregating cluster metadata...")
self.group_metadata = aggregate_cluster_metadata(
logger.info("Extracting marker genes...")
self.marker_genes = extract_marker_genes(
adata=self.adata,
group_key=self.group_key,
min_percentage=min_percentage,
max_categories=max_metadata_categories,
cell_group_key=self.group_key,
rank_genes_key=self.rank_key,
cluster_map=self.cluster_map,
n_top_genes=n_top_genes,
gene_symbols_col=self.gene_symbols_column,
)
# Replace keys in group_metadata using cluster_map
self.group_metadata = {
self.cluster_map.get(str(key), str(key)): value
for key, value in self.group_metadata.items()
}
self.group_metadata = {
k: self.group_metadata[k] for k in sorted(self.group_metadata.keys())

if aggregate_metadata:
logger.info("Aggregating cluster metadata...")
self.group_metadata = aggregate_cluster_metadata(
adata=self.adata,
group_key=self.group_key,
min_percentage=min_percentage,
max_categories=max_metadata_categories,
)
# Replace keys in group_metadata using cluster_map
self.group_metadata = {
self.cluster_map.get(str(key), str(key)): value
for key, value in self.group_metadata.items()
}
self.group_metadata = {
k: self.group_metadata[k]
for k in sorted(self.group_metadata.keys())
}
else:
self.group_metadata = {}

# Prepare visualization data with sampling
sampled_coordinates, sampled_cluster_labels = (
extract_visualization_coordinates(
adata=adata,
coordinates_key=self.coordinates_key,
group_key=self.group_key,
cluster_map=self.cluster_map,
max_cells_per_group=self.max_cells_per_group,
)
)

self.visualization_data = {
"coordinates": sampled_coordinates,
"clusters": sampled_cluster_labels,
}
else:
self.group_metadata = {}

# Prepare visualization data with sampling
sampled_coordinates, sampled_cluster_labels = extract_visualization_coordinates(
adata=adata,
coordinates_key=self.coordinates_key,
group_key=self.group_key,
cluster_map=self.cluster_map,
max_cells_per_group=self.max_cells_per_group,
)

self.visualization_data = {
"coordinates": sampled_coordinates,
"clusters": sampled_cluster_labels,
}
# Resolve raw counts once and cache
self._raw_counts_result = self._resolve_raw_counts()
if self._raw_counts_result is None:
logger.warning(
"No integer raw counts found in adata.layers['counts'], "
"adata.raw.X, or adata.X. Skipping raw counts in vars.h5."
)

# Resolve raw counts once and cache
self._raw_counts_result = self._resolve_raw_counts()
if self._raw_counts_result is None:
logger.warning(
"No integer raw counts found in adata.layers['counts'], "
"adata.raw.X, or adata.X. Skipping raw counts in vars.h5."
)
# Build vars.h5
try:
raw_mat, raw_col_indices = (
self._raw_counts_result
if self._raw_counts_result is not None
else (None, None)
)
save_features_matrix(
out_file=vars_h5_path,
mat=self.adata.X,
var_df=self.adata.var,
var_names=self.adata.var_names,
raw_mat=raw_mat,
raw_col_indices=raw_col_indices,
gene_symbols_column=self.gene_symbols_column,
)
sys.stderr.flush()
self._vars_h5_path = vars_h5_path
except Exception as exc:
logger.warning(f"vars.h5 artifact failed during build: {exc}")
self._artifact_build_errors.append(("vars_h5", exc))

# Build vars.h5
try:
raw_mat, raw_col_indices = (
self._raw_counts_result
if self._raw_counts_result is not None
else (None, None)
)
save_features_matrix(
out_file=vars_h5_path,
mat=self.adata.X,
var_df=self.adata.var,
var_names=self.adata.var_names,
raw_mat=raw_mat,
raw_col_indices=raw_col_indices,
)
sys.stderr.flush()
self._vars_h5_path: str | None = vars_h5_path
except Exception as exc:
logger.warning(f"vars.h5 artifact failed during build: {exc}")
self._vars_h5_path = None
self._artifact_build_errors.append(("vars_h5", exc))

# Build obs.duckdb
try:
logger.info("Writing obs data to duckdb artifact...")
obsm_coordinates = (
self.adata.obsm[self.coordinates_key]
if self.coordinates_key and self.coordinates_key in self.adata.obsm
else None
)
save_obs_duckdb_file(
out_file=obs_duckdb_path,
obs_df=self.adata.obs,
obsm_coordinates=obsm_coordinates,
coordinates_key=self.coordinates_key,
# Build obs.duckdb
try:
logger.info("Writing obs data to duckdb artifact...")
obsm_coordinates = (
self.adata.obsm[self.coordinates_key]
if self.coordinates_key and self.coordinates_key in self.adata.obsm
else None
)
save_obs_duckdb_file(
out_file=obs_duckdb_path,
obs_df=self.adata.obs,
obsm_coordinates=obsm_coordinates,
coordinates_key=self.coordinates_key,
)
sys.stderr.flush()
self._obs_duckdb_path = obs_duckdb_path
except Exception as exc:
logger.warning(f"obs.duckdb artifact failed during build: {exc}")
self._artifact_build_errors.append(("obs_duckdb", exc))

logger.info("Data preparation completed. Ready for submitting jobs.")
except Exception:
self._cleanup_temporary_gene_symbols_column()
raise

def _cleanup_temporary_gene_symbols_column(self) -> None:
temp_column = self._temporary_gene_symbols_column
if temp_column is None:
return

if temp_column in self.adata.var.columns:
del self.adata.var[temp_column]
logger.info(
f"Deleted temporary canonical gene-symbol column '{temp_column}'."
)
sys.stderr.flush()
self._obs_duckdb_path: str | None = obs_duckdb_path
except Exception as exc:
logger.warning(f"obs.duckdb artifact failed during build: {exc}")
self._obs_duckdb_path = None
self._artifact_build_errors.append(("obs_duckdb", exc))

logger.info("Data preparation completed. Ready for submitting jobs.")
self.gene_symbols_column = self._original_gene_symbols_column
self._temporary_gene_symbols_column = None

def _resolve_raw_counts(
self,
Expand Down Expand Up @@ -356,7 +383,8 @@ def cleanup(self) -> None:
"""Delete the artifact files built during initialization.

Call this after run() completes to remove the vars.h5 and obs.duckdb
files from disk. Paths are cleared so repeated calls are safe.
files from disk and drop the temporary canonical gene-symbol column.
Paths are cleared so repeated calls are safe.
"""
for attr, path in [
("_vars_h5_path", self._vars_h5_path),
Expand All @@ -370,6 +398,8 @@ def cleanup(self) -> None:
logger.warning(f"Failed to delete artifact {path}: {exc}")
setattr(self, attr, None)

self._cleanup_temporary_gene_symbols_column()

def run(
self,
study_context: str,
Expand Down
Loading