1+ import os
12import re
23import requests
34from textwrap import dedent
45from ajet .schema .task import Task , WorkflowOutput
56from ajet .copilot .job import AgentJetJob
67from ajet .task_reader import RouterTaskReader
78from ajet .utils .retry import retry_with_backoff
8- from ajet .utils .thread_executors import BoundedThreadPoolExecutor
9+ from ajet .utils .thread_executors import PeriodicDrainThreadPoolExecutor
910from ajet .tuner_lib .as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
1011from ajet .tuner_lib .experimental .interchange_utils import SwarmThrottlePolicy
1112from ajet .default_config .ajet_default import AjetTaskReader , HuggingfaceDatRepo
1213from ajet .tuner_lib .experimental .as_swarm_client import SwarmClient , SwarmThrottlePolicy
1314
14- # --------- configurations that take effect locally -------------
15- LOCAL_GRPO_N = 4 # grpo group size
16- LOCAL_NUM_EPOCH = 10000
17- LOCAL_NUM_EPOCH = 1
18- LOCAL_MAX_PARALLEL = 64
19- LOCAL_DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main"
20- REMOTE_SWARM_URL = "http://localhost:10086" # Change to your swarm remote url
21-
22- # --------- configurations that take effect remotely -------------
23- REMOTE_BATCH_SIZE = 32
24- REMOTE_ALLOCATE_GPU_PER_NODE = 4
25- REMOTE_TRAIN_MODEL_01 = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct'
26-
2715# python -m tutorial.example_math_swarm.math
2816
29- class WeightUpdatedHalfway (Exception ):
30- """Raised when the remote side starts updating model weights halfway through an episode."""
17+ GRPO_N = 4 # grpo group size
18+ NUM_EPOCH = 10000
19+ DATASET_PATH = "/mnt/data_cpfs/qingxu.fu/dataset/openai/gsm8k/main"
20+ AJET_SWARM_URL = os .getenv ("AJET_SWARM_URL" , "http://localhost:10086" )
3121
22+ REMOTE_BATCH_SIZE = 32
23+ REMOTE_ALLOCATE_GPU_PER_NODE = 4
24+ REMOTE_TRAIN_MODEL = '/mnt/data_cpfs/model_cache/modelscope/hub/Qwen/Qwen/Qwen2.5-3B-Instruct'
3225
3326def main ():
3427
@@ -37,21 +30,21 @@ def main():
3730 reader_type = "huggingface_dat_repo" ,
3831 reader_config = AjetTaskReader (
3932 huggingface_dat_repo = HuggingfaceDatRepo (
40- dataset_path = LOCAL_DATASET_PATH
33+ dataset_path = DATASET_PATH
4134 )
4235 )
4336 )
4437
4538 # # Hand shake with remote swarm server
46- swarm_worker = SwarmClient (REMOTE_SWARM_URL )
39+ swarm_worker = SwarmClient (AJET_SWARM_URL )
4740 swarm_worker .auto_sync_train_config_and_start_engine (
4841 AgentJetJob (
4942 experiment_name = "math_gsm8k_grpo" ,
5043 algorithm = "grpo" ,
5144 n_gpu = REMOTE_ALLOCATE_GPU_PER_NODE ,
52- model = REMOTE_TRAIN_MODEL_01 ,
45+ model = REMOTE_TRAIN_MODEL ,
5346 batch_size = REMOTE_BATCH_SIZE ,
54- num_repeat = LOCAL_GRPO_N ,
47+ num_repeat = GRPO_N ,
5548 )
5649 )
5750
@@ -62,7 +55,7 @@ def rollout(task):
6255 throttle_policy = SwarmThrottlePolicy (
6356 ratio = 0.5 ,
6457 expected_batch_size = REMOTE_BATCH_SIZE ,
65- expected_num_repeat = LOCAL_GRPO_N ,
58+ expected_num_repeat = GRPO_N ,
6659 current_task_id = task .task_id
6760 )
6861 )
@@ -76,11 +69,11 @@ def rollout(task):
7669 except :
7770 pass
7871
79- executor = BoundedThreadPoolExecutor ( max_workers = LOCAL_MAX_PARALLEL )
80- for epoch in range (LOCAL_NUM_EPOCH ):
72+ executor = PeriodicDrainThreadPoolExecutor ( workers = GRPO_N * REMOTE_BATCH_SIZE , auto_retry = True )
73+ for _ in range (NUM_EPOCH ):
8174 for _ , task in enumerate (dataset .generate_training_tasks ()):
82- for _ in range (LOCAL_GRPO_N ):
83- executor .submit ( rollout , task )
75+ for _ in range (GRPO_N ):
76+ executor .submit_with_periodic_drain ( fn = rollout , task = task )
8477
8578 return None
8679
0 commit comments