Skip to content

Commit 1acfa9d

Browse files
committed
fix(jax): fix support with default fparam
1 parent b98f6c5 commit 1acfa9d

9 files changed

Lines changed: 77 additions & 0 deletions

File tree

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ def has_default_fparam(self) -> bool:
9595
"""Check if the model has default frame parameters."""
9696
return False
9797

98+
def get_default_fparam(self) -> list[int] | None:
99+
"""Get the default frame parameters."""
100+
return []
101+
98102
def reinit_atom_exclude(
99103
self,
100104
exclude_types: list[int] = [],

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@ def has_default_fparam(self) -> bool:
240240
"""Check if the model has default frame parameters."""
241241
return self.fitting.has_default_fparam()
242242

243+
def get_default_fparam(self) -> list[int] | None:
244+
"""Get the default frame parameters."""
245+
return self.fitting.get_default_fparam()
246+
243247
def get_sel_type(self) -> list[int]:
244248
"""Get the selected atom types of this model.
245249

deepmd/dpmodel/fitting/general_fitting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,10 @@ def has_default_fparam(self) -> bool:
304304
"""Check if the fitting has default frame parameters."""
305305
return self.default_fparam is not None
306306

307+
def get_default_fparam(self) -> list[int] | None:
308+
"""Get the default frame parameters."""
309+
return self.default_fparam
310+
307311
def get_sel_type(self) -> list[int]:
308312
"""Get the selected atom types of this model.
309313

deepmd/dpmodel/model/make_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,10 @@ def has_default_fparam(self) -> bool:
567567
"""Check if the model has default frame parameters."""
568568
return self.atomic_model.has_default_fparam()
569569

570+
def get_default_fparam(self) -> list[int] | None:
571+
"""Get the default frame parameters."""
572+
return self.atomic_model.get_default_fparam()
573+
570574
def get_sel_type(self) -> list[int]:
571575
"""Get the selected atom types of this model.
572576

deepmd/jax/infer/deep_eval.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,14 @@ def _eval_model(
354354
box_input = None
355355
if fparam is not None:
356356
fparam_input = fparam.reshape(nframes, self.get_dim_fparam())
357+
elif self.dp.has_default_fparam():
358+
# JAX (XLA) requires static shapes, so default must be implemented here
359+
default_fparam = self.dp.get_default_fparam()
360+
assert default_fparam is not None
361+
fparam_input = np.tile(
362+
np.array(default_fparam, dtype=GLOBAL_NP_FLOAT_PRECISION),
363+
(nframes, 1),
364+
)
357365
else:
358366
fparam_input = None
359367
if aparam is not None:
@@ -433,3 +441,7 @@ def get_model(self) -> Any:
433441
The JAX model as BaseModel instance.
434442
"""
435443
return self.dp
444+
445+
def has_default_fparam(self) -> bool:
446+
"""Check if the model has default frame parameters."""
447+
return self.dp.has_default_fparam()

deepmd/jax/jax2tf/serialization.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,23 @@ def has_message_passing() -> tf.Tensor:
319319
return tf.constant(model.has_message_passing(), dtype=tf.bool)
320320

321321
tf_model.has_message_passing = has_message_passing
322+
323+
@tf.function
324+
def has_default_fparam() -> tf.Tensor:
325+
return tf.constant(model.has_default_fparam(), dtype=tf.bool)
326+
327+
tf_model.has_default_fparam = has_default_fparam
328+
329+
@tf.function
330+
def get_default_fparam() -> tf.Tensor:
331+
default_fparam = model.get_default_fparam()
332+
if default_fparam is None:
333+
return tf.constant([], dtype=tf.double)
334+
else:
335+
return tf.constant(default_fparam, dtype=tf.double)
336+
337+
tf_model.get_default_fparam = get_default_fparam
338+
322339
tf.saved_model.save(
323340
tf_model,
324341
model_file,

deepmd/jax/jax2tf/tfmodel.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ def __init__(
6969
self.min_nbor_dist = None
7070
self.sel = self.model.get_sel().numpy().tolist()
7171
self.model_def_script = self.model.get_model_def_script().numpy().decode()
72+
if hasattr(self.model, "has_default_fparam"):
73+
# No attrs before v3.1.2
74+
self._has_default_fparam = self.model.has_default_fparam().numpy().item()
75+
else:
76+
self._has_default_fparam = False
77+
if hasattr(self.model, "get_default_fparam"):
78+
self.default_fparam = self.model.get_default_fparam().numpy().tolist()
79+
else:
80+
self.default_fparam = None
7281

7382
def __call__(
7483
self,
@@ -331,3 +340,11 @@ def get_model(cls, model_params: dict) -> "TFModelWrapper":
331340
The model
332341
"""
333342
raise NotImplementedError("Not implemented")
343+
344+
def has_default_fparam(self) -> bool:
345+
"""Check whether the model has default frame parameters."""
346+
return self._has_default_fparam
347+
348+
def get_default_fparam(self) -> list[int] | None:
349+
"""Get the default frame parameters."""
350+
return self.default_fparam

deepmd/jax/model/hlo.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def __init__(
5858
mixed_types: bool,
5959
min_nbor_dist: float | None,
6060
sel: list[int],
61+
# new in v3.1.1
62+
has_default_fparam: bool = False,
63+
default_fparam: list[int] | None = None,
6164
) -> None:
6265
self._call_lower = jax_export.deserialize(stablehlo).call
6366
self._call_lower_atomic_virial = jax_export.deserialize(
@@ -79,6 +82,8 @@ def __init__(
7982
self.min_nbor_dist = min_nbor_dist
8083
self.sel = sel
8184
self.model_def_script = model_def_script
85+
self._has_default_fparam = has_default_fparam
86+
self.default_fparam = default_fparam
8287

8388
def __call__(
8489
self,
@@ -327,3 +332,11 @@ def get_model(cls, model_params: dict) -> "BaseModel":
327332
The model
328333
"""
329334
raise NotImplementedError("Not implemented")
335+
336+
def has_default_fparam(self) -> bool:
337+
"""Check whether the model has default frame parameters."""
338+
return self._has_default_fparam
339+
340+
def get_default_fparam(self) -> list[int] | None:
341+
"""Get the default frame parameters."""
342+
return self.default_fparam

deepmd/jax/utils/serialization.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,8 @@ def call_lower_with_fixed_do_atomic_virial(
133133
"mixed_types": model.mixed_types(),
134134
"min_nbor_dist": model.get_min_nbor_dist(),
135135
"sel": model.get_sel(),
136+
"has_default_fparam": model.has_default_fparam(),
137+
"default_fparam": model.get_default_fparam(),
136138
}
137139
save_dp_model(filename=model_file, model_dict=data)
138140
elif model_file.endswith(".savedmodel"):

0 commit comments

Comments
 (0)