Skip to content

Commit f07e5ae

Browse files
committed
refactor again
1 parent 93b53b3 commit f07e5ae

1 file changed

Lines changed: 13 additions & 7 deletions

File tree

src/maxtext/trainers/pre_train/train.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,12 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr
127127
rng1, aqt_rng = jax.random.split(dropout_rng)
128128

129129
# Flax Linen model
130-
model_vars = {"params": params}
131-
if sparsity_state:
132-
model_vars["batch_stats"] = sparsity_state
130+
if sparsity_enabled:
131+
model_vars = {"params": params}
132+
if sparsity_state:
133+
model_vars["batch_stats"] = sparsity_state
134+
else:
135+
model_vars = params
133136
logits, intermediate_outputs = model.apply(
134137
model_vars,
135138
data["inputs"],
@@ -336,7 +339,8 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
336339
params,
337340
params_shardings,
338341
)
339-
pure_params = params["params"] if "params" in params else params
342+
sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m
343+
pure_params = params["params"] if sparsity_enabled else params
340344
batch_stats = params.get("batch_stats", {})
341345

342346
grad_func = jax.value_and_grad(_loss_fn, argnums=4, has_aux=True)
@@ -424,9 +428,10 @@ def move(path, value):
424428
)
425429
)
426430
# Re-wrap grads to match state.params structure if it's a dict of collections
427-
if isinstance(state.params, dict) and "params" in state.params:
431+
sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m
432+
if sparsity_enabled:
428433
full_grads = {"params": grads}
429-
if "batch_stats" in state.params:
434+
if sparsity_enabled and "batch_stats" in state.params:
430435
batch_stats_grads = jax.tree_util.tree_map(jnp.zeros_like, state.params.get("batch_stats", {}))
431436
full_grads["batch_stats"] = batch_stats_grads
432437
full_grads = max_utils.unbox_logicallypartioned(full_grads)
@@ -521,7 +526,8 @@ def eval_step(model, config, state, data, dropout_rng):
521526
extra_dpo_args = [reference_params]
522527
_loss_fn = dpo_loss_fn
523528

524-
pure_params = state.params["params"] if "params" in state.params else state.params
529+
sparsity_enabled = config.weight_sparsity_n and config.weight_sparsity_m
530+
pure_params = state.params["params"] if sparsity_enabled else state.params
525531
batch_stats = state.params.get("batch_stats", {})
526532

527533
eval_loss_fn = functools.partial(_loss_fn, model, config, data, dropout_rng, is_train=False)

0 commit comments

Comments
 (0)