@@ -143,70 +143,48 @@ def run_scanned_custom(loop_state, positions, segment_ids):
143143 return final_state
144144
145145 def run_scanned_custom_fwd (loop_state , positions , segment_ids ):
146- final_state , _ = run_scanned (model , loop_state )
147- # We return loop_state as residual. model is passed to bwd as arg.
148- return final_state , (
149- loop_state ,
150- positions ,
151- segment_ids ,
152- )
153-
154- def run_scanned_custom_bwd (residuals , g_final_state ):
155- init_loop_state , positions , segment_ids = residuals
156-
157- # Re-run forward pass to get saved states (checkpointing)
158- def scan_body_fwd (carry , _ ):
159- new_state = model .run_one_iteration (
160- carry ,
146+ def step_fn (s ):
147+ out = model .run_one_iteration (
148+ s ,
161149 positions ,
162150 segment_ids ,
163151 deterministic ,
164152 model_mode ,
165153 logical_partition_spec = logical_partition_spec ,
166154 )
167- # Return lightweight state for saving (exclude bsw/weights)
168- saved = {k : v for k , v in carry .items () if k not in ["bsw" , "weights" ]}
169- return new_state , saved
155+ return out
170156
171- _ , saved_states = jax .lax .scan (
157+ step_fn_remat = jax .remat (step_fn , prevent_cse = False , policy = model .get_pipeline_remat_policy ())
158+
159+ # Forward pass generating VJP functions and intermediate state
160+ def scan_body_fwd (carry , _ ):
161+ new_state , vjp_fun = jax .vjp (step_fn_remat , carry )
162+ return new_state , vjp_fun
163+
164+ final_state , vjp_funs = jax .lax .scan (
172165 scan_body_fwd ,
173- init_loop_state ,
166+ loop_state ,
174167 None ,
175168 length = length ,
176169 )
177170
171+ # We return vjp_funs as residual. model is passed to bwd as arg.
172+ return final_state , vjp_funs
173+
174+ def run_scanned_custom_bwd (residuals , g_final_state ):
175+ vjp_funs = residuals
176+
178177 # Backward scan to accumulate gradients
179- def scan_body_bwd (carry , saved_slice ):
178+ def scan_body_bwd (carry , vjp_fun ):
180179 d_next_state = carry
181180
182- # Reconstruct current loop_state (input to step)
183- curr_loop_state = {
184- ** saved_slice ,
185- "bsw" : init_loop_state ["bsw" ],
186- "weights" : init_loop_state ["weights" ],
187- }
188-
189- # Define function to differentiate w.r.t loop_state
190- def step_fn (s ):
191- out = model .run_one_iteration (
192- s ,
193- positions ,
194- segment_ids ,
195- deterministic ,
196- model_mode ,
197- logical_partition_spec = logical_partition_spec ,
198- )
199- return out
200-
201- _ , vjp_fun = jax .vjp (step_fn , curr_loop_state )
202-
203- # Backprop d_next_state
181+ # Backprop d_next_state using saved vjp_fun
204182 (d_curr_state ,) = vjp_fun (d_next_state )
205183
206184 return d_curr_state , None
207185
208186 # Run backward scan
209- d_init_state , _ = jax .lax .scan (scan_body_bwd , g_final_state , saved_states , reverse = True )
187+ d_init_state , _ = jax .lax .scan (scan_body_bwd , g_final_state , vjp_funs , reverse = True )
210188
211189 return (d_init_state , None , None )
212190
0 commit comments