Skip to content

Commit 57dc06b

Browse files
committed
Enhance AgentJet framework and update documentation
- Updated README.md to clarify AgentJet's capabilities and added swarm training instructions. - Refactored job.py to import ray only when needed. - Fixed file opening in doc_reader.py, tracing_reader, and native_parallel_worker.py to include UTF-8 encoding. - Improved as_oai_baseurl_apikey.py by using TYPE_CHECKING for conditional imports. - Moved generate_auth_token function to interchange_utils.py and added API_KEY_PREFIX. - Updated various file opening methods across the codebase to ensure UTF-8 encoding. - Added new dependencies (hydra-core, datasets) to pyproject.toml. - Modified deep_finance_judge.py to improve code readability and structure. - Updated math.py to change the remote training model path for consistency.
1 parent 484d1bc commit 57dc06b

17 files changed

Lines changed: 199 additions & 182 deletions

File tree

README.md

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,31 @@
1212
</div>
1313

1414

15-
**AgentJet (AJet)** is a cutting-edge, user-friendly training framework designed to optimize agents and workflows (built with OpenAI SDK, AgentScope, Langchain, or just HTTP requests), fine-tuning language model weights behind the scenes.
15+
**AgentJet (AJet)** is a cutting-edge, user-friendly agent RL training framework designed to optimize agents and agentic workflows (supporting any agent built with OpenAI SDK, AgentScope, Langchain, or raw HTTP requests), fine-tuning LLM weights to enhance model performance.
1616

17-
Simply provide your agent **workflow**, training **dataset**, and **reward** function, and **AgentJet** will be ready to enhance your agents to their optimal performance!
17+
**AgentJet (AJet)** has fully-distributed **swarm training** capability, which means that you can **deploy `ajet-swarm start` in GPU server(s) and then start training agents in your laptop(s)**! Simply provide your agent workflow, training dataset, and reward function, and AgentJet will be ready to go!
1818

1919

2020

2121
## ✈️ Minimum Example
2222

23-
Let's begin with the simplest example: a math agent with a tool call.
23+
### Classic Mode
2424

25-
- First, please check out the [installation guide](https://modelscope.github.io/AgentJet/en/installation/) to set up the training environment.
26-
- Then, tune your first model using the minimum example.
27-
```python
28-
ajet --conf tutorial/example_math_agent/math_agent.yaml --backbone='verl'
25+
Let's begin with the simplest example: a math agent with a tool call. This is a simple & centralized training example.
2926

30-
# change to --backbone='trinity' if you want to switch to trinity training engine;
31-
# or --backbone='debug' if you want to debug with only vLLM
32-
```
27+
1. please check out the [installation guide](https://modelscope.github.io/AgentJet/en/installation/) to set up the training environment.
28+
2. tune your first model using the minimum example.
29+
```python
30+
ajet --conf ./tutorial/example_math_agent/math_agent.yaml --backbone='verl'
31+
```
32+
33+
### Swarm Mode
3334

35+
1. Start swarm server and begin swarm overwatch: `ajet-swarm start` and `ajet-swarm overwatch --swarm-url=http://localhost:10086`.
36+
2. From another device (or localhost), run [this script to train](https://github.com/modelscope/AgentJet/blob/main/tutorial/example_math_swarm/math.py):
37+
```python
38+
AJET_SWARM_URL="http://swarm-server-ip:10086" python ./tutorial/example_math_swarm/math.py
39+
```
3440

3541
## ✈️ Features
3642

ajet/copilot/job.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from types import SimpleNamespace
1313
from typing import Any, Callable, Union
1414

15-
import ray
1615
import yaml
1716
from loguru import logger
1817

@@ -138,6 +137,7 @@ def set_data(
138137
return self
139138

140139
def tune(self, *args, **kwargs) -> "AgentJetJob":
140+
import ray
141141
ast_cfg = self.config.ajet
142142
if not ast_cfg.rollout or not ast_cfg.rollout.user_workflow:
143143
raise ValueError("Workflow must be set via set_workflow before tuning.")

ajet/task_reader/document_reader/doc_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _calculate_file_hash(self, file_path: str) -> str:
7272
"""Calculate SHA256 hash of a file."""
7373
try:
7474
hash_sha256 = hashlib.sha256()
75-
with open(file_path, "rb") as f:
75+
with open(file_path, "rb", encoding="utf-8") as f:
7676
for chunk in iter(lambda: f.read(4096), b""):
7777
hash_sha256.update(chunk)
7878
return hash_sha256.hexdigest()

ajet/task_reader/tracing_reader/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def _load_existing_tasks(self, path: str) -> List[Task]:
5353
if not os.path.exists(path):
5454
return []
5555
tasks: List[Task] = []
56-
with open(path, "r") as f:
56+
with open(path, "r", encoding="utf-8") as f:
5757
for line in f:
5858
line = line.strip()
5959
if not line:
@@ -66,7 +66,7 @@ def _append_tasks(self, path: str, tasks: List[Task]) -> None:
6666
if not tasks:
6767
return
6868
mode = "a" if os.path.exists(path) else "w"
69-
with open(path, mode) as f:
69+
with open(path, mode, encoding="utf-8") as f:
7070
for task in tasks:
7171
obj = task.model_dump()
7272
f.write(json.dumps(obj, ensure_ascii=False) + "\n")

ajet/task_rollout/native_parallel_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def _write_swarm_rollout_dynamic_log(self, observation_window):
8585
string_buffer = ""
8686
for info in observation_window["info"]:
8787
string_buffer += f"{info}\n"
88-
with open(fp, "w") as f:
88+
with open(fp, "w", encoding="utf-8") as f:
8989
f.write(string_buffer)
9090
return
9191

ajet/tuner_lib/as_oai_baseurl_apikey.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import os
2-
from typing import Any
2+
from typing import Any, TYPE_CHECKING
33
from pydantic import BaseModel, Field
4-
from ajet.context_tracker.multiagent_tracking import (
5-
MultiAgentContextTracker,
6-
)
74
from openai.resources.chat.chat import AsyncChat
85
from openai.resources.completions import AsyncCompletions
9-
from .experimental.as_oai_model_client import generate_auth_token
6+
from ajet.tuner_lib.experimental.interchange_utils import generate_auth_token
107

8+
if TYPE_CHECKING:
9+
from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker
1110

1211
class MockAsyncCompletions(AsyncCompletions):
1312
async def create(self, *args, **kwargs) -> Any: # type: ignore
@@ -43,7 +42,7 @@ class OpenaiClientBaseUrlTuner(BaseModel):
4342
def __init__(
4443
self,
4544
config,
46-
context_tracker: MultiAgentContextTracker,
45+
context_tracker: "MultiAgentContextTracker",
4746
target_tag: str,
4847
agent_name: str,
4948
episode_uuid: str,

ajet/tuner_lib/experimental/as_oai_model_client.py

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010

1111
from loguru import logger
1212
from typing import TYPE_CHECKING
13-
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
1413
from openai.types.chat.chat_completion import ChatCompletion
15-
from ajet.tuner_lib.experimental.as_oai_model_server import InterchangeCompletionRequest, API_KEY_PREFIX
14+
from ajet.tuner_lib.experimental.as_oai_model_server import InterchangeCompletionRequest
1615
from ajet.utils.thread_executors import SharedInferenceTrackerThreadExecutor, SharedInterchangeThreadExecutor
1716
from ajet.tuner_lib.experimental.interchange_utils import get_zmq_socket
18-
from ajet.tuner_lib.experimental.interchange_utils import DEBUG
17+
from ajet.tuner_lib.experimental.interchange_utils import DEBUG, API_KEY_PREFIX
18+
19+
if TYPE_CHECKING:
20+
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
1921

2022
context = zmq.Context()
2123
atexit.register(context.term)
@@ -24,38 +26,6 @@
2426
from ajet.context_tracker.multiagent_tracking import MultiAgentContextTracker
2527

2628

27-
def generate_auth_token(agent_name, target_tag, episode_uuid, episode_address):
28-
"""
29-
Generate a Base64-encoded auth_token from the given agent_name, target_tag, and episode_uuid.
30-
31-
Args:
32-
agent_name (str): The name of the agent.
33-
target_tag (str): The target tag.
34-
episode_uuid (str): The UUID of the episode.
35-
36-
Returns:
37-
str: The generated auth_token in the format "Bearer <base64_encoded_string>".
38-
"""
39-
# Step 1: Construct the auth_data dictionary
40-
auth_data = {
41-
"agent_name": agent_name,
42-
"target_tag": target_tag,
43-
"episode_uuid": episode_uuid,
44-
"episode_address": episode_address,
45-
}
46-
47-
# Step 2: Convert the dictionary to a JSON string
48-
json_string = json.dumps(auth_data)
49-
50-
# Step 3: Encode the JSON string into Base64
51-
base64_encoded = base64.b64encode(json_string.encode('utf-8')).decode('utf-8')
52-
53-
# Step 4: Prepend "Bearer " to the Base64-encoded string
54-
auth_token = f"{API_KEY_PREFIX}{base64_encoded}" # API_KEY_PREFIX: Literal['sk-ajet-']
55-
56-
return auth_token
57-
58-
5929
class InterchangeClient:
6030
""" InterchangeClient is re-created in each episode
6131
"""

ajet/tuner_lib/experimental/as_oai_model_server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,8 @@
3535

3636
from ajet.utils.networking import get_host_ip
3737
from ajet.tuner_lib.experimental.interchange_utils import EpisodeStatus
38-
from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE
38+
from ajet.tuner_lib.experimental.interchange_utils import DEBUG, VERBOSE, API_KEY_PREFIX
3939

40-
API_KEY_PREFIX = "sk-ajet-"
4140

4241
class InterchangeCompletionRequest(BaseModel):
4342
completion_request: ChatCompletionRequest

ajet/tuner_lib/experimental/as_swarm_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def _write_swarm_server_dynamic_log(shared_mem_dict):
243243
p = es.model_dump_json()
244244
string_buffer += f"{p}\n"
245245

246-
with open(fp, "w") as f:
246+
with open(fp, "w", encoding="utf-8") as f:
247247
f.write(string_buffer)
248248
return
249249

ajet/tuner_lib/experimental/interchange_utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import os
22
import time
33
import httpx
4+
import base64
5+
import json
6+
47
from typing import List
58
from pydantic import BaseModel, Field
69
from loguru import logger
@@ -19,6 +22,7 @@
1922
"ENGINE.WEIGHT_EXPORTING"
2023
]
2124

25+
API_KEY_PREFIX = "sk-ajet-"
2226

2327
class SyncTrainConfigRequest(BaseModel):
2428
yaml_as_string: str
@@ -205,3 +209,36 @@ def get_zmq_socket(config, episode_uuid: str, tag: str = ""):
205209
else:
206210
raise RuntimeError(f"Unknown interchange_method: {interchange_method}")
207211
return zmq_contect_address, ipc_path
212+
213+
214+
215+
def generate_auth_token(agent_name, target_tag, episode_uuid, episode_address):
216+
"""
217+
Generate a Base64-encoded auth_token from the given agent_name, target_tag, and episode_uuid.
218+
219+
Args:
220+
agent_name (str): The name of the agent.
221+
target_tag (str): The target tag.
222+
episode_uuid (str): The UUID of the episode.
223+
224+
Returns:
225+
str: The generated auth_token in the format "Bearer <base64_encoded_string>".
226+
"""
227+
# Step 1: Construct the auth_data dictionary
228+
auth_data = {
229+
"agent_name": agent_name,
230+
"target_tag": target_tag,
231+
"episode_uuid": episode_uuid,
232+
"episode_address": episode_address,
233+
}
234+
235+
# Step 2: Convert the dictionary to a JSON string
236+
json_string = json.dumps(auth_data)
237+
238+
# Step 3: Encode the JSON string into Base64
239+
base64_encoded = base64.b64encode(json_string.encode('utf-8')).decode('utf-8')
240+
241+
# Step 4: Prepend "Bearer " to the Base64-encoded string
242+
auth_token = f"{API_KEY_PREFIX}{base64_encoded}" # API_KEY_PREFIX: Literal['sk-ajet-']
243+
244+
return auth_token

0 commit comments

Comments
 (0)