|
54 | 54 | KFOptimizerWrapper, |
55 | 55 | LKFOptimizer, |
56 | 56 | ) |
| 57 | +from deepmd.pt.train.validation import ( |
| 58 | + FullValidator, |
| 59 | + resolve_full_validation_start_step, |
| 60 | +) |
57 | 61 | from deepmd.pt.train.wrapper import ( |
58 | 62 | ModelWrapper, |
59 | 63 | ) |
@@ -857,6 +861,57 @@ def single_model_finetune( |
857 | 861 | self.enable_profiler = training_params.get("enable_profiler", False) |
858 | 862 | self.profiling = training_params.get("profiling", False) |
859 | 863 | self.profiling_file = training_params.get("profiling_file", "timeline.json") |
| 864 | + self.full_validator = None |
| 865 | + |
| 866 | + validating_params = config.get("validating") or {} |
| 867 | + validation_start_step = resolve_full_validation_start_step( |
| 868 | + validating_params.get("full_val_start", 0.0), |
| 869 | + self.num_steps, |
| 870 | + ) |
| 871 | + full_validation_requested = ( |
| 872 | + bool(validating_params.get("full_validation", False)) |
| 873 | + and validation_start_step is not None |
| 874 | + and validation_start_step < self.num_steps |
| 875 | + ) |
| 876 | + if full_validation_requested: |
| 877 | + if self.multi_task: |
| 878 | + raise ValueError( |
| 879 | + "validating.full_validation only supports single-task energy " |
| 880 | + "training; multi-task training is not supported." |
| 881 | + ) |
| 882 | + has_spin = getattr(self.model, "has_spin", False) |
| 883 | + if callable(has_spin): |
| 884 | + has_spin = has_spin() |
| 885 | + if has_spin or isinstance(self.loss, EnergySpinLoss): |
| 886 | + raise ValueError( |
| 887 | + "validating.full_validation only supports single-task energy " |
| 888 | + "training; spin-energy training is not supported." |
| 889 | + ) |
| 890 | + if not isinstance(self.loss, EnergyStdLoss): |
| 891 | + raise ValueError( |
| 892 | + "validating.full_validation only supports single-task energy " |
| 893 | + "training." |
| 894 | + ) |
| 895 | + if validation_data is None: |
| 896 | + raise ValueError( |
| 897 | + "validating.full_validation requires `training.validation_data` " |
| 898 | + "to be configured." |
| 899 | + ) |
| 900 | + if self.zero_stage >= 2: |
| 901 | + raise ValueError( |
| 902 | + "validating.full_validation only supports single-task energy " |
| 903 | + "training with training.zero_stage < 2." |
| 904 | + ) |
| 905 | + self.full_validator = FullValidator( |
| 906 | + validating_params=validating_params, |
| 907 | + validation_data=validation_data, |
| 908 | + model=self.model, |
| 909 | + train_infos=self._get_inner_module().train_infos, |
| 910 | + num_steps=self.num_steps, |
| 911 | + rank=self.rank, |
| 912 | + zero_stage=self.zero_stage, |
| 913 | + restart_training=self.restart_training, |
| 914 | + ) |
860 | 915 |
|
861 | 916 | # Log model parameter count |
862 | 917 | if self.rank == 0: |
@@ -1363,6 +1418,14 @@ def log_loss_valid(_task_key: str = "Default") -> dict: |
1363 | 1418 | fout, display_step_id, cur_lr, train_results, valid_results |
1364 | 1419 | ) |
1365 | 1420 |
|
| 1421 | + if self.full_validator is not None: |
| 1422 | + self.full_validator.run( |
| 1423 | + step_id=_step_id, |
| 1424 | + display_step=display_step_id, |
| 1425 | + lr=cur_lr, |
| 1426 | + save_checkpoint=self.save_model, |
| 1427 | + ) |
| 1428 | + |
1366 | 1429 | if ( |
1367 | 1430 | ( |
1368 | 1431 | (display_step_id) % self.save_freq == 0 |
|
0 commit comments