|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +"""DeePMD training entrypoint script. |
| 3 | +
|
| 4 | +Can handle local training. |
| 5 | +""" |
| 6 | + |
| 7 | +import json |
| 8 | +import logging |
| 9 | +import time |
| 10 | +from typing import ( |
| 11 | + Any, |
| 12 | + Optional, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +from deepmd.common import ( |
| 17 | + j_loader, |
| 18 | +) |
| 19 | +from deepmd.jax.env import ( |
| 20 | + jax, |
| 21 | + jax_export, |
| 22 | +) |
| 23 | +from deepmd.jax.train.trainer import ( |
| 24 | + DPTrainer, |
| 25 | +) |
| 26 | +from deepmd.utils import random as dp_random |
| 27 | +from deepmd.utils.argcheck import ( |
| 28 | + normalize, |
| 29 | +) |
| 30 | +from deepmd.utils.compat import ( |
| 31 | + update_deepmd_input, |
| 32 | +) |
| 33 | +from deepmd.utils.data_system import ( |
| 34 | + get_data, |
| 35 | +) |
| 36 | +from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter |
| 37 | + |
| 38 | +__all__ = ["train"] |
| 39 | + |
| 40 | +log = logging.getLogger(__name__) |
| 41 | + |
| 42 | + |
| 43 | +class SummaryPrinter(BaseSummaryPrinter): |
| 44 | + """Summary printer for JAX.""" |
| 45 | + |
| 46 | + def is_built_with_cuda(self) -> bool: |
| 47 | + """Check if the backend is built with CUDA.""" |
| 48 | + return jax_export.default_export_platform() == "cuda" |
| 49 | + |
| 50 | + def is_built_with_rocm(self) -> bool: |
| 51 | + """Check if the backend is built with ROCm.""" |
| 52 | + return jax_export.default_export_platform() == "rocm" |
| 53 | + |
| 54 | + def get_compute_device(self) -> str: |
| 55 | + """Get Compute device.""" |
| 56 | + return jax.default_backend() |
| 57 | + |
| 58 | + def get_ngpus(self) -> int: |
| 59 | + """Get the number of GPUs.""" |
| 60 | + return jax.device_count() |
| 61 | + |
| 62 | + def get_backend_info(self) -> dict: |
| 63 | + """Get backend information.""" |
| 64 | + return { |
| 65 | + "Backend": "JAX", |
| 66 | + "JAX ver": jax.__version__, |
| 67 | + } |
| 68 | + |
| 69 | + def get_device_name(self) -> str: |
| 70 | + """Get the name of the device.""" |
| 71 | + devices = jax.devices() |
| 72 | + if devices: |
| 73 | + return devices[0].device_kind |
| 74 | + else: |
| 75 | + return "Unknown" |
| 76 | + |
| 77 | + |
| 78 | +def train( |
| 79 | + *, |
| 80 | + INPUT: str, |
| 81 | + init_model: Optional[str], |
| 82 | + restart: Optional[str], |
| 83 | + output: str, |
| 84 | + init_frz_model: str, |
| 85 | + mpi_log: str, |
| 86 | + log_level: int, |
| 87 | + log_path: Optional[str], |
| 88 | + skip_neighbor_stat: bool = False, |
| 89 | + finetune: Optional[str] = None, |
| 90 | + use_pretrain_script: bool = False, |
| 91 | + **kwargs: Any, |
| 92 | +) -> None: |
| 93 | + """Run DeePMD model training. |
| 94 | +
|
| 95 | + Parameters |
| 96 | + ---------- |
| 97 | + INPUT : str |
| 98 | + json/yaml control file |
| 99 | + init_model : Optional[str] |
| 100 | + path prefix of checkpoint files or None |
| 101 | + restart : Optional[str] |
| 102 | + path prefix of checkpoint files or None |
| 103 | + output : str |
| 104 | + path for dump file with arguments |
| 105 | + init_frz_model : str |
| 106 | + path to frozen model or None |
| 107 | + mpi_log : str |
| 108 | + mpi logging mode |
| 109 | + log_level : int |
| 110 | + logging level defined by int 0-3 |
| 111 | + log_path : Optional[str] |
| 112 | + logging file path or None if logs are to be output only to stdout |
| 113 | + skip_neighbor_stat : bool, default=False |
| 114 | + skip checking neighbor statistics |
| 115 | + finetune : Optional[str] |
| 116 | + path to pretrained model or None |
| 117 | + use_pretrain_script : bool |
| 118 | + Whether to use model script in pretrained model when doing init-model or init-frz-model. |
| 119 | + Note that this option is true and unchangeable for fine-tuning. |
| 120 | + **kwargs |
| 121 | + additional arguments |
| 122 | +
|
| 123 | + Raises |
| 124 | + ------ |
| 125 | + RuntimeError |
| 126 | + if the training command fails. |
| 127 | + """ |
| 128 | + # load json database |
| 129 | + jdata = j_loader(INPUT) |
| 130 | + |
| 131 | + origin_type_map = None |
| 132 | + |
| 133 | + jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json") |
| 134 | + |
| 135 | + jdata = normalize(jdata) |
| 136 | + jdata = update_sel(jdata) |
| 137 | + |
| 138 | + with open(output, "w") as fp: |
| 139 | + json.dump(jdata, fp, indent=4) |
| 140 | + SummaryPrinter()() |
| 141 | + |
| 142 | + # make necessary checks |
| 143 | + assert "training" in jdata |
| 144 | + |
| 145 | + # init the model |
| 146 | + |
| 147 | + model = DPTrainer( |
| 148 | + jdata, |
| 149 | + init_model=init_model, |
| 150 | + restart=restart, |
| 151 | + ) |
| 152 | + rcut = model.model.get_rcut() |
| 153 | + type_map = model.model.get_type_map() |
| 154 | + if len(type_map) == 0: |
| 155 | + ipt_type_map = None |
| 156 | + else: |
| 157 | + ipt_type_map = type_map |
| 158 | + |
| 159 | + # init random seed of data systems |
| 160 | + seed = jdata["training"].get("seed", None) |
| 161 | + if seed is not None: |
| 162 | + seed += jax.process_index() |
| 163 | + seed = seed % (2**32) |
| 164 | + dp_random.seed(seed) |
| 165 | + |
| 166 | + # init data |
| 167 | + train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, None) |
| 168 | + train_data.add_data_requirements(model.data_requirements) |
| 169 | + train_data.print_summary("training") |
| 170 | + if jdata["training"].get("validation_data", None) is not None: |
| 171 | + valid_data = get_data( |
| 172 | + jdata["training"]["validation_data"], |
| 173 | + rcut, |
| 174 | + train_data.type_map, |
| 175 | + None, |
| 176 | + ) |
| 177 | + valid_data.add_data_requirements(model.data_requirements) |
| 178 | + valid_data.print_summary("validation") |
| 179 | + else: |
| 180 | + valid_data = None |
| 181 | + |
| 182 | + # get training info |
| 183 | + stop_batch = jdata["training"]["numb_steps"] |
| 184 | + origin_type_map = jdata["model"].get("origin_type_map", None) |
| 185 | + if ( |
| 186 | + origin_type_map is not None and not origin_type_map |
| 187 | + ): # get the type_map from data if not provided |
| 188 | + origin_type_map = get_data( |
| 189 | + jdata["training"]["training_data"], rcut, None, None |
| 190 | + ).get_type_map() |
| 191 | + |
| 192 | + # train the model with the provided systems in a cyclic way |
| 193 | + start_time = time.time() |
| 194 | + model.train(train_data, valid_data) |
| 195 | + end_time = time.time() |
| 196 | + log.info("finished training") |
| 197 | + log.info(f"wall time: {(end_time - start_time):.3f} s") |
| 198 | + |
| 199 | + |
| 200 | +def update_sel(jdata: dict) -> dict: |
| 201 | + """Update descriptor selections from neighbor statistics when available.""" |
| 202 | + log.info( |
| 203 | + "Calculate neighbor statistics... (add --skip-neighbor-stat to skip this step)" |
| 204 | + ) |
| 205 | + jdata_cpy = jdata.copy() |
| 206 | + type_map = jdata["model"].get("type_map") |
| 207 | + train_data = get_data( |
| 208 | + jdata["training"]["training_data"], |
| 209 | + 0, # not used |
| 210 | + type_map, |
| 211 | + None, # not used |
| 212 | + ) |
| 213 | + # TODO: OOM, need debug |
| 214 | + # jdata_cpy["model"], min_nbor_dist = BaseModel.update_sel( |
| 215 | + # train_data, type_map, jdata["model"] |
| 216 | + # ) |
| 217 | + return jdata_cpy |
0 commit comments