Skip to content

Commit 64001e1

Browse files
committed
Simplify is_training logic
1 parent 83de5b0 commit 64001e1

3 files changed

Lines changed: 2 additions & 5 deletions

File tree

jax_toolkit/losses/tests/__init__.py

Whitespace-only changes.

jax_toolkit/losses/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,10 @@ def get_haiku_loss_function(
5050

5151
@jax.jit
5252
def loss_function_wrapper(
53-
params: hk.Params, x: jnp.ndarray, y_true: jnp.ndarray, rng: jnp.ndarray = None, is_training: bool = None
53+
params: hk.Params, x: jnp.ndarray, y_true: jnp.ndarray, rng: jnp.ndarray = None
5454
) -> jnp.ndarray:
5555
# rng argument can be used if net_transform.apply() is non-deterministic, and you require a "random seed"
56-
try:
57-
y_pred: jnp.ndarray = net_transform.apply(params, rng, x, is_training=is_training)
58-
except TypeError:
59-
y_pred: jnp.ndarray = net_transform.apply(params, rng, x)
56+
y_pred: jnp.ndarray = net_transform.apply(params, rng, x)
6057
loss_value: jnp.ndarray = loss_function(y_true, y_pred, **loss_kwargs) # type: ignore
6158
return loss_value
6259

jax_toolkit/metrics/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)