diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 07dab35a90..65ea9e4c7b 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -152,6 +152,51 @@ def __init__( else: raise RuntimeError("batch size must be specified for mixed systems") self.batch_size = rule * np.ones(self.nsystems, dtype=int) + elif "max" == words[0]: + # Determine batch size so that batch_size * natoms <= rule, at least 1 + if len(words) != 2: + raise RuntimeError("batch size must be specified for max systems") + rule = int(words[1]) + bs = [] + for ii in self.data_systems: + ni = ii.get_natoms() + bsi = rule // ni + if bsi == 0: + bsi = 1 + bs.append(bsi) + self.batch_size = bs + elif "filter" == words[0]: + # Remove systems with natoms > rule, then set batch size like "max:rule" + if len(words) != 2: + raise RuntimeError( + "batch size must be specified for filter systems" + ) + rule = int(words[1]) + filtered_data_systems = [] + filtered_system_dirs = [] + for sys_dir, data_sys in zip(self.system_dirs, self.data_systems): + if data_sys.get_natoms() <= rule: + filtered_data_systems.append(data_sys) + filtered_system_dirs.append(sys_dir) + if len(filtered_data_systems) == 0: + raise RuntimeError( + f"No system left after removing systems with more than {rule} atoms" + ) + if len(filtered_data_systems) != len(self.data_systems): + warnings.warn( + f"Remove {len(self.data_systems) - len(filtered_data_systems)} systems with more than {rule} atoms" + ) + self.data_systems = filtered_data_systems + self.system_dirs = filtered_system_dirs + self.nsystems = len(self.data_systems) + bs = [] + for ii in self.data_systems: + ni = ii.get_natoms() + bsi = rule // ni + if bsi == 0: + bsi = 1 + bs.append(bsi) + self.batch_size = bs else: raise RuntimeError("unknown batch_size rule " + words[0]) elif isinstance(self.batch_size, list):