Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepmd/backend/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ def entry_point_hook(self) -> Callable[["Namespace"], None]:
Callable[[Namespace], None]
The entry point hook of the backend.
"""
raise NotImplementedError
from deepmd.jax.entrypoints.main import (
main,
)

return main

@property
def deep_eval(self) -> type["DeepEvalBackend"]:
Expand Down
1 change: 1 addition & 0 deletions deepmd/jax/entrypoints/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
49 changes: 49 additions & 0 deletions deepmd/jax/entrypoints/freeze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Freeze utilities for the JAX backend."""

from pathlib import (
Path,
)

from deepmd.backend.suffix import (
format_model_suffix,
)
from deepmd.jax.utils.serialization import (
deserialize_to_file,
serialize_from_file,
)


def freeze(
*,
checkpoint_folder: str,
output: str,
**kwargs: object,
) -> None:
"""Freeze a JAX checkpoint into a serialized model file.

Parameters
----------
checkpoint_folder : str
Location of either the checkpoint directory or a folder containing the
stable ``checkpoint`` pointer.
output : str
Output model filename or prefix. The JAX model suffix is added when the
filename has no supported backend suffix.
**kwargs
Other CLI arguments accepted for backend entry-point compatibility.
"""
del kwargs

checkpoint_path = Path(checkpoint_folder)
if (checkpoint_path / "checkpoint").is_file():
checkpoint_pointer = (checkpoint_path / "checkpoint").read_text().strip()
checkpoint_folder = str(checkpoint_path / checkpoint_pointer)

output = format_model_suffix(
output,
preferred_backend="jax",
strict_prefer=True,
)
data = serialize_from_file(checkpoint_folder)
deserialize_to_file(output, data)
57 changes: 57 additions & 0 deletions deepmd/jax/entrypoints/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""DeePMD-Kit entry point module."""

import argparse
from pathlib import (
Path,
)

from deepmd.jax.entrypoints.freeze import (
freeze,
Comment thread
njzjz-bot marked this conversation as resolved.
)
from deepmd.jax.entrypoints.train import (
train,
)
from deepmd.loggers.loggers import (
set_log_handles,
)
from deepmd.main import (
parse_args,
)

__all__ = ["main"]


def main(args: list[str] | argparse.Namespace | None = None) -> None:
"""DeePMD-Kit entry point.

Parameters
----------
args : list[str] or argparse.Namespace, optional
list of command line arguments, used to avoid calling from the subprocess,
as it is quite slow to import tensorflow; if Namespace is given, it will
be used directly

Raises
------
RuntimeError
if no command was input
"""
if not isinstance(args, argparse.Namespace):
args = parse_args(args=args)

dict_args = vars(args)
set_log_handles(
args.log_level,
Path(args.log_path) if args.log_path else None,
mpi_log=None,
)

if args.command == "train":
train(**dict_args)
elif args.command == "freeze":
freeze(**dict_args)
elif args.command is None:
pass
else:
raise RuntimeError(f"unknown command {args.command}")
203 changes: 203 additions & 0 deletions deepmd/jax/entrypoints/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""DeePMD training entrypoint script.

Can handle local training.
"""

import json
import logging
import time
from typing import (
Any,
)

from deepmd.common import (
j_loader,
)
from deepmd.jax.env import (
jax,
jax_export,
)
from deepmd.jax.train.trainer import (
DPTrainer,
)
from deepmd.utils import random as dp_random
from deepmd.utils.argcheck import (
normalize,
)
from deepmd.utils.compat import (
update_deepmd_input,
)
from deepmd.utils.data_system import (
get_data,
)
from deepmd.utils.summary import SummaryPrinter as BaseSummaryPrinter

__all__ = ["train"]

log = logging.getLogger(__name__)


class SummaryPrinter(BaseSummaryPrinter):
"""Summary printer for JAX."""

def is_built_with_cuda(self) -> bool:
"""Check if the backend is built with CUDA."""
return jax_export.default_export_platform() == "cuda"

def is_built_with_rocm(self) -> bool:
"""Check if the backend is built with ROCm."""
return jax_export.default_export_platform() == "rocm"

def get_compute_device(self) -> str:
"""Get Compute device."""
return jax.default_backend()

def get_ngpus(self) -> int:
"""Get the number of GPUs."""
return jax.device_count()

def get_backend_info(self) -> dict:
"""Get backend information."""
return {
"Backend": "JAX",
"JAX ver": jax.__version__,
}

def get_device_name(self) -> str:
"""Get the name of the device."""
devices = jax.devices()
if devices:
return devices[0].device_kind
else:
return "Unknown"


def train(
*,
INPUT: str,
init_model: str | None,
restart: str | None,
output: str,
init_frz_model: str | None,
mpi_log: str,
log_level: int,
log_path: str | None,
skip_neighbor_stat: bool = False,
finetune: str | None = None,
use_pretrain_script: bool = False,
**kwargs: Any,
) -> None:
Comment thread
njzjz-bot marked this conversation as resolved.
"""Run DeePMD model training.

Parameters
----------
INPUT : str
json/yaml control file
init_model : Optional[str]
path prefix of checkpoint files or None
restart : Optional[str]
path prefix of checkpoint files or None
output : str
path for dump file with arguments
init_frz_model : str | None
path to frozen model, or None if no frozen model is used
mpi_log : str
mpi logging mode
log_level : int
logging level defined by int 0-3
log_path : Optional[str]
logging file path or None if logs are to be output only to stdout
skip_neighbor_stat : bool, default=False
skip checking neighbor statistics
finetune : Optional[str]
path to pretrained model or None
use_pretrain_script : bool
Whether to use model script in pretrained model when doing init-model or init-frz-model.
Note that this option is true and unchangeable for fine-tuning.
**kwargs
additional arguments

Raises
------
RuntimeError
if the training command fails.
"""
# load json database
jdata = j_loader(INPUT)

if init_frz_model:
raise NotImplementedError("JAX training does not support init_frz_model yet")
if finetune:
raise NotImplementedError("JAX training does not support finetune yet")
if use_pretrain_script:
raise NotImplementedError(
"JAX training does not support use_pretrain_script yet"
)

jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")

jdata = normalize(jdata)
if not skip_neighbor_stat:
jdata = update_sel(jdata)

with open(output, "w") as fp:
json.dump(jdata, fp, indent=4)
SummaryPrinter()()

# make necessary checks
assert "training" in jdata

# init the model

model = DPTrainer(
jdata,
init_model=init_model,
restart=restart,
)
rcut = model.model.get_rcut()
type_map = model.model.get_type_map()
if len(type_map) == 0:
ipt_type_map = None
else:
ipt_type_map = type_map

# init random seed of data systems
seed = jdata["training"].get("seed", None)
if seed is not None:
seed += jax.process_index()
seed = seed % (2**32)
dp_random.seed(seed)

# init data
train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, None)
train_data.add_data_requirements(model.data_requirements)
train_data.print_summary("training")
if jdata["training"].get("validation_data", None) is not None:
valid_data = get_data(
jdata["training"]["validation_data"],
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
rcut,
train_data.type_map,
None,
)
valid_data.add_data_requirements(model.data_requirements)
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
valid_data.print_summary("validation")
else:
valid_data = None

# train the model with the provided systems in a cyclic way
start_time = time.time()
model.train(train_data, valid_data)
end_time = time.time()
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
log.info("finished training")
log.info(f"wall time: {(end_time - start_time):.3f} s")


def update_sel(jdata: dict) -> dict:
"""Update descriptor selections from neighbor statistics when available."""
log.info(
"Skip neighbor statistics update for JAX training; "
"BaseModel.update_sel currently needs more memory than expected."
)
# TODO: Restore BaseModel.update_sel once the JAX data path avoids OOM.
return jdata.copy()
1 change: 1 addition & 0 deletions deepmd/jax/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
Loading
Loading