-
Notifications
You must be signed in to change notification settings - Fork 1
Provide tensor-friendly methods for torch and jax through OMEArrow.to_dlpack
#34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR implements DLPack export and streaming tensor views for OME-Arrow, enabling zero-copy consumption of pixel data by PyTorch and JAX. A new Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant OMEArrow
participant TensorView
participant DLPack
participant PyTorch
participant JAX
User->>OMEArrow: tensor_view(t, z, c, roi, layout)
OMEArrow->>TensorView: __init__(struct_array/data, selections)
TensorView-->>OMEArrow: TensorView instance
OMEArrow-->>User: TensorView
User->>TensorView: to_dlpack(device="cpu"|"cuda", contiguous=True|False)
TensorView->>TensorView: _materialize() if needed
TensorView->>DLPack: PyCapsule (zero-copy or materialized)
DLPack-->>TensorView: DLPack capsule
TensorView-->>User: capsule
User->>PyTorch: from_dlpack(capsule)
PyTorch-->>User: torch.Tensor
Note over User,PyTorch: or
User->>JAX: from_dlpack(capsule)
JAX-->>User: jax.Array
sequenceDiagram
participant User
participant TensorView
participant iter_dlpack
participant DLPack
participant Consumer
User->>TensorView: iter_dlpack(batch_size, tiles, shuffle, seed)
TensorView->>iter_dlpack: __init__(enumerate batches/tiles)
iter_dlpack-->>TensorView: iterator
TensorView-->>User: iterator
loop For each batch/tile
User->>iter_dlpack: next()
iter_dlpack->>iter_dlpack: _slice_batch() or _slice_tile()
iter_dlpack->>DLPack: to_dlpack() for subset
DLPack-->>iter_dlpack: capsule
iter_dlpack-->>User: capsule
User->>Consumer: from_dlpack(capsule)
Consumer-->>User: tensor/array
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~70 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches🧪 Generate unit tests (beta)
No actionable comments were generated in the recent review. 🎉 🧹 Recent nitpick comments
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🤖 Fix all issues with AI agents
In `@docs/src/dlpack.md`:
- Around line 92-93: The fenced code block containing the pip command (the block
with pip install "ome-arrow[dlpack]") lacks a language specifier; update that
block to use a shell language tag (e.g., change the opening ``` to ```bash or
```shell) so the snippet complies with markdownlint MD040 and renders correctly.
- Around line 46-68: The examples call np.from_dlpack(...) but never import
numpy; add "import numpy as np" at the top of each snippet (above OMEArrow
usage) so the examples using OMEArrow, tensor_view(), and view.iter_dlpack(...)
work as shown; ensure the import appears before any np.from_dlpack calls in both
the batch (iter_dlpack(batch_size=2, ...)) and tile
(iter_dlpack(tiles=(256,256), ...)) examples.
In `@src/ome_arrow/core.py`:
- Around line 509-520: The type annotations for tensor_view's t, z, and c
parameters are too broad (Iterable[int]) and must be changed to Sequence[int] to
match TensorView.__init__ and the runtime check in _normalize_index; update the
signature of tensor_view (the parameters t, z, c) from "int | slice |
Iterable[int] | None" to "int | slice | Sequence[int] | None" and add/import
typing.Sequence where needed so callers and the isinstance(index, Sequence)
check in _normalize_index behave correctly.
In `@src/ome_arrow/tensor.py`:
- Around line 482-488: _pixels_meta currently only checks _data_py and
_struct_array and raises if both are None; add handling for the
StructScalar-only case by checking for self._struct_scalar and returning the
pixels_meta from it (e.g., call
self._struct_scalar.field("pixels_meta").as_py()) so the method works when the
object was constructed from a StructScalar; keep existing behavior for _data_py
and _struct_array paths and only raise if none of the three sources (_data_py,
_struct_array, _struct_scalar) are available.
- Around line 120-131: The current shape property materializes the full array
via self.to_numpy(contiguous=False).shape which is expensive; change the shape
property (method name: shape) to compute and return the tensor shape directly
from the Tensor's selection/metadata (e.g., use whatever selection, bounds, and
layout fields on self such as self.selection, self.bounds, self.layout or
equivalent) without calling to_numpy, while leaving the strides property
unchanged (keep strides calling to_numpy for now); ensure the new shape logic
mirrors the existing layout ordering used by to_numpy so callers get the same
tuple but computed cheaply.
- Around line 67-80: In __init__, stop eagerly calling data.as_py() for
StructScalar: when isinstance(data, pa.StructScalar) set self._struct_scalar =
data and self._data_py = None (don’t materialize pixels); ensure _data_py_dict()
still calls as_py() on-demand to populate _data_py. Update _pixels_meta(),
_has_chunks(), and _chunk_grid() to handle the case where both self._data_py and
self._struct_array are None by reading necessary fields directly from
self._struct_scalar (extract struct fields from the scalar rather than relying
on _data_py), and only fall back to calling _data_py_dict() if you need the full
Python dict. Keep identifiers: __init__, _struct_scalar, _data_py,
_data_py_dict, _struct_array, _pixels_meta, _has_chunks, and _chunk_grid when
making these changes.
🧹 Nitpick comments (9)
pyproject.toml (1)
34-37: Consider splittingjaxandtorchinto separate extras.Bundling both frameworks into a single
dlpackextra forces users who only need one (e.g., PyTorch-only workflows) to install the other. JAX and PyTorch are both large dependencies. Consider offering finer-grained extras (e.g.,dlpack-torch,dlpack-jax) alongside the combineddlpackgroup.docs/src/python-api.md (1)
12-21: Consider whether:private-members:is intentional for the public API page.This mirrors the existing
ome_arrow.metablock, but:private-members:will surface internal helpers (e.g.,_DLPackWrapper,_normalize_device,_read_plane) in the public API docs. If these are implementation details not meant for consumers, dropping:private-members:would reduce noise.tests/test_tensor.py (1)
60-80: Zero-copy assertion is fragile: depends on internal caching order.Line 80 asserts pointer equality between a torch tensor (from DLPack) and a numpy array obtained from a separate
to_numpycall. This works because_materializecaches the result andto_dlpack(mode="numpy")builds the capsule from the same cached array. However, ifto_dlpackis called beforeto_numpy, the cache is populated byto_dlpack; callingto_numpyafterwards returns the same object — pointer equality holds.But if caching behavior in
_materializechanges (e.g., layout-dependent invalidation), this test would fail for non-obvious reasons. Consider adding a brief comment clarifying the dependency on the shared internal cache, so future maintainers understand the coupling.src/ome_arrow/ingest.py (2)
109-128: Zero-copy is lost in the non-strict schema path (line 121-128).When
strict_schema=Falseand the column type doesn't exactly matchOME_ARROW_STRUCT, the code falls back toto_pylist()(line 121) and reconstructs both the scalar and struct_array from a Python dict (line 127-128). This materializes all data into Python objects and back, defeating the zero-copy goal for tensor_view.Since
OMEArrow.__init__calls withstrict_schema=False(the default), this path will be hit whenever the on-disk schema has extra/differently-ordered fields — which is common with evolving schemas.This isn't necessarily a bug (correctness is preserved), but it's worth documenting that zero-copy DLPack export from Parquet only works when the on-disk schema exactly matches
OME_ARROW_STRUCT. Consider logging a warning or adding a note in thetensor_viewdocs.
130-136: Pre-existing: soft validation comparisons are no-ops.Lines 133-134 evaluate equality but discard the results. This is pre-existing code and not introduced by this PR, but it's worth noting for a future cleanup.
meta.get(b"ome.arrow.type", b"").decode() == str(OME_ARROW_TAG_TYPE) # result unused meta.get(b"ome.arrow.version", b"").decode() == str(OME_ARROW_TAG_VERSION) # result unusedsrc/ome_arrow/tensor.py (4)
646-660:_batchedprefetch path useslist.pop(0)— O(n) per dequeue.When
prefetch > 0,queue.pop(0)is O(n) since list elements must be shifted. For large iteration counts this degrades to O(n²). Usecollections.dequeinstead:Proposed fix
+from collections import deque ... def _batched(items: List[Any], size: int, *, prefetch: int) -> Iterator[List[Any]]: if size <= 0: raise ValueError("batch size must be positive") if prefetch <= 0: for i in range(0, len(items), size): yield items[i : i + size] return - queue: List[List[Any]] = [] + queue: deque[List[Any]] = deque() idx = 0 while idx < len(items) or queue: while idx < len(items) and len(queue) <= prefetch: queue.append(items[idx : idx + size]) idx += size - yield queue.pop(0) + yield queue.popleft()
690-702:_ensure_struct_arrayfrom StructScalar callsas_py()— not zero-copy.Line 698:
pa.array([data.as_py()], type=OME_ARROW_STRUCT)deserializes and re-serializes the entire record. This meansmode="arrow"DLPack export from a StructScalar input won't be zero-copy. The path works for Parquet/Vortex where_datais already a StructArray, but for other inputs (TIFF, numpy, dict), it will materialize.This is correctly documented via the
mode='arrow' requires Arrow-backed dataerror on line 517, but that error is only raised when_ensure_struct_arrayreturnsNone(i.e., the input is not one of the known types). For StructScalar, a new array IS created — so the code silently works but without zero-copy semantics.Consider either documenting this behavior or raising when true zero-copy can't be guaranteed.
510-534:_arrow_valuessilently creates a non-zero-copy array for StructScalar inputs.When the input is a
StructScalar(e.g., from TIFF/NumPy paths),_ensure_struct_array(line 515) will create a newStructArrayviaas_py()round-trip. The subsequent_select_plane_valuescall on line 534 then returns an Arrow array from this copy, not from the original data. The DLPack capsule will point to the copy's buffer.This is functionally correct, but it may surprise users who expect
mode="arrow"to mean zero-copy. Consider adding a note in the docstring forto_dlpackthat zero-copy Arrow mode is only available for Parquet/Vortex-ingested data.
168-212:to_dlpackownership semantics should be documented.The DLPack protocol specifies that
__dlpack__()should be called at most once on a capsule (ownership transfer)._DLPackWrapperstores the capsule and returns it on every__dlpack__()call, which works correctly because the consumer (torch/jax) marks the capsule as consumed. However, a second call would fail silently or raise.Consider adding a brief docstring note on
_DLPackWrapperorto_dlpackthat the returned object is single-use per the DLPack protocol. The linked issue also specifically requests documenting "ownership and lifetime semantics for returned DLPack capsules."
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR adds a new tensor-oriented API to OME-Arrow that enables exporting pixel data via DLPack for consumption by PyTorch and JAX (aiming for zero-copy on Arrow-backed CPU data), along with tests and documentation.
Changes:
- Introduce
TensorViewwithto_numpy(),to_dlpack(),to_torch(),to_jax(), anditer_dlpack()for batch/tile iteration. - Extend Parquet/Vortex ingest paths to optionally return a 1-row
StructArrayto preserve Arrow buffers for downstream zero-copy exports. - Add docs and tests for DLPack export and iteration; add a new
dlpackextra (jax/torch) and update CI dependency sync.
Reviewed changes
Copilot reviewed 11 out of 13 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
src/ome_arrow/tensor.py |
New TensorView implementation and DLPack export/iteration support. |
src/ome_arrow/core.py |
Add OMEArrow.tensor_view() and preserve an Arrow StructArray for zero-copy paths. |
src/ome_arrow/ingest.py |
Add return_array option to return (StructScalar, StructArray) for Parquet/Vortex. |
src/ome_arrow/__init__.py |
Export TensorView from the package namespace. |
tests/test_tensor.py |
New tests for layout, DLPack round-trips, and iterators. |
docs/src/dlpack.md |
New documentation page describing DLPack export and examples. |
docs/src/index.md |
Add dlpack page to the docs TOC. |
docs/src/python-api.md |
Add autodoc section for the tensor module. |
docs/src/examples/learning_to_fly_with_ome-arrow.py |
Add example usage snippets for DLPack export with torch/jax. |
pyproject.toml |
Add dlpack optional dependency extra (jax, torch). |
.github/workflows/run-tests.yml |
Install dlpack extra in CI test job. |
uv.lock |
Lockfile updates to include dlpack extra dependencies and transitive packages. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/ome_arrow/tensor.py`:
- Around line 728-734: The _DLPackWrapper class is returning the same capsule on
every call to __dlpack__, violating the DLPack single-use rule; update the
__dlpack__ method to return the stored capsule once, then immediately invalidate
it (set self._capsule to None) and raise a clear error (e.g., RuntimeError) if
__dlpack__ is called again when _capsule is None; touch the __init__ only to
keep the _capsule and _device fields, but add the post-return invalidation and a
guard at the top of __dlpack__ that checks _capsule and raises on reuse so
consumers cannot receive the capsule twice.
🧹 Nitpick comments (2)
src/ome_arrow/tensor.py (2)
690-704:_batchedprefetch is synchronous — the name is misleading.The
prefetchparameter pre-slices items into adequebut doesn't perform any actual async precomputation. The iterator still materializes each batch lazily when yielded. This is functionally correct but won't deliver the performance benefit users expect from a "prefetch" parameter (as documented initer_dlpack: "Number of items to precompute ahead").Consider either documenting this as a placeholder for future async prefetch, or removing the parameter until real prefetch (e.g., via
concurrent.futures) is implemented.
430-441: Triple-nested loop in_build_tzchwreads planes one at a time.For selections with many (t, z, c) combinations this is O(T×Z×C) individual plane reads, each potentially calling
plane_from_chunkswhich iterates all chunks. This is acceptable for the initial implementation but could become a bottleneck for large 5D datasets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@src/ome_arrow/tensor.py`:
- Around line 505-516: The _has_chunks method currently treats non-null but
empty Arrow lists as present, causing inconsistency with the dict path and
downstream errors in _plane_map/_read_plane and plane_from_chunks; update
_has_chunks to check for empty list lengths on both Arrow paths: when using
self._struct_array, inspect the 'chunks' field's length (and ensure it's not
zero) rather than only testing is_null, and when using self._struct_scalar,
obtain the 'chunks' scalar, verify it's valid and also check its list length (or
empty-ness) before returning True so that an empty list yields False
consistently with the dict path.
- Around line 505-516: The _has_chunks method returns true for a StructScalar
when the scalar is valid but contains an empty list; update _has_chunks to
mirror the dict/array checks by first verifying
self._struct_scalar["chunks"].is_valid and then checking the list length (e.g.,
via len(chunks_scalar.as_py()) > 0) before returning True so that empty lists
are treated as no-chunks; this will prevent _read_plane from calling
plane_from_chunks with no actual chunks.
🧹 Nitpick comments (2)
src/ome_arrow/tensor.py (1)
708-722:_batchedwithprefetch > 0doesn't provide actual prefetching.The
prefetchparameter fills a synchronous deque before yielding, but since there's no background thread or async I/O, buffering items into a deque before yielding them one at a time is functionally equivalent to simple sequential iteration — it just adds overhead. The docstring oniter_dlpack(line 288–289) does note this is not async, which is good, but the implementation could be simplified to just theprefetch <= 0path for now.Consider documenting this as a placeholder for future async prefetch, or removing the deque-based path until true async prefetch is implemented.
tests/test_tensor.py (1)
13-28:_Wrapperinside_from_dlpack_capsuledoesn't enforce single-use semantics.This is a test-only helper, so it's not a real concern — each capsule is consumed exactly once by
np.from_dlpack. Just flagging for awareness since_DLPackWrapperin production code was specifically fixed for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@src/ome_arrow/tensor.py`:
- Around line 674-680: The _dtype_from_meta function currently falls back
silently to np.uint16 on errors; change the try/except around np.dtype(dtype) to
emit a warning when the conversion fails (include the original dtype value and
the exception message) before returning np.dtype(np.uint16). Use the warnings
module (or processLogger.warn if project logging convention prefers) so callers
are informed that an unrecognized dtype (e.g., "floa32") caused a fallback to
uint16; reference the function name _dtype_from_meta and the dtype parameter to
locate and modify the code.
🧹 Nitpick comments (4)
src/ome_arrow/tensor.py (2)
406-419:_has_chunks()called repeatedly inside the triple loop — minor performance concern.
_has_chunks()is invoked on every_read_planecall inside the(T × Z × C)loop. For non-dict paths it re-queries Arrow fields each time. Consider caching the result in aself._chunks_presentflag computed once during__init__or on first access.
356-394: Minor naming ambiguity betweentile(index) andtiles(size).The
__init__parametertilerefers to a tile index(tile_y, tile_x)into the chunk grid, whileiter_dlpack'stilesparameter specifies tile size(tile_h, tile_w)in pixels. Both aretuple[int, int]with similar names but very different semantics. Consider renaming one (e.g.,tile_sizeforiter_dlpack) to reduce confusion.tests/test_tensor.py (2)
256-284: Test helper_selected_values_for_arrow_modeduplicates filtering logic fromtensor.py.This helper mirrors the logic in
_select_plane_valuesand_select_chunk_values. While acceptable in tests to independently verify behavior, consider documenting its purpose more explicitly or noting it's intentionally reimplemented for test independence.
235-253:_jax_buffer_ptris brittle across JAX versions.The helper tries multiple JAX internal attributes (
device_buffer,device_buffers,addressable_data,unsafe_buffer_pointer). These are implementation details that change across JAX versions. The overall zero-copy test already usespytest.skipas a safety net, so this is acceptable, but worth noting for future maintenance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@pyproject.toml`:
- Around line 130-131: Replace the TOML table using a dotted key
(`[tool.pytest]` with `ini_options.pythonpath`) with the canonical pytest table
`[tool.pytest.ini_options]` and move the pythonpath setting into that table
(i.e., set `pythonpath = ["." ]`) so pytest picks up the configuration
correctly; update the section header and the key accordingly where
`ini_options.pythonpath` appears.
In `@src/ome_arrow/tensor.py`:
- Around line 71-72: Handle the zero-chunk case before calling
combine_chunks()/data[0]: inside the existing ChunkedArray branch (the code that
checks isinstance(data, pa.ChunkedArray) and uses data.chunk(0),
data.num_chunks, and data.combine_chunks()), add a guard for data.num_chunks ==
0 and produce/assign an appropriate empty StructArray (or return an empty
tensor) matching the expected schema, otherwise keep the current logic (use
chunk(0) when num_chunks == 1 and combine_chunks() when >1); also ensure any
subsequent indexing like data[0] only runs when the resulting array has length >
0.
- Around line 261-267: The code currently calls jax.dlpack.from_dlpack(...)
after ensuring JAX is installed; replace that public API call with the
NumPy-compatible jax.numpy.from_dlpack(...) to follow the recommended interface.
In the method that builds dlpack via self.to_dlpack(...), import or reference
jax.numpy (e.g., jax.numpy.from_dlpack) and call jax.numpy.from_dlpack(dlpack)
instead of jax.dlpack.from_dlpack(dlpack) so the returned array uses the public,
supported API.
🧹 Nitpick comments (3)
src/ome_arrow/tensor.py (1)
68-72:self._dataretains the rawChunkedArray, causing repeatedcombine_chunks()in iteration.Line 68 stores the original
data(potentially a multi-chunkChunkedArray) inself._data. Since_iter_batchesand_iter_tilespassself._datato each newTensorView, every batch/tile re-runs thecombine_chunks()call. Consider storing the unwrapped array instead:Suggested fix
- self._data = data - self._struct_array: pa.StructArray | None = None + self._struct_array: pa.StructArray | None = None self._struct_scalar: pa.StructScalar | None = None if isinstance(data, pa.ChunkedArray): data = data.chunk(0) if data.num_chunks == 1 else data.combine_chunks() if isinstance(data, pa.StructArray): self._struct_array = data self._struct_scalar = data[0] self._data_py: dict[str, Any] | None = None elif isinstance(data, pa.StructScalar): self._struct_scalar = data self._data_py = None else: self._data_py = data + self._data = datatests/test_tensor.py (2)
89-96: Note:jax.dlpack.from_dlpackusage mirrors the deprecation concern intensor.py.If the source is updated to use
jax.numpy.from_dlpack, this test should be updated accordingly (line 96).
235-259:_jax_buffer_ptr— fragile but appropriately guarded.The helper probes version-specific JAX internals with multiple fallback paths. The docstring clearly states the intent to skip rather than hard-fail. This is acceptable for best-effort zero-copy tests. Consider adding a
pytest.skipcall in theAssertionErrorraises (lines 247, 252, 259) if you'd prefer tests to skip rather than fail when JAX internals change.
Description
This PR provides tensor-friendly methods for loading images in
torchorjaxby usingOMEArrow.to_dlpack. The goal is to provide zero-copy capabilities in tensor environments through Arrow, avoiding unnecessary conversions where possible.Closes #33
What kind of change(s) are included?
Checklist
Please ensure that all boxes are checked before indicating that this pull request is ready for review.
Summary by CodeRabbit
Release Notes
New Features
tensor_view()method for tensor-like access to pixel data with rich indexing supportDocumentation
Tests