Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cytetype/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.19.1"
__version__ = "0.19.2"

import requests

Expand Down
2 changes: 1 addition & 1 deletion cytetype/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


MAX_UPLOAD_BYTES: dict[UploadFileKind, int] = {
"obs_duckdb": 100 * 1024 * 1024, # 100MB
"obs_duckdb": 2 * 1024 * 1024 * 1024, # 2GB
"vars_h5": 50 * 1024 * 1024 * 1024, # 10GB
}

Expand Down
19 changes: 13 additions & 6 deletions cytetype/core/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,16 +491,23 @@ def save_obs_duckdb(
"Invalid table_name. Use letters, numbers, and underscores only."
)

added_cols: list[str] = []
if obsm_coordinates is not None and coordinates_key is not None:
obs_df = obs_df.copy()
obs_df[f"__vis_coordinates_{coordinates_key}_1"] = obsm_coordinates[:, 0]
obs_df[f"__vis_coordinates_{coordinates_key}_2"] = obsm_coordinates[:, 1]
col1 = f"__vis_coordinates_{coordinates_key}_1"
col2 = f"__vis_coordinates_{coordinates_key}_2"
obs_df[col1] = obsm_coordinates[:, 0]
obs_df[col2] = obsm_coordinates[:, 1]
added_cols = [col1, col2]

dd_config: dict[str, Any] = {
"threads": threads,
"memory_limit": memory_limit,
"temp_directory": temp_directory,
}
with duckdb.connect(out_file, config=dd_config) as con:
con.register("obs_df", obs_df)
con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM obs_df")
try:
with duckdb.connect(out_file, config=dd_config) as con:
con.register("obs_df", obs_df)
con.execute(f"CREATE OR REPLACE TABLE {table_name} AS SELECT * FROM obs_df")
finally:
for col in added_cols:
obs_df.drop(columns=col, inplace=True, errors="ignore")
6 changes: 6 additions & 0 deletions cytetype/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
max_cells_per_group: int = 1000,
vars_h5_path: str = "vars.h5",
obs_duckdb_path: str = "obs.duckdb",
max_metadata_categories: int = 500,
api_url: str = "https://prod.cytetype.nygen.io",
auth_token: str | None = None,
) -> None:
Expand Down Expand Up @@ -116,6 +117,10 @@ def __init__(
max_cells_per_group (int, optional): Maximum number of cells to sample per group
for visualization. If a group has more cells than this limit, a random sample
will be taken. Defaults to 1000.
max_metadata_categories (int, optional): Maximum number of unique values a categorical
obs column may have to be included in cluster metadata aggregation. Columns with
more unique values (e.g. cell barcodes, per-cell IDs) are skipped to avoid
excessive memory usage. Defaults to 500.
api_url (str, optional): URL for the CyteType API endpoint. Only change if using a custom
deployment. Defaults to "https://prod.cytetype.nygen.io".
auth_token (str | None, optional): Bearer token for API authentication. If provided,
Expand Down Expand Up @@ -186,6 +191,7 @@ def __init__(
adata=self.adata,
group_key=self.group_key,
min_percentage=min_percentage,
max_categories=max_metadata_categories,
)
# Replace keys in group_metadata using cluster_map
self.group_metadata = {
Expand Down
17 changes: 15 additions & 2 deletions cytetype/preprocessing/aggregation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import anndata
import numpy as np

from ..config import logger
from .marker_detection import _accumulate_group_stats


Expand Down Expand Up @@ -55,6 +56,7 @@ def aggregate_cluster_metadata(
adata: anndata.AnnData,
group_key: str,
min_percentage: int = 10,
max_categories: int = 500,
) -> dict[str, dict[str, dict[str, int]]]:
"""Aggregate categorical metadata per cluster.

Expand All @@ -66,6 +68,9 @@ def aggregate_cluster_metadata(
adata: AnnData object containing single-cell data
group_key: Column name in adata.obs to group cells by
min_percentage: Minimum percentage of cells in a group to include
max_categories: Maximum number of unique values a column may have to be
included. Columns exceeding this threshold are skipped to avoid
memory-expensive intermediate DataFrames.

Returns:
Nested dictionary structure:
Expand All @@ -76,14 +81,22 @@ def aggregate_cluster_metadata(
grouped_data = adata.obs.groupby(group_key, observed=False)
column_distributions: dict[str, dict[str, dict[str, int]]] = {}

# Process each column in adata.obs
for column_name in adata.obs.columns:
if column_name == group_key:
continue

column_dtype = adata.obs[column_name].dtype
if column_dtype in ["object", "category", "string"]:
# Calculate value counts for each group
n_unique = adata.obs[column_name].nunique()
if n_unique > max_categories:
logger.debug(
"Skipping column '{}' ({} unique values > max_categories={}).",
column_name,
n_unique,
max_categories,
)
continue

value_counts_df = grouped_data[column_name].value_counts().unstack().T

# Convert to percentages and filter for values >min_percentage
Expand Down