Skip to content

Commit 21f9bb8

Browse files
committed
improve tinkerscript
1 parent b2c70db commit 21f9bb8

11 files changed

Lines changed: 413 additions & 165 deletions

ajet/backbone/main_vllm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ def main(config):
189189
if config.ajet.enable_experimental_interchange_server:
190190
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import start_interchange_server
191191
start_interchange_server(config)
192+
if config.ajet.enable_tinkerscript_mode:
193+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
194+
http_change_engine_status(config, "ROLLING")
192195

193196
def companion_launch():
194197
import torch

ajet/backbone/trainer_verl.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,13 @@ def init_workers(self):
443443
tokenizer=self.tokenizer,
444444
)
445445

446+
def _update_interchange_server_status_flag(self, status: str):
447+
# if interchange server is enabled, change engine status to ROLLING
448+
if self.config.ajet.enable_experimental_interchange_server:
449+
if self.config.ajet.enable_tinkerscript_mode:
450+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_change_engine_status
451+
http_change_engine_status(self.config, status)
452+
446453
# #######################################
447454
# training loop
448455
# #######################################
@@ -552,6 +559,7 @@ def fit(self): # noqa: C901
552559
assert self.async_rollout_mode
553560
logger.info("=== wake up begin ===")
554561
self.async_rollout_manager.wake_up()
562+
self._update_interchange_server_status_flag("ROLLING")
555563
logger.info("=== wake up end ===")
556564
tasks: List[Task] = [
557565
dict_to_ajet_task(dict(
@@ -577,6 +585,7 @@ def fit(self): # noqa: C901
577585
tasks, mode="sample", epoch=f"train.{epoch}"
578586
)
579587
logger.info("=" * 10 + "end fit rollout" + "=" * 10)
588+
self._update_interchange_server_status_flag("UPDATE_WEIGHT")
580589
logger.info("begin to convert context_tracker_arr to dataproto")
581590
gen_batch_output = self.parallel_env.to_dataproto(context_tracker_arr)
582591
logger.info("end convertion")

ajet/task_runner/tinkerscript_runner.py

Lines changed: 9 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11

22
import atexit
33
import json
4-
import requests
54
import zmq
65
import os
7-
import time
86
from ajet import AjetTuner
97
from ajet import WorkflowOutput
108
from ajet.context_tracker.multiagent_tracking import (
@@ -14,77 +12,38 @@
1412
from ajet.schema.task import WorkflowTask
1513
from ajet.schema.trajectory import Reward
1614
from ajet.task_runner.base_runner import BaseAgentRunner
17-
from ajet.utils.networking import find_free_port
15+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import http_register_episode, get_zmq_socket
1816
from loguru import logger
1917
from ajet import Workflow
2018

2119
context = zmq.Context()
2220
atexit.register(context.term)
21+
DEBUG = True
2322

2423
class TinkerScriptRunner(BaseAgentRunner):
2524

26-
def get_zmq_socket(self, episode_uuid: str):
27-
interchange_method = self.config.ajet.interchange_server.interchange_method
28-
if interchange_method == 'tcp':
29-
master_node_ip = os.getenv("MASTER_NODE_IP", "localhost")
30-
episode_contect_address = f"tcp://{master_node_ip}:{find_free_port()}"
31-
elif interchange_method == 'ipc':
32-
ipc_path = f"/tmp/ajet/{episode_uuid}-workflow.sock"
33-
episode_contect_address = f"ipc://{ipc_path}"
34-
else:
35-
raise RuntimeError(f"Unknown interchange_method: {interchange_method}")
36-
return episode_contect_address
37-
38-
39-
def get_interchange_server_url(self):
40-
port = os.getenv("AJET_DAT_INTERCHANGE_PORT")
41-
if self.config.ajet.interchange_server.interchange_server_port != 'auto':
42-
port = str(int(self.config.ajet.interchange_server.interchange_server_port))
43-
assert port is not None, "AJET_DAT_INTERCHANGE_PORT env var must be set"
44-
master_node_ip = os.getenv("MASTER_NODE_IP", "localhost")
45-
base_url = f"http://{master_node_ip}:{port}"
46-
return base_url
47-
48-
4925
def register_episode_and_wait_output(self, episode_uuid: str, openai_base_url: str, openai_api_key: str) -> WorkflowOutput:
5026
"""Register the episode as ready in the TinkerScript data interchange center."""
51-
from ajet.tuner_lib.weight_tuner.experimental.as_tinkerscript_server import RegisterEpisodeRequest
52-
5327
# parse episode_uuid, openai_base_url, openai_api_key
54-
zmq_listen_result_addr = self.get_zmq_socket(episode_uuid)
55-
interchange_http_addr = self.get_interchange_server_url()
56-
rer = RegisterEpisodeRequest(
28+
zmq_listen_result_addr, ipc_path = get_zmq_socket(self.config, episode_uuid, tag="workflow")
29+
http_register_episode(
30+
self.config,
5731
episode_uuid=episode_uuid,
5832
openai_base_url=openai_base_url,
5933
openai_api_key=openai_api_key,
6034
zmq_listen_result_addr=zmq_listen_result_addr,
6135
)
62-
logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}, interchange_http_addr: {interchange_http_addr}")
63-
64-
# send http request to tinkerscript server to register episode
65-
while True:
66-
try:
67-
response = requests.post(
68-
f"{interchange_http_addr}/register_episode",
69-
json=rer.model_dump(), # 或者 rer.model_dump() 如果使用 Pydantic v2
70-
timeout=30
71-
)
72-
response.raise_for_status()
73-
result = response.json()
74-
if not result.get('success'):
75-
raise RuntimeError(f"Failed to register episode {episode_uuid}")
76-
logger.info(f"Successfully registered episode {episode_uuid}")
77-
break
78-
except requests.RequestException as e:
79-
logger.error(f"Error registering episode {episode_uuid}: {e}. Retrying...")
80-
time.sleep(5)
36+
logger.info(f"zmq_listen_result_addr: {zmq_listen_result_addr}")
8137

8238
# begin wait for result
8339
zmq_socket = zmq.Context().socket(zmq.REP)
8440
zmq_socket.bind(zmq_listen_result_addr)
8541
message = zmq_socket.recv_string()
8642
logger.success(f"Received workflow output for episode {episode_uuid}")
8743
zmq_socket.send_string("ack")
44+
zmq_socket.close()
45+
if ipc_path and os.path.exists(ipc_path): os.remove(ipc_path)
46+
8847
return WorkflowOutput(**json.loads(message))
8948

9049

ajet/tuner_lib/weight_tuner/as_oai_baseurl_apikey.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class OpenaiBaseUrlAndApiKey(BaseModel):
2727
base_url: str = Field(default="http://localhost:27788/v1", description="The base URL for the Ajet's fake OpenAI API")
2828
api_key: str = Field(default="invalid_apikey", description="The Ajet's fake key, which is not a real key, it is a encoded string contain episode_uuid and other stuff.")
2929
model: str = Field(default="reserved_field", description="reserved field.")
30+
episode_uuid: str = Field(default="episode_id", description="reserved field.")
3031

3132

3233
class OpenaiClientBaseUrlTuner(BaseModel):

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_client.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from openai.types.chat.chat_completion import ChatCompletion
1515
from ajet.tuner_lib.weight_tuner.experimental.as_oai_model_server import InterchangeCompletionRequest, API_KEY_PREFIX
1616
from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor
17-
from ajet.utils.networking import find_free_port
17+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import get_zmq_socket
1818

1919
context = zmq.Context()
2020
atexit.register(context.term)
@@ -67,17 +67,11 @@ def __init__(self, episode_uuid: str, context_tracker: "MultiAgentContextTracker
6767
self.llm_inference_fn = llm_inference_fn
6868
self.config = config
6969
self._should_terminate = False
70-
70+
self.episode_contect_address, ipc_path = get_zmq_socket(config, episode_uuid, tag="llm")
71+
self.ipc_path = ipc_path
7172
self.interchange_method = config.ajet.interchange_server.interchange_method
72-
if self.interchange_method == 'tcp':
73-
master_node_ip = os.getenv("MASTER_NODE_IP", "localhost")
74-
self.episode_contect_address = f"tcp://{master_node_ip}:{find_free_port()}"
75-
elif self.interchange_method == 'ipc':
76-
self.ipc_path = f"/tmp/ajet/{self.episode_uuid}.sock"
77-
self.episode_contect_address = f"ipc://{self.ipc_path}"
7873
self.max_inference_tracker_threads = config.ajet.interchange_server.max_inference_tracker_threads
7974

80-
8175
async def llm_infer(
8276
self,
8377
req: ChatCompletionRequest,

ajet/tuner_lib/weight_tuner/experimental/as_oai_model_server.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
3434
from openai.types.chat.chat_completion import ChatCompletion
3535

36+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import EpisodeStatus
37+
3638
API_KEY_PREFIX = "sk-ajet-"
3739

3840
class InterchangeCompletionRequest(BaseModel):
@@ -151,6 +153,21 @@ async def chat_completions(request: Request, authorization: str = Header(None)):
151153
# Create timeline UUID
152154
timeline_uuid = uuid.uuid4().hex
153155

156+
# enable_tinkerscript_mode
157+
if enable_tinkerscript_mode:
158+
assert shared_mem_dict is not None
159+
assert shared_mem_dict_lock is not None
160+
if shared_mem_dict['engine_status'] != "ROLLING":
161+
logger.error(f"The server is not in ROLLING status (current status: [{shared_mem_dict['engine_status']}]), cannot accept new requests.")
162+
raise HTTPException(status_code=503, detail="The server is not in ROLLING status, cannot accept new requests.")
163+
if (f"episodes-{episode_uuid}") not in shared_mem_dict:
164+
raise HTTPException(status_code=404, detail=f"Episode {episode_uuid} not found.")
165+
# update activate timestamp
166+
with shared_mem_dict_lock:
167+
es:EpisodeStatus = shared_mem_dict[f"episodes-{episode_uuid}"]
168+
es.latest_activity_timestamp = time.time()
169+
shared_mem_dict[f"episodes-{episode_uuid}"] = es
170+
154171
# Add to received queue
155172
int_req = InterchangeCompletionRequest(
156173
completion_request = new_req,

ajet/tuner_lib/weight_tuner/experimental/as_tinkerscript_client.py

Lines changed: 77 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,47 +2,32 @@
22
import time
33
import httpx
44
import yaml
5+
from typing import List, Tuple
56
from loguru import logger
6-
from pydantic import BaseModel
77
from ajet.schema.task import WorkflowOutput
88
from ajet.copilot.job import AgentJetJob
99
from ajet.tuner_lib.weight_tuner.as_oai_baseurl_apikey import OpenaiBaseUrlAndApiKey
10+
from ajet.tuner_lib.weight_tuner.experimental.interchange_utils import (
11+
SyncTrainConfigRequest,
12+
ClaimEpisodeRequest,
13+
ClaimEpisodeResponse,
14+
CanContinueEpisodeRequest,
15+
CanContinueEpisodeResponse,
16+
EndEpisodeRequest,
17+
EndEpisodeResponse,
18+
EpisodeStatus,
19+
EpisodeBufferResponse,
20+
)
1021

11-
# --- Schema Definitions ---
12-
13-
class SyncTrainConfigRequest(BaseModel):
14-
yaml_as_string: str
15-
16-
class ClaimEpisodeRequest(BaseModel):
17-
client_uuid: str
18-
episode_type: str
19-
20-
class ClaimEpisodeResponse(BaseModel):
21-
success: bool
22-
client_uuid: str
23-
episode_uuid: str
24-
openai_base_url: str = ""
25-
openai_api_key: str = ""
26-
fail_cause: str = ""
27-
28-
class EndEpisodeRequest(BaseModel):
29-
client_uuid: str
30-
episode_uuid: str
31-
workflow_output: WorkflowOutput
32-
33-
class EndEpisodeResponse(BaseModel):
34-
success: bool
3522

3623
class TinkerScriptClient(object):
3724

3825
def __init__(self, server_url: str):
3926
self.server_url = server_url
4027
self.client_uuid = str(uuid.uuid4())
41-
self.episode_uuid = None
42-
self.openai_base_url = None
43-
self.openai_api_key = None
4428

45-
def begin_episode(self) -> OpenaiBaseUrlAndApiKey:
29+
30+
def begin_episode(self, allow_discard_timeout=60) -> Tuple[str, OpenaiBaseUrlAndApiKey]:
4631
"""
4732
Block until an episode is claimed.
4833
Return (episode_uuid, openai_base_url, openai_api_key)
@@ -51,7 +36,8 @@ def begin_episode(self) -> OpenaiBaseUrlAndApiKey:
5136
try:
5237
req_obj = ClaimEpisodeRequest(
5338
client_uuid=self.client_uuid,
54-
episode_type="default"
39+
episode_type="default",
40+
allow_discard_timeout=allow_discard_timeout,
5541
)
5642
resp = httpx.post(
5743
f"{self.server_url}/claim_episode",
@@ -60,15 +46,17 @@ def begin_episode(self) -> OpenaiBaseUrlAndApiKey:
6046
)
6147
resp.raise_for_status()
6248
data = ClaimEpisodeResponse.model_validate(resp.json())
49+
episode_uuid = data.episode_uuid
6350

6451
if data.success:
65-
self.episode_uuid = data.episode_uuid
66-
self.openai_base_url = data.openai_base_url
67-
self.openai_api_key = data.openai_api_key
68-
logger.info(f"Claimed episode {self.episode_uuid}")
69-
return OpenaiBaseUrlAndApiKey(
70-
base_url=self.openai_base_url,
71-
api_key=self.openai_api_key,
52+
episode_uuid = data.episode_uuid
53+
openai_base_url = data.openai_base_url
54+
openai_api_key = data.openai_api_key
55+
logger.info(f"Claimed episode {episode_uuid}")
56+
return episode_uuid, OpenaiBaseUrlAndApiKey(
57+
base_url=openai_base_url,
58+
api_key=openai_api_key,
59+
episode_uuid=episode_uuid
7260
)
7361
else:
7462
logger.info(f"Failed to claim episode: {data.fail_cause}. Retrying in 5s...")
@@ -77,15 +65,15 @@ def begin_episode(self) -> OpenaiBaseUrlAndApiKey:
7765
logger.error(f"Error claiming episode: {e}. Retrying in 5s...")
7866
time.sleep(5)
7967

80-
def end_episode(self, workflow_output: WorkflowOutput):
81-
if not self.episode_uuid:
68+
def end_episode(self, episode_uuid: str, workflow_output: WorkflowOutput):
69+
if not episode_uuid:
8270
logger.error("No episode to end.")
8371
return
8472

8573
try:
8674
req_obj = EndEpisodeRequest(
8775
client_uuid=self.client_uuid,
88-
episode_uuid=self.episode_uuid,
76+
episode_uuid=episode_uuid,
8977
workflow_output=workflow_output
9078
)
9179

@@ -98,10 +86,9 @@ def end_episode(self, workflow_output: WorkflowOutput):
9886
data = EndEpisodeResponse.model_validate(resp.json())
9987

10088
if data.success:
101-
logger.info(f"Ended episode {self.episode_uuid}")
102-
self.episode_uuid = None
89+
logger.info(f"Ended episode {episode_uuid}")
10390
else:
104-
logger.error(f"Failed to end episode {self.episode_uuid}")
91+
logger.error(f"Failed to end episode {episode_uuid}")
10592

10693
except Exception as e:
10794
logger.error(f"Error ending episode: {e}")
@@ -122,3 +109,50 @@ def sync_train_config(self, agent_jet_job: AgentJetJob):
122109
logger.info("Synced train config")
123110
except Exception as e:
124111
logger.error(f"Error syncing train config: {e}")
112+
113+
def get_engine_status(self) -> str:
114+
try:
115+
resp = httpx.get(
116+
f"{self.server_url}/get_engine_status",
117+
timeout=10
118+
)
119+
resp.raise_for_status()
120+
return resp.json().get("engine_status", "unknown")
121+
except Exception as e:
122+
logger.error(f"Error getting engine status: {e}")
123+
return "unknown"
124+
125+
def can_continue_episode(self, episode_uuid: str) -> bool:
126+
if not episode_uuid:
127+
return False
128+
129+
try:
130+
req_obj = CanContinueEpisodeRequest(
131+
client_uuid=self.client_uuid,
132+
episode_uuid=episode_uuid
133+
)
134+
resp = httpx.post(
135+
f"{self.server_url}/can_continue_episode",
136+
json=req_obj.model_dump(),
137+
timeout=10
138+
)
139+
resp.raise_for_status()
140+
data = CanContinueEpisodeResponse.model_validate(resp.json())
141+
return data.can_continue
142+
except Exception as e:
143+
logger.error(f"Error checking can_continue_episode: {e}")
144+
return False
145+
146+
def get_episode_buffer(self) -> List[EpisodeStatus]:
147+
try:
148+
resp = httpx.post(
149+
f"{self.server_url}/get_episode_buffer",
150+
json={},
151+
timeout=10
152+
)
153+
resp.raise_for_status()
154+
data = EpisodeBufferResponse.model_validate(resp.json())
155+
return data.buffer
156+
except Exception as e:
157+
logger.error(f"Error getting episode buffer: {e}")
158+
return []

0 commit comments

Comments
 (0)