@@ -840,7 +840,7 @@ class IciParallelism(BaseModel):
840840class PipelineParallelism (BaseModel ):
841841 """Configuration for pipeline parallelism."""
842842
843- use_pipeline_weight_prefetching : bool = Field (
843+ pipeline_fsdp_ag_per_repeat : bool = Field (
844844 False , description = "Enable weight prefetching for circular pipeline parallelism."
845845 )
846846 num_layers_per_pipeline_stage : int = Field (1 , description = "Number of layers to place on each pipeline stage." )
@@ -2240,7 +2240,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
22402240 )
22412241 self .num_pipeline_repeats = num_pipeline_repeats
22422242
2243- if self .use_pipeline_weight_prefetching :
2243+ if self .pipeline_fsdp_ag_per_repeat :
22442244 assert self .num_pipeline_repeats > 1 , "Pipeline weight prefetching only supports circular pipeline."
22452245 assert (
22462246 self .num_layers_per_pipeline_stage == 1
@@ -2556,7 +2556,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25562556 "expert" : self .ici_expert_parallelism ,
25572557 "autoregressive" : self .ici_autoregressive_parallelism ,
25582558 "attn_dp" : 1 , # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2559- "attn_dp_expert" : 1 , # initialized to 1, vLLM will auto calculate this value based on EP
2559+ "attn_dp_expert" : 1 , # initialized to 1, vLLM will auto calculate this value based on EP
25602560 }
25612561 self .ici_parallelism = [ici_map [axis ] for axis in self .mesh_axes ]
25622562
@@ -2576,7 +2576,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
25762576 "expert" : self .dcn_expert_parallelism ,
25772577 "autoregressive" : self .dcn_autoregressive_parallelism ,
25782578 "attn_dp" : 1 , # initialized to 1, vLLM will auto calculate this value based on TP and num_kv_heads
2579- "attn_dp_expert" : 1 , # initialized to 1, vLLM will auto calculate this value based on EP
2579+ "attn_dp_expert" : 1 , # initialized to 1, vLLM will auto calculate this value based on EP
25802580 }
25812581 self .dcn_parallelism = [dcn_map [axis ] for axis in self .mesh_axes ]
25822582
0 commit comments