@@ -143,7 +143,9 @@ def run_scanned_custom(loop_state, positions, segment_ids):
143143 return final_state
144144
145145 def run_scanned_custom_fwd (loop_state , positions , segment_ids ):
146- def step_fn (s ):
146+ def step_fn (lightweight_state , bsw , weights ):
147+ # Reconstruct full state for the model
148+ s = {** lightweight_state , "bsw" : bsw , "weights" : weights }
147149 out = model .run_one_iteration (
148150 s ,
149151 positions ,
@@ -152,40 +154,63 @@ def step_fn(s):
152154 model_mode ,
153155 logical_partition_spec = logical_partition_spec ,
154156 )
155- return out
157+ # Deconstruct back to lightweight to decouple gradients
158+ new_lightweight = {k : v for k , v in out .items () if k not in ["bsw" , "weights" ]}
159+ return new_lightweight , out ["bsw" ], out ["weights" ]
156160
157161 step_fn_remat = jax .remat (step_fn , prevent_cse = False , policy = model .get_pipeline_remat_policy ())
158162
159- # Forward pass generating VJP functions and intermediate state
163+ # Separate heavy and light state initially
164+ initial_lightweight = {k : v for k , v in loop_state .items () if k not in ["bsw" , "weights" ]}
165+ initial_bsw = loop_state ["bsw" ]
166+ initial_weights = loop_state ["weights" ]
167+
160168 def scan_body_fwd (carry , _ ):
161- new_state , vjp_fun = jax .vjp (step_fn_remat , carry )
162- return new_state , vjp_fun
169+ lightweight_carry , bsw_carry , weights_carry = carry
170+ (new_lightweight , new_bsw , new_weights ), vjp_fun = jax .vjp (
171+ step_fn_remat , lightweight_carry , bsw_carry , weights_carry
172+ )
173+ return (new_lightweight , new_bsw , new_weights ), vjp_fun
163174
164- final_state , vjp_funs = jax .lax .scan (
175+ ( final_lightweight , final_bsw , final_weights ) , vjp_funs = jax .lax .scan (
165176 scan_body_fwd ,
166- loop_state ,
177+ ( initial_lightweight , initial_bsw , initial_weights ) ,
167178 None ,
168179 length = length ,
169180 )
170181
171- # We return vjp_funs as residual. model is passed to bwd as arg.
182+ final_state = { ** final_lightweight , "bsw" : final_bsw , "weights" : final_weights }
172183 return final_state , vjp_funs
173184
174185 def run_scanned_custom_bwd (residuals , g_final_state ):
175186 vjp_funs = residuals
176187
177- # Backward scan to accumulate gradients
188+ # Split the gradient of the final state
189+ g_lightweight = {k : v for k , v in g_final_state .items () if k not in ["bsw" , "weights" ]}
190+ g_bsw = g_final_state ["bsw" ]
191+ g_weights = g_final_state ["weights" ]
192+
178193 def scan_body_bwd (carry , vjp_fun ):
179- d_next_state = carry
194+ d_next_lightweight , d_next_bsw , d_next_weights = carry
180195
181- # Backprop d_next_state using saved vjp_fun
182- ( d_curr_state ,) = vjp_fun (d_next_state )
196+ # Apply saved vjp_fun directly
197+ d_curr_lightweight , d_curr_bsw , d_curr_weights = vjp_fun (( d_next_lightweight , d_next_bsw , d_next_weights ) )
183198
184- return d_curr_state , None
199+ # Accumulate gradients for invariant parts
200+ d_bsw_accum = jax .tree .map (lambda x , y : x + y , d_next_bsw , d_curr_bsw )
201+ d_weights_accum = jax .tree .map (lambda x , y : x + y , d_next_weights , d_curr_weights )
202+
203+ return (d_curr_lightweight , d_bsw_accum , d_weights_accum ), None
185204
186205 # Run backward scan
187- d_init_state , _ = jax .lax .scan (scan_body_bwd , g_final_state , vjp_funs , reverse = True )
206+ (d_init_lightweight , d_init_bsw , d_init_weights ), _ = jax .lax .scan (
207+ scan_body_bwd ,
208+ (g_lightweight , g_bsw , g_weights ),
209+ vjp_funs ,
210+ reverse = True ,
211+ )
188212
213+ d_init_state = {** d_init_lightweight , "bsw" : d_init_bsw , "weights" : d_init_weights }
189214 return (d_init_state , None , None )
190215
191216 run_scanned_custom .defvjp (run_scanned_custom_fwd , run_scanned_custom_bwd )
0 commit comments