2222import jax .numpy as jnp
2323
2424
25+ def to_bf16 (tree ):
26+ """Casts all floating-point leaves in a PyTree to bfloat16."""
27+ return jax .tree .map (
28+ lambda t : t .astype (jnp .bfloat16 ) if hasattr (t , "dtype" ) and jnp .issubdtype (t .dtype , jnp .floating ) else t , tree
29+ )
30+
31+
32+ def to_f32 (tree ):
33+ """Casts all floating-point leaves in a PyTree to float32."""
34+ return jax .tree .map (
35+ lambda t : t .astype (jnp .float32 ) if hasattr (t , "dtype" ) and jnp .issubdtype (t .dtype , jnp .floating ) else t , tree
36+ )
37+
38+
2539def get_mesh_axis_dim_indices (physical_partition_spec , axis_name = "fsdp" ):
2640 """Finds the tensor dimension index sharded across a specific physical mesh axis.
2741
@@ -310,7 +324,7 @@ def execute_pipeline_stage_pure(loop_state, w_curr, pipeline_weights):
310324 def execute_pipeline_stage_pure_fwd (loop_state , w_curr , pipeline_weights ):
311325 # Prefetch FSDP-sharded weights for the upcoming pipeline repeat
312326 w_next = model .weight_prefetching (
313- pipeline_weights ,
327+ to_bf16 ( pipeline_weights ) ,
314328 physical_partition_spec ,
315329 loop_state ["loop_iteration" ],
316330 )
@@ -332,9 +346,8 @@ def execute_pipeline_stage_pure_fwd(loop_state, w_curr, pipeline_weights):
332346 )
333347 # Execute the forward pass of the microbatches and generate its VJP.
334348 # The VJP captures necessary checkpoints to evaluate gradients later.
335- (loop_state , bsw ), scan_microbatches_vjp = jax .vjp (scan_microbatches_fn , loop_state , bsw , positions , segment_ids )
349+ (loop_state , _ ), scan_microbatches_vjp = jax .vjp (scan_microbatches_fn , loop_state , bsw , positions , segment_ids )
336350 # Discard the old weights (w_curr) and advance w_next to act as the current weights in the next iteration
337- _ , w_next = bsw
338351 return (loop_state , w_next ), (scan_microbatches_vjp , weight_prefetching_t )
339352
340353 def execute_pipeline_stage_pure_bwd (residuals , g_outputs ):
@@ -349,7 +362,7 @@ def execute_pipeline_stage_pure_bwd(residuals, g_outputs):
349362 # Apply the linear transpose of the weight prefetch to execute the reduce-scatter
350363 # This maps the gradients of the gathered weights back to the FSDP-sharded parameter space
351364 g_w_curr , g_w_next = g_bsw
352- (g_pipeline_weights ,) = weight_prefetching_t (g_w_next )
365+ (g_pipeline_weights ,) = weight_prefetching_t (to_f32 ( g_w_next ) )
353366 # Return gradients corresponding to the three original inputs of execute_pipeline_stage_pure
354367 return g_loop_state , g_w_curr , g_pipeline_weights
355368
0 commit comments