@@ -842,6 +842,9 @@ class IciParallelism(BaseModel):
842842class PipelineParallelism (BaseModel ):
843843 """Configuration for pipeline parallelism."""
844844
845+ pipeline_fsdp_ag_per_repeat : bool = Field (
846+ False , description = "Enable weight prefetching for circular pipeline parallelism."
847+ )
845848 num_layers_per_pipeline_stage : int = Field (1 , description = "Number of layers to place on each pipeline stage." )
846849 num_pipeline_repeats : int = Field (
847850 - 1 ,
@@ -857,6 +860,7 @@ class PipelineParallelism(BaseModel):
857860 )
858861 pipeline_fsdp_ag_once : bool = Field (False , description = "If True, all-gather FSDP weights once per pipeline repeat." )
859862 scan_pipeline_iterations : bool = Field (True , description = "Use jax.lax.scan over pipeline iterations." )
863+ scan_pipeline_repeats : bool = Field (True , description = "Use jax.lax.scan over pipeline repeats." )
860864 scan_layers_per_stage : bool = Field (False , description = "Use jax.lax.scan over layers within a stage." )
861865 set_remat_policy_on_pipeline_iterations : bool = Field (True , description = "Set remat policy on the pipeline scan." )
862866 set_remat_policy_on_layers_per_stage : bool = Field (False , description = "Set remat policy on the inner layer scan." )
@@ -2250,6 +2254,17 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22502254 )
22512255 self .num_pipeline_repeats = num_pipeline_repeats
22522256
2257+ if self .pipeline_fsdp_ag_per_repeat :
2258+ assert self .num_pipeline_repeats > 1 , "Pipeline weight prefetching only supports circular pipeline."
2259+ assert (
2260+ self .num_layers_per_pipeline_stage == 1
2261+ ), "Pipeline weight prefetching currently only supports one layer per pipeline stage."
2262+ assert (
2263+ not self .pipeline_delay_activation_forwarding
2264+ ), "Pipeline weight prefetching does not support pipeline delay."
2265+ assert not self .quantization , "Quantization is currently not supported for pipeline prefetching."
2266+ assert not self .scan_layers_per_stage , "Pipeline weight prefetching currently does not support scan."
2267+
22532268 assert (num_stages * self .num_pipeline_repeats * self .num_layers_per_pipeline_stage ) == (
22542269 self .pipeline_parallel_layers
22552270 ), (
@@ -2539,78 +2554,45 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25392554 raise ValueError ("`share_kv_projections` is not compatible with `attention_type='mla'`." )
25402555
25412556 # I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
2542- # Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
2543- if self .using_pipeline_parallelism and self .mesh_axes and self .mesh_axes [0 ] == "stage" :
2544- self .ici_parallelism = [
2545- self .ici_diloco_parallelism ,
2546- self .ici_pipeline_parallelism ,
2547- self .ici_data_parallelism ,
2548- self .ici_fsdp_parallelism ,
2549- self .ici_fsdp_transpose_parallelism ,
2550- self .ici_sequence_parallelism ,
2551- self .ici_context_parallelism ,
2552- self .ici_context_autoregressive_parallelism ,
2553- self .ici_tensor_parallelism ,
2554- self .ici_tensor_transpose_parallelism ,
2555- self .ici_tensor_sequence_parallelism ,
2556- self .ici_expert_parallelism ,
2557- self .ici_autoregressive_parallelism ,
2558- ]
2559- self .dcn_parallelism = [
2560- self .dcn_diloco_parallelism ,
2561- self .dcn_pipeline_parallelism ,
2562- self .dcn_data_parallelism ,
2563- self .dcn_fsdp_parallelism ,
2564- self .dcn_fsdp_transpose_parallelism ,
2565- self .dcn_sequence_parallelism ,
2566- self .dcn_context_parallelism ,
2567- self .dcn_context_autoregressive_parallelism ,
2568- self .dcn_tensor_parallelism ,
2569- self .dcn_tensor_transpose_parallelism ,
2570- self .dcn_tensor_sequence_parallelism ,
2571- self .dcn_expert_parallelism ,
2572- self .dcn_autoregressive_parallelism ,
2573- ]
2574- else :
2575- ici_map = {
2576- "diloco" : self .ici_diloco_parallelism ,
2577- "data" : self .ici_data_parallelism ,
2578- "stage" : self .ici_pipeline_parallelism ,
2579- "fsdp" : self .ici_fsdp_parallelism ,
2580- "fsdp_transpose" : self .ici_fsdp_transpose_parallelism ,
2581- "sequence" : self .ici_sequence_parallelism ,
2582- "context" : self .ici_context_parallelism ,
2583- "context_autoregressive" : self .ici_context_autoregressive_parallelism ,
2584- "tensor" : self .ici_tensor_parallelism ,
2585- "tensor_transpose" : self .ici_tensor_transpose_parallelism ,
2586- "tensor_sequence" : self .ici_tensor_sequence_parallelism ,
2587- "model" : self .ici_tensor_parallelism ,
2588- "expert" : self .ici_expert_parallelism ,
2589- "autoregressive" : self .ici_autoregressive_parallelism ,
2590- "attn_dp" : 1 , # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2591- "attn_dp_expert" : 1 , # initialized to 1, vLLM will auto calculate this value based on EP
2592- }
2593- self .ici_parallelism = [ici_map [axis ] for axis in self .mesh_axes ]
2594-
2595- dcn_map = {
2596- "diloco" : self .dcn_diloco_parallelism ,
2597- "data" : self .dcn_data_parallelism ,
2598- "stage" : self .dcn_pipeline_parallelism ,
2599- "fsdp" : self .dcn_fsdp_parallelism ,
2600- "fsdp_transpose" : self .dcn_fsdp_transpose_parallelism ,
2601- "sequence" : self .dcn_sequence_parallelism ,
2602- "context" : self .dcn_context_parallelism ,
2603- "context_autoregressive" : self .dcn_context_autoregressive_parallelism ,
2604- "tensor" : self .dcn_tensor_parallelism ,
2605- "tensor_transpose" : self .dcn_tensor_transpose_parallelism ,
2606- "tensor_sequence" : self .dcn_tensor_sequence_parallelism ,
2607- "model" : self .dcn_tensor_parallelism ,
2608- "expert" : self .dcn_expert_parallelism ,
2609- "autoregressive" : self .dcn_autoregressive_parallelism ,
2610- "attn_dp" : 1 , # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2611- "attn_dp_expert" : 1 , # initialized to 1, vLLM will auto calculate this value based on EP
2612- }
2613- self .dcn_parallelism = [dcn_map [axis ] for axis in self .mesh_axes ]
2557+ ici_map = {
2558+ "diloco" : self .ici_diloco_parallelism ,
2559+ "data" : self .ici_data_parallelism ,
2560+ "stage" : self .ici_pipeline_parallelism ,
2561+ "fsdp" : self .ici_fsdp_parallelism ,
2562+ "fsdp_transpose" : self .ici_fsdp_transpose_parallelism ,
2563+ "sequence" : self .ici_sequence_parallelism ,
2564+ "context" : self .ici_context_parallelism ,
2565+ "context_autoregressive" : self .ici_context_autoregressive_parallelism ,
2566+ "tensor" : self .ici_tensor_parallelism ,
2567+ "tensor_transpose" : self .ici_tensor_transpose_parallelism ,
2568+ "tensor_sequence" : self .ici_tensor_sequence_parallelism ,
2569+ "model" : self .ici_tensor_parallelism ,
2570+ "expert" : self .ici_expert_parallelism ,
2571+ "autoregressive" : self .ici_autoregressive_parallelism ,
2572+ "attn_dp" : 1 , # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2573+ "attn_dp_expert" : 1 , # initialized to 1, vLLM will auto calculate this value based on EP
2574+ }
2575+ self .ici_parallelism = [ici_map [axis ] for axis in self .mesh_axes ]
2576+
2577+ dcn_map = {
2578+ "diloco" : self .dcn_diloco_parallelism ,
2579+ "data" : self .dcn_data_parallelism ,
2580+ "stage" : self .dcn_pipeline_parallelism ,
2581+ "fsdp" : self .dcn_fsdp_parallelism ,
2582+ "fsdp_transpose" : self .dcn_fsdp_transpose_parallelism ,
2583+ "sequence" : self .dcn_sequence_parallelism ,
2584+ "context" : self .dcn_context_parallelism ,
2585+ "context_autoregressive" : self .dcn_context_autoregressive_parallelism ,
2586+ "tensor" : self .dcn_tensor_parallelism ,
2587+ "tensor_transpose" : self .dcn_tensor_transpose_parallelism ,
2588+ "tensor_sequence" : self .dcn_tensor_sequence_parallelism ,
2589+ "model" : self .dcn_tensor_parallelism ,
2590+ "expert" : self .dcn_expert_parallelism ,
2591+ "autoregressive" : self .dcn_autoregressive_parallelism ,
2592+ "attn_dp" : 1 , # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2593+ "attn_dp_expert" : 1 , # initialized to 1, vLLM will auto calculate this value based on EP
2594+ }
2595+ self .dcn_parallelism = [dcn_map [axis ] for axis in self .mesh_axes ]
26142596
26152597 # Diloco params
26162598 self .num_diloco_replicas = int (self .ici_diloco_parallelism * self .dcn_diloco_parallelism )
0 commit comments