|
25 | 25 | import jax.numpy as jnp |
26 | 26 | import numpy as np |
27 | 27 | from jax._src.sharding_impls import UNSPECIFIED |
28 | | -from jax.core import Literal, Var, Jaxpr |
| 28 | +if jax.__version__ >= '0.5.0': |
| 29 | + from jax.extend.core import Primitive, Literal, Var, Jaxpr |
| 30 | +else: |
| 31 | + from jax.core import Primitive, Literal, Var, Jaxpr |
29 | 32 |
|
30 | 33 | __all__ = [ |
31 | 34 | 'fn_to_python_code', |
@@ -187,7 +190,7 @@ def fn_to_python_code(fn, *args, **kwargs): |
187 | 190 | return source |
188 | 191 |
|
189 | 192 |
|
190 | | -def jaxpr_to_python_code(jaxpr: jax.core.Jaxpr, |
| 193 | +def jaxpr_to_python_code(jaxpr: Jaxpr, |
191 | 194 | fn_name: str = "generated_function"): |
192 | 195 | """ |
193 | 196 | Given a JAX jaxpr, return the Python code that would be generated by JAX for that jaxpr. |
@@ -367,7 +370,7 @@ def _maybe_wrap_fn_for_leaves(node, f, num_args): |
367 | 370 |
|
368 | 371 |
|
369 | 372 | def jaxpr_to_py_ast(state: SourcerorState, |
370 | | - jaxpr: jax.core.Jaxpr, |
| 373 | + jaxpr: Jaxpr, |
371 | 374 | fn_name: str = "function"): |
372 | 375 | # Generate argument declarations |
373 | 376 | ast_args = [ast.arg(arg=state.str_name(var), annotation=None) |
@@ -405,7 +408,7 @@ def jaxpr_to_py_ast(state: SourcerorState, |
405 | 408 | return ast.FunctionDef(name=fn_name, args=ast_args, body=stmts, decorator_list=[]) |
406 | 409 |
|
407 | 410 |
|
408 | | -def constant_fold_jaxpr(jaxpr: jax.core.Jaxpr): |
| 411 | +def constant_fold_jaxpr(jaxpr: Jaxpr): |
409 | 412 | """ |
410 | 413 | Given a jaxpr, return a new jaxpr with all constant folding done. |
411 | 414 | """ |
|
0 commit comments