|
| 1 | +from typing import List |
| 2 | +from loguru import logger |
| 3 | +from pydantic import BaseModel, Field |
| 4 | +from ajet import AjetTuner, Workflow, WorkflowOutput, WorkflowTask |
| 5 | +from openai.types.chat.chat_completion import ChatCompletion |
| 6 | +from openai.types.chat import ChatCompletionMessageToolCall |
| 7 | +from textwrap import dedent |
| 8 | + |
| 9 | +import json |
| 10 | +import os |
| 11 | +import asyncio |
| 12 | +import requests |
| 13 | + |
| 14 | + |
| 15 | +# ------------------------------------------------------ |
| 16 | +# Simple version - no tool call |
| 17 | +# ------------------------------------------------------ |
| 18 | + |
| 19 | + |
| 20 | +class DeepResearchInputSchema(BaseModel): |
| 21 | + base_url: str = Field(default="", description="The base URL of the OpenAI-compatible API.") |
| 22 | + api_key: str = Field(default="", description="The API key for authentication.") |
| 23 | + init_messages: List[dict] = Field(default=[], description="The initial messages for the deep research task.") |
| 24 | + task_id: str = Field(default="", description="The unique identifier for the research task.") |
| 25 | + main_query: str = Field(default="", description="The main query for the research task.") |
| 26 | + max_steps: int = Field(default=20, description="The maximum number of steps for the research task.") |
| 27 | + env_service_url: str = Field(default="", description="The URL of the environment service.") |
| 28 | + |
| 29 | + |
| 30 | +class ExampleMaDeepResearch(Workflow): |
| 31 | + name: str = "multiagent_deep_research_workflow" |
| 32 | + |
| 33 | + async def execute(self, workflow_task: WorkflowTask, tuner: AjetTuner) -> WorkflowOutput: # type: ignore |
| 34 | + # Extract base URL and API key from the tuner |
| 35 | + url_and_apikey = tuner.as_oai_baseurl_apikey() |
| 36 | + base_url = url_and_apikey.base_url |
| 37 | + api_key = url_and_apikey.api_key |
| 38 | + init_messages = workflow_task.task.init_messages |
| 39 | + |
| 40 | + # Get the AGENT_SERVER_URL from environment variables or use a default value |
| 41 | + agent_server_url = os.getenv("AGENT_SERVER_URL", "http://localhost:8000") |
| 42 | + |
| 43 | + # Prepare the payload using DeepResearchInputSchema |
| 44 | + payload = DeepResearchInputSchema( |
| 45 | + base_url=base_url, |
| 46 | + api_key=api_key, |
| 47 | + init_messages=init_messages, |
| 48 | + task_id=workflow_task.task.task_id, |
| 49 | + main_query=workflow_task.task.main_query, |
| 50 | + max_steps=tuner.config.astune.rollout.multi_turn.max_steps, |
| 51 | + env_service_url=workflow_task.gym_env.service_url, |
| 52 | + ) |
| 53 | + |
| 54 | + try: |
| 55 | + # Send the HTTP POST request to the AGENT_SERVER_URL |
| 56 | + headers = { |
| 57 | + "Content-Type": "application/json", |
| 58 | + } |
| 59 | + |
| 60 | + response = requests.post( |
| 61 | + agent_server_url, |
| 62 | + headers=headers, |
| 63 | + data=payload.model_dump(), |
| 64 | + ) |
| 65 | + |
| 66 | + # Check if the request was successful |
| 67 | + if response.status_code == 200: |
| 68 | + result_data = response.json() |
| 69 | + logger.info(f"Successfully received response: {result_data}") |
| 70 | + result = WorkflowOutput(**result_data) |
| 71 | + return result |
| 72 | + |
| 73 | + except Exception as e: |
| 74 | + logger.error(f"An error occurred while sending the request: {e}") |
0 commit comments