Skip to content

Commit a66a182

Browse files
committed
Add service URL attribute to BaseGymEnv and implement DeepResearch workflow example
1 parent f6c5453 commit a66a182

2 files changed

Lines changed: 75 additions & 0 deletions

File tree

ajet/task_rollout/resource_keeper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(
150150
self.task_thread_index = task_thread_index
151151
self.observation_window = observation_window
152152
self.episode_uuid = episode_uuid
153+
self.service_url = self.env_client.base_url
153154

154155
def step(self, action: dict) -> Tuple[str, float, bool, dict]:
155156
"""Take a step in the gym environment."""
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from typing import List
2+
from loguru import logger
3+
from pydantic import BaseModel, Field
4+
from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask
5+
from openai.types.chat.chat_completion import ChatCompletion
6+
from openai.types.chat import ChatCompletionMessageToolCall
7+
from textwrap import dedent
8+
9+
import json
10+
import os
11+
import asyncio
12+
import requests
13+
14+
15+
# ------------------------------------------------------
16+
# Simple version - no tool call
17+
# ------------------------------------------------------
18+
19+
20+
class DeepResearchInputSchema(BaseModel):
21+
base_url: str = Field(default="", description="The base URL of the OpenAI-compatible API.")
22+
api_key: str = Field(default="", description="The API key for authentication.")
23+
init_messages: List[dict] = Field(default=[], description="The initial messages for the deep research task.")
24+
task_id: str = Field(default="", description="The unique identifier for the research task.")
25+
main_query: str = Field(default="", description="The main query for the research task.")
26+
max_steps: int = Field(default=20, description="The maximum number of steps for the research task.")
27+
env_service_url: str = Field(default="", description="The URL of the environment service.")
28+
29+
30+
class ExampleMaDeepResearch(Workflow):
31+
name: str = "multiagent_deep_research_workflow"
32+
33+
async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: # type: ignore
34+
# Extract base URL and API key from the tuner
35+
url_and_apikey = tuner.as_oai_baseurl_apikey()
36+
base_url = url_and_apikey.base_url
37+
api_key = url_and_apikey.api_key
38+
init_messages = workflow_task.task.init_messages
39+
40+
# Get the AGENT_SERVER_URL from environment variables or use a default value
41+
agent_server_url = os.getenv("AGENT_SERVER_URL", "http://localhost:8000")
42+
43+
# Prepare the payload using DeepResearchInputSchema
44+
payload = DeepResearchInputSchema(
45+
base_url=base_url,
46+
api_key=api_key,
47+
init_messages=init_messages,
48+
task_id=workflow_task.task.task_id,
49+
main_query=workflow_task.task.main_query,
50+
max_steps=tuner.config.astune.rollout.multi_turn.max_steps,
51+
env_service_url=workflow_task.gym_env.service_url,
52+
)
53+
54+
try:
55+
# Send the HTTP POST request to the AGENT_SERVER_URL
56+
headers = {
57+
"Content-Type": "application/json",
58+
}
59+
60+
response = requests.post(
61+
agent_server_url,
62+
headers=headers,
63+
data=payload.model_dump(),
64+
)
65+
66+
# Check if the request was successful
67+
if response.status_code == 200:
68+
result_data = response.json()
69+
logger.info(f"Successfully received response: {result_data}")
70+
result = WorkflowOutput(**result_data)
71+
return result
72+
73+
except Exception as e:
74+
logger.error(f"An error occurred while sending the request: {e}")

0 commit comments

Comments
 (0)