Skip to content

Commit 48703dc

Browse files
author
Han Wang
committed
fix(pt_expt): clear default device during spin .pt2 export in tests
AOTInductor internally creates tensors using PyTorch's default device. When pt/__init__.py sets torch.set_default_device("cuda:9999999") as a test sentinel, spin .pt2 compilation fails on CPU-only CI with "Torch not compiled with CUDA enabled". Clear the default device around deserialize_to_file calls in spin test fixtures, matching the pattern already used in non-spin .pt2 tests.
1 parent 590244a commit 48703dc

1 file changed

Lines changed: 20 additions & 3 deletions

File tree

source/tests/pt_expt/infer/test_deep_eval_spin.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,14 @@ def spin_model_files():
154154
tmpdir = tempfile.mkdtemp()
155155
for ext in (".pt2", ".pte"):
156156
path = os.path.join(tmpdir, f"spin_test{ext}")
157-
deserialize_to_file(path, copy.deepcopy(data))
157+
# AOTInductor (.pt2) internally creates tensors using the PyTorch
158+
# default device. Clear it so compilation stays on CPU.
159+
prev = torch.get_default_device()
160+
torch.set_default_device(None)
161+
try:
162+
deserialize_to_file(path, copy.deepcopy(data))
163+
finally:
164+
torch.set_default_device(prev)
158165
files[ext] = path
159166
yield files, ref_pbc, ref_nopbc
160167
for path in files.values():
@@ -355,7 +362,12 @@ def spin_fparam_model_files():
355362
tmpdir = tempfile.mkdtemp()
356363
for ext in (".pt2", ".pte"):
357364
path = os.path.join(tmpdir, f"spin_fparam_test{ext}")
358-
deserialize_to_file(path, copy.deepcopy(data))
365+
prev = torch.get_default_device()
366+
torch.set_default_device(None)
367+
try:
368+
deserialize_to_file(path, copy.deepcopy(data))
369+
finally:
370+
torch.set_default_device(prev)
359371
files[ext] = path
360372
yield files
361373
for path in files.values():
@@ -414,7 +426,12 @@ def spin_aparam_model_files():
414426
tmpdir = tempfile.mkdtemp()
415427
for ext in (".pt2", ".pte"):
416428
path = os.path.join(tmpdir, f"spin_aparam_test{ext}")
417-
deserialize_to_file(path, copy.deepcopy(data))
429+
prev = torch.get_default_device()
430+
torch.set_default_device(None)
431+
try:
432+
deserialize_to_file(path, copy.deepcopy(data))
433+
finally:
434+
torch.set_default_device(prev)
418435
files[ext] = path
419436
yield files
420437
for path in files.values():

0 commit comments

Comments
 (0)