Skip to content

Commit 2384835

Browse files
author
Han Wang
committed
fix device of xp array
1 parent f4dc0af commit 2384835

6 files changed

Lines changed: 38 additions & 12 deletions

File tree

deepmd/dpmodel/descriptor/dpa1.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -909,9 +909,14 @@ def compute_input_stats(
909909
self.stats = env_mat_stat.stats
910910
mean, stddev = env_mat_stat()
911911
xp = array_api_compat.array_namespace(self.stddev)
912+
device = array_api_compat.device(self.stddev)
912913
if not self.set_davg_zero:
913-
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
914-
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
914+
self.mean = xp.asarray(
915+
mean, dtype=self.mean.dtype, copy=True, device=device
916+
)
917+
self.stddev = xp.asarray(
918+
stddev, dtype=self.stddev.dtype, copy=True, device=device
919+
)
915920

916921
def get_stats(self) -> dict[str, StatItem]:
917922
"""Get the statistics of the descriptor."""

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -453,9 +453,14 @@ def compute_input_stats(
453453
self.stats = env_mat_stat.stats
454454
mean, stddev = env_mat_stat()
455455
xp = array_api_compat.array_namespace(self.stddev)
456+
device = array_api_compat.device(self.stddev)
456457
if not self.set_davg_zero:
457-
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
458-
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
458+
self.mean = xp.asarray(
459+
mean, dtype=self.mean.dtype, copy=True, device=device
460+
)
461+
self.stddev = xp.asarray(
462+
stddev, dtype=self.stddev.dtype, copy=True, device=device
463+
)
459464

460465
def get_stats(self) -> dict[str, StatItem]:
461466
"""Get the statistics of the descriptor."""

deepmd/dpmodel/descriptor/repformers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -417,9 +417,14 @@ def compute_input_stats(
417417
self.stats = env_mat_stat.stats
418418
mean, stddev = env_mat_stat()
419419
xp = array_api_compat.array_namespace(self.stddev)
420+
device = array_api_compat.device(self.stddev)
420421
if not self.set_davg_zero:
421-
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
422-
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
422+
self.mean = xp.asarray(
423+
mean, dtype=self.mean.dtype, copy=True, device=device
424+
)
425+
self.stddev = xp.asarray(
426+
stddev, dtype=self.stddev.dtype, copy=True, device=device
427+
)
423428

424429
def get_stats(self) -> dict[str, StatItem]:
425430
"""Get the statistics of the descriptor."""

deepmd/dpmodel/descriptor/se_e2_a.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,12 @@ def compute_input_stats(
350350
self.stats = env_mat_stat.stats
351351
mean, stddev = env_mat_stat()
352352
xp = array_api_compat.array_namespace(self.dstd)
353+
device = array_api_compat.device(self.dstd)
353354
if not self.set_davg_zero:
354-
self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True)
355-
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True)
355+
self.davg = xp.asarray(
356+
mean, dtype=self.davg.dtype, copy=True, device=device
357+
)
358+
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device)
356359

357360
def set_stat_mean_and_stddev(
358361
self,

deepmd/dpmodel/descriptor/se_t.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,9 +290,12 @@ def compute_input_stats(
290290
self.stats = env_mat_stat.stats
291291
mean, stddev = env_mat_stat()
292292
xp = array_api_compat.array_namespace(self.dstd)
293+
device = array_api_compat.device(self.dstd)
293294
if not self.set_davg_zero:
294-
self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True)
295-
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True)
295+
self.davg = xp.asarray(
296+
mean, dtype=self.davg.dtype, copy=True, device=device
297+
)
298+
self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device)
296299

297300
def set_stat_mean_and_stddev(
298301
self,

deepmd/dpmodel/descriptor/se_t_tebd.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,9 +694,14 @@ def compute_input_stats(
694694
self.stats = env_mat_stat.stats
695695
mean, stddev = env_mat_stat()
696696
xp = array_api_compat.array_namespace(self.stddev)
697+
device = array_api_compat.device(self.stddev)
697698
if not self.set_davg_zero:
698-
self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True)
699-
self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True)
699+
self.mean = xp.asarray(
700+
mean, dtype=self.mean.dtype, copy=True, device=device
701+
)
702+
self.stddev = xp.asarray(
703+
stddev, dtype=self.stddev.dtype, copy=True, device=device
704+
)
700705

701706
def get_stats(self) -> dict[str, StatItem]:
702707
"""Get the statistics of the descriptor."""

0 commit comments

Comments
 (0)