Skip to content

Commit d36677d

Browse files
fix(jax): make display_if_exist jit-able (#4766)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **Bug Fixes** - Enhanced loss display to support element-wise handling of properties, improving accuracy and consistency with array inputs. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn> Signed-off-by: Jinzhe Zeng <njzjz@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f460493 commit d36677d

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

deepmd/dpmodel/loss/loss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,9 @@ def display_if_exist(loss: np.ndarray, find_property: float) -> np.ndarray:
5151
the loss scalar or NaN
5252
"""
5353
xp = array_api_compat.array_namespace(loss)
54-
return loss if bool(find_property) else xp.nan
54+
return xp.where(
55+
xp.asarray(find_property, dtype=xp.bool), loss, xp.asarray(xp.nan)
56+
)
5557

5658
@classmethod
5759
def get_loss(cls, loss_params: dict) -> "Loss":

0 commit comments

Comments
 (0)