Skip to content

Commit d523061

Browse files
authored
Fix device-aware dtype handling in identity, gradient functions (#2835)
This PR ensures that default dtype selection respects device-specific capabilities across multiple functions. The PR includes changes: - `dpnp.identity`: Remove redundant default dtype handling. The function now delegates dtype resolution to dpnp.eye(), which already handles device-aware default types correctly. - `dpnp.gradient`: Pass sycl_queue parameter to default_float_type() calls to ensure the selected float type is compatible with the device where the array resides. This prevents issues when converting integer arrays on devices with different dtype support. The PR also updates SYCL queue tests to fix the parametrization to generate device-dtype pairs using a new get_all_dev_dtypes() helper. Each device is now tested only with dtypes it actually supports (e.g., devices without fp64 support won't test fp64), preventing false failures and unnecessary test combinations.
1 parent e36835d commit d523061

File tree

4 files changed

+31
-14
lines changed

4 files changed

+31
-14
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
8383
* Fixed `.data.ptr` property on array views to correctly return the pointer to the view's data location instead of the base allocation pointer [#2812](https://github.com/IntelPython/dpnp/pull/2812)
8484
* Resolved an issue with strides calculation in `dpnp.diagonal` to return correct values for empty diagonals [#2814](https://github.com/IntelPython/dpnp/pull/2814)
8585
* Fixed test tolerance issues for float16 intermediate precision that became visible when testing against conda-forge's NumPy [#2828](https://github.com/IntelPython/dpnp/pull/2828)
86+
* Ensured device aware dtype handling in `dpnp.identity` and `dpnp.gradient` [#2835](https://github.com/IntelPython/dpnp/pull/2835)
8687

8788
### Security
8889

dpnp/dpnp_iface_arraycreation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2664,10 +2664,9 @@ def identity(
26642664

26652665
dpnp.check_limitations(like=like)
26662666

2667-
_dtype = dpnp.default_float_type() if dtype is None else dtype
26682667
return dpnp.eye(
26692668
n,
2670-
dtype=_dtype,
2669+
dtype=dtype,
26712670
device=device,
26722671
usm_type=usm_type,
26732672
sycl_queue=sycl_queue,

dpnp/dpnp_iface_mathematical.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def _gradient_build_dx(f, axes, *varargs):
141141
if dpnp.issubdtype(distances.dtype, dpnp.integer):
142142
# Convert integer types to default float type to avoid modular
143143
# arithmetic in dpnp.diff(distances).
144-
distances = distances.astype(dpnp.default_float_type())
144+
distances = distances.astype(
145+
dpnp.default_float_type(sycl_queue=f.sycl_queue)
146+
)
145147
diffx = dpnp.diff(distances)
146148

147149
# if distances are constant reduce to the scalar case
@@ -2707,9 +2709,9 @@ def gradient(f, *varargs, axis=None, edge_order=1):
27072709
# All other types convert to floating point.
27082710
# First check if f is a dpnp integer type; if so, convert f to default
27092711
# float type to avoid modular arithmetic when computing changes in f.
2710-
if dpnp.issubdtype(otype, dpnp.integer):
2711-
f = f.astype(dpnp.default_float_type())
2712-
otype = dpnp.default_float_type()
2712+
otype = dpnp.default_float_type(sycl_queue=f.sycl_queue)
2713+
if dpnp.issubdtype(f.dtype, dpnp.integer):
2714+
f = f.astype(otype)
27132715

27142716
for axis_, ax_dx in zip(axes, dx):
27152717
if f.shape[axis_] < edge_order + 1:

dpnp/tests/test_sycl_queue.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,23 @@ def assert_sycl_queue_equal(result, expected):
5454
assert exec_queue is not None
5555

5656

57+
def get_all_dev_dtypes(no_float16=True, no_none=True):
58+
"""
59+
Build a list of (device, dtype) combinations for each device's
60+
supported dtype.
61+
62+
"""
63+
64+
device_dtype_pairs = []
65+
for device in valid_dev:
66+
dtypes = get_all_dtypes(
67+
no_float16=no_float16, no_none=no_none, device=device
68+
)
69+
for dtype in dtypes:
70+
device_dtype_pairs.append((device, dtype))
71+
return device_dtype_pairs
72+
73+
5774
@pytest.mark.parametrize(
5875
"func, arg, kwargs",
5976
[
@@ -1082,11 +1099,10 @@ def test_array_creation_from_dpctl(copy, device):
10821099
assert isinstance(result, dpnp_array)
10831100

10841101

1085-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
1086-
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True))
1102+
@pytest.mark.parametrize("device, dt", get_all_dev_dtypes())
10871103
@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)])
1088-
def test_from_dlpack(arr_dtype, shape, device):
1089-
X = dpnp.ones(shape=shape, dtype=arr_dtype, device=device)
1104+
def test_from_dlpack(shape, device, dt):
1105+
X = dpnp.ones(shape=shape, dtype=dt, device=device)
10901106
Y = dpnp.from_dlpack(X)
10911107
assert_array_equal(X, Y)
10921108
assert X.__dlpack_device__() == Y.__dlpack_device__()
@@ -1098,10 +1114,9 @@ def test_from_dlpack(arr_dtype, shape, device):
10981114
assert V.strides == W.strides
10991115

11001116

1101-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
1102-
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True))
1103-
def test_from_dlpack_with_dpt(arr_dtype, device):
1104-
X = dpt.ones((64,), dtype=arr_dtype, device=device)
1117+
@pytest.mark.parametrize("device, dt", get_all_dev_dtypes())
1118+
def test_from_dlpack_with_dpt(device, dt):
1119+
X = dpt.ones((64,), dtype=dt, device=device)
11051120
Y = dpnp.from_dlpack(X)
11061121
assert_array_equal(X, Y)
11071122
assert isinstance(Y, dpnp.dpnp_array.dpnp_array)

0 commit comments

Comments
 (0)