Skip to content

Commit 7111056

Browse files
committed
adopt
1 parent 94149a9 commit 7111056

3 files changed

Lines changed: 67 additions & 35 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,17 @@ def resolve_model_prob(
278278
model_training_data: dict[str, DpLoaderSet],
279279
) -> np.ndarray:
280280
model_prob = np.zeros(len(model_keys), dtype=np.float64)
281-
if model_prob_config is not None:
281+
if model_prob_config:
282282
for ii, model_key in enumerate(model_keys):
283283
if model_key in model_prob_config:
284284
model_prob[ii] = float(model_prob_config[model_key])
285285
else:
286286
for ii, model_key in enumerate(model_keys):
287287
model_prob[ii] = float(len(model_training_data[model_key]))
288+
if not np.all(np.isfinite(model_prob)):
289+
raise ValueError("Model prob must be finite.")
290+
if np.any(model_prob < 0.0):
291+
raise ValueError("Model prob must be non-negative.")
288292
sum_prob = float(np.sum(model_prob))
289293
if sum_prob <= 0.0:
290294
raise ValueError("Sum of model prob must be larger than 0!")

deepmd/utils/argcheck.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3213,15 +3213,22 @@ def mixed_precision_args() -> list[Argument]: # ! added by Denghui.
32133213
def training_args(
32143214
multi_task: bool = False,
32153215
) -> list[Argument]: # ! modified by Ziyao: data configuration isolated.
3216-
doc_numb_steps = "Number of training batches. Each training uses one batch of data. If set, this value takes precedence over num_epoch."
3216+
doc_numb_steps = (
3217+
"Number of training batches. Each training uses one batch of data. "
3218+
"If set, this value takes precedence over num_epoch. If both numb_steps "
3219+
"and num_epoch are not set, a ValueError is raised."
3220+
)
32173221
doc_num_epoch = (
3218-
"Number of training epochs. "
3222+
"Number of training epochs (can be fractional). "
32193223
"When numb_steps is not set, the total steps are computed as "
3220-
"ceil(num_epoch * total_numb_batch). For each training dataset, "
3221-
"total_numb_batch is computed as ceil(max_i(n_bch_i / p_i)), where p_i "
3222-
"is the sampling probability of system i after sys_probs/auto_prob. "
3223-
"In multi-task mode, total_numb_batch is the model_prob-weighted sum "
3224-
"over tasks."
3224+
"ceil(num_epoch * total_numb_batch). For each task, total_numb_batch "
3225+
"is computed as ceil(max_i(n_bch_i / p_i)), where n_bch_i is the number "
3226+
"of batches for system i and p_i is the sampling probability after "
3227+
"sys_probs/auto_prob normalization. In multi-task mode, model_prob is "
3228+
"normalized to sum to 1, per-task total_numb_batch values are computed "
3229+
"as above, and the final total_numb_batch is their model_prob-weighted "
3230+
"sum. At least one of numb_steps or num_epoch must be set; otherwise a "
3231+
"ValueError is raised."
32253232
)
32263233
doc_seed = "The random seed for getting frames from the training data set."
32273234
doc_disp_file = "The file for printing learning curve."
@@ -3295,7 +3302,11 @@ def training_args(
32953302
args += [
32963303
mixed_precision_data,
32973304
Argument(
3298-
"numb_steps", int, optional=True, doc=doc_numb_steps, alias=["stop_batch"]
3305+
"numb_steps",
3306+
int,
3307+
optional=True,
3308+
doc=doc_numb_steps,
3309+
alias=["stop_batch", "num_steps"],
32993310
),
33003311
Argument(
33013312
"num_epoch",

source/tests/pt/test_sampler.py

Lines changed: 43 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@
1717
from deepmd.pt.utils import (
1818
dp_random,
1919
)
20-
from deepmd.pt.utils.dataloader import (
21-
DpLoaderSet,
22-
get_sampler_from_params,
23-
get_weighted_sampler,
24-
)
2520
from deepmd.tf.common import (
2621
expand_sys_str,
2722
)
@@ -67,7 +62,7 @@ def setUp(self) -> None:
6762
self.systems = config["training"]["validation_data"]["systems"]
6863
if isinstance(self.systems, str):
6964
self.systems = expand_sys_str(self.systems)
70-
self.my_dataset = DpLoaderSet(
65+
self.my_dataset = pt_dataloader.DpLoaderSet(
7166
self.systems,
7267
self.batch_size,
7368
self.type_map,
@@ -81,7 +76,9 @@ def setUp(self) -> None:
8176
def tearDown(self) -> None:
8277
self._monkeypatch.undo()
8378

84-
def _make_dataloader(self, dataset: DpLoaderSet, sampler) -> DataLoader:
79+
def _make_dataloader(
80+
self, dataset: pt_dataloader.DpLoaderSet, sampler
81+
) -> DataLoader:
8582
return DataLoader(
8683
dataset,
8784
sampler=sampler,
@@ -96,6 +93,18 @@ def _normalize_probs(self, weights: np.ndarray) -> np.ndarray:
9693
return weights / np.sum(weights)
9794

9895
def _compute_total_numb_batch(self, nbatches: np.ndarray, probs: np.ndarray) -> int:
96+
# NOTE: This is a simplified test-only variant of training.py logic.
97+
nbatches = np.asarray(nbatches, dtype=np.float64)
98+
probs = np.asarray(probs, dtype=np.float64)
99+
if nbatches.shape != probs.shape:
100+
raise ValueError(
101+
"nbatches and probs must have the same shape in this test helper."
102+
)
103+
if not np.all(probs > 0.0):
104+
raise ValueError(
105+
"Zero or negative sampling probabilities are not supported in this "
106+
"test helper."
107+
)
99108
return int(np.ceil(np.max(nbatches / probs)))
100109

101110
def _sample_sid_counts(
@@ -156,7 +165,9 @@ def _sample_multitask_counts(
156165
def test_sampler_debug_info(self) -> None:
157166
dataloader = DataLoader(
158167
self.my_dataset,
159-
sampler=get_weighted_sampler(self.my_dataset, prob_style="prob_sys_size"),
168+
sampler=pt_dataloader.get_weighted_sampler(
169+
self.my_dataset, prob_style="prob_sys_size"
170+
),
160171
batch_size=None,
161172
num_workers=0, # setting to 0 diverges the behavior of its iterator; should be >=1
162173
drop_last=False,
@@ -171,31 +182,37 @@ def test_sampler_debug_info(self) -> None:
171182

172183
def test_auto_prob_uniform(self) -> None:
173184
auto_prob_style = "prob_uniform"
174-
sampler = get_weighted_sampler(self.my_dataset, prob_style=auto_prob_style)
185+
sampler = pt_dataloader.get_weighted_sampler(
186+
self.my_dataset, prob_style=auto_prob_style
187+
)
175188
my_probs = np.array(sampler.weights)
176189
self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style)
177190
dp_probs = np.array(self.dp_dataset.sys_probs)
178191
self.assertTrue(np.allclose(my_probs, dp_probs))
179192

180193
def test_auto_prob_sys_size(self) -> None:
181194
auto_prob_style = "prob_sys_size"
182-
sampler = get_weighted_sampler(self.my_dataset, prob_style=auto_prob_style)
195+
sampler = pt_dataloader.get_weighted_sampler(
196+
self.my_dataset, prob_style=auto_prob_style
197+
)
183198
my_probs = np.array(sampler.weights)
184199
self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style)
185200
dp_probs = np.array(self.dp_dataset.sys_probs)
186201
self.assertTrue(np.allclose(my_probs, dp_probs))
187202

188203
def test_auto_prob_sys_size_ext(self) -> None:
189204
auto_prob_style = "prob_sys_size;0:1:0.2;1:3:0.8"
190-
sampler = get_weighted_sampler(self.my_dataset, prob_style=auto_prob_style)
205+
sampler = pt_dataloader.get_weighted_sampler(
206+
self.my_dataset, prob_style=auto_prob_style
207+
)
191208
my_probs = np.array(sampler.weights)
192209
self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style)
193210
dp_probs = np.array(self.dp_dataset.sys_probs)
194211
self.assertTrue(np.allclose(my_probs, dp_probs))
195212

196213
def test_sys_probs(self) -> None:
197214
sys_probs = [0.1, 0.4, 0.5]
198-
sampler = get_weighted_sampler(
215+
sampler = pt_dataloader.get_weighted_sampler(
199216
self.my_dataset, prob_style=sys_probs, sys_prob=True
200217
)
201218
my_probs = np.array(sampler.weights)
@@ -209,7 +226,7 @@ def test_sys_probs_end2end(self):
209226
"sys_probs": sys_probs,
210227
"auto_prob": "prob_sys_size",
211228
} # use sys_probs first
212-
sampler = get_sampler_from_params(self.my_dataset, _params)
229+
sampler = pt_dataloader.get_sampler_from_params(self.my_dataset, _params)
213230
my_probs = np.array(sampler.weights)
214231
self.dp_dataset.set_sys_probs(sys_probs=sys_probs)
215232
dp_probs = np.array(self.dp_dataset.sys_probs)
@@ -218,7 +235,7 @@ def test_sys_probs_end2end(self):
218235
def test_auto_prob_sys_size_ext_end2end(self):
219236
auto_prob_style = "prob_sys_size;0:1:0.2;1:3:0.8"
220237
_params = {"sys_probs": None, "auto_prob": auto_prob_style} # use auto_prob
221-
sampler = get_sampler_from_params(self.my_dataset, _params)
238+
sampler = pt_dataloader.get_sampler_from_params(self.my_dataset, _params)
222239
my_probs = np.array(sampler.weights)
223240
self.dp_dataset.set_sys_probs(auto_prob_style=auto_prob_style)
224241
dp_probs = np.array(self.dp_dataset.sys_probs)
@@ -231,7 +248,7 @@ def test_sampling_stability_single_task(self) -> None:
231248
str(Path(__file__).parent / "water/data/data_1"),
232249
str(Path(__file__).parent / "water/data/single"),
233250
]
234-
dataset_epoch = DpLoaderSet(
251+
dataset_epoch = pt_dataloader.DpLoaderSet(
235252
systems,
236253
self.batch_size,
237254
self.type_map,
@@ -240,7 +257,7 @@ def test_sampling_stability_single_task(self) -> None:
240257
)
241258
sys_probs = [0.2, 0.3, 0.5]
242259
params = {"sys_probs": sys_probs, "auto_prob": "prob_sys_size"}
243-
sampler_epoch = get_sampler_from_params(dataset_epoch, params)
260+
sampler_epoch = pt_dataloader.get_sampler_from_params(dataset_epoch, params)
244261
probs = self._normalize_probs(np.asarray(sampler_epoch.weights))
245262
nbatches = np.asarray(dataset_epoch.index, dtype=np.float64)
246263
total_numb_batch = self._compute_total_numb_batch(nbatches, probs)
@@ -257,14 +274,14 @@ def test_sampling_stability_single_task(self) -> None:
257274
self.assertTrue(np.allclose(empirical_epoch, probs, atol=0.1))
258275

259276
# === Step 3. Sample Using Explicit Steps ===
260-
dataset_steps = DpLoaderSet(
277+
dataset_steps = pt_dataloader.DpLoaderSet(
261278
systems,
262279
self.batch_size,
263280
self.type_map,
264281
seed=10,
265282
shuffle=False,
266283
)
267-
sampler_steps = get_sampler_from_params(dataset_steps, params)
284+
sampler_steps = pt_dataloader.get_sampler_from_params(dataset_steps, params)
268285
torch.manual_seed(123)
269286
dataloader_steps = self._make_dataloader(dataset_steps, sampler_steps)
270287
counts_steps = self._sample_sid_counts(
@@ -283,24 +300,24 @@ def test_sampling_stability_multi_task(self) -> None:
283300
str(Path(__file__).parent / "water/data/data_1"),
284301
str(Path(__file__).parent / "water/data/single"),
285302
]
286-
dataset_1 = DpLoaderSet(
303+
dataset_1 = pt_dataloader.DpLoaderSet(
287304
systems_1,
288305
self.batch_size,
289306
self.type_map,
290307
seed=10,
291308
shuffle=False,
292309
)
293-
dataset_2 = DpLoaderSet(
310+
dataset_2 = pt_dataloader.DpLoaderSet(
294311
systems_2,
295312
self.batch_size,
296313
self.type_map,
297314
seed=10,
298315
shuffle=False,
299316
)
300-
sampler_1 = get_sampler_from_params(
317+
sampler_1 = pt_dataloader.get_sampler_from_params(
301318
dataset_1, {"sys_probs": [0.7, 0.3], "auto_prob": "prob_sys_size"}
302319
)
303-
sampler_2 = get_sampler_from_params(
320+
sampler_2 = pt_dataloader.get_sampler_from_params(
304321
dataset_2, {"sys_probs": [0.4, 0.6], "auto_prob": "prob_sys_size"}
305322
)
306323
probs_1 = self._normalize_probs(np.asarray(sampler_1.weights))
@@ -352,24 +369,24 @@ def test_sampling_stability_multi_task(self) -> None:
352369
)
353370

354371
# === Step 3. Sample Using Explicit Steps ===
355-
dataset_1b = DpLoaderSet(
372+
dataset_1b = pt_dataloader.DpLoaderSet(
356373
systems_1,
357374
self.batch_size,
358375
self.type_map,
359376
seed=10,
360377
shuffle=False,
361378
)
362-
dataset_2b = DpLoaderSet(
379+
dataset_2b = pt_dataloader.DpLoaderSet(
363380
systems_2,
364381
self.batch_size,
365382
self.type_map,
366383
seed=10,
367384
shuffle=False,
368385
)
369-
sampler_1b = get_sampler_from_params(
386+
sampler_1b = pt_dataloader.get_sampler_from_params(
370387
dataset_1b, {"sys_probs": [0.7, 0.3], "auto_prob": "prob_sys_size"}
371388
)
372-
sampler_2b = get_sampler_from_params(
389+
sampler_2b = pt_dataloader.get_sampler_from_params(
373390
dataset_2b, {"sys_probs": [0.4, 0.6], "auto_prob": "prob_sys_size"}
374391
)
375392
dataloaders_steps = {

0 commit comments

Comments
 (0)