Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions py/core/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
R2RXMLToolsRAGAgent,
R2RXMLToolsStreamingRAGAgent,
)
from .rag_pyd import RAGPydAgent

# Import the concrete implementations
from .research import (
Expand All @@ -33,4 +34,6 @@
"R2RStreamingResearchAgent",
"R2RXMLToolsResearchAgent",
"R2RXMLToolsStreamingResearchAgent",
# Pydantic Agents
"RAGPydAgent",
]
124 changes: 124 additions & 0 deletions py/core/agent/rag_pyd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import logging
from typing import Callable, Optional

from pydantic_ai import Agent as PydanticAgent

from core.base import (
Message,
)
from core.base.abstractions import (
GenerationConfig,
SearchSettings,
)
from core.base.agent.tools.registry import ToolRegistry
from core.base.providers import DatabaseProvider
from core.providers import (
AnthropicCompletionProvider,
LiteLLMCompletionProvider,
OpenAICompletionProvider,
R2RCompletionProvider,
)

from ..base.agent.agent import RAGAgentConfig # type: ignore
from .rag import R2RRAGAgent # type: ignore

logger = logging.getLogger(__name__)

INSTRUCTIONS = """
You are a helpful agent that can search for information, the date is {date}.

If you have access to tools that help you set some filters, like smart_filter_tool, to narrow down and speed up the search, use them BEFORE the search
When asked a question, YOU SHOULD ALWAYS USE YOUR SEARCH TOOL TO ATTEMPT TO SEARCH FOR RELEVANT INFORMATION THAT ANSWERS THE USER QUESTION.

The response should contain line-item attributions to relevant search results, and be as informative if possible.

If no relevant results are found, then state that no results were found. If no obvious question is present, then do not carry out a search, and instead ask for clarification.

REMINDER - Use line item references to like [c910e2e], [b12cd2f], to refer to the specific search result IDs returned in the provided context.
"""


def pydantic_to_r2r_message(pydantic_response) -> list[Message]:
messages = []
logger.debug(f"Pydantic response: {pydantic_response}")
try:
all_messages = pydantic_response.all_messages
logger.debug(f"All messages: {all_messages}")
except Exception as e:
logger.error(f"Error getting all messages: {e}")
all_messages = []
if hasattr(pydantic_response, "output"):
messages.append(
Message(
role="assistant",
content=pydantic_response.output,
)
)
return messages


class RAGPydAgent(R2RRAGAgent):
def __init__(
self,
database_provider: DatabaseProvider,
llm_provider: (
AnthropicCompletionProvider
| LiteLLMCompletionProvider
| OpenAICompletionProvider
| R2RCompletionProvider
),
config: RAGAgentConfig,
search_settings: SearchSettings,
rag_generation_config: GenerationConfig,
knowledge_search_method: Callable,
content_method: Callable,
file_search_method: Callable,
tool_registry: Optional[ToolRegistry] = None,
max_tool_context_length: int = 20_000,
**kwargs,
):
super().__init__(
database_provider=database_provider,
llm_provider=llm_provider,
config=config,
search_settings=search_settings,
rag_generation_config=rag_generation_config,
knowledge_search_method=knowledge_search_method,
content_method=content_method,
file_search_method=file_search_method,
tool_registry=tool_registry,
max_tool_context_length=max_tool_context_length,
**kwargs,
)
# Init pydantic agent
self._pydantic_tools = [
tool._pydantic_ai_tool
for tool in getattr(self, "_tools", [])
if hasattr(tool, "_pydantic_ai_tool")
]
self._pydantic_agent = PydanticAgent(
model=self.get_pyd_ai_model_name(rag_generation_config.model),
tools=self._pydantic_tools,
name="R2R Pydantic Agent",
instructions=INSTRUCTIONS,
)

def get_pyd_ai_model_name(self, model_name: str | None):
logger.debug(f"Fetching model name from: {model_name}")
if not model_name:
raise ValueError("Model name is required")
if "gpt" in model_name:
# Initialize the model with our reverse proxy
# remove openai/ prefix if there is one
model_id = model_name.replace("openai/", "")
return model_id
else:
raise ValueError(f"Model {model_name} is not supported")

async def arun(self, messages: list[Message], **kwargs):
# logger.debug(f"Running pydantic agent with messages: {messages}")
message = messages[0].content
# logger.debug(f"Message: {message}")
py_response = await self._pydantic_agent.run(message)
r2r_response = pydantic_to_r2r_message(py_response)
return r2r_response
2 changes: 2 additions & 0 deletions py/core/agent/research.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def _register_research_tools(self):
# Add our research tools to whatever tools are already registered
research_tools = []
for tool_name in set(self.config.research_tools):
logger.debug(f"Registering research tool: {tool_name}")
if tool_name == "rag":
research_tools.append(self.rag_tool())
elif tool_name == "reasoning":
Expand Down Expand Up @@ -115,6 +116,7 @@ def rag_tool(self) -> Tool:
},
"required": ["query"],
},
context=self,
)

def reasoning_tool(self) -> Tool:
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
HybridSearchSettings,
SearchMode,
SearchSettings,
SmartFilterResult,
WebPageSearchResult,
WebSearchResult,
select_search_filters,
Expand Down Expand Up @@ -119,6 +120,7 @@
"ChunkSearchSettings",
"ChunkSearchResult",
"WebPageSearchResult",
"SmartFilterResult",
"SearchSettings",
"select_search_filters",
"SearchMode",
Expand Down
13 changes: 12 additions & 1 deletion py/core/base/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,17 @@ async def handle_function_or_tool_call(
)
)
# HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb]
if self.rag_generation_config.extended_thinking:
logger.debug(
f"Extended thinking - Claude needs a particular message continuation which however breaks other models. Model in use : {self.rag_generation_config.model}"
)
is_anthropic = (
self.rag_generation_config.model
and "anthropic/" in self.rag_generation_config.model
)
if (
self.rag_generation_config.extended_thinking
and is_anthropic
):
await self.conversation.add_message(
Message(
role="user",
Expand Down Expand Up @@ -258,6 +268,7 @@ class RAGAgentConfig(AgentConfig):
"search_file_descriptions",
"search_file_knowledge",
"get_file_content",
"smart_filter_tool",
# Web search tools - disabled by default
# "web_search",
# "web_scrape",
Expand Down
10 changes: 10 additions & 0 deletions py/core/base/agent/tools/built_in/get_file_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Any, Optional
from uuid import UUID

from pydantic_ai import Tool as PydanticTool

from shared.abstractions.tool import Tool

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -38,6 +40,14 @@ def __init__(self):
results_function=self.execute,
llm_format_function=None,
)
pyd_params = self.parameters.copy()
pyd_params["additionalProperties"] = False
self._pydantic_ai_tool = PydanticTool.from_schema(
function=self.execute,
name=self.name,
description=self.description,
json_schema=pyd_params,
)

async def execute(
self,
Expand Down
13 changes: 13 additions & 0 deletions py/core/base/agent/tools/built_in/search_file_descriptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

from pydantic_ai import Tool as PydanticTool

from shared.abstractions.tool import Tool

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -31,11 +33,22 @@ def __init__(self):
results_function=self.execute,
llm_format_function=None,
)
pyd_params = self.parameters.copy()
pyd_params["additionalProperties"] = False
self._pydantic_ai_tool = PydanticTool.from_schema(
function=self.execute,
name=self.name,
description=self.description,
json_schema=pyd_params,
)

async def execute(self, query: str, *args, **kwargs):
"""
Calls the file_search_method from context.
"""
logger.debug(
f"Executing SearchFileDescriptionsTool with query: {query}"
)
from core.base.abstractions import AggregateSearchResult

context = self.context
Expand Down
11 changes: 11 additions & 0 deletions py/core/base/agent/tools/built_in/search_file_knowledge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging

from pydantic_ai import Tool as PydanticTool

from shared.abstractions.tool import Tool

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -30,11 +32,20 @@ def __init__(self):
results_function=self.execute,
llm_format_function=None,
)
pyd_params = self.parameters.copy()
pyd_params["additionalProperties"] = False
self._pydantic_ai_tool = PydanticTool.from_schema(
function=self.execute,
name=self.name,
description=self.description,
json_schema=pyd_params,
)

async def execute(self, query: str, *args, **kwargs):
"""
Calls the knowledge_search_method from context.
"""
logger.debug(f"Executing SearchFileKnowledgeTool with query: {query}")
from core.base.abstractions import AggregateSearchResult

context = self.context
Expand Down
Loading
Loading