Skip to content

Commit 9eb6a2a

Browse files
hawkinspGoogle-ML-Automation
authored andcommitted
[JAX] Replace jnp.clip(..., a_min=..., a_max=...) with jnp.clip(..., min=..., max=...).
a_min and a_max are deprecated parameter names to jax.numpy.clip. PiperOrigin-RevId: 890539968
1 parent 61dc465 commit 9eb6a2a

2 files changed

Lines changed: 16 additions & 4 deletions

File tree

src/maxtext/layers/moe.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,14 @@ def apply_ffn_activation(self, layer_w0, layer_w1):
598598
"""Applies FFN activation function."""
599599
with jax.named_scope("ffn_act"):
600600
if self.config.decoder_block == ctypes.DecoderBlockType.GPT_OSS:
601-
layer_w0 = jnp.clip(layer_w0, a_min=None, a_max=self.config.mlp_activations_limit)
602-
layer_w1 = jnp.clip(layer_w1, a_min=-self.config.mlp_activations_limit, a_max=self.config.mlp_activations_limit)
601+
layer_w0 = jnp.clip(
602+
layer_w0, min=None, max=self.config.mlp_activations_limit
603+
)
604+
layer_w1 = jnp.clip(
605+
layer_w1,
606+
min=-self.config.mlp_activations_limit,
607+
max=self.config.mlp_activations_limit,
608+
)
603609
layer_act = self.activation_fn(layer_w0 * 1.702)
604610
glu = jnp.multiply(layer_w0, layer_act)
605611
intermediate_layer = jnp.multiply(glu, (layer_w1 + 1))

tests/utils/forward_pass_logit_checker.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,14 @@ def main(config, test_args): # pylint: disable=W0621
332332
max_logging.log(msg)
333333

334334
if test_args.clip_logits_epsilon is not None:
335-
model_probabilities = jnp.clip(jax.nn.softmax(train_logits_slice, axis=-1), a_min=test_args.clip_logits_epsilon)
336-
golden_probabilities = jnp.clip(jax.nn.softmax(golden_logits_slice, axis=-1), a_min=test_args.clip_logits_epsilon)
335+
model_probabilities = jnp.clip(
336+
jax.nn.softmax(train_logits_slice, axis=-1),
337+
min=test_args.clip_logits_epsilon,
338+
)
339+
golden_probabilities = jnp.clip(
340+
jax.nn.softmax(golden_logits_slice, axis=-1),
341+
min=test_args.clip_logits_epsilon,
342+
)
337343
else:
338344
model_probabilities = jax.nn.softmax(train_logits_slice, axis=-1)
339345
golden_probabilities = jax.nn.softmax(golden_logits_slice, axis=-1)

0 commit comments

Comments
 (0)