|
| 1 | +from functools import partial |
1 | 2 | from typing import Callable, Dict, Optional |
2 | 3 |
|
3 | 4 | import jax |
|
43 | 44 |
|
44 | 45 | def get_haiku_loss_function( |
45 | 46 | net_transform: hk.Transformed, loss: str, **loss_kwargs: Dict[str, float] |
46 | | -) -> Callable[[hk.Params, jnp.ndarray, jnp.ndarray], jnp.ndarray]: |
| 47 | +) -> Callable[[hk.Params, jnp.ndarray, jnp.ndarray, jnp.ndarray, bool], jnp.ndarray]: |
47 | 48 | try: |
48 | 49 | loss_function = SUPPORTED_LOSSES[loss] |
49 | 50 |
|
50 | 51 | @jax.jit |
51 | 52 | def loss_function_wrapper( |
52 | | - params: hk.Params, x: jnp.ndarray, y_true: jnp.ndarray, rng: jnp.ndarray = None |
| 53 | + params: hk.Params, x: jnp.ndarray, y_true: jnp.ndarray, rng: jnp.ndarray = None, is_training: bool = None |
53 | 54 | ) -> jnp.ndarray: |
54 | 55 | # rng argument can be used if net_transform.apply() is non-deterministic, and you require a "random seed" |
55 | | - y_pred: jnp.ndarray = net_transform.apply(params, rng, x) |
| 56 | + try: |
| 57 | + y_pred: jnp.ndarray = net_transform.apply(params, rng, x, is_training=is_training) |
| 58 | + except TypeError: |
| 59 | + y_pred: jnp.ndarray = net_transform.apply(params, rng, x) |
56 | 60 | loss_value: jnp.ndarray = loss_function(y_true, y_pred, **loss_kwargs) # type: ignore |
57 | 61 | return loss_value |
58 | 62 |
|
|
0 commit comments