1919GRPO_N = 4 # grpo group size
2020NUM_EPOCH = 10000
2121AJET_SWARM_URL = os .getenv ("AJET_SWARM_URL" , "http://localhost:10086" )
22-
22+ REMOTE_MODEL_PATH = os . getenv ( "REMOTE_MODEL_PATH" , "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct" )
2323REMOTE_BATCH_SIZE = 32
2424REMOTE_ALLOCATE_GPU_PER_NODE = 8
25- # REMOTE_TRAIN_MODEL = '/root/agentjet/modelscope_cache/Qwen/Qwen2.5-7B-Instruct'
26- REMOTE_TRAIN_MODEL = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct'
2725
2826def main ():
2927
@@ -46,11 +44,11 @@ def main():
4644 experiment_name = "math_gsm8k_grpo" ,
4745 algorithm = "grpo" ,
4846 n_gpu = REMOTE_ALLOCATE_GPU_PER_NODE ,
49- model = REMOTE_TRAIN_MODEL ,
47+ model = REMOTE_MODEL_PATH ,
5048 batch_size = REMOTE_BATCH_SIZE ,
5149 num_repeat = GRPO_N ,
5250 ),
53- force_restart = True ,
51+ # force_restart=True,
5452 )
5553
5654 def rollout (task ):
@@ -76,7 +74,6 @@ def rollout(task):
7674
7775
7876
79- @retry_with_backoff (max_retry = 2 )
8077def execute_agent (task : Task , api_baseurl_key : OpenaiBaseUrlAndApiKey ):
8178 # Prepare base_url, api_key
8279 base_url , api_key = (api_baseurl_key .base_url , api_baseurl_key .api_key )
@@ -89,8 +86,15 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
8986 { "role" : "user" , "content" : query }
9087 ]
9188 # Use raw http requests (non-streaming) to get response
92- response = requests .post ( f"{ base_url } /chat/completions" , json = { "model" : "fill_whatever_model" , "messages" : messages , },
93- headers = { "Authorization" : f"Bearer { api_key } " } )
89+ # "Connection: close" prevents keep-alive pool reuse, which can cause BadStatusLine
90+ # errors under high concurrency when stale pooled connections return residual bytes.
91+ response = requests .post (
92+ f"{ base_url } /chat/completions" ,
93+ json = { "model" : "fill_whatever_model" , "messages" : messages , "stream" : False },
94+ headers = { "Authorization" : f"Bearer { api_key } " , "Connection" : "close" },
95+ timeout = 300 ,
96+ )
97+ response .raise_for_status ()
9498 final_answer = response .json ()['choices' ][0 ]['message' ]['content' ]
9599
96100 reference_answer = reference_answer .split ("####" )[- 1 ].strip ()
0 commit comments