Skip to content

Commit 13a7adb

Browse files
Add sycl queue validation for mask and values
1 parent 764a9e5 commit 13a7adb

2 files changed

Lines changed: 20 additions & 2 deletions

File tree

dpnp/dpnp_iface_indexing.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,8 +1870,17 @@ def putmask(a, /, mask, values):
18701870
if a.dtype != values.dtype:
18711871
values = dpnp.astype(values, a.dtype, casting="safe", copy=False)
18721872

1873-
_, exec_q = get_usm_allocations([a, mask, values])
1874-
1873+
exec_q = a.sycl_queue
1874+
if (
1875+
dpu.get_execution_queue(
1876+
[exec_q, mask.sycl_queue, values.sycl_queue]
1877+
)
1878+
is None
1879+
):
1880+
raise ValueError(
1881+
"`mask` and `values` must be allocated on "
1882+
"the same SYCL queue as `a`"
1883+
)
18751884
_manager = dpu.SequentialOrderManager[exec_q]
18761885
dep_evs = _manager.submitted_events
18771886

dpnp/tests/test_indexing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,15 @@ def test_putmask_errors(self):
11171117
# values as list
11181118
assert_raises(TypeError, dpnp.putmask, a, mask, [1, 2, 3])
11191119

1120+
# values has a different SYCL queue
1121+
q1 = dpctl.SyclQueue()
1122+
q2 = dpctl.SyclQueue()
1123+
a = dpnp.arange(10, sycl_queue=q1)
1124+
mask = a > 3
1125+
val = dpnp.arange(5, sycl_queue=q2)
1126+
if q1 != q2:
1127+
assert_raises(ValueError, dpnp.putmask, a, mask, val)
1128+
11201129

11211130
@pytest.mark.parametrize("m", [None, 0, 1, 2, 3, 4])
11221131
@pytest.mark.parametrize("k", [-3, -2, -1, 0, 1, 2, 3])

0 commit comments

Comments
 (0)