Skip to content

Commit 2ef9444

Browse files
authored
Merge branch 'devel' into 0726_devel_zbl_ft
2 parents 2bb190f + cefce47 commit 2ef9444

1 file changed

Lines changed: 45 additions & 0 deletions

File tree

deepmd/utils/data_system.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,51 @@ def __init__(
152152
else:
153153
raise RuntimeError("batch size must be specified for mixed systems")
154154
self.batch_size = rule * np.ones(self.nsystems, dtype=int)
155+
elif "max" == words[0]:
156+
# Determine batch size so that batch_size * natoms <= rule, at least 1
157+
if len(words) != 2:
158+
raise RuntimeError("batch size must be specified for max systems")
159+
rule = int(words[1])
160+
bs = []
161+
for ii in self.data_systems:
162+
ni = ii.get_natoms()
163+
bsi = rule // ni
164+
if bsi == 0:
165+
bsi = 1
166+
bs.append(bsi)
167+
self.batch_size = bs
168+
elif "filter" == words[0]:
169+
# Remove systems with natoms > rule, then set batch size like "max:rule"
170+
if len(words) != 2:
171+
raise RuntimeError(
172+
"batch size must be specified for filter systems"
173+
)
174+
rule = int(words[1])
175+
filtered_data_systems = []
176+
filtered_system_dirs = []
177+
for sys_dir, data_sys in zip(self.system_dirs, self.data_systems):
178+
if data_sys.get_natoms() <= rule:
179+
filtered_data_systems.append(data_sys)
180+
filtered_system_dirs.append(sys_dir)
181+
if len(filtered_data_systems) == 0:
182+
raise RuntimeError(
183+
f"No system left after removing systems with more than {rule} atoms"
184+
)
185+
if len(filtered_data_systems) != len(self.data_systems):
186+
warnings.warn(
187+
f"Remove {len(self.data_systems) - len(filtered_data_systems)} systems with more than {rule} atoms"
188+
)
189+
self.data_systems = filtered_data_systems
190+
self.system_dirs = filtered_system_dirs
191+
self.nsystems = len(self.data_systems)
192+
bs = []
193+
for ii in self.data_systems:
194+
ni = ii.get_natoms()
195+
bsi = rule // ni
196+
if bsi == 0:
197+
bsi = 1
198+
bs.append(bsi)
199+
self.batch_size = bs
155200
else:
156201
raise RuntimeError("unknown batch_size rule " + words[0])
157202
elif isinstance(self.batch_size, list):

0 commit comments

Comments
 (0)