22"""End-to-end tests for the local JAX training entrypoint."""
33
44import argparse
5- import functools
65import json
76import os
87import shutil
9- import signal
8+ import subprocess
9+ import sys
1010import tempfile
11+ import textwrap
1112import unittest
12- from collections .abc import (
13- Callable ,
14- )
15- from copy import (
16- deepcopy ,
17- )
1813from pathlib import (
1914 Path ,
2015)
21- from typing import (
22- Any ,
23- TypeVar ,
24- cast ,
25- )
2616from unittest .mock import (
2717 patch ,
2818)
3323from deepmd .jax .entrypoints .main import (
3424 main ,
3525)
36- from deepmd .jax .entrypoints .train import (
37- train ,
38- )
3926from deepmd .utils .compat import (
4027 convert_optimizer_v31_to_v32 ,
4128)
4229
43- _F = TypeVar ("_F" , bound = Callable [..., Any ])
44-
45-
46- def _training_timeout (seconds : int ) -> Callable [[_F ], _F ]:
47- """Limit real training tests on platforms that support SIGALRM."""
48-
49- def decorate (func : _F ) -> _F :
50- if not hasattr (signal , "SIGALRM" ):
51- return func
52-
53- @functools .wraps (func )
54- def wrapped (* args : Any , ** kwargs : Any ) -> Any :
55- def raise_timeout (signum : int , frame : Any ) -> None :
56- raise TimeoutError (f"training test exceeded { seconds } seconds" )
57-
58- previous_handler = signal .signal (signal .SIGALRM , raise_timeout )
59- signal .alarm (seconds )
60- try :
61- return func (* args , ** kwargs )
62- finally :
63- signal .alarm (0 )
64- signal .signal (signal .SIGALRM , previous_handler )
65-
66- return cast ("_F" , wrapped )
67-
68- return decorate
69-
70-
71- TRAINING_TEST_TIMEOUT = _training_timeout (60 )
72-
7330MODEL_SE_E2_A = {
7431 "type_map" : ["O" , "H" , "B" ],
7532 "descriptor" : {
7633 "type" : "se_e2_a" ,
77- "sel" : [46 , 92 , 4 ],
34+ "sel" : [6 , 12 , 1 ],
7835 "rcut_smth" : 0.50 ,
7936 "rcut" : 4.00 ,
80- "neuron" : [25 , 50 , 100 ],
37+ "neuron" : [2 , 4 , 8 ],
8138 "resnet_dt" : False ,
82- "axis_neuron" : 16 ,
39+ "axis_neuron" : 2 ,
8340 "type_one_side" : True ,
8441 "seed" : 1 ,
8542 },
8643 "fitting_net" : {
87- "neuron" : [24 , 24 , 24 ],
44+ "neuron" : [4 , 4 , 4 ],
8845 "resnet_dt" : True ,
8946 "seed" : 1 ,
9047 },
91- "data_stat_nbatch" : 20 ,
48+ "data_stat_nbatch" : 1 ,
9249}
9350
9451
52+ TRAINING_SCRIPT = """
53+ from pathlib import Path
54+ from unittest.mock import patch
55+
56+ from deepmd.main import main
57+
58+ with patch("deepmd.jax.entrypoints.train.SummaryPrinter.__call__"):
59+ main(["--jax", "train", "input.json", "--log-level", "2"])
60+
61+ for path in ["out.json", "lcurve.out", "checkpoint", "model-1.jax"]:
62+ if not Path(path).exists():
63+ raise FileNotFoundError(path)
64+ if "1" not in Path("lcurve.out").read_text():
65+ raise AssertionError("lcurve.out does not contain the first training step")
66+ """
67+
68+
9569class TestJAXTraining (unittest .TestCase ):
9670 """Regression tests for complete JAX training runs."""
9771
@@ -103,12 +77,12 @@ def setUp(self) -> None:
10377
10478 source_dir = Path (__file__ ).resolve ().parents [1 ] / "pt" / "water"
10579 shutil .copytree (source_dir , self .work_dir / "water" )
106- data_file = [str (self .work_dir / "water" / "data" / "data_0 " )]
80+ data_file = [str (self .work_dir / "water" / "data" / "single " )]
10781
10882 with (self .work_dir / "water" / "se_atten.json" ).open () as f :
10983 self .config = json .load (f )
11084 self .config = convert_optimizer_v31_to_v32 (self .config , warning = False )
111- self .config ["model" ] = deepcopy ( MODEL_SE_E2_A )
85+ self .config ["model" ] = MODEL_SE_E2_A
11286 self .config ["model" ]["data_stat_nbatch" ] = 1
11387 self .config ["training" ]["training_data" ]["systems" ] = data_file
11488 self .config ["training" ]["validation_data" ]["systems" ] = data_file
@@ -126,26 +100,18 @@ def tearDown(self) -> None:
126100 os .chdir (self .cwd )
127101 shutil .rmtree (self .work_dir )
128102
129- @TRAINING_TEST_TIMEOUT
130- @patch ("deepmd.jax.entrypoints.train.SummaryPrinter.__call__" )
131- def test_train_entrypoint_runs_one_step_from_scratch (self , _summary ) -> None :
132- """Run local JAX training and check that expected artifacts are written."""
133- train (
134- INPUT = str (self .input_file ),
135- init_model = None ,
136- restart = None ,
137- output = "out.json" ,
138- init_frz_model = None ,
139- mpi_log = "master" ,
140- log_level = 2 ,
141- log_path = None ,
103+ def test_train_entrypoint_runs_one_step_from_scratch (self ) -> None :
104+ """Run local JAX training in a child process and check artifacts."""
105+ proc = subprocess .run (
106+ [sys .executable , "-c" , textwrap .dedent (TRAINING_SCRIPT )],
107+ cwd = self .work_dir ,
108+ text = True ,
109+ capture_output = True ,
110+ timeout = 60 ,
111+ check = False ,
142112 )
143113
144- self .assertTrue (Path ("out.json" ).is_file ())
145- self .assertTrue (Path ("lcurve.out" ).is_file ())
146- self .assertTrue (Path ("checkpoint" ).is_file ())
147- self .assertTrue (Path ("model-1.jax" ).is_dir ())
148- self .assertIn ("1" , Path ("lcurve.out" ).read_text ())
114+ self .assertEqual (proc .returncode , 0 , proc .stdout + proc .stderr )
149115
150116 @patch ("deepmd.jax.entrypoints.freeze.deserialize_to_file" )
151117 @patch ("deepmd.jax.entrypoints.freeze.serialize_from_file" )
0 commit comments