From 87d2a5152f538a0ab2aa6072add8f691dd464c07 Mon Sep 17 00:00:00 2001 From: Chris Zuo Date: Fri, 5 Jun 2026 22:07:26 +0000 Subject: [PATCH] : add scripts to run vanilla diloco on v5p cluster : script to run diloco update. --- scripts/diloco/run_diloco.sh | 82 ++++++++++++ src/maxtext/trainers/diloco/diloco.py | 95 +++++++++++--- src/maxtext/trainers/pre_train/train.py | 1 + .../trainers/pre_train/train_compile.py | 120 ++++++------------ src/maxtext/utils/train_utils.py | 24 +++- 5 files changed, 220 insertions(+), 102 deletions(-) create mode 100755 scripts/diloco/run_diloco.sh diff --git a/scripts/diloco/run_diloco.sh b/scripts/diloco/run_diloco.sh new file mode 100755 index 0000000000..63c4c29e30 --- /dev/null +++ b/scripts/diloco/run_diloco.sh @@ -0,0 +1,82 @@ +#!/bin/bash + +# This script launches a DiLoCo pre-training workload on a GKE cluster using XPK. + +set -e + +# --- Environment Setup --- +if ! pip show xpk &> /dev/null; then + echo "xpk not found in the environment. Please install it by running:" + echo "uv pip install -e .[runner] --resolution=lowest" + exit 1 +fi + +# --- Environment Variables --- +export PROJECT_ID="${PROJECT_ID:-cloud-tpu-multipod-dev}" +export CLUSTER_NAME="${CLUSTER_NAME:-auto-v5p-8-bodaborg}" +export ZONE="${ZONE:-europe-west4-b}" +export RESERVATION="${RESERVATION:-}" +export BASE_OUTPUT_DIRECTORY="${BASE_OUTPUT_DIRECTORY:-gs://chriszuo-maxtext-logs}" # change to your own GCS bucket for logging and checkpointing +export DATASET_PATH="${DATASET_PATH:-gs://chriszuo-maxtext-datasets}" # change to your own GSC bucket for datasets. Make sure datasets exists +export DOCKER_IMAGE="${DOCKER_IMAGE:-gcr.io/tpu-prod-env-multipod/maxtext_jax_stable:2026-06-04}" # should update if later versions come up +export TPU_TYPE="${TPU_TYPE:-v5p-8}" # At least v5p-32 is needed to run Qwen3-30b-a3b. For v5p-8 you may need to decrease the PER_DEVICE_BATCH_SIZE +export NUM_SLICES="${NUM_SLICES:-2}" # you need at least two slices to let diloco take effect +export WORKLOAD_NAME="${WORKLOAD_NAME:-$(whoami)-diloco-v5p-$(date +%Y%m%d-%H%M%S)}" # this will be the name of run, for logging purposes + +# --- Hyperparameters --- +export MODEL_NAME="${MODEL_NAME:-qwen3-8b}" +export PER_DEVICE_BATCH_SIZE="${PER_DEVICE_BATCH_SIZE:-2}" +export MAX_TARGET_LENGTH="${MAX_TARGET_LENGTH:-2048}" +export DILOCO_SYNC_PERIOD="${DILOCO_SYNC_PERIOD:-10}" +export DILOCO_OUTER_LR="${DILOCO_OUTER_LR:-0.3}" +export DILOCO_OUTER_MOMENTUM="${DILOCO_OUTER_MOMENTUM:-0.9}" +export TRAINING_STEPS="${TRAINING_STEPS:-20}" + +# --- Variable Validation --- +if [ -z "$PROJECT_ID" ] || [ -z "$CLUSTER_NAME" ] || [ -z "$ZONE" ]; then + echo "Error: PROJECT_ID, CLUSTER_NAME, or ZONE is not set." + exit 1 +fi + +if [ -z "$BASE_OUTPUT_DIRECTORY" ] || [ -z "$DATASET_PATH" ]; then + echo "Error: BASE_OUTPUT_DIRECTORY or DATASET_PATH is not set." + exit 1 +fi + +if [ "$NUM_SLICES" -lt 2 ]; then + echo "Warning: NUM_SLICES is less than 2. DiLoCo will not take effect." +fi + +# MaxText command +MAXTEXT_COMMAND="cd /deps/src/ && python3 maxtext/trainers/pre_train/train.py \ +maxtext/configs/base.yml \ +run_name=$WORKLOAD_NAME \ +save_config_to_gcs=true \ +base_output_directory=$BASE_OUTPUT_DIRECTORY \ +dataset_path=$DATASET_PATH \ +dataset_name='c4/en:3.0.1' \ +eval_dataset_name='c4/en:3.0.1' \ +model_name=$MODEL_NAME \ +tokenizer_type=huggingface \ +tokenizer_path=maxtext/assets/tokenizers/qwen3-tokenizer \ +per_device_batch_size=$PER_DEVICE_BATCH_SIZE \ +max_target_length=$MAX_TARGET_LENGTH \ +enable_diloco=true \ +dcn_diloco_parallelism=$NUM_SLICES \ +diloco_sync_period=$DILOCO_SYNC_PERIOD \ +diloco_outer_lr=$DILOCO_OUTER_LR \ +diloco_outer_momentum=$DILOCO_OUTER_MOMENTUM \ +steps=$TRAINING_STEPS" + +# Workload Creation +echo "Submitting DiLoCo job to XPK..." +xpk workload create \ + --cluster="$CLUSTER_NAME" \ + --project="$PROJECT_ID" \ + --reservation="$RESERVATION" \ + --zone="$ZONE" \ + --tpu-type="$TPU_TYPE" \ + --num-slices="$NUM_SLICES" \ + --docker-image="${DOCKER_IMAGE}" \ + --workload="${WORKLOAD_NAME}" \ + --command="${MAXTEXT_COMMAND}" diff --git a/src/maxtext/trainers/diloco/diloco.py b/src/maxtext/trainers/diloco/diloco.py index ef650b872e..e3d7b24581 100644 --- a/src/maxtext/trainers/diloco/diloco.py +++ b/src/maxtext/trainers/diloco/diloco.py @@ -205,34 +205,58 @@ def build_diloco_state( nesterov=True, ) + state = initialize_state() + @drjax.program(placements={"diloco": config.num_diloco_replicas}) - def init_diloco_state() -> tuple[DiLoCoTrainState, PyTree]: - state = initialize_state() + def init_inner_state() -> Any: # Inner state must be broadcast across clients. # Pass mesh explicitly because jax.set_mesh() uses a different thread-local # than pxla.thread_resources (which drjax reads), so drjax cannot find the # mesh automatically when jax.set_mesh is used. - inner_state = drjax.broadcast(state, mesh=mesh) - # Outer state retains a single copy of the model parameters and optimizer state. - # For NNX, model params (Param variables only) live under state.model; - # for Linen under state.params. - outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params - outer_opt_state = outer_optimizer.init(outer_params) - outer_opt_state_sharding = jax.tree_util.tree_map(lambda x: x.sharding, outer_opt_state) - # For NNX, the step counter lives at state.optimizer.step; for Linen at state.step. - step = state.optimizer.step if config.pure_nnx else state.step + return drjax.broadcast(state, mesh=mesh) + + inner_state = init_inner_state() + + # Outer state retains a single copy of the model parameters and optimizer state. + # For NNX, model params (Param variables only) live under state.model; + # for Linen under state.params. + outer_params = state.model.filter(nnx.Param) if config.pure_nnx else state.params + + # For NNX, the step counter lives at state.optimizer.step; for Linen at state.step. + step = state.optimizer.step if config.pure_nnx else state.step + + # Drop the reference to `state` so that JAX's asynchronous garbage collector + # can naturally free the inner optimizer arrays from the original non-DiLoCo state. + # This prevents the TPU from holding both copies in memory while allocating outer_opt_state. + del state + + outer_opt_state_sharding = ( + optax.TraceState(trace=jax.tree_util.tree_map(lambda x: getattr(x, "sharding", None), outer_params)), + optax.EmptyState(), + ) + + # Initialize outer_opt_state using jax.jit with explicit out_shardings to avoid + # creating unsharded zeros eagerly or relying on tracers' sharding inside drjax.program. + def init_outer_opt_state(): return ( - DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=step), - outer_opt_state_sharding, + optax.TraceState(trace=jax.tree_util.tree_map(lambda p: jnp.zeros(p.shape, p.dtype), outer_params)), + optax.EmptyState(), ) - return init_diloco_state() + outer_opt_state = jax.jit(init_outer_opt_state, out_shardings=outer_opt_state_sharding)() + return ( + DiLoCoTrainState(inner_state=inner_state, params=outer_params, outer_opt_state=outer_opt_state, step=step), + outer_opt_state_sharding, + ) def build_diloco_train_step( config: pyconfig.HyperParameters, train_step: Callable[[Any, Batch, PRNGKey], tuple[Any, Metrics]], mesh: jax.sharding.Mesh | None = None, + outer_params_shardings: PyTree | None = None, + inner_model_params_shardings: PyTree | None = None, + outer_opt_state_shardings: PyTree | None = None, ) -> Callable[[DiLoCoTrainState, Batch, PRNGKey], tuple[DiLoCoTrainState, Metrics]]: """Convert a local state and train step into DiLoCo-compatible versions. @@ -256,17 +280,58 @@ def build_diloco_train_step( def synchronize(state): # Calculate the delta between the current replica's state and the global # state (since last synchronization). - broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh) # For NNX, model Param vars live under inner_state.model; for Linen under inner_state.params. inner_model_params = ( nnx.filter_state(state.inner_state.model, nnx.Param) if config.pure_nnx else state.inner_state.params ) + + # Helper to enforce the correct FSDP/Tensor sharding so XLA doesn't gather the full tensor + # and OOM the TPU during the megascale All-Reduce over the DCN. + def _apply_sharding(reference, target, explicit_sharding=None): + if explicit_sharding is not None: + return jax.lax.with_sharding_constraint(target, explicit_sharding) + sharding = getattr(reference, "sharding", None) + if sharding is not None: + return jax.lax.with_sharding_constraint(target, sharding) + return target + + _inner_shardings = ( + inner_model_params_shardings + if inner_model_params_shardings is not None + else jax.tree.map(lambda _: None, inner_model_params) + ) + _outer_shardings = ( + outer_params_shardings if outer_params_shardings is not None else jax.tree.map(lambda _: None, state.params) + ) + _opt_shardings = ( + outer_opt_state_shardings + if outer_opt_state_shardings is not None + else jax.tree.map(lambda _: None, state.outer_opt_state) + ) + + broadcast_outer_params = drjax.broadcast(state.params, mesh=mesh) + broadcast_outer_params = jax.tree.map(_apply_sharding, inner_model_params, broadcast_outer_params, _inner_shardings) + model_delta = jax.tree.map(lambda x, y: y - x, inner_model_params, broadcast_outer_params) + model_delta = jax.tree.map(_apply_sharding, inner_model_params, model_delta, _inner_shardings) + # Treat the average delta as the outer optimizer's gradient and apply to # the global (outer) model params. averaged_pseudo_grad = drjax.reduce_mean(model_delta) + averaged_pseudo_grad = jax.tree.map(_apply_sharding, state.params, averaged_pseudo_grad, _outer_shardings) + updates, new_opt_state = outer_optimizer.update(averaged_pseudo_grad, state.outer_opt_state, state.params) new_outer_params = optax.apply_updates(state.params, updates) + + # Cast back to original dtype to prevent silent promotion to f32 (which doubles memory) + # and enforce sharding on the new params. + def _cast_and_shard(reference, target, explicit_sharding=None): + target = target.astype(reference.dtype) + return _apply_sharding(reference, target, explicit_sharding) + + new_outer_params = jax.tree.map(_cast_and_shard, state.params, new_outer_params, _outer_shardings) + new_opt_state = jax.tree.map(_cast_and_shard, state.outer_opt_state, new_opt_state, _opt_shardings) + # Replace inner model params with the new global model params. # NOTE: inner optimizer state is retained despite the change in parameters, # see section 6.1 in https://arxiv.org/pdf/2311.08105. diff --git a/src/maxtext/trainers/pre_train/train.py b/src/maxtext/trainers/pre_train/train.py index 9a13849248..ae403f7fc4 100644 --- a/src/maxtext/trainers/pre_train/train.py +++ b/src/maxtext/trainers/pre_train/train.py @@ -708,6 +708,7 @@ def train_loop(config, recorder, state=None): max_utils.print_mem_stats("After params initialized") last_step_completion = datetime.datetime.now() + # Temporarily disable metric logging to avoid cross-slice read metric_logger_instance.buffer_and_write_metrics(metrics, step, step_time_delta) if config.save_checkpoint_on_completion: diff --git a/src/maxtext/trainers/pre_train/train_compile.py b/src/maxtext/trainers/pre_train/train_compile.py index 471abac3f0..0e9e4cc879 100644 --- a/src/maxtext/trainers/pre_train/train_compile.py +++ b/src/maxtext/trainers/pre_train/train_compile.py @@ -61,12 +61,9 @@ def validate_config(config): """Validates the config is is setup correctly to compile, returning a useful error message if not.""" assert config.compile_topology != "", ( - "You must pass your desired target hardware in compile_topology, e.g." - " compile_topology=v5e-256" + "You must pass your desired target hardware in compile_topology, e.g." " compile_topology=v5e-256" ) - assert ( - config.compile_topology_num_slices > 0 - ), "You must set compile_topology_num_slices to a positive integer" + assert config.compile_topology_num_slices > 0, "You must set compile_topology_num_slices to a positive integer" def get_topology_mesh(config): @@ -78,18 +75,12 @@ def get_topology_mesh(config): num_slices=config.compile_topology_num_slices, ).devices else: - target_hardware = accelerator_to_spec_map.get_system_characteristics( - config.compile_topology - ) + target_hardware = accelerator_to_spec_map.get_system_characteristics(config.compile_topology) if target_hardware.platform == "gpu": # Disable sharded autotuning. This is an optimization to distribute # autotuning across the fleet, but can cause hangs with AoT compilation. - os.environ["XLA_FLAGS"] = ( - os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false" - ) - jax.config.update( - "mock_num_gpu_processes", config.compile_topology_num_slices - ) + os.environ["XLA_FLAGS"] = os.environ.get("XLA_FLAGS", "") + " --xla_gpu_shard_autotuning=false" + jax.config.update("mock_num_gpu_processes", config.compile_topology_num_slices) topology_devices = jax.devices() else: topology_devices = get_topology_desc( @@ -104,14 +95,8 @@ def get_topology_mesh(config): "jax_remove_size_one_mesh_axis_from_type", config.remove_size_one_mesh_axis_from_type, ) - topology_device_mesh = maxtext_utils.create_device_mesh( - config, topology_devices - ) - mesh_axis_type = ( - AxisType.Explicit - if config.shard_mode == ShardMode.EXPLICIT - else AxisType.Auto - ) + topology_device_mesh = maxtext_utils.create_device_mesh(config, topology_devices) + mesh_axis_type = AxisType.Explicit if config.shard_mode == ShardMode.EXPLICIT else AxisType.Auto topology_mesh = Mesh( topology_device_mesh, config.mesh_axes, @@ -129,9 +114,7 @@ def _collect_nnx_activation_shardings(create_model_fn, config, mesh): input_shape = (config.micro_batch_size_to_train_on, config.max_target_length) abstract_input = jax.ShapeDtypeStruct(input_shape, jnp.int32) - def _nnx_forward( - decoder_input_tokens, decoder_positions, decoder_segment_ids - ): + def _nnx_forward(decoder_input_tokens, decoder_positions, decoder_segment_ids): model_instance = create_model_fn() return model_instance( decoder_input_tokens=decoder_input_tokens, @@ -140,9 +123,7 @@ def _nnx_forward( enable_dropout=False, ) - with jax.set_mesh(mesh), nn_partitioning.axis_rules( - config.logical_axis_rules - ): + with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): jax.eval_shape(_nnx_forward, abstract_input, abstract_input, abstract_input) @@ -151,13 +132,9 @@ def get_shaped_inputs(topology_mesh, config): # Construct the model and optimizer to get shaped versions of the state quant = quantizations.configure_quantization(config) if config.pure_nnx: - _create_model_partial, model = ( - model_creation_utils.create_nnx_abstract_model(config, topology_mesh) - ) + _create_model_partial, model = model_creation_utils.create_nnx_abstract_model(config, topology_mesh) else: - model = Transformer( - config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN - ) + model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN) # The learning_rate_schedule is baked into the compiled object. learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config) # pass in model for muon @@ -176,20 +153,14 @@ def create_train_state_fn(): init_state_fn = create_train_state_fn else: - init_state_fn = functools.partial( - maxtext_utils.init_initial_state, model, tx, config, True, example_rng - ) + init_state_fn = functools.partial(maxtext_utils.init_initial_state, model, tx, config, True, example_rng) # Shaped state - abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state( - config, topology_mesh, init_state_fn, True - ) + abstract_state, _, state_mesh_shardings = maxtext_utils.get_abstract_state(config, topology_mesh, init_state_fn, True) if config.pure_nnx: # NNX doesn't use Linen logical annotations; derive PartitionSpecs from the physical shardings. - logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx( - state_mesh_shardings - ) + logical_annotations = maxtext_utils_nnx.get_partition_spec_nnx(state_mesh_shardings) # For NNX, get_functional_train_with_signature expects the graphdef (static structure), # not the raw model — mirroring how the training loop does nnx.split(train_state). with nn_partitioning.axis_rules(config.logical_axis_rules): @@ -198,9 +169,7 @@ def create_train_state_fn(): model = graphdef else: # unsharded logical annotations - logical_annotations = maxtext_utils.get_logical_annotations( - config, topology_mesh, init_state_fn - ) + logical_annotations = maxtext_utils.get_logical_annotations(config, topology_mesh, init_state_fn) # Shaped batch shaped_batch = maxtext_utils.get_shaped_batch(config) @@ -217,9 +186,7 @@ def create_train_state_fn(): # Collect NNX activation shardings via an abstract forward pass (must run # after get_abstract_state, which only traces __init__). if config.debug_sharding and config.pure_nnx: - _collect_nnx_activation_shardings( - _create_model_partial, config, topology_mesh - ) + _collect_nnx_activation_shardings(_create_model_partial, config, topology_mesh) return ( shaped_train_args, @@ -256,9 +223,7 @@ def jit_and_compile( maxtext_utils.maybe_dump_jaxpr(config, jitted, func_input_args) lowered = jitted.lower(*func_input_args, **func_input_kwargs) # Import libtpu flags as compiler options. Defaults to empty dict if string is empty. - compiler_options = max_utils.parse_libtpu_flags_to_dict( - config.compile_xla_flags - ) + compiler_options = max_utils.parse_libtpu_flags_to_dict(config.compile_xla_flags) compiled = lowered.compile(compiler_options=compiler_options) return compiled @@ -293,18 +258,12 @@ def is_oom(argv: Sequence[str]) -> bool: ) = get_shaped_inputs(topology_mesh, config) # Update params_shardings when shard_optimizer_over_data is enabled (Zero-1) - params_shardings, state_mesh_shardings = ( - sharding.maybe_update_params_sharding_with_opt( - config, state_mesh_shardings - ) - ) + params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) # When ZeRO-1 is enabled, we need to use the original params_shardings for input shardings # but keep the updated state_mesh_shardings for the optimizer state if config.shard_optimizer_over_data: - input_state_mesh_shardings = state_mesh_shardings.replace( - params=params_shardings - ) + input_state_mesh_shardings = state_mesh_shardings.replace(params=params_shardings) else: input_state_mesh_shardings = state_mesh_shardings @@ -355,8 +314,7 @@ def is_oom(argv: Sequence[str]) -> bool: def main(argv: Sequence[str]) -> None: jax.config.update("jax_default_prng_impl", "unsafe_rbg") os.environ["LIBTPU_INIT_ARGS"] = ( - os.environ.get("LIBTPU_INIT_ARGS", "") - + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" ) print("Starting train_compile.py...", flush=True) @@ -381,18 +339,12 @@ def main(argv: Sequence[str]) -> None: ) = get_shaped_inputs(topology_mesh, config) # Update params_shardings when shard_optimizer_over_data is enabled (Zero-1) - params_shardings, state_mesh_shardings = ( - sharding.maybe_update_params_sharding_with_opt( - config, state_mesh_shardings - ) - ) + params_shardings, state_mesh_shardings = sharding.maybe_update_params_sharding_with_opt(config, state_mesh_shardings) # When ZeRO-1 is enabled, we need to use the original params_shardings for input shardings # but keep the updated state_mesh_shardings for the optimizer state if config.shard_optimizer_over_data: - input_state_mesh_shardings = state_mesh_shardings.replace( - params=params_shardings - ) + input_state_mesh_shardings = state_mesh_shardings.replace(params=params_shardings) else: input_state_mesh_shardings = state_mesh_shardings @@ -401,22 +353,25 @@ def main(argv: Sequence[str]) -> None: if config.enable_diloco: # Build abstract DiLoCo state and shardings for AOT compilation abstract_state = shaped_train_args[0] - diloco_state, state_mesh_shardings, inner_state_shardings = ( - diloco.build_abstract_diloco_state( - config, abstract_state, state_mesh_shardings, topology_mesh - ) + diloco_state, state_mesh_shardings, inner_state_shardings = diloco.build_abstract_diloco_state( + config, abstract_state, state_mesh_shardings, topology_mesh ) # For NNX, shaped_train_args has 2 elements (state, batch) — no rng; pass None for prng. - shaped_rng_arg = ( - shaped_train_args[2] if len(shaped_train_args) > 2 else None - ) + shaped_rng_arg = shaped_train_args[2] if len(shaped_train_args) > 2 else None shaped_train_args = (diloco_state, shaped_train_args[1], shaped_rng_arg) # Wrap train_step with diloco - train_step_partial = functools.partial( - train.train_step, model, config, inner_state_shardings, params_shardings + train_step_partial = functools.partial(train.train_step, model, config, inner_state_shardings, params_shardings) + train_step_fn = diloco.build_diloco_train_step( + config, + train_step_partial, + mesh=topology_mesh, + outer_params_shardings=params_shardings, + inner_model_params_shardings=inner_state_shardings.params + if not config.pure_nnx + else nnx.filter_state(inner_state_shardings.model, nnx.Param), + outer_opt_state_shardings=diloco_state_shardings.outer_opt_state, ) - train_step_fn = diloco.build_diloco_train_step(config, train_step_partial) # For DiLoCo, the train_step_fn is already fully wrapped and takes (state, batch, prng) func_to_compile = train_step_fn @@ -480,10 +435,7 @@ def main(argv: Sequence[str]) -> None: if config.compiled_trainstep_file != "": print("Saving compiled object...") save_compiled(compiled, config.compiled_trainstep_file) - print( - "Successfully saved compiled object as" - f" {config.compiled_trainstep_file}" - ) + print("Successfully saved compiled object as" f" {config.compiled_trainstep_file}") print("Finished train_compile.py successfully!", flush=True) print(f"Cost analysis: {compiled.cost_analysis()}") print(f"Memory analysis: {compiled.memory_analysis()}") diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 12328bce35..1d00e7a5a8 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -174,8 +174,19 @@ def jit_train_and_eval_step( ): """Returns a JIT-compiled train and eval step function.""" if config.enable_diloco: - train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) - train_step = diloco.build_diloco_train_step(config, train_step_partial, mesh=mesh) + train_step_partial = functools.partial(train_step, model, config, state_mesh_shardings.inner_state, params_shardings) + train_step = diloco.build_diloco_train_step( + config, + train_step_partial, + mesh=mesh, + outer_params_shardings=state_mesh_shardings.params + if not config.pure_nnx + else nnx.filter_state(state_mesh_shardings.model, nnx.Param), + inner_model_params_shardings=state_mesh_shardings.inner_state.params + if not config.pure_nnx + else nnx.filter_state(state_mesh_shardings.inner_state.model, nnx.Param), + outer_opt_state_shardings=state_mesh_shardings.outer_opt_state, + ) data_sharding = sharding.get_input_data_sharding(config, mesh) p_train_step = jit_train_step( config, model, state, state_mesh_shardings, data_sharding, train_step, params_shardings, mesh=mesh @@ -288,7 +299,14 @@ def create_train_state_fn(): if config.enable_diloco: with jax.set_mesh(mesh), nn_partitioning.axis_rules(config.logical_axis_rules): - state, outer_opt_state_sharding = diloco.build_diloco_state(config, lambda: state, mesh=mesh) + + def _get_and_clear_state(): + nonlocal state + s = state + state = None + return s + + state, outer_opt_state_sharding = diloco.build_diloco_state(config, _get_and_clear_state, mesh=mesh) # create state_mesh_shardings for the DilocoState inner_state_shardings = diloco.add_diloco_to_sharding(state_mesh_shardings)