Skip to content

Commit 886d7c0

Browse files
authored
Merge pull request #2 from modelscope/feat/cc/langchain-demos
Add math and learn-to-ask langchain demos
2 parents cdecc4e + 0347c80 commit 886d7c0

5 files changed

Lines changed: 260 additions & 3 deletions

File tree

ajet/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"WorkflowOutput",
1111
"AjetTuner",
1212
"AgentJetJob",
13-
"bp",
13+
"bp"
1414
]
1515

1616
__version__ = "0.1.0"

ajet/utils/launch_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def verify_python_env(args, exp_config):
144144
time.sleep(5)
145145
raise ImportError(cause + " " + solution)
146146
elif args.backbone == "verl":
147-
if not any([v in verl.__version__ for v in ["0.5.0.post", "0.7.0.post"]]): # you must install via `pip install -e .[verl]` to get every dependency right
147+
if not any([v in verl.__version__ for v in ["0.5.0.post", "0.5.0.dev", "0.7.0.post"]]): # you must install via `pip install -e .[verl]` to get every dependency right
148148
cause = "Python environment does not match current backbone 'verl'."
149149
solution = "Please `cd /path/to/project/AgentJet` and run `(uv) pip install -e .[verl]` to install the correct environment."
150150
print_dict(

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ dev = [
5151
"mypy>=1.7.0",
5252
"pytest>=8.0.0",
5353
"pytest-json-ctrf",
54+
"langchain>=1.2.3",
5455
]
5556

5657
reward = [
@@ -112,4 +113,4 @@ known_third_party = ["wandb"]
112113

113114

114115
[project.urls]
115-
"Homepage" = "https://github.com/modelscope/AgentJet"
116+
"Homepage" = "https://github.com/modelscope/AgentJet"
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
2+
import re
3+
import time
4+
import asyncio
5+
import threading
6+
7+
from agentscope.message import Msg
8+
from loguru import logger
9+
10+
from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask
11+
from ajet.utils.robust_dashscope import RobustDashScopeChatModel
12+
13+
system_prompt = """# Task
14+
You are a medical assistant. Your task is to understand the ongoing conversation and continue the medical inquiry in English.
15+
16+
## Guidelines
17+
- Each response must contain exactly one clear and concise medical question with 2 to 3 answer choices.
18+
- Do not repeat any previous question.
19+
- Your response must be a single sentence.
20+
- If enough information has been gathered to make a medication suggestion, output only: <stop />
21+
"""
22+
23+
reward_prompt = """# Task
24+
You are an evaluation assistant. The user will provide a dialogue history between a doctor and a patient. You must analyze the dialogue and evaluate the doctor's last message.
25+
26+
# Grading Policy
27+
## Format Score
28+
- 1.0: The doctor's last message contains exactly **one question**.
29+
- 0.5: The doctor's last message contains **two questions**.
30+
- 0.0: The doctor's last message contains **three or more questions**.
31+
32+
## Content Score
33+
Reference Information contains the information that the doctor has not known.
34+
35+
- 1.0: The question(s) **directly ask about** item in the Reference Information.
36+
- 0.1: The question(s) are a general type of question that could be asked for any symptoms.
37+
- 0.0: The question(s) are **irrelevant** to all items in the Reference Information.
38+
39+
### You should
40+
41+
- ONLY if the doctor asks a question that helps to collect information and diagnose the patient, it is a good question.
42+
- A ambiguous question should get 0.
43+
- For example, the doctor asks "How long have you been feeling this way?", but "this way" is not clear in the previous messages.
44+
- For example, the doctor asks "Do you feel bad?". This is a meaningless question that does not provide any useful information.
45+
46+
# Reference Information
47+
48+
{}
49+
50+
# Output Format
51+
<think>Explain your reasoning for the format and content scores clearly and concisely.</think>
52+
<format_score>Insert only the format score as a float (e.g., 1.0, 0.5, 0.0)</format_score>
53+
<content_score>Insert only the content score as a float (e.g., 1.0, 0.5, 0.0)</content_score>
54+
55+
> ✅ Important:
56+
> - Output **exactly** the three tags shown above.
57+
> - Do **not** include any additional text, explanation, or formatting outside the tags.
58+
> - Scores must be based **only** on the doctor's **last message** and the provided Reference Information.
59+
> - Ensure clarity and precision in your evaluation reasoning within the `<think>` tag.
60+
"""
61+
62+
63+
llm = RobustDashScopeChatModel("qwen-plus", stream=False)
64+
65+
66+
async def llm_reward(init_messages: list[dict], response: str, truth_info: str):
67+
def format_messages(messages: list[dict]) -> str:
68+
result_str = ""
69+
for msg in messages:
70+
if msg["role"] == "user":
71+
result_str += f"patient: {msg['content']}\n"
72+
if msg["role"] == "assistant":
73+
result_str += f"doctor: {msg['content']}\n"
74+
return result_str
75+
76+
def parse_tag_string(text: str):
77+
pattern = r"<(\w+)>(.*?)</\1>"
78+
matches = re.findall(pattern, text)
79+
result = {}
80+
for tag, value in matches:
81+
result[tag] = value
82+
return result
83+
84+
history = format_messages([] + init_messages + [{"role": "assistant", "content": response}])
85+
messages = [
86+
{"role": "system", "content": reward_prompt.format(truth_info)},
87+
{"role": "user", "content": history},
88+
]
89+
90+
try_count, max_retries = 0, 5
91+
while try_count <= max_retries:
92+
try:
93+
94+
async def get_content():
95+
from agentscope.model import ChatResponse
96+
97+
response = await llm(messages)
98+
99+
if isinstance(response, ChatResponse):
100+
res = "".join([x["text"] for x in response.content if "text" in x])
101+
else:
102+
res = ""
103+
async for chunk in response:
104+
res += "".join([x["text"] for x in chunk.content if "text" in x])
105+
return res
106+
107+
content = await get_content()
108+
score_dict = parse_tag_string(content)
109+
return score_dict
110+
except Exception as e:
111+
if try_count > max_retries:
112+
logger.warning("retried too many times, abort task.")
113+
return None
114+
else:
115+
logger.warning(f"error: {e}, response:{response}, retrying...")
116+
time.sleep(2**try_count)
117+
118+
119+
async def reward_fn(init_messages: list[dict], response: str, truth_action: str, truth_info: str):
120+
"""
121+
content_score: R_a, the reward for response quality
122+
action_score: R_s, the reward for decision correctness
123+
format_score: P, the reward for response format
124+
"""
125+
126+
action_response = "stop" if "<stop />" in response else "continue"
127+
if truth_action == action_response:
128+
action_score = 1.0
129+
if truth_action == "continue":
130+
score_dict = await llm_reward(init_messages, response, truth_info)
131+
if score_dict is not None:
132+
format_score = float(score_dict.get("format_score", 0.0))
133+
content_score = float(score_dict.get("content_score", 0.0))
134+
else:
135+
format_score, content_score = 0.0, 0.0
136+
else:
137+
content_score = 1.0
138+
format_score = 1.0 if response == "<stop />" else 0.0
139+
else:
140+
action_score, format_score, content_score = 0.0, 0.0, 0.0
141+
142+
# treat as self.train_mode == "Ra+Rs", the default setting
143+
final_reward = action_score * (1 + 2 * content_score) + format_score
144+
145+
return final_reward
146+
147+
148+
_reward_semaphore = threading.Semaphore(16)
149+
150+
async def reward_fn_with_semaphore(*args, **kwargs):
151+
152+
get_sem_ok = False
153+
while not get_sem_ok:
154+
get_sem_ok = _reward_semaphore.acquire(blocking=False)
155+
if not get_sem_ok:
156+
await asyncio.sleep(1)
157+
158+
try:
159+
fn_result = await reward_fn(*args, **kwargs)
160+
finally:
161+
_reward_semaphore.release()
162+
163+
return fn_result
164+
165+
166+
class ExampleLearn2Ask(Workflow):
167+
name: str = "math_agent_workflow"
168+
169+
async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput:
170+
from langchain_openai import ChatOpenAI
171+
from langchain.agents import create_agent
172+
173+
messages = workflow_task.task.init_messages
174+
assert isinstance(messages, list)
175+
truth_action = workflow_task.task.metadata["decision_truth"] or "continue"
176+
truth_info = workflow_task.task.metadata["info_truth"]
177+
178+
llm_info=tuner.as_oai_baseurl_apikey()
179+
180+
llm=ChatOpenAI(
181+
base_url=llm_info.base_url,
182+
api_key=lambda:llm_info.api_key,
183+
)
184+
185+
agent=create_agent(
186+
model=llm,
187+
system_prompt=system_prompt,
188+
)
189+
190+
msg=[
191+
{"role": x["role"], "content": x["content"]} for x in messages
192+
]
193+
result = agent.invoke({
194+
"messages": msg, # type: ignore
195+
})
196+
197+
response = result["messages"][-1].content
198+
reward = await reward_fn_with_semaphore(msg, response, truth_action, truth_info)
199+
return WorkflowOutput(reward=reward)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from loguru import logger
2+
from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask
3+
from openai.types.chat.chat_completion import ChatCompletion
4+
from openai.types.chat import ChatCompletionMessageToolCall
5+
from textwrap import dedent
6+
7+
import json
8+
import asyncio
9+
import requests
10+
from langchain.agents import create_agent
11+
12+
13+
# ------------------------------------------------------
14+
# Simple version - no tool call
15+
# ------------------------------------------------------
16+
17+
18+
class ExampleMathLearn(Workflow):
19+
20+
name: str = "math_agent_workflow"
21+
system_prompt: str = dedent("""
22+
You are an agent specialized in solving math problems.
23+
Please solve the math problem given to you.
24+
You can write and execute Python code to perform calculation or verify your answer.
25+
You should return your final answer within \\boxed{{}}.
26+
""")
27+
28+
async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: # type: ignore
29+
# tuner to api key
30+
url_and_apikey = tuner.as_oai_baseurl_apikey()
31+
base_url = url_and_apikey.base_url
32+
api_key = url_and_apikey.api_key
33+
34+
from langchain_openai import ChatOpenAI
35+
llm=ChatOpenAI(
36+
base_url=base_url,
37+
api_key=lambda:api_key,
38+
)
39+
agent=create_agent(
40+
model=llm,
41+
system_prompt=self.system_prompt,
42+
)
43+
44+
# take out query
45+
query = workflow_task.task.main_query
46+
47+
response = agent.invoke({
48+
"messages": [
49+
{
50+
"role": "user",
51+
"content": query
52+
}
53+
],
54+
})
55+
56+
final_answer = response['messages'][-1].content
57+
return WorkflowOutput(reward=None, metadata={"final_answer": final_answer})

0 commit comments

Comments
 (0)