From 6801820503e4397932e4bdeef2f6aee916008cca Mon Sep 17 00:00:00 2001 From: OutisLi Date: Sun, 10 Aug 2025 16:29:12 +0800 Subject: [PATCH 1/2] fix:error when batch size use max/filter under beighbor-stat --- deepmd/utils/data_system.py | 39 +++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 07dab35a90..826934e0fd 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -152,6 +152,45 @@ def __init__( else: raise RuntimeError("batch size must be specified for mixed systems") self.batch_size = rule * np.ones(self.nsystems, dtype=int) + elif "max" == words[0]: + # Determine batch size so that batch_size * natoms <= rule, at least 1 + if len(words) != 2: + raise RuntimeError("batch size must be specified for max systems") + rule = int(words[1]) + bs = [] + for ii in self.data_systems: + ni = ii.get_natoms() + bsi = rule // ni + if bsi == 0: + bsi = 1 + bs.append(bsi) + self.batch_size = bs + elif "filter" == words[0]: + # Remove systems with natoms > rule, then set batch size like "max:rule" + if len(words) != 2: + raise RuntimeError("batch size must be specified for filter systems") + rule = int(words[1]) + filtered_data_systems = [] + filtered_system_dirs = [] + for sys_dir, data_sys in zip(self.system_dirs, self.data_systems): + if data_sys.get_natoms() <= rule: + filtered_data_systems.append(data_sys) + filtered_system_dirs.append(sys_dir) + if len(filtered_data_systems) == 0: + raise RuntimeError(f"No system left after removing systems with more than {rule} atoms") + if len(filtered_data_systems) != len(self.data_systems): + warnings.warn(f"Remove {len(self.data_systems) - len(filtered_data_systems)} systems with more than {rule} atoms") + self.data_systems = filtered_data_systems + self.system_dirs = filtered_system_dirs + self.nsystems = len(self.data_systems) + bs = [] + for ii in self.data_systems: + ni = ii.get_natoms() + bsi = rule // ni + if bsi == 0: + bsi = 1 + bs.append(bsi) + self.batch_size = bs else: raise RuntimeError("unknown batch_size rule " + words[0]) elif isinstance(self.batch_size, list): From 43a2f19c6ff75a59ab0c22d1ea3cb3c677d727a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 10 Aug 2025 09:01:20 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/utils/data_system.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index 826934e0fd..65ea9e4c7b 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -168,7 +168,9 @@ def __init__( elif "filter" == words[0]: # Remove systems with natoms > rule, then set batch size like "max:rule" if len(words) != 2: - raise RuntimeError("batch size must be specified for filter systems") + raise RuntimeError( + "batch size must be specified for filter systems" + ) rule = int(words[1]) filtered_data_systems = [] filtered_system_dirs = [] @@ -177,9 +179,13 @@ def __init__( filtered_data_systems.append(data_sys) filtered_system_dirs.append(sys_dir) if len(filtered_data_systems) == 0: - raise RuntimeError(f"No system left after removing systems with more than {rule} atoms") + raise RuntimeError( + f"No system left after removing systems with more than {rule} atoms" + ) if len(filtered_data_systems) != len(self.data_systems): - warnings.warn(f"Remove {len(self.data_systems) - len(filtered_data_systems)} systems with more than {rule} atoms") + warnings.warn( + f"Remove {len(self.data_systems) - len(filtered_data_systems)} systems with more than {rule} atoms" + ) self.data_systems = filtered_data_systems self.system_dirs = filtered_system_dirs self.nsystems = len(self.data_systems)