diff --git a/src/maxtext/configs/base.yml b/src/maxtext/configs/base.yml index 63967584d3..fb58aa79b4 100644 --- a/src/maxtext/configs/base.yml +++ b/src/maxtext/configs/base.yml @@ -306,7 +306,7 @@ pipeline_fsdp_ag_per_repeat: False # It may be useful to do the reverse when the layers_per_stage is very large. # The below settings only have effect when using pipeline parallelism. scan_pipeline_iterations: True -scan_pipeline_repeats: True +scan_pipeline_repeats: False scan_layers_per_stage: False set_remat_policy_on_pipeline_iterations: True set_remat_policy_on_layers_per_stage: False diff --git a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml index 7b99d96978..8209dece2d 100644 --- a/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml +++ b/src/maxtext/configs/custom_mesh_and_rule/pipeline-large-moe.yml @@ -20,48 +20,49 @@ # The `data` axis is preserved for two reasons: first, the pipeline stage acts as a # data parallel (DP) domain externally, making the `data` axis a necessary reference; # second, it may be required for DCN communication. +# +# The `context` axis is used for supporting fractional per device batch size # # Finally, the `tensor` axis is used to shard weights when `pipeline_fsdp_ag_once` or # `pipeline_fsdp_ag_per_repeat` is enabled, ensuring we have sufficient memory to # store prefetched weights. -mesh_axes: ['data', 'stage', 'fsdp', 'tensor', 'expert'] -data_sharding: [['data', 'stage', 'fsdp', 'tensor', 'expert']] +mesh_axes: ['data', 'stage', 'fsdp', 'context', 'tensor', 'expert'] +data_sharding: [['data', 'stage', 'fsdp', 'context', 'tensor', 'expert']] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'expert']], ['activation_batch_moe', ['data', 'fsdp', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp']], ['activation_batch_no_exp_moe', ['data', 'fsdp']], ['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp', 'expert']], - ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']], + ['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'context', 'expert']], ['activation_heads', ['tensor']], ['activation_kv_heads', ['tensor']], - ['activation_length', ['expert']], - ['activation_attn_length', ['expert']], - ['activation_q_length', ['expert']], + ['activation_length', ['context', 'expert']], + ['activation_attn_length', ['context', 'expert']], + ['activation_q_length', ['context', 'expert']], ['activation_attn_embed', ['tensor']], + ['activation_norm_length', ['context']], + ['activation_norm_length_moe', ['context']], ['activation_embed', ['tensor']], ['activation_embed_moe', ['tensor']], ['activation_mlp', ['tensor']], ['activation_kv', ['tensor']], - ['activation_prefill_kv_batch', ['data', 'fsdp', 'expert']], ['activation_kv_batch', ['data', 'fsdp', 'expert']], ['activation_kv_batch_no_exp', ['data', 'fsdp']], ['activation_kv_head_dim', ['tensor']], ['activation_vocab', ['tensor']], ['activation_stage', 'stage'], ['activation_exp', ['expert']], - ['decode_batch', ['data', 'fsdp', 'expert']], ['mlp', ['tensor']], ['mlp_no_fsdp', ['tensor']], ['vocab', ['tensor']], ['heads', ['tensor']], ['q_heads', ['tensor']], ['kv_heads', ['tensor']], - ['embed', ['fsdp', 'expert']], + ['embed', ['fsdp', 'expert']], # remove context from embed sharding ['embed_moe', ['fsdp', 'expert']], ['embed_no_exp', ['fsdp']], ['embed_no_exp_moe', ['fsdp']], - ['embed_moe', ['fsdp']], ['q_lora', ['fsdp']], ['kv_lora', ['fsdp']], ['norm', ['tensor']], diff --git a/src/maxtext/layers/pipeline.py b/src/maxtext/layers/pipeline.py index 60cb7d2ac2..1b130f1888 100644 --- a/src/maxtext/layers/pipeline.py +++ b/src/maxtext/layers/pipeline.py @@ -184,6 +184,7 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): non-circular""" # Stage 0 has processed one microbatch every loop_iter, but Stage 1 is 1 behind due to bubble, etc for other stages microbatches_processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) + microbatches_processed = self._maybe_shard_with_name(microbatches_processed, NamedSharding(self.mesh, P("stage"))) microbatch_ids = microbatches_processed % self.config.num_pipeline_microbatches repeat_ids = microbatches_processed // self.config.num_pipeline_microbatches return microbatch_ids, repeat_ids @@ -1006,8 +1007,12 @@ def gather_microbatch_inputs_vmap(self, xs, ids, ids_dim): def _gather_one(x, i): idx = tuple(i if d == ids_dim else slice(None) for d in range(x.ndim)) - replicated_sharding = NamedSharding(self.mesh, P()) - return x.at[idx].get(out_sharding=replicated_sharding) + positions_sharding = ( + create_sharding(self.mesh, (None, "layers", "activation_length")) + if self.config.shard_mode == ShardMode.EXPLICIT + else None + ) + return x.at[idx].get(out_sharding=positions_sharding) return jax.vmap(_gather_one, in_axes=(None, 0), out_axes=ids_dim)(xs, ids) @@ -1168,11 +1173,12 @@ def from_repeat_weights_to_bsw( self, repeat_weights, physical_partition_spec, - axes_to_gather=("fsdp", "fsdp_transpose", "expert"), # three major FSDP-like axes + axes_to_gather=("fsdp", "fsdp_transpose", "context", "expert"), + # TODO (chengnuojin) set use_shardmap=true after JAX >= 10.0.0 and use all_gather(..., to='invarying') use_shardmap=False, # using shardmap produces additional reduce-scatter in backward pass ): """Executes the FSDP-like all-gathers to fully materialize a block of weights for the BSW.""" - axes_to_remove = ["fsdp", "fsdp_transpose"] + axes_to_remove = ["fsdp", "fsdp_transpose", "context"] bsw_pps = pipeline_utils.derive_stage_weight_partition_specs(physical_partition_spec, axes_to_remove) def _from_repeat_weights_to_bsw_shardmap( @@ -1229,7 +1235,7 @@ def _apply_sharding_hint(weight, pspec): weight, sharding_name, shard_mode=self.config.shard_mode, - debug_sharding=self.config.shard_mode, + debug_sharding=self.config.debug_sharding, extra_stack_level=0, ) @@ -1240,21 +1246,15 @@ def _apply_sharding_hint(weight, pspec): return _from_repeat_weights_to_bsw_hint(repeat_weights) def weight_prefetching(self, weights, physical_partition_spec, loop_iteration): - """Triggers asynchronous FSDP-like all-gathers for the current and next pipeline steps. + """Triggers asynchronous FSDP-like all-gathers for the next pipeline steps. By gathering weights for `loop_iteration + 1` right now, the network communication - can overlap with the compute happening in `loop_iteration`. The dual-buffers - are returned grouped in an explicit `jax.ad_checkpoint` to strictly control memory. + can overlap with the compute happening in `loop_iteration`. """ - cur_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration) - nxt_repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1) - bsw_0 = self.from_repeat_weights_to_bsw(cur_repeat_weights, physical_partition_spec) - bsw_1 = self.from_repeat_weights_to_bsw(nxt_repeat_weights, physical_partition_spec) - return jax.ad_checkpoint.checkpoint_name((bsw_0, bsw_1), "bsw") + repeat_weights = self.from_all_variables_to_repeat_weights(weights, loop_iteration + 1) + return self.from_repeat_weights_to_bsw(repeat_weights, physical_partition_spec) - def run_one_iteration( - self, loop_state, bsw, weights, positions, segment_ids, deterministic, model_mode, logical_partition_spec - ): + def run_one_iteration(self, loop_state, bsw, positions, segment_ids, deterministic, model_mode, logical_partition_spec): """Executes the forward/backward logic for a single microbatch inside the pipeline. This acts as the core step function that our `jax.lax.scan` wrappers call. It routes @@ -1339,7 +1339,6 @@ def __call__( segment_idx = None loop_state, bsw = self.init_states(inputs) - weights = self.layers.variables physical_partition_spec = logical_to_mesh( logical_partition_spec, mesh=self.mesh, rules=self.config.logical_axis_rules ) @@ -1353,12 +1352,11 @@ def __call__( logical_partition_spec = pipeline_utils.strip_pipeline_repeat_logical_axis(logical_partition_spec) - def run_iteration_scannable(model, loop_state, bsw, weights): + def run_iteration_scannable(model, loop_state, bsw): return ( model.run_one_iteration( loop_state, bsw, - weights, positions, segment_ids, deterministic, @@ -1377,9 +1375,7 @@ def run_iteration_scannable(model, loop_state, bsw, weights): # base scannable function used twice for real and bubble runs base_scannable = functools.partial( - pipeline_utils.create_rematerialized_pipeline_stage, - model=self, - run_iteration_scannable=run_iteration_scannable, + pipeline_utils.create_pipeline_stage, deterministic=deterministic, model_mode=model_mode, logical_partition_spec=logical_partition_spec, @@ -1391,18 +1387,22 @@ def run_iteration_scannable(model, loop_state, bsw, weights): run_one_repeat_scannable = base_scannable(length=self.config.num_pipeline_microbatches) run_bubbles_scannable = base_scannable(length=bubble_iterations) - if self.config.scan_pipeline_repeats: - run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan( - pipeline_stage_fn=run_one_repeat_scannable, length=self.config.num_pipeline_repeats - ) - run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan(pipeline_stage_fn=run_bubbles_scannable, length=1) - (loop_state, bsw, weights), _ = run_repeats_scanned(self, (loop_state, bsw, weights)) - (loop_state, bsw, weights), _ = run_bubbles_scanned(self, (loop_state, bsw, weights)) - else: - for _ in range(self.config.num_pipeline_repeats): - (loop_state, bsw, weights), _ = run_one_repeat_scannable(self, (loop_state, bsw, weights)) - for _ in range(bubble_iterations): - (loop_state, bsw, weights), _ = run_iteration_scannable(self, loop_state, bsw, weights) + run_repeats_scanned = pipeline_utils.create_flax_pipeline_scan( + pipeline_stage_fn=run_one_repeat_scannable, + length=self.config.num_pipeline_repeats, + remat_policy=self.get_pipeline_remat_policy(), + use_scan=self.config.scan_pipeline_repeats, + ) + run_bubbles_scanned = pipeline_utils.create_flax_pipeline_scan( + pipeline_stage_fn=run_bubbles_scannable, + length=1, + remat_policy=self.get_pipeline_remat_policy(), + use_scan=self.config.scan_pipeline_repeats, + ) + initial_carry_repeats = (loop_state, bsw[0], self.layers.variables) + (loop_state, w_curr, pipeline_weights), _ = run_repeats_scanned(self, initial_carry_repeats) + initial_carry_bubbles = (loop_state, w_curr, pipeline_weights) + (loop_state, _, pipeline_weights), _ = run_bubbles_scanned(self, initial_carry_bubbles) final_output = self.realign_output_microbatches(loop_state["state_io"]) final_output = jnp.reshape( diff --git a/src/maxtext/utils/pipeline_utils.py b/src/maxtext/utils/pipeline_utils.py index ceefd2ac01..c02030f301 100644 --- a/src/maxtext/utils/pipeline_utils.py +++ b/src/maxtext/utils/pipeline_utils.py @@ -19,6 +19,7 @@ from jax.sharding import PartitionSpec as P from flax import linen as nn from flax.linen.spmd import LogicallyPartitioned +import jax.numpy as jnp def get_mesh_axis_dim_indices(physical_partition_spec, axis_name="fsdp"): @@ -193,65 +194,61 @@ def create_gradient_accumulation_scan( A JAX custom_vjp function that executes the `length` pipeline iterations. """ - @functools.partial(jax.custom_vjp) - def run_single_microbatch_custom(lightweight_state, bsw, weights, pos_arg, seg_arg): - return run_single_microbatch_custom_fwd(lightweight_state, bsw, weights, pos_arg, seg_arg)[0] + @jax.custom_vjp + def run_single_microbatch_custom(lightweight_state, bsw, pos_arg, seg_arg): + return run_single_microbatch_custom_fwd(lightweight_state, bsw, pos_arg, seg_arg)[0] - def run_single_microbatch_custom_fwd(lightweight_state, bsw, weights, pos_arg, seg_arg): - def _run(l, b, w): + def run_single_microbatch_custom_fwd(lightweight_state, bsw, pos_arg, seg_arg): + def _run(l, b): out = model.run_one_iteration( - l, b, w, pos_arg, seg_arg, deterministic, model_mode, logical_partition_spec=logical_partition_spec + l, b, pos_arg, seg_arg, deterministic, model_mode, logical_partition_spec=logical_partition_spec ) - return out, b, w + return out # Rematerialize the inner step to save activation memory - _run_remat = jax.remat(_run, prevent_cse=False, policy=model.get_pipeline_remat_policy()) - out, vjp_fun = jax.vjp(_run_remat, lightweight_state, bsw, weights) + _run_remat = jax.remat(_run, policy=model.get_pipeline_remat_policy()) + out, vjp_fun = jax.vjp(_run_remat, lightweight_state, bsw) return out, vjp_fun def run_single_microbatch_custom_bwd(res, g_out): vjp_fun = res - d_l, d_b, d_w = vjp_fun(g_out) - return d_l, d_b, d_w, None, None + d_l, d_b = vjp_fun(g_out) + return d_l, d_b, None, None run_single_microbatch_custom.defvjp(run_single_microbatch_custom_fwd, run_single_microbatch_custom_bwd) - @functools.partial(jax.custom_vjp) - def run_pipeline_microbatches_custom(loop_state, bsw, weights, positions, segment_ids): - return run_pipeline_microbatches_custom_fwd(loop_state, bsw, weights, positions, segment_ids)[0] + @jax.custom_vjp + def run_pipeline_microbatches_custom(loop_state, bsw, positions, segment_ids): + return run_pipeline_microbatches_custom_fwd(loop_state, bsw, positions, segment_ids)[0] - def run_pipeline_microbatches_custom_fwd(loop_state, bsw, weights, positions, segment_ids): + def run_pipeline_microbatches_custom_fwd(loop_state, bsw, positions, segment_ids): final_lightweight, scan_vjp_fun = jax.vjp( - lambda l, b, w: jax.lax.scan( - lambda carry, _: (run_single_microbatch_custom(carry, b, w, positions, segment_ids)[0], None), + lambda l, b: jax.lax.scan( + lambda carry, _: (run_single_microbatch_custom(carry, b, positions, segment_ids), None), l, None, length=length, )[0], loop_state, bsw, - weights, ) - return (final_lightweight, bsw, weights), scan_vjp_fun + return (final_lightweight, bsw), scan_vjp_fun def run_pipeline_microbatches_custom_bwd(residuals, g_final_state): scan_vjp_fun = residuals - g_lightweight, g_bsw, g_weights = g_final_state - d_init_lightweight, d_init_bsw, d_init_weights = scan_vjp_fun(g_lightweight) + g_lightweight, g_bsw = g_final_state + d_init_lightweight, d_init_bsw = scan_vjp_fun(g_lightweight) d_init_bsw = jax.tree.map(lambda d, g: d + g if hasattr(d, "shape") else d, d_init_bsw, g_bsw) - d_init_weights = jax.tree.map(lambda d, g: d + g if hasattr(d, "shape") else d, d_init_weights, g_weights) - return (d_init_lightweight, d_init_bsw, d_init_weights, None, None) + return (d_init_lightweight, d_init_bsw, None, None) run_pipeline_microbatches_custom.defvjp(run_pipeline_microbatches_custom_fwd, run_pipeline_microbatches_custom_bwd) return run_pipeline_microbatches_custom -def create_rematerialized_pipeline_stage( - model, - run_iteration_scannable, +def create_pipeline_stage( length, deterministic, model_mode, @@ -260,18 +257,14 @@ def create_rematerialized_pipeline_stage( positions, segment_ids, ): - """Builds a memory-checkpointed execution block for a single pipeline stage. + """Builds an execution block for a single pipeline stage. This function prepares the state for a specific chunk of pipeline execution by: - 1. Prefetching the required weights for the current stage/loop iteration. - 2. Executing `length` microbatches using either a memory-efficient `jax.lax.scan` - (if `scan_pipeline_iterations` is True) or an unrolled Python `for` loop. - 3. Wrapping the entire stage block in `flax.linen.remat` to discard and recompute - activations during the backward pass based on the model's policy. + 1. Prefetching the required weights (e.g., FSDP-gathered) for the current stage/loop iteration. + 2. Executing `length` microbatches using a memory-efficient `jax.lax.scan` via a custom VJP + that manages collective communication overlap. Args: - model: The model instance containing configuration and prefetching logic. - run_iteration_scannable: A fallback function for executing a single iteration unrolled. length: The number of microbatches to process in this stage. deterministic: Whether to run deterministically (e.g., disable dropout). model_mode: The operational mode (e.g., 'train'). @@ -281,60 +274,128 @@ def create_rematerialized_pipeline_stage( segment_ids: Segment/Attention routing IDs for the sequence. Returns: - A function decorated with `nn.remat` that takes `(model, loop_state)` and returns - the updated `loop_state`. + A function that takes `(model, carry)` and returns the updated `carry` and `None` for the scan outputs. """ - def execute_pipeline_stage(model, loop_state_and_bsw_and_weights): - loop_state, bsw, weights = loop_state_and_bsw_and_weights - # Retrieve the specific weights needed for this pipeline chunk - bsw = model.weight_prefetching(weights, physical_partition_spec, loop_state["loop_iteration"]) - - if model.config.scan_pipeline_iterations: - scan_microbatches_fn = create_gradient_accumulation_scan( - model=model, - length=length, - deterministic=deterministic, - model_mode=model_mode, - logical_partition_spec=logical_partition_spec, - ) - loop_state, bsw, weights = scan_microbatches_fn(loop_state, bsw, weights, positions, segment_ids) - else: - for _ in range(length): - (loop_state, bsw, weights), _ = run_iteration_scannable(model, loop_state, bsw, weights) - return (loop_state, bsw, weights), None - - return nn.remat( - execute_pipeline_stage, - prevent_cse=not model.config.scan_pipeline_iterations, - policy=model.get_pipeline_remat_policy(), - ) - - -def create_flax_pipeline_scan(pipeline_stage_fn, length): - """Wraps the pipeline stage execution in a `flax.linen.scan`. + def execute_pipeline_stage_flax(model, carry): + """ + A non-pure Flax closure of the pipeline stage. + + This function bridges the pure JAX custom VJP logic with Flax's object-oriented + lifting mechanisms. It unpacks the carry state and routes it through the pure VJP function. + + Args: + model: CircularPipeline Flax linen model instance. + carry: A tuple containing (loop_state, w_curr, pipeline_weights). + - loop_state: The current execution state of the pipeline. + - w_curr: The gathered weights used for the current pipeline step. + - pipeline_weights: The fully sharded baseline weights. + """ + + loop_state, w_curr, pipeline_weights = carry + + scan_microbatches_fn = create_gradient_accumulation_scan( + model=model, + length=length, + deterministic=deterministic, + model_mode=model_mode, + logical_partition_spec=logical_partition_spec, + ) - This lifts the pipeline stage function so it can be repeated sequentially over - the specified length. It safely handles Flax-specific state collections, ensuring - that metrics, intermediate values, and PRNG keys do not collide or overwrite - each other across the loop iterations. + # Establish a pure function boundary to allow for custom VJP definition + @jax.custom_vjp + def execute_pipeline_stage_pure(loop_state, w_curr, pipeline_weights): + return execute_pipeline_stage_pure_fwd(loop_state, w_curr, pipeline_weights)[0] + + def execute_pipeline_stage_pure_fwd(loop_state, w_curr, pipeline_weights): + # Prefetch FSDP-sharded weights for the upcoming pipeline repeat + w_next = model.weight_prefetching( + pipeline_weights, + physical_partition_spec, + loop_state["loop_iteration"], + ) + # Construct a buffered sliding window (BSW) of weights. + # w_curr: Weights actively used for the current microbatch steps. + # w_next: Newly gathered weights that will be carried forward as the new w_curr. + bsw = (w_curr, w_next) + # Bind arguments to the weight prefetching function to prepare it for linear transpose + p_weight_prefetching = functools.partial( + model.weight_prefetching, + physical_partition_spec=physical_partition_spec, + loop_iteration=loop_state["loop_iteration"], + ) + # Since weight gathering (all-gather) is a linear operation, we can derive its dual + # (reduce-scatter) via jax.linear_transpose. This avoids redundant forward passes + weight_prefetching_t = jax.linear_transpose( + p_weight_prefetching, + pipeline_weights, + ) + # Execute the forward pass of the microbatches and generate its VJP. + # The VJP captures necessary checkpoints to evaluate gradients later. + (loop_state, _), scan_microbatches_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids) + # Discard the old weights (w_curr) and advance w_next to act as the current weights in the next iteration + return (loop_state, w_next), (scan_microbatches_vjp, weight_prefetching_t) + + def execute_pipeline_stage_pure_bwd(residuals, g_outputs): + # Unpack forward pass residuals (VJP closures) and the incoming output gradients + g_loop_state, g_w_next = g_outputs + scan_microbatches_vjp, weight_prefetching_t = residuals + # Initialize zero cotangents for w_curr, as it was consumed in the forward pass + g_w_curr = jax.tree.map(jnp.zeros_like, g_w_next) + g_bsw = (g_w_curr, g_w_next) + # Backpropagate gradients through the dual microbatch execution block + g_loop_state, g_bsw, _, _ = scan_microbatches_vjp((g_loop_state, g_bsw)) + # Apply the linear transpose of the weight prefetch to execute the reduce-scatter + # This maps the gradients of the gathered weights back to the FSDP-sharded parameter space + g_w_curr, g_w_next = g_bsw + (g_pipeline_weights,) = weight_prefetching_t(g_w_next) + # Return gradients corresponding to the three original inputs of execute_pipeline_stage_pure + return g_loop_state, g_w_curr, g_pipeline_weights + + execute_pipeline_stage_pure.defvjp(execute_pipeline_stage_pure_fwd, execute_pipeline_stage_pure_bwd) + + # Execute the pure pipeline stage. We unpack the two modified outputs (loop_state, w_next) + # and repack them alongside the unmodified pipeline_weights to maintain a consistent carry shape for nn.scan. + return (*execute_pipeline_stage_pure(loop_state, w_curr, pipeline_weights), pipeline_weights), None + + return execute_pipeline_stage_flax + + +def create_flax_pipeline_scan(pipeline_stage_fn, length, remat_policy, use_scan=True): + """Wraps the pipeline stage execution in `flax.linen.remat` and `flax.linen.scan`. + + This explicitly wraps the pipeline step in a gradient checkpointing policy + and then lifts it so it can be repeated sequentially over the specified length. + It safely handles Flax-specific state collections, ensuring that metrics, intermediate + values, and PRNG keys do not collide or overwrite each other across loop iterations. Args: pipeline_stage_fn: The function representing a single pipeline stage - (usually created by `create_rematerialized_pipeline_stage`). + (usually created by `create_pipeline_stage`). + remat_policy: The checkpointing policy used by `nn.remat` to manage activation memory. length: The total number of pipeline stages/repeats to scan over. + use_scan: Whether to use `jax.lax.scan` (True) or unroll the loop (False). Returns: A Flax scanned function that executes the full pipeline schedule. """ + unroll_length = 1 if use_scan else length return nn.scan( - pipeline_stage_fn, + nn.remat( + pipeline_stage_fn, + policy=remat_policy, + ), variable_axes={ "summaries": 0, "aux_loss": 0, "intermediates": 0, "hyper_params": 0, }, + variable_broadcast=[ + "_overwrite_with_gradient", + "non_trainable", + ], split_rngs={"random": True}, length=length, + unroll=unroll_length, ) diff --git a/tests/utils/forward_pass_logit_checker.py b/tests/utils/forward_pass_logit_checker.py index ef8059bc88..4d6f9982b6 100644 --- a/tests/utils/forward_pass_logit_checker.py +++ b/tests/utils/forward_pass_logit_checker.py @@ -341,8 +341,8 @@ def main(config, test_args): # pylint: disable=W0621 max_logging.log(msg) if test_args.clip_logits_epsilon is not None: - model_probabilities = jnp.clip(jax.nn.softmax(train_logits_slice, axis=-1), a_min=test_args.clip_logits_epsilon) - golden_probabilities = jnp.clip(jax.nn.softmax(golden_logits_slice, axis=-1), a_min=test_args.clip_logits_epsilon) + model_probabilities = jnp.clip(jax.nn.softmax(train_logits_slice, axis=-1), min=test_args.clip_logits_epsilon) + golden_probabilities = jnp.clip(jax.nn.softmax(golden_logits_slice, axis=-1), min=test_args.clip_logits_epsilon) else: model_probabilities = jax.nn.softmax(train_logits_slice, axis=-1) golden_probabilities = jax.nn.softmax(golden_logits_slice, axis=-1)