Skip to content

Commit 3e47f4a

Browse files
authored
Merge branch 'IntelPython:master' into feature-sparse-linalg-solvers
2 parents ac3bed5 + 2a78c06 commit 3e47f4a

File tree

7 files changed

+35
-18
lines changed

7 files changed

+35
-18
lines changed

.github/workflows/build-sphinx.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ jobs:
224224
if: env.GH_EVENT_OPEN_PR_UPSTREAM == 'true'
225225
env:
226226
PR_NUM: ${{ github.event.number }}
227-
uses: mshick/add-pr-comment@ffd016c7e151d97d69d21a843022fd4cd5b96fe5 # v3.9.0.8.3.9.0
227+
uses: mshick/add-pr-comment@64b8e914979889d746c99dea15a76e77ef64580a # v3.10.0.8.3.10.0
228228
with:
229229
message-id: url_to_docs
230230
message: |
@@ -268,7 +268,7 @@ jobs:
268268
git push tokened_docs gh-pages
269269
270270
- name: Modify the comment with URL to official documentation
271-
uses: mshick/add-pr-comment@ffd016c7e151d97d69d21a843022fd4cd5b96fe5 # v3.9.0.8.3.9.0
271+
uses: mshick/add-pr-comment@64b8e914979889d746c99dea15a76e77ef64580a # v3.10.0.8.3.10.0
272272
with:
273273
message-id: url_to_docs
274274
find: |

.github/workflows/conda-package.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ jobs:
654654

655655
- name: Post result to PR
656656
if: ${{ github.event.pull_request && !github.event.pull_request.head.repo.fork }}
657-
uses: mshick/add-pr-comment@ffd016c7e151d97d69d21a843022fd4cd5b96fe5 # v3.9.0.8.3.9.0
657+
uses: mshick/add-pr-comment@64b8e914979889d746c99dea15a76e77ef64580a # v3.10.0.8.3.10.0
658658
with:
659659
message-id: array_api_results
660660
message: |

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ repos:
124124
- id: pretty-format-toml
125125
args: [--autofix]
126126
- repo: https://github.com/rhysd/actionlint
127-
rev: v1.7.11
127+
rev: v1.7.12
128128
hooks:
129129
- id: actionlint
130130
- repo: https://github.com/BlankSpruce/gersemi

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
8383
* Fixed `.data.ptr` property on array views to correctly return the pointer to the view's data location instead of the base allocation pointer [#2812](https://github.com/IntelPython/dpnp/pull/2812)
8484
* Resolved an issue with strides calculation in `dpnp.diagonal` to return correct values for empty diagonals [#2814](https://github.com/IntelPython/dpnp/pull/2814)
8585
* Fixed test tolerance issues for float16 intermediate precision that became visible when testing against conda-forge's NumPy [#2828](https://github.com/IntelPython/dpnp/pull/2828)
86+
* Ensured device aware dtype handling in `dpnp.identity` and `dpnp.gradient` [#2835](https://github.com/IntelPython/dpnp/pull/2835)
8687

8788
### Security
8889

dpnp/dpnp_iface_arraycreation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2664,10 +2664,9 @@ def identity(
26642664

26652665
dpnp.check_limitations(like=like)
26662666

2667-
_dtype = dpnp.default_float_type() if dtype is None else dtype
26682667
return dpnp.eye(
26692668
n,
2670-
dtype=_dtype,
2669+
dtype=dtype,
26712670
device=device,
26722671
usm_type=usm_type,
26732672
sycl_queue=sycl_queue,

dpnp/dpnp_iface_mathematical.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,9 @@ def _gradient_build_dx(f, axes, *varargs):
141141
if dpnp.issubdtype(distances.dtype, dpnp.integer):
142142
# Convert integer types to default float type to avoid modular
143143
# arithmetic in dpnp.diff(distances).
144-
distances = distances.astype(dpnp.default_float_type())
144+
distances = distances.astype(
145+
dpnp.default_float_type(sycl_queue=f.sycl_queue)
146+
)
145147
diffx = dpnp.diff(distances)
146148

147149
# if distances are constant reduce to the scalar case
@@ -2707,9 +2709,9 @@ def gradient(f, *varargs, axis=None, edge_order=1):
27072709
# All other types convert to floating point.
27082710
# First check if f is a dpnp integer type; if so, convert f to default
27092711
# float type to avoid modular arithmetic when computing changes in f.
2710-
if dpnp.issubdtype(otype, dpnp.integer):
2711-
f = f.astype(dpnp.default_float_type())
2712-
otype = dpnp.default_float_type()
2712+
otype = dpnp.default_float_type(sycl_queue=f.sycl_queue)
2713+
if dpnp.issubdtype(f.dtype, dpnp.integer):
2714+
f = f.astype(otype)
27132715

27142716
for axis_, ax_dx in zip(axes, dx):
27152717
if f.shape[axis_] < edge_order + 1:

dpnp/tests/test_sycl_queue.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,23 @@ def assert_sycl_queue_equal(result, expected):
5454
assert exec_queue is not None
5555

5656

57+
def get_all_dev_dtypes(no_float16=True, no_none=True):
58+
"""
59+
Build a list of (device, dtype) combinations for each device's
60+
supported dtype.
61+
62+
"""
63+
64+
device_dtype_pairs = []
65+
for device in valid_dev:
66+
dtypes = get_all_dtypes(
67+
no_float16=no_float16, no_none=no_none, device=device
68+
)
69+
for dtype in dtypes:
70+
device_dtype_pairs.append((device, dtype))
71+
return device_dtype_pairs
72+
73+
5774
@pytest.mark.parametrize(
5875
"func, arg, kwargs",
5976
[
@@ -1082,11 +1099,10 @@ def test_array_creation_from_dpctl(copy, device):
10821099
assert isinstance(result, dpnp_array)
10831100

10841101

1085-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
1086-
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True))
1102+
@pytest.mark.parametrize("device, dt", get_all_dev_dtypes())
10871103
@pytest.mark.parametrize("shape", [tuple(), (2,), (3, 0, 1), (2, 2, 2)])
1088-
def test_from_dlpack(arr_dtype, shape, device):
1089-
X = dpnp.ones(shape=shape, dtype=arr_dtype, device=device)
1104+
def test_from_dlpack(shape, device, dt):
1105+
X = dpnp.ones(shape=shape, dtype=dt, device=device)
10901106
Y = dpnp.from_dlpack(X)
10911107
assert_array_equal(X, Y)
10921108
assert X.__dlpack_device__() == Y.__dlpack_device__()
@@ -1098,10 +1114,9 @@ def test_from_dlpack(arr_dtype, shape, device):
10981114
assert V.strides == W.strides
10991115

11001116

1101-
@pytest.mark.parametrize("device", valid_dev, ids=dev_ids)
1102-
@pytest.mark.parametrize("arr_dtype", get_all_dtypes(no_float16=True))
1103-
def test_from_dlpack_with_dpt(arr_dtype, device):
1104-
X = dpt.ones((64,), dtype=arr_dtype, device=device)
1117+
@pytest.mark.parametrize("device, dt", get_all_dev_dtypes())
1118+
def test_from_dlpack_with_dpt(device, dt):
1119+
X = dpt.ones((64,), dtype=dt, device=device)
11051120
Y = dpnp.from_dlpack(X)
11061121
assert_array_equal(X, Y)
11071122
assert isinstance(Y, dpnp.dpnp_array.dpnp_array)

0 commit comments

Comments
 (0)