|
54 | 54 | KFOptimizerWrapper, |
55 | 55 | LKFOptimizer, |
56 | 56 | ) |
| 57 | +from deepmd.pt.train.validation import ( |
| 58 | + FullValidator, |
| 59 | + resolve_full_validation_start_step, |
| 60 | +) |
57 | 61 | from deepmd.pt.train.wrapper import ( |
58 | 62 | ModelWrapper, |
59 | 63 | ) |
@@ -857,11 +861,88 @@ def single_model_finetune( |
857 | 861 | self.enable_profiler = training_params.get("enable_profiler", False) |
858 | 862 | self.profiling = training_params.get("profiling", False) |
859 | 863 | self.profiling_file = training_params.get("profiling_file", "timeline.json") |
| 864 | + validating_params = config.get("validating") or {} |
| 865 | + self.full_validator = self._create_full_validator( |
| 866 | + validating_params=validating_params, |
| 867 | + validation_data=validation_data, |
| 868 | + ) |
860 | 869 |
|
861 | 870 | # Log model parameter count |
862 | 871 | if self.rank == 0: |
863 | 872 | self._log_parameter_count() |
864 | 873 |
|
| 874 | + def _create_full_validator( |
| 875 | + self, |
| 876 | + *, |
| 877 | + validating_params: dict[str, Any], |
| 878 | + validation_data: DpLoaderSet | None, |
| 879 | + ) -> FullValidator | None: |
| 880 | + """Create the runtime full validator when it is active.""" |
| 881 | + if not self._is_full_validation_requested(validating_params): |
| 882 | + return None |
| 883 | + self._raise_if_full_validation_unsupported(validation_data) |
| 884 | + if validation_data is None: |
| 885 | + raise RuntimeError( |
| 886 | + "validation_data must be available after full validation checks." |
| 887 | + ) |
| 888 | + return FullValidator( |
| 889 | + validating_params=validating_params, |
| 890 | + validation_data=validation_data, |
| 891 | + model=self.model, |
| 892 | + train_infos=self._get_inner_module().train_infos, |
| 893 | + num_steps=self.num_steps, |
| 894 | + rank=self.rank, |
| 895 | + zero_stage=self.zero_stage, |
| 896 | + restart_training=self.restart_training, |
| 897 | + ) |
| 898 | + |
| 899 | + def _is_full_validation_requested(self, validating_params: dict[str, Any]) -> bool: |
| 900 | + """Check whether full validation can trigger during this training run.""" |
| 901 | + if not validating_params.get("full_validation", False): |
| 902 | + return False |
| 903 | + start_step = resolve_full_validation_start_step( |
| 904 | + validating_params.get("full_val_start", 0.0), |
| 905 | + self.num_steps, |
| 906 | + ) |
| 907 | + return start_step is not None and start_step <= self.num_steps |
| 908 | + |
| 909 | + def _raise_if_full_validation_unsupported( |
| 910 | + self, |
| 911 | + validation_data: DpLoaderSet | None, |
| 912 | + ) -> None: |
| 913 | + """Validate runtime full validation constraints.""" |
| 914 | + if self.multi_task: |
| 915 | + raise ValueError( |
| 916 | + "validating.full_validation only supports single-task energy " |
| 917 | + "training; multi-task training is not supported." |
| 918 | + ) |
| 919 | + |
| 920 | + has_spin = getattr(self.model, "has_spin", False) |
| 921 | + if callable(has_spin): |
| 922 | + has_spin = has_spin() |
| 923 | + if has_spin or isinstance(self.loss, EnergySpinLoss): |
| 924 | + raise ValueError( |
| 925 | + "validating.full_validation only supports single-task energy " |
| 926 | + "training; spin-energy training is not supported." |
| 927 | + ) |
| 928 | + |
| 929 | + if not isinstance(self.loss, EnergyStdLoss): |
| 930 | + raise ValueError( |
| 931 | + "validating.full_validation only supports single-task energy training." |
| 932 | + ) |
| 933 | + |
| 934 | + if validation_data is None: |
| 935 | + raise ValueError( |
| 936 | + "validating.full_validation requires `training.validation_data` " |
| 937 | + "to be configured." |
| 938 | + ) |
| 939 | + |
| 940 | + if self.zero_stage >= 2: |
| 941 | + raise ValueError( |
| 942 | + "validating.full_validation only supports single-task energy " |
| 943 | + "training with training.zero_stage < 2." |
| 944 | + ) |
| 945 | + |
865 | 946 | @staticmethod |
866 | 947 | def _count_parameters(model: torch.nn.Module) -> tuple[int, int]: |
867 | 948 | """ |
@@ -1363,6 +1444,14 @@ def log_loss_valid(_task_key: str = "Default") -> dict: |
1363 | 1444 | fout, display_step_id, cur_lr, train_results, valid_results |
1364 | 1445 | ) |
1365 | 1446 |
|
| 1447 | + if self.full_validator is not None: |
| 1448 | + self.full_validator.run( |
| 1449 | + step_id=_step_id, |
| 1450 | + display_step=display_step_id, |
| 1451 | + lr=cur_lr, |
| 1452 | + save_checkpoint=self.save_model, |
| 1453 | + ) |
| 1454 | + |
1366 | 1455 | if ( |
1367 | 1456 | ( |
1368 | 1457 | (display_step_id) % self.save_freq == 0 |
|
0 commit comments