Skip to content

Commit 8e17656

Browse files
author
Charles Li
committed
change func name in case it can avoid trace leaking
1 parent e2f076b commit 8e17656

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxtext/layers/nnx_decoders.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ def _extract_matching_state(template, full):
505505
return {k: _extract_matching_state(v, full[k]) for k, v in template.items()}
506506
return full
507507

508-
def layer_fn(carry, scanned_vars):
508+
def _layer_fn(carry, scanned_vars):
509509
current_params, current_state = scanned_vars
510510

511511
if self.config.parameter_memory_host_offload:
@@ -521,9 +521,9 @@ def layer_fn(carry, scanned_vars):
521521
# ONLY return non-param state to prevent memory duplication of weights
522522
return new_carry, new_current_state
523523

524-
layer_fn = jax.checkpoint(layer_fn, policy=policy, prevent_cse=prevent_cse)
524+
checkpointed_layer_fn = jax.checkpoint(_layer_fn, policy=policy, prevent_cse=prevent_cse)
525525

526-
final_carry, scanned_other = jax.lax.scan(layer_fn, x_in, (params, state))
526+
final_carry, scanned_other = jax.lax.scan(checkpointed_layer_fn, x_in, (params, state))
527527

528528
if scan_axis != 0:
529529
params = jax.tree.map(lambda x: jnp.moveaxis(x, 0, scan_axis), params)

0 commit comments

Comments
 (0)