|
54 | 54 | from deepmd.pt.train import ( |
55 | 55 | training, |
56 | 56 | ) |
| 57 | +from deepmd.pt.train.trainer import Trainer as NewTrainer |
57 | 58 | from deepmd.pt.train.wrapper import ( |
58 | 59 | ModelWrapper, |
59 | 60 | ) |
@@ -106,6 +107,7 @@ def get_trainer( |
106 | 107 | init_frz_model: str | None = None, |
107 | 108 | shared_links: dict[str, Any] | None = None, |
108 | 109 | finetune_links: dict[str, Any] | None = None, |
| 110 | + use_legacy: bool = False, |
109 | 111 | ) -> training.Trainer: |
110 | 112 | multi_task = "model_dict" in config.get("model", {}) |
111 | 113 |
|
@@ -200,19 +202,34 @@ def prepare_trainer_input_single( |
200 | 202 | seed=data_seed, |
201 | 203 | ) |
202 | 204 |
|
203 | | - trainer = training.Trainer( |
204 | | - config, |
205 | | - train_data, |
206 | | - stat_file_path=stat_file_path, |
207 | | - validation_data=validation_data, |
208 | | - init_model=init_model, |
209 | | - restart_model=restart_model, |
210 | | - finetune_model=finetune_model, |
211 | | - force_load=force_load, |
212 | | - shared_links=shared_links, |
213 | | - finetune_links=finetune_links, |
214 | | - init_frz_model=init_frz_model, |
215 | | - ) |
| 205 | + if use_legacy: |
| 206 | + trainer = training.Trainer( |
| 207 | + config, |
| 208 | + train_data, |
| 209 | + stat_file_path=stat_file_path, |
| 210 | + validation_data=validation_data, |
| 211 | + init_model=init_model, |
| 212 | + restart_model=restart_model, |
| 213 | + finetune_model=finetune_model, |
| 214 | + force_load=force_load, |
| 215 | + shared_links=shared_links, |
| 216 | + finetune_links=finetune_links, |
| 217 | + init_frz_model=init_frz_model, |
| 218 | + ) |
| 219 | + else: |
| 220 | + trainer = NewTrainer( |
| 221 | + config, |
| 222 | + train_data, |
| 223 | + stat_file_path=stat_file_path, |
| 224 | + validation_data=validation_data, |
| 225 | + init_model=init_model, |
| 226 | + restart_model=restart_model, |
| 227 | + finetune_model=finetune_model, |
| 228 | + force_load=force_load, |
| 229 | + shared_links=shared_links, |
| 230 | + finetune_links=finetune_links, |
| 231 | + init_frz_model=init_frz_model, |
| 232 | + ) |
216 | 233 | return trainer |
217 | 234 |
|
218 | 235 |
|
|
0 commit comments