@@ -127,14 +127,9 @@ def loss_fn(model, config, data, dropout_rng, params, sparsity_state=None, is_tr
127127 rng1 , aqt_rng = jax .random .split (dropout_rng )
128128
129129 # Flax Linen model
130- if sparsity_enabled :
131- model_vars = {"params" : params }
132- else :
133- model_vars = params
134-
135- if sparsity_state and sparsity_enabled :
130+ model_vars = {"params" : params }
131+ if sparsity_state :
136132 model_vars ["batch_stats" ] = sparsity_state
137-
138133 logits , intermediate_outputs = model .apply (
139134 model_vars ,
140135 data ["inputs" ],
@@ -341,16 +336,20 @@ def train_step(model, config, state_mesh_shardings, params_shardings, state, dat
341336 params ,
342337 params_shardings ,
343338 )
344- sparsity_enabled = config .weight_sparsity_n and config .weight_sparsity_m
345- pure_params = params ["params" ] if sparsity_enabled else params
339+ pure_params = params ["params" ] if "params" in params else params
346340 batch_stats = params .get ("batch_stats" , {})
347341
348342 grad_func = jax .value_and_grad (_loss_fn , argnums = 4 , has_aux = True )
349-
350- kwargs = {"is_train" : True }
351- if sparsity_enabled :
352- kwargs ["sparsity_state" ] = batch_stats
353- (loss , aux ), raw_grads = grad_func (model , config , data , dropout_rng , pure_params , * extra_dpo_args , kwargs )
343+ (loss , aux ), raw_grads = grad_func (
344+ model ,
345+ config ,
346+ data ,
347+ dropout_rng ,
348+ pure_params ,
349+ * extra_dpo_args ,
350+ sparsity_state = batch_stats ,
351+ is_train = True ,
352+ )
354353
355354 raw_grads = jax .tree_util .tree_map (
356355 lambda x : x .astype (config .grad_dtype ) if x .dtype == jnp .float32 else x ,
@@ -425,10 +424,9 @@ def move(path, value):
425424 )
426425 )
427426 # Re-wrap grads to match state.params structure if it's a dict of collections
428- sparsity_enabled = config .weight_sparsity_n and config .weight_sparsity_m
429- if sparsity_enabled :
427+ if isinstance (state .params , dict ) and "params" in state .params :
430428 full_grads = {"params" : grads }
431- if sparsity_enabled and "batch_stats" in state .params :
429+ if "batch_stats" in state .params :
432430 batch_stats_grads = jax .tree_util .tree_map (jnp .zeros_like , state .params .get ("batch_stats" , {}))
433431 full_grads ["batch_stats" ] = batch_stats_grads
434432 full_grads = max_utils .unbox_logicallypartioned (full_grads )
@@ -461,7 +459,6 @@ def move(path, value):
461459 and "batch_stats" in state .params
462460 )
463461
464- jax .debug .print ("amanda has_batch_stats: {s}" , s = has_batch_stats )
465462 if has_batch_stats :
466463 new_params = dict (new_state .params )
467464 new_params ["batch_stats" ] = max_utils .unbox_logicallypartioned (aux ["batch_stats" ])
@@ -524,15 +521,11 @@ def eval_step(model, config, state, data, dropout_rng):
524521 extra_dpo_args = [reference_params ]
525522 _loss_fn = dpo_loss_fn
526523
527- sparsity_enabled = config .weight_sparsity_n and config .weight_sparsity_m
528- pure_params = state .params ["params" ] if sparsity_enabled else state .params
524+ pure_params = state .params ["params" ] if "params" in state .params else state .params
529525 batch_stats = state .params .get ("batch_stats" , {})
530526
531527 eval_loss_fn = functools .partial (_loss_fn , model , config , data , dropout_rng , is_train = False )
532- kwargs = {}
533- if sparsity_enabled :
534- kwargs ["sparsity_state" ] = batch_stats
535- loss , aux = eval_loss_fn (pure_params , * extra_dpo_args , ** kwargs )
528+ loss , aux = eval_loss_fn (pure_params , * extra_dpo_args , sparsity_state = batch_stats )
536529
537530 mtp_acceptance_rate = 0.0
538531 if config .mtp_eval_target_module > 0 :
0 commit comments