@@ -714,7 +714,7 @@ def process_layer_scannable(carry, layer_idx, group_id):
714714 pairwise_swap_and_negate_mask = yarn_mask ,
715715 )
716716 # Offload to host memory.
717- for residual_name in ("mlpwi_0" , "mlpwi_1" ):
717+ for residual_name in ("mlpwi_0" , "mlpwi_1" , "attn_out" , "layer_inputs" ):
718718 r = res .pop (residual_name )
719719 r = jax .tree .map (lambda x : jax .device_put (x , jax .typeof (x ).sharding .with_memory_kind ("pinned_host" )), r )
720720 res [residual_name ] = r
@@ -736,7 +736,7 @@ def process_layer_scannable(carry, layer_idx, group_id):
736736 pairwise_swap_and_negate_mask = yarn_mask ,
737737 )
738738 # Offload first layer residuals to host memory.
739- for residual_name in ("mlpwi_0" , "mlpwi_1" ):
739+ for residual_name in ("mlpwi_0" , "mlpwi_1" , "attn_out" , "layer_inputs" ):
740740 r = first_res .pop (residual_name )
741741 r = jax .tree .map (lambda x : jax .device_put (x , jax .typeof (x ).sharding .with_memory_kind ("pinned_host" )), r )
742742 first_res [residual_name ] = r
@@ -829,7 +829,7 @@ def process_layer_bwd_scannable(carry, res_and_layer_idx, group_id):
829829 next_next_ws_grad = all_reduce_ws_grad_dcn (next_next_ws_grad , mesh )
830830 all_layer_ws_grad = insert_layer_ws_grad (all_layer_ws_grad , next_next_ws_grad , layer_idx + 2 , cfg .param_scan_axis )
831831 # Get residuals from host.
832- for residual_name in ("mlpwi_0" , "mlpwi_1" ):
832+ for residual_name in ("mlpwi_0" , "mlpwi_1" , "attn_out" , "layer_inputs" ):
833833 r = res .pop (residual_name )
834834 r = jax .tree .map (lambda x : jax .device_put (x , jax .typeof (x ).sharding .with_memory_kind ("device" )), r )
835835 res [residual_name ] = r
@@ -890,7 +890,7 @@ def process_layer_bwd_scannable(carry, res_and_layer_idx, group_id):
890890 prev_prev_ws = gather_weights (extract_layer_weights (all_weights , num_layers - 3 , cfg .param_scan_axis ), mesh )
891891 ws_grad = reduce_scatter_ws_grad (ws_grad , mesh )
892892 # Get residuals from host.
893- for residual_name in ("mlpwi_0" , "mlpwi_1" ):
893+ for residual_name in ("mlpwi_0" , "mlpwi_1" , "attn_out" , "layer_inputs" ):
894894 r = last_last_res .pop (residual_name )
895895 r = jax .tree .map (lambda x : jax .device_put (x , jax .typeof (x ).sharding .with_memory_kind ("device" )), r )
896896 last_last_res [residual_name ] = r
@@ -931,7 +931,7 @@ def process_layer_bwd_scannable(carry, res_and_layer_idx, group_id):
931931 third_ws_grad = all_reduce_ws_grad_dcn (third_ws_grad , mesh )
932932 all_layer_ws_grad = insert_layer_ws_grad (all_layer_ws_grad , third_ws_grad , 2 , cfg .param_scan_axis )
933933 # Get residuals from host.
934- for residual_name in ("mlpwi_0" , "mlpwi_1" ):
934+ for residual_name in ("mlpwi_0" , "mlpwi_1" , "attn_out" , "layer_inputs" ):
935935 r = first_res .pop (residual_name )
936936 r = jax .tree .map (lambda x : jax .device_put (x , jax .typeof (x ).sharding .with_memory_kind ("device" )), r )
937937 first_res [residual_name ] = r
0 commit comments