Skip to content

Commit 58db3cc

Browse files
committed
Removing deprecated get_eval in favor of typeof.
1 parent 020e0ec commit 58db3cc

1 file changed

Lines changed: 3 additions & 5 deletions

File tree

src/tfc/utils/TFCUtils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@
1212
import numpy as onp
1313
import numpy.typing as npt
1414
import jax.numpy as np
15-
from jax import jvp, jit, lax, jacfwd
15+
from jax import jvp, jit, lax, jacfwd, typeof
1616
from jax.extend import linear_util as lu
1717
from jax.api_util import debug_info
1818
from jax.tree_util import register_pytree_node, tree_map
1919
from jax._src.api_util import flatten_fun
2020
from jax._src.tree_util import tree_flatten
21-
from jax.core import get_aval, eval_jaxpr
21+
from jax.core import eval_jaxpr
2222
from jax.interpreters.partial_eval import trace_to_jaxpr_nounits, PartialVal
2323
from jax.experimental import io_callback
2424
from typing import Any, Callable, Optional, cast, Union, overload
@@ -256,9 +256,7 @@ def get_arg(a, unknown):
256256
if unknown:
257257
return tree_flatten(
258258
(
259-
tree_map(
260-
lambda x: PartialVal.unknown(get_aval(x).at_least_vspace()), a
261-
),
259+
tree_map(lambda x: PartialVal.unknown(typeof(x).at_least_vspace()), a),
262260
{},
263261
)
264262
)[0]

0 commit comments

Comments
 (0)