Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions source/tests/pt_expt/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,21 @@
"""

import pytest
import torch._inductor.config as _inductor_config
import torch.utils._device as _device
from torch.overrides import (
_get_current_function_mode_stack,
)

# Reduce AOTInductor (.pt2) compile time for unit tests.
# Tests only validate correctness, not runtime performance, so we can
# skip expensive C++ optimizations. This cuts compile time by ~50%.
_inductor_config.max_fusion_size = 8
_inductor_config.epilogue_fusion = False
_inductor_config.pattern_matcher = False
_inductor_config.aot_inductor.package_cpp_only = True
_inductor_config.aot_inductor.compile_wrapper_opt_level = "O0"


def _pop_device_contexts() -> list:
"""Pop all stale DeviceContext modes from the torch function mode stack."""
Expand Down
14 changes: 9 additions & 5 deletions source/tests/pt_expt/infer/test_deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,12 +544,13 @@ def setUpClass(cls) -> None:
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False)
cls.tmpfile.close()
# Temporarily clear default device to avoid poisoning AOTInductor
# compilation (tests/pt/__init__.py sets it to "cuda:9999999").
# compilation (tests/pt/__init__.py may set a fake CUDA device).
prev = torch.get_default_device()
torch.set_default_device(None)
try:
deserialize_to_file(cls.tmpfile.name, cls.model_data)
finally:
torch.set_default_device("cuda:9999999")
torch.set_default_device(prev)

# Also save to .pte for cross-format comparison
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
Expand Down Expand Up @@ -606,6 +607,7 @@ def test_get_model_def_script_with_params(self) -> None:
with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f:
tmpfile2 = f.name
try:
prev = torch.get_default_device()
torch.set_default_device(None)
try:
data_with_config = {
Expand All @@ -614,7 +616,7 @@ def test_get_model_def_script_with_params(self) -> None:
}
deserialize_to_file(tmpfile2, data_with_config)
finally:
torch.set_default_device("cuda:9999999")
torch.set_default_device(prev)
dp2 = DeepPot(tmpfile2)
mds = dp2.deep_eval.get_model_def_script()
self.assertEqual(mds, training_config)
Expand Down Expand Up @@ -970,11 +972,12 @@ def setUpClass(cls) -> None:
cls.model_data = {"model": cls.model.serialize()}
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False)
cls.tmpfile.close()
prev = torch.get_default_device()
torch.set_default_device(None)
try:
deserialize_to_file(cls.tmpfile.name, cls.model_data)
finally:
torch.set_default_device("cuda:9999999")
torch.set_default_device(prev)

# Also save .pte for cross-format comparison
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
Expand Down Expand Up @@ -1185,11 +1188,12 @@ def setUpClass(cls) -> None:
cls.model_data = {"model": cls.model.serialize()}
cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False)
cls.tmpfile.close()
prev = torch.get_default_device()
torch.set_default_device(None)
try:
deserialize_to_file(cls.tmpfile.name, cls.model_data)
finally:
torch.set_default_device("cuda:9999999")
torch.set_default_device(prev)

# Also save .pte for cross-format comparison
cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)
Expand Down
Loading