@@ -127,9 +127,12 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr
127127 rng1 , aqt_rng = jax .random .split (dropout_rng )
128128
129129 # Flax Linen model
130- model_vars = {"params" : params }
131- if sparsity_state :
132- model_vars ["batch_stats" ] = sparsity_state
130+ if sparsity_enabled :
131+ model_vars = {"params" : params }
132+ if sparsity_state :
133+ model_vars ["batch_stats" ] = sparsity_state
134+ else :
135+ model_vars = params
133136 logits , intermediate_outputs = model .apply (
134137 model_vars ,
135138 data ["inputs" ],
@@ -336,7 +339,8 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
336339 params ,
337340 params_shardings ,
338341 )
339- pure_params = params ["params" ] if "params" in params else params
342+ sparsity_enabled = config .weight_sparsity_n and config .weight_sparsity_m
343+ pure_params = params ["params" ] if sparsity_enabled else params
340344 batch_stats = params .get ("batch_stats" , {})
341345
342346 grad_func = jax .value_and_grad (_loss_fn , argnums = 4 , has_aux = True )
@@ -424,9 +428,10 @@ def move(path, value):
424428 )
425429 )
426430 # Re-wrap grads to match state.params structure if it's a dict of collections
427- if isinstance (state .params , dict ) and "params" in state .params :
431+ sparsity_enabled = config .weight_sparsity_n and config .weight_sparsity_m
432+ if sparsity_enabled :
428433 full_grads = {"params" : grads }
429- if "batch_stats" in state .params :
434+ if sparsity_enabled and "batch_stats" in state .params :
430435 batch_stats_grads = jax .tree_util .tree_map (jnp .zeros_like , state .params .get ("batch_stats" , {}))
431436 full_grads ["batch_stats" ] = batch_stats_grads
432437 full_grads = max_utils .unbox_logicallypartioned (full_grads )
@@ -521,7 +526,8 @@ def eval_step(model, config, state, data, dropout_rng):
521526 extra_dpo_args = [reference_params ]
522527 _loss_fn = dpo_loss_fn
523528
524- pure_params = state .params ["params" ] if "params" in state .params else state .params
529+ sparsity_enabled = config .weight_sparsity_n and config .weight_sparsity_m
530+ pure_params = state .params ["params" ] if sparsity_enabled else state .params
525531 batch_stats = state .params .get ("batch_stats" , {})
526532
527533 eval_loss_fn = functools .partial (_loss_fn , model , config , data , dropout_rng , is_train = False )
0 commit comments