|
19 | 19 | from jax.sharding import PartitionSpec as P |
20 | 20 | from flax import linen as nn |
21 | 21 | from flax.linen.spmd import LogicallyPartitioned |
| 22 | +import jax.numpy as jnp |
22 | 23 |
|
23 | 24 |
|
24 | 25 | def get_mesh_axis_dim_indices(physical_partition_spec, axis_name="fsdp"): |
@@ -248,8 +249,6 @@ def run_pipeline_microbatches_custom_bwd(residuals, g_final_state): |
248 | 249 |
|
249 | 250 |
|
250 | 251 | def create_rematerialized_pipeline_stage( |
251 | | - model, |
252 | | - run_iteration_scannable, |
253 | 252 | length, |
254 | 253 | deterministic, |
255 | 254 | model_mode, |
@@ -285,39 +284,61 @@ def create_rematerialized_pipeline_stage( |
285 | 284 | the updated `loop_state`. |
286 | 285 | """ |
287 | 286 |
|
288 | | - def execute_pipeline_stage(model, loop_state_and_bsw): |
289 | | - loop_state, w_curr = loop_state_and_bsw |
290 | | - # # Retrieve the specific weights needed for this pipeline chunk |
291 | | - # bsw = model.both_weight_prefetching(pipeline_weights, physical_partition_spec, loop_state["loop_iteration"]) |
292 | | - w_next = jax.remat( |
293 | | - model.one_weight_prefetching, |
294 | | - static_argnums=(1,), |
295 | | - policy=jax.checkpoint_policies.nothing_saveable, |
296 | | - )( |
297 | | - pipeline_weights, |
298 | | - physical_partition_spec, |
299 | | - loop_state["loop_iteration"], |
300 | | - ) |
301 | | - bsw = (w_curr, w_next) |
| 287 | + def execute_pipeline_stage_outer(model, loop_state_and_bsw): |
| 288 | + |
302 | 289 | scan_microbatches_fn = create_gradient_accumulation_scan( |
303 | 290 | model=model, |
304 | 291 | length=length, |
305 | 292 | deterministic=deterministic, |
306 | 293 | model_mode=model_mode, |
307 | 294 | logical_partition_spec=logical_partition_spec, |
308 | 295 | ) |
309 | | - loop_state, bsw = scan_microbatches_fn(loop_state, bsw, positions, segment_ids) |
310 | | - w_curr, w_next = bsw |
311 | | - del w_curr |
312 | | - return (loop_state, w_next), None |
313 | | - |
314 | | - return execute_pipeline_stage |
315 | | - |
316 | | - # return nn.remat( |
317 | | - # execute_pipeline_stage, |
318 | | - # prevent_cse=not model.config.scan_pipeline_iterations, |
319 | | - # policy=model.get_pipeline_remat_policy(), |
320 | | - # ) |
| 296 | + |
| 297 | + remat_weight_prefetching = model.one_weight_prefetching |
| 298 | + |
| 299 | + @jax.custom_vjp |
| 300 | + def execute_pipeline_stage(loop_state_and_bsw, pipeline_weights): |
| 301 | + return execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights)[0] |
| 302 | + |
| 303 | + def execute_pipeline_stage_custom_fwd(loop_state_and_bsw, pipeline_weights): |
| 304 | + loop_state, w_curr = loop_state_and_bsw |
| 305 | + # # Retrieve the specific weights needed for this pipeline chunk |
| 306 | + w_next = remat_weight_prefetching( |
| 307 | + pipeline_weights, |
| 308 | + physical_partition_spec, |
| 309 | + loop_state["loop_iteration"], |
| 310 | + ) |
| 311 | + bsw = (w_curr, w_next) |
| 312 | + |
| 313 | + (loop_state, bsw), scan_fn_vjp = jax.vjp(scan_microbatches_fn, loop_state, bsw, positions, segment_ids) |
| 314 | + p_remat_weight_prefetching = functools.partial( |
| 315 | + remat_weight_prefetching, |
| 316 | + physical_partition_spec=physical_partition_spec, |
| 317 | + loop_iteration=loop_state["loop_iteration"], |
| 318 | + ) |
| 319 | + remat_weight_prefetching_t = jax.linear_transpose( |
| 320 | + p_remat_weight_prefetching, |
| 321 | + pipeline_weights, |
| 322 | + ) |
| 323 | + w_curr, w_next = bsw |
| 324 | + del w_curr |
| 325 | + return (loop_state, w_next), (scan_fn_vjp, remat_weight_prefetching_t) |
| 326 | + |
| 327 | + def execute_pipeline_stage_custom_bwd(residuals, g_outputs): |
| 328 | + g_loop_state, g_w_next = g_outputs |
| 329 | + scan_fn_vjp, remat_weight_prefetching_t = residuals |
| 330 | + g_w_curr = jax.tree.map(jnp.zeros_like, g_w_next) |
| 331 | + g_bsw = (g_w_curr, g_w_next) |
| 332 | + g_loop_state, g_bsw, _, _ = scan_fn_vjp((g_loop_state, g_bsw)) |
| 333 | + g_w_curr, g_w_next = g_bsw |
| 334 | + (g_pipeline_weights,) = remat_weight_prefetching_t(g_w_next) |
| 335 | + return (g_loop_state, g_w_curr), g_pipeline_weights |
| 336 | + |
| 337 | + execute_pipeline_stage.defvjp(execute_pipeline_stage_custom_fwd, execute_pipeline_stage_custom_bwd) |
| 338 | + |
| 339 | + return execute_pipeline_stage(loop_state_and_bsw, pipeline_weights), None |
| 340 | + |
| 341 | + return execute_pipeline_stage_outer |
321 | 342 |
|
322 | 343 |
|
323 | 344 | def create_flax_pipeline_scan(pipeline_stage_fn, length): |
|
0 commit comments