Skip to content

Commit 86ef13f

Browse files
Merge pull request #2920 from AI-Hypercomputer:mohit/diloco_trainer
PiperOrigin-RevId: 868302856
2 parents f227fa2 + 1683fb7 commit 86ef13f

33 files changed

Lines changed: 3782 additions & 20 deletions

File tree

dependencies/requirements/base_requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ array-record
44
cloud-accelerator-diagnostics
55
cloud-tpu-diagnostics
66
datasets
7+
drjax
78
flax
89
gcsfs
910
google-api-python-client

dependencies/requirements/generated_requirements/cuda12-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dill>=0.4.0
4040
distlib>=0.4.0
4141
dm-tree>=0.1.9
4242
docstring-parser>=0.17.0
43+
drjax>=0.1.4
4344
editdistance>=0.8.1
4445
einops>=0.8.1
4546
einshape>=1.0

dependencies/requirements/generated_requirements/tpu-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ dill>=0.4.0
4141
distlib>=0.4.0
4242
dm-tree>=0.1.9
4343
docstring-parser>=0.17.0
44+
drjax>=0.1.4
4445
editdistance>=0.8.1
4546
einops>=0.8.1
4647
einshape>=1.0

dependencies/requirements/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ array-record
44
cloud-accelerator-diagnostics
55
cloud-tpu-diagnostics
66
datasets
7+
drjax>=0.1.4
78
flax
89
gcsfs
910
google-api-python-client
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
{
2+
"architectures": [
3+
"MaxTextForCausalLM"
4+
],
5+
"attention_bias": false,
6+
"attention_dropout": 0.0,
7+
"auto_map": {
8+
"AutoConfig": "configuration_deepseek.DeepseekV3Config",
9+
"AutoModel": "modeling_deepseek.DeepseekV3Model",
10+
"AutoModelForCausalLM": "modeling_deepseek.DeepseekV3ForCausalLM"
11+
},
12+
"bos_token_id": 0,
13+
"eos_token_id": 1,
14+
"ep_size": 1,
15+
"first_k_dense_replace": 3,
16+
"hidden_act": "silu",
17+
"hidden_size": 7168,
18+
"initializer_range": 0.02,
19+
"intermediate_size": 18432,
20+
"kv_lora_rank": 512,
21+
"max_position_embeddings": 163840,
22+
"model_type": "deepseek_v3",
23+
"moe_intermediate_size": 2048,
24+
"moe_layer_freq": 1,
25+
"n_group": 8,
26+
"n_routed_experts": 256,
27+
"n_shared_experts": 1,
28+
"norm_topk_prob": true,
29+
"num_attention_heads": 128,
30+
"num_experts_per_tok": 8,
31+
"num_hidden_layers": 61,
32+
"num_key_value_heads": 128,
33+
"num_nextn_predict_layers": 1,
34+
"q_lora_rank": 1536,
35+
"qk_nope_head_dim": 128,
36+
"qk_rope_head_dim": 64,
37+
"rms_norm_eps": 1e-06,
38+
"rope_scaling": {
39+
"beta_fast": 32,
40+
"beta_slow": 1,
41+
"factor": 40,
42+
"mscale": 1.0,
43+
"mscale_all_dim": 1.0,
44+
"original_max_position_embeddings": 4096,
45+
"type": "yarn"
46+
},
47+
"rope_theta": 10000,
48+
"routed_scaling_factor": 2.5,
49+
"scoring_func": "sigmoid",
50+
"tie_word_embeddings": false,
51+
"topk_group": 4,
52+
"topk_method": "noaux_tc",
53+
"torch_dtype": "bfloat16",
54+
"transformers_version": "4.33.1",
55+
"use_cache": true,
56+
"v_head_dim": 128,
57+
"vocab_size": 129280
58+
}

src/MaxText/sharding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,13 @@
3636

3737
def get_input_data_sharding(config, mesh):
3838
"""Get the input data sharding for the model"""
39-
return create_sharding(mesh, config.input_data_sharding_logical_axes, rules=config.logical_axis_rules)
39+
if config.enable_diloco:
40+
data_sharding = create_sharding(
41+
mesh, ["diloco"] + config.input_data_sharding_logical_axes, rules=config.logical_axis_rules
42+
)
43+
else:
44+
data_sharding = create_sharding(mesh, config.input_data_sharding_logical_axes, rules=config.logical_axis_rules)
45+
return data_sharding
4046

4147

4248
def maybe_shard_with_name(inputs, named_sharding, shard_mode, debug_sharding=False, extra_stack_level=0):

src/MaxText/train_compile.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from typing import Sequence
2525
import os
2626
import pickle
27+
import functools
2728

2829
from absl import app
2930

@@ -45,6 +46,7 @@
4546
from maxtext.utils import gcs_utils
4647
from maxtext.utils import max_utils
4748
from maxtext.utils import maxtext_utils
49+
from maxtext.trainers.diloco import diloco
4850

4951
# pylint: disable=too-many-positional-arguments
5052

@@ -235,13 +237,32 @@ def main(argv: Sequence[str]) -> None:
235237

236238
# Get data sharding
237239
data_sharding = sharding.get_input_data_sharding(config, topology_mesh)
238-
239-
# Get function to compile and shardings
240-
func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = (
241-
maxtext_utils.get_functional_train_with_signature(
242-
train.train_step, data_sharding, state_mesh_shardings, model, config
243-
)
244-
)
240+
if config.enable_diloco:
241+
# Build abstract DiLoCo state and shardings for AOT compilation
242+
abstract_state = shaped_train_args[0]
243+
diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state(
244+
config, abstract_state, state_mesh_shardings, topology_mesh
245+
)
246+
shaped_train_args = (diloco_state, shaped_train_args[1], shaped_train_args[2])
247+
248+
# Wrap train_step with diloco
249+
train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, None)
250+
train_step_fn = diloco.build_diloco_train_step(config, train_step_partial)
251+
252+
# For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng)
253+
func_to_compile = train_step_fn
254+
func_to_compile.__name__ = "train_step"
255+
in_shard = (state_mesh_shardings, data_sharding, None) # State, batch, rng
256+
out_shard = (state_mesh_shardings, None) # State, metrics
257+
static_argnums = ()
258+
donate_argnums = 0
259+
else:
260+
# Get function to compile and shardings
261+
func_to_compile, in_shard, out_shard, static_argnums, donate_argnums = (
262+
maxtext_utils.get_functional_train_with_signature(
263+
train.train_step, data_sharding, state_mesh_shardings, model, config
264+
)
265+
)
245266

246267
# print weights sharding info under debug sharding mode
247268
if config.debug_sharding:

src/maxtext/common/data_loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
maybe_record_goodput,
2626
)
2727
from maxtext.utils import exceptions
28+
from maxtext.trainers.diloco import diloco
2829

2930

3031
class DataLoader:
@@ -70,10 +71,13 @@ def load_next_batch_pre_sharding(self):
7071

7172
def load_next_batch(self, *args, **kwargs):
7273
"""Loads the next batch with sharding hint"""
73-
return jax.device_put(
74+
example_batch = jax.device_put(
7475
self.load_next_batch_pre_sharding(),
7576
self.input_data_shardings,
7677
)
78+
if self.config.enable_diloco:
79+
example_batch = diloco.reshape_first_axis_with_diloco(self.config.num_diloco_replicas, example_batch)
80+
return example_batch
7781

7882
def check_example_batch(self):
7983
if self.config.max_checkify:

src/maxtext/configs/base.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess'
400400

401401
# Parallelism
402402
shard_mode: "auto" # can be either auto or explicit
403-
mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
403+
mesh_axes: ['diloco', 'data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
404404
logical_axis_rules: [
405405
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
406406
['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']],
@@ -483,6 +483,7 @@ logical_axis_rules: [
483483
['paged_kv_head_dim_size', []],
484484
['dense_layers', []],
485485
['moe_layers', []],
486+
['diloco', 'diloco'],
486487
]
487488
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
488489
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
@@ -495,6 +496,7 @@ sharding_tolerance: 0.02
495496
# value to auto-shard based on available slices and devices.
496497
# By default, product of the DCN axes should equal number of slices
497498
# and product of the ICI axes should equal number of devices per slice.
499+
dcn_diloco_parallelism: 1
498500
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
499501
dcn_fsdp_parallelism: 1
500502
dcn_fsdp_transpose_parallelism: 1
@@ -507,6 +509,7 @@ dcn_tensor_sequence_parallelism: 1 # never recommended
507509
dcn_pipeline_parallelism: 1
508510
dcn_expert_parallelism: 1
509511
dcn_autoregressive_parallelism: 1 # never recommended
512+
ici_diloco_parallelism: 1
510513
ici_data_parallelism: 1
511514
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
512515
ici_fsdp_transpose_parallelism: 1
@@ -738,6 +741,12 @@ enable_data_shuffling: True
738741
data_shuffle_seed: 0
739742
init_weights_seed: 0
740743

744+
# DiLoCo params.
745+
enable_diloco: False
746+
diloco_sync_period: 36
747+
diloco_outer_lr: 0.3
748+
diloco_outer_momentum: 0.9
749+
741750
# You may disable clipping by setting gradient_clipping_threshold to zero.
742751
gradient_clipping_threshold: 1.0
743752

src/maxtext/configs/types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,7 @@ class LayoutAndSharding(BaseModel):
785785
class DcnParallelism(BaseModel):
786786
"""Parallelism dimensions across the DCN (Data Center Network)."""
787787

788+
dcn_diloco_parallelism: int = Field(1, description="DCN axis for Diloco parallelism.")
788789
dcn_data_parallelism: int = Field(-1, description="DCN axis for data parallelism.")
789790
dcn_fsdp_parallelism: int = Field(1, description="DCN axis for FSDP.")
790791
dcn_fsdp_transpose_parallelism: int = Field(1, description="DCN axis for FSDP transpose.")
@@ -804,6 +805,7 @@ class DcnParallelism(BaseModel):
804805
class IciParallelism(BaseModel):
805806
"""Parallelism dimensions within the ICI (Inter-Chip Interconnect)."""
806807

808+
ici_diloco_parallelism: int = Field(1, description="ICI axis for Diloco parallelism.")
807809
ici_data_parallelism: int = Field(1, description="ICI axis for data parallelism.")
808810
ici_fsdp_parallelism: int = Field(-1, description="ICI axis for FSDP.")
809811
ici_fsdp_transpose_parallelism: int = Field(1, description="ICI axis for FSDP transpose.")
@@ -1083,6 +1085,15 @@ class ManifoldConstrainedHyperConnections(BaseModel):
10831085
sinkhorn_iterations: PositiveInt = Field(20, description="The number of iterations for the Sinkhorn-Knopp algorithm.")
10841086

10851087

1088+
class DilocoParams(BaseModel):
1089+
"""Diloco Hyperparameters"""
1090+
1091+
enable_diloco: bool = Field(False, description="Enable Diloco parallelism")
1092+
diloco_sync_period: int = Field(36, description="Diloco sync period.")
1093+
diloco_outer_lr: float = Field(0.3, description="learning rate for outer optimizer.")
1094+
diloco_outer_momentum: float = Field(0.9, description="momentum for outer optimizer.")
1095+
1096+
10861097
class Optimizer(BaseModel):
10871098
"""Configuration for the optimizer and learning rate schedule."""
10881099

@@ -1633,6 +1644,11 @@ class DerivedValues(BaseModel):
16331644
description="Effective number of query heads, scaled by `global_parameter_scale`.",
16341645
)
16351646

1647+
num_diloco_replicas: None | int = Field(
1648+
None,
1649+
description="The number of diloco replicas, derived from ICI and DCN values.",
1650+
)
1651+
16361652
ici_parallelism: None | list[int] = Field(
16371653
None,
16381654
description="Aggregated list of all ICI parallelism values for legacy compatibility.",
@@ -1780,6 +1796,7 @@ class MaxTextConfig(
17801796
RematAndOffload,
17811797
TrainingLoop,
17821798
ManifoldConstrainedHyperConnections,
1799+
DilocoParams,
17831800
Optimizer,
17841801
AdamW,
17851802
Muon,
@@ -2380,6 +2397,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23802397
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
23812398
if self.using_pipeline_parallelism and self.mesh_axes and self.mesh_axes[0] == "stage":
23822399
self.ici_parallelism = [
2400+
self.ici_diloco_parallelism,
23832401
self.ici_pipeline_parallelism,
23842402
self.ici_data_parallelism,
23852403
self.ici_fsdp_parallelism,
@@ -2394,6 +2412,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
23942412
self.ici_autoregressive_parallelism,
23952413
]
23962414
self.dcn_parallelism = [
2415+
self.dcn_diloco_parallelism,
23972416
self.dcn_pipeline_parallelism,
23982417
self.dcn_data_parallelism,
23992418
self.dcn_fsdp_parallelism,
@@ -2409,6 +2428,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24092428
]
24102429
else:
24112430
ici_map = {
2431+
"diloco": self.ici_diloco_parallelism,
24122432
"data": self.ici_data_parallelism,
24132433
"stage": self.ici_pipeline_parallelism,
24142434
"fsdp": self.ici_fsdp_parallelism,
@@ -2427,6 +2447,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24272447
self.ici_parallelism = [ici_map[axis] for axis in self.mesh_axes]
24282448

24292449
dcn_map = {
2450+
"diloco": self.dcn_diloco_parallelism,
24302451
"data": self.dcn_data_parallelism,
24312452
"stage": self.dcn_pipeline_parallelism,
24322453
"fsdp": self.dcn_fsdp_parallelism,
@@ -2444,6 +2465,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
24442465
}
24452466
self.dcn_parallelism = [dcn_map[axis] for axis in self.mesh_axes]
24462467

2468+
# Diloco params
2469+
self.num_diloco_replicas = int(self.ici_diloco_parallelism * self.dcn_diloco_parallelism)
2470+
24472471
# Final string-to-enum conversions if they haven't been coerced by pydantic yet.
24482472
if isinstance(self.decoder_block, str):
24492473
self.decoder_block = DecoderBlockType(self.decoder_block.lower())

0 commit comments

Comments
 (0)