2727REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url
2828
2929# --------- configurations that take effect remotely -------------
30+ REMOTE_BATCH_SIZE = 32
3031REMOTE_ALLOCATE_GPU_PER_NODE = 8
3132REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct'
32- REMOTE_BATCH_SIZE = 32
3333
3434class WeightUpdatedHalfway (Exception ):
3535 """Raised when the remote side starts updating model weights halfway through an episode."""
@@ -49,55 +49,51 @@ def main():
4949
5050 # Hand shake with remote swarm server
5151 swarm_remote = SwarmClient (REMOTE_SWARM_URL )
52- # swarm_remote.stop_engine()
5352 swarm_remote .auto_sync_train_config_and_start_engine (
5453 AgentJetJob (
5554 algorithm = "grpo" ,
5655 n_gpu = REMOTE_ALLOCATE_GPU_PER_NODE ,
5756 model = REMOTE_TRAIN_MODEL_01 ,
57+ batch_size = REMOTE_BATCH_SIZE ,
5858 grpo_n = LOCAL_GRPO_N ,
59- ),
60- force_restart = True ,
59+ )
6160 )
6261
63- # Define rollout
6462 def rollout (task ):
6563 group_reward = []
66- for i in range (LOCAL_GRPO_N ):
67- episode_uuid = None
68- try :
69- # begin episode
70- episode_uuid , api_baseurl_key = swarm_remote .begin_episode ()
71- # execute agent
72- workflow_output = execute_agent (task , api_baseurl_key )
73- # report output back to swarm remote
74- swarm_remote .end_episode (task , episode_uuid , workflow_output )
75- # collect reward
76- group_reward .append (workflow_output .reward )
77- except Exception as e :
78- logger .exception ("Exception during rollout:" , e )
79- if episode_uuid :
80- swarm_remote .abort_episode (episode_uuid )
64+ try :
65+ for _ in range (LOCAL_GRPO_N ):
66+ try :
67+ # begin episode
68+ episode_uuid , api_baseurl_key = swarm_remote .begin_episode ()
69+ # execute agent
70+ workflow_output = execute_agent (task , api_baseurl_key )
71+ # report output back to swarm remote
72+ swarm_remote .end_episode (task , episode_uuid , workflow_output )
73+ # collect reward
74+ group_reward .append (workflow_output .reward )
75+ except Exception as e :
76+ logger .exception ("Exception during rollout:" , e )
77+
8178 print (f"Group reward mean & std: { sum (group_reward )/ len (group_reward )} +/- { (max (group_reward )- min (group_reward ))/ 2 } " )
79+ except Exception as e :
80+ logger .exception ("Exception during rollout group" , e )
8281
83- # Main Training loop
84- futures = []
85- with ThreadPoolExecutor (max_workers = LOCAL_MAX_PARALLEL ) as executor :
86- for epoch in range (LOCAL_NUM_EPOCH ):
87- for i , task in enumerate (dataset .generate_training_tasks ()):
88- print (f"Submitting task for epoch { epoch } " )
89- future = executor .submit (rollout , task )
82+ task_batch = []
83+ for i , task in enumerate (dataset .generate_training_tasks ()):
84+ task_batch += [task ]
9085
91- futures += [future ]
92- while (i % REMOTE_BATCH_SIZE ) == (REMOTE_BATCH_SIZE - 1 ) and futures :
93- futures = [f for f in futures if not f .done ()]
94- time .sleep (1 )
86+ if len (task_batch ) == REMOTE_BATCH_SIZE :
87+ print ('*********** beginning a new batch of tasks... ***********' )
88+ with ThreadPoolExecutor (max_workers = LOCAL_MAX_PARALLEL ) as executor :
89+ for task in task_batch :
90+ executor .submit (rollout , task )
91+ executor .shutdown (wait = True )
92+ task_batch = []
93+ print ('*********** tasks completed, wait a minute... ***********' )
94+ time .sleep (60 )
9595
9696
97- # swarm_remote.stop_engine()
98- # model_path = swarm_remote.download_latest_model(path='./swarm_saved_model')
99- time .sleep (10000 )
100- # Get tuned model from swarm remote
10197 return None
10298
10399
0 commit comments