Skip to content

Commit 44ba37c

Browse files
committed
refactor(pt): refactor training code
1 parent 9465d71 commit 44ba37c

12 files changed

Lines changed: 5128 additions & 13 deletions

File tree

deepmd/pt/entrypoints/main.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
from deepmd.pt.train import (
5555
training,
5656
)
57+
from deepmd.pt.train.trainer import Trainer as NewTrainer
5758
from deepmd.pt.train.wrapper import (
5859
ModelWrapper,
5960
)
@@ -106,6 +107,7 @@ def get_trainer(
106107
init_frz_model: str | None = None,
107108
shared_links: dict[str, Any] | None = None,
108109
finetune_links: dict[str, Any] | None = None,
110+
use_legacy: bool = False,
109111
) -> training.Trainer:
110112
multi_task = "model_dict" in config.get("model", {})
111113

@@ -200,19 +202,34 @@ def prepare_trainer_input_single(
200202
seed=data_seed,
201203
)
202204

203-
trainer = training.Trainer(
204-
config,
205-
train_data,
206-
stat_file_path=stat_file_path,
207-
validation_data=validation_data,
208-
init_model=init_model,
209-
restart_model=restart_model,
210-
finetune_model=finetune_model,
211-
force_load=force_load,
212-
shared_links=shared_links,
213-
finetune_links=finetune_links,
214-
init_frz_model=init_frz_model,
215-
)
205+
if use_legacy:
206+
trainer = training.Trainer(
207+
config,
208+
train_data,
209+
stat_file_path=stat_file_path,
210+
validation_data=validation_data,
211+
init_model=init_model,
212+
restart_model=restart_model,
213+
finetune_model=finetune_model,
214+
force_load=force_load,
215+
shared_links=shared_links,
216+
finetune_links=finetune_links,
217+
init_frz_model=init_frz_model,
218+
)
219+
else:
220+
trainer = NewTrainer(
221+
config,
222+
train_data,
223+
stat_file_path=stat_file_path,
224+
validation_data=validation_data,
225+
init_model=init_model,
226+
restart_model=restart_model,
227+
finetune_model=finetune_model,
228+
force_load=force_load,
229+
shared_links=shared_links,
230+
finetune_links=finetune_links,
231+
init_frz_model=init_frz_model,
232+
)
216233
return trainer
217234

218235

deepmd/pt/train/__init__.py

100644100755
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,102 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
"""PyTorch training module with modular, extensible design.
3+
4+
This module provides a clean, component-based training system:
5+
6+
- TrainingConfig: Configuration management with validation
7+
- DataManager: Data loading and batch iteration
8+
- OptimizerFactory: Strategy pattern for optimizer creation
9+
- CheckpointManager: Model persistence and recovery
10+
- TrainingLoop: Specialized training step implementations
11+
- HookManager: Extensible callback system
12+
- TrainingLogger: Formatted output and file I/O
13+
- Trainer: Main orchestrator coordinating all components
14+
15+
Example:
16+
>>> from deepmd.pt.train import Trainer, TrainingConfig
17+
>>>
18+
>>> # Create trainer
19+
>>> trainer = Trainer(
20+
... config=config_dict,
21+
... training_data=train_dataset,
22+
... validation_data=valid_dataset,
23+
... )
24+
>>>
25+
>>> # Run training
26+
>>> trainer.run()
27+
28+
Future extensions for multi-backend support:
29+
- AbstractTrainingLoop can be extended for JAX/NumPy
30+
- OptimizerFactory can support backend-specific optimizers
31+
- DataManager can use backend-specific data loading
32+
"""
33+
34+
from deepmd.pt.train.checkpoint_manager import (
35+
CheckpointManager,
36+
)
37+
from deepmd.pt.train.config import (
38+
CheckpointConfig,
39+
DisplayConfig,
40+
LearningRateConfig,
41+
OptimizerConfig,
42+
TrainingConfig,
43+
)
44+
from deepmd.pt.train.data_manager import (
45+
DataManager,
46+
)
47+
from deepmd.pt.train.hooks import (
48+
HookManager,
49+
HookPriority,
50+
TensorBoardHook,
51+
TimingHook,
52+
TrainingHook,
53+
)
54+
from deepmd.pt.train.logger import (
55+
LossAccumulator,
56+
TrainingLogger,
57+
)
58+
from deepmd.pt.train.optimizer_factory import (
59+
OptimizerFactory,
60+
)
61+
from deepmd.pt.train.trainer import (
62+
Trainer,
63+
)
64+
65+
# Keep old Trainer available for backward compatibility during transition
66+
from deepmd.pt.train.training import Trainer as LegacyTrainer
67+
from deepmd.pt.train.training_loop import (
68+
AdamTrainingLoop,
69+
BaseTrainingLoop,
70+
LKFEnergyTrainingLoop,
71+
TrainingLoopFactory,
72+
)
73+
from deepmd.pt.train.wrapper import (
74+
ModelWrapper,
75+
)
76+
77+
__all__ = [
78+
# New modular components
79+
"AdamTrainingLoop",
80+
"BaseTrainingLoop",
81+
"CheckpointConfig",
82+
"CheckpointManager",
83+
"DataManager",
84+
"DisplayConfig",
85+
"HookManager",
86+
"HookPriority",
87+
"LKFEnergyTrainingLoop",
88+
"LearningRateConfig",
89+
# Legacy support
90+
"LegacyTrainer",
91+
"LossAccumulator",
92+
"ModelWrapper",
93+
"OptimizerConfig",
94+
"OptimizerFactory",
95+
"TensorBoardHook",
96+
"TimingHook",
97+
"Trainer",
98+
"TrainingConfig",
99+
"TrainingHook",
100+
"TrainingLogger",
101+
"TrainingLoopFactory",
102+
]

0 commit comments

Comments
 (0)