@@ -339,6 +339,7 @@ def mla(
339339 qk_nope_head_dim = qk_nope_head_dim ,
340340 mscale = mscale ,
341341 )
342+ query = jax .ad_checkpoint .checkpoint_name (query , "query_proj" )
342343 key , value = kv_projection (
343344 inputs ,
344345 positions ,
@@ -358,6 +359,8 @@ def mla(
358359 qk_nope_head_dim = qk_nope_head_dim ,
359360 num_query_heads = num_query_heads ,
360361 )
362+ key = jax .ad_checkpoint .checkpoint_name (key , "key_proj" )
363+ value = jax .ad_checkpoint .checkpoint_name (value , "value_proj" )
361364 out = attention_op_fn (
362365 query ,
363366 key ,
@@ -366,7 +369,9 @@ def mla(
366369 model_mode ,
367370 cached_values = [None , None ],
368371 )
372+ out = jax .ad_checkpoint .checkpoint_name (out , "attention_out" )
369373 out = dot (out , out_weights , axes = 2 )
374+ out = jax .ad_checkpoint .checkpoint_name (out , "out_proj" )
370375 return out
371376
372377
@@ -405,6 +410,7 @@ def query_projection(
405410 epsilon = epsilon ,
406411 dtype = dtype ,
407412 )
413+ low_rank_q = jax .ad_checkpoint .checkpoint_name (low_rank_q , "mla_q" )
408414 q = dot (low_rank_q , wq_b_weights )
409415
410416 # Split into non-positional and rotary parts.
@@ -454,6 +460,7 @@ def kv_projection(
454460 epsilon = kv_norm_epsilon ,
455461 dtype = dtype ,
456462 )
463+ low_rank_main = jax .ad_checkpoint .checkpoint_name (low_rank_main , "mla_kv" )
457464 key_rope = jnp .expand_dims (low_rank_rope , axis = 2 )
458465 key_rope = yarn (
459466 key_rope ,
@@ -693,6 +700,8 @@ def compute(x, w0, w1, wo, group_sizes, weights, *, wi_tile_size, wo_tile_size,
693700 )
694701 layer_w0 = gmm_fn (x , w0 , tiling = wi_tile_size )
695702 layer_w1 = gmm_fn (x , w1 , tiling = wi_tile_size )
703+ layer_w0 = jax .ad_checkpoint .checkpoint_name (layer_w0 , "mlpwi_0" )
704+ layer_w1 = jax .ad_checkpoint .checkpoint_name (layer_w1 , "mlpwi_1" )
696705 intermediate_layer = jax .nn .silu (layer_w0 ) * layer_w1
697706 intermediate_layer *= weights [:, None ]
698707 return gmm_fn (intermediate_layer , wo , tiling = wo_tile_size )
0 commit comments