Skip to content

Commit 7f8d27f

Browse files
Jake VanderPlaslearned_optimization authors
authored andcommitted
No public description
PiperOrigin-RevId: 903911142
1 parent 32186aa commit 7f8d27f

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

learned_optimization/jax_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@
2222
import jax.numpy as jnp
2323
import numpy as onp
2424

25+
try:
26+
# JAX v0.10.0 or newer
27+
from jax.extend.core import unsafe_am_i_under_a_jit_DO_NOT_USE # pylint: disable=g-import-not-at-top
28+
except ImportError:
29+
# JAX v0.9.2 or older
30+
from jax.core import unsafe_am_i_under_a_jit_DO_NOT_USE # pylint: disable=g-import-not-at-top
31+
2532

2633
def maybe_static_cond(pred, true_fn, false_fn, val):
2734
"""Conditional that first checks if pred can be determined at compile time."""
@@ -51,7 +58,7 @@ def in_jit() -> bool:
5158
jax.core.thread_local_state.trace_state.trace_stack # type: ignore
5259
)
5360

54-
return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE()
61+
return unsafe_am_i_under_a_jit_DO_NOT_USE()
5562

5663

5764
Carry = TypeVar("Carry")

0 commit comments

Comments
 (0)