Skip to content

Commit d69632e

Browse files
author
Han Wang
committed
fix(pt_expt): centralize default-device guard in AOTInductor compilation
AOTInductor's lowering code creates tensors without explicit device=, inheriting any active torch.set_default_device. This caused compilation failures when tests/pt/__init__.py set a fake CUDA device. Move the set_default_device(None) guard into _deserialize_to_file_pt2 so all callers (tests, dp freeze, dp compress) are protected, and remove the 12 scattered workarounds from test files.
1 parent 345d162 commit d69632e

4 files changed

Lines changed: 25 additions & 56 deletions

File tree

deepmd/pt_expt/utils/serialization.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,19 @@ def _deserialize_to_file_pt2(
554554
data, model_json_override
555555
)
556556

557-
# Compile via AOTInductor into a .pt2 package
558-
aoti_compile_and_package(exported, package_path=model_file)
557+
# AOTInductor's lowering code internally creates tensors (e.g.
558+
# ``torch.zeros``) without an explicit ``device=`` argument. If a
559+
# non-CPU default device is active (e.g. tests/pt/__init__.py sets
560+
# ``torch.set_default_device("cuda:9999999")``), the compilation fails
561+
# on CPU-only builds. Temporarily clear the default device so the
562+
# inductor always targets CPU.
563+
prev_device = torch.get_default_device()
564+
torch.set_default_device(None)
565+
try:
566+
# Compile via AOTInductor into a .pt2 package
567+
aoti_compile_and_package(exported, package_path=model_file)
568+
finally:
569+
torch.set_default_device(prev_device)
559570

560571
# Embed metadata into the .pt2 ZIP archive
561572
model_def_script = data.get("model_def_script") or {}

source/tests/pt_expt/infer/test_deep_eval.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -543,13 +543,7 @@ def setUpClass(cls) -> None:
543543
cls.model_data = {"model": cls.model.serialize()}
544544
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False)
545545
cls.tmpfile.close()
546-
# Temporarily clear default device to avoid poisoning AOTInductor
547-
# compilation (tests/pt/__init__.py sets it to "cuda:9999999").
548-
torch.set_default_device(None)
549-
try:
550-
deserialize_to_file(cls.tmpfile.name, cls.model_data)
551-
finally:
552-
torch.set_default_device("cuda:9999999")
546+
deserialize_to_file(cls.tmpfile.name, cls.model_data)
553547

554548
# Also save to .pte for cross-format comparison
555549
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
@@ -606,15 +600,11 @@ def test_get_model_def_script_with_params(self) -> None:
606600
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
607601
tmpfile2 = f.name
608602
try:
609-
torch.set_default_device(None)
610-
try:
611-
data_with_config = {
612-
**self.model_data,
613-
"model_def_script": training_config,
614-
}
615-
deserialize_to_file(tmpfile2, data_with_config)
616-
finally:
617-
torch.set_default_device("cuda:9999999")
603+
data_with_config = {
604+
**self.model_data,
605+
"model_def_script": training_config,
606+
}
607+
deserialize_to_file(tmpfile2, data_with_config)
618608
dp2 = DeepPot(tmpfile2)
619609
mds = dp2.deep_eval.get_model_def_script()
620610
self.assertEqual(mds, training_config)
@@ -970,11 +960,7 @@ def setUpClass(cls) -> None:
970960
cls.model_data = {"model": cls.model.serialize()}
971961
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False)
972962
cls.tmpfile.close()
973-
torch.set_default_device(None)
974-
try:
975-
deserialize_to_file(cls.tmpfile.name, cls.model_data)
976-
finally:
977-
torch.set_default_device("cuda:9999999")
963+
deserialize_to_file(cls.tmpfile.name, cls.model_data)
978964

979965
# Also save .pte for cross-format comparison
980966
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
@@ -1185,11 +1171,7 @@ def setUpClass(cls) -> None:
11851171
cls.model_data = {"model": cls.model.serialize()}
11861172
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False)
11871173
cls.tmpfile.close()
1188-
torch.set_default_device(None)
1189-
try:
1190-
deserialize_to_file(cls.tmpfile.name, cls.model_data)
1191-
finally:
1192-
torch.set_default_device("cuda:9999999")
1174+
deserialize_to_file(cls.tmpfile.name, cls.model_data)
11931175

11941176
# Also save .pte for cross-format comparison
11951177
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)

source/tests/pt_expt/infer/test_deep_eval_spin.py

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,7 @@ def spin_model_files():
154154
tmpdir = tempfile.mkdtemp()
155155
for ext in (".pt2", ".pte"):
156156
path = os.path.join(tmpdir, f"spin_test{ext}")
157-
# AOTInductor (.pt2) internally creates tensors using the PyTorch
158-
# default device. Clear it so compilation stays on CPU.
159-
prev = torch.get_default_device()
160-
torch.set_default_device(None)
161-
try:
162-
deserialize_to_file(path, copy.deepcopy(data))
163-
finally:
164-
torch.set_default_device(prev)
157+
deserialize_to_file(path, copy.deepcopy(data))
165158
files[ext] = path
166159
yield files, ref_pbc, ref_nopbc
167160
for path in files.values():
@@ -362,12 +355,7 @@ def spin_fparam_model_files():
362355
tmpdir = tempfile.mkdtemp()
363356
for ext in (".pt2", ".pte"):
364357
path = os.path.join(tmpdir, f"spin_fparam_test{ext}")
365-
prev = torch.get_default_device()
366-
torch.set_default_device(None)
367-
try:
368-
deserialize_to_file(path, copy.deepcopy(data))
369-
finally:
370-
torch.set_default_device(prev)
358+
deserialize_to_file(path, copy.deepcopy(data))
371359
files[ext] = path
372360
yield files
373361
for path in files.values():
@@ -426,12 +414,7 @@ def spin_aparam_model_files():
426414
tmpdir = tempfile.mkdtemp()
427415
for ext in (".pt2", ".pte"):
428416
path = os.path.join(tmpdir, f"spin_aparam_test{ext}")
429-
prev = torch.get_default_device()
430-
torch.set_default_device(None)
431-
try:
432-
deserialize_to_file(path, copy.deepcopy(data))
433-
finally:
434-
torch.set_default_device(prev)
417+
deserialize_to_file(path, copy.deepcopy(data))
435418
files[ext] = path
436419
yield files
437420
for path in files.values():

source/tests/pt_expt/test_change_bias.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,14 +156,7 @@ def setUpClass(cls) -> None:
156156
cls.shared_pte = os.path.join(cls.tmpdir, "shared.pte")
157157
freeze(model=cls.model_path, output=cls.shared_pte)
158158
cls.shared_pt2 = os.path.join(cls.tmpdir, "shared.pt2")
159-
# Clear default device: tests/pt/__init__.py may set a fake device
160-
# for CPU fallback, which poisons AOTInductor compilation.
161-
saved_device = torch.get_default_device()
162-
torch.set_default_device(None)
163-
try:
164-
freeze(model=cls.model_path, output=cls.shared_pt2)
165-
finally:
166-
torch.set_default_device(saved_device)
159+
freeze(model=cls.model_path, output=cls.shared_pt2)
167160

168161
@classmethod
169162
def tearDownClass(cls) -> None:

0 commit comments

Comments
 (0)