Skip to content

Commit 83de5b0

Browse files
committed
Add logic for networks with "is_training" (e.g. don't want dropout when predicting)
1 parent 6900952 commit 83de5b0

2 files changed

Lines changed: 10 additions & 6 deletions

File tree

jax_toolkit/losses/tests/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def net_function(x: jnp.ndarray) -> jnp.ndarray:
2828
rng = jax.random.PRNGKey(42)
2929
params = net_transform.init(rng, jnp.array(0))
3030

31-
self.assertEqual(0, actual_loss_function_wrapper(params, x=jnp.array(0), y_true=jnp.array(0)))
32-
self.assertEqual(0, actual_loss_function_wrapper(params, x=jnp.array(1), y_true=jnp.array(1)))
33-
self.assertEqual(1, actual_loss_function_wrapper(params, x=jnp.array(0), y_true=jnp.array(1)))
31+
self.assertEqual(0, actual_loss_function_wrapper(params, jnp.array(0), jnp.array(0)))
32+
self.assertEqual(0, actual_loss_function_wrapper(params, jnp.array(1), jnp.array(1)))
33+
self.assertEqual(1, actual_loss_function_wrapper(params, jnp.array(0), jnp.array(1)))
3434

3535
def test_supported_loss_returns_correctly_with_loss_kwargs(self):
3636
import haiku as hk

jax_toolkit/losses/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from functools import partial
12
from typing import Callable, Dict, Optional
23

34
import jax
@@ -43,16 +44,19 @@
4344

4445
def get_haiku_loss_function(
4546
net_transform: hk.Transformed, loss: str, **loss_kwargs: Dict[str, float]
46-
) -> Callable[[hk.Params, jnp.ndarray, jnp.ndarray], jnp.ndarray]:
47+
) -> Callable[[hk.Params, jnp.ndarray, jnp.ndarray, jnp.ndarray, bool], jnp.ndarray]:
4748
try:
4849
loss_function = SUPPORTED_LOSSES[loss]
4950

5051
@jax.jit
5152
def loss_function_wrapper(
52-
params: hk.Params, x: jnp.ndarray, y_true: jnp.ndarray, rng: jnp.ndarray = None
53+
params: hk.Params, x: jnp.ndarray, y_true: jnp.ndarray, rng: jnp.ndarray = None, is_training: bool = None
5354
) -> jnp.ndarray:
5455
# rng argument can be used if net_transform.apply() is non-deterministic, and you require a "random seed"
55-
y_pred: jnp.ndarray = net_transform.apply(params, rng, x)
56+
try:
57+
y_pred: jnp.ndarray = net_transform.apply(params, rng, x, is_training=is_training)
58+
except TypeError:
59+
y_pred: jnp.ndarray = net_transform.apply(params, rng, x)
5660
loss_value: jnp.ndarray = loss_function(y_true, y_pred, **loss_kwargs) # type: ignore
5761
return loss_value
5862

0 commit comments

Comments
 (0)