Commit 3f91293
feat(pt_expt): multi-task training support (#5397)
## Summary
- Add multi-task training to the pt_expt backend: multiple models share
a descriptor but have separate fitting nets/losses, trained on
heterogeneous datasets
- Support shared descriptor stats merging (probability-weighted
`merge_env_stat`), shared fitting via `share_params`, and `case_embd`
per-task embedding
- Add DDP (distributed) training support with rank-0 gating, stat
broadcast, and `find_unused_parameters`
- Add multi-task finetune (multi-task→multi-task and
single-task→multi-task) with extended type_map
- Enable `torch.compile` for multi-task with `backend="inductor"`,
`dynamic=True`, and symbolic tracing
(`make_fx(tracing_mode="symbolic")`); includes `silu_backward`
decomposition for second-order gradient compatibility
- Add `silut` activation variant and `DescrptBlockRepformers` accessors
(ported from #5393)
## Compile training benchmark (V100-SXM2-16GB, batch_size=1)
### Speed
| Model | Uncompiled (ms/step) | Compiled (ms/step) | Speedup |
|-------|---------------------|-------------------|---------|
| DPA1 (se_atten) | 21.5 | 10.2 | **2.11x** |
| DPA2 | 66.3 | 23.5 | **2.82x** |
| DPA3 | 161.9 | 50.6 | **3.20x** |
### Convergence (1000 steps)
Training data does not contain virial labels; only energy and force RMSE
are reported.
**DPA1** (se_atten_compressible, float64, rcut=6, sel=120)
| step | Uncompiled (e_rmse / f_rmse) | Compiled (e_rmse / f_rmse) |
|---|---|---|
| 1 | 4.24e-01 / 1.31 | 4.35e-01 / 1.31 |
| 500 | 3.74e-02 / 0.446 | 7.96e-03 / 0.437 |
| 1000 | 9.22e-03 / 0.285 | 8.65e-03 / 0.284 |
**DPA2** (input_torch_small, float32)
| step | Uncompiled (e_rmse / f_rmse) | Compiled (e_rmse / f_rmse) |
|---|---|---|
| 1 | 1.63e-01 / 0.983 | 1.59e-01 / 0.870 |
| 500 | 9.12e-02 / 0.562 | 6.50e-02 / 0.506 |
| 1000 | 4.78e-02 / 0.547 | 4.34e-02 / 0.430 |
**DPA3** (input_torch, float32, static sel)
| step | Uncompiled (e_rmse / f_rmse) | Compiled (e_rmse / f_rmse) |
|---|---|---|
| 1 | 3.88e-02 / 0.784 | 3.13e-02 / 0.893 |
| 500 | 2.25e-02 / 0.829 | 3.01e-03 / 0.879 |
| 1000 | 1.47e-03 / 0.400 | 1.78e-03 / 0.429 |
All models converge to comparable accuracy. Variation between compiled
and uncompiled is within normal run-to-run noise from random batch
ordering.
## Known limitations
- No `num_epoch_dict`: only `numb_steps` + `model_prob`; epoch-based
scheduling deferred
- Only `EnergyFittingNet` has `share_params`; other fitting types (DOS,
dipole, polar, property) need the same override
- No cross-backend (PT vs pt_expt) multi-task consistency test
- `share_fitting` + single-task finetune is incompatible (no
`dim_case_embd` in pretrained)
- DPA3 `use_dynamic_sel: true` cannot compile (symbolic tracer fails on
data-dependent `int()` in `get_graph_index`)
## Test plan
- [ ] 78 multi-task tests (`test_multitask.py`): training, freeze,
finetune, compile, shared fitting, DPA1/DPA2/DPA3/SeA
- [ ] 21 single-task training tests (`test_training.py`): compile
correctness for se_e2_a, DPA1 (with and without attention), DPA2, DPA3;
dynamic shapes, silu compile
- [ ] DDP training tests (`test_training_ddp.py`): single-task +
multi-task with compile
- [ ] Descriptor stat merge tests (`test_descrpt_stat_merge.py`)
- [ ] Fitting stat tests (`test_fitting_stat.py`)
- [ ] share_params tests for all descriptor types
- [ ] Activation tests (`test_activation.py`): silut export + compile
compatibility
---------
Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>1 parent d42732e commit 3f91293
45 files changed
Lines changed: 9090 additions & 582 deletions
File tree
- deepmd
- dpmodel
- descriptor
- fitting
- utils
- pt_expt
- descriptor
- entrypoints
- fitting
- train
- utils
- pt/model
- descriptor
- task
- source/tests
- common/dpmodel
- consistent
- pt_expt
- descriptor
- fitting
- utils
- pt
- model
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
405 | 405 | | |
406 | 406 | | |
407 | 407 | | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
408 | 412 | | |
409 | 413 | | |
410 | 414 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
345 | 345 | | |
346 | 346 | | |
347 | 347 | | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
| 351 | + | |
| 352 | + | |
| 353 | + | |
| 354 | + | |
| 355 | + | |
348 | 356 | | |
349 | 357 | | |
350 | 358 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
255 | 255 | | |
256 | 256 | | |
257 | 257 | | |
| 258 | + | |
258 | 259 | | |
259 | 260 | | |
260 | 261 | | |
| |||
296 | 297 | | |
297 | 298 | | |
298 | 299 | | |
| 300 | + | |
299 | 301 | | |
300 | 302 | | |
301 | 303 | | |
| |||
362 | 364 | | |
363 | 365 | | |
364 | 366 | | |
| 367 | + | |
365 | 368 | | |
366 | 369 | | |
367 | 370 | | |
| |||
407 | 410 | | |
408 | 411 | | |
409 | 412 | | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
410 | 417 | | |
411 | 418 | | |
412 | 419 | | |
| |||
666 | 673 | | |
667 | 674 | | |
668 | 675 | | |
669 | | - | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
670 | 679 | | |
671 | | - | |
672 | | - | |
673 | | - | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
674 | 683 | | |
675 | 684 | | |
676 | 685 | | |
| |||
687 | 696 | | |
688 | 697 | | |
689 | 698 | | |
690 | | - | |
| 699 | + | |
| 700 | + | |
| 701 | + | |
691 | 702 | | |
692 | | - | |
693 | | - | |
694 | | - | |
695 | | - | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
696 | 706 | | |
697 | 707 | | |
698 | 708 | | |
| |||
735 | 745 | | |
736 | 746 | | |
737 | 747 | | |
738 | | - | |
| 748 | + | |
| 749 | + | |
739 | 750 | | |
740 | 751 | | |
741 | 752 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
68 | 68 | | |
69 | 69 | | |
70 | 70 | | |
71 | | - | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
72 | 76 | | |
73 | 77 | | |
74 | 78 | | |
| |||
77 | 81 | | |
78 | 82 | | |
79 | 83 | | |
80 | | - | |
| 84 | + | |
81 | 85 | | |
82 | 86 | | |
83 | 87 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
43 | 112 | | |
44 | 113 | | |
45 | 114 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
383 | 383 | | |
384 | 384 | | |
385 | 385 | | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
386 | 390 | | |
387 | 391 | | |
388 | 392 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
779 | 779 | | |
780 | 780 | | |
781 | 781 | | |
782 | | - | |
| 782 | + | |
783 | 783 | | |
784 | | - | |
785 | | - | |
| 784 | + | |
| 785 | + | |
786 | 786 | | |
787 | 787 | | |
788 | 788 | | |
| |||
804 | 804 | | |
805 | 805 | | |
806 | 806 | | |
807 | | - | |
| 807 | + | |
808 | 808 | | |
809 | | - | |
810 | | - | |
| 809 | + | |
| 810 | + | |
811 | 811 | | |
812 | 812 | | |
813 | 813 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
12 | 15 | | |
13 | 16 | | |
14 | 17 | | |
| |||
26 | 29 | | |
27 | 30 | | |
28 | 31 | | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
29 | 57 | | |
30 | 58 | | |
31 | 59 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
14 | 14 | | |
15 | 15 | | |
16 | 16 | | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
17 | 20 | | |
18 | 21 | | |
19 | 22 | | |
| |||
30 | 33 | | |
31 | 34 | | |
32 | 35 | | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
33 | 77 | | |
34 | 78 | | |
35 | 79 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
4 | 7 | | |
5 | 8 | | |
6 | 9 | | |
| |||
16 | 19 | | |
17 | 20 | | |
18 | 21 | | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
0 commit comments