Skip to content

Commit 6945c38

Browse files
committed
Enhance logging in AjetDataParallelPPOActor and increase max_parallel in AIMESwarmTrainer
1 parent 8b978b1 commit 6945c38

3 files changed

Lines changed: 47 additions & 44 deletions

File tree

ajet/backbone/verl/dp_actor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ def update_policy(self, data: DataProto):
198198

199199
self.actor_optimizer.zero_grad()
200200

201-
for micro_batch in micro_batches:
201+
num_micro_batches = len(micro_batches)
202+
for micro_batch_idx, micro_batch in enumerate(micro_batches, 1):
202203
micro_batch = micro_batch.to(get_device_id())
203204
micro_batch_metrics = {}
204205
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch, "pad_token_id": pad_token_id}
@@ -207,7 +208,7 @@ def update_policy(self, data: DataProto):
207208
advantages = model_inputs["advantages"]
208209
# [AJET] Debug logging for tensor shapes
209210
input_ids = model_inputs["input_ids"]
210-
print(f'[Update Policy] -> Micro batch shape, input_ids {input_ids.shape}, response {response_mask.shape}')
211+
print(f'[Update Policy] -> Micro batch shape, input_ids {input_ids.shape}, response {response_mask.shape} @{micro_batch_idx}/{num_micro_batches}')
211212

212213
entropy_coeff = self.config.entropy_coeff
213214
loss_agg_mode = self.config.loss_agg_mode

tutorial/opencode_build_aime/agent_roll.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def train(self):
186186
task_count = 0
187187
executor = PeriodicDrainThreadPoolExecutor(
188188
workers=self.grpo_n * self.remote_batch_size,
189-
max_parallel=64,
189+
max_parallel=256,
190190
auto_retry=True
191191
)
192192

tutorial/opencode_build_aime/agent_run_v3.py

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import tempfile
2121
import time
2222
from dataclasses import dataclass
23+
from textwrap import dedent
2324
from uuid import uuid4
2425

2526
from openai import OpenAI
@@ -96,39 +97,39 @@ def __init__(self, timeout: int = 30, memory_limit_mb: int = 512):
9697
self.memory_limit_mb = memory_limit_mb
9798

9899
def execute(self, code: str, stdin: str = "") -> ProcessExecuteResult:
99-
pre_template = f"""
100-
import signal
101-
import resource
102-
import os
103-
import sys
100+
pre_template = dedent(f"""\
101+
import signal
102+
import resource
103+
import os
104+
import sys
104105
105-
os.environ['OPENBLAS_NUM_THREADS'] = '1'
106+
os.environ['OPENBLAS_NUM_THREADS'] = '1'
106107
107-
def _exec_set_alarm_timeout(timeout):
108-
signal.signal(signal.SIGALRM, _exec_time_exceeded)
109-
signal.alarm(timeout)
108+
def _exec_set_alarm_timeout(timeout):
109+
signal.signal(signal.SIGALRM, _exec_time_exceeded)
110+
signal.alarm(timeout)
110111
111-
def _exec_time_exceeded(*_):
112-
print('Suicide from timeout.', flush=True)
113-
try:
114-
os.killpg(0, 9)
115-
except Exception:
116-
pass
117-
os._exit({TIMEOUT_EXIT_CODE})
112+
def _exec_time_exceeded(*_):
113+
print('Suicide from timeout.', flush=True)
114+
try:
115+
os.killpg(0, 9)
116+
except Exception:
117+
pass
118+
os._exit({TIMEOUT_EXIT_CODE})
118119
119-
def _exec_set_max_runtime(seconds):
120-
soft, hard = resource.getrlimit(resource.RLIMIT_CPU)
121-
resource.setrlimit(resource.RLIMIT_CPU, (seconds, hard))
120+
def _exec_set_max_runtime(seconds):
121+
soft, hard = resource.getrlimit(resource.RLIMIT_CPU)
122+
resource.setrlimit(resource.RLIMIT_CPU, (seconds, hard))
122123
123-
_exec_set_alarm_timeout({self.timeout})
124-
_exec_set_max_runtime({self.timeout})
124+
_exec_set_alarm_timeout({self.timeout})
125+
_exec_set_max_runtime({self.timeout})
125126
126-
_exec_time_start = time.perf_counter()
127-
"""
128-
post_template = f"""
129-
_exec_time_end = time.perf_counter()
130-
_exec_duration = _exec_time_end - _exec_time_start
131-
"""
127+
_exec_time_start = time.perf_counter()
128+
""")
129+
post_template = dedent("""\
130+
_exec_time_end = time.perf_counter()
131+
_exec_duration = _exec_time_end - _exec_time_start
132+
""")
132133

133134
with tempfile.TemporaryDirectory() as tmp_path:
134135
source_path = f"{tmp_path}/source.py"
@@ -307,23 +308,24 @@ async def run(self, messages: list[dict], sampling_params: dict) -> tuple[str, l
307308
user_turns, assistant_turns = 0, 0
308309
all_response_content = []
309310

310-
system_prompt = """You are an expert mathematician specialized in solving challenging math competition problems.
311+
system_prompt = dedent("""\
312+
You are an expert mathematician specialized in solving challenging math competition problems.
311313
312-
You have access to a Python code execution tool. Use it to:
313-
1. Perform calculations and verify your answers
314-
2. Run code when you need precise computation
315-
3. Test your hypotheses before giving final answers
314+
You have access to a Python code execution tool. Use it to:
315+
1. Perform calculations and verify your answers
316+
2. Run code when you need precise computation
317+
3. Test your hypotheses before giving final answers
316318
317-
Instructions:
318-
1. Think through the problem step by step
319-
2. Use the python_code_with_standard_io tool when you need to execute code
320-
3. Show your reasoning clearly
321-
4. Put your final numerical answer inside \\boxed{} at the end
319+
Instructions:
320+
1. Think through the problem step by step
321+
2. Use the python_code_with_standard_io tool when you need to execute code
322+
3. Show your reasoning clearly
323+
4. Put your final numerical answer inside \\boxed{} at the end
322324
323-
For each function call, return a json object within <tool_call></tool_call> XML tags:
324-
<tool_call>
325-
{"name": "python_code_with_standard_io", "arguments": {"code": "your python code", "input": "stdin input if needed"}}
326-
</tool_call>"""
325+
For each function call, return a json object within <tool_call></tool_call> XML tags:
326+
<tool_call>
327+
{"name": "python_code_with_standard_io", "arguments": {"code": "your python code", "input": "stdin input if needed"}}
328+
</tool_call>""")
327329

328330
formatted_messages = [msg for msg in messages if msg.get("role") != "system"]
329331
if not any(msg.get("role") == "system" for msg in messages):

0 commit comments

Comments
 (0)