From 7272f95e9cdf550c440b2eb5c4a8252ba396bd45 Mon Sep 17 00:00:00 2001 From: Andrea Date: Thu, 5 Jun 2025 16:43:57 +0200 Subject: [PATCH 1/4] Enhance ResearchAgentMixin with context parameter for rag tool and skip extended thinking message in Agent class for non Anthropic models --- py/core/agent/research.py | 1 + py/core/base/agent/agent.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/py/core/agent/research.py b/py/core/agent/research.py index f0e7f8cf5..2cf94a609 100644 --- a/py/core/agent/research.py +++ b/py/core/agent/research.py @@ -115,6 +115,7 @@ def rag_tool(self) -> Tool: }, "required": ["query"], }, + context=self, ) def reasoning_tool(self) -> Tool: diff --git a/py/core/base/agent/agent.py b/py/core/base/agent/agent.py index 199e18def..f25452b49 100644 --- a/py/core/base/agent/agent.py +++ b/py/core/base/agent/agent.py @@ -227,7 +227,17 @@ async def handle_function_or_tool_call( ) ) # HACK - to fix issues with claude thinking + tool use [https://github.com/anthropics/anthropic-cookbook/blob/main/extended_thinking/extended_thinking_with_tool_use.ipynb] - if self.rag_generation_config.extended_thinking: + logger.debug( + f"Extended thinking - Claude needs a particular message continuation which however breaks other models. Model in use : {self.rag_generation_config.model}" + ) + is_anthropic = ( + self.rag_generation_config.model + and "anthropic/" in self.rag_generation_config.model + ) + if ( + self.rag_generation_config.extended_thinking + and is_anthropic + ): await self.conversation.add_message( Message( role="user", From 3134e0101b83ae1c376e803b05707e285622d233 Mon Sep 17 00:00:00 2001 From: nPeppon Date: Fri, 6 Jun 2025 14:06:50 +0200 Subject: [PATCH 2/4] Add SmartFilterTool for enhanced metadata filtering in RAG searches - Introduced SmartFilterTool to refine metadata and collection filters based on user queries. - Updated relevant files to register the new tool and include it in the available tools list. - Enhanced documentation and prompts to reflect the use of smart_filter_tool for improved search efficiency. --- py/core/agent/research.py | 1 + py/core/base/abstractions/__init__.py | 2 + py/core/base/agent/agent.py | 1 + .../base/agent/tools/built_in/smart_filter.py | 180 ++++++++++++++++++ py/core/main/api/v3/retrieval_router.py | 4 +- .../database/prompts/dynamic_rag_agent.yaml | 2 + .../dynamic_rag_agent_xml_tooling.yaml | 1 + .../database/prompts/static_rag_agent.yaml | 1 + .../prompts/static_research_agent.yaml | 4 + py/sdk/sync_methods/retrieval.py | 2 +- py/shared/abstractions/search.py | 65 +++++++ 11 files changed, 261 insertions(+), 2 deletions(-) create mode 100644 py/core/base/agent/tools/built_in/smart_filter.py diff --git a/py/core/agent/research.py b/py/core/agent/research.py index 2cf94a609..61f0ece3e 100644 --- a/py/core/agent/research.py +++ b/py/core/agent/research.py @@ -79,6 +79,7 @@ def _register_research_tools(self): # Add our research tools to whatever tools are already registered research_tools = [] for tool_name in set(self.config.research_tools): + logger.debug(f"Registering research tool: {tool_name}") if tool_name == "rag": research_tools.append(self.rag_tool()) elif tool_name == "reasoning": diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index 2fd2729dc..0140ad1ce 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -49,6 +49,7 @@ HybridSearchSettings, SearchMode, SearchSettings, + SmartFilterResult, WebPageSearchResult, WebSearchResult, select_search_filters, @@ -119,6 +120,7 @@ "ChunkSearchSettings", "ChunkSearchResult", "WebPageSearchResult", + "SmartFilterResult", "SearchSettings", "select_search_filters", "SearchMode", diff --git a/py/core/base/agent/agent.py b/py/core/base/agent/agent.py index f25452b49..a01bd23cb 100644 --- a/py/core/base/agent/agent.py +++ b/py/core/base/agent/agent.py @@ -268,6 +268,7 @@ class RAGAgentConfig(AgentConfig): "search_file_descriptions", "search_file_knowledge", "get_file_content", + "smart_filter_tool", # Web search tools - disabled by default # "web_search", # "web_scrape", diff --git a/py/core/base/agent/tools/built_in/smart_filter.py b/py/core/base/agent/tools/built_in/smart_filter.py new file mode 100644 index 000000000..2720eb98d --- /dev/null +++ b/py/core/base/agent/tools/built_in/smart_filter.py @@ -0,0 +1,180 @@ +import logging +from typing import Any + +from shared.abstractions.tool import Tool + +logger = logging.getLogger(__name__) + + +class SmartFilterTool(Tool): + """ + A tool to refine metadata and collection filters for a RAG search using LLM analysis. + This tool does NOT perform the RAG search itself, only returns the refined filters and prompt. + """ + + def __init__(self): + super().__init__( + name="smart_filter_tool", + description=( + "Analyzes the user query and available collections, then returns refined collection IDs, " + "metadata filters, and a possibly modified prompt for downstream RAG search. " + "Does NOT perform the search itself." + "The tool is editing the search settings internally, so no need to use its results afterwards." + ), + parameters={ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The user query to analyze.", + }, + }, + "required": ["query"], + }, + results_function=self.execute, + llm_format_function=None, + ) + + async def execute(self, query: str, *args, **kwargs): + """ + Uses the LLM to analyze the query and available collections, returning collection IDs, filters, and prompt_mod. + """ + + from core.base.abstractions import ( + AggregateSearchResult, + SmartFilterResult, + ) + + context = self.context + if ( + not context + or not hasattr(context, "database_provider") + or not hasattr(context, "config") + or not hasattr(context, "rag_generation_config") + ): + logger.error( + "Context missing database_provider or config or rag_generation_config for SmartRagTool" + ) + return {"collections": [], "filters": {}, "prompt_mod": query} + + if not hasattr(context, "search_settings"): + logger.error("Context missing search_settings for SmartRagTool") + return {"collections": [], "filters": {}, "prompt_mod": query} + + try: + collections_overview = await context.database_provider.collections_handler.get_collections_overview( + offset=0, limit=50 + ) + collections_brief = [ + { + "id": str(c.id), + "name": c.name, + "description": getattr(c, "description", ""), + } + for c in collections_overview.get("results", []) + ] + except Exception as e: + logger.error(f"Error fetching collections: {e}") + return {"collections": [], "filters": {}, "prompt_mod": query} + + collections_str = "\n".join( + [ + f"- {c['name']} (ID: {c['id']}): {c['description']}" + for c in collections_brief + ] + ) + llm_prompt = ( + f"Here are the available collections with data the user might be interested in (with IDs and descriptions):\n{collections_str}\n\n" + f'User query: "{query}"\n\n' + "Please return a JSON with:\n" + "- 'collections': [list of relevant collection IDs as strings]\n" + "- 'filters': (optional) metadata filters\n" + "- 'prompt_mod': (optional) a modified prompt for the RAG search" + "We will use this new params to query the database (RAG search) and filter the collections." + ) + model = context.rag_generation_config.model + try: + GenerationConfig = __import__( + "core.base.abstractions", fromlist=["GenerationConfig"] + ).GenerationConfig + gen_cfg = GenerationConfig( + model=model, + max_tokens_to_sample=1024, + temperature=0.0, + stream=False, + tools=None, + functions=None, + ) + response = await context.llm_provider.aget_completion( + [{"role": "user", "content": llm_prompt}], gen_cfg + ) + llm_response = response.choices[0].message.content + import json + + try: + result = json.loads(llm_response) + logger.debug(f"SmartFilter raw response:\n{result}") + filtered_collections_ids = result.get("collections", []) + # filters = result.get("filters", {}) + if len(filtered_collections_ids) > 0 and hasattr( + context, "search_settings" + ): + new_filters = self.merge_filters( + context.search_settings.filters, + filtered_collections_ids, + collections_brief, + ) + logger.debug(f"SmartFilter output Filters:\n{new_filters}") + context.search_settings.filters = new_filters + + smart_filter_result = SmartFilterResult( + collections=filtered_collections_ids, + filters=new_filters, + prompt_mod=query, + ) + result = AggregateSearchResult( + smart_filter_result=smart_filter_result + ) + return result + except Exception: + logger.error(f"LLM did not return valid JSON: {llm_response}") + return {"collections": [], "filters": {}, "prompt_mod": query} + except Exception as e: + logger.error(f"Error in SmartRagTool LLM analysis: {e}") + return {"collections": [], "filters": {}, "prompt_mod": query} + + def merge_filters( + self, + existing_filters: dict[str, Any], + collections_ids_to_filter_on: list[str], + collections_brief: list[dict[str, Any]], + ): + from uuid import UUID + + new_collection_filter = { + "collection_ids": { + "$in": [UUID(c) for c in collections_ids_to_filter_on] + } + } + new_category_metadata_filter = { + "metadata.category": { + "$in": [ + c["name"] + for c in collections_brief + if c["id"] in collections_ids_to_filter_on + ] + } + } + new_filters = [new_collection_filter, new_category_metadata_filter] + if not existing_filters: + return {"$and": new_filters} + if not collections_ids_to_filter_on: + return existing_filters + # If existing is already an $and, append new filters + if "$and" in existing_filters: + # Avoid duplicating filters if already present + combined = existing_filters["$and"] + new_filters + return {"$and": combined} + else: + # Combine the single filter with the new filters + return {"$and": [existing_filters] + new_filters} diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 49f5c6ca2..6d9850597 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -580,11 +580,12 @@ async def agent_app( "search_file_descriptions", "search_file_knowledge", "get_file_content", + "smart_filter_tool", ] ] ] = Body( None, - description="List of tools to enable for RAG mode. Available tools: search_file_knowledge, get_file_content, web_search, web_scrape, search_file_descriptions", + description="List of tools to enable for RAG mode. Available tools: search_file_knowledge, get_file_content, web_search, web_scrape, search_file_descriptions, smart_filter_tool", ), # FIXME: We need a more generic way to handle this research_tools: Optional[ @@ -659,6 +660,7 @@ async def agent_app( - `content`: Fetch entire documents or chunk structures - `web_search`: Query external search APIs for up-to-date information - `web_scrape`: Scrape and extract content from specific web pages + - `smart_filter_tool`: Use a smart filter to narrow down the search results **Research Tools:** - `rag`: Leverage the underlying RAG agent for information retrieval diff --git a/py/core/providers/database/prompts/dynamic_rag_agent.yaml b/py/core/providers/database/prompts/dynamic_rag_agent.yaml index 5b2645300..3ec0dd6d6 100644 --- a/py/core/providers/database/prompts/dynamic_rag_agent.yaml +++ b/py/core/providers/database/prompts/dynamic_rag_agent.yaml @@ -2,6 +2,8 @@ dynamic_rag_agent: template: > ### You are a helpful agent that can search for information, the date is {date}. + If you have access to tools that help you set some filters, like smart_filter_tool, to narrow down and speed up the search, use them BEFORE the search + When asked a question, YOU SHOULD ALWAYS USE YOUR SEARCH TOOL TO ATTEMPT TO SEARCH FOR RELEVANT INFORMATION THAT ANSWERS THE USER QUESTION. The response should contain line-item attributions to relevant search results, and be as informative if possible. Note that you will only be able to load {max_tool_context_length} tokens of context at a time, if the context surpasses this then it will be truncated. If possible, set filters which will reduce the context returned to only that which is specific, by means of '$eq' or '$overlap' filters. diff --git a/py/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml b/py/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml index ce5784a30..135be58ad 100644 --- a/py/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml +++ b/py/core/providers/database/prompts/dynamic_rag_agent_xml_tooling.yaml @@ -1,6 +1,7 @@ dynamic_rag_agent_xml_tooling: template: | You are an AI research assistant with access to document retrieval tools. You should use both your internal knowledge store and web search tools to answer the user questions. Today is {date}. + If you have access to tools that help you set some filters to narrow down and speed up the search, use them before the search diff --git a/py/core/providers/database/prompts/static_rag_agent.yaml b/py/core/providers/database/prompts/static_rag_agent.yaml index 0e940af17..e244221b9 100644 --- a/py/core/providers/database/prompts/static_rag_agent.yaml +++ b/py/core/providers/database/prompts/static_rag_agent.yaml @@ -2,6 +2,7 @@ static_rag_agent: template: > ### You are a helpful agent that can search for information, the date is {date}. + If you have access to tools that help you set some filters, like smart_filter_tool, to narrow down and speed up the search, use them BEFORE the search When asked a question, YOU SHOULD ALWAYS USE YOUR SEARCH TOOL TO ATTEMPT TO SEARCH FOR RELEVANT INFORMATION THAT ANSWERS THE USER QUESTION. The response should contain line-item attributions to relevant search results, and be as informative if possible. diff --git a/py/core/providers/database/prompts/static_research_agent.yaml b/py/core/providers/database/prompts/static_research_agent.yaml index 417d161cd..d34a227bd 100644 --- a/py/core/providers/database/prompts/static_research_agent.yaml +++ b/py/core/providers/database/prompts/static_research_agent.yaml @@ -14,6 +14,10 @@ static_research_agent: Provide focused, precise, and strategic analyses. Clearly articulate cause-effect relationships, relevant context, and strategic significance. Prioritize accuracy, clarity, and concise insights. ## Research Guidance + - **RAG FIRST**: + - If you have access to any kind of rag tools, use them first to find sources that can help you narrowing the context. + - Fallback to your own knowledge base or other tools if you don't have access to rag tools or rag tools reply is not helpful. + - **Multi-thesis Approach (for qualitative/subjective queries):** - Identify and retrieve detailed information from credible sources covering multiple angles, including technical, economic, market-specific, geopolitical, psychological, and long-term strategic implications. - Seek contrasting viewpoints, expert opinions, market analyses, and nuanced discussions. diff --git a/py/sdk/sync_methods/retrieval.py b/py/sdk/sync_methods/retrieval.py index ff10ce1fa..26462defe 100644 --- a/py/sdk/sync_methods/retrieval.py +++ b/py/sdk/sync_methods/retrieval.py @@ -373,7 +373,7 @@ def agent( max_tool_context_length (Optional[int]): Maximum context length for tool replies. use_system_context (Optional[bool]): Whether to use system context in the prompt. rag_tools (Optional[list[str]]): List of tools to enable for RAG mode. - Available tools: "search_file_knowledge", "content", "web_search", "web_scrape", "search_file_descriptions". + Available tools: "search_file_knowledge", "content", "web_search", "web_scrape", "search_file_descriptions", "smart_filter_tool". research_tools (Optional[list[str]]): List of tools to enable for Research mode. Available tools: "rag", "reasoning", "critique", "python_executor". tools (Optional[list[str]]): Deprecated. List of tools to execute. diff --git a/py/shared/abstractions/search.py b/py/shared/abstractions/search.py index 163036cd9..6eadfd20a 100644 --- a/py/shared/abstractions/search.py +++ b/py/shared/abstractions/search.py @@ -252,6 +252,47 @@ def from_serper_results(cls, results: list[dict]) -> "WebSearchResult": ) +class SmartFilterResult(R2RSerializable): + collections: list[str] + filters: dict[str, Any] + prompt_mod: str + + class Config: + json_schema_extra = { + "example": { + "collections": ["collection1_id", "collection2_id"], + "filters": { + "$and": [ + { + "collection_ids": { + "$in": "['7c96bee1-d537-4b68-9ede-a0e5355ac957']" + } + }, + { + "metadata.category": { + "$in": "['7c96bee1-d537-4b68-9ede-a0e5355ac957']" + } + }, + ] + }, + "prompt_mod": "What is the capital of France?", + } + } + + def as_dict(self) -> dict: + return { + "collections": self.collections, + "filters": self.filters, + "prompt_mod": self.prompt_mod, + } + + def __str__(self) -> str: + return f"SmartFilterResult(collections={self.collections}, filters={self.filters}, prompt_mod={self.prompt_mod})" + + def __repr__(self) -> str: + return self.__str__() + + class AggregateSearchResult(R2RSerializable): """Result of an aggregate search operation.""" @@ -260,6 +301,7 @@ class AggregateSearchResult(R2RSerializable): web_page_search_results: Optional[list[WebPageSearchResult]] = None web_search_results: Optional[list[WebSearchResult]] = None document_search_results: Optional[list[DocumentResponse]] = None + smart_filter_result: Optional[SmartFilterResult] = None generic_tool_result: Optional[Any] = ( None # FIXME: Give this a proper generic type ) @@ -299,6 +341,11 @@ def as_dict(self) -> dict: if self.document_search_results else [] ), + "smart_filter_result": ( + self.smart_filter_result.as_dict() + if self.smart_filter_result + else None + ), "generic_tool_result": ( [result.to_dict() for result in self.generic_tool_result] if self.generic_tool_result @@ -379,6 +426,24 @@ class Config: }, } ], + "smart_filter_result": { + "collections": ["collection1_id", "collection2_id"], + "filters": { + "$and": [ + { + "collection_ids": { + "$in": "['7c96bee1-d537-4b68-9ede-a0e5355ac957']" + } + }, + { + "metadata.category": { + "$in": "['7c96bee1-d537-4b68-9ede-a0e5355ac957']" + } + }, + ] + }, + "prompt_mod": "What is the capital of France?", + }, "generic_tool_result": [ { "result": "Generic tool result", From d22b91a805399d5533beb574aa6b5455eae735eb Mon Sep 17 00:00:00 2001 From: nPeppon Date: Mon, 9 Jun 2025 17:59:21 +0200 Subject: [PATCH 3/4] Add Smolagent integration for enhanced retrieval capabilities - Introduced Smolagent tools and the R2RSmolRAGAgent for improved retrieval performance. - Updated the ToolRegistry to support discovery of Smolagent tools. - Enhanced the RetrievalRouter and RetrievalService to accommodate new modes including 'rag_smol'. - Added new tools for file content retrieval, semantic search, and metadata filtering specific to Smolagent. - Updated project dependencies to include smolagents package. --- py/core/__init__.py | 2 + .../built_in/search_file_descriptions.py | 3 + .../tools/built_in/search_file_knowledge.py | 1 + .../base/agent/tools/built_in/smart_filter.py | 1 + py/core/base/agent/tools/registry.py | 99 ++++++++++----- py/core/main/api/v3/retrieval_router.py | 2 +- py/core/main/services/retrieval_service.py | 37 ++++-- py/core/smolagent/__init__.py | 15 +++ py/core/smolagent/agent.py | 117 ++++++++++++++++++ py/core/smolagent/llm_provider/__init__.py | 0 .../smolagent/llm_provider/hf_llm_provider.py | 20 +++ py/core/smolagent/tools/__init__.py | 24 ++++ py/core/smolagent/tools/get_file_content.py | 43 +++++++ .../tools/search_file_descriptions.py | 51 ++++++++ .../smolagent/tools/search_file_knowledge.py | 51 ++++++++ py/core/smolagent/tools/smart_filter.py | 49 ++++++++ py/pyproject.toml | 1 + py/sdk/sync_methods/retrieval.py | 2 +- 18 files changed, 477 insertions(+), 41 deletions(-) create mode 100644 py/core/smolagent/__init__.py create mode 100644 py/core/smolagent/agent.py create mode 100644 py/core/smolagent/llm_provider/__init__.py create mode 100644 py/core/smolagent/llm_provider/hf_llm_provider.py create mode 100644 py/core/smolagent/tools/__init__.py create mode 100644 py/core/smolagent/tools/get_file_content.py create mode 100644 py/core/smolagent/tools/search_file_descriptions.py create mode 100644 py/core/smolagent/tools/search_file_knowledge.py create mode 100644 py/core/smolagent/tools/smart_filter.py diff --git a/py/core/__init__.py b/py/core/__init__.py index c27da75af..62799947a 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -6,6 +6,7 @@ from .main import * from .parsers import * from .providers import * +from .smolagent import * logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -173,4 +174,5 @@ "UnstructuredIngestionProvider", "R2RIngestionProvider", "ChunkingStrategy", + "R2RSmolRAGAgent", ] diff --git a/py/core/base/agent/tools/built_in/search_file_descriptions.py b/py/core/base/agent/tools/built_in/search_file_descriptions.py index 2cc0f4741..439e678a4 100644 --- a/py/core/base/agent/tools/built_in/search_file_descriptions.py +++ b/py/core/base/agent/tools/built_in/search_file_descriptions.py @@ -36,6 +36,9 @@ async def execute(self, query: str, *args, **kwargs): """ Calls the file_search_method from context. """ + logger.debug( + f"Executing SearchFileDescriptionsTool with query: {query}" + ) from core.base.abstractions import AggregateSearchResult context = self.context diff --git a/py/core/base/agent/tools/built_in/search_file_knowledge.py b/py/core/base/agent/tools/built_in/search_file_knowledge.py index 1107aee30..377b41a7c 100644 --- a/py/core/base/agent/tools/built_in/search_file_knowledge.py +++ b/py/core/base/agent/tools/built_in/search_file_knowledge.py @@ -35,6 +35,7 @@ async def execute(self, query: str, *args, **kwargs): """ Calls the knowledge_search_method from context. """ + logger.debug(f"Executing SearchFileKnowledgeTool with query: {query}") from core.base.abstractions import AggregateSearchResult context = self.context diff --git a/py/core/base/agent/tools/built_in/smart_filter.py b/py/core/base/agent/tools/built_in/smart_filter.py index 2720eb98d..a7bc4989d 100644 --- a/py/core/base/agent/tools/built_in/smart_filter.py +++ b/py/core/base/agent/tools/built_in/smart_filter.py @@ -39,6 +39,7 @@ async def execute(self, query: str, *args, **kwargs): """ Uses the LLM to analyze the query and available collections, returning collection IDs, filters, and prompt_mod. """ + logger.debug(f"Executing SmartFilterTool with query: {query}") from core.base.abstractions import ( AggregateSearchResult, diff --git a/py/core/base/agent/tools/registry.py b/py/core/base/agent/tools/registry.py index 526e1956b..84179f969 100644 --- a/py/core/base/agent/tools/registry.py +++ b/py/core/base/agent/tools/registry.py @@ -21,6 +21,7 @@ def __init__( self, built_in_path: str | None = None, user_tools_path: str | None = None, + smolagent_tools_path: str | None = None, ): self.built_in_path = built_in_path or os.path.join( os.path.dirname(os.path.abspath(__file__)), "built_in" @@ -30,10 +31,15 @@ def __init__( or os.getenv("R2R_USER_TOOLS_PATH") or "../docker/user_tools" ) + self.smolagent_tools_path = smolagent_tools_path or os.path.join( + os.path.dirname(os.path.abspath(__file__)), + "../../../smolagent/tools", + ) # Tool storage self._built_in_tools: dict[str, Type[Tool]] = {} self._user_tools: dict[str, Type[Tool]] = {} + self._smolagent_tools: dict[str, Type[Tool]] = {} # Discover tools self._discover_built_in_tools() @@ -43,37 +49,56 @@ def __init__( logger.warning( f"User tools directory not found: {self.user_tools_path}" ) - - def _discover_built_in_tools(self): - """Load all built-in tools from the built_in directory.""" - if not os.path.exists(self.built_in_path): + if os.path.exists(self.smolagent_tools_path): + self._discover_smolagent_tools() + else: logger.warning( - f"Built-in tools directory not found: {self.built_in_path}" + f"Smolagent tools directory not found: {self.smolagent_tools_path}" ) + + def _discover_tools_from_path( + self, path: str, registry: dict, package_prefix: str = "" + ): + if not os.path.exists(path): + logger.warning(f"Tools directory not found: {path}") return # Add to Python path if needed - if self.built_in_path not in sys.path: - sys.path.append(os.path.dirname(self.built_in_path)) + if path not in sys.path: + sys.path.append(os.path.dirname(path)) - # Import the built_in package - try: - built_in_pkg = importlib.import_module("built_in") - except ImportError: - logger.error("Failed to import built_in tools package") - return - - # Discover all modules in the package - for _, module_name, is_pkg in pkgutil.iter_modules( - [self.built_in_path] - ): - if is_pkg: # Skip subpackages + # Import the package if a prefix is given + prefix_valid = True + if package_prefix: + try: + importlib.import_module(package_prefix) + except ImportError as e: + logger.warning( + f"Failed to import tools package with prefix only: {package_prefix}\t" + f"Error: {e}" + ) + try: + # prefix works if it is local package, let's also fetch by path + tools_pkg_name = os.path.basename(path) + importlib.import_module(tools_pkg_name) + prefix_valid = False + except ImportError as e: + logger.warning( + f"Failed to import tools package: {tools_pkg_name}\t" + f"Error: {e}" + ) + return + + for _, module_name, is_pkg in pkgutil.iter_modules([path]): + if is_pkg: continue - try: - module = importlib.import_module(f"built_in.{module_name}") - - # Find all tool classes in the module + module_path = f"{module_name}" + if prefix_valid and package_prefix: + module_path = f"{package_prefix}.{module_name}" + elif not prefix_valid and package_prefix: + module_path = f"{tools_pkg_name}.{module_name}" + module = importlib.import_module(module_path) for name, obj in inspect.getmembers(module, inspect.isclass): if ( issubclass(obj, Tool) @@ -81,19 +106,22 @@ def _discover_built_in_tools(self): and obj != Tool ): try: - tool_instance = obj() - self._built_in_tools[tool_instance.name] = obj + tool_instance = obj() # type: ignore + registry[tool_instance.name] = obj logger.debug( - f"Loaded built-in tool: {tool_instance.name}" + f"Loaded tool: {tool_instance.name} from {path}" ) except Exception as e: logger.error( - f"Error instantiating built-in tool {name}: {e}" + f"Error instantiating tool {name}: {e}" ) except Exception as e: - logger.error( - f"Error loading built-in tool module {module_name}: {e}" - ) + logger.error(f"Error loading tool module {module_name}: {e}") + + def _discover_built_in_tools(self): + self._discover_tools_from_path( + self.built_in_path, self._built_in_tools, "built_in" + ) def _discover_user_tools(self): """Scan the user tools directory for custom tools.""" @@ -142,8 +170,17 @@ def _discover_user_tools(self): f"Error loading user tool module {module_name}: {e}" ) + def _discover_smolagent_tools(self): + self._discover_tools_from_path( + self.smolagent_tools_path, self._smolagent_tools, "smolagent.tools" + ) + def get_tool_class(self, tool_name: str): - """Get a tool class by name.""" + """Get a tool class by name. + If the tool is a smolagent tool, it will return the R2R wrapper. + """ + if tool_name in self._smolagent_tools: + return self._smolagent_tools[tool_name] if tool_name in self._user_tools: return self._user_tools[tool_name] diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index 6d9850597..e1172b933 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -619,7 +619,7 @@ async def agent_app( description="Use extended prompt for generation", ), # FIXME: We need a more generic way to handle this - mode: Optional[Literal["rag", "research"]] = Body( + mode: Optional[Literal["rag", "research", "rag_smol"]] = Body( default="rag", description="Mode to use for generation: 'rag' for standard retrieval or 'research' for deep analysis with reasoning capabilities", ), diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index 29332cfde..c32178821 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -38,6 +38,7 @@ ) from core.base.agent.tools.registry import ToolRegistry from core.base.api.models import RAGResponse, User +from core.smolagent.agent import R2RSmolRAGAgent from core.utils import ( CitationTracker, SearchResultsCollector, @@ -65,7 +66,7 @@ class AgentFactory: @staticmethod def create_agent( - mode: Literal["rag", "research"], + mode: Literal["rag", "research", "rag_smol"], database_provider, llm_provider, config, # : AgentConfig @@ -84,7 +85,7 @@ def create_agent( Creates and returns the appropriate agent based on provided parameters. Args: - mode: Either "rag" or "research" to determine agent type + mode: Either "rag", "research", or "rag_smol" to determine agent type database_provider: Provider for database operations llm_provider: Provider for LLM operations config: Agent configuration @@ -107,12 +108,14 @@ def create_agent( tool_registry = ToolRegistry() # Handle tool specifications based on mode - if mode == "rag": + if mode == "rag" or mode == "rag_smol": # For RAG mode, prioritize explicitly passed rag_tools, then tools, then config defaults if rag_tools: agent_config.rag_tools = rag_tools + logger.debug(f"RAG tools: {rag_tools}") elif tools: # Backward compatibility agent_config.rag_tools = tools + logger.debug(f"Tools: {tools}") # If neither was provided, the config's default rag_tools will be used elif mode == "research": # For Research mode, prioritize explicitly passed research_tools, then tools, then config defaults @@ -187,7 +190,7 @@ def create_agent( file_search_method=file_search_method, tool_registry=tool_registry, ) - else: + elif mode == "research": # Research mode agents if is_streaming: if use_xml_format: @@ -243,6 +246,19 @@ def create_agent( content_method=content_method, file_search_method=file_search_method, ) + elif mode == "rag_smol": + return R2RSmolRAGAgent( + database_provider=database_provider, + llm_provider=llm_provider, + config=agent_config, + search_settings=search_settings, + rag_generation_config=generation_config, + max_tool_context_length=max_tool_context_length, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + tool_registry=tool_registry, + ) class RetrievalService(Service): @@ -1281,7 +1297,7 @@ async def agent( research_tools: Optional[list[str]] = None, research_generation_config: Optional[GenerationConfig] = None, needs_initial_conversation_name: Optional[bool] = None, - mode: Optional[Literal["rag", "research"]] = "rag", + mode: Optional[Literal["rag", "research", "rag_smol"]] = "rag", ): """ Engage with an intelligent agent for information retrieval, analysis, and research. @@ -1365,9 +1381,12 @@ async def agent( if mode == "research" and research_generation_config: effective_generation_config = research_generation_config + logger.debug( + f"Effective generation config: {effective_generation_config}" + ) # Set appropriate LLM model based on mode if not explicitly specified if "model" not in effective_generation_config.model_fields_set: - if mode == "rag": + if mode == "rag" or mode == "rag_smol": effective_generation_config.model = ( self.config.app.quality_llm ) @@ -1375,7 +1394,9 @@ async def agent( effective_generation_config.model = ( self.config.app.planning_llm ) - + logger.debug( + f"Effective generation config after model set: {effective_generation_config}" + ) # Transform UUID filters to strings for filter_key, value in search_settings.filters.items(): if isinstance(value, UUID): @@ -1476,7 +1497,7 @@ async def agent( # Configure agent with appropriate tools agent_config = deepcopy(self.config.agent) - if mode == "rag": + if mode == "rag" or mode == "rag_smol": # Use provided RAG tools or default from config agent_config.rag_tools = ( rag_tools or tools or self.config.agent.rag_tools diff --git a/py/core/smolagent/__init__.py b/py/core/smolagent/__init__.py new file mode 100644 index 000000000..e9111cd91 --- /dev/null +++ b/py/core/smolagent/__init__.py @@ -0,0 +1,15 @@ +from .agent import R2RSmolRAGAgent +from .tools import ( + SmolGetFileContentTool, + SmolSearchFileDescriptionsTool, + SmolSearchFileKnowledgeTool, + SmolSmartFilterTool, +) + +__all__ = [ + "R2RSmolRAGAgent", + "SmolSmartFilterTool", + "SmolSearchFileKnowledgeTool", + "SmolGetFileContentTool", + "SmolSearchFileDescriptionsTool", +] diff --git a/py/core/smolagent/agent.py b/py/core/smolagent/agent.py new file mode 100644 index 000000000..adfcd7a87 --- /dev/null +++ b/py/core/smolagent/agent.py @@ -0,0 +1,117 @@ +import logging +from typing import Any + +from smolagents import CodeAgent + +from core import R2RRAGAgent +from core.base import Message +from core.smolagent.llm_provider.hf_llm_provider import ( + fetch_hf_inference_from_model, +) + +logger = logging.getLogger(__name__) + + +def smol_to_r2r_messages(smol_run_result: Any) -> list[Message]: + """ + Convert smolagents RunResult.messages (list of dicts) to list of R2R Message objects. + """ + logger.debug( + f"Converting smol run result to R2R messages: {smol_run_result}" + ) + r2r_messages = [] + if hasattr(smol_run_result, "messages"): + for msg in smol_run_result.messages: + # Defensive: Only use fields that Message accepts + role = msg.get("role", "assistant") + content = msg.get("content", "") + # Optionally pass other fields if R2R Message supports them + r2r_messages.append(Message(role=role, content=content)) + else: + r2r_messages = [ + Message(role="assistant", content=str(smol_run_result)) + ] + logger.debug(f"R2R messages: {r2r_messages}") + return r2r_messages + + +class R2RSmolRAGAgent(R2RRAGAgent): + def __init__( + self, + database_provider, + llm_provider, + config, + search_settings, + rag_generation_config, + knowledge_search_method, + content_method, + file_search_method, + tool_registry=None, + max_tool_context_length=20000, + **kwargs, + ): + # Prefix all tool names in rag_tools with 'smol_' + if hasattr(config, "rag_tools") and config.rag_tools: + config.rag_tools = [ + f"smol_{name}" if not name.startswith("smol_") else name + for name in config.rag_tools + ] + logger.debug(f"Smol RAG tools: {config.rag_tools}") + super().__init__( + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + tool_registry=tool_registry, + max_tool_context_length=max_tool_context_length, + **kwargs, + ) + + # register tools method called in super init will fill _tools with the tools from the registry + self._smol_tools = [ + tool._smol_tool + for tool in getattr(self, "_tools", []) + if hasattr(tool, "_smol_tool") + ] + logger.debug(f"Smol tools: {self._smol_tools}") + + # logger.debug(f"Config: {config}") + # logger.debug(f"Fetching HF model: {rag_generation_config}") + self._hf_model = fetch_hf_inference_from_model( + rag_generation_config.model + ) + self.hf_agent = CodeAgent( + tools=self._smol_tools, + model=self._hf_model, + use_structured_outputs_internally=True, + add_base_tools=False, + ) + self.update_system_prompt() + + def run(self, messages: list[Message], **kwargs): + # logger.debug(f"Running smol agent with messages: {messages}") + message = messages[0].content + # logger.debug(f"Message: {message}") + return self.hf_agent.run(message) + + async def arun(self, messages: list[Message], **kwargs): + logger.debug(f"Running smol agent with messages: {messages}") + message = messages[0].content + logger.debug(f"Message: {message}") + # Non stream version + result = self.hf_agent.run(message) + return smol_to_r2r_messages(result) + + def update_system_prompt(self): + if hasattr(self, "hf_agent"): + # logger.debug(f"Old system prompt: {self.hf_agent.system_prompt}") + custom_system_prompt_addition = ( + "\n\nWhen asked a question, YOU SHOULD ALWAYS USE YOUR SEARCH TOOL TO ATTEMPT TO SEARCH FOR RELEVANT INFORMATION THAT ANSWERS THE USER QUESTION BUT" + " if you have access to tools that help you set some filters, like smol_smart_filter_tool, to narrow down and speed up the search, use them BEFORE the search" + ) + self.hf_agent.system_prompt += custom_system_prompt_addition + # logger.debug(f"Updated system prompt: {self.hf_agent.system_prompt}") diff --git a/py/core/smolagent/llm_provider/__init__.py b/py/core/smolagent/llm_provider/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/py/core/smolagent/llm_provider/hf_llm_provider.py b/py/core/smolagent/llm_provider/hf_llm_provider.py new file mode 100644 index 000000000..f3c8e6d89 --- /dev/null +++ b/py/core/smolagent/llm_provider/hf_llm_provider.py @@ -0,0 +1,20 @@ +import logging + +from smolagents import Model + +logger = logging.getLogger(__name__) + + +def fetch_hf_inference_from_model(model_name: str) -> Model: + logger.debug(f"Fetching HF inference from model: {model_name}") + if "gpt" in model_name: + from smolagents import OpenAIServerModel + + # Initialize the model with our reverse proxy + # remove openai/ prefix if there is one + model_id = model_name.replace("openai/", "") + return OpenAIServerModel( + model_id=model_id, + ) + else: + raise ValueError(f"Model {model_name} is not supported") diff --git a/py/core/smolagent/tools/__init__.py b/py/core/smolagent/tools/__init__.py new file mode 100644 index 000000000..91446b155 --- /dev/null +++ b/py/core/smolagent/tools/__init__.py @@ -0,0 +1,24 @@ +from .get_file_content import ( + SmolGetFileContentTool, + SmolGetFileContentToolR2RWrapper, +) +from .search_file_descriptions import ( + SmolSearchFileDescriptionsTool, + SmolSearchFileDescriptionsToolR2RWrapper, +) +from .search_file_knowledge import ( + SmolSearchFileKnowledgeTool, + SmolSearchFileKnowledgeToolR2RWrapper, +) +from .smart_filter import SmolSmartFilterTool, SmolSmartFilterToolR2RWrapper + +__all__ = [ + "SmolSmartFilterTool", + "SmolSmartFilterToolR2RWrapper", + "SmolSearchFileKnowledgeTool", + "SmolSearchFileKnowledgeToolR2RWrapper", + "SmolGetFileContentTool", + "SmolGetFileContentToolR2RWrapper", + "SmolSearchFileDescriptionsTool", + "SmolSearchFileDescriptionsToolR2RWrapper", +] diff --git a/py/core/smolagent/tools/get_file_content.py b/py/core/smolagent/tools/get_file_content.py new file mode 100644 index 000000000..2234217fb --- /dev/null +++ b/py/core/smolagent/tools/get_file_content.py @@ -0,0 +1,43 @@ +from smolagents import Tool as SmolTool + +from core.base.agent.tools.built_in.get_file_content import GetFileContentTool +from shared.abstractions.tool import Tool as R2RTool + + +class SmolGetFileContentTool(SmolTool): + name = "smol_get_file_content" + description = "Fetches the complete contents of a user document from the local database by document ID." + inputs = { + "document_id": { + "type": "string", + "description": "The unique UUID of the document to fetch.", + } + } + output_type = "object" + parameters = inputs + + def __init__(self, context, r2r_tool: R2RTool, **kwargs): + super().__init__() + self.context = context + self.results_function = self.forward + self._r2r_tool = r2r_tool + + async def forward(self, document_id: str): + return await self._r2r_tool.execute(document_id) + + +class SmolGetFileContentToolR2RWrapper(R2RTool): + def __init__(self): + super().__init__( + name="smol_get_file_content", + description="Wrapper for the SmolGetFileContentTool to be used in R2R", + results_function=self.execute, + ) + self._r2r_tool = GetFileContentTool() + self._smol_tool = SmolGetFileContentTool( + context=self, r2r_tool=self._r2r_tool + ) + + async def execute(self, document_id: str, **kwargs): + # Placeholder function, not to be used + pass diff --git a/py/core/smolagent/tools/search_file_descriptions.py b/py/core/smolagent/tools/search_file_descriptions.py new file mode 100644 index 000000000..d3ca598d1 --- /dev/null +++ b/py/core/smolagent/tools/search_file_descriptions.py @@ -0,0 +1,51 @@ +import logging + +from smolagents import Tool as SmolTool + +from core.base.agent.tools.built_in.search_file_descriptions import ( + SearchFileDescriptionsTool, +) +from shared.abstractions.tool import Tool as R2RTool + +logger = logging.getLogger(__name__) + + +class SmolSearchFileDescriptionsTool(SmolTool): + name = "smol_search_file_descriptions" + description = "Semantic search over AI-generated summaries of stored documents. Use for a broad overview of relevant files." + inputs = { + "query": { + "type": "string", + "description": "Query string to semantic search over available files.", + } + } + output_type = "object" + parameters = inputs + + def __init__(self, context, r2r_tool: R2RTool, **kwargs): + super().__init__() + self.context = context + self.results_function = self.forward + self._r2r_tool = r2r_tool + + async def forward(self, query: str): + return await self._r2r_tool.execute(query) + + +class SmolSearchFileDescriptionsToolR2RWrapper(R2RTool): + def __init__(self): + super().__init__( + name="smol_search_file_descriptions", + description="Wrapper for the SmolSearchFileDescriptionsTool to be used in R2R", + results_function=self.execute, + ) + self._r2r_tool = SearchFileDescriptionsTool() + self._smol_tool = SmolSearchFileDescriptionsTool( + context=self, r2r_tool=self._r2r_tool + ) + + async def execute(self, query: str, **kwargs): + logger.debug( + f"Executing SmolSearchFileDescriptionsToolR2RWrapper with query: {query}" + ) + return await self._r2r_tool.execute(query) diff --git a/py/core/smolagent/tools/search_file_knowledge.py b/py/core/smolagent/tools/search_file_knowledge.py new file mode 100644 index 000000000..251fa41f5 --- /dev/null +++ b/py/core/smolagent/tools/search_file_knowledge.py @@ -0,0 +1,51 @@ +import logging + +from smolagents import Tool as SmolTool + +from core.base.agent.tools.built_in.search_file_knowledge import ( + SearchFileKnowledgeTool, +) +from shared.abstractions.tool import Tool as R2RTool + +logger = logging.getLogger(__name__) + + +class SmolSearchFileKnowledgeTool(SmolTool): + name = "smol_search_file_knowledge" + description = "Search your local knowledge base using the R2R system. Use this for relevant text chunks or knowledge graph data." + inputs = { + "query": { + "type": "string", + "description": "User query to search in the local DB.", + } + } + output_type = "object" + parameters = inputs + + def __init__(self, context, r2r_tool: R2RTool, **kwargs): + super().__init__() + self.context = context + self.results_function = self.forward + self._r2r_tool = r2r_tool + + async def forward(self, query: str): + return await self._r2r_tool.execute(query) + + +class SmolSearchFileKnowledgeToolR2RWrapper(R2RTool): + def __init__(self): + super().__init__( + name="smol_search_file_knowledge", + description="Wrapper for the SmolSearchFileKnowledgeTool to be used in R2R", + results_function=self.execute, + ) + self._r2r_tool = SearchFileKnowledgeTool() + self._smol_tool = SmolSearchFileKnowledgeTool( + context=self, r2r_tool=self._r2r_tool + ) + + async def execute(self, query: str, **kwargs): + logger.debug( + f"Executing SmolSearchFileKnowledgeToolR2RWrapper with query: {query}" + ) + return await self._r2r_tool.execute(query) diff --git a/py/core/smolagent/tools/smart_filter.py b/py/core/smolagent/tools/smart_filter.py new file mode 100644 index 000000000..798e2191b --- /dev/null +++ b/py/core/smolagent/tools/smart_filter.py @@ -0,0 +1,49 @@ +import logging + +from smolagents import Tool as SmolTool + +from core.base.agent.tools.built_in.smart_filter import SmartFilterTool +from shared.abstractions.tool import Tool as R2RTool + +logger = logging.getLogger(__name__) + + +class SmolSmartFilterTool(SmolTool): + name = "smol_smart_filter" + description = "Refines metadata and collection filters for a RAG search using LLM analysis. To be used BEFORE the rag search" + inputs = { + "query": { + "type": "string", + "description": "The user query to analyze.", + } + } + output_type = "object" + parameters = inputs + + def __init__(self, context, r2r_tool: R2RTool, **kwargs): + super().__init__() + self.context = context + self.results_function = self.forward + self._r2r_tool = r2r_tool + + async def forward(self, query: str): + return await self._r2r_tool.execute(query) + + +class SmolSmartFilterToolR2RWrapper(R2RTool): + def __init__(self): + super().__init__( + name="smol_smart_filter_tool", + description="Wrapper for the SmolSmartFilterTool to be used in R2R", + results_function=self.execute, + ) + self._r2r_tool = SmartFilterTool() + self._smol_tool = SmolSmartFilterTool( + context=self, r2r_tool=self._r2r_tool + ) + + async def execute(self, query: str, **kwargs): + logger.debug( + f"Executing SmolSmartFilterToolR2RWrapper with query: {query}" + ) + return await self._r2r_tool.execute(query) diff --git a/py/pyproject.toml b/py/pyproject.toml index 21aa1c813..35a24987a 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "pydantic>=2.10.6", "python-json-logger>=3.2.1", "filetype>=1.2.0", + "smolagents>=1.17.0,<2.0.0", ] [project.optional-dependencies] diff --git a/py/sdk/sync_methods/retrieval.py b/py/sdk/sync_methods/retrieval.py index 26462defe..ac03ed923 100644 --- a/py/sdk/sync_methods/retrieval.py +++ b/py/sdk/sync_methods/retrieval.py @@ -377,7 +377,7 @@ def agent( research_tools (Optional[list[str]]): List of tools to enable for Research mode. Available tools: "rag", "reasoning", "critique", "python_executor". tools (Optional[list[str]]): Deprecated. List of tools to execute. - mode (Optional[str]): Mode to use for generation: "rag" for standard retrieval or "research" for deep analysis. + mode (Optional[str]): Mode to use for generation: "rag", "rag_smol" for smolagent, or "research" for deep analysis. Defaults to "rag". Returns: From 32caecf9a2383fee240a05403f89be66ea31c74c Mon Sep 17 00:00:00 2001 From: nPeppon Date: Tue, 10 Jun 2025 15:45:41 +0200 Subject: [PATCH 4/4] Refactor agent architecture and remove Smolagent integration - Replaced Smolagent tools with the new RAGPydAgent for enhanced retrieval capabilities. - Updated ToolRegistry to remove references to Smolagent tools and streamline tool discovery. - Modified RetrievalRouter and RetrievalService to support the new 'rag_pyd' mode. - Introduced new Pydantic-based tools for improved functionality and integration. - Updated project dependencies to reflect changes in agent architecture. --- py/core/__init__.py | 2 - py/core/agent/__init__.py | 3 + py/core/agent/rag_pyd.py | 124 ++++++++++++++++++ .../agent/tools/built_in/get_file_content.py | 10 ++ .../built_in/search_file_descriptions.py | 10 ++ .../tools/built_in/search_file_knowledge.py | 10 ++ .../base/agent/tools/built_in/smart_filter.py | 10 ++ .../agent/tools/built_in/tavily_extract.py | 10 ++ .../agent/tools/built_in/tavily_search.py | 10 ++ .../base/agent/tools/built_in/web_scrape.py | 10 ++ .../base/agent/tools/built_in/web_search.py | 10 ++ py/core/base/agent/tools/registry.py | 99 +++++--------- py/core/main/api/v3/retrieval_router.py | 2 +- py/core/main/services/retrieval_service.py | 18 +-- py/core/smolagent/__init__.py | 15 --- py/core/smolagent/agent.py | 117 ----------------- py/core/smolagent/llm_provider/__init__.py | 0 .../smolagent/llm_provider/hf_llm_provider.py | 20 --- py/core/smolagent/tools/__init__.py | 24 ---- py/core/smolagent/tools/get_file_content.py | 43 ------ .../tools/search_file_descriptions.py | 51 ------- .../smolagent/tools/search_file_knowledge.py | 51 ------- py/core/smolagent/tools/smart_filter.py | 49 ------- py/pyproject.toml | 6 +- py/sdk/sync_methods/retrieval.py | 2 +- 25 files changed, 252 insertions(+), 454 deletions(-) create mode 100644 py/core/agent/rag_pyd.py delete mode 100644 py/core/smolagent/__init__.py delete mode 100644 py/core/smolagent/agent.py delete mode 100644 py/core/smolagent/llm_provider/__init__.py delete mode 100644 py/core/smolagent/llm_provider/hf_llm_provider.py delete mode 100644 py/core/smolagent/tools/__init__.py delete mode 100644 py/core/smolagent/tools/get_file_content.py delete mode 100644 py/core/smolagent/tools/search_file_descriptions.py delete mode 100644 py/core/smolagent/tools/search_file_knowledge.py delete mode 100644 py/core/smolagent/tools/smart_filter.py diff --git a/py/core/__init__.py b/py/core/__init__.py index 62799947a..c27da75af 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -6,7 +6,6 @@ from .main import * from .parsers import * from .providers import * -from .smolagent import * logger = logging.getLogger() logger.setLevel(logging.INFO) @@ -174,5 +173,4 @@ "UnstructuredIngestionProvider", "R2RIngestionProvider", "ChunkingStrategy", - "R2RSmolRAGAgent", ] diff --git a/py/core/agent/__init__.py b/py/core/agent/__init__.py index bd6dda79b..6758853cb 100644 --- a/py/core/agent/__init__.py +++ b/py/core/agent/__init__.py @@ -10,6 +10,7 @@ R2RXMLToolsRAGAgent, R2RXMLToolsStreamingRAGAgent, ) +from .rag_pyd import RAGPydAgent # Import the concrete implementations from .research import ( @@ -33,4 +34,6 @@ "R2RStreamingResearchAgent", "R2RXMLToolsResearchAgent", "R2RXMLToolsStreamingResearchAgent", + # Pydantic Agents + "RAGPydAgent", ] diff --git a/py/core/agent/rag_pyd.py b/py/core/agent/rag_pyd.py new file mode 100644 index 000000000..072c30652 --- /dev/null +++ b/py/core/agent/rag_pyd.py @@ -0,0 +1,124 @@ +import logging +from typing import Callable, Optional + +from pydantic_ai import Agent as PydanticAgent + +from core.base import ( + Message, +) +from core.base.abstractions import ( + GenerationConfig, + SearchSettings, +) +from core.base.agent.tools.registry import ToolRegistry +from core.base.providers import DatabaseProvider +from core.providers import ( + AnthropicCompletionProvider, + LiteLLMCompletionProvider, + OpenAICompletionProvider, + R2RCompletionProvider, +) + +from ..base.agent.agent import RAGAgentConfig # type: ignore +from .rag import R2RRAGAgent # type: ignore + +logger = logging.getLogger(__name__) + +INSTRUCTIONS = """ +You are a helpful agent that can search for information, the date is {date}. + +If you have access to tools that help you set some filters, like smart_filter_tool, to narrow down and speed up the search, use them BEFORE the search +When asked a question, YOU SHOULD ALWAYS USE YOUR SEARCH TOOL TO ATTEMPT TO SEARCH FOR RELEVANT INFORMATION THAT ANSWERS THE USER QUESTION. + +The response should contain line-item attributions to relevant search results, and be as informative if possible. + +If no relevant results are found, then state that no results were found. If no obvious question is present, then do not carry out a search, and instead ask for clarification. + +REMINDER - Use line item references to like [c910e2e], [b12cd2f], to refer to the specific search result IDs returned in the provided context. +""" + + +def pydantic_to_r2r_message(pydantic_response) -> list[Message]: + messages = [] + logger.debug(f"Pydantic response: {pydantic_response}") + try: + all_messages = pydantic_response.all_messages + logger.debug(f"All messages: {all_messages}") + except Exception as e: + logger.error(f"Error getting all messages: {e}") + all_messages = [] + if hasattr(pydantic_response, "output"): + messages.append( + Message( + role="assistant", + content=pydantic_response.output, + ) + ) + return messages + + +class RAGPydAgent(R2RRAGAgent): + def __init__( + self, + database_provider: DatabaseProvider, + llm_provider: ( + AnthropicCompletionProvider + | LiteLLMCompletionProvider + | OpenAICompletionProvider + | R2RCompletionProvider + ), + config: RAGAgentConfig, + search_settings: SearchSettings, + rag_generation_config: GenerationConfig, + knowledge_search_method: Callable, + content_method: Callable, + file_search_method: Callable, + tool_registry: Optional[ToolRegistry] = None, + max_tool_context_length: int = 20_000, + **kwargs, + ): + super().__init__( + database_provider=database_provider, + llm_provider=llm_provider, + config=config, + search_settings=search_settings, + rag_generation_config=rag_generation_config, + knowledge_search_method=knowledge_search_method, + content_method=content_method, + file_search_method=file_search_method, + tool_registry=tool_registry, + max_tool_context_length=max_tool_context_length, + **kwargs, + ) + # Init pydantic agent + self._pydantic_tools = [ + tool._pydantic_ai_tool + for tool in getattr(self, "_tools", []) + if hasattr(tool, "_pydantic_ai_tool") + ] + self._pydantic_agent = PydanticAgent( + model=self.get_pyd_ai_model_name(rag_generation_config.model), + tools=self._pydantic_tools, + name="R2R Pydantic Agent", + instructions=INSTRUCTIONS, + ) + + def get_pyd_ai_model_name(self, model_name: str | None): + logger.debug(f"Fetching model name from: {model_name}") + if not model_name: + raise ValueError("Model name is required") + if "gpt" in model_name: + # Initialize the model with our reverse proxy + # remove openai/ prefix if there is one + model_id = model_name.replace("openai/", "") + return model_id + else: + raise ValueError(f"Model {model_name} is not supported") + + async def arun(self, messages: list[Message], **kwargs): + # logger.debug(f"Running pydantic agent with messages: {messages}") + message = messages[0].content + # logger.debug(f"Message: {message}") + py_response = await self._pydantic_agent.run(message) + r2r_response = pydantic_to_r2r_message(py_response) + return r2r_response diff --git a/py/core/base/agent/tools/built_in/get_file_content.py b/py/core/base/agent/tools/built_in/get_file_content.py index ce4e08181..0b1bc7bd2 100644 --- a/py/core/base/agent/tools/built_in/get_file_content.py +++ b/py/core/base/agent/tools/built_in/get_file_content.py @@ -2,6 +2,8 @@ from typing import Any, Optional from uuid import UUID +from pydantic_ai import Tool as PydanticTool + from shared.abstractions.tool import Tool logger = logging.getLogger(__name__) @@ -38,6 +40,14 @@ def __init__(self): results_function=self.execute, llm_format_function=None, ) + pyd_params = self.parameters.copy() + pyd_params["additionalProperties"] = False + self._pydantic_ai_tool = PydanticTool.from_schema( + function=self.execute, + name=self.name, + description=self.description, + json_schema=pyd_params, + ) async def execute( self, diff --git a/py/core/base/agent/tools/built_in/search_file_descriptions.py b/py/core/base/agent/tools/built_in/search_file_descriptions.py index 439e678a4..4665474be 100644 --- a/py/core/base/agent/tools/built_in/search_file_descriptions.py +++ b/py/core/base/agent/tools/built_in/search_file_descriptions.py @@ -1,5 +1,7 @@ import logging +from pydantic_ai import Tool as PydanticTool + from shared.abstractions.tool import Tool logger = logging.getLogger(__name__) @@ -31,6 +33,14 @@ def __init__(self): results_function=self.execute, llm_format_function=None, ) + pyd_params = self.parameters.copy() + pyd_params["additionalProperties"] = False + self._pydantic_ai_tool = PydanticTool.from_schema( + function=self.execute, + name=self.name, + description=self.description, + json_schema=pyd_params, + ) async def execute(self, query: str, *args, **kwargs): """ diff --git a/py/core/base/agent/tools/built_in/search_file_knowledge.py b/py/core/base/agent/tools/built_in/search_file_knowledge.py index 377b41a7c..6f06d3d09 100644 --- a/py/core/base/agent/tools/built_in/search_file_knowledge.py +++ b/py/core/base/agent/tools/built_in/search_file_knowledge.py @@ -1,5 +1,7 @@ import logging +from pydantic_ai import Tool as PydanticTool + from shared.abstractions.tool import Tool logger = logging.getLogger(__name__) @@ -30,6 +32,14 @@ def __init__(self): results_function=self.execute, llm_format_function=None, ) + pyd_params = self.parameters.copy() + pyd_params["additionalProperties"] = False + self._pydantic_ai_tool = PydanticTool.from_schema( + function=self.execute, + name=self.name, + description=self.description, + json_schema=pyd_params, + ) async def execute(self, query: str, *args, **kwargs): """ diff --git a/py/core/base/agent/tools/built_in/smart_filter.py b/py/core/base/agent/tools/built_in/smart_filter.py index a7bc4989d..68f8393a1 100644 --- a/py/core/base/agent/tools/built_in/smart_filter.py +++ b/py/core/base/agent/tools/built_in/smart_filter.py @@ -1,6 +1,8 @@ import logging from typing import Any +from pydantic_ai import Tool as PydanticTool + from shared.abstractions.tool import Tool logger = logging.getLogger(__name__) @@ -34,6 +36,14 @@ def __init__(self): results_function=self.execute, llm_format_function=None, ) + pyd_params = self.parameters.copy() + pyd_params["additionalProperties"] = False + self._pydantic_ai_tool = PydanticTool.from_schema( + function=self.execute, + name=self.name, + description=self.description, + json_schema=pyd_params, + ) async def execute(self, query: str, *args, **kwargs): """ diff --git a/py/core/base/agent/tools/built_in/tavily_extract.py b/py/core/base/agent/tools/built_in/tavily_extract.py index 6cf6ab9bb..a204a3be5 100644 --- a/py/core/base/agent/tools/built_in/tavily_extract.py +++ b/py/core/base/agent/tools/built_in/tavily_extract.py @@ -1,5 +1,7 @@ import logging +from pydantic_ai import Tool as PydanticTool + from core.utils import ( generate_id, ) @@ -37,6 +39,14 @@ def __init__(self): results_function=self.execute, llm_format_function=None, ) + pyd_params = self.parameters.copy() + pyd_params["additionalProperties"] = False + self._pydantic_ai_tool = PydanticTool.from_schema( + function=self.execute, + name=self.name, + description=self.description, + json_schema=pyd_params, + ) async def execute(self, url: str, *args, **kwargs): """ diff --git a/py/core/base/agent/tools/built_in/tavily_search.py b/py/core/base/agent/tools/built_in/tavily_search.py index d849846c4..a9e94111b 100644 --- a/py/core/base/agent/tools/built_in/tavily_search.py +++ b/py/core/base/agent/tools/built_in/tavily_search.py @@ -1,5 +1,7 @@ import logging +from pydantic_ai import Tool as PydanticTool + from core.utils import ( generate_id, ) @@ -41,6 +43,14 @@ def __init__(self): results_function=self.execute, llm_format_function=None, ) + pyd_params = self.parameters.copy() + pyd_params["additionalProperties"] = False + self._pydantic_ai_tool = PydanticTool.from_schema( + function=self.execute, + name=self.name, + description=self.description, + json_schema=pyd_params, + ) async def execute(self, query: str, *args, **kwargs): """ diff --git a/py/core/base/agent/tools/built_in/web_scrape.py b/py/core/base/agent/tools/built_in/web_scrape.py index b4de649a6..dfeafe05d 100644 --- a/py/core/base/agent/tools/built_in/web_scrape.py +++ b/py/core/base/agent/tools/built_in/web_scrape.py @@ -1,5 +1,7 @@ import logging +from pydantic_ai import Tool as PydanticTool + from core.utils import ( generate_id, ) @@ -38,6 +40,14 @@ def __init__(self): results_function=self.execute, llm_format_function=None, ) + pyd_params = self.parameters.copy() + pyd_params["additionalProperties"] = False + self._pydantic_ai_tool = PydanticTool.from_schema( + function=self.execute, + name=self.name, + description=self.description, + json_schema=pyd_params, + ) async def execute(self, url: str, *args, **kwargs): """ diff --git a/py/core/base/agent/tools/built_in/web_search.py b/py/core/base/agent/tools/built_in/web_search.py index 41d23d822..906ed485a 100644 --- a/py/core/base/agent/tools/built_in/web_search.py +++ b/py/core/base/agent/tools/built_in/web_search.py @@ -1,3 +1,5 @@ +from pydantic_ai import Tool as PydanticTool + from shared.abstractions.tool import Tool @@ -27,6 +29,14 @@ def __init__(self): results_function=self.execute, llm_format_function=None, ) + pyd_params = self.parameters.copy() + pyd_params["additionalProperties"] = False + self._pydantic_ai_tool = PydanticTool.from_schema( + function=self.execute, + name=self.name, + description=self.description, + json_schema=pyd_params, + ) async def execute(self, query: str, *args, **kwargs): """ diff --git a/py/core/base/agent/tools/registry.py b/py/core/base/agent/tools/registry.py index 84179f969..526e1956b 100644 --- a/py/core/base/agent/tools/registry.py +++ b/py/core/base/agent/tools/registry.py @@ -21,7 +21,6 @@ def __init__( self, built_in_path: str | None = None, user_tools_path: str | None = None, - smolagent_tools_path: str | None = None, ): self.built_in_path = built_in_path or os.path.join( os.path.dirname(os.path.abspath(__file__)), "built_in" @@ -31,15 +30,10 @@ def __init__( or os.getenv("R2R_USER_TOOLS_PATH") or "../docker/user_tools" ) - self.smolagent_tools_path = smolagent_tools_path or os.path.join( - os.path.dirname(os.path.abspath(__file__)), - "../../../smolagent/tools", - ) # Tool storage self._built_in_tools: dict[str, Type[Tool]] = {} self._user_tools: dict[str, Type[Tool]] = {} - self._smolagent_tools: dict[str, Type[Tool]] = {} # Discover tools self._discover_built_in_tools() @@ -49,56 +43,37 @@ def __init__( logger.warning( f"User tools directory not found: {self.user_tools_path}" ) - if os.path.exists(self.smolagent_tools_path): - self._discover_smolagent_tools() - else: + + def _discover_built_in_tools(self): + """Load all built-in tools from the built_in directory.""" + if not os.path.exists(self.built_in_path): logger.warning( - f"Smolagent tools directory not found: {self.smolagent_tools_path}" + f"Built-in tools directory not found: {self.built_in_path}" ) - - def _discover_tools_from_path( - self, path: str, registry: dict, package_prefix: str = "" - ): - if not os.path.exists(path): - logger.warning(f"Tools directory not found: {path}") return # Add to Python path if needed - if path not in sys.path: - sys.path.append(os.path.dirname(path)) + if self.built_in_path not in sys.path: + sys.path.append(os.path.dirname(self.built_in_path)) - # Import the package if a prefix is given - prefix_valid = True - if package_prefix: - try: - importlib.import_module(package_prefix) - except ImportError as e: - logger.warning( - f"Failed to import tools package with prefix only: {package_prefix}\t" - f"Error: {e}" - ) - try: - # prefix works if it is local package, let's also fetch by path - tools_pkg_name = os.path.basename(path) - importlib.import_module(tools_pkg_name) - prefix_valid = False - except ImportError as e: - logger.warning( - f"Failed to import tools package: {tools_pkg_name}\t" - f"Error: {e}" - ) - return - - for _, module_name, is_pkg in pkgutil.iter_modules([path]): - if is_pkg: + # Import the built_in package + try: + built_in_pkg = importlib.import_module("built_in") + except ImportError: + logger.error("Failed to import built_in tools package") + return + + # Discover all modules in the package + for _, module_name, is_pkg in pkgutil.iter_modules( + [self.built_in_path] + ): + if is_pkg: # Skip subpackages continue + try: - module_path = f"{module_name}" - if prefix_valid and package_prefix: - module_path = f"{package_prefix}.{module_name}" - elif not prefix_valid and package_prefix: - module_path = f"{tools_pkg_name}.{module_name}" - module = importlib.import_module(module_path) + module = importlib.import_module(f"built_in.{module_name}") + + # Find all tool classes in the module for name, obj in inspect.getmembers(module, inspect.isclass): if ( issubclass(obj, Tool) @@ -106,22 +81,19 @@ def _discover_tools_from_path( and obj != Tool ): try: - tool_instance = obj() # type: ignore - registry[tool_instance.name] = obj + tool_instance = obj() + self._built_in_tools[tool_instance.name] = obj logger.debug( - f"Loaded tool: {tool_instance.name} from {path}" + f"Loaded built-in tool: {tool_instance.name}" ) except Exception as e: logger.error( - f"Error instantiating tool {name}: {e}" + f"Error instantiating built-in tool {name}: {e}" ) except Exception as e: - logger.error(f"Error loading tool module {module_name}: {e}") - - def _discover_built_in_tools(self): - self._discover_tools_from_path( - self.built_in_path, self._built_in_tools, "built_in" - ) + logger.error( + f"Error loading built-in tool module {module_name}: {e}" + ) def _discover_user_tools(self): """Scan the user tools directory for custom tools.""" @@ -170,17 +142,8 @@ def _discover_user_tools(self): f"Error loading user tool module {module_name}: {e}" ) - def _discover_smolagent_tools(self): - self._discover_tools_from_path( - self.smolagent_tools_path, self._smolagent_tools, "smolagent.tools" - ) - def get_tool_class(self, tool_name: str): - """Get a tool class by name. - If the tool is a smolagent tool, it will return the R2R wrapper. - """ - if tool_name in self._smolagent_tools: - return self._smolagent_tools[tool_name] + """Get a tool class by name.""" if tool_name in self._user_tools: return self._user_tools[tool_name] diff --git a/py/core/main/api/v3/retrieval_router.py b/py/core/main/api/v3/retrieval_router.py index e1172b933..bd07db379 100644 --- a/py/core/main/api/v3/retrieval_router.py +++ b/py/core/main/api/v3/retrieval_router.py @@ -619,7 +619,7 @@ async def agent_app( description="Use extended prompt for generation", ), # FIXME: We need a more generic way to handle this - mode: Optional[Literal["rag", "research", "rag_smol"]] = Body( + mode: Optional[Literal["rag", "research", "rag_pyd"]] = Body( default="rag", description="Mode to use for generation: 'rag' for standard retrieval or 'research' for deep analysis with reasoning capabilities", ), diff --git a/py/core/main/services/retrieval_service.py b/py/core/main/services/retrieval_service.py index c32178821..ec1ff9cb7 100644 --- a/py/core/main/services/retrieval_service.py +++ b/py/core/main/services/retrieval_service.py @@ -17,6 +17,7 @@ R2RXMLToolsResearchAgent, R2RXMLToolsStreamingRAGAgent, R2RXMLToolsStreamingResearchAgent, + RAGPydAgent, ) from core.agent.research import R2RResearchAgent from core.base import ( @@ -38,7 +39,6 @@ ) from core.base.agent.tools.registry import ToolRegistry from core.base.api.models import RAGResponse, User -from core.smolagent.agent import R2RSmolRAGAgent from core.utils import ( CitationTracker, SearchResultsCollector, @@ -66,7 +66,7 @@ class AgentFactory: @staticmethod def create_agent( - mode: Literal["rag", "research", "rag_smol"], + mode: Literal["rag", "research", "rag_pyd"], database_provider, llm_provider, config, # : AgentConfig @@ -85,7 +85,7 @@ def create_agent( Creates and returns the appropriate agent based on provided parameters. Args: - mode: Either "rag", "research", or "rag_smol" to determine agent type + mode: Either "rag", "research", or "rag_pyd" to determine agent type database_provider: Provider for database operations llm_provider: Provider for LLM operations config: Agent configuration @@ -108,7 +108,7 @@ def create_agent( tool_registry = ToolRegistry() # Handle tool specifications based on mode - if mode == "rag" or mode == "rag_smol": + if mode == "rag" or mode == "rag_pyd": # For RAG mode, prioritize explicitly passed rag_tools, then tools, then config defaults if rag_tools: agent_config.rag_tools = rag_tools @@ -246,8 +246,8 @@ def create_agent( content_method=content_method, file_search_method=file_search_method, ) - elif mode == "rag_smol": - return R2RSmolRAGAgent( + elif mode == "rag_pyd": + return RAGPydAgent( database_provider=database_provider, llm_provider=llm_provider, config=agent_config, @@ -1297,7 +1297,7 @@ async def agent( research_tools: Optional[list[str]] = None, research_generation_config: Optional[GenerationConfig] = None, needs_initial_conversation_name: Optional[bool] = None, - mode: Optional[Literal["rag", "research", "rag_smol"]] = "rag", + mode: Optional[Literal["rag", "research", "rag_pyd"]] = "rag", ): """ Engage with an intelligent agent for information retrieval, analysis, and research. @@ -1386,7 +1386,7 @@ async def agent( ) # Set appropriate LLM model based on mode if not explicitly specified if "model" not in effective_generation_config.model_fields_set: - if mode == "rag" or mode == "rag_smol": + if mode == "rag" or mode == "rag_pyd": effective_generation_config.model = ( self.config.app.quality_llm ) @@ -1497,7 +1497,7 @@ async def agent( # Configure agent with appropriate tools agent_config = deepcopy(self.config.agent) - if mode == "rag" or mode == "rag_smol": + if mode == "rag" or mode == "rag_pyd": # Use provided RAG tools or default from config agent_config.rag_tools = ( rag_tools or tools or self.config.agent.rag_tools diff --git a/py/core/smolagent/__init__.py b/py/core/smolagent/__init__.py deleted file mode 100644 index e9111cd91..000000000 --- a/py/core/smolagent/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -from .agent import R2RSmolRAGAgent -from .tools import ( - SmolGetFileContentTool, - SmolSearchFileDescriptionsTool, - SmolSearchFileKnowledgeTool, - SmolSmartFilterTool, -) - -__all__ = [ - "R2RSmolRAGAgent", - "SmolSmartFilterTool", - "SmolSearchFileKnowledgeTool", - "SmolGetFileContentTool", - "SmolSearchFileDescriptionsTool", -] diff --git a/py/core/smolagent/agent.py b/py/core/smolagent/agent.py deleted file mode 100644 index adfcd7a87..000000000 --- a/py/core/smolagent/agent.py +++ /dev/null @@ -1,117 +0,0 @@ -import logging -from typing import Any - -from smolagents import CodeAgent - -from core import R2RRAGAgent -from core.base import Message -from core.smolagent.llm_provider.hf_llm_provider import ( - fetch_hf_inference_from_model, -) - -logger = logging.getLogger(__name__) - - -def smol_to_r2r_messages(smol_run_result: Any) -> list[Message]: - """ - Convert smolagents RunResult.messages (list of dicts) to list of R2R Message objects. - """ - logger.debug( - f"Converting smol run result to R2R messages: {smol_run_result}" - ) - r2r_messages = [] - if hasattr(smol_run_result, "messages"): - for msg in smol_run_result.messages: - # Defensive: Only use fields that Message accepts - role = msg.get("role", "assistant") - content = msg.get("content", "") - # Optionally pass other fields if R2R Message supports them - r2r_messages.append(Message(role=role, content=content)) - else: - r2r_messages = [ - Message(role="assistant", content=str(smol_run_result)) - ] - logger.debug(f"R2R messages: {r2r_messages}") - return r2r_messages - - -class R2RSmolRAGAgent(R2RRAGAgent): - def __init__( - self, - database_provider, - llm_provider, - config, - search_settings, - rag_generation_config, - knowledge_search_method, - content_method, - file_search_method, - tool_registry=None, - max_tool_context_length=20000, - **kwargs, - ): - # Prefix all tool names in rag_tools with 'smol_' - if hasattr(config, "rag_tools") and config.rag_tools: - config.rag_tools = [ - f"smol_{name}" if not name.startswith("smol_") else name - for name in config.rag_tools - ] - logger.debug(f"Smol RAG tools: {config.rag_tools}") - super().__init__( - database_provider=database_provider, - llm_provider=llm_provider, - config=config, - search_settings=search_settings, - rag_generation_config=rag_generation_config, - knowledge_search_method=knowledge_search_method, - content_method=content_method, - file_search_method=file_search_method, - tool_registry=tool_registry, - max_tool_context_length=max_tool_context_length, - **kwargs, - ) - - # register tools method called in super init will fill _tools with the tools from the registry - self._smol_tools = [ - tool._smol_tool - for tool in getattr(self, "_tools", []) - if hasattr(tool, "_smol_tool") - ] - logger.debug(f"Smol tools: {self._smol_tools}") - - # logger.debug(f"Config: {config}") - # logger.debug(f"Fetching HF model: {rag_generation_config}") - self._hf_model = fetch_hf_inference_from_model( - rag_generation_config.model - ) - self.hf_agent = CodeAgent( - tools=self._smol_tools, - model=self._hf_model, - use_structured_outputs_internally=True, - add_base_tools=False, - ) - self.update_system_prompt() - - def run(self, messages: list[Message], **kwargs): - # logger.debug(f"Running smol agent with messages: {messages}") - message = messages[0].content - # logger.debug(f"Message: {message}") - return self.hf_agent.run(message) - - async def arun(self, messages: list[Message], **kwargs): - logger.debug(f"Running smol agent with messages: {messages}") - message = messages[0].content - logger.debug(f"Message: {message}") - # Non stream version - result = self.hf_agent.run(message) - return smol_to_r2r_messages(result) - - def update_system_prompt(self): - if hasattr(self, "hf_agent"): - # logger.debug(f"Old system prompt: {self.hf_agent.system_prompt}") - custom_system_prompt_addition = ( - "\n\nWhen asked a question, YOU SHOULD ALWAYS USE YOUR SEARCH TOOL TO ATTEMPT TO SEARCH FOR RELEVANT INFORMATION THAT ANSWERS THE USER QUESTION BUT" - " if you have access to tools that help you set some filters, like smol_smart_filter_tool, to narrow down and speed up the search, use them BEFORE the search" - ) - self.hf_agent.system_prompt += custom_system_prompt_addition - # logger.debug(f"Updated system prompt: {self.hf_agent.system_prompt}") diff --git a/py/core/smolagent/llm_provider/__init__.py b/py/core/smolagent/llm_provider/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/py/core/smolagent/llm_provider/hf_llm_provider.py b/py/core/smolagent/llm_provider/hf_llm_provider.py deleted file mode 100644 index f3c8e6d89..000000000 --- a/py/core/smolagent/llm_provider/hf_llm_provider.py +++ /dev/null @@ -1,20 +0,0 @@ -import logging - -from smolagents import Model - -logger = logging.getLogger(__name__) - - -def fetch_hf_inference_from_model(model_name: str) -> Model: - logger.debug(f"Fetching HF inference from model: {model_name}") - if "gpt" in model_name: - from smolagents import OpenAIServerModel - - # Initialize the model with our reverse proxy - # remove openai/ prefix if there is one - model_id = model_name.replace("openai/", "") - return OpenAIServerModel( - model_id=model_id, - ) - else: - raise ValueError(f"Model {model_name} is not supported") diff --git a/py/core/smolagent/tools/__init__.py b/py/core/smolagent/tools/__init__.py deleted file mode 100644 index 91446b155..000000000 --- a/py/core/smolagent/tools/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from .get_file_content import ( - SmolGetFileContentTool, - SmolGetFileContentToolR2RWrapper, -) -from .search_file_descriptions import ( - SmolSearchFileDescriptionsTool, - SmolSearchFileDescriptionsToolR2RWrapper, -) -from .search_file_knowledge import ( - SmolSearchFileKnowledgeTool, - SmolSearchFileKnowledgeToolR2RWrapper, -) -from .smart_filter import SmolSmartFilterTool, SmolSmartFilterToolR2RWrapper - -__all__ = [ - "SmolSmartFilterTool", - "SmolSmartFilterToolR2RWrapper", - "SmolSearchFileKnowledgeTool", - "SmolSearchFileKnowledgeToolR2RWrapper", - "SmolGetFileContentTool", - "SmolGetFileContentToolR2RWrapper", - "SmolSearchFileDescriptionsTool", - "SmolSearchFileDescriptionsToolR2RWrapper", -] diff --git a/py/core/smolagent/tools/get_file_content.py b/py/core/smolagent/tools/get_file_content.py deleted file mode 100644 index 2234217fb..000000000 --- a/py/core/smolagent/tools/get_file_content.py +++ /dev/null @@ -1,43 +0,0 @@ -from smolagents import Tool as SmolTool - -from core.base.agent.tools.built_in.get_file_content import GetFileContentTool -from shared.abstractions.tool import Tool as R2RTool - - -class SmolGetFileContentTool(SmolTool): - name = "smol_get_file_content" - description = "Fetches the complete contents of a user document from the local database by document ID." - inputs = { - "document_id": { - "type": "string", - "description": "The unique UUID of the document to fetch.", - } - } - output_type = "object" - parameters = inputs - - def __init__(self, context, r2r_tool: R2RTool, **kwargs): - super().__init__() - self.context = context - self.results_function = self.forward - self._r2r_tool = r2r_tool - - async def forward(self, document_id: str): - return await self._r2r_tool.execute(document_id) - - -class SmolGetFileContentToolR2RWrapper(R2RTool): - def __init__(self): - super().__init__( - name="smol_get_file_content", - description="Wrapper for the SmolGetFileContentTool to be used in R2R", - results_function=self.execute, - ) - self._r2r_tool = GetFileContentTool() - self._smol_tool = SmolGetFileContentTool( - context=self, r2r_tool=self._r2r_tool - ) - - async def execute(self, document_id: str, **kwargs): - # Placeholder function, not to be used - pass diff --git a/py/core/smolagent/tools/search_file_descriptions.py b/py/core/smolagent/tools/search_file_descriptions.py deleted file mode 100644 index d3ca598d1..000000000 --- a/py/core/smolagent/tools/search_file_descriptions.py +++ /dev/null @@ -1,51 +0,0 @@ -import logging - -from smolagents import Tool as SmolTool - -from core.base.agent.tools.built_in.search_file_descriptions import ( - SearchFileDescriptionsTool, -) -from shared.abstractions.tool import Tool as R2RTool - -logger = logging.getLogger(__name__) - - -class SmolSearchFileDescriptionsTool(SmolTool): - name = "smol_search_file_descriptions" - description = "Semantic search over AI-generated summaries of stored documents. Use for a broad overview of relevant files." - inputs = { - "query": { - "type": "string", - "description": "Query string to semantic search over available files.", - } - } - output_type = "object" - parameters = inputs - - def __init__(self, context, r2r_tool: R2RTool, **kwargs): - super().__init__() - self.context = context - self.results_function = self.forward - self._r2r_tool = r2r_tool - - async def forward(self, query: str): - return await self._r2r_tool.execute(query) - - -class SmolSearchFileDescriptionsToolR2RWrapper(R2RTool): - def __init__(self): - super().__init__( - name="smol_search_file_descriptions", - description="Wrapper for the SmolSearchFileDescriptionsTool to be used in R2R", - results_function=self.execute, - ) - self._r2r_tool = SearchFileDescriptionsTool() - self._smol_tool = SmolSearchFileDescriptionsTool( - context=self, r2r_tool=self._r2r_tool - ) - - async def execute(self, query: str, **kwargs): - logger.debug( - f"Executing SmolSearchFileDescriptionsToolR2RWrapper with query: {query}" - ) - return await self._r2r_tool.execute(query) diff --git a/py/core/smolagent/tools/search_file_knowledge.py b/py/core/smolagent/tools/search_file_knowledge.py deleted file mode 100644 index 251fa41f5..000000000 --- a/py/core/smolagent/tools/search_file_knowledge.py +++ /dev/null @@ -1,51 +0,0 @@ -import logging - -from smolagents import Tool as SmolTool - -from core.base.agent.tools.built_in.search_file_knowledge import ( - SearchFileKnowledgeTool, -) -from shared.abstractions.tool import Tool as R2RTool - -logger = logging.getLogger(__name__) - - -class SmolSearchFileKnowledgeTool(SmolTool): - name = "smol_search_file_knowledge" - description = "Search your local knowledge base using the R2R system. Use this for relevant text chunks or knowledge graph data." - inputs = { - "query": { - "type": "string", - "description": "User query to search in the local DB.", - } - } - output_type = "object" - parameters = inputs - - def __init__(self, context, r2r_tool: R2RTool, **kwargs): - super().__init__() - self.context = context - self.results_function = self.forward - self._r2r_tool = r2r_tool - - async def forward(self, query: str): - return await self._r2r_tool.execute(query) - - -class SmolSearchFileKnowledgeToolR2RWrapper(R2RTool): - def __init__(self): - super().__init__( - name="smol_search_file_knowledge", - description="Wrapper for the SmolSearchFileKnowledgeTool to be used in R2R", - results_function=self.execute, - ) - self._r2r_tool = SearchFileKnowledgeTool() - self._smol_tool = SmolSearchFileKnowledgeTool( - context=self, r2r_tool=self._r2r_tool - ) - - async def execute(self, query: str, **kwargs): - logger.debug( - f"Executing SmolSearchFileKnowledgeToolR2RWrapper with query: {query}" - ) - return await self._r2r_tool.execute(query) diff --git a/py/core/smolagent/tools/smart_filter.py b/py/core/smolagent/tools/smart_filter.py deleted file mode 100644 index 798e2191b..000000000 --- a/py/core/smolagent/tools/smart_filter.py +++ /dev/null @@ -1,49 +0,0 @@ -import logging - -from smolagents import Tool as SmolTool - -from core.base.agent.tools.built_in.smart_filter import SmartFilterTool -from shared.abstractions.tool import Tool as R2RTool - -logger = logging.getLogger(__name__) - - -class SmolSmartFilterTool(SmolTool): - name = "smol_smart_filter" - description = "Refines metadata and collection filters for a RAG search using LLM analysis. To be used BEFORE the rag search" - inputs = { - "query": { - "type": "string", - "description": "The user query to analyze.", - } - } - output_type = "object" - parameters = inputs - - def __init__(self, context, r2r_tool: R2RTool, **kwargs): - super().__init__() - self.context = context - self.results_function = self.forward - self._r2r_tool = r2r_tool - - async def forward(self, query: str): - return await self._r2r_tool.execute(query) - - -class SmolSmartFilterToolR2RWrapper(R2RTool): - def __init__(self): - super().__init__( - name="smol_smart_filter_tool", - description="Wrapper for the SmolSmartFilterTool to be used in R2R", - results_function=self.execute, - ) - self._r2r_tool = SmartFilterTool() - self._smol_tool = SmolSmartFilterTool( - context=self, r2r_tool=self._r2r_tool - ) - - async def execute(self, query: str, **kwargs): - logger.debug( - f"Executing SmolSmartFilterToolR2RWrapper with query: {query}" - ) - return await self._r2r_tool.execute(query) diff --git a/py/pyproject.toml b/py/pyproject.toml index 35a24987a..723437c41 100644 --- a/py/pyproject.toml +++ b/py/pyproject.toml @@ -29,7 +29,7 @@ dependencies = [ "pydantic>=2.10.6", "python-json-logger>=3.2.1", "filetype>=1.2.0", - "smolagents>=1.17.0,<2.0.0", + "pydantic-ai==0.2.16", ] [project.optional-dependencies] @@ -53,7 +53,7 @@ core = [ "future >=1.0.0,<2.0.0", "google-auth >=2.37.0,<3.0.0", "google-auth-oauthlib >=1.2.1,<2.0.0", - "google-genai >=0.6.0,<0.7.0", + "google-genai >=0.6.0", "gunicorn >=21.2.0,<22.0.0", "hatchet-sdk ==0.47.0", "litellm >=1.69.3", @@ -63,7 +63,7 @@ core = [ "networkx >=3.3,<4.0", "numpy >=1.22.4,<1.29.0", "olefile >=0.47,<0.48", - "ollama >=0.3.1,<0.4.0", + "ollama >=0.3.1", "openpyxl >=3.1.2,<4.0.0", "orgparse >=0.4.20231004,<0.5.0", "pdf2image>=1.17.0", diff --git a/py/sdk/sync_methods/retrieval.py b/py/sdk/sync_methods/retrieval.py index ac03ed923..068acf3db 100644 --- a/py/sdk/sync_methods/retrieval.py +++ b/py/sdk/sync_methods/retrieval.py @@ -377,7 +377,7 @@ def agent( research_tools (Optional[list[str]]): List of tools to enable for Research mode. Available tools: "rag", "reasoning", "critique", "python_executor". tools (Optional[list[str]]): Deprecated. List of tools to execute. - mode (Optional[str]): Mode to use for generation: "rag", "rag_smol" for smolagent, or "research" for deep analysis. + mode (Optional[str]): Mode to use for generation: "rag", "rag_pyd" for pydantic agent, or "research" for deep analysis. Defaults to "rag". Returns: