1717from deepmd .pt .utils import (
1818 dp_random ,
1919)
20- from deepmd .pt .utils .dataloader import (
21- DpLoaderSet ,
22- get_sampler_from_params ,
23- get_weighted_sampler ,
24- )
2520from deepmd .tf .common import (
2621 expand_sys_str ,
2722)
@@ -67,7 +62,7 @@ def setUp(self) -> None:
6762 self .systems = config ["training" ]["validation_data" ]["systems" ]
6863 if isinstance (self .systems , str ):
6964 self .systems = expand_sys_str (self .systems )
70- self .my_dataset = DpLoaderSet (
65+ self .my_dataset = pt_dataloader . DpLoaderSet (
7166 self .systems ,
7267 self .batch_size ,
7368 self .type_map ,
@@ -81,7 +76,9 @@ def setUp(self) -> None:
8176 def tearDown (self ) -> None :
8277 self ._monkeypatch .undo ()
8378
84- def _make_dataloader (self , dataset : DpLoaderSet , sampler ) -> DataLoader :
79+ def _make_dataloader (
80+ self , dataset : pt_dataloader .DpLoaderSet , sampler
81+ ) -> DataLoader :
8582 return DataLoader (
8683 dataset ,
8784 sampler = sampler ,
@@ -96,6 +93,18 @@ def _normalize_probs(self, weights: np.ndarray) -> np.ndarray:
9693 return weights / np .sum (weights )
9794
9895 def _compute_total_numb_batch (self , nbatches : np .ndarray , probs : np .ndarray ) -> int :
96+ # NOTE: This is a simplified test-only variant of training.py logic.
97+ nbatches = np .asarray (nbatches , dtype = np .float64 )
98+ probs = np .asarray (probs , dtype = np .float64 )
99+ if nbatches .shape != probs .shape :
100+ raise ValueError (
101+ "nbatches and probs must have the same shape in this test helper."
102+ )
103+ if not np .all (probs > 0.0 ):
104+ raise ValueError (
105+ "Zero or negative sampling probabilities are not supported in this "
106+ "test helper."
107+ )
99108 return int (np .ceil (np .max (nbatches / probs )))
100109
101110 def _sample_sid_counts (
@@ -156,7 +165,9 @@ def _sample_multitask_counts(
156165 def test_sampler_debug_info (self ) -> None :
157166 dataloader = DataLoader (
158167 self .my_dataset ,
159- sampler = get_weighted_sampler (self .my_dataset , prob_style = "prob_sys_size" ),
168+ sampler = pt_dataloader .get_weighted_sampler (
169+ self .my_dataset , prob_style = "prob_sys_size"
170+ ),
160171 batch_size = None ,
161172 num_workers = 0 , # setting to 0 diverges the behavior of its iterator; should be >=1
162173 drop_last = False ,
@@ -171,31 +182,37 @@ def test_sampler_debug_info(self) -> None:
171182
172183 def test_auto_prob_uniform (self ) -> None :
173184 auto_prob_style = "prob_uniform"
174- sampler = get_weighted_sampler (self .my_dataset , prob_style = auto_prob_style )
185+ sampler = pt_dataloader .get_weighted_sampler (
186+ self .my_dataset , prob_style = auto_prob_style
187+ )
175188 my_probs = np .array (sampler .weights )
176189 self .dp_dataset .set_sys_probs (auto_prob_style = auto_prob_style )
177190 dp_probs = np .array (self .dp_dataset .sys_probs )
178191 self .assertTrue (np .allclose (my_probs , dp_probs ))
179192
180193 def test_auto_prob_sys_size (self ) -> None :
181194 auto_prob_style = "prob_sys_size"
182- sampler = get_weighted_sampler (self .my_dataset , prob_style = auto_prob_style )
195+ sampler = pt_dataloader .get_weighted_sampler (
196+ self .my_dataset , prob_style = auto_prob_style
197+ )
183198 my_probs = np .array (sampler .weights )
184199 self .dp_dataset .set_sys_probs (auto_prob_style = auto_prob_style )
185200 dp_probs = np .array (self .dp_dataset .sys_probs )
186201 self .assertTrue (np .allclose (my_probs , dp_probs ))
187202
188203 def test_auto_prob_sys_size_ext (self ) -> None :
189204 auto_prob_style = "prob_sys_size;0:1:0.2;1:3:0.8"
190- sampler = get_weighted_sampler (self .my_dataset , prob_style = auto_prob_style )
205+ sampler = pt_dataloader .get_weighted_sampler (
206+ self .my_dataset , prob_style = auto_prob_style
207+ )
191208 my_probs = np .array (sampler .weights )
192209 self .dp_dataset .set_sys_probs (auto_prob_style = auto_prob_style )
193210 dp_probs = np .array (self .dp_dataset .sys_probs )
194211 self .assertTrue (np .allclose (my_probs , dp_probs ))
195212
196213 def test_sys_probs (self ) -> None :
197214 sys_probs = [0.1 , 0.4 , 0.5 ]
198- sampler = get_weighted_sampler (
215+ sampler = pt_dataloader . get_weighted_sampler (
199216 self .my_dataset , prob_style = sys_probs , sys_prob = True
200217 )
201218 my_probs = np .array (sampler .weights )
@@ -209,7 +226,7 @@ def test_sys_probs_end2end(self):
209226 "sys_probs" : sys_probs ,
210227 "auto_prob" : "prob_sys_size" ,
211228 } # use sys_probs first
212- sampler = get_sampler_from_params (self .my_dataset , _params )
229+ sampler = pt_dataloader . get_sampler_from_params (self .my_dataset , _params )
213230 my_probs = np .array (sampler .weights )
214231 self .dp_dataset .set_sys_probs (sys_probs = sys_probs )
215232 dp_probs = np .array (self .dp_dataset .sys_probs )
@@ -218,7 +235,7 @@ def test_sys_probs_end2end(self):
218235 def test_auto_prob_sys_size_ext_end2end (self ):
219236 auto_prob_style = "prob_sys_size;0:1:0.2;1:3:0.8"
220237 _params = {"sys_probs" : None , "auto_prob" : auto_prob_style } # use auto_prob
221- sampler = get_sampler_from_params (self .my_dataset , _params )
238+ sampler = pt_dataloader . get_sampler_from_params (self .my_dataset , _params )
222239 my_probs = np .array (sampler .weights )
223240 self .dp_dataset .set_sys_probs (auto_prob_style = auto_prob_style )
224241 dp_probs = np .array (self .dp_dataset .sys_probs )
@@ -231,7 +248,7 @@ def test_sampling_stability_single_task(self) -> None:
231248 str (Path (__file__ ).parent / "water/data/data_1" ),
232249 str (Path (__file__ ).parent / "water/data/single" ),
233250 ]
234- dataset_epoch = DpLoaderSet (
251+ dataset_epoch = pt_dataloader . DpLoaderSet (
235252 systems ,
236253 self .batch_size ,
237254 self .type_map ,
@@ -240,7 +257,7 @@ def test_sampling_stability_single_task(self) -> None:
240257 )
241258 sys_probs = [0.2 , 0.3 , 0.5 ]
242259 params = {"sys_probs" : sys_probs , "auto_prob" : "prob_sys_size" }
243- sampler_epoch = get_sampler_from_params (dataset_epoch , params )
260+ sampler_epoch = pt_dataloader . get_sampler_from_params (dataset_epoch , params )
244261 probs = self ._normalize_probs (np .asarray (sampler_epoch .weights ))
245262 nbatches = np .asarray (dataset_epoch .index , dtype = np .float64 )
246263 total_numb_batch = self ._compute_total_numb_batch (nbatches , probs )
@@ -257,14 +274,14 @@ def test_sampling_stability_single_task(self) -> None:
257274 self .assertTrue (np .allclose (empirical_epoch , probs , atol = 0.1 ))
258275
259276 # === Step 3. Sample Using Explicit Steps ===
260- dataset_steps = DpLoaderSet (
277+ dataset_steps = pt_dataloader . DpLoaderSet (
261278 systems ,
262279 self .batch_size ,
263280 self .type_map ,
264281 seed = 10 ,
265282 shuffle = False ,
266283 )
267- sampler_steps = get_sampler_from_params (dataset_steps , params )
284+ sampler_steps = pt_dataloader . get_sampler_from_params (dataset_steps , params )
268285 torch .manual_seed (123 )
269286 dataloader_steps = self ._make_dataloader (dataset_steps , sampler_steps )
270287 counts_steps = self ._sample_sid_counts (
@@ -283,24 +300,24 @@ def test_sampling_stability_multi_task(self) -> None:
283300 str (Path (__file__ ).parent / "water/data/data_1" ),
284301 str (Path (__file__ ).parent / "water/data/single" ),
285302 ]
286- dataset_1 = DpLoaderSet (
303+ dataset_1 = pt_dataloader . DpLoaderSet (
287304 systems_1 ,
288305 self .batch_size ,
289306 self .type_map ,
290307 seed = 10 ,
291308 shuffle = False ,
292309 )
293- dataset_2 = DpLoaderSet (
310+ dataset_2 = pt_dataloader . DpLoaderSet (
294311 systems_2 ,
295312 self .batch_size ,
296313 self .type_map ,
297314 seed = 10 ,
298315 shuffle = False ,
299316 )
300- sampler_1 = get_sampler_from_params (
317+ sampler_1 = pt_dataloader . get_sampler_from_params (
301318 dataset_1 , {"sys_probs" : [0.7 , 0.3 ], "auto_prob" : "prob_sys_size" }
302319 )
303- sampler_2 = get_sampler_from_params (
320+ sampler_2 = pt_dataloader . get_sampler_from_params (
304321 dataset_2 , {"sys_probs" : [0.4 , 0.6 ], "auto_prob" : "prob_sys_size" }
305322 )
306323 probs_1 = self ._normalize_probs (np .asarray (sampler_1 .weights ))
@@ -352,24 +369,24 @@ def test_sampling_stability_multi_task(self) -> None:
352369 )
353370
354371 # === Step 3. Sample Using Explicit Steps ===
355- dataset_1b = DpLoaderSet (
372+ dataset_1b = pt_dataloader . DpLoaderSet (
356373 systems_1 ,
357374 self .batch_size ,
358375 self .type_map ,
359376 seed = 10 ,
360377 shuffle = False ,
361378 )
362- dataset_2b = DpLoaderSet (
379+ dataset_2b = pt_dataloader . DpLoaderSet (
363380 systems_2 ,
364381 self .batch_size ,
365382 self .type_map ,
366383 seed = 10 ,
367384 shuffle = False ,
368385 )
369- sampler_1b = get_sampler_from_params (
386+ sampler_1b = pt_dataloader . get_sampler_from_params (
370387 dataset_1b , {"sys_probs" : [0.7 , 0.3 ], "auto_prob" : "prob_sys_size" }
371388 )
372- sampler_2b = get_sampler_from_params (
389+ sampler_2b = pt_dataloader . get_sampler_from_params (
373390 dataset_2b , {"sys_probs" : [0.4 , 0.6 ], "auto_prob" : "prob_sys_size" }
374391 )
375392 dataloaders_steps = {
0 commit comments