@@ -458,7 +458,7 @@ def init_workers(self):
458458
459459 def _update_interchange_server_status_flag (self , status : str ):
460460 if self .config .ajet .enable_experimental_interchange_server :
461- if self .config .ajet .enable_tinkerscript_mode :
461+ if self .config .ajet .enable_swarm_mode :
462462 from ajet .tuner_lib .weight_tuner .experimental .interchange_utils import http_change_engine_status
463463 http_change_engine_status (self .config , status )
464464
@@ -493,7 +493,7 @@ def fit(self): # noqa: C901
493493
494494 # perform validation before training
495495 # currently, we only support validation using the reward_function.
496- if (self .val_reward_fn is not None ) and (self .config .trainer .get ("val_before_train" , True )) and (not self .config .ajet .enable_tinkerscript_mode ):
496+ if (self .val_reward_fn is not None ) and (self .config .trainer .get ("val_before_train" , True )) and (not self .config .ajet .enable_swarm_mode ):
497497 val_metrics = self ._validate ()
498498 assert val_metrics , f"{ val_metrics = } "
499499 pprint (f"Initial validation metrics: { val_metrics } " )
@@ -651,7 +651,7 @@ def fit(self): # noqa: C901
651651 [str (uuid .uuid4 ()) for _ in range (len (batch .batch ))],
652652 dtype = object ,
653653 )
654- discard_original_batch = self .config .ajet .enable_tinkerscript_mode
654+ discard_original_batch = self .config .ajet .enable_swarm_mode
655655 batch = union_gen_batch_via_task_id (tasks , batch , gen_batch_output , discard_original_batch )
656656 batch .batch ["response_mask" ] = compute_response_mask (batch )
657657
@@ -784,7 +784,7 @@ def fit(self): # noqa: C901
784784 self .val_reward_fn is not None
785785 and self .config .trainer .test_freq > 0
786786 and (is_last_step or self .global_steps % self .config .trainer .test_freq == 0 )
787- and (not self .config .ajet .enable_tinkerscript_mode )
787+ and (not self .config .ajet .enable_swarm_mode )
788788 ):
789789 with marked_timer ("testing" , timing_raw , color = "green" ):
790790 val_metrics : dict = self ._validate ()
@@ -958,7 +958,7 @@ def _validate(self):
958958 dtype = object ,
959959 )
960960 tasks = tasks [: len (main_val_dataset )]
961- discard_original_batch = self .config .ajet .enable_tinkerscript_mode
961+ discard_original_batch = self .config .ajet .enable_swarm_mode
962962 test_batch = union_gen_batch_via_task_id (tasks , test_batch , test_output_gen_batch , discard_original_batch )
963963 # test_batch = test_batch.union(test_output_gen_batch)
964964 test_batch .meta_info ["validate" ] = True
0 commit comments