Skip to content

Commit 188dae3

Browse files
authored
fix(jax): setattr case_embd (#5104)
`case_embd` was supported but the JAX backend was not touched. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Ensured the case_embd parameter is consistently converted and handled during fitting, improving compatibility across array backends and preventing mis-coercion. * **Tests** * Adjusted test setup to reset the default computation graph before enabling eager execution, stabilizing related test runs. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <jinzhe.zeng@ustc.edu.cn>
1 parent 5274f69 commit 188dae3

3 files changed

Lines changed: 3 additions & 0 deletions

File tree

deepmd/jax/fitting/fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
4343
"fparam_inv_std",
4444
"aparam_avg",
4545
"aparam_inv_std",
46+
"case_embd",
4647
"default_fparam_tensor",
4748
}:
4849
value = to_jax_array(value)

source/tests/array_api_strict/fitting/fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def setattr_for_general_fitting(name: str, value: Any) -> Any:
3131
"fparam_inv_std",
3232
"aparam_avg",
3333
"aparam_inv_std",
34+
"case_embd",
3435
"default_fparam_tensor",
3536
}:
3637
value = to_array_api_strict_array(value)

source/tests/pt/test_tabulate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def get_activation_function(functype: int):
4141

4242

4343
def setUpModule() -> None:
44+
tf.reset_default_graph()
4445
tf.compat.v1.enable_eager_execution()
4546

4647

0 commit comments

Comments
 (0)