Skip to content

Commit f1edf19

Browse files
committed
update pro-academic-trans agent
1 parent 175e259 commit f1edf19

File tree

4 files changed

+36
-40
lines changed

4 files changed

+36
-40
lines changed

ajet/tuner_lib/weight_tuner/experimental/as_swarm_client.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,18 @@ def end_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOut
108108
except Exception as e:
109109
logger.error(f"Error ending episode: {e}")
110110

111-
def abort_episode(self, task:Task, episode_uuid: str, workflow_output: WorkflowOutput):
111+
def abort_episode(self, episode_uuid: str):
112112
if not episode_uuid:
113113
logger.error("No episode to end.")
114114
return
115115

116116
try:
117-
task_id = task.task_id
118-
workflow_output.metadata["task_id"] = task_id
117+
workflow_output = WorkflowOutput(reward=0.0, metadata={})
119118
req_obj = EndEpisodeRequest(
120119
client_uuid=self.client_uuid,
121120
episode_uuid=episode_uuid,
122121
workflow_output=workflow_output,
123-
task_id=task_id
122+
task_id=""
124123
)
125124

126125
resp = httpx.post(

ajet_swarm_threading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def rollout(task):
7777
logger.exception("Exception during rollout group", e)
7878

7979
task_batch = []
80-
for i, task in enumerate(dataset.get_training_tasks()):
80+
for i, task in enumerate(dataset.generate_training_tasks()):
8181
task_batch += [task]
8282

8383
if len(task_batch) == REMOTE_BATCH_SIZE:

tutorial/example_academic_trans/trans_reward.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def get_translation_quality_system_prompt() -> str:
7676
4. **Subject-verb inconsistencies** - Mismatched subjects due to improper sentence structure (e.g., "在...中,本文展示..." where the subject is confused)
7777
5. **Inappropriate word choices** - Using colloquial or incorrect terms instead of proper academic expressions (e.g., "效率" vs "有效性" in certain contexts)
7878
6. **Redundant punctuation** - Unnecessary commas or other punctuation that disrupts Chinese reading flow
79+
7. **主语不清晰** - 中文句子主语缺失或不明确。例如:“通过该实验,证明了该药物对癌细胞有抑制作用”(缺少主语)
7980
8081
**Examples of these errors:**
8182
[[examples_text]]

tutorial/example_academic_trans/trans_roll.py

Lines changed: 31 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url
2828

2929
# --------- configurations that take effect remotely -------------
30+
REMOTE_BATCH_SIZE = 32
3031
REMOTE_ALLOCATE_GPU_PER_NODE = 8
3132
REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-7B-Instruct'
32-
REMOTE_BATCH_SIZE = 32
3333

3434
class WeightUpdatedHalfway(Exception):
3535
"""Raised when the remote side starts updating model weights halfway through an episode."""
@@ -49,55 +49,51 @@ def main():
4949

5050
# Hand shake with remote swarm server
5151
swarm_remote = SwarmClient(REMOTE_SWARM_URL)
52-
# swarm_remote.stop_engine()
5352
swarm_remote.auto_sync_train_config_and_start_engine(
5453
AgentJetJob(
5554
algorithm="grpo",
5655
n_gpu=REMOTE_ALLOCATE_GPU_PER_NODE,
5756
model=REMOTE_TRAIN_MODEL_01,
57+
batch_size=REMOTE_BATCH_SIZE,
5858
grpo_n=LOCAL_GRPO_N,
59-
),
60-
force_restart=True,
59+
)
6160
)
6261

63-
# Define rollout
6462
def rollout(task):
6563
group_reward = []
66-
for i in range(LOCAL_GRPO_N):
67-
episode_uuid = None
68-
try:
69-
# begin episode
70-
episode_uuid, api_baseurl_key = swarm_remote.begin_episode()
71-
# execute agent
72-
workflow_output = execute_agent(task, api_baseurl_key)
73-
# report output back to swarm remote
74-
swarm_remote.end_episode(task, episode_uuid, workflow_output)
75-
# collect reward
76-
group_reward.append(workflow_output.reward)
77-
except Exception as e:
78-
logger.exception("Exception during rollout:", e)
79-
if episode_uuid:
80-
swarm_remote.abort_episode(episode_uuid)
64+
try:
65+
for _ in range(LOCAL_GRPO_N):
66+
try:
67+
# begin episode
68+
episode_uuid, api_baseurl_key = swarm_remote.begin_episode()
69+
# execute agent
70+
workflow_output = execute_agent(task, api_baseurl_key)
71+
# report output back to swarm remote
72+
swarm_remote.end_episode(task, episode_uuid, workflow_output)
73+
# collect reward
74+
group_reward.append(workflow_output.reward)
75+
except Exception as e:
76+
logger.exception("Exception during rollout:", e)
77+
8178
print(f"Group reward mean & std: {sum(group_reward)/len(group_reward)} +/- { (max(group_reward)-min(group_reward))/2 }")
79+
except Exception as e:
80+
logger.exception("Exception during rollout group", e)
8281

83-
# Main Training loop
84-
futures = []
85-
with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor:
86-
for epoch in range(LOCAL_NUM_EPOCH):
87-
for i, task in enumerate(dataset.generate_training_tasks()):
88-
print(f"Submitting task for epoch {epoch}")
89-
future = executor.submit(rollout, task)
82+
task_batch = []
83+
for i, task in enumerate(dataset.generate_training_tasks()):
84+
task_batch += [task]
9085

91-
futures += [future]
92-
while (i % REMOTE_BATCH_SIZE) == (REMOTE_BATCH_SIZE - 1) and futures:
93-
futures = [f for f in futures if not f.done()]
94-
time.sleep(1)
86+
if len(task_batch) == REMOTE_BATCH_SIZE:
87+
print('*********** beginning a new batch of tasks... ***********')
88+
with ThreadPoolExecutor(max_workers=LOCAL_MAX_PARALLEL) as executor:
89+
for task in task_batch:
90+
executor.submit(rollout, task)
91+
executor.shutdown(wait=True)
92+
task_batch = []
93+
print('*********** tasks completed, wait a minute... ***********')
94+
time.sleep(60)
9595

9696

97-
# swarm_remote.stop_engine()
98-
# model_path = swarm_remote.download_latest_model(path='./swarm_saved_model')
99-
time.sleep(10000)
100-
# Get tuned model from swarm remote
10197
return None
10298

10399

0 commit comments

Comments
 (0)