-
Notifications
You must be signed in to change notification settings - Fork 611
Expand file tree
/
Copy path__init__.py
More file actions
executable file
·102 lines (95 loc) · 2.56 KB
/
__init__.py
File metadata and controls
executable file
·102 lines (95 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# SPDX-License-Identifier: LGPL-3.0-or-later
"""PyTorch training module with modular, extensible design.
This module provides a clean, component-based training system:
- TrainingConfig: Configuration management with validation
- DataManager: Data loading and batch iteration
- OptimizerFactory: Strategy pattern for optimizer creation
- CheckpointManager: Model persistence and recovery
- TrainingLoop: Specialized training step implementations
- HookManager: Extensible callback system
- TrainingLogger: Formatted output and file I/O
- Trainer: Main orchestrator coordinating all components
Example:
>>> from deepmd.pt.train import Trainer, TrainingConfig
>>>
>>> # Create trainer
>>> trainer = Trainer(
... config=config_dict,
... training_data=train_dataset,
... validation_data=valid_dataset,
... )
>>>
>>> # Run training
>>> trainer.run()
Future extensions for multi-backend support:
- AbstractTrainingLoop can be extended for JAX/NumPy
- OptimizerFactory can support backend-specific optimizers
- DataManager can use backend-specific data loading
"""
from deepmd.pt.train.checkpoint_manager import (
CheckpointManager,
)
from deepmd.pt.train.config import (
CheckpointConfig,
DisplayConfig,
LearningRateConfig,
OptimizerConfig,
TrainingConfig,
)
from deepmd.pt.train.data_manager import (
DataManager,
)
from deepmd.pt.train.hooks import (
HookManager,
HookPriority,
TensorBoardHook,
TimingHook,
TrainingHook,
)
from deepmd.pt.train.logger import (
LossAccumulator,
TrainingLogger,
)
from deepmd.pt.train.optimizer_factory import (
OptimizerFactory,
)
from deepmd.pt.train.trainer import (
Trainer,
)
# Keep old Trainer available for backward compatibility during transition
from deepmd.pt.train.training import Trainer as LegacyTrainer
from deepmd.pt.train.training_loop import (
AdamTrainingLoop,
BaseTrainingLoop,
LKFEnergyTrainingLoop,
TrainingLoopFactory,
)
from deepmd.pt.train.wrapper import (
ModelWrapper,
)
__all__ = [
# New modular components
"AdamTrainingLoop",
"BaseTrainingLoop",
"CheckpointConfig",
"CheckpointManager",
"DataManager",
"DisplayConfig",
"HookManager",
"HookPriority",
"LKFEnergyTrainingLoop",
"LearningRateConfig",
# Legacy support
"LegacyTrainer",
"LossAccumulator",
"ModelWrapper",
"OptimizerConfig",
"OptimizerFactory",
"TensorBoardHook",
"TimingHook",
"Trainer",
"TrainingConfig",
"TrainingHook",
"TrainingLogger",
"TrainingLoopFactory",
]