2424import torch .distributed as dist
2525
2626from deepmd .dpmodel .common import PRECISION_DICT as NP_PRECISION_DICT
27+ from deepmd .dpmodel .utils .lmdb_data import (
28+ LmdbTestData ,
29+ LmdbTestDataNlocView ,
30+ )
2731from deepmd .pt .utils .auto_batch_size import (
2832 AutoBatchSize ,
2933)
3539 GLOBAL_PT_FLOAT_PRECISION ,
3640 RESERVED_PRECISION_DICT ,
3741)
42+ from deepmd .pt .utils .lmdb_dataset import (
43+ LmdbDataset ,
44+ )
3845from deepmd .pt .utils .utils import (
3946 to_torch_tensor ,
4047)
5562log = logging .getLogger (__name__ )
5663
5764if TYPE_CHECKING :
65+ from collections .abc import (
66+ Iterator ,
67+ )
68+
5869 from deepmd .utils .data import (
5970 DeepmdData ,
6071 )
@@ -229,6 +240,12 @@ def __init__(
229240 if self .rank == 0 :
230241 self ._initialize_best_checkpoints (restart_training = restart_training )
231242
243+ # Lazily-populated full test snapshot for LMDB validation. Mixed-nloc
244+ # LMDB datasets cannot be stacked as a single (nframes, natoms*3)
245+ # tensor, so we materialize frames grouped by nloc the first time
246+ # full validation runs and reuse the snapshot on subsequent calls.
247+ self ._lmdb_test_data : LmdbTestData | None = None
248+
232249 def should_run (self , display_step : int ) -> bool :
233250 """Check whether the current step should trigger full validation."""
234251 if not self .enabled or self .start_step is None :
@@ -348,14 +365,10 @@ def evaluate_all_systems(self) -> dict[str, float]:
348365 if torch .cuda .is_available ():
349366 torch .cuda .empty_cache ()
350367
351- system_metrics = []
352- for dataset in self .validation_data .systems :
353- if not isinstance (dataset , DeepmdDataSetForLoader ):
354- raise TypeError (
355- "Full validation expects each dataset in validation_data.systems "
356- f"to be DeepmdDataSetForLoader, got { type (dataset )!r} ."
357- )
358- system_metrics .append (self ._evaluate_system (dataset .data_system ))
368+ system_metrics = [
369+ self ._evaluate_system (data_system )
370+ for data_system in self ._iter_validation_data_systems ()
371+ ]
359372
360373 aggregated = weighted_average ([metric for metric in system_metrics if metric ])
361374 return {
@@ -364,6 +377,54 @@ def evaluate_all_systems(self) -> dict[str, float]:
364377 if metric_key in aggregated
365378 }
366379
380+ def _iter_validation_data_systems (self ) -> Iterator [Any ]:
381+ """Yield ``DeepmdData``-like systems to evaluate in this run.
382+
383+ - For ``DpLoaderSet``-style validation data, each entry in
384+ ``validation_data.systems`` is a :class:`DeepmdDataSetForLoader`,
385+ and we forward its underlying ``DeepmdData`` instance.
386+ - For ``LmdbDataset`` validation data, we lazily materialize a
387+ :class:`LmdbTestData` snapshot (cached across calls) and yield one
388+ :class:`LmdbTestDataNlocView` per ``nloc`` group, so mixed-nloc
389+ frames can be stacked and evaluated group by group.
390+ """
391+ validation_data = self .validation_data
392+ if isinstance (validation_data , LmdbDataset ):
393+ lmdb_test_data = self ._get_lmdb_test_data_snapshot (validation_data )
394+ for nloc in sorted (lmdb_test_data .nloc_groups .keys ()):
395+ yield LmdbTestDataNlocView (lmdb_test_data , nloc )
396+ return
397+
398+ for dataset in validation_data .systems :
399+ if not isinstance (dataset , DeepmdDataSetForLoader ):
400+ raise TypeError (
401+ "Full validation expects each dataset in validation_data.systems "
402+ f"to be DeepmdDataSetForLoader, got { type (dataset )!r} ."
403+ )
404+ yield dataset .data_system
405+
406+ def _get_lmdb_test_data_snapshot (self , lmdb_dataset : LmdbDataset ) -> LmdbTestData :
407+ """Build (once) and return the cached LMDB test snapshot.
408+
409+ Reuses the ``type_map`` and previously-registered
410+ ``DataRequirementItem`` entries on ``LmdbDataset._reader`` so that
411+ the full-validation snapshot sees exactly the same fields and
412+ dtypes as training batches.
413+ """
414+ if self ._lmdb_test_data is not None :
415+ return self ._lmdb_test_data
416+
417+ reader = lmdb_dataset ._reader
418+ self ._lmdb_test_data = LmdbTestData (
419+ lmdb_dataset .lmdb_path ,
420+ type_map = list (reader ._type_map ),
421+ shuffle_test = False ,
422+ )
423+ data_requirements = list (reader ._data_requirements .values ())
424+ if data_requirements :
425+ self ._lmdb_test_data .add_data_requirement (data_requirements )
426+ return self ._lmdb_test_data
427+
367428 def _evaluate_system (
368429 self , data_system : DeepmdData
369430 ) -> dict [str , tuple [float , float ]]:
0 commit comments