Skip to content

Commit b2c70db

Browse files
committed
tinkerscript-v1
1 parent 0487de2 commit b2c70db

17 files changed

Lines changed: 922 additions & 20 deletions

ajet/backbone/main_vllm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def run(config):
144144
max_parallel = config.ajet.debug.debug_max_parallel
145145
n_task = config.ajet.debug.debug_first_n_tasks
146146
vllm_port = config.ajet.debug.debug_vllm_port
147+
enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode
147148

148149
# --------- init ---------
149150
async_rollout_manager = ChatCompletionScheduler(
@@ -166,8 +167,10 @@ def run(config):
166167
tasks = task_reader.get_validation_tasks()
167168
logger.info(tasks[:n_task])
168169
ctx_tracker = parallel_env.rollout(
169-
tasks=tasks[:n_task], mode="sample", epoch="1"
170-
) # "sample" or "validate"
170+
tasks=tasks[:n_task],
171+
mode="sample" if not enable_tinkerscript_mode else "sample-ts", # type: ignore
172+
epoch="1"
173+
)
171174
_ = parallel_env.to_dataproto(ctx_tracker)
172175

173176

ajet/default_config/ajet_default.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,9 @@ ajet:
282282

283283
# the experimental ZeroMQ interchange server feature that allows `tuner.as_oai_baseurl_apikey` feature
284284
enable_experimental_interchange_server: True
285+
# train in cloud, run episode locally
286+
enable_tinkerscript_mode: False
287+
# both tinkerscript / oai share the same interchange server
285288
interchange_server:
286289
interchange_method: 'ipc' # options: 'tcp' (multi-nodes) or 'ipc' (1 node)
287290
interchange_server_port: 'auto'

ajet/task_rollout/single_worker.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ajet.task_rollout.async_llm_bridge import AsyncLlmBridge
1313
from ajet.task_rollout.resource_keeper import ResourceKeeper
1414
from ajet.task_runner.general_runner import GeneralRunner
15+
from ajet.task_runner.tinkerscript_runner import TinkerScriptRunner
1516
from ajet.utils.retry import retry_with_backoff
1617
from ajet.utils.sample import get_sample_params
1718
from ajet.utils.testing_utils import TestFailException, TestSuccessException
@@ -59,6 +60,7 @@ def __init__(
5960
assert isinstance(self.pad_token_id, int), "pad_token_id must be an integer"
6061
self.current_token = 0
6162
self.current_global_steps: int | str = "NA"
63+
self.enable_tinkerscript_mode = config.ajet.enable_tinkerscript_mode
6264
self.async_llm_bridge = AsyncLlmBridge(
6365
config=config,
6466
async_rollout_manager=async_rollout_manager,
@@ -110,9 +112,14 @@ def rollout_env_worker(
110112
with ResourceKeeper(workflow_task, config=self.config) as resource_keeper:
111113
try:
112114
workflow_task = resource_keeper.prepare()
113-
agent_runner = GeneralRunner(
114-
llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config
115-
)
115+
if self.enable_tinkerscript_mode:
116+
agent_runner = TinkerScriptRunner(
117+
llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config
118+
)
119+
else:
120+
agent_runner = GeneralRunner(
121+
llm_inference_fn=llm_inference_fn, tokenizer=self.tokenizer, config=self.config
122+
)
116123
tracker = agent_runner.execute(
117124
workflow_task=workflow_task,
118125
)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
2+
import atexit
3+
import json
4+
import requests
5+
import zmq
6+
import os
7+
import time
8+
from ajet import AjetTuner
9+
from ajet import WorkflowOutput
10+
from ajet.context_tracker.multiagent_tracking import (
11+
MultiAgentContextTracker,
12+
)
13+
from ajet.context_tracker.basic_tracker import BaseContextTracker
14+
from ajet.schema.task import WorkflowTask
15+
from ajet.schema.trajectory import Reward
16+
from ajet.task_runner.base_runner import BaseAgentRunner
17+
from ajet.utils.networking import find_free_port
18+
from loguru import logger
19+
from ajet import Workflow
20+
21+
context = zmq.Context()
22+
atexit.register(context.term)
23+
24+
class TinkerScriptRunner(BaseAgentRunner):
25+
26+
def get_zmq_socket(self, episode_uuid: str):
27+
interchange_method = self.config.ajet.interchange_server.interchange_method
28+
if interchange_method == 'tcp':
29+
master_node_ip = os.getenv("MASTER_NODE_IP", "localhost")
30+
episode_contect_address = f"tcp://{master_node_ip}:{find_free_port()}"
31+
elif interchange_method == 'ipc':
32+
ipc_path = f"/tmp/ajet/{episode_uuid}-workflow.sock"
33+
episode_contect_address = f"ipc://{ipc_path}"
34+
else:
35+
raise RuntimeError(f"Unknown interchange_method: {interchange_method}")
36+
return episode_contect_address
37+
38+
39+
def get_interchange_server_url(self):
40+
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
41+
if self.config.ajet.interchange_server.interchange_server_port != 'auto':
42+
port = str(int(self.config.ajet.interchange_server.interchange_server_port))
43+
assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set"
44+
master_node_ip = os.getenv("MASTER_NODE_IP", "localhost")
45+
base_url = f"http://{master_node_ip}:{port}"
46+
return base_url
47+
48+
49+
def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: str, openai_api_key: str) -> WorkflowOutput:
50+
"""Register the episode as ready in the TinkerScript data interchange center."""
51+
from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_server import RegisterEpisodeRequest
52+
53+
# parse episode_uuid, openai_base_url, openai_api_key
54+
zmq_listen_result_addr = self.get_zmq_socket(episode_uuid)
55+
interchange_http_addr = self.get_interchange_server_url()
56+
rer = RegisterEpisodeRequest(
57+
episode_uuid=episode_uuid,
58+
openai_base_url=openai_base_url,
59+
openai_api_key=openai_api_key,
60+
zmq_listen_result_addr=zmq_listen_result_addr,
61+
)
62+
logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}, interchange_http_addr: {interchange_http_addr}")
63+
64+
# send http request to tinkerscript server to register episode
65+
while True:
66+
try:
67+
response = requests.post(
68+
f"{interchange_http_addr}/register_episode",
69+
json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2
70+
timeout=30
71+
)
72+
response.raise_for_status()
73+
result = response.json()
74+
if not result.get('success'):
75+
raise RuntimeError(f"Failed to register episode {episode_uuid}")
76+
logger.info(f"Successfully registered episode {episode_uuid}")
77+
break
78+
except requests.RequestException as e:
79+
logger.error(f"Error registering episode {episode_uuid}: {e}. Retrying...")
80+
time.sleep(5)
81+
82+
# begin wait for result
83+
zmq_socket = zmq.Context().socket(zmq.REP)
84+
zmq_socket.bind(zmq_listen_result_addr)
85+
message = zmq_socket.recv_string()
86+
logger.success(f"Received workflow output for episode {episode_uuid}")
87+
zmq_socket.send_string("ack")
88+
return WorkflowOutput(**json.loads(message))
89+
90+
91+
def execute(self, workflow_task: WorkflowTask) -> BaseContextTracker:
92+
observation_window = workflow_task.observation_window
93+
task_thread_index = workflow_task.task_thread_index
94+
95+
hooks = self.runner_hooks(
96+
observation_window=observation_window,
97+
task_thread_index=task_thread_index,
98+
workflow_task=workflow_task,
99+
)
100+
context_tracker = MultiAgentContextTracker(
101+
llm_inference_fn=self.llm_inference_fn,
102+
tokenizer=self.tokenizer,
103+
config=self.config,
104+
workflow_task = workflow_task,
105+
**hooks,
106+
)
107+
tuner = AjetTuner(
108+
context_tracker=context_tracker,
109+
llm_inference_fn=self.llm_inference_fn,
110+
workflow_cls=Workflow,
111+
config=self.config,
112+
)
113+
114+
baseurl_apikey = tuner.as_oai_baseurl_apikey()
115+
base_url = baseurl_apikey.base_url
116+
api_key = baseurl_apikey.api_key
117+
118+
workflow_output: WorkflowOutput = self.register_episode_and_wait_output(
119+
episode_uuid=context_tracker.episode_uuid,
120+
openai_base_url=base_url,
121+
openai_api_key=api_key,
122+
)
123+
124+
if workflow_output.reward is not None:
125+
raw_reward, is_success = (
126+
workflow_output.reward,
127+
workflow_output.is_success,
128+
)
129+
else:
130+
raise ValueError("workflow_output.reward is None in TinkerScriptRunner, this is currently not allowed.")
131+
132+
workflow_task.gym_env = None # clear gym env client reference to avoid serialization issue
133+
134+
assert not isinstance(
135+
raw_reward, list
136+
), "AgentJet will support step reward in future versions."
137+
138+
# register reward
139+
# TODO: support multi-step reward
140+
reward = Reward(
141+
raw_reward=raw_reward,
142+
raw_step_reward=None, # "AgentJet will support step reward in future versions."
143+
success_rate=1.0 if is_success else 0.0,
144+
madness=0,
145+
description="",
146+
)
147+
context_tracker.process_reward(reward)
148+
# generate token before merging
149+
context_tracker.group_merge()
150+
# after merging, process and align reward again
151+
context_tracker.process_reward(reward)
152+
# mark the thread as ended
153+
observation_window["step"][task_thread_index] = -1
154+
tuner.terminate_episode()
155+
context_tracker.log_metrics = workflow_output.log_metrics
156+
return context_tracker

ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,17 @@ class MockAsyncChat(AsyncChat):
1818
def completions(self) -> MockAsyncCompletions: # type: ignore
1919
return MockAsyncCompletions(self._client)
2020

21+
class OpenaiBaseUrlAndApiKey(BaseModel):
22+
""" At this layer, we will determine which model to use:
23+
- training model
24+
- debug model assigned by user, used when this target is not being trained
25+
"""
26+
27+
base_url: str = Field(default="http://localhost:27788/v1", description="The base URL for the Ajet's fake OpenAI API")
28+
api_key: str = Field(default="invalid_apikey", description="The Ajet's fake key, which is not a real key, it is a encoded string contain episode_uuid and other stuff.")
29+
model: str = Field(default="reserved_field", description="reserved field.")
30+
31+
2132
class OpenaiClientBaseUrlTuner(BaseModel):
2233
""" At this layer, we will determine which model to use:
2334
- training model
@@ -40,6 +51,9 @@ def __init__(
4051
):
4152

4253
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
54+
if config.ajet.interchange_server.interchange_server_port != 'auto':
55+
port = str(int(config.ajet.interchange_server.interchange_server_port))
56+
4357
assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set"
4458
master_node_ip = os.getenv("MASTER_NODE_IP", "localhost")
4559

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor
1717
from ajet.utils.networking import find_free_port
1818

19-
2019
context = zmq.Context()
2120
atexit.register(context.term)
2221

@@ -158,7 +157,7 @@ def _begin_service_threading(self):
158157
break
159158
timepassed = time.time() - begin_time
160159
if timepassed > 60:
161-
logger.warning(f"[client] {self.episode_uuid} | Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...")
160+
if DEBUG: logger.warning(f"[client] {self.episode_uuid} | Still waiting for first message... (time passed {timepassed}) for episode_uuid:{self.episode_uuid}...")
162161
continue
163162

164163
# parse the incoming request

0 commit comments

Comments
 (0)