diff --git a/source/tests/pt_expt/conftest.py b/source/tests/pt_expt/conftest.py index f2a9b07a6a..3b48d2e51d 100644 --- a/source/tests/pt_expt/conftest.py +++ b/source/tests/pt_expt/conftest.py @@ -12,11 +12,21 @@ """ import pytest +import torch._inductor.config as _inductor_config import torch.utils._device as _device from torch.overrides import ( _get_current_function_mode_stack, ) +# Reduce AOTInductor (.pt2) compile time for unit tests. +# Tests only validate correctness, not runtime performance, so we can +# skip expensive C++ optimizations. This cuts compile time by ~50%. +_inductor_config.max_fusion_size = 8 +_inductor_config.epilogue_fusion = False +_inductor_config.pattern_matcher = False +_inductor_config.aot_inductor.package_cpp_only = True +_inductor_config.aot_inductor.compile_wrapper_opt_level = "O0" + def _pop_device_contexts() -> list: """Pop all stale DeviceContext modes from the torch function mode stack.""" diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 6797fa2c03..0bc9a90a79 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -544,12 +544,13 @@ def setUpClass(cls) -> None: cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) cls.tmpfile.close() # Temporarily clear default device to avoid poisoning AOTInductor - # compilation (tests/pt/__init__.py sets it to "cuda:9999999"). + # compilation (tests/pt/__init__.py may set a fake CUDA device). + prev = torch.get_default_device() torch.set_default_device(None) try: deserialize_to_file(cls.tmpfile.name, cls.model_data) finally: - torch.set_default_device("cuda:9999999") + torch.set_default_device(prev) # Also save to .pte for cross-format comparison cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) @@ -606,6 +607,7 @@ def test_get_model_def_script_with_params(self) -> None: with tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) as f: tmpfile2 = f.name try: + prev = torch.get_default_device() torch.set_default_device(None) try: data_with_config = { @@ -614,7 +616,7 @@ def test_get_model_def_script_with_params(self) -> None: } deserialize_to_file(tmpfile2, data_with_config) finally: - torch.set_default_device("cuda:9999999") + torch.set_default_device(prev) dp2 = DeepPot(tmpfile2) mds = dp2.deep_eval.get_model_def_script() self.assertEqual(mds, training_config) @@ -970,11 +972,12 @@ def setUpClass(cls) -> None: cls.model_data = {"model": cls.model.serialize()} cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) cls.tmpfile.close() + prev = torch.get_default_device() torch.set_default_device(None) try: deserialize_to_file(cls.tmpfile.name, cls.model_data) finally: - torch.set_default_device("cuda:9999999") + torch.set_default_device(prev) # Also save .pte for cross-format comparison cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) @@ -1185,11 +1188,12 @@ def setUpClass(cls) -> None: cls.model_data = {"model": cls.model.serialize()} cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pt2", delete=False) cls.tmpfile.close() + prev = torch.get_default_device() torch.set_default_device(None) try: deserialize_to_file(cls.tmpfile.name, cls.model_data) finally: - torch.set_default_device("cuda:9999999") + torch.set_default_device(prev) # Also save .pte for cross-format comparison cls.pte_tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False)