@@ -785,6 +785,7 @@ class LayoutAndSharding(BaseModel):
785785class DcnParallelism (BaseModel ):
786786 """Parallelism dimensions across the DCN (Data Center Network)."""
787787
788+ dcn_diloco_parallelism : int = Field (1 , description = "DCN axis for Diloco parallelism." )
788789 dcn_data_parallelism : int = Field (- 1 , description = "DCN axis for data parallelism." )
789790 dcn_fsdp_parallelism : int = Field (1 , description = "DCN axis for FSDP." )
790791 dcn_fsdp_transpose_parallelism : int = Field (1 , description = "DCN axis for FSDP transpose." )
@@ -804,6 +805,7 @@ class DcnParallelism(BaseModel):
804805class IciParallelism (BaseModel ):
805806 """Parallelism dimensions within the ICI (Inter-Chip Interconnect)."""
806807
808+ ici_diloco_parallelism : int = Field (1 , description = "ICI axis for Diloco parallelism." )
807809 ici_data_parallelism : int = Field (1 , description = "ICI axis for data parallelism." )
808810 ici_fsdp_parallelism : int = Field (- 1 , description = "ICI axis for FSDP." )
809811 ici_fsdp_transpose_parallelism : int = Field (1 , description = "ICI axis for FSDP transpose." )
@@ -1083,6 +1085,15 @@ class ManifoldConstrainedHyperConnections(BaseModel):
10831085 sinkhorn_iterations : PositiveInt = Field (20 , description = "The number of iterations for the Sinkhorn-Knopp algorithm." )
10841086
10851087
1088+ class DilocoParams (BaseModel ):
1089+ """Diloco Hyperparameters"""
1090+
1091+ enable_diloco : bool = Field (False , description = "Enable Diloco parallelism" )
1092+ diloco_sync_period : int = Field (36 , description = "Diloco sync period." )
1093+ diloco_outer_lr : float = Field (0.3 , description = "learning rate for outer optimizer." )
1094+ diloco_outer_momentum : float = Field (0.9 , description = "momentum for outer optimizer." )
1095+
1096+
10861097class Optimizer (BaseModel ):
10871098 """Configuration for the optimizer and learning rate schedule."""
10881099
@@ -1633,6 +1644,11 @@ class DerivedValues(BaseModel):
16331644 description = "Effective number of query heads, scaled by `global_parameter_scale`." ,
16341645 )
16351646
1647+ num_diloco_replicas : None | int = Field (
1648+ None ,
1649+ description = "The number of diloco replicas, derived from ICI and DCN values." ,
1650+ )
1651+
16361652 ici_parallelism : None | list [int ] = Field (
16371653 None ,
16381654 description = "Aggregated list of all ICI parallelism values for legacy compatibility." ,
@@ -1780,6 +1796,7 @@ class MaxTextConfig(
17801796 RematAndOffload ,
17811797 TrainingLoop ,
17821798 ManifoldConstrainedHyperConnections ,
1799+ DilocoParams ,
17831800 Optimizer ,
17841801 AdamW ,
17851802 Muon ,
@@ -2380,6 +2397,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23802397 # Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
23812398 if self .using_pipeline_parallelism and self .mesh_axes and self .mesh_axes [0 ] == "stage" :
23822399 self .ici_parallelism = [
2400+ self .ici_diloco_parallelism ,
23832401 self .ici_pipeline_parallelism ,
23842402 self .ici_data_parallelism ,
23852403 self .ici_fsdp_parallelism ,
@@ -2394,6 +2412,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23942412 self .ici_autoregressive_parallelism ,
23952413 ]
23962414 self .dcn_parallelism = [
2415+ self .dcn_diloco_parallelism ,
23972416 self .dcn_pipeline_parallelism ,
23982417 self .dcn_data_parallelism ,
23992418 self .dcn_fsdp_parallelism ,
@@ -2409,6 +2428,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24092428 ]
24102429 else :
24112430 ici_map = {
2431+ "diloco" : self .ici_diloco_parallelism ,
24122432 "data" : self .ici_data_parallelism ,
24132433 "stage" : self .ici_pipeline_parallelism ,
24142434 "fsdp" : self .ici_fsdp_parallelism ,
@@ -2427,6 +2447,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24272447 self .ici_parallelism = [ici_map [axis ] for axis in self .mesh_axes ]
24282448
24292449 dcn_map = {
2450+ "diloco" : self .dcn_diloco_parallelism ,
24302451 "data" : self .dcn_data_parallelism ,
24312452 "stage" : self .dcn_pipeline_parallelism ,
24322453 "fsdp" : self .dcn_fsdp_parallelism ,
@@ -2444,6 +2465,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24442465 }
24452466 self .dcn_parallelism = [dcn_map [axis ] for axis in self .mesh_axes ]
24462467
2468+ # Diloco params
2469+ self .num_diloco_replicas = int (self .ici_diloco_parallelism * self .dcn_diloco_parallelism )
2470+
24472471 # Final string-to-enum conversions if they haven't been coerced by pydantic yet.
24482472 if isinstance (self .decoder_block , str ):
24492473 self .decoder_block = DecoderBlockType (self .decoder_block .lower ())
0 commit comments