diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index d98b23d25c..552ebf3db9 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -721,6 +721,45 @@ def warm_up_linear(step: int, warmup_steps: int) -> float: self.profiling = training_params.get("profiling", False) self.profiling_file = training_params.get("profiling_file", "timeline.json") + # Log model summary info (descriptor type and parameter count) + if self.rank == 0: + self._log_model_summary() + + def _log_model_summary(self) -> None: + """Log model summary information including descriptor type and parameter count.""" + + def get_descriptor_type(model: Any) -> str: + """Get the descriptor type name from model.""" + # Standard models have get_descriptor method + if hasattr(model, "get_descriptor"): + descriptor = model.get_descriptor() + return descriptor.serialize()["type"].upper() + # ZBL models: descriptor is in atomic_model.models[0] + if hasattr(model, "atomic_model") and hasattr(model.atomic_model, "models"): + dp_model = model.atomic_model.models[0] + if hasattr(dp_model, "descriptor"): + return ( + dp_model.descriptor.serialize()["type"].upper() + " (with ZBL)" + ) + return "UNKNOWN" + + def count_parameters(model: Any) -> int: + """Count the total number of trainable parameters.""" + return sum(p.numel() for p in model.parameters()) + + if not self.multi_task: + desc_type = get_descriptor_type(self.model) + num_params = count_parameters(self.model) + log.info(f"Descriptor: {desc_type}") + log.info(f"Model params: {num_params / 1e6:.3f} M") + else: + # For multi-task, log each model's info + for model_key in self.model_keys: + desc_type = get_descriptor_type(self.model[model_key]) + num_params = count_parameters(self.model[model_key]) + log.info(f"Descriptor [{model_key}]: {desc_type}") + log.info(f"Model params [{model_key}]: {num_params / 1e6:.3f} M") + def run(self) -> None: fout = ( open( diff --git a/deepmd/utils/summary.py b/deepmd/utils/summary.py index c00e6deb9e..28aab99a98 100644 --- a/deepmd/utils/summary.py +++ b/deepmd/utils/summary.py @@ -74,6 +74,11 @@ def __call__(self) -> None: "computing device": self.get_compute_device(), } ) + if build_info["Backend"] == "PyTorch": + import torch + + if torch.cuda.is_available(): + build_info["device name"] = torch.cuda.get_device_name(0) if self.is_built_with_cuda(): env_value = os.environ.get("CUDA_VISIBLE_DEVICES", "unset") build_info["CUDA_VISIBLE_DEVICES"] = env_value