@@ -87,9 +87,7 @@ def get_first_step(model, state):
8787# -----------------------------------------------------------------------------
8888
8989
90- def loss_fn (
91- model , config , data , dropout_rng , params , sparsity_state = None , is_train = True
92- ):
90+ def loss_fn (model , config , data , dropout_rng , params , sparsity_state = None , is_train = True ):
9391 """loss_fn for both train and eval.
9492
9593 Args:
@@ -121,9 +119,7 @@ def loss_fn(
121119 # make its specific collection mutable so the MTPBlock can sow into it.
122120 if config .mtp_eval_target_module > 0 and not is_train :
123121 mutable_collections .append ("mtp_acceptance" )
124- sparsity_enabled = (
125- is_train and config .weight_sparsity_n and config .weight_sparsity_m
126- )
122+ sparsity_enabled = is_train and config .weight_sparsity_n and config .weight_sparsity_m
127123 if sparsity_enabled :
128124 mutable_collections .append ("batch_stats" )
129125 if isinstance (model , nn .Module ):
@@ -143,9 +139,7 @@ def loss_fn(
143139 data ["inputs_position" ],
144140 decoder_segment_ids = data ["inputs_segmentation" ],
145141 encoder_images = data ["images" ] if config .use_multimodal else None ,
146- encoder_image_masks = data ["image_masks" ]
147- if config .use_multimodal and "image_masks" in data
148- else None ,
142+ encoder_image_masks = data ["image_masks" ] if config .use_multimodal and "image_masks" in data else None ,
149143 enable_dropout = config .enable_dropout if is_train else False ,
150144 rngs = {"dropout" : rng1 , "params" : aqt_rng },
151145 mutable = mutable_collections ,
@@ -286,11 +280,7 @@ def loss_fn(
286280 "indexer_loss" : indexer_loss ,
287281 "moe_bias_updates" : moe_bias_updates ,
288282 "mtp_loss" : mtp_loss ,
289- "batch_stats" : (
290- intermediate_outputs .get ("batch_stats" , None )
291- if hasattr (intermediate_outputs , "get" )
292- else None
293- ),
283+ "batch_stats" : (intermediate_outputs .get ("batch_stats" , None ) if hasattr (intermediate_outputs , "get" ) else None ),
294284 }
295285 return loss , aux
296286
@@ -416,9 +406,7 @@ def move(path, value):
416406 if sparsity_enabled :
417407 full_grads = {"params" : grads }
418408 if sparsity_enabled and "batch_stats" in state .params :
419- batch_stats_grads = jax .tree_util .tree_map (
420- jnp .zeros_like , state .params .get ("batch_stats" , {})
421- )
409+ batch_stats_grads = jax .tree_util .tree_map (jnp .zeros_like , state .params .get ("batch_stats" , {}))
422410 full_grads ["batch_stats" ] = batch_stats_grads
423411 full_grads = max_utils .unbox_logicallypartioned (full_grads )
424412 else :
@@ -501,9 +489,7 @@ def eval_step(model, config, state, data, dropout_rng):
501489 batch_stats = state .params .get ("batch_stats" , {})
502490
503491 eval_loss_fn = functools .partial (_loss_fn , model , config , data , dropout_rng , is_train = False )
504- loss , aux = eval_loss_fn (
505- pure_params , * extra_dpo_args , sparsity_state = batch_stats
506- )
492+ loss , aux = eval_loss_fn (pure_params , * extra_dpo_args , sparsity_state = batch_stats )
507493
508494 mtp_acceptance_rate = 0.0
509495 if config .mtp_eval_target_module > 0 :
@@ -630,6 +616,8 @@ def train_loop(config, recorder, state=None):
630616 eval_step_count = 0
631617 # pylint: disable=not-callable
632618 for eval_batch in eval_data_iterator :
619+ # Shard input eval data
620+ eval_batch = jax .device_put (eval_batch , sharding .get_input_data_sharding (config , mesh ))
633621 if config .eval_steps > 0 and eval_step_count >= config .eval_steps :
634622 break
635623 with jax .set_mesh (mesh ), nn_partitioning .axis_rules (config .logical_axis_rules ):
0 commit comments