We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4de9a56 commit f4dc0afCopy full SHA for f4dc0af
1 file changed
deepmd/dpmodel/descriptor/se_r.py
@@ -309,9 +309,12 @@ def compute_input_stats(
309
self.stats = env_mat_stat.stats
310
mean, stddev = env_mat_stat()
311
xp = array_api_compat.array_namespace(self.dstd)
312
+ device = array_api_compat.device(self.dstd)
313
if not self.set_davg_zero:
- self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True)
314
- self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True)
+ self.davg = xp.asarray(
315
+ mean, dtype=self.davg.dtype, copy=True, device=device
316
+ )
317
+ self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device)
318
319
def set_stat_mean_and_stddev(
320
self,
0 commit comments