Skip to content

Commit 96a7b3c

Browse files
committed
test(jax): isolate training test in subprocess
Run the JAX training end-to-end test in a child Python process so CUDA/XLA teardown failures cannot poison the parent pytest process. Shrink the model and use the single-frame water data to minimize memory use.\n\nAuthored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.5)
1 parent a0d66c0 commit 96a7b3c

1 file changed

Lines changed: 37 additions & 71 deletions

File tree

source/tests/jax/test_training.py

Lines changed: 37 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,17 @@
22
"""End-to-end tests for the local JAX training entrypoint."""
33

44
import argparse
5-
import functools
65
import json
76
import os
87
import shutil
9-
import signal
8+
import subprocess
9+
import sys
1010
import tempfile
11+
import textwrap
1112
import unittest
12-
from collections.abc import (
13-
Callable,
14-
)
15-
from copy import (
16-
deepcopy,
17-
)
1813
from pathlib import (
1914
Path,
2015
)
21-
from typing import (
22-
Any,
23-
TypeVar,
24-
cast,
25-
)
2616
from unittest.mock import (
2717
patch,
2818
)
@@ -33,65 +23,49 @@
3323
from deepmd.jax.entrypoints.main import (
3424
main,
3525
)
36-
from deepmd.jax.entrypoints.train import (
37-
train,
38-
)
3926
from deepmd.utils.compat import (
4027
convert_optimizer_v31_to_v32,
4128
)
4229

43-
_F = TypeVar("_F", bound=Callable[..., Any])
44-
45-
46-
def _training_timeout(seconds: int) -> Callable[[_F], _F]:
47-
"""Limit real training tests on platforms that support SIGALRM."""
48-
49-
def decorate(func: _F) -> _F:
50-
if not hasattr(signal, "SIGALRM"):
51-
return func
52-
53-
@functools.wraps(func)
54-
def wrapped(*args: Any, **kwargs: Any) -> Any:
55-
def raise_timeout(signum: int, frame: Any) -> None:
56-
raise TimeoutError(f"training test exceeded {seconds} seconds")
57-
58-
previous_handler = signal.signal(signal.SIGALRM, raise_timeout)
59-
signal.alarm(seconds)
60-
try:
61-
return func(*args, **kwargs)
62-
finally:
63-
signal.alarm(0)
64-
signal.signal(signal.SIGALRM, previous_handler)
65-
66-
return cast("_F", wrapped)
67-
68-
return decorate
69-
70-
71-
TRAINING_TEST_TIMEOUT = _training_timeout(60)
72-
7330
MODEL_SE_E2_A = {
7431
"type_map": ["O", "H", "B"],
7532
"descriptor": {
7633
"type": "se_e2_a",
77-
"sel": [46, 92, 4],
34+
"sel": [6, 12, 1],
7835
"rcut_smth": 0.50,
7936
"rcut": 4.00,
80-
"neuron": [25, 50, 100],
37+
"neuron": [2, 4, 8],
8138
"resnet_dt": False,
82-
"axis_neuron": 16,
39+
"axis_neuron": 2,
8340
"type_one_side": True,
8441
"seed": 1,
8542
},
8643
"fitting_net": {
87-
"neuron": [24, 24, 24],
44+
"neuron": [4, 4, 4],
8845
"resnet_dt": True,
8946
"seed": 1,
9047
},
91-
"data_stat_nbatch": 20,
48+
"data_stat_nbatch": 1,
9249
}
9350

9451

52+
TRAINING_SCRIPT = """
53+
from pathlib import Path
54+
from unittest.mock import patch
55+
56+
from deepmd.main import main
57+
58+
with patch("deepmd.jax.entrypoints.train.SummaryPrinter.__call__"):
59+
main(["--jax", "train", "input.json", "--log-level", "2"])
60+
61+
for path in ["out.json", "lcurve.out", "checkpoint", "model-1.jax"]:
62+
if not Path(path).exists():
63+
raise FileNotFoundError(path)
64+
if "1" not in Path("lcurve.out").read_text():
65+
raise AssertionError("lcurve.out does not contain the first training step")
66+
"""
67+
68+
9569
class TestJAXTraining(unittest.TestCase):
9670
"""Regression tests for complete JAX training runs."""
9771

@@ -103,12 +77,12 @@ def setUp(self) -> None:
10377

10478
source_dir = Path(__file__).resolve().parents[1] / "pt" / "water"
10579
shutil.copytree(source_dir, self.work_dir / "water")
106-
data_file = [str(self.work_dir / "water" / "data" / "data_0")]
80+
data_file = [str(self.work_dir / "water" / "data" / "single")]
10781

10882
with (self.work_dir / "water" / "se_atten.json").open() as f:
10983
self.config = json.load(f)
11084
self.config = convert_optimizer_v31_to_v32(self.config, warning=False)
111-
self.config["model"] = deepcopy(MODEL_SE_E2_A)
85+
self.config["model"] = MODEL_SE_E2_A
11286
self.config["model"]["data_stat_nbatch"] = 1
11387
self.config["training"]["training_data"]["systems"] = data_file
11488
self.config["training"]["validation_data"]["systems"] = data_file
@@ -126,26 +100,18 @@ def tearDown(self) -> None:
126100
os.chdir(self.cwd)
127101
shutil.rmtree(self.work_dir)
128102

129-
@TRAINING_TEST_TIMEOUT
130-
@patch("deepmd.jax.entrypoints.train.SummaryPrinter.__call__")
131-
def test_train_entrypoint_runs_one_step_from_scratch(self, _summary) -> None:
132-
"""Run local JAX training and check that expected artifacts are written."""
133-
train(
134-
INPUT=str(self.input_file),
135-
init_model=None,
136-
restart=None,
137-
output="out.json",
138-
init_frz_model=None,
139-
mpi_log="master",
140-
log_level=2,
141-
log_path=None,
103+
def test_train_entrypoint_runs_one_step_from_scratch(self) -> None:
104+
"""Run local JAX training in a child process and check artifacts."""
105+
proc = subprocess.run(
106+
[sys.executable, "-c", textwrap.dedent(TRAINING_SCRIPT)],
107+
cwd=self.work_dir,
108+
text=True,
109+
capture_output=True,
110+
timeout=60,
111+
check=False,
142112
)
143113

144-
self.assertTrue(Path("out.json").is_file())
145-
self.assertTrue(Path("lcurve.out").is_file())
146-
self.assertTrue(Path("checkpoint").is_file())
147-
self.assertTrue(Path("model-1.jax").is_dir())
148-
self.assertIn("1", Path("lcurve.out").read_text())
114+
self.assertEqual(proc.returncode, 0, proc.stdout + proc.stderr)
149115

150116
@patch("deepmd.jax.entrypoints.freeze.deserialize_to_file")
151117
@patch("deepmd.jax.entrypoints.freeze.serialize_from_file")

0 commit comments

Comments
 (0)