Skip to content

Commit 1ac753c

Browse files
caic99pre-commit-ci[bot]coderabbitai[bot]
authored andcommitted
feat: add new batch size rules for large systems (deepmodeling#4659)
This PR adds two type of defining batch sizes: `max`: Unlike `auto`, `max` does not do ceiling to batch size. Consider the case: batch size is set as `auto:256`, and you've got a system with 255 atoms. In this case, `auto` will infer a batch size of 2, while `max` remains 1 to avoid potential OOM. `filter`: It removes any systems with atoms more than expected, and limits the total atom numbers in a batch like what `max` does. I've running out of idea what the keywords should be, so please leave a comment if there is a better idea. - [ ] Todo: migrate changes to Paddle backend <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Summary by CodeRabbit - **New Features** - Enhanced batch size configuration options now allow the use of descriptive prefixes for more flexible control, including `"max:N"` and `"filter:N"`. - Clarified the existing `"auto:N"` option for batch size configuration. - **Bug Fixes** - Improved validation ensures that batch sizes are always at least one. - Updated error feedback provides clear messaging for unsupported batch size formats. - **Tests** - Introduced a new test suite to validate batch size behavior under various input scenarios, including assertions for different formats and conditions. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <amoycaic@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 46cb516 commit 1ac753c

5 files changed

Lines changed: 149 additions & 9 deletions

File tree

deepmd/pd/utils/dataloader.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,16 +117,41 @@ def construct_dataset(system):
117117
if isinstance(batch_size, str):
118118
if batch_size == "auto":
119119
rule = 32
120+
ceiling = True
120121
elif batch_size.startswith("auto:"):
121122
rule = int(batch_size.split(":")[1])
123+
ceiling = True
124+
elif batch_size.startswith("max:"):
125+
rule = int(batch_size.split(":")[1])
126+
ceiling = False
127+
elif batch_size.startswith("filter:"):
128+
# remove system with more than `filter` atoms
129+
rule = int(batch_size.split(":")[1])
130+
len_before = len(self.systems)
131+
self.systems = [
132+
system for system in self.systems if system._natoms <= rule
133+
]
134+
len_after = len(self.systems)
135+
if len_before != len_after:
136+
log.warning(
137+
f"Remove {len_before - len_after} systems with more than {rule} atoms"
138+
)
139+
if len(self.systems) == 0:
140+
raise ValueError(
141+
f"No system left after removing systems with more than {rule} atoms"
142+
)
143+
ceiling = False
122144
else:
123-
rule = None
124-
log.error("Unsupported batch size type")
145+
raise ValueError(f"Unsupported batch size rule: {batch_size}")
125146
for ii in self.systems:
126147
ni = ii._natoms
127148
bsi = rule // ni
128-
if bsi * ni < rule:
129-
bsi += 1
149+
if ceiling:
150+
if bsi * ni < rule:
151+
bsi += 1
152+
else:
153+
if bsi == 0:
154+
bsi = 1
130155
self.batch_sizes.append(bsi)
131156
elif isinstance(batch_size, list):
132157
self.batch_sizes = batch_size

deepmd/pt/utils/dataloader.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,16 +131,41 @@ def construct_dataset(system):
131131
if isinstance(batch_size, str):
132132
if batch_size == "auto":
133133
rule = 32
134+
ceiling = True
134135
elif batch_size.startswith("auto:"):
135136
rule = int(batch_size.split(":")[1])
137+
ceiling = True
138+
elif batch_size.startswith("max:"):
139+
rule = int(batch_size.split(":")[1])
140+
ceiling = False
141+
elif batch_size.startswith("filter:"):
142+
# remove system with more than `filter` atoms
143+
rule = int(batch_size.split(":")[1])
144+
len_before = len(self.systems)
145+
self.systems = [
146+
system for system in self.systems if system._natoms <= rule
147+
]
148+
len_after = len(self.systems)
149+
if len_before != len_after:
150+
log.warning(
151+
f"Remove {len_before - len_after} systems with more than {rule} atoms"
152+
)
153+
if len(self.systems) == 0:
154+
raise ValueError(
155+
f"No system left after removing systems with more than {rule} atoms"
156+
)
157+
ceiling = False
136158
else:
137-
rule = None
138-
log.error("Unsupported batch size type")
159+
raise ValueError(f"Unsupported batch size rule: {batch_size}")
139160
for ii in self.systems:
140161
ni = ii._natoms
141162
bsi = rule // ni
142-
if bsi * ni < rule:
143-
bsi += 1
163+
if ceiling:
164+
if bsi * ni < rule:
165+
bsi += 1
166+
else:
167+
if bsi == 0:
168+
bsi = 1
144169
self.batch_sizes.append(bsi)
145170
elif isinstance(batch_size, list):
146171
self.batch_sizes = batch_size

deepmd/utils/argcheck.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3269,6 +3269,8 @@ def training_data_args(): # ! added by Ziyao: new specification style for data
32693269
- string "auto": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than 32.\n\n\
32703270
- string "auto:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no less than N.\n\n\
32713271
- string "mixed:N": the batch data will be sampled from all systems and merged into a mixed system with the batch size N. Only support the se_atten descriptor for TensorFlow backend.\n\n\
3272+
- string "max:N": automatically determines the batch size so that the batch_size times the number of atoms in the system is no more than N.\n\n\
3273+
- string "filter:N": the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set.\n\n\
32723274
If MPI is used, the value should be considered as the batch size per task.'
32733275
doc_auto_prob_style = 'Determine the probability of systems automatically. The method is assigned by this key and can be\n\n\
32743276
- "prob_uniform" : the probability all the systems are equal, namely 1.0/self.get_nsystems()\n\n\

doc/train/training-advanced.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ The sections {ref}`training_data <training/training_data>` and {ref}`validation_
106106
- `list`: the length of which is the same as the {ref}`systems`. The batch size of each system is given by the elements of the list.
107107
- `int`: all systems use the same batch size.
108108
- `"auto"`: the same as `"auto:32"`, see `"auto:N"`
109-
- `"auto:N"`: automatically determines the batch size so that the {ref}`batch_size <training/training_data/batch_size>` times the number of atoms in the system is no less than `N`.
109+
- `"auto:N"`: automatically determines the batch size so that the {ref}`batch_size <training/training_data/batch_size>` times the number of atoms in the system is **no less than** `N`.
110+
- `"max:N"`: automatically determines the batch size so that the {ref}`batch_size <training/training_data/batch_size>` times the number of atoms in the system is **no more than** `N`. The minimum batch size is 1. **Supported backends**: PyTorch {{ pytorch_icon }}, Paddle {{ paddle_icon }}
111+
- `"filter:N"`: the same as `"max:N"` but removes the systems with the number of atoms larger than `N` from the data set. Throws an error if no system is left in a dataset. **Supported backends**: PyTorch {{ pytorch_icon }}, Paddle {{ paddle_icon }}
110112
- The key {ref}`numb_batch <training/validation_data/numb_btch>` in {ref}`validate_data <training/validation_data>` gives the number of batches of model validation. Note that the batches may not be from the same system
111113

112114
The section {ref}`mixed_precision <training/mixed_precision>` specifies the mixed precision settings, which will enable the mixed precision training workflow for DeePMD-kit. The keys are explained below:
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import json
3+
import unittest
4+
from pathlib import (
5+
Path,
6+
)
7+
8+
from deepmd.common import (
9+
expand_sys_str,
10+
)
11+
from deepmd.pt.utils.dataloader import (
12+
DpLoaderSet,
13+
)
14+
15+
16+
class TestSampler(unittest.TestCase):
17+
def setUp(self) -> None:
18+
with open(str(Path(__file__).parent / "water/se_e2_a.json")) as fin:
19+
content = fin.read()
20+
config = json.loads(content)
21+
data_file = [
22+
str(Path(__file__).parent / "model/water/data/data_0"),
23+
]
24+
config["training"]["training_data"]["systems"] = data_file
25+
config["training"]["validation_data"]["systems"] = data_file
26+
model_config = config["model"]
27+
self.rcut = model_config["descriptor"]["rcut"]
28+
self.rcut_smth = model_config["descriptor"]["rcut_smth"]
29+
self.sel = model_config["descriptor"]["sel"]
30+
self.batch_size = config["training"]["training_data"]["batch_size"]
31+
self.systems = config["training"]["validation_data"]["systems"]
32+
self.type_map = model_config["type_map"]
33+
if isinstance(self.systems, str):
34+
self.systems = expand_sys_str(self.systems)
35+
36+
def get_batch_sizes(self, batch_size) -> int:
37+
dataset = DpLoaderSet(
38+
self.systems,
39+
batch_size,
40+
self.type_map,
41+
seed=10,
42+
shuffle=False,
43+
)
44+
return dataset.batch_sizes[0]
45+
46+
def test_batchsize(self) -> None:
47+
# 192 atoms, 1 system
48+
assert len(self.systems) == 1
49+
50+
# test: batch_size:int
51+
self.assertEqual(self.get_batch_sizes(3), 3)
52+
53+
# test: batch_size:list[int]
54+
self.assertEqual(self.get_batch_sizes([3]), 3)
55+
56+
# test: batch_size:str = "auto"
57+
self.assertEqual(self.get_batch_sizes("auto:384"), 2)
58+
self.assertEqual(self.get_batch_sizes("auto:383"), 2)
59+
self.assertEqual(self.get_batch_sizes("auto:193"), 2)
60+
self.assertEqual(self.get_batch_sizes("auto:192"), 1)
61+
self.assertEqual(self.get_batch_sizes("auto:191"), 1)
62+
self.assertEqual(self.get_batch_sizes("auto:32"), 1)
63+
self.assertEqual(self.get_batch_sizes("auto"), 1)
64+
65+
# test: batch_size:str = "max"
66+
self.assertEqual(self.get_batch_sizes("max:384"), 2)
67+
self.assertEqual(self.get_batch_sizes("max:383"), 1)
68+
self.assertEqual(self.get_batch_sizes("max:193"), 1)
69+
self.assertEqual(self.get_batch_sizes("max:192"), 1)
70+
self.assertEqual(self.get_batch_sizes("max:191"), 1)
71+
72+
# test: batch_size:str = "filter"
73+
self.assertEqual(self.get_batch_sizes("filter:193"), 1)
74+
self.assertEqual(self.get_batch_sizes("filter:192"), 1)
75+
with self.assertLogs(logger="deepmd") as cm:
76+
self.assertRaises(ValueError, self.get_batch_sizes, "filter:191")
77+
self.assertIn("Remove 1 systems with more than 191 atoms", cm.output[-1])
78+
79+
# test: unknown batch_size: str
80+
with self.assertRaises(ValueError) as context:
81+
self.get_batch_sizes("unknown")
82+
self.assertIn("Unsupported batch size rule: unknown", str(context.exception))
83+
84+
85+
if __name__ == "__main__":
86+
unittest.main()

0 commit comments

Comments
 (0)