6464@modifier_args_plugin .register ("random_tester" , doc = doc_random_tester )
6565def modifier_random_tester () -> list :
6666 doc_seed = "Random seed used to initialize the random number generator for deterministic scaling factors."
67+ doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation."
6768 return [
6869 Argument ("seed" , int , optional = True , doc = doc_seed ),
70+ Argument ("use_cache" , bool , optional = True , doc = doc_use_cache ),
6971 ]
7072
7173
7274@modifier_args_plugin .register ("zero_tester" , doc = doc_zero_tester )
7375def modifier_zero_tester () -> list :
74- return []
76+ doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation."
77+ return [
78+ Argument ("use_cache" , bool , optional = True , doc = doc_use_cache ),
79+ ]
7580
7681
7782@modifier_args_plugin .register ("scaling_tester" , doc = doc_scaling_tester )
7883def modifier_scaling_tester () -> list [Argument ]:
7984 doc_model_name = "The name of the frozen energy model file."
8085 doc_sfactor = "The scaling factor for correction."
86+ doc_use_cache = "Whether to cache modified frames to improve performance by avoiding recomputation."
8187 return [
8288 Argument ("model_name" , str , optional = False , doc = doc_model_name ),
8389 Argument ("sfactor" , float , optional = False , doc = doc_sfactor ),
90+ Argument ("use_cache" , bool , optional = True , doc = doc_use_cache ),
8491 ]
8592
8693
@@ -92,12 +99,14 @@ def __new__(cls, *args, **kwargs):
9299 def __init__ (
93100 self ,
94101 seed : int = 1 ,
102+ use_cache : bool = True ,
95103 ) -> None :
96104 """Construct a random_tester modifier that scales data by deterministic random factors for testing."""
97- super ().__init__ ()
105+ super ().__init__ (use_cache )
98106 self .modifier_type = "random_tester"
99107 # Use a fixed seed for deterministic behavior
100108 self .rng = np .random .default_rng (seed )
109+ self .sfactor = self .rng .random ()
101110
102111 def forward (
103112 self ,
@@ -121,21 +130,24 @@ def modify_data(self, data: dict[str, Array | float], data_sys: DeepmdData) -> N
121130 return
122131
123132 if "find_energy" in data and data ["find_energy" ] == 1.0 :
124- data ["energy" ] = data ["energy" ] * self .rng . random ()
133+ data ["energy" ] = data ["energy" ] * self .sfactor
125134 if "find_force" in data and data ["find_force" ] == 1.0 :
126- data ["force" ] = data ["force" ] * self .rng . random ()
135+ data ["force" ] = data ["force" ] * self .sfactor
127136 if "find_virial" in data and data ["find_virial" ] == 1.0 :
128- data ["virial" ] = data ["virial" ] * self .rng . random ()
137+ data ["virial" ] = data ["virial" ] * self .sfactor
129138
130139
131140@BaseModifier .register ("zero_tester" )
132141class ModifierZeroTester (BaseModifier ):
133142 def __new__ (cls , * args , ** kwargs ):
134143 return super ().__new__ (cls )
135144
136- def __init__ (self ) -> None :
145+ def __init__ (
146+ self ,
147+ use_cache : bool = True ,
148+ ) -> None :
137149 """Construct a modifier that zeros out data for testing."""
138- super ().__init__ ()
150+ super ().__init__ (use_cache )
139151 self .modifier_type = "zero_tester"
140152
141153 def forward (
@@ -176,9 +188,10 @@ def __init__(
176188 self ,
177189 model_name : str ,
178190 sfactor : float = 1.0 ,
191+ use_cache : bool = True ,
179192 ) -> None :
180193 """Initialize a test modifier that applies scaled model predictions using a frozen model."""
181- super ().__init__ ()
194+ super ().__init__ (use_cache )
182195 self .modifier_type = "scaling_tester"
183196 self .model_name = model_name
184197 self .sfactor = sfactor
@@ -212,6 +225,7 @@ def forward(
212225@parameterized (
213226 (1 , 2 ), # training data batch_size
214227 (1 , 2 ), # validation data batch_size
228+ (True , False ), # use_cache
215229)
216230class TestDataModifier (unittest .TestCase ):
217231 def setUp (self ) -> None :
@@ -240,7 +254,10 @@ def test_init_modify_data(self):
240254 """Ensure modify_data applied."""
241255 tmp_config = self .config .copy ()
242256 # add tester data modifier
243- tmp_config ["model" ]["modifier" ] = {"type" : "zero_tester" }
257+ tmp_config ["model" ]["modifier" ] = {
258+ "type" : "zero_tester" ,
259+ "use_cache" : self .param [2 ],
260+ }
244261
245262 # data modification is finished in __init__
246263 trainer = get_trainer (tmp_config )
@@ -262,6 +279,7 @@ def test_full_modify_data(self):
262279 tmp_config ["model" ]["modifier" ] = {
263280 "type" : "random_tester" ,
264281 "seed" : 1024 ,
282+ "use_cache" : self .param [2 ],
265283 }
266284
267285 # data modification is finished in __init__
@@ -307,6 +325,7 @@ def test_inference(self):
307325 "type" : "scaling_tester" ,
308326 "model_name" : "frozen_model_dm.pth" ,
309327 "sfactor" : sfactor ,
328+ "use_cache" : True ,
310329 }
311330
312331 trainer = get_trainer (tmp_config )
0 commit comments