Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
python-version: ${{ matrix.python_version }}
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v7
- name: Sync dependencies (with viz extra)
run: uv sync --frozen --extra viz
- name: Sync dependencies (with viz + dlpack extras)
run: uv sync --frozen --extra viz --extra dlpack
- name: Run pytest
run: uv run --frozen pytest
103 changes: 103 additions & 0 deletions docs/src/dlpack.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Exporting OME-Arrow pixel data via DLPack

OME-Arrow exposes a small tensor view API for pixel data. The returned
`TensorView` can export DLPack capsules for zero-copy interoperability on CPU
and (optionally) GPU.
Comment on lines +3 to +5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


Key defaults:

- 2D views default to `CHW` layout.
- 5D views default to `TZCHW` layout.
- Use `layout="HWC"` (or any TZCHW permutation) to override.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do We skip 3D and 3D?

Also are we being opinionated about color being a dimension? If so 2D is actually 3D here. If not then 5D should then be 4D


## PyTorch

```python
from ome_arrow import OMEArrow

obj = OMEArrow("example.ome.parquet")
view = obj.tensor_view(t=0, z=0, c=0)

# DLPack capsule -> torch.Tensor
import torch

capsule = view.to_dlpack(mode="arrow", device="cpu")
flat = torch.utils.dlpack.from_dlpack(capsule)
tensor = flat.reshape(view.shape)
```

## JAX

```python
from ome_arrow import OMEArrow

obj = OMEArrow("example.ome.parquet")
view = obj.tensor_view(t=0, z=0, c=0, layout="HWC")

import jax

capsule = view.to_dlpack(mode="arrow", device="cpu")
flat = jax.dlpack.from_dlpack(capsule)
arr = flat.reshape(view.shape)
```

## Iteration examples

```python
from ome_arrow import OMEArrow
import numpy as np

obj = OMEArrow("example.ome.parquet")
view = obj.tensor_view()

# Batch over time (T) dimension.
for cap in view.iter_dlpack(batch_size=2, shuffle=False, mode="numpy"):
batch = np.from_dlpack(cap)
# batch shape: (batch, Z, C, H, W) in TZCHW layout
```

```python
from ome_arrow import OMEArrow
import numpy as np

obj = OMEArrow("example.ome.parquet")
view = obj.tensor_view(t=0, z=0)

# Tile over spatial region.
for cap in view.iter_dlpack(
tile_size=(256, 256), shuffle=True, seed=123, mode="numpy"
):
tile = np.from_dlpack(cap)
# tile shape: (C, H, W) in CHW layout
```

## Ownership and lifetime

`TensorView.to_dlpack()` returns a DLPack-capable object (with `__dlpack__`)
that references the underlying Arrow values buffer in `mode="arrow"`, or a
NumPy buffer in `mode="numpy"`. Keep the `TensorView` (or any NumPy array
returned by `to_numpy`) alive until the consumer finishes using the DLPack
object.

`mode="arrow"` currently requires a single `(t, z, c)` selection and a full-frame
ROI. Use `mode="numpy"` for batches, crops, or layout reshaping beyond a simple
reshape.

Zero-copy guarantees depend on the source: Arrow-backed inputs preserve buffers,
while records built from Python lists or NumPy arrays will materialize once into
Arrow buffers. The same applies to `StructScalar` inputs, which are normalized
through Python objects before Arrow-mode export.
For Parquet/Vortex sources, zero-copy also requires the on-disk struct schema
to match `OME_ARROW_STRUCT`; non-strict schema normalization materializes via
Python objects.

## Optional dependencies

CPU DLPack export uses Arrow buffers by default. For framework helpers and GPU
paths, install only what you need:

```bash
pip install "ome-arrow[dlpack-torch]" # torch only
pip install "ome-arrow[dlpack-jax]" # jax only
pip install "ome-arrow[dlpack]" # both
```
177 changes: 152 additions & 25 deletions docs/src/examples/learning_to_fly_with_ome-arrow.ipynb

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions docs/src/examples/learning_to_fly_with_ome-arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@
# which handles all data I/O and manipulation
from ome_arrow import OMEArrow

try:
import jax
except ImportError:
jax = None

try:
import torch
except ImportError:
torch = None

# read a TIFF file and convert it to OME-Arrow
oa_image = OMEArrow(
data="../../../tests/data/examplehuman/AS_09125_050116030001_D03f00d0.tif"
Expand Down Expand Up @@ -103,3 +113,36 @@
oa_image = OMEArrow(data="../../../tests/data/idr0062A/6001240_labels.zarr")
# show the image using pyvista
oa_image.view(how="pyvista")

# ## DLPack tensor export (advanced)
# This is optional and requires torch: `pip install "ome-arrow[dlpack-torch]"`

# examples of exporting OME-Arrow data into DLPack format for zero-copy
oa = OMEArrow("example.ome.parquet")

# %%time
# DLPack Arrow mode: zero-copy 1D values buffer + reshape
view = oa.tensor_view(t=0, z=0, c=0)
cap = view.to_dlpack(mode="arrow")
if torch is not None:
flat = torch.utils.dlpack.from_dlpack(cap)
tensor = flat.reshape(view.shape)
tensor.shape


# %%time
# DLPack NumPy mode: shaped tensor directly (still zero-copy when possible)
view_hwc = oa.tensor_view(t=0, z=0, layout="HWC")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider clarifying the HWC here

cap_hwc = view_hwc.to_dlpack(mode="numpy", contiguous=True)
if torch is not None:
tensor_hwc = torch.utils.dlpack.from_dlpack(cap_hwc)
tensor_hwc.shape

# %%time
# DLPack Arrow mode: zero-copy 1D values buffer + reshape
view = oa.tensor_view(t=0, z=0, c=0)
caps = view.to_dlpack(mode="arrow")
if jax is not None:
flat = jax.dlpack.from_dlpack(caps)
arr = flat.reshape(view.shape)
arr.shape
1 change: 1 addition & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ caption: 'Contents:'
maxdepth: 3
---
python-api
dlpack
```
9 changes: 9 additions & 0 deletions docs/src/python-api.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,12 @@ ome_arrow.meta
:undoc-members:
:show-inheritance:
```

```{eval-rst}
ome_arrow.tensor
-------------------
.. automodule:: src.ome_arrow.tensor
:members:
:undoc-members:
:show-inheritance:
```
10 changes: 10 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ dependencies = [
"pillow>=12",
"pyarrow>=22",
]
optional-dependencies.dlpack = [
"jax>=0.4",
"torch>=2.1",
]
optional-dependencies.dlpack-jax = [
"jax>=0.4",
]
optional-dependencies.dlpack-torch = [
"torch>=2.1",
]
optional-dependencies.viz = [
"ipywidgets>=8.1.8",
"jupyterlab-widgets>=3.0.16",
Expand Down
1 change: 1 addition & 0 deletions src/ome_arrow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
to_ome_arrow,
)
from ome_arrow.meta import OME_ARROW_STRUCT, OME_ARROW_TAG_TYPE, OME_ARROW_TAG_VERSION
from ome_arrow.tensor import TensorView
from ome_arrow.utils import describe_ome_arrow, verify_ome_arrow
from ome_arrow.view import view_matplotlib, view_pyvista

Expand Down
65 changes: 60 additions & 5 deletions src/ome_arrow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import pathlib
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Sequence, Tuple

import matplotlib
import numpy as np
Expand All @@ -27,6 +27,7 @@
from_tiff,
)
from ome_arrow.meta import OME_ARROW_STRUCT
from ome_arrow.tensor import TensorView
from ome_arrow.transform import slice_ome_arrow
from ome_arrow.utils import describe_ome_arrow
from ome_arrow.view import view_matplotlib, view_pyvista
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(

# set the tcz for viewing
self.tcz = tcz
self._struct_array: pa.StructArray | None = None

# --- 1) Stack pattern (Bio-Formats-style) --------------------------------
if isinstance(data, str) and any(c in data for c in "<>*"):
Expand Down Expand Up @@ -109,17 +111,25 @@ def __init__(
".parquet",
".pq",
}:
self.data = from_ome_parquet(
s, column_name=column_name, row_index=row_index
parquet_result = from_ome_parquet(
s,
column_name=column_name,
row_index=row_index,
return_array=True,
)
self.data, self._struct_array = parquet_result
if image_type is not None:
self.data = self._wrap_with_image_type(self.data, image_type)

# Vortex
elif s.lower().endswith(".vortex") or path.suffix.lower() == ".vortex":
self.data = from_ome_vortex(
s, column_name=column_name, row_index=row_index
vortex_result = from_ome_vortex(
s,
column_name=column_name,
row_index=row_index,
return_array=True,
)
self.data, self._struct_array = vortex_result
if image_type is not None:
self.data = self._wrap_with_image_type(self.data, image_type)

Expand Down Expand Up @@ -496,6 +506,51 @@ def view(

return plotter

def tensor_view(
self,
*,
scene: int | None = None,
t: int | slice | Sequence[int] | None = None,
z: int | slice | Sequence[int] | None = None,
c: int | slice | Sequence[int] | None = None,
roi: tuple[int, int, int, int] | None = None,
tile: tuple[int, int] | None = None,
layout: str | None = None,
dtype: np.dtype | None = None,
) -> TensorView:
"""Create a TensorView of the pixel data.

Args:
scene: Scene index (only 0 is supported for single-image records).
t: Time index selection (int, slice, or sequence). Default: all.
z: Z index selection (int, slice, or sequence). Default: all.
c: Channel index selection (int, slice, or sequence). Default: all.
roi: Spatial crop (x, y, w, h) in pixels.
tile: Tile index (tile_y, tile_x) based on chunk grid.
layout: Desired layout string using TZCHW letters.
dtype: Output dtype override.

Returns:
TensorView: Tensor view over the selected pixels.

Raises:
ValueError: If an unsupported scene is requested.
"""

if scene not in (None, 0):
raise ValueError("Only scene=0 is supported for single-image records.")

return TensorView(
self._struct_array if self._struct_array is not None else self.data,
t=t,
z=z,
c=c,
roi=roi,
tile=tile,
layout=layout,
dtype=dtype,
Comment on lines +548 to +551
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are nice to have for scaling up in the future!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scaling up in this case I mean the size of data that ome-arrow can hold (1000um x 1000um x 1000um sections)

)

def slice(
self,
x_min: int,
Expand Down
Loading