Skip to content

Commit 87f556c

Browse files
Update
1 parent f26e160 commit 87f556c

1 file changed

Lines changed: 5 additions & 4 deletions

File tree

src/MaxText/layers/nnx_decoders.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,10 @@ def __call__(
194194
layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names)
195195

196196
if cfg.record_internal_nn_metrics:
197-
self.sow("intermediates", "activation_mean", jnp.mean(layer_output))
198-
self.sow("intermediates", "activation_stdev", jnp.std(layer_output))
197+
self.sow(nnx.Intermediate, "activation_mean", jnp.mean(layer_output))
198+
self.sow(nnx.Intermediate, "activation_stdev", jnp.std(layer_output))
199199
self.sow(
200-
"intermediates",
200+
nnx.Intermediate,
201201
"activation_fraction_zero",
202202
jnp.sum(layer_output == 0) / jnp.size(layer_output),
203203
)
@@ -380,10 +380,11 @@ def create_layer_fn(rng):
380380
except: # pylint: disable=bare-except
381381
pass
382382

383+
scan_axis = self.config.param_scan_axis
383384
layers_vmapped = nnx.vmap(
384385
create_layer_fn,
385386
in_axes=0,
386-
out_axes=0,
387+
out_axes=scan_axis,
387388
axis_name="layers",
388389
transform_metadata={nnx.PARTITION_NAME: "layers"},
389390
)(forked_rngs)

0 commit comments

Comments
 (0)