Skip to content

Commit d67b912

Browse files
committed
rename functions
1 parent 1604076 commit d67b912

File tree

3 files changed

+15
-11
lines changed

3 files changed

+15
-11
lines changed

ajet/task_reader/document_reader/doc_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
try:
1212
from unstructured.partition.auto import partition
1313
except Exception:
14-
logger.info("`unstructured` is not installed.")
14+
logger.debug("`unstructured` is not installed.")
1515

1616
from ajet.schema.document import Document
1717
from ajet.task_reader.document_reader.document_reader_base import (

ajet/tuner_lib/experimental/as_swarm_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ def begin_episode(self, discard_episode_timeout=600, episode_type="train", throt
203203
Return:
204204
(episode_uuid, openai_base_url, openai_api_key)
205205
"""
206-
return self._begin_episode_auto_repeat(discard_episode_timeout, episode_type, throttle_policy)
206+
return self._begin_episode_auto_retry(discard_episode_timeout, episode_type, throttle_policy)
207207

208-
def _begin_episode_auto_repeat(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
208+
def _begin_episode_auto_retry(self, discard_episode_timeout=600, episode_type="train", throttle_policy: SwarmThrottlePolicy|None = None) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
209209
# max_episode_time: when an episode has **lasted** for more than X seconds, it will be terminated **locally** by client (call `end_episode` will be re-route to `abort_episode`)
210210
max_episode_time = 2*discard_episode_timeout
211211

tutorial/example_math_swarm/math.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,9 @@
1919
GRPO_N = 4 # grpo group size
2020
NUM_EPOCH = 10000
2121
AJET_SWARM_URL = os.getenv("AJET_SWARM_URL", "http://localhost:10086")
22-
22+
REMOTE_MODEL_PATH = os.getenv("REMOTE_MODEL_PATH", "/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct")
2323
REMOTE_BATCH_SIZE = 32
2424
REMOTE_ALLOCATE_GPU_PER_NODE = 8
25-
# REMOTE_TRAIN_MODEL = '/root/agentjet/modelscope_cache/Qwen/Qwen2.5-7B-Instruct'
26-
REMOTE_TRAIN_MODEL = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct'
2725

2826
def main():
2927

@@ -46,11 +44,11 @@ def main():
4644
experiment_name="math_gsm8k_grpo",
4745
algorithm="grpo",
4846
n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE,
49-
model=REMOTE_TRAIN_MODEL,
47+
model=REMOTE_MODEL_PATH,
5048
batch_size=REMOTE_BATCH_SIZE,
5149
num_repeat=GRPO_N,
5250
),
53-
force_restart=True,
51+
# force_restart=True,
5452
)
5553

5654
def rollout(task):
@@ -76,7 +74,6 @@ def rollout(task):
7674

7775

7876

79-
@retry_with_backoff(max_retry=2)
8077
def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
8178
# Prepare base_url, api_key
8279
base_url, api_key = (api_baseurl_key.base_url, api_baseurl_key.api_key)
@@ -89,8 +86,15 @@ def execute_agent(task: Task, api_baseurl_key: OpenaiBaseUrlAndApiKey):
8986
{ "role": "user", "content": query }
9087
]
9188
# Use raw http requests (non-streaming) to get response
92-
response = requests.post( f"{base_url}/chat/completions", json = { "model": "fill_whatever_model", "messages": messages, },
93-
headers = { "Authorization": f"Bearer {api_key}" } )
89+
# "Connection: close" prevents keep-alive pool reuse, which can cause BadStatusLine
90+
# errors under high concurrency when stale pooled connections return residual bytes.
91+
response = requests.post(
92+
f"{base_url}/chat/completions",
93+
json = { "model": "fill_whatever_model", "messages": messages, "stream": False },
94+
headers = { "Authorization": f"Bearer {api_key}", "Connection": "close" },
95+
timeout = 300,
96+
)
97+
response.raise_for_status()
9498
final_answer = response.json()['choices'][0]['message']['content']
9599

96100
reference_answer = reference_answer.split("####")[-1].strip()

0 commit comments

Comments
 (0)