-
Notifications
You must be signed in to change notification settings - Fork 1
Provide tensor-friendly methods for torch and jax through OMEArrow.to_dlpack
#34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ba8af01
7c3b504
ed1d437
a79a129
948993c
39e22d7
a83c3c4
2a69d24
5f88a86
54980b7
0a576df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| # Exporting OME-Arrow pixel data via DLPack | ||
|
|
||
| OME-Arrow exposes a small tensor view API for pixel data. The returned | ||
| `TensorView` can export DLPack capsules for zero-copy interoperability on CPU | ||
| and (optionally) GPU. | ||
|
Comment on lines
+3
to
+5
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❗ |
||
|
|
||
| Key defaults: | ||
|
|
||
| - 2D views default to `CHW` layout. | ||
| - 5D views default to `TZCHW` layout. | ||
| - Use `layout="HWC"` (or any TZCHW permutation) to override. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do We skip 3D and 3D? Also are we being opinionated about color being a dimension? If so 2D is actually 3D here. If not then 5D should then be 4D |
||
|
|
||
| ## PyTorch | ||
|
|
||
| ```python | ||
| from ome_arrow import OMEArrow | ||
|
|
||
| obj = OMEArrow("example.ome.parquet") | ||
| view = obj.tensor_view(t=0, z=0, c=0) | ||
|
|
||
| # DLPack capsule -> torch.Tensor | ||
| import torch | ||
|
|
||
| capsule = view.to_dlpack(mode="arrow", device="cpu") | ||
| flat = torch.utils.dlpack.from_dlpack(capsule) | ||
| tensor = flat.reshape(view.shape) | ||
| ``` | ||
|
|
||
| ## JAX | ||
|
|
||
| ```python | ||
| from ome_arrow import OMEArrow | ||
|
|
||
| obj = OMEArrow("example.ome.parquet") | ||
| view = obj.tensor_view(t=0, z=0, c=0, layout="HWC") | ||
|
|
||
| import jax | ||
|
|
||
| capsule = view.to_dlpack(mode="arrow", device="cpu") | ||
| flat = jax.dlpack.from_dlpack(capsule) | ||
| arr = flat.reshape(view.shape) | ||
| ``` | ||
|
|
||
| ## Iteration examples | ||
|
|
||
| ```python | ||
| from ome_arrow import OMEArrow | ||
| import numpy as np | ||
|
|
||
| obj = OMEArrow("example.ome.parquet") | ||
| view = obj.tensor_view() | ||
|
|
||
| # Batch over time (T) dimension. | ||
| for cap in view.iter_dlpack(batch_size=2, shuffle=False, mode="numpy"): | ||
| batch = np.from_dlpack(cap) | ||
| # batch shape: (batch, Z, C, H, W) in TZCHW layout | ||
| ``` | ||
d33bs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| ```python | ||
| from ome_arrow import OMEArrow | ||
| import numpy as np | ||
|
|
||
| obj = OMEArrow("example.ome.parquet") | ||
| view = obj.tensor_view(t=0, z=0) | ||
|
|
||
| # Tile over spatial region. | ||
| for cap in view.iter_dlpack( | ||
| tile_size=(256, 256), shuffle=True, seed=123, mode="numpy" | ||
| ): | ||
| tile = np.from_dlpack(cap) | ||
| # tile shape: (C, H, W) in CHW layout | ||
| ``` | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| ## Ownership and lifetime | ||
|
|
||
| `TensorView.to_dlpack()` returns a DLPack-capable object (with `__dlpack__`) | ||
| that references the underlying Arrow values buffer in `mode="arrow"`, or a | ||
| NumPy buffer in `mode="numpy"`. Keep the `TensorView` (or any NumPy array | ||
| returned by `to_numpy`) alive until the consumer finishes using the DLPack | ||
| object. | ||
|
|
||
| `mode="arrow"` currently requires a single `(t, z, c)` selection and a full-frame | ||
| ROI. Use `mode="numpy"` for batches, crops, or layout reshaping beyond a simple | ||
| reshape. | ||
|
|
||
| Zero-copy guarantees depend on the source: Arrow-backed inputs preserve buffers, | ||
| while records built from Python lists or NumPy arrays will materialize once into | ||
| Arrow buffers. The same applies to `StructScalar` inputs, which are normalized | ||
| through Python objects before Arrow-mode export. | ||
| For Parquet/Vortex sources, zero-copy also requires the on-disk struct schema | ||
| to match `OME_ARROW_STRUCT`; non-strict schema normalization materializes via | ||
| Python objects. | ||
|
|
||
| ## Optional dependencies | ||
|
|
||
| CPU DLPack export uses Arrow buffers by default. For framework helpers and GPU | ||
| paths, install only what you need: | ||
|
|
||
| ```bash | ||
| pip install "ome-arrow[dlpack-torch]" # torch only | ||
| pip install "ome-arrow[dlpack-jax]" # jax only | ||
| pip install "ome-arrow[dlpack]" # both | ||
| ``` | ||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,16 @@ | |
| # which handles all data I/O and manipulation | ||
| from ome_arrow import OMEArrow | ||
|
|
||
| try: | ||
| import jax | ||
| except ImportError: | ||
| jax = None | ||
|
|
||
| try: | ||
| import torch | ||
| except ImportError: | ||
| torch = None | ||
|
|
||
| # read a TIFF file and convert it to OME-Arrow | ||
| oa_image = OMEArrow( | ||
| data="../../../tests/data/examplehuman/AS_09125_050116030001_D03f00d0.tif" | ||
|
|
@@ -103,3 +113,36 @@ | |
| oa_image = OMEArrow(data="../../../tests/data/idr0062A/6001240_labels.zarr") | ||
| # show the image using pyvista | ||
| oa_image.view(how="pyvista") | ||
|
|
||
| # ## DLPack tensor export (advanced) | ||
| # This is optional and requires torch: `pip install "ome-arrow[dlpack-torch]"` | ||
|
|
||
| # examples of exporting OME-Arrow data into DLPack format for zero-copy | ||
| oa = OMEArrow("example.ome.parquet") | ||
|
|
||
| # %%time | ||
| # DLPack Arrow mode: zero-copy 1D values buffer + reshape | ||
| view = oa.tensor_view(t=0, z=0, c=0) | ||
| cap = view.to_dlpack(mode="arrow") | ||
| if torch is not None: | ||
| flat = torch.utils.dlpack.from_dlpack(cap) | ||
| tensor = flat.reshape(view.shape) | ||
| tensor.shape | ||
|
|
||
|
|
||
| # %%time | ||
| # DLPack NumPy mode: shaped tensor directly (still zero-copy when possible) | ||
| view_hwc = oa.tensor_view(t=0, z=0, layout="HWC") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. consider clarifying the HWC here |
||
| cap_hwc = view_hwc.to_dlpack(mode="numpy", contiguous=True) | ||
| if torch is not None: | ||
| tensor_hwc = torch.utils.dlpack.from_dlpack(cap_hwc) | ||
| tensor_hwc.shape | ||
|
|
||
| # %%time | ||
| # DLPack Arrow mode: zero-copy 1D values buffer + reshape | ||
| view = oa.tensor_view(t=0, z=0, c=0) | ||
| caps = view.to_dlpack(mode="arrow") | ||
| if jax is not None: | ||
| flat = jax.dlpack.from_dlpack(caps) | ||
| arr = flat.reshape(view.shape) | ||
| arr.shape | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,4 +13,5 @@ caption: 'Contents:' | |
| maxdepth: 3 | ||
| --- | ||
| python-api | ||
| dlpack | ||
| ``` | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| import pathlib | ||
| from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple | ||
| from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Sequence, Tuple | ||
|
|
||
| import matplotlib | ||
| import numpy as np | ||
|
|
@@ -27,6 +27,7 @@ | |
| from_tiff, | ||
| ) | ||
| from ome_arrow.meta import OME_ARROW_STRUCT | ||
| from ome_arrow.tensor import TensorView | ||
| from ome_arrow.transform import slice_ome_arrow | ||
| from ome_arrow.utils import describe_ome_arrow | ||
| from ome_arrow.view import view_matplotlib, view_pyvista | ||
|
|
@@ -77,6 +78,7 @@ def __init__( | |
|
|
||
| # set the tcz for viewing | ||
| self.tcz = tcz | ||
| self._struct_array: pa.StructArray | None = None | ||
|
|
||
| # --- 1) Stack pattern (Bio-Formats-style) -------------------------------- | ||
| if isinstance(data, str) and any(c in data for c in "<>*"): | ||
|
|
@@ -109,17 +111,25 @@ def __init__( | |
| ".parquet", | ||
| ".pq", | ||
| }: | ||
| self.data = from_ome_parquet( | ||
| s, column_name=column_name, row_index=row_index | ||
| parquet_result = from_ome_parquet( | ||
| s, | ||
| column_name=column_name, | ||
| row_index=row_index, | ||
| return_array=True, | ||
| ) | ||
| self.data, self._struct_array = parquet_result | ||
| if image_type is not None: | ||
| self.data = self._wrap_with_image_type(self.data, image_type) | ||
|
|
||
| # Vortex | ||
| elif s.lower().endswith(".vortex") or path.suffix.lower() == ".vortex": | ||
| self.data = from_ome_vortex( | ||
| s, column_name=column_name, row_index=row_index | ||
| vortex_result = from_ome_vortex( | ||
| s, | ||
| column_name=column_name, | ||
| row_index=row_index, | ||
| return_array=True, | ||
| ) | ||
| self.data, self._struct_array = vortex_result | ||
| if image_type is not None: | ||
| self.data = self._wrap_with_image_type(self.data, image_type) | ||
|
|
||
|
|
@@ -496,6 +506,51 @@ def view( | |
|
|
||
| return plotter | ||
|
|
||
| def tensor_view( | ||
| self, | ||
| *, | ||
| scene: int | None = None, | ||
| t: int | slice | Sequence[int] | None = None, | ||
| z: int | slice | Sequence[int] | None = None, | ||
| c: int | slice | Sequence[int] | None = None, | ||
| roi: tuple[int, int, int, int] | None = None, | ||
| tile: tuple[int, int] | None = None, | ||
| layout: str | None = None, | ||
| dtype: np.dtype | None = None, | ||
| ) -> TensorView: | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Create a TensorView of the pixel data. | ||
|
|
||
| Args: | ||
| scene: Scene index (only 0 is supported for single-image records). | ||
| t: Time index selection (int, slice, or sequence). Default: all. | ||
| z: Z index selection (int, slice, or sequence). Default: all. | ||
| c: Channel index selection (int, slice, or sequence). Default: all. | ||
| roi: Spatial crop (x, y, w, h) in pixels. | ||
| tile: Tile index (tile_y, tile_x) based on chunk grid. | ||
| layout: Desired layout string using TZCHW letters. | ||
| dtype: Output dtype override. | ||
|
|
||
| Returns: | ||
| TensorView: Tensor view over the selected pixels. | ||
|
|
||
| Raises: | ||
| ValueError: If an unsupported scene is requested. | ||
| """ | ||
|
|
||
| if scene not in (None, 0): | ||
| raise ValueError("Only scene=0 is supported for single-image records.") | ||
|
|
||
| return TensorView( | ||
| self._struct_array if self._struct_array is not None else self.data, | ||
| t=t, | ||
| z=z, | ||
| c=c, | ||
| roi=roi, | ||
| tile=tile, | ||
| layout=layout, | ||
| dtype=dtype, | ||
|
Comment on lines
+548
to
+551
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. these are nice to have for scaling up in the future!
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Scaling up in this case I mean the size of data that ome-arrow can hold (1000um x 1000um x 1000um sections) |
||
| ) | ||
|
|
||
| def slice( | ||
| self, | ||
| x_min: int, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.