88from __future__ import annotations
99
1010import os
11+ import time
12+ import yaml
1113import tempfile
12- from types import SimpleNamespace
13- from typing import Any , Callable , Union
1414
15- import yaml
15+ from types import SimpleNamespace
16+ from typing import Any , Callable , Union , cast
1617from loguru import logger
17-
18-
1918from ajet .default_config .ajet_default import Config
2019from ajet .utils .config_utils import (
2120 expand_ajet_hierarchical_config ,
3029 setup_environment_vars ,
3130)
3231
33- DEFAULT_DIR = "saved_experiments"
32+
33+ def override_current_yaml_value_if_given (override_value , current_value ):
34+ if override_value is not None :
35+ return override_value
36+ else :
37+ return current_value
38+
39+ def _set_nested_attr (obj , attr_path : str , value ):
40+ keys = attr_path .split ("." )
41+ for key in keys [:- 1 ]:
42+ obj = getattr (obj , key )
43+ setattr (obj , keys [- 1 ], value )
44+
45+ def _get_nested_attr (obj , attr_path : str ):
46+ for key in attr_path .split ("." ):
47+ obj = getattr (obj , key )
48+ return obj
3449
3550class AgentJetJob :
36- """Lightweight builder that launches AgentJet training as a subprocess."""
51+ """
52+ arg: base_yaml_config + **kwargs (yaml config, then override with kwargs)
53+ arg: base_yaml_config (yaml config)
54+ arg: **kwargs (yaml config, then override with kwargs)
55+ """
3756
3857 def __init__ (
3958 self ,
40- backbone : str = "verl" ,
41- model : str = "Qwen/Qwen2___5-7B-Instruct" ,
42- n_gpu : int = 8 ,
43- algorithm : str = "grpo" ,
44- project_name = "ajet-swarm" ,
45- experiment_name = "test" ,
46- n_gpu_for_infer : int | None = None , # only for trinity backbone
47- num_repeat : int = 8 ,
48- batch_size : int = 32 ,
49- swarm_mode : bool = True ,
50- sample_collection_method : str = "rollout_until_finish_enough_tasks" ,
51- * kwargs ,
59+ base_yaml_config : str | None = None ,
60+ experiment_dir : str | None = None ,
61+ project_name : str | None = None ,
62+ experiment_name : str | None = None ,
63+ n_gpu : int | None = None ,
64+ model : str | None = None ,
65+ algorithm : str | None = None ,
66+ num_repeat : int | None = None ,
67+ batch_size : int | None = None ,
68+ swarm_mode : bool | None = None ,
69+ swarm_mode_sample_collection_method : str | None = None ,
70+ max_env_worker : int | None = None ,
71+ backbone : str | None = None ,
5272 ) -> None :
53- self .backbone = backbone
54- self .exp_dir = DEFAULT_DIR
55- self .project_name = project_name
56- self .exp_name = experiment_name
57- self .sample_collection_method = sample_collection_method
58- if swarm_mode :
59- default_yaml = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , "default_config/ajet_ts_default.yaml" ))
73+
74+ if base_yaml_config is None :
75+ base_yaml_config = os .path .abspath (os .path .join (os .path .dirname (__file__ ), '..' , "default_config/ajet_ts_default.yaml" ))
6076 else :
61- default_yaml = None
62- self .config_as_dict : dict = self .build_job_from_yaml (default_yaml )
77+ logger .warning (f"Reading config from { base_yaml_config } ." )
78+ time .sleep (1 )
79+ self .config_as_dict : dict = self .build_job_from_yaml (base_yaml_config )
6380 self .config = Config .update_from_dict_recursive (Config (), self .config_as_dict )
6481
65- self .config .ajet .experiment_name = experiment_name
66- self .config .ajet .backbone = backbone
67- self .config .ajet .model .path = model
68- self .config .ajet .trainer_common .n_gpus_per_node = n_gpu
69- self .config .ajet .trainer_common .algorithm .adv_estimator = algorithm
70- self .config .ajet .rollout .num_repeat = num_repeat
71- self .config .ajet .data .train_batch_size = batch_size
72- self .config .ajet .enable_swarm_mode = swarm_mode
73- self .config .ajet .swarm_mode_sample_collection_method = sample_collection_method
74- if n_gpu_for_infer is None and backbone == "trinity" :
75- raise ValueError ("Please specify `n_gpu_for_infer` (n_gpu_for_infer < n_gpu) for trinity backbone." )
76- if (n_gpu_for_infer is not None ) and backbone == "verl" :
77- raise ValueError ("n_gpu_for_infer is only for trinity backbone, please set it to `None`." )
78- else :
79- if backbone == "trinity" :
80- assert isinstance (n_gpu_for_infer , int ), f"`n_gpu_for_infer` should be int, got { type (n_gpu_for_infer )} ."
81- assert n_gpu_for_infer < n_gpu , "`n_gpu_for_infer` should be less than `n_gpu`."
82- self .config .ajet .rollout .n_vllm_engine = n_gpu_for_infer
83- self .config .ajet .rollout .tensor_model_parallel_size = 1
82+ self .base_yaml_config : str = cast (str , base_yaml_config ) # currently may be None, but will be set later
83+ self .experiment_dir : str = cast (str , experiment_dir )
84+ self .project_name : str = cast (str , project_name )
85+ self .experiment_name : str = cast (str , experiment_name )
86+ self .n_gpu : int = cast (int , n_gpu )
87+ self .model : str = cast (str , model )
88+ self .algorithm : str = cast (str , algorithm )
89+ self .num_repeat : int = cast (int , num_repeat )
90+ self .batch_size : int = cast (int , batch_size )
91+ self .swarm_mode : bool = cast (bool , swarm_mode )
92+ self .swarm_mode_sample_collection_method : str = cast (str , swarm_mode_sample_collection_method )
93+ self .max_env_worker : int = cast (int , max_env_worker )
94+ self .backbone : str = cast (str , backbone )
95+
96+ # see `ajet/default_config/ajet_ts_default.yaml`
97+ overrides = {
98+ "ajet.experiment_dir" : "experiment_dir" ,
99+ "ajet.project_name" : "project_name" ,
100+ "ajet.experiment_name" : "experiment_name" ,
101+ "ajet.model.path" : "model" ,
102+ "ajet.trainer_common.n_gpus_per_node" : "n_gpu" ,
103+ "ajet.trainer_common.algorithm.adv_estimator" : "algorithm" ,
104+ "ajet.rollout.num_repeat" : "num_repeat" ,
105+ "ajet.data.train_batch_size" : "batch_size" ,
106+ "ajet.enable_swarm_mode" : "swarm_mode" ,
107+ "ajet.swarm_mode_sample_collection_method" : "swarm_mode_sample_collection_method" ,
108+ "ajet.rollout.max_env_worker" : "max_env_worker" ,
109+ "ajet.backbone" : "backbone" ,
110+ }
111+
112+ # if any value given in kwargs, override the corresponding value in config
113+ for attr_path , override_val in overrides .items ():
114+ # get value from yaml config
115+ # >> e.g. current_model = self.config.model.path
116+ current_val = _get_nested_attr (self .config , attr_path )
117+
118+ # if override_val (given in __init__) is not None, use it to override the value from yaml config
119+ # >> e.g. new_model = self.model if (self.model is not None) else current_model
120+ new_val = override_current_yaml_value_if_given (getattr (self , override_val ), current_val )
121+
122+ # write final value to `self.config``
123+ # >> e.g. self.config.model.path = new_model
124+ _set_nested_attr (self .config , attr_path , new_val )
125+
126+ # write final value to `self`
127+ # >> e.g. self.model = new_model
128+ setattr (self , override_val , new_val )
129+
130+ if self .backbone == "trinity" :
131+ raise NotImplementedError ("Trinity backbone is not yet supported in AgentJetJob." )
132+
84133
85134 def build_job_from_yaml (self , yaml_path : str | None ) -> dict :
86135 self .config_as_dict = read_ajet_hierarchical_config (
87136 yaml_path ,
88- exp_name = self .exp_name ,
89- backbone = self .backbone ,
90137 write_to = None ,
91- exp_dir = self .exp_dir ,
92138 )
93139 self .config_as_dict = expand_ajet_hierarchical_config (self .config_as_dict , write_to = None )
94140 logger .info (f"Built AgentJet job config: { yaml_path } " )
95141 return self .config_as_dict
96142
143+
97144 def dump_job_as_yaml (self , yaml_path : str ) -> str :
98145 if os .path .dirname (yaml_path ):
99146 os .makedirs (os .path .dirname (yaml_path ), exist_ok = True )
@@ -102,6 +149,7 @@ def dump_job_as_yaml(self, yaml_path: str) -> str:
102149 logger .info (f"Saved training config to { yaml_path } " )
103150 return yaml_path
104151
152+
105153 def set_workflow (
106154 self , workflow : Union [str , Callable [..., Any ]], ensure_reward_in_workflow : bool = False
107155 ) -> "AgentJetJob" :
@@ -110,6 +158,7 @@ def set_workflow(
110158 # ensure_reward_in_workflow
111159 return self
112160
161+
113162 def set_data (
114163 self ,
115164 type : str ,
@@ -136,60 +185,3 @@ def set_data(
136185
137186 return self
138187
139- def tune (self , * args , ** kwargs ) -> "AgentJetJob" :
140- import ray
141- ast_cfg = self .config .ajet
142- if not ast_cfg .rollout or not ast_cfg .rollout .user_workflow :
143- raise ValueError ("Workflow must be set via set_workflow before tuning." )
144- if not ast_cfg .task_reader :
145- raise ValueError ("Data source must be set via set_data before tuning." )
146-
147- backbone = self .config .ajet .backbone
148- exp_dir = self .config .ajet .experiment_dir
149-
150- with tempfile .NamedTemporaryFile (mode = "w+" , delete = False , suffix = ".yaml" ) as temp_yaml :
151- yaml_path = temp_yaml .name
152- self .dump_job_as_yaml (yaml_path )
153- args = SimpleNamespace (
154- conf = yaml_path ,
155- backbone = backbone ,
156- exp_dir = exp_dir ,
157- with_logview = False ,
158- debug = False ,
159- )
160-
161- if args .backbone != "debug" :
162- # Enforce GPU availability and free memory threshold before proceeding
163- check_avail_gpu (min_free_ratio = 0.95 )
164-
165- # finalize experiment config
166- main_yaml_fp , exe_exp_base , exp_name , exp_config = prepare_experiment_config (
167- yaml_path , exp_dir , backbone
168- )
169-
170- # setup environment variables for ray
171- env = setup_environment_vars (args , exp_config , main_yaml_fp )
172-
173- # start ray if not already started
174- if not ray .is_initialized ():
175- from ajet .utils .launch_utils import start_ray_service
176-
177- start_ray_service (args , env )
178- else :
179- raise RuntimeError (
180- "Ray is already initialized. Please shutdown existing Ray instance before starting a new tuning job."
181- )
182-
183- # start training process
184- if args .conf and main_yaml_fp and exe_exp_base and exp_config :
185- execute_training_process (
186- args ,
187- get_backbone_target (args .backbone ),
188- main_yaml_fp ,
189- exe_exp_base ,
190- main_yaml_fp ,
191- env ,
192- exp_config ,
193- )
194-
195- return self
0 commit comments