Skip to content

Commit f68fc18

Browse files
committed
fix hlo
1 parent 545a664 commit f68fc18

2 files changed

Lines changed: 2 additions & 0 deletions

File tree

deepmd/jax/jax2tf/tfmodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def call_lower(
187187
fparam: jnp.ndarray | None = None,
188188
aparam: jnp.ndarray | None = None,
189189
do_atomic_virial: bool = False,
190+
charge_spin: jnp.ndarray | None = None,
190191
) -> dict[str, jnp.ndarray]:
191192
if do_atomic_virial:
192193
call_lower = self._call_lower_atomic_virial

deepmd/jax/model/hlo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def call_lower(
183183
fparam: jnp.ndarray | None = None,
184184
aparam: jnp.ndarray | None = None,
185185
do_atomic_virial: bool = False,
186+
charge_spin: jnp.ndarray | None = None,
186187
) -> dict[str, jnp.ndarray]:
187188
if extended_coord.shape[1] > nlist.shape[1]:
188189
if do_atomic_virial:

0 commit comments

Comments
 (0)