Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cuda_core/cuda/core/_memoryview.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,8 @@ cpdef StridedMemoryView view_as_cai(obj, stream_ptr, view=None):
buf.exporting_obj = obj
buf.metadata = cai_data
buf.dl_tensor = NULL
# Validate shape/strides/typestr eagerly so constructor paths fail fast.
buf.get_layout()
buf.ptr, buf.readonly = cai_data["data"]
buf.is_device_accessible = True
if buf.ptr != 0:
Expand Down Expand Up @@ -1138,6 +1140,8 @@ cpdef StridedMemoryView view_as_array_interface(obj, view=None):
buf.exporting_obj = obj
buf.metadata = data
buf.dl_tensor = NULL
# Validate shape/strides/typestr eagerly so constructor paths fail fast.
buf.get_layout()
buf.ptr, buf.readonly = data["data"]
buf.is_device_accessible = False
buf.device_id = handle_return(driver.cuCtxGetDevice())
Expand Down
49 changes: 46 additions & 3 deletions cuda_core/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,10 +582,53 @@ def test_from_array_interface_unsupported_strides(init_cuda):
# Create an array with strides that aren't a multiple of itemsize
x = np.array([(1, 2.0), (3, 4.0)], dtype=[("a", "i4"), ("b", "f8")])
b = x["b"]
smv = StridedMemoryView.from_array_interface(b)
with pytest.raises(ValueError, match="strides must be divisible by itemsize"):
# TODO: ideally this would raise on construction
smv.strides # noqa: B018
StridedMemoryView.from_array_interface(b)


def _make_cuda_array_interface_obj(*, shape, strides, typestr="<f8", data=(0, False), version=3):
return type(
"SyntheticCAI",
(),
{
"__cuda_array_interface__": {
"shape": shape,
"strides": strides,
"typestr": typestr,
"data": data,
"version": version,
}
},
)()


def test_from_cuda_array_interface_unsupported_strides(init_cuda):
cai_obj = _make_cuda_array_interface_obj(shape=(2,), strides=(10,))
with pytest.raises(ValueError, match="strides must be divisible by itemsize"):
StridedMemoryView.from_cuda_array_interface(cai_obj, stream_ptr=-1)


def test_from_cuda_array_interface_zero_strides(init_cuda):
cai_obj = _make_cuda_array_interface_obj(shape=(1, 1), strides=(0, 0))
smv = StridedMemoryView.from_cuda_array_interface(cai_obj, stream_ptr=-1)
assert smv.shape == (1, 1)
assert smv.strides == (0, 0)


@pytest.mark.skipif(cp is None, reason="CuPy is not installed")
def test_from_cuda_array_interface_negative_strides(init_cuda):
x = cp.arange(4, dtype=cp.float64)[::-1]
smv = StridedMemoryView.from_cuda_array_interface(_EnforceCAIView(x), stream_ptr=-1)
assert smv.shape == x.shape
assert smv.strides == (-1,)


def test_from_cuda_array_interface_empty_array(init_cuda):
cai_obj = _make_cuda_array_interface_obj(shape=(0, 3), strides=(24, 8))
smv = StridedMemoryView.from_cuda_array_interface(cai_obj, stream_ptr=-1)
assert smv.size == 0
assert smv.shape == (0, 3)
assert smv.strides == (3, 1)


@pytest.mark.parametrize(
Expand Down
Loading