Skip to content
Merged
28 changes: 28 additions & 0 deletions dpgen/generator/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,13 @@ def model_devi_lmp_args() -> list[Argument]:
doc_epsilon_v = (
"The level parameter for computing the relative virial model deviation."
)
doc_lmp_d3 = "D3 dispersion configuration for LAMMPS. When present, all sub-parameters are required."
doc_lmp_d3_enable = "Enable D3 dispersion correction. If false, D3 will be disabled even if other parameters are present."
doc_lmp_d3_damping_function = "Damping function for D3 dispersion. Options include 'original', 'zerom', 'bj', 'bjm'."
doc_lmp_d3_functional = "Exchange-correlation functional for D3 dispersion. The functional should match that used to train MLPs."
doc_lmp_d3_cutoff = "Cutoff radius for D3 dispersion (in Angstrom)."
doc_lmp_d3_cn_cutoff = "Coordination number cutoff for D3 dispersion (in Angstrom)."
doc_lmp_neigh_modify_one = "Maximum number of neighbors of one atom for 'neigh_modify one N' command. Helps with D3 compatibility."

return [
model_devi_jobs_args(),
Expand Down Expand Up @@ -491,6 +498,27 @@ def model_devi_lmp_args() -> list[Argument]:
"use_relative_v", bool, optional=True, default=False, doc=doc_use_relative_v
),
Argument("epsilon_v", float, optional=True, doc=doc_epsilon_v),
Argument(
"lmp_d3",
dict,
optional=True,
doc=doc_lmp_d3,
sub_fields=[
Argument("enable", bool, optional=False, doc=doc_lmp_d3_enable),
Argument(
"damping_function",
str,
optional=False,
doc=doc_lmp_d3_damping_function,
),
Argument("functional", str, optional=False, doc=doc_lmp_d3_functional),
Argument("cutoff", float, optional=False, doc=doc_lmp_d3_cutoff),
Argument("cn_cutoff", float, optional=False, doc=doc_lmp_d3_cn_cutoff),
],
),
Argument(
"lmp_neigh_modify_one", int, optional=True, doc=doc_lmp_neigh_modify_one
),
]


Expand Down
53 changes: 47 additions & 6 deletions dpgen/generator/lib/lammps.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,17 @@ def make_lammps_input(
ret += "atom_style atomic\n"
ret += "\n"
ret += "neighbor 1.0 bin\n"

# Build neigh_modify command with applicable options
neigh_modify_one = jdata.get("lmp_neigh_modify_one")
neigh_modify_options = []
if neidelay is not None:
ret += "neigh_modify delay %d\n" % neidelay # noqa: UP031
neigh_modify_options.append(f"delay {neidelay}")
if neigh_modify_one is not None:
neigh_modify_options.append(f"one {neigh_modify_one}")

if neigh_modify_options:
ret += f"neigh_modify {' '.join(neigh_modify_options)}\n"
ret += "\n"
ret += "box tilt large\n"
if nbeads is None:
Expand All @@ -96,9 +105,23 @@ def make_lammps_input(
graph_list = ""
for ii in graphs:
graph_list += ii + " "

# Check if D3 dispersion is configured
lmp_d3 = jdata.get("lmp_d3", {})
d3_enabled = lmp_d3.get("enable", False) if lmp_d3 else False

if d3_enabled:
# Build D3 parameter string from validated arguments
d3_params = f"{lmp_d3['damping_function']} {lmp_d3['functional']} {lmp_d3['cutoff']} {lmp_d3['cn_cutoff']}"

if Version(deepmd_version) < Version("1"):
# 0.x
ret += f"pair_style deepmd {graph_list} ${{THERMO_FREQ}} model_devi.out\n"
if d3_enabled:
ret += f"pair_style hybrid/overlay deepmd {graph_list} ${{THERMO_FREQ}} model_devi.out dispersion/d3 {d3_params}\n"
else:
ret += (
f"pair_style deepmd {graph_list} ${{THERMO_FREQ}} model_devi.out\n"
)
else:
# 1.x
keywords = ""
Expand All @@ -112,11 +135,29 @@ def make_lammps_input(
keywords += "fparam ${ELE_TEMP}"
if ele_temp_a is not None:
keywords += "aparam ${ELE_TEMP}"
if nbeads is None:
ret += f"pair_style deepmd {graph_list} out_freq ${{THERMO_FREQ}} out_file model_devi.out {keywords}\n"

if d3_enabled:
# Use hybrid/overlay with D3
if nbeads is None:
ret += f"pair_style hybrid/overlay deepmd {graph_list} out_freq ${{THERMO_FREQ}} out_file model_devi.out {keywords} dispersion/d3 {d3_params}\n"
else:
ret += f"pair_style hybrid/overlay deepmd {graph_list} out_freq ${{THERMO_FREQ}} out_file model_devi${{ibead}}.out {keywords} dispersion/d3 {d3_params}\n"
else:
ret += f"pair_style deepmd {graph_list} out_freq ${{THERMO_FREQ}} out_file model_devi${{ibead}}.out {keywords}\n"
ret += "pair_coeff * *\n"
# Standard deepmd only
if nbeads is None:
ret += f"pair_style deepmd {graph_list} out_freq ${{THERMO_FREQ}} out_file model_devi.out {keywords}\n"
else:
ret += f"pair_style deepmd {graph_list} out_freq ${{THERMO_FREQ}} out_file model_devi${{ibead}}.out {keywords}\n"

# Add pair_coeff lines
if d3_enabled:
# D3 requires type maps (element symbols)
type_map = jdata.get("type_map", [])
type_map_str = " ".join(type_map)
ret += "pair_coeff * * deepmd\n"
ret += f"pair_coeff * * dispersion/d3 {type_map_str}\n"
else:
ret += "pair_coeff * *\n"
ret += "\n"
ret += "thermo_style custom step temp pe ke etotal press vol lx ly lz xy xz yz\n"
ret += "thermo ${THERMO_FREQ}\n"
Expand Down
133 changes: 123 additions & 10 deletions dpgen/generator/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,11 +1078,22 @@ def revise_lmp_input_model(
):
idx = find_only_one_key(lmp_lines, ["pair_style", "deepmd"])
graph_list = " ".join(task_model_list)

# Check if D3 dispersion is configured
lmp_d3 = jdata.get("lmp_d3", {}) if jdata else {}
d3_enabled = lmp_d3.get("enable", False) if lmp_d3 else False

if Version(deepmd_version) < Version("1"):
lmp_lines[idx] = "pair_style deepmd %s %d model_devi.out\n" % ( # noqa: UP031
graph_list,
trj_freq,
)
if d3_enabled:
d3_params = f"{lmp_d3['damping_function']} {lmp_d3['functional']} {lmp_d3['cutoff']} {lmp_d3['cn_cutoff']}"
lmp_lines[idx] = (
f"pair_style hybrid/overlay deepmd {graph_list} {trj_freq} model_devi.out dispersion/d3 {d3_params}\n"
)
else:
lmp_lines[idx] = "pair_style deepmd %s %d model_devi.out\n" % ( # noqa: UP031
graph_list,
trj_freq,
)
else:
# Build keywords string like in make_lammps_input
keywords = ""
Expand All @@ -1097,14 +1108,106 @@ def revise_lmp_input_model(
if use_ele_temp == 1:
keywords += "fparam ${ELE_TEMP}"

lmp_lines[idx] = (
"pair_style deepmd %s out_freq %d out_file model_devi.out %s\n" # noqa: UP031
% (
graph_list,
trj_freq,
keywords.rstrip(),
if d3_enabled:
d3_params = f"{lmp_d3['damping_function']} {lmp_d3['functional']} {lmp_d3['cutoff']} {lmp_d3['cn_cutoff']}"
lmp_lines[idx] = (
"pair_style hybrid/overlay deepmd %s out_freq %d out_file model_devi.out %s dispersion/d3 %s\n" # noqa: UP031
% (
graph_list,
trj_freq,
keywords.rstrip(),
d3_params,
)
)
else:
lmp_lines[idx] = (
"pair_style deepmd %s out_freq %d out_file model_devi.out %s\n" # noqa: UP031
% (
graph_list,
trj_freq,
keywords.rstrip(),
)
)
return lmp_lines


def revise_lmp_input_pair_coeff(lmp_lines, jdata=None):
"""Update pair_coeff lines for D3 support."""
if jdata is None:
return lmp_lines

lmp_d3 = jdata.get("lmp_d3", {})
d3_enabled = lmp_d3.get("enable", False) if lmp_d3 else False

if not d3_enabled:
return lmp_lines

# D3 requires type maps (element symbols)
type_map = jdata.get("type_map", [])
type_map_str = " ".join(type_map)

# Find pair_coeff line
pair_coeff_idx = None
for idx, line in enumerate(lmp_lines):
if line.strip().startswith("pair_coeff") and "* *" in line:
pair_coeff_idx = idx
break

if pair_coeff_idx is None:
# If no pair_coeff found, add them after pair_style
pair_style_idx = find_only_one_key(lmp_lines, ["pair_style"])
lmp_lines.insert(pair_style_idx + 1, "pair_coeff * * deepmd\n")
lmp_lines.insert(
pair_style_idx + 2, f"pair_coeff * * dispersion/d3 {type_map_str}\n"
)
else:
# Replace existing pair_coeff with D3 version
lmp_lines[pair_coeff_idx] = "pair_coeff * * deepmd\n"
lmp_lines.insert(
pair_coeff_idx + 1, f"pair_coeff * * dispersion/d3 {type_map_str}\n"
)

return lmp_lines


def revise_lmp_input_neigh_modify(lmp_lines, jdata=None):
"""Add neigh_modify one N if requested."""
if jdata is None:
return lmp_lines

neigh_modify_one = jdata.get("lmp_neigh_modify_one")
if neigh_modify_one is None:
return lmp_lines

# Find where to insert neigh_modify one N
# Look for existing neigh_modify lines or insert after neighbor command
neigh_modify_found = False
neighbor_idx = None

for idx, line in enumerate(lmp_lines):
if line.strip().startswith("neigh_modify") and " one " in line:
neigh_modify_found = True
break
elif line.strip().startswith("neighbor"):
neighbor_idx = idx

if not neigh_modify_found:
if neighbor_idx is not None:
lmp_lines.insert(
neighbor_idx + 1, f"neigh_modify one {neigh_modify_one}\n"
)
else:
# Insert after units command if neighbor not found
units_idx = None
for idx, line in enumerate(lmp_lines):
if line.strip().startswith("units"):
units_idx = idx
break
if units_idx is not None:
lmp_lines.insert(
units_idx + 1, f"neigh_modify one {neigh_modify_one}\n"
)

return lmp_lines


Expand Down Expand Up @@ -1479,6 +1582,9 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
use_ele_temp=use_ele_temp,
jdata=jdata,
)
# Add D3 pair_coeff and neigh_modify support for templates
lmp_lines = revise_lmp_input_pair_coeff(lmp_lines, jdata)
lmp_lines = revise_lmp_input_neigh_modify(lmp_lines, jdata)
else:
if len(lmp_lines[template_pair_deepmd_idx].split()) != (
len(models)
Expand All @@ -1501,6 +1607,9 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
use_ele_temp=use_ele_temp,
jdata=jdata,
)
# Add D3 pair_coeff and neigh_modify support for templates
lmp_lines = revise_lmp_input_pair_coeff(lmp_lines, jdata)
lmp_lines = revise_lmp_input_neigh_modify(lmp_lines, jdata)
# use revise_lmp_input_model to raise error message if "part_style" or "deepmd" not found
else:
lmp_lines = revise_lmp_input_model(
Expand All @@ -1512,6 +1621,10 @@ def _make_model_devi_revmat(iter_index, jdata, mdata, conf_systems):
jdata=jdata,
)

# Add D3 pair_coeff and neigh_modify support for templates
lmp_lines = revise_lmp_input_pair_coeff(lmp_lines, jdata)
lmp_lines = revise_lmp_input_neigh_modify(lmp_lines, jdata)

lmp_lines = revise_lmp_input_dump(
lmp_lines, trj_freq, model_devi_merge_traj
)
Expand Down
Loading