@@ -397,6 +397,9 @@ def default_search_config(self):
397397 "disabled_layers" : None ,
398398 "verbose" : is_master (),
399399 "checkpoint" : None ,
400+ "cost_model" : COST_MODEL_WEIGHT ,
401+ "cost" : {},
402+ "active_moe_expert_ratio" : None ,
400403 }
401404
402405 @property
@@ -405,6 +408,7 @@ def default_state_dict(self) -> SearchStateDict:
405408 return {
406409 "method" : self .method_name ,
407410 "cost_model" : "weight" ,
411+ "cost" : {},
408412 "active_moe_expert_ratio" : None ,
409413 "cost_denominator" : None ,
410414 "candidate_stats" : defaultdict (dict ),
@@ -608,25 +612,22 @@ def initialize_candidate_stats(self):
608612 if not isinstance (hparam , QuantRecipeHparam ):
609613 continue
610614
611- formats , scores , costs , active_costs = [], [], [], []
615+ formats , scores , costs = [], [], []
612616 prev_score = float ("inf" )
613617 for recipe in hparam .choices :
614618 formats .append (recipe )
615619
616620 score = hparam .get_score (recipe ) # type: ignore [arg-type]
617621 cost = hparam .get_cost (recipe ) # type: ignore [arg-type]
618- active_cost = hparam .get_cost (recipe , cost_weight = hparam .cost_weight ) # type: ignore [arg-type]
619622
620623 score = min (score , prev_score ) # TODO: Should we get rid of this?
621624 scores .append (score )
622625 costs .append (cost )
623- active_costs .append (active_cost )
624626 prev_score = score
625627
626628 self .candidate_stats [name ]["formats" ] = formats
627629 self .candidate_stats [name ]["scores" ] = scores
628630 self .candidate_stats [name ]["costs" ] = costs
629- self .candidate_stats [name ]["active_costs" ] = active_costs
630631 self .candidate_stats [name ]["module_names" ] = hparam .quant_module_names
631632 self .candidate_stats [name ]["cost_weight" ] = hparam .cost_weight
632633
@@ -674,6 +675,7 @@ def before_search(self):
674675 )
675676 self .method = self .method_name
676677 self .cost_model = self .config ["cost_model" ]
678+ self .cost = self .config ["cost" ]
677679 self .active_moe_expert_ratio = self .config ["active_moe_expert_ratio" ]
678680 self .cost_denominator = getattr (self , "cost_denominator" , None )
679681
@@ -1466,7 +1468,20 @@ def _resolve_best_recipe(search_state, constraints, verbose=False):
14661468
14671469 searcher .candidate_stats = candidate_stats
14681470 searcher .cost_model = search_state .get ("cost_model" , COST_MODEL_WEIGHT )
1469- searcher .config = {"cost_model" : searcher .cost_model }
1471+ searcher .cost = search_state .get ("cost" , {})
1472+ searcher .active_moe_expert_ratio = search_state .get ("active_moe_expert_ratio" )
1473+ if (
1474+ searcher .cost_model == COST_MODEL_ACTIVE_MOE
1475+ and not searcher .cost
1476+ and searcher .active_moe_expert_ratio is not None
1477+ ):
1478+ searcher .cost = {ACTIVE_MOE_EXPERT_RATIO_KEY : searcher .active_moe_expert_ratio }
1479+ searcher .config = {
1480+ ** searcher .default_search_config ,
1481+ "cost_model" : searcher .cost_model ,
1482+ "cost" : searcher .cost ,
1483+ "active_moe_expert_ratio" : searcher .active_moe_expert_ratio ,
1484+ }
14701485 best_recipe_info , _ = searcher .run_search_with_stats (max_weight_size , verbose = verbose )
14711486
14721487 best_recipe = {name : info ["format" ] for name , info in best_recipe_info .items ()}
0 commit comments