Skip to content

Commit 8d1c75b

Browse files
Extend dpctl_ext.tensor with the remaining functions (#2806)
This PR extends `dpctl_ext.tensor` API with the remaining statistical and testing functions adding `std(), var(), mean(), allclose()`
1 parent aa816fd commit 8d1c75b

File tree

11 files changed

+586
-22
lines changed

11 files changed

+586
-22
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,8 @@
179179
unique_values,
180180
)
181181
from ._sorting import argsort, sort, top_k
182+
from ._statistical_functions import mean, std, var
183+
from ._testing import allclose
182184
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
183185
from ._utility_functions import all, any, diff
184186

@@ -188,6 +190,7 @@
188190
"acosh",
189191
"add",
190192
"all",
193+
"allclose",
191194
"angle",
192195
"any",
193196
"arange",
@@ -267,6 +270,7 @@
267270
"log10",
268271
"max",
269272
"maximum",
273+
"mean",
270274
"meshgrid",
271275
"min",
272276
"minimum",
@@ -308,6 +312,7 @@
308312
"square",
309313
"squeeze",
310314
"stack",
315+
"std",
311316
"subtract",
312317
"sum",
313318
"swapaxes",
@@ -327,6 +332,7 @@
327332
"unique_inverse",
328333
"unique_values",
329334
"unstack",
335+
"var",
330336
"vecdot",
331337
"where",
332338
"zeros",

dpctl_ext/tensor/_clip.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828

2929
import dpctl
3030
import dpctl.tensor as dpt
31-
import dpctl.tensor._tensor_elementwise_impl as tei
3231
from dpctl.utils import ExecutionPlacementError, SequentialOrderManager
3332

3433
# TODO: revert to `import dpctl.tensor...`
3534
# when dpnp fully migrates dpctl/tensor
3635
import dpctl_ext.tensor as dpt_ext
36+
import dpctl_ext.tensor._tensor_elementwise_impl as tei
3737
import dpctl_ext.tensor._tensor_impl as ti
3838

3939
from ._copy_utils import (

dpctl_ext/tensor/_ctors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,7 @@ def _copy_through_host_walker(seq_o, usm_res):
361361
)
362362
is None
363363
):
364-
usm_res[...] = dpt.asnumpy(seq_o).copy()
364+
usm_res[...] = dpt_ext.asnumpy(seq_o).copy()
365365
return
366366
else:
367367
usm_res[...] = seq_o

dpctl_ext/tensor/_reduction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,7 +506,7 @@ def count_nonzero(x, /, *, axis=None, keepdims=False, out=None):
506506
type.
507507
"""
508508
if x.dtype != dpt.bool:
509-
x = dpt.astype(x, dpt.bool, copy=False)
509+
x = dpt_ext.astype(x, dpt.bool, copy=False)
510510
return sum(
511511
x,
512512
axis=axis,

dpctl_ext/tensor/_set_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@
3030

3131
import dpctl.tensor as dpt
3232
import dpctl.utils as du
33-
from dpctl.tensor._tensor_elementwise_impl import _not_equal, _subtract
3433

3534
# TODO: revert to `import dpctl.tensor...`
3635
# when dpnp fully migrates dpctl/tensor
3736
import dpctl_ext.tensor as dpt_ext
37+
from dpctl_ext.tensor._tensor_elementwise_impl import _not_equal, _subtract
3838

3939
from ._copy_utils import _empty_like_orderK
4040
from ._scalar_utils import (

0 commit comments

Comments
 (0)