@@ -152,6 +152,51 @@ def __init__(
152152 else :
153153 raise RuntimeError ("batch size must be specified for mixed systems" )
154154 self .batch_size = rule * np .ones (self .nsystems , dtype = int )
155+ elif "max" == words [0 ]:
156+ # Determine batch size so that batch_size * natoms <= rule, at least 1
157+ if len (words ) != 2 :
158+ raise RuntimeError ("batch size must be specified for max systems" )
159+ rule = int (words [1 ])
160+ bs = []
161+ for ii in self .data_systems :
162+ ni = ii .get_natoms ()
163+ bsi = rule // ni
164+ if bsi == 0 :
165+ bsi = 1
166+ bs .append (bsi )
167+ self .batch_size = bs
168+ elif "filter" == words [0 ]:
169+ # Remove systems with natoms > rule, then set batch size like "max:rule"
170+ if len (words ) != 2 :
171+ raise RuntimeError (
172+ "batch size must be specified for filter systems"
173+ )
174+ rule = int (words [1 ])
175+ filtered_data_systems = []
176+ filtered_system_dirs = []
177+ for sys_dir , data_sys in zip (self .system_dirs , self .data_systems ):
178+ if data_sys .get_natoms () <= rule :
179+ filtered_data_systems .append (data_sys )
180+ filtered_system_dirs .append (sys_dir )
181+ if len (filtered_data_systems ) == 0 :
182+ raise RuntimeError (
183+ f"No system left after removing systems with more than { rule } atoms"
184+ )
185+ if len (filtered_data_systems ) != len (self .data_systems ):
186+ warnings .warn (
187+ f"Remove { len (self .data_systems ) - len (filtered_data_systems )} systems with more than { rule } atoms"
188+ )
189+ self .data_systems = filtered_data_systems
190+ self .system_dirs = filtered_system_dirs
191+ self .nsystems = len (self .data_systems )
192+ bs = []
193+ for ii in self .data_systems :
194+ ni = ii .get_natoms ()
195+ bsi = rule // ni
196+ if bsi == 0 :
197+ bsi = 1
198+ bs .append (bsi )
199+ self .batch_size = bs
155200 else :
156201 raise RuntimeError ("unknown batch_size rule " + words [0 ])
157202 elif isinstance (self .batch_size , list ):
0 commit comments