Skip to content

Commit 281989d

Browse files
author
Han Wang
committed
fix: reset default device before .pt2 AOTInductor compilation
tests/pt/__init__.py may set a fake default device for CPU fallback, which poisons AOTInductor compilation. Temporarily clear the default device before converting to .pt2, matching the pattern used in test_change_bias.py.
1 parent 34d4cb5 commit 281989d

1 file changed

Lines changed: 13 additions & 1 deletion

File tree

source/tests/infer/test_models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,19 @@ def setUpClass(cls) -> None:
4848
"se_e2_r type_one_side is not supported for PyTorch models"
4949
)
5050
cls.case = get_cases()[key]
51-
cls.model_name = cls.case.get_model(extension)
51+
if extension == ".pt2":
52+
import torch
53+
54+
# Clear default device: tests/pt/__init__.py may set a fake
55+
# device for CPU fallback, which poisons AOTInductor compilation.
56+
saved_device = torch.get_default_device()
57+
torch.set_default_device(None)
58+
try:
59+
cls.model_name = cls.case.get_model(extension)
60+
finally:
61+
torch.set_default_device(saved_device)
62+
else:
63+
cls.model_name = cls.case.get_model(extension)
5264
cls.dp = DeepEval(cls.model_name)
5365

5466
@classmethod

0 commit comments

Comments
 (0)