Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,45 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
self.profiling = training_params.get("profiling", False)
self.profiling_file = training_params.get("profiling_file", "timeline.json")

# Log model summary info (descriptor type and parameter count)
if self.rank == 0:
self._log_model_summary()

def _log_model_summary(self) -> None:
"""Log model summary information including descriptor type and parameter count."""

def get_descriptor_type(model: Any) -> str:
"""Get the descriptor type name from model."""
# Standard models have get_descriptor method
if hasattr(model, "get_descriptor"):
descriptor = model.get_descriptor()
return descriptor.serialize()["type"].upper()
# ZBL models: descriptor is in atomic_model.models[0]
if hasattr(model, "atomic_model") and hasattr(model.atomic_model, "models"):
dp_model = model.atomic_model.models[0]
if hasattr(dp_model, "descriptor"):
return (
dp_model.descriptor.serialize()["type"].upper() + " (with ZBL)"
)
return "UNKNOWN"
Comment thread
OutisLi marked this conversation as resolved.

def count_parameters(model: Any) -> int:
"""Count the total number of trainable parameters."""
return sum(p.numel() for p in model.parameters())
Comment thread
OutisLi marked this conversation as resolved.

if not self.multi_task:
desc_type = get_descriptor_type(self.model)
num_params = count_parameters(self.model)
log.info(f"Descriptor: {desc_type}")
log.info(f"Model params: {num_params / 1e6:.3f} M")
else:
# For multi-task, log each model's info
for model_key in self.model_keys:
desc_type = get_descriptor_type(self.model[model_key])
num_params = count_parameters(self.model[model_key])
log.info(f"Descriptor [{model_key}]: {desc_type}")
log.info(f"Model params [{model_key}]: {num_params / 1e6:.3f} M")
Comment on lines +728 to +761
Copy link

Copilot AI Jan 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new _log_model_summary() method that logs descriptor type and parameter count lacks explicit test coverage. While existing training tests will execute this code path, consider adding a dedicated test to verify that the descriptor type is correctly detected for different model types (standard models, ZBL models, multi-task models) and that the parameter counting logic works as expected. This would help catch potential issues with model structure assumptions.

Copilot uses AI. Check for mistakes.

def run(self) -> None:
fout = (
open(
Expand Down
5 changes: 5 additions & 0 deletions deepmd/utils/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ def __call__(self) -> None:
"computing device": self.get_compute_device(),
}
)
if build_info["Backend"] == "PyTorch":
import torch

if torch.cuda.is_available():
build_info["device name"] = torch.cuda.get_device_name(0)
if self.is_built_with_cuda():
env_value = os.environ.get("CUDA_VISIBLE_DEVICES", "unset")
build_info["CUDA_VISIBLE_DEVICES"] = env_value
Expand Down