We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 32186aa commit 7f8d27fCopy full SHA for 7f8d27f
1 file changed
learned_optimization/jax_utils.py
@@ -22,6 +22,13 @@
22
import jax.numpy as jnp
23
import numpy as onp
24
25
+try:
26
+ # JAX v0.10.0 or newer
27
+ from jax.extend.core import unsafe_am_i_under_a_jit_DO_NOT_USE # pylint: disable=g-import-not-at-top
28
+except ImportError:
29
+ # JAX v0.9.2 or older
30
+ from jax.core import unsafe_am_i_under_a_jit_DO_NOT_USE # pylint: disable=g-import-not-at-top
31
+
32
33
def maybe_static_cond(pred, true_fn, false_fn, val):
34
"""Conditional that first checks if pred can be determined at compile time."""
@@ -51,7 +58,7 @@ def in_jit() -> bool:
51
58
jax.core.thread_local_state.trace_state.trace_stack # type: ignore
52
59
)
53
60
54
- return jax.core.unsafe_am_i_under_a_jit_DO_NOT_USE()
61
+ return unsafe_am_i_under_a_jit_DO_NOT_USE()
55
62
56
63
57
64
Carry = TypeVar("Carry")
0 commit comments