Skip to content

Commit 9b14bb7

Browse files
author
Han Wang
committed
fix: use PT_EXPT_DEVICE in consistent activation tests
tests/pt/__init__.py sets torch.set_default_device("cuda:9999999"), which causes bare torch.tensor() calls to attempt CUDA init on CPU-only CI. Use the pt_expt DEVICE (same pattern as the pt tests use their own DEVICE via to_torch_tensor).
1 parent 69dc9c1 commit 9b14bb7

1 file changed

Lines changed: 7 additions & 2 deletions

File tree

source/tests/consistent/test_activation.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
if INSTALLED_PT_EXPT:
3434
import torch
3535

36+
from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE
3637
from deepmd.pt_expt.utils.network import (
3738
_torch_activation,
3839
)
@@ -109,7 +110,9 @@ def test_pd_consistent_with_ref(self):
109110
@unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed")
110111
def test_pt_expt_consistent_with_ref(self) -> None:
111112
if INSTALLED_PT_EXPT:
112-
x = torch.tensor(self.random_input, dtype=torch.float64)
113+
x = torch.tensor(
114+
self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE
115+
)
113116
test = _torch_activation(x, self.activation).detach().numpy()
114117
np.testing.assert_allclose(self.ref, test, atol=1e-10)
115118

@@ -149,6 +152,8 @@ def test_pt_consistent_with_ref(self) -> None:
149152
@unittest.skipUnless(INSTALLED_PT_EXPT, "PyTorch Exportable is not installed")
150153
def test_pt_expt_consistent_with_ref(self) -> None:
151154
if INSTALLED_PT_EXPT:
152-
x = torch.tensor(self.random_input, dtype=torch.float64)
155+
x = torch.tensor(
156+
self.random_input, dtype=torch.float64, device=PT_EXPT_DEVICE
157+
)
153158
test = _torch_activation(x, self.activation).detach().numpy()
154159
np.testing.assert_allclose(self.ref, test, atol=1e-10)

0 commit comments

Comments
 (0)