From 1fd31990a6905ecf6e412bb601f08040214300cc Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 24 Oct 2025 09:48:49 +0800 Subject: [PATCH 01/14] perf: accelarate data loading through using memmap inget_numb_batch --- deepmd/utils/data.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index ae4a484909..7051345718 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -338,8 +338,9 @@ def get_numb_set(self) -> int: def get_numb_batch(self, batch_size: int, set_idx: int) -> int: """Get the number of batches in a set.""" - data = self._load_set(self.dirs[set_idx]) - ret = data["coord"].shape[0] // batch_size + # Directly obtain the number of frames to avoid loading the entire dataset + nframes = self._get_nframes(self.dirs[set_idx]) + ret = nframes // batch_size if ret == 0: ret = 1 return ret @@ -583,13 +584,16 @@ def _get_nframes(self, set_name: DPPath) -> int: if not isinstance(set_name, DPPath): set_name = DPPath(set_name) path = set_name / "coord.npy" - if self.data_dict["coord"]["high_prec"]: - coord = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) - else: - coord = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) - if coord.ndim == 1: - coord = coord.reshape([1, -1]) - nframes = coord.shape[0] + # Read only the header to get shape without creating memmap + with open(str(path), "rb") as f: + version = np.lib.format.read_magic(f) + if version[0] == 1: + shape, fortran_order, dtype = np.lib.format.read_array_header_1_0(f) + elif version[0] in [2, 3]: + shape, fortran_order, dtype = np.lib.format.read_array_header_2_0(f) + else: + raise ValueError(f"Unsupported .npy file version: {version}") + nframes = shape[0] return nframes def reformat_data_torch(self, data: dict[str, Any]) -> dict[str, Any]: From ec18c0ba140eab12cfef0c6d100d4dbbb8fb93ef Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 24 Oct 2025 11:25:38 +0800 Subject: [PATCH 02/14] perf: use rglob("type.raw") to find systems instead of find * and then find type.raw. 10x accelarate --- deepmd/common.py | 2 +- deepmd/utils/path.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/deepmd/common.py b/deepmd/common.py index 26b655f876..f463905131 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -202,7 +202,7 @@ def expand_sys_str(root_dir: Union[str, Path]) -> list[str]: list of string pointing to system directories """ root_dir = DPPath(root_dir) - matches = [str(d) for d in root_dir.rglob("*") if (d / "type.raw").is_file()] + matches = [str(p.parent) for p in root_dir.rglob("type.raw") if p.is_file()] if (root_dir / "type.raw").is_file(): matches.append(str(root_dir)) return matches diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index e6b00cdf80..821e760765 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -157,6 +157,11 @@ def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: If true, no error will be raised if the target directory already exists. """ + @property + @abstractmethod + def parent(self) -> "DPPath": + """Return the parent path.""" + class DPOSPath(DPPath): """The OS path class to data system (DeepmdData) for real directories. @@ -267,6 +272,11 @@ def name(self) -> str: """Name of the path.""" return self.path.name + @property + def parent(self) -> "DPPath": + """Return the parent path.""" + return type(self)(self.path.parent, mode=self.mode) + def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: """Make directory. @@ -469,6 +479,14 @@ def name(self) -> str: """Name of the path.""" return self._name.split("/")[-1] + @property + def parent(self) -> "DPPath": + """Return the parent path.""" + parent_name = "/".join(self._name.split("/")[:-1]) + if not parent_name: + parent_name = "/" + return type(self)(f"{self.root_path}#{parent_name}", mode=self.mode) + def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: """Make directory. From 0b4c151b83116105df0ce512d5a68468d56e0cee Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 24 Oct 2025 13:51:17 +0800 Subject: [PATCH 03/14] perf: use multithread to accelarate stat computing and loading --- deepmd/pt/utils/stat.py | 119 +++++++++++++++++++++++------------ deepmd/utils/env_mat_stat.py | 40 +++++++++--- 2 files changed, 110 insertions(+), 49 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 7312d95a06..b0d9ad8344 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -3,6 +3,9 @@ from collections import ( defaultdict, ) +from concurrent.futures import ( + ThreadPoolExecutor, +) from typing import ( Any, Callable, @@ -39,7 +42,7 @@ def make_stat_input( datasets: list[Any], dataloaders: list[Any], nbatches: int ) -> dict[str, Any]: - """Pack data for statistics. + """Pack data for statistics in parallel. Args: - dataset: A list of dataset to analyze. @@ -49,49 +52,83 @@ def make_stat_input( ------- - a list of dicts, each of which contains data from a system """ - lst = [] log.info(f"Packing data for statistics from {len(datasets)} systems") - for i in range(len(datasets)): - sys_stat = {} - with torch.device("cpu"): - iterator = iter(dataloaders[i]) - numb_batches = min(nbatches, len(dataloaders[i])) - for _ in range(numb_batches): - try: - stat_data = next(iterator) - except StopIteration: - iterator = iter(dataloaders[i]) - stat_data = next(iterator) - if ( - "find_fparam" in stat_data - and "fparam" in stat_data - and stat_data["find_fparam"] == 0.0 - ): - # for model using default fparam - stat_data.pop("fparam") - stat_data.pop("find_fparam") - for dd in stat_data: - if stat_data[dd] is None: - sys_stat[dd] = None - elif isinstance(stat_data[dd], torch.Tensor): - if dd not in sys_stat: - sys_stat[dd] = [] - sys_stat[dd].append(stat_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat[dd] = stat_data[dd] - else: - pass - - for key in sys_stat: - if isinstance(sys_stat[key], np.float32): - pass - elif sys_stat[key] is None or sys_stat[key][0] is None: + dataloader_lens = [len(dl) for dl in dataloaders] + args_list = [ + (dataloaders[i], nbatches, dataloader_lens[i]) for i in range(len(datasets)) + ] + + lst = [] + # I/O intensive, set a larger number of workers + with ThreadPoolExecutor(max_workers=256) as executor: + lst = list(executor.map(_process_one_dataset, args_list)) + log.info("Finished packing data.") + return lst + + +def _process_one_dataset(args: tuple[Any, int, int]) -> dict[str, Any]: + """ + Helper function to process a single dataset's dataloader for statistics. + Designed to be called in parallel by a ThreadPoolExecutor. + + Parameters + ---------- + args : tuple(Any, int, int) + A tuple containing (dataloader, nbatches, dataloader_len) + + Returns + ------- + dict[str, Any] + The processed sys_stat dictionary for one dataset. + """ + dataloader, nbatches, dataloader_len = args + sys_stat = {} + + with torch.device("cpu"): + iterator = iter(dataloader) + numb_batches = min(nbatches, dataloader_len) + + for _ in range(numb_batches): + try: + stat_data = next(iterator) + except StopIteration: + iterator = iter(dataloader) + stat_data = next(iterator) + + if ( + "find_fparam" in stat_data + and "fparam" in stat_data + and stat_data["find_fparam"] == 0.0 + ): + # for model using default fparam + stat_data.pop("fparam") + stat_data.pop("find_fparam") + + for dd in stat_data: + if stat_data[dd] is None: + sys_stat[dd] = None + elif isinstance(stat_data[dd], torch.Tensor): + if dd not in sys_stat: + sys_stat[dd] = [] + sys_stat[dd].append(stat_data[dd]) + elif isinstance(stat_data[dd], np.float32): + sys_stat[dd] = stat_data[dd] + else: + pass + + for key in sys_stat: + if isinstance(sys_stat[key], np.float32): + pass + elif isinstance(sys_stat[key], list): + if sys_stat[key][0] is None: sys_stat[key] = None - elif isinstance(stat_data[dd], torch.Tensor): + else: sys_stat[key] = torch.cat(sys_stat[key], dim=0) - dict_to_device(sys_stat) - lst.append(sys_stat) - return lst + elif sys_stat[key] is None: + pass + + dict_to_device(sys_stat) + return sys_stat def _restore_from_file( diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index ecc0b7b62f..42ec792772 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -10,6 +10,9 @@ from collections.abc import ( Iterator, ) +from concurrent.futures import ( + ThreadPoolExecutor, +) from typing import ( Optional, ) @@ -142,7 +145,7 @@ def save_stats(self, path: DPPath) -> None: (path / kk).save_numpy(np.array([vv.number, vv.sum, vv.squared_sum])) def load_stats(self, path: DPPath) -> None: - """Load the statistics of the environment matrix. + """Load the statistics of the environment matrix in parallel. Parameters ---------- @@ -151,13 +154,18 @@ def load_stats(self, path: DPPath) -> None: """ if len(self.stats) > 0: raise ValueError("The statistics has already been computed.") - for kk in path.glob("*"): - arr = kk.load_numpy() - self.stats[kk.name] = StatItem( - number=arr[0], - sum=arr[1], - squared_sum=arr[2], - ) + + files_to_load = list(path.glob("*")) + + if not files_to_load: + raise ValueError(f"No statistics files found in {path}.") + + with ThreadPoolExecutor(max_workers=128) as executor: + results = executor.map(self._load_stat_file, files_to_load) + + for name, stat_item in results: + if stat_item is not None: + self.stats[name] = stat_item def load_or_compute_stats( self, data: list[dict[str, np.ndarray]], path: Optional[DPPath] = None @@ -216,3 +224,19 @@ def get_std( kk: vv.compute_std(default=default, protection=protection) for kk, vv in self.stats.items() } + + @staticmethod + def _load_stat_file(file_path: DPPath) -> tuple[str, StatItem]: + """Helper function for parallel loading of stat files.""" + try: + arr = file_path.load_numpy() + if arr.shape == (3,): + return file_path.name, StatItem( + number=arr[0], sum=arr[1], squared_sum=arr[2] + ) + else: + log.warning(f"Skipping malformed stat file: {file_path.name}") + return file_path.name, None + except Exception as e: + log.warning(f"Failed to load stat file {file_path.name}: {e}") + return file_path.name, None From 4eb076d5953e858ac610b323a76406553adfa994 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 10 Nov 2025 10:57:01 +0800 Subject: [PATCH 04/14] bug fix --- deepmd/pt/utils/stat.py | 5 +++-- deepmd/utils/data.py | 11 ++++++++--- deepmd/utils/env_mat_stat.py | 3 ++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index b0d9ad8344..c3f1e72a00 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import os from collections import ( defaultdict, ) @@ -60,7 +61,7 @@ def make_stat_input( lst = [] # I/O intensive, set a larger number of workers - with ThreadPoolExecutor(max_workers=256) as executor: + with ThreadPoolExecutor(min(128, (os.cpu_count() or 1) * 6)) as executor: lst = list(executor.map(_process_one_dataset, args_list)) log.info("Finished packing data.") return lst @@ -120,7 +121,7 @@ def _process_one_dataset(args: tuple[Any, int, int]) -> dict[str, Any]: if isinstance(sys_stat[key], np.float32): pass elif isinstance(sys_stat[key], list): - if sys_stat[key][0] is None: + if len(sys_stat[key]) == 0 or sys_stat[key][0] is None: sys_stat[key] = None else: sys_stat[key] = torch.cat(sys_stat[key], dim=0) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 7051345718..608d7755fa 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -338,8 +338,13 @@ def get_numb_set(self) -> int: def get_numb_batch(self, batch_size: int, set_idx: int) -> int: """Get the number of batches in a set.""" - # Directly obtain the number of frames to avoid loading the entire dataset - nframes = self._get_nframes(self.dirs[set_idx]) + set_name = self.dirs[set_idx] + if isinstance(set_name, DPH5Path): + data = self._load_set(set_name) + nframes = data["coord"].shape[0] + else: + # Directly obtain the number of frames to avoid loading the entire dataset + nframes = self._get_nframes(set_name) ret = nframes // batch_size if ret == 0: ret = 1 @@ -584,7 +589,7 @@ def _get_nframes(self, set_name: DPPath) -> int: if not isinstance(set_name, DPPath): set_name = DPPath(set_name) path = set_name / "coord.npy" - # Read only the header to get shape without creating memmap + # Read only the header to get shape with open(str(path), "rb") as f: version = np.lib.format.read_magic(f) if version[0] == 1: diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index 42ec792772..3b9625c1cc 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging +import os from abc import ( ABC, abstractmethod, @@ -160,7 +161,7 @@ def load_stats(self, path: DPPath) -> None: if not files_to_load: raise ValueError(f"No statistics files found in {path}.") - with ThreadPoolExecutor(max_workers=128) as executor: + with ThreadPoolExecutor(min(64, (os.cpu_count() or 1) * 4)) as executor: results = executor.map(self._load_stat_file, files_to_load) for name, stat_item in results: From dae2028949614bd6ffe73017d14a523265f1fc40 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 10 Nov 2025 11:38:19 +0800 Subject: [PATCH 05/14] bug fix --- deepmd/utils/data.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 608d7755fa..c29ca531c4 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -136,7 +136,12 @@ def __init__( # set modifier self.modifier = modifier # calculate prefix sum for get_item method - frames_list = [self._get_nframes(item) for item in self.dirs] + frames_list = [ + self._load_set(item)["coord"].shape[0] + if isinstance(item, DPH5Path) + else self._get_nframes(item) + for item in self.dirs + ] self.nframes = np.sum(frames_list) # The prefix sum stores the range of indices contained in each directory, which is needed by get_item method self.prefix_sum = np.cumsum(frames_list).tolist() From 293def1875f2331ae61e5b7a3d5e149c1af9f2af Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 10 Nov 2025 15:55:14 +0800 Subject: [PATCH 06/14] Revert "perf: use multithread to accelarate stat computing and loading" --- deepmd/pt/utils/stat.py | 120 ++++++++++++----------------------- deepmd/utils/env_mat_stat.py | 41 +++--------- 2 files changed, 49 insertions(+), 112 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index c3f1e72a00..7312d95a06 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -1,12 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import os from collections import ( defaultdict, ) -from concurrent.futures import ( - ThreadPoolExecutor, -) from typing import ( Any, Callable, @@ -43,7 +39,7 @@ def make_stat_input( datasets: list[Any], dataloaders: list[Any], nbatches: int ) -> dict[str, Any]: - """Pack data for statistics in parallel. + """Pack data for statistics. Args: - dataset: A list of dataset to analyze. @@ -53,83 +49,49 @@ def make_stat_input( ------- - a list of dicts, each of which contains data from a system """ - log.info(f"Packing data for statistics from {len(datasets)} systems") - dataloader_lens = [len(dl) for dl in dataloaders] - args_list = [ - (dataloaders[i], nbatches, dataloader_lens[i]) for i in range(len(datasets)) - ] - lst = [] - # I/O intensive, set a larger number of workers - with ThreadPoolExecutor(min(128, (os.cpu_count() or 1) * 6)) as executor: - lst = list(executor.map(_process_one_dataset, args_list)) - log.info("Finished packing data.") - return lst - - -def _process_one_dataset(args: tuple[Any, int, int]) -> dict[str, Any]: - """ - Helper function to process a single dataset's dataloader for statistics. - Designed to be called in parallel by a ThreadPoolExecutor. - - Parameters - ---------- - args : tuple(Any, int, int) - A tuple containing (dataloader, nbatches, dataloader_len) - - Returns - ------- - dict[str, Any] - The processed sys_stat dictionary for one dataset. - """ - dataloader, nbatches, dataloader_len = args - sys_stat = {} - - with torch.device("cpu"): - iterator = iter(dataloader) - numb_batches = min(nbatches, dataloader_len) - - for _ in range(numb_batches): - try: - stat_data = next(iterator) - except StopIteration: - iterator = iter(dataloader) - stat_data = next(iterator) - - if ( - "find_fparam" in stat_data - and "fparam" in stat_data - and stat_data["find_fparam"] == 0.0 - ): - # for model using default fparam - stat_data.pop("fparam") - stat_data.pop("find_fparam") - - for dd in stat_data: - if stat_data[dd] is None: - sys_stat[dd] = None - elif isinstance(stat_data[dd], torch.Tensor): - if dd not in sys_stat: - sys_stat[dd] = [] - sys_stat[dd].append(stat_data[dd]) - elif isinstance(stat_data[dd], np.float32): - sys_stat[dd] = stat_data[dd] - else: - pass - - for key in sys_stat: - if isinstance(sys_stat[key], np.float32): - pass - elif isinstance(sys_stat[key], list): - if len(sys_stat[key]) == 0 or sys_stat[key][0] is None: + log.info(f"Packing data for statistics from {len(datasets)} systems") + for i in range(len(datasets)): + sys_stat = {} + with torch.device("cpu"): + iterator = iter(dataloaders[i]) + numb_batches = min(nbatches, len(dataloaders[i])) + for _ in range(numb_batches): + try: + stat_data = next(iterator) + except StopIteration: + iterator = iter(dataloaders[i]) + stat_data = next(iterator) + if ( + "find_fparam" in stat_data + and "fparam" in stat_data + and stat_data["find_fparam"] == 0.0 + ): + # for model using default fparam + stat_data.pop("fparam") + stat_data.pop("find_fparam") + for dd in stat_data: + if stat_data[dd] is None: + sys_stat[dd] = None + elif isinstance(stat_data[dd], torch.Tensor): + if dd not in sys_stat: + sys_stat[dd] = [] + sys_stat[dd].append(stat_data[dd]) + elif isinstance(stat_data[dd], np.float32): + sys_stat[dd] = stat_data[dd] + else: + pass + + for key in sys_stat: + if isinstance(sys_stat[key], np.float32): + pass + elif sys_stat[key] is None or sys_stat[key][0] is None: sys_stat[key] = None - else: + elif isinstance(stat_data[dd], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) - elif sys_stat[key] is None: - pass - - dict_to_device(sys_stat) - return sys_stat + dict_to_device(sys_stat) + lst.append(sys_stat) + return lst def _restore_from_file( diff --git a/deepmd/utils/env_mat_stat.py b/deepmd/utils/env_mat_stat.py index 3b9625c1cc..ecc0b7b62f 100644 --- a/deepmd/utils/env_mat_stat.py +++ b/deepmd/utils/env_mat_stat.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging -import os from abc import ( ABC, abstractmethod, @@ -11,9 +10,6 @@ from collections.abc import ( Iterator, ) -from concurrent.futures import ( - ThreadPoolExecutor, -) from typing import ( Optional, ) @@ -146,7 +142,7 @@ def save_stats(self, path: DPPath) -> None: (path / kk).save_numpy(np.array([vv.number, vv.sum, vv.squared_sum])) def load_stats(self, path: DPPath) -> None: - """Load the statistics of the environment matrix in parallel. + """Load the statistics of the environment matrix. Parameters ---------- @@ -155,18 +151,13 @@ def load_stats(self, path: DPPath) -> None: """ if len(self.stats) > 0: raise ValueError("The statistics has already been computed.") - - files_to_load = list(path.glob("*")) - - if not files_to_load: - raise ValueError(f"No statistics files found in {path}.") - - with ThreadPoolExecutor(min(64, (os.cpu_count() or 1) * 4)) as executor: - results = executor.map(self._load_stat_file, files_to_load) - - for name, stat_item in results: - if stat_item is not None: - self.stats[name] = stat_item + for kk in path.glob("*"): + arr = kk.load_numpy() + self.stats[kk.name] = StatItem( + number=arr[0], + sum=arr[1], + squared_sum=arr[2], + ) def load_or_compute_stats( self, data: list[dict[str, np.ndarray]], path: Optional[DPPath] = None @@ -225,19 +216,3 @@ def get_std( kk: vv.compute_std(default=default, protection=protection) for kk, vv in self.stats.items() } - - @staticmethod - def _load_stat_file(file_path: DPPath) -> tuple[str, StatItem]: - """Helper function for parallel loading of stat files.""" - try: - arr = file_path.load_numpy() - if arr.shape == (3,): - return file_path.name, StatItem( - number=arr[0], sum=arr[1], squared_sum=arr[2] - ) - else: - log.warning(f"Skipping malformed stat file: {file_path.name}") - return file_path.name, None - except Exception as e: - log.warning(f"Failed to load stat file {file_path.name}: {e}") - return file_path.name, None From 1dcbc883bbd9deefa0e33119df336ba95251ab35 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Mon, 10 Nov 2025 16:29:21 +0800 Subject: [PATCH 07/14] fix: special case of single frame dataset --- deepmd/utils/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index c29ca531c4..2315162e43 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -603,7 +603,7 @@ def _get_nframes(self, set_name: DPPath) -> int: shape, fortran_order, dtype = np.lib.format.read_array_header_2_0(f) else: raise ValueError(f"Unsupported .npy file version: {version}") - nframes = shape[0] + nframes = shape[0] if (len(shape) if isinstance(shape, tuple) else 0) > 1 else 1 return nframes def reformat_data_torch(self, data: dict[str, Any]) -> dict[str, Any]: From 7c2077cb0a9ba80dacc6438a90f3d4181a06c08a Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 13 Nov 2025 16:31:20 +0800 Subject: [PATCH 08/14] minor update --- deepmd/utils/data.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 2315162e43..65b87ebc52 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -589,21 +589,28 @@ def _shuffle_data(self, data: dict[str, Any]) -> dict[str, Any]: ret[kk] = data[kk] return ret, idx - def _get_nframes(self, set_name: DPPath) -> int: - # get nframes + def _get_nframes(self, set_name: DPPath | str) -> int: if not isinstance(set_name, DPPath): set_name = DPPath(set_name) - path = set_name / "coord.npy" - # Read only the header to get shape - with open(str(path), "rb") as f: - version = np.lib.format.read_magic(f) - if version[0] == 1: - shape, fortran_order, dtype = np.lib.format.read_array_header_1_0(f) - elif version[0] in [2, 3]: - shape, fortran_order, dtype = np.lib.format.read_array_header_2_0(f) - else: - raise ValueError(f"Unsupported .npy file version: {version}") - nframes = shape[0] if (len(shape) if isinstance(shape, tuple) else 0) > 1 else 1 + if isinstance(set_name, DPH5Path): + path = set_name / "coord.npy" + nframes = path.root[path._name].shape[0] + else: + path = set_name / "coord.npy" + # Read only the header to get shape + with open(str(path), "rb") as f: + version = np.lib.format.read_magic(f) + if version[0] == 1: + shape, _fortran_order, _dtype = np.lib.format.read_array_header_1_0( + f + ) + elif version[0] in [2, 3]: + shape, _fortran_order, _dtype = np.lib.format.read_array_header_2_0( + f + ) + else: + raise ValueError(f"Unsupported .npy file version: {version}") + nframes = shape[0] if len(shape) > 1 else 1 return nframes def reformat_data_torch(self, data: dict[str, Any]) -> dict[str, Any]: From f3a8341cfa8f54e64283e8daafc6634ed9a8c07e Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 13 Nov 2025 16:58:24 +0800 Subject: [PATCH 09/14] fix: repeated path collection --- deepmd/common.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/deepmd/common.py b/deepmd/common.py index f463905131..4945b33d7b 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -203,8 +203,6 @@ def expand_sys_str(root_dir: Union[str, Path]) -> list[str]: """ root_dir = DPPath(root_dir) matches = [str(p.parent) for p in root_dir.rglob("type.raw") if p.is_file()] - if (root_dir / "type.raw").is_file(): - matches.append(str(root_dir)) return matches From 55fa967ad4a514a93864f8c5794935c20131650c Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 13 Nov 2025 17:32:45 +0800 Subject: [PATCH 10/14] refactor: simplify frame count calculation in DeepmdData initialization --- deepmd/utils/data.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 65b87ebc52..224af70baa 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -135,13 +135,7 @@ def __init__( self.shuffle_test = shuffle_test # set modifier self.modifier = modifier - # calculate prefix sum for get_item method - frames_list = [ - self._load_set(item)["coord"].shape[0] - if isinstance(item, DPH5Path) - else self._get_nframes(item) - for item in self.dirs - ] + frames_list = [self._get_nframes(set_name) for set_name in self.dirs] self.nframes = np.sum(frames_list) # The prefix sum stores the range of indices contained in each directory, which is needed by get_item method self.prefix_sum = np.cumsum(frames_list).tolist() From 8f7d6c930e3cefa79d55720d92e3eea7012f9da8 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 13 Nov 2025 18:22:47 +0800 Subject: [PATCH 11/14] refactor: update type hint for _get_nframes method to use Union --- deepmd/utils/data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 224af70baa..79689ccb67 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -14,6 +14,7 @@ from typing import ( Any, Optional, + Union, ) import numpy as np @@ -583,7 +584,7 @@ def _shuffle_data(self, data: dict[str, Any]) -> dict[str, Any]: ret[kk] = data[kk] return ret, idx - def _get_nframes(self, set_name: DPPath | str) -> int: + def _get_nframes(self, set_name: Union[DPPath, str]) -> int: if not isinstance(set_name, DPPath): set_name = DPPath(set_name) if isinstance(set_name, DPH5Path): From ffd9232334da472e50e048792ac83f6122d0a953 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Thu, 13 Nov 2025 18:46:10 +0800 Subject: [PATCH 12/14] refactor: streamline batch count calculation by removing unnecessary data loading --- deepmd/utils/data.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 79689ccb67..02274d4e8e 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -339,12 +339,8 @@ def get_numb_set(self) -> int: def get_numb_batch(self, batch_size: int, set_idx: int) -> int: """Get the number of batches in a set.""" set_name = self.dirs[set_idx] - if isinstance(set_name, DPH5Path): - data = self._load_set(set_name) - nframes = data["coord"].shape[0] - else: - # Directly obtain the number of frames to avoid loading the entire dataset - nframes = self._get_nframes(set_name) + # Directly obtain the number of frames to avoid loading the entire dataset + nframes = self._get_nframes(set_name) ret = nframes // batch_size if ret == 0: ret = 1 From 6cb64351e2adab74212334f105681b028897575b Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 14 Nov 2025 09:28:10 +0800 Subject: [PATCH 13/14] revert changes in expand_sys_str since the modification cannot process symlink --- deepmd/common.py | 4 +++- deepmd/utils/path.py | 18 ------------------ 2 files changed, 3 insertions(+), 19 deletions(-) diff --git a/deepmd/common.py b/deepmd/common.py index 4945b33d7b..26b655f876 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -202,7 +202,9 @@ def expand_sys_str(root_dir: Union[str, Path]) -> list[str]: list of string pointing to system directories """ root_dir = DPPath(root_dir) - matches = [str(p.parent) for p in root_dir.rglob("type.raw") if p.is_file()] + matches = [str(d) for d in root_dir.rglob("*") if (d / "type.raw").is_file()] + if (root_dir / "type.raw").is_file(): + matches.append(str(root_dir)) return matches diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 821e760765..e6b00cdf80 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -157,11 +157,6 @@ def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: If true, no error will be raised if the target directory already exists. """ - @property - @abstractmethod - def parent(self) -> "DPPath": - """Return the parent path.""" - class DPOSPath(DPPath): """The OS path class to data system (DeepmdData) for real directories. @@ -272,11 +267,6 @@ def name(self) -> str: """Name of the path.""" return self.path.name - @property - def parent(self) -> "DPPath": - """Return the parent path.""" - return type(self)(self.path.parent, mode=self.mode) - def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: """Make directory. @@ -479,14 +469,6 @@ def name(self) -> str: """Name of the path.""" return self._name.split("/")[-1] - @property - def parent(self) -> "DPPath": - """Return the parent path.""" - parent_name = "/".join(self._name.split("/")[:-1]) - if not parent_name: - parent_name = "/" - return type(self)(f"{self.root_path}#{parent_name}", mode=self.mode) - def mkdir(self, parents: bool = False, exist_ok: bool = False) -> None: """Make directory. From c5bfec210927ddcdb9827dfde96937b39be005a3 Mon Sep 17 00:00:00 2001 From: OutisLi Date: Fri, 14 Nov 2025 18:17:01 +0800 Subject: [PATCH 14/14] refactor: remove redundant path assignment in _get_nframes method --- deepmd/utils/data.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index 02274d4e8e..26a27c82d7 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -583,11 +583,10 @@ def _shuffle_data(self, data: dict[str, Any]) -> dict[str, Any]: def _get_nframes(self, set_name: Union[DPPath, str]) -> int: if not isinstance(set_name, DPPath): set_name = DPPath(set_name) + path = set_name / "coord.npy" if isinstance(set_name, DPH5Path): - path = set_name / "coord.npy" nframes = path.root[path._name].shape[0] else: - path = set_name / "coord.npy" # Read only the header to get shape with open(str(path), "rb") as f: version = np.lib.format.read_magic(f)