|
14 | 14 | from typing import ( |
15 | 15 | Any, |
16 | 16 | Optional, |
| 17 | + Union, |
17 | 18 | ) |
18 | 19 |
|
19 | 20 | import numpy as np |
@@ -103,13 +104,10 @@ def __init__( |
103 | 104 | f"Elements {missing_elements} are not present in the provided `type_map`." |
104 | 105 | ) |
105 | 106 | if not self.mixed_type: |
106 | | - # Use vectorized operation for better performance with large atom counts |
107 | | - # Create a mapping array where old_type_idx -> new_type_idx |
108 | | - max_old_type = max(self.atom_type) + 1 |
109 | | - type_mapping = np.zeros(max_old_type, dtype=np.int32) |
110 | | - for old_idx in range(len(self.type_map)): |
111 | | - type_mapping[old_idx] = type_map.index(self.type_map[old_idx]) |
112 | | - self.atom_type = type_mapping[self.atom_type].astype(np.int32) |
| 107 | + old_to_new_type_idx = np.array( |
| 108 | + [type_map.index(name) for name in self.type_map], dtype=np.int32 |
| 109 | + ) |
| 110 | + self.atom_type = old_to_new_type_idx[self.atom_type].astype(np.int32) |
113 | 111 | else: |
114 | 112 | self.enforce_type_map = True |
115 | 113 | sorter = np.argsort(type_map) |
@@ -138,8 +136,7 @@ def __init__( |
138 | 136 | self.shuffle_test = shuffle_test |
139 | 137 | # set modifier |
140 | 138 | self.modifier = modifier |
141 | | - # calculate prefix sum for get_item method |
142 | | - frames_list = [self._get_nframes(item) for item in self.dirs] |
| 139 | + frames_list = [self._get_nframes(set_name) for set_name in self.dirs] |
143 | 140 | self.nframes = np.sum(frames_list) |
144 | 141 | # The prefix sum stores the range of indices contained in each directory, which is needed by get_item method |
145 | 142 | self.prefix_sum = np.cumsum(frames_list).tolist() |
@@ -341,8 +338,10 @@ def get_numb_set(self) -> int: |
341 | 338 |
|
342 | 339 | def get_numb_batch(self, batch_size: int, set_idx: int) -> int: |
343 | 340 | """Get the number of batches in a set.""" |
344 | | - data = self._load_set(self.dirs[set_idx]) |
345 | | - ret = data["coord"].shape[0] // batch_size |
| 341 | + set_name = self.dirs[set_idx] |
| 342 | + # Directly obtain the number of frames to avoid loading the entire dataset |
| 343 | + nframes = self._get_nframes(set_name) |
| 344 | + ret = nframes // batch_size |
346 | 345 | if ret == 0: |
347 | 346 | ret = 1 |
348 | 347 | return ret |
@@ -581,18 +580,27 @@ def _shuffle_data(self, data: dict[str, Any]) -> dict[str, Any]: |
581 | 580 | ret[kk] = data[kk] |
582 | 581 | return ret, idx |
583 | 582 |
|
584 | | - def _get_nframes(self, set_name: DPPath) -> int: |
585 | | - # get nframes |
| 583 | + def _get_nframes(self, set_name: Union[DPPath, str]) -> int: |
586 | 584 | if not isinstance(set_name, DPPath): |
587 | 585 | set_name = DPPath(set_name) |
588 | 586 | path = set_name / "coord.npy" |
589 | | - if self.data_dict["coord"]["high_prec"]: |
590 | | - coord = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) |
| 587 | + if isinstance(set_name, DPH5Path): |
| 588 | + nframes = path.root[path._name].shape[0] |
591 | 589 | else: |
592 | | - coord = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) |
593 | | - if coord.ndim == 1: |
594 | | - coord = coord.reshape([1, -1]) |
595 | | - nframes = coord.shape[0] |
| 590 | + # Read only the header to get shape |
| 591 | + with open(str(path), "rb") as f: |
| 592 | + version = np.lib.format.read_magic(f) |
| 593 | + if version[0] == 1: |
| 594 | + shape, _fortran_order, _dtype = np.lib.format.read_array_header_1_0( |
| 595 | + f |
| 596 | + ) |
| 597 | + elif version[0] in [2, 3]: |
| 598 | + shape, _fortran_order, _dtype = np.lib.format.read_array_header_2_0( |
| 599 | + f |
| 600 | + ) |
| 601 | + else: |
| 602 | + raise ValueError(f"Unsupported .npy file version: {version}") |
| 603 | + nframes = shape[0] if len(shape) > 1 else 1 |
596 | 604 | return nframes |
597 | 605 |
|
598 | 606 | def reformat_data_torch(self, data: dict[str, Any]) -> dict[str, Any]: |
|
0 commit comments