Skip to content

Commit f4dc0af

Browse files
author
Han Wang
committed
fix device of xp array
1 parent 4de9a56 commit f4dc0af

1 file changed

Lines changed: 5 additions & 2 deletions

File tree

deepmd/dpmodel/descriptor/se_r.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,12 @@ def compute_input_stats(
309309
self.stats = env_mat_stat.stats
310310
mean, stddev = env_mat_stat()
311311
xp = array_api_compat.array_namespace(self.dstd)
312+
device = array_api_compat.device(self.dstd)
312313
if not self.set_davg_zero:
313-
self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True)
314-
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True)
314+
self.davg = xp.asarray(
315+
mean, dtype=self.davg.dtype, copy=True, device=device
316+
)
317+
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device)
315318

316319
def set_stat_mean_and_stddev(
317320
self,

0 commit comments

Comments
 (0)