File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -194,10 +194,10 @@ def __call__(
194194 layer_output = _maybe_shard_with_logical (layer_output , logical_axis_names )
195195
196196 if cfg .record_internal_nn_metrics :
197- self .sow ("intermediates" , "activation_mean" , jnp .mean (layer_output ))
198- self .sow ("intermediates" , "activation_stdev" , jnp .std (layer_output ))
197+ self .sow (nnx . Intermediate , "activation_mean" , jnp .mean (layer_output ))
198+ self .sow (nnx . Intermediate , "activation_stdev" , jnp .std (layer_output ))
199199 self .sow (
200- "intermediates" ,
200+ nnx . Intermediate ,
201201 "activation_fraction_zero" ,
202202 jnp .sum (layer_output == 0 ) / jnp .size (layer_output ),
203203 )
@@ -380,10 +380,11 @@ def create_layer_fn(rng):
380380 except : # pylint: disable=bare-except
381381 pass
382382
383+ scan_axis = self .config .param_scan_axis
383384 layers_vmapped = nnx .vmap (
384385 create_layer_fn ,
385386 in_axes = 0 ,
386- out_axes = 0 ,
387+ out_axes = scan_axis ,
387388 axis_name = "layers" ,
388389 transform_metadata = {nnx .PARTITION_NAME : "layers" },
389390 )(forked_rngs )
You can’t perform that action at this time.
0 commit comments