Skip to content

Commit fb73297

Browse files
committed
resolve comments
1 parent 637cf68 commit fb73297

5 files changed

Lines changed: 39 additions & 51 deletions

File tree

deepmd/dpmodel/utils/lmdb_data.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,9 @@ def _remap_keys(frame: dict[str, Any]) -> dict[str, Any]:
146146
return out
147147

148148

149-
def is_lmdb(systems: Any) -> bool:
149+
def is_lmdb(systems: str) -> bool:
150150
"""Check if systems points to an LMDB dataset."""
151-
if not isinstance(systems, str):
152-
return False
153-
return systems.endswith(".lmdb") or Path(systems, "data.mdb").exists()
151+
return systems.endswith(".lmdb") or Path(systems, "data.mdb").is_file()
154152

155153

156154
def _parse_metadata(meta: dict) -> tuple[int, str, list[int]]:

deepmd/entrypoints/test.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import numpy as np
1414

1515
from deepmd.common import (
16-
expand_sys_str,
1716
j_loader,
1817
)
1918
from deepmd.dpmodel.utils.lmdb_data import (
@@ -148,11 +147,8 @@ def test(
148147
systems = str((root / Path(systems)).resolve())
149148
else:
150149
systems = [str((root / Path(ss)).resolve()) for ss in systems]
151-
if is_lmdb(systems):
152-
all_sys = [systems]
153-
else:
154-
patterns = data_params.get("rglob_patterns", None)
155-
all_sys = process_systems(systems, patterns=patterns)
150+
patterns = data_params.get("rglob_patterns", None)
151+
all_sys = process_systems(systems, patterns=patterns)
156152
elif valid_json is not None:
157153
jdata = j_loader(valid_json)
158154
jdata = update_deepmd_input(jdata)
@@ -165,19 +161,13 @@ def test(
165161
systems = str((root / Path(systems)).resolve())
166162
else:
167163
systems = [str((root / Path(ss)).resolve()) for ss in systems]
168-
if is_lmdb(systems):
169-
all_sys = [systems]
170-
else:
171-
patterns = data_params.get("rglob_patterns", None)
172-
all_sys = process_systems(systems, patterns=patterns)
164+
patterns = data_params.get("rglob_patterns", None)
165+
all_sys = process_systems(systems, patterns=patterns)
173166
elif datafile is not None:
174167
with open(datafile) as datalist:
175168
all_sys = datalist.read().splitlines()
176169
elif system is not None:
177-
if is_lmdb(system):
178-
all_sys = [system]
179-
else:
180-
all_sys = expand_sys_str(system)
170+
all_sys = process_systems(system)
181171
else:
182172
raise RuntimeError("No data source specified for testing")
183173

deepmd/pt/entrypoints/main.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,23 @@ def prepare_trainer_input_single(
147147
Path(stat_file_path_single).mkdir()
148148
stat_file_path_single = DPPath(stat_file_path_single, "a")
149149

150+
rank_seed = [rank, seed % (2**32)] if seed is not None else None
151+
152+
def _make_dp_loader_set(
153+
systems: str | list[str],
154+
dataset_params: dict[str, Any],
155+
) -> DpLoaderSet:
156+
"""Create a DpLoaderSet from systems with pattern expansion."""
157+
patterns = dataset_params.get("rglob_patterns", None)
158+
systems = process_systems(systems, patterns=patterns)
159+
return DpLoaderSet(
160+
systems,
161+
dataset_params["batch_size"],
162+
model_params_single["type_map"],
163+
seed=rank_seed,
164+
modifier=modifier,
165+
)
166+
150167
# LMDB path: single string → LmdbDataset
151168
if is_lmdb(training_systems):
152169
auto_prob = training_dataset_params.get("auto_prob", None)
@@ -163,46 +180,21 @@ def prepare_trainer_input_single(
163180
validation_dataset_params["batch_size"],
164181
)
165182
elif validation_systems is not None:
166-
val_patterns = validation_dataset_params.get("rglob_patterns", None)
167-
validation_systems = process_systems(validation_systems, val_patterns)
168-
rank_seed = [rank, seed % (2**32)] if seed is not None else None
169-
validation_data_single = DpLoaderSet(
170-
validation_systems,
171-
validation_dataset_params["batch_size"],
172-
model_params_single["type_map"],
173-
seed=rank_seed,
174-
modifier=modifier,
183+
validation_data_single = _make_dp_loader_set(
184+
validation_systems, validation_dataset_params
175185
)
176186
else:
177187
validation_data_single = None
178188
else:
179189
# Standard npy path
180-
trn_patterns = training_dataset_params.get("rglob_patterns", None)
181-
training_systems = process_systems(training_systems, patterns=trn_patterns)
182-
if validation_systems is not None:
183-
val_patterns = validation_dataset_params.get("rglob_patterns", None)
184-
validation_systems = process_systems(validation_systems, val_patterns)
185-
186-
# avoid the same batch sequence among devices
187-
rank_seed = [rank, seed % (2**32)] if seed is not None else None
190+
train_data_single = _make_dp_loader_set(
191+
training_systems, training_dataset_params
192+
)
188193
validation_data_single = (
189-
DpLoaderSet(
190-
validation_systems,
191-
validation_dataset_params["batch_size"],
192-
model_params_single["type_map"],
193-
seed=rank_seed,
194-
modifier=modifier,
195-
)
194+
_make_dp_loader_set(validation_systems, validation_dataset_params)
196195
if validation_systems
197196
else None
198197
)
199-
train_data_single = DpLoaderSet(
200-
training_systems,
201-
training_dataset_params["batch_size"],
202-
model_params_single["type_map"],
203-
seed=rank_seed,
204-
modifier=modifier,
205-
)
206198
return (
207199
train_data_single,
208200
validation_data_single,

deepmd/utils/data_system.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,7 @@ def process_systems(
791791
792792
If it is a single directory, search for all the systems in the directory.
793793
If it is a list, each item in the list is treated as a directory to search.
794+
If it is a single LMDB path, return it directly without expansion.
794795
Check if the systems are valid.
795796
796797
Parameters
@@ -805,6 +806,14 @@ def process_systems(
805806
result_systems: list of str
806807
The valid systems
807808
"""
809+
from deepmd.dpmodel.utils.lmdb_data import (
810+
is_lmdb,
811+
)
812+
813+
# LMDB path: return directly without expansion
814+
if isinstance(systems, str) and is_lmdb(systems):
815+
return [systems]
816+
808817
# Normalize input to a list of paths to search
809818
if isinstance(systems, str):
810819
search_paths = [systems]

source/tests/common/dpmodel/test_lmdb_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,6 @@ def test_is_lmdb(self):
309309
self.assertTrue(is_lmdb(self._lmdb_path))
310310
self.assertTrue(is_lmdb("something.lmdb"))
311311
self.assertFalse(is_lmdb("/some/npy/system"))
312-
self.assertFalse(is_lmdb(["list", "of", "systems"]))
313312

314313
def test_lmdb_test_data(self):
315314
td = LmdbTestData(self._lmdb_path, type_map=self._type_map, shuffle_test=False)

0 commit comments

Comments
 (0)