Skip to content

Commit 175e259

Browse files
committed
rename to agentjet swarm
1 parent 3157658 commit 175e259

29 files changed

+202
-183
lines changed

ajet/backbone/main_vllm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def run(config):
144144
max_parallel = config.ajet.debug.debug_max_parallel
145145
n_task = config.ajet.debug.debug_first_n_tasks
146146
vllm_port = config.ajet.debug.debug_vllm_port
147-
enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode
147+
enable_swarm_mode = config.ajet.enable_swarm_mode
148148

149149
# --------- init ---------
150150
async_rollout_manager = ChatCompletionScheduler(
@@ -168,7 +168,7 @@ def run(config):
168168
logger.info(tasks[:n_task])
169169
ctx_tracker = parallel_env.rollout(
170170
tasks=tasks[:n_task],
171-
mode="sample" if not enable_tinkerscript_mode else "sample-ts", # type: ignore
171+
mode="sample" if not enable_swarm_mode else "sample-ts", # type: ignore
172172
epoch="1"
173173
)
174174
_ = parallel_env.to_dataproto(ctx_tracker)
@@ -189,7 +189,7 @@ def main(config):
189189
if config.ajet.enable_experimental_interchange_server:
190190
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
191191
start_interchange_server(config)
192-
if config.ajet.enable_tinkerscript_mode:
192+
if config.ajet.enable_swarm_mode:
193193
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
194194
http_change_engine_status(config, "ENGINE.ROLLING")
195195

ajet/backbone/trainer_verl.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def init_workers(self):
458458

459459
def _update_interchange_server_status_flag(self, status: str):
460460
if self.config.ajet.enable_experimental_interchange_server:
461-
if self.config.ajet.enable_tinkerscript_mode:
461+
if self.config.ajet.enable_swarm_mode:
462462
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
463463
http_change_engine_status(self.config, status)
464464

@@ -493,7 +493,7 @@ def fit(self): # noqa: C901
493493

494494
# perform validation before training
495495
# currently, we only support validation using the reward_function.
496-
if (self.val_reward_fn is not None) and (self.config.trainer.get("val_before_train", True)) and (not self.config.ajet.enable_tinkerscript_mode):
496+
if (self.val_reward_fn is not None) and (self.config.trainer.get("val_before_train", True)) and (not self.config.ajet.enable_swarm_mode):
497497
val_metrics = self._validate()
498498
assert val_metrics, f"{val_metrics=}"
499499
pprint(f"Initial validation metrics: {val_metrics}")
@@ -651,7 +651,7 @@ def fit(self): # noqa: C901
651651
[str(uuid.uuid4()) for _ in range(len(batch.batch))],
652652
dtype=object,
653653
)
654-
discard_original_batch = self.config.ajet.enable_tinkerscript_mode
654+
discard_original_batch = self.config.ajet.enable_swarm_mode
655655
batch = union_gen_batch_via_task_id(tasks, batch, gen_batch_output, discard_original_batch)
656656
batch.batch["response_mask"] = compute_response_mask(batch)
657657

@@ -784,7 +784,7 @@ def fit(self): # noqa: C901
784784
self.val_reward_fn is not None
785785
and self.config.trainer.test_freq > 0
786786
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
787-
and (not self.config.ajet.enable_tinkerscript_mode)
787+
and (not self.config.ajet.enable_swarm_mode)
788788
):
789789
with marked_timer("testing", timing_raw, color="green"):
790790
val_metrics: dict = self._validate()
@@ -958,7 +958,7 @@ def _validate(self):
958958
dtype=object,
959959
)
960960
tasks = tasks[: len(main_val_dataset)]
961-
discard_original_batch = self.config.ajet.enable_tinkerscript_mode
961+
discard_original_batch = self.config.ajet.enable_swarm_mode
962962
test_batch = union_gen_batch_via_task_id(tasks, test_batch, test_output_gen_batch, discard_original_batch)
963963
# test_batch = test_batch.union(test_output_gen_batch)
964964
test_batch.meta_info["validate"] = True

ajet/context_tracker/multiagent_tracking.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ def __init__(
4848
self,
4949
tokenizer: PreTrainedTokenizer,
5050
config,
51-
should_interrupt_fn,
51+
should_interrupt_soft_fn,
5252
should_interrupt_hard_fn,
5353
generated_token_callback_fn,
5454
**kwargs,
5555
):
5656
super().__init__(config, tokenizer, **kwargs)
5757
self.tokenizer = tokenizer
58-
self.should_interrupt_fn = should_interrupt_fn
58+
self.should_interrupt_soft_fn = should_interrupt_soft_fn
5959
self.should_interrupt_hard_fn = should_interrupt_hard_fn
6060
self.generated_token_callback_fn = generated_token_callback_fn
6161
self.context_overflow = False
@@ -601,7 +601,7 @@ def check_context_token_num_safe(
601601
token_overflow = False
602602
else:
603603
token_overflow = True
604-
if self.should_interrupt_fn():
604+
if self.should_interrupt_soft_fn():
605605
ret = (False, token_overflow, "externally_interrupted")
606606
elif self.already_mad_flag and self.config.ajet.rollout.agent_madness_termination:
607607
ret = (False, token_overflow, "already_mad")

ajet/copilot/job.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ def __init__(
4545
n_gpu_for_infer: int | None = None, # only for trinity backbone
4646
grpo_n: int = 8,
4747
batch_size: int = 32,
48-
tinkerscript_mode: bool = True,
48+
swarm_mode: bool = True,
4949
*kwargs,
5050
) -> None:
5151
self.backbone = backbone
52-
if tinkerscript_mode:
52+
if swarm_mode:
5353
default_yaml = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', "default_config/ajet_ts_default.yaml"))
5454
else:
5555
default_yaml = None

ajet/default_config/ajet_default.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,8 @@ ajet:
282282

283283

284284
# the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature
285-
enable_tinkerscript_mode: False
286-
# both tinkerscript / oai share the same interchange server
285+
enable_swarm_mode: False
286+
# both swarm / oai share the same interchange server
287287
enable_experimental_interchange_server: False
288288
# interchange server configuration
289289
interchange_server:
@@ -292,7 +292,7 @@ ajet:
292292
num_fastapi_process: 2 # 1, 2 or 4 is fine
293293
max_fastapi_threads: 512 # 64 or 128 is fine
294294
max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker`
295-
already_started: False # do not edit, used by `tinkerscript`
295+
already_started: False # do not edit, used by `swarm`
296296

297297

298298
task_runner:

ajet/default_config/ajet_ts_default.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ ajet:
2323
# the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature
2424
enable_experimental_interchange_server: True
2525
# train in cloud, run episode locally
26-
enable_tinkerscript_mode: True
27-
# both tinkerscript / oai share the same interchange server
26+
enable_swarm_mode: True
27+
# both swarm / oai share the same interchange server
2828
interchange_server:
2929
interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node)
3030
interchange_server_port: 10086
3131
num_fastapi_process: 2 # 1, 2 or 4 is fine
3232
max_fastapi_threads: 512 # 64 or 128 is fine
3333
max_inference_tracker_threads: 64 # recommend to be equal to `ajet.rollout.max_env_worker`
34-
already_started: False # do not edit, used by `tinkerscript`
34+
already_started: False # do not edit, used by `swarm`
3535

3636
rollout:
3737
# maximum number of parallel environments / simulate workers

ajet/launcher.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ def parse_args():
3535
help="verl or trinity or debug",
3636
)
3737
parser.add_argument(
38-
"--tinkerscript-server",
38+
"--swarm-server",
3939
action="store_true",
4040
default=False,
41-
help="Enable TinkerScript server mode",
41+
help="Enable Swarm server mode",
4242
)
4343
parser.add_argument(
4444
"--conf",
@@ -146,12 +146,12 @@ def check_model_file_exists(exp_config):
146146
assert os.path.exists(model_path), f"Model path {model_path} does not exist. Please check your configuration."
147147

148148

149-
def start_tinkerscript_server(env, config):
149+
def start_swarm_server(env, config):
150150
config = dict_to_namespace(config)
151-
assert config.ajet.enable_tinkerscript_mode, \
152-
"Please enable_tinkerscript_mode in config to start tinkerscript server."
151+
assert config.ajet.enable_swarm_mode, \
152+
"Please enable_swarm_mode in config to start swarm server."
153153
assert config.ajet.enable_experimental_interchange_server, \
154-
"Please enable_experimental_interchange_server in config to start tinkerscript server."
154+
"Please enable_experimental_interchange_server in config to start swarm server."
155155
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
156156
start_interchange_server(config, blocking=True, env=env)
157157

@@ -191,9 +191,9 @@ def main():
191191
# read configuration from yaml
192192
exp_config = None
193193
exp_dir = args.exp_dir or "saved_experiments"
194-
if args.tinkerscript_server and (not args.conf):
194+
if args.swarm_server and (not args.conf):
195195
args.conf = os.path.abspath(os.path.join(os.path.dirname(__file__), "default_config/ajet_ts_default.yaml"))
196-
assert os.path.exists(args.conf), "Please provide a valid config file for tinkerscript server mode."
196+
assert os.path.exists(args.conf), "Please provide a valid config file for swarm server mode."
197197
if args.conf:
198198
yaml_path = args.conf
199199
(
@@ -206,8 +206,8 @@ def main():
206206
# setup environment variables
207207
env, exp_config = setup_environment_vars(args, exp_config, main_yaml_fp)
208208

209-
if args.tinkerscript_server:
210-
start_tinkerscript_server(env, exp_config)
209+
if args.swarm_server:
210+
start_swarm_server(env, exp_config)
211211
return
212212

213213
if args.with_ray:

ajet/task_reader/document_reader/doc_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
try:
1212
from unstructured.partition.auto import partition
1313
except Exception:
14-
logger.warning("Cannot import dependency `unstructured`")
14+
logger.info("`unstructured` is not installed.")
1515

1616
from ajet.schema.document import Document
1717
from ajet.task_reader.document_reader.document_reader_base import (

ajet/task_rollout/native_parallel_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def rollout(
144144
epoch: str,
145145
) -> List[BaseContextTracker]:
146146
"""Delegate to dynamic rollout when oversampling is enabled."""
147-
if self.config.ajet.enable_tinkerscript_mode:
147+
if self.config.ajet.enable_swarm_mode:
148148
return self.rollout_swarm(tasks, mode, epoch)
149149
elif (
150150
mode == "sample"

ajet/task_rollout/single_worker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ajet.task_rollout.async_llm_bridge import AsyncLlmBridge
1616
from ajet.task_rollout.resource_keeper import ResourceKeeper
1717
from ajet.task_runner.general_runner import GeneralRunner
18-
from ajet.task_runner.tinkerscript_runner import TinkerScriptRunner
18+
from ajet.task_runner.swarm_runner import SwarmRunner
1919
from ajet.utils.retry import retry_with_backoff
2020
from ajet.utils.retry import SwarmReceiveAbortException
2121
from ajet.utils.sample import get_sample_params
@@ -64,7 +64,7 @@ def __init__(
6464
assert isinstance(self.pad_token_id, int), "pad_token_id must be an integer"
6565
self.current_token = 0
6666
self.current_global_steps: int | str = "NA"
67-
self.enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode
67+
self.enable_swarm_mode = config.ajet.enable_swarm_mode
6868
self.async_llm_bridge = AsyncLlmBridge(
6969
config=config,
7070
async_rollout_manager=async_rollout_manager,
@@ -116,8 +116,8 @@ def rollout_env_worker(
116116
with ResourceKeeper(workflow_task, config=self.config) as resource_keeper:
117117
try:
118118
workflow_task = resource_keeper.prepare()
119-
if self.enable_tinkerscript_mode:
120-
agent_runner = TinkerScriptRunner(
119+
if self.enable_swarm_mode:
120+
agent_runner = SwarmRunner(
121121
llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config
122122
)
123123
else:

0 commit comments

Comments
 (0)