diff --git a/.env.example b/.env.example index 8558e9c4..9ad0e864 100644 --- a/.env.example +++ b/.env.example @@ -46,10 +46,18 @@ MAX_WORKERS=30 # API Keys and External Services # ============================================================================= +# Search backend selector. Set to "exa" to route web/scholar search through Exa, +# or leave as "serper" (default) to use Serper. Existing deployments are unaffected. +SEARCH_PROVIDER=serper + # Serper API for web search and Google Scholar # Get your key from: https://serper.dev/ SERPER_KEY_ID=your_key +# Exa API for web search and research-paper search (used when SEARCH_PROVIDER=exa) +# Get your key from: https://dashboard.exa.ai/ +EXA_API_KEY=your_exa_api_key + # Jina API for web page reading # Get your key from: https://jina.ai/ JINA_API_KEYS=your_key diff --git a/inference/tests/__init__.py b/inference/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/inference/tests/test_tool_exa.py b/inference/tests/test_tool_exa.py new file mode 100644 index 00000000..d54f83de --- /dev/null +++ b/inference/tests/test_tool_exa.py @@ -0,0 +1,159 @@ +"""Tests for the Exa search backend (``inference/tool_exa.py``). + +These tests do not hit the live Exa API. They mock the ``exa_py.Exa`` class +and assert response parsing, snippet fallback ordering, the integration header, +and behavior when ``EXA_API_KEY`` is unset. +""" + +import os +import sys +import unittest +from unittest import mock + +# The inference package uses sibling-relative imports (e.g. ``from tool_exa import ...``) +# so we add the inference directory to sys.path for these tests. +HERE = os.path.dirname(os.path.abspath(__file__)) +INFERENCE_DIR = os.path.dirname(HERE) +if INFERENCE_DIR not in sys.path: + sys.path.insert(0, INFERENCE_DIR) + + +class _FakeResult: + def __init__( + self, + title="", + url="", + published_date=None, + author=None, + text=None, + highlights=None, + summary=None, + ): + self.title = title + self.url = url + self.published_date = published_date + self.author = author + self.text = text + self.highlights = highlights + self.summary = summary + + +class _FakeResponse: + def __init__(self, results): + self.results = results + + +class _FakeExa: + """Stands in for ``exa_py.Exa`` and records the kwargs it receives.""" + + last_kwargs: dict = {} + last_query: str = "" + last_instance: "Optional[_FakeExa]" = None + + def __init__(self, api_key): + self.api_key = api_key + self.headers: dict = {} + _FakeExa.last_instance = self + + def search_and_contents(self, query, **kwargs): + _FakeExa.last_query = query + _FakeExa.last_kwargs = kwargs + return _FakeResponse(_FakeExa.next_results) + + +_FakeExa.next_results = [] + + +def _patch_exa(results): + """Install a fake ``exa_py`` module so ``from exa_py import Exa`` returns _FakeExa.""" + _FakeExa.next_results = results + fake_module = mock.MagicMock() + fake_module.Exa = _FakeExa + return mock.patch.dict(sys.modules, {"exa_py": fake_module}) + + +class ExaParsingTests(unittest.TestCase): + def setUp(self): + os.environ["EXA_API_KEY"] = "test-key" + + def test_search_parses_results_and_sets_integration_header(self): + results = [ + _FakeResult( + title="Example", + url="https://example.com/a", + published_date="2026-04-01", + highlights=["Hello world", "More context"], + ), + _FakeResult( + title="Second", + url="https://example.com/b", + summary="A short summary.", + ), + ] + with _patch_exa(results): + from tool_exa import search_with_exa + + output = search_with_exa("python testing") + + self.assertIn("python testing", output) + self.assertIn("Example", output) + self.assertIn("https://example.com/a", output) + self.assertIn("Hello world ... More context", output) + self.assertIn("A short summary.", output) + self.assertIn("Date published: 2026-04-01", output) + self.assertEqual(_FakeExa.last_query, "python testing") + self.assertEqual(_FakeExa.last_kwargs.get("type"), "auto") + self.assertIn("highlights", _FakeExa.last_kwargs) + self.assertIn("text", _FakeExa.last_kwargs) + self.assertIsNotNone(_FakeExa.last_instance) + self.assertEqual( + _FakeExa.last_instance.headers.get("x-exa-integration"), + "deepresearch", + ) + + def test_snippet_falls_back_through_highlights_summary_text(self): + from tool_exa import ExaResult + + only_highlights = ExaResult(title="t", url="u", highlights=["h1", "h2"]) + only_summary = ExaResult(title="t", url="u", summary="s") + only_text = ExaResult(title="t", url="u", text="t" * 600) + empty = ExaResult(title="t", url="u") + + self.assertEqual(only_highlights.snippet(), "h1 ... h2") + self.assertEqual(only_summary.snippet(), "s") + self.assertTrue(only_text.snippet().endswith("...")) + self.assertLessEqual(len(only_text.snippet()), 504) + self.assertEqual(empty.snippet(), "") + + def test_scholar_uses_research_paper_category(self): + results = [_FakeResult(title="Paper", url="https://arxiv.org/x", summary="S")] + with _patch_exa(results): + from tool_exa import scholar_with_exa + + output = scholar_with_exa("transformer scaling laws") + + self.assertIn("Scholar Results", output) + self.assertIn("Paper", output) + self.assertEqual(_FakeExa.last_kwargs.get("category"), "research paper") + + def test_no_results_returns_friendly_message(self): + with _patch_exa([]): + from tool_exa import search_with_exa + + output = search_with_exa("zzz nothing matches") + + self.assertIn("No results found", output) + + +class ExaDisabledTests(unittest.TestCase): + def test_missing_api_key_returns_error_string(self): + os.environ.pop("EXA_API_KEY", None) + with _patch_exa([]): + from tool_exa import search_with_exa + + output = search_with_exa("anything") + + self.assertIn("EXA_API_KEY is not set", output) + +if __name__ == "__main__": + unittest.main() diff --git a/inference/tool_exa.py b/inference/tool_exa.py new file mode 100644 index 00000000..a7f9661b --- /dev/null +++ b/inference/tool_exa.py @@ -0,0 +1,158 @@ +"""Exa search backend. + +Provides Exa-powered web search and Exa-powered scholar search. The implementations +return strings with the same numbered-snippet format produced by the Serper backends +in ``tool_search.py`` and ``tool_scholar.py`` so the agent prompt format is unchanged. +""" + +import os +from dataclasses import dataclass +from typing import Any, List, Optional + + +_EXA_INTEGRATION_HEADER = "deepresearch" + + +@dataclass +class ExaResult: + title: str + url: str + published_date: Optional[str] = None + author: Optional[str] = None + text: Optional[str] = None + highlights: Optional[List[str]] = None + summary: Optional[str] = None + + @classmethod + def from_sdk(cls, item: Any) -> "ExaResult": + get = lambda key: getattr(item, key, None) if not isinstance(item, dict) else item.get(key) + highlights = get("highlights") + if highlights is not None and not isinstance(highlights, list): + highlights = list(highlights) + return cls( + title=get("title") or "", + url=get("url") or "", + published_date=get("published_date") or get("publishedDate"), + author=get("author"), + text=get("text"), + highlights=highlights, + summary=get("summary"), + ) + + def snippet(self) -> str: + if self.highlights: + return " ... ".join(h.strip() for h in self.highlights if h) + if self.summary: + return self.summary.strip() + if self.text: + text = self.text.strip() + if len(text) > 500: + text = text[:500].rstrip() + "..." + return text + return "" + + +def _build_client(): + from exa_py import Exa + + api_key = os.environ.get("EXA_API_KEY") + if not api_key: + raise RuntimeError("EXA_API_KEY is not set") + client = Exa(api_key=api_key) + try: + client.headers["x-exa-integration"] = _EXA_INTEGRATION_HEADER + except AttributeError: + pass + return client + + +def _run_search( + query: str, + num_results: int, + category: Optional[str] = None, + include_domains: Optional[List[str]] = None, + exclude_domains: Optional[List[str]] = None, + start_published_date: Optional[str] = None, + end_published_date: Optional[str] = None, +) -> List[ExaResult]: + client = _build_client() + + kwargs: dict = { + "num_results": num_results, + "type": "auto", + "highlights": {"num_sentences": 3}, + "text": {"max_characters": 500}, + } + if category: + kwargs["category"] = category + if include_domains: + kwargs["include_domains"] = include_domains + if exclude_domains: + kwargs["exclude_domains"] = exclude_domains + if start_published_date: + kwargs["start_published_date"] = start_published_date + if end_published_date: + kwargs["end_published_date"] = end_published_date + + last_err: Optional[Exception] = None + for attempt in range(5): + try: + response = client.search_and_contents(query, **kwargs) + raw_results = getattr(response, "results", None) or [] + return [ExaResult.from_sdk(r) for r in raw_results] + except Exception as e: + last_err = e + continue + raise RuntimeError(f"Exa search failed after 5 attempts: {last_err}") + + +def _format_block(query: str, results: List[ExaResult], header: str, label: str) -> str: + if not results: + return f"No results found for '{query}'. Try with a more general query." + + snippets: List[str] = [] + for idx, r in enumerate(results, start=1): + date_published = f"\nDate published: {r.published_date}" if r.published_date else "" + author_line = f"\nSource: {r.author}" if r.author else "" + snippet = r.snippet() + snippet_block = f"\n{snippet}" if snippet else "" + title = r.title or r.url + snippets.append(f"{idx}. [{title}]({r.url}){date_published}{author_line}{snippet_block}") + + return f"{header} for '{query}' found {len(snippets)} results:\n\n## {label}\n" + "\n\n".join(snippets) + + +def search_with_exa( + query: str, + num_results: int = 10, + category: Optional[str] = None, + include_domains: Optional[List[str]] = None, + exclude_domains: Optional[List[str]] = None, + start_published_date: Optional[str] = None, + end_published_date: Optional[str] = None, +) -> str: + try: + results = _run_search( + query, + num_results=num_results, + category=category, + include_domains=include_domains, + exclude_domains=exclude_domains, + start_published_date=start_published_date, + end_published_date=end_published_date, + ) + except RuntimeError as e: + return f"[Exa search] {e}" + except Exception as e: + return f"[Exa search] Unexpected error: {e}" + return _format_block(query, results, header="A web search", label="Web Results") + + +def scholar_with_exa(query: str, num_results: int = 10) -> str: + try: + results = _run_search(query, num_results=num_results, category="research paper") + except RuntimeError as e: + return f"[Exa scholar] {e}" + except Exception as e: + return f"[Exa scholar] Unexpected error: {e}" + return _format_block(query, results, header="A research-paper search", label="Scholar Results") diff --git a/inference/tool_scholar.py b/inference/tool_scholar.py index ae021b38..158f0257 100644 --- a/inference/tool_scholar.py +++ b/inference/tool_scholar.py @@ -6,8 +6,11 @@ from concurrent.futures import ThreadPoolExecutor import http.client +from tool_exa import scholar_with_exa + SERPER_KEY=os.environ.get('SERPER_KEY_ID') +SEARCH_PROVIDER=os.environ.get('SEARCH_PROVIDER', 'serper').lower() @register_tool("google_scholar", allow_overwrite=True) @@ -91,6 +94,11 @@ def google_scholar_with_serp(self, query: str): return f"No results found for '{query}'. Try with a more general query." + def _do_scholar_search(self, query: str) -> str: + if SEARCH_PROVIDER == "exa": + return scholar_with_exa(query) + return self.google_scholar_with_serp(query) + def call(self, params: Union[str, dict], **kwargs) -> str: # assert GOOGLE_SEARCH_KEY is not None, "Please set the IDEALAB_SEARCH_KEY environment variable." try: @@ -98,13 +106,13 @@ def call(self, params: Union[str, dict], **kwargs) -> str: query = params["query"] except: return "[google_scholar] Invalid request format: Input must be a JSON object containing 'query' field" - + if isinstance(query, str): - response = self.google_scholar_with_serp(query) + response = self._do_scholar_search(query) else: assert isinstance(query, List) with ThreadPoolExecutor(max_workers=3) as executor: - response = list(executor.map(self.google_scholar_with_serp, query)) + response = list(executor.map(self._do_scholar_search, query)) response = "\n=======\n".join(response) return response diff --git a/inference/tool_search.py b/inference/tool_search.py index 1a3f7b53..0e78bb7c 100644 --- a/inference/tool_search.py +++ b/inference/tool_search.py @@ -11,8 +11,11 @@ import os +from tool_exa import search_with_exa + SERPER_KEY=os.environ.get('SERPER_KEY_ID') +SEARCH_PROVIDER=os.environ.get('SEARCH_PROVIDER', 'serper').lower() @register_tool("search", allow_overwrite=True) @@ -110,22 +113,27 @@ def search_with_serp(self, query: str): result = self.google_search_with_serp(query) return result + def _do_search(self, query: str) -> str: + if SEARCH_PROVIDER == "exa": + return search_with_exa(query) + return self.search_with_serp(query) + def call(self, params: Union[str, dict], **kwargs) -> str: try: query = params["query"] except: return "[Search] Invalid request format: Input must be a JSON object containing 'query' field" - + if isinstance(query, str): # 单个查询 - response = self.search_with_serp(query) + response = self._do_search(query) else: # 多个查询 assert isinstance(query, List) responses = [] for q in query: - responses.append(self.search_with_serp(q)) + responses.append(self._do_search(q)) response = "\n=======\n".join(responses) - + return response diff --git a/requirements.txt b/requirements.txt index 5fd43bd1..798bd344 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,6 +40,7 @@ einops==0.8.1 email-validator==2.3.0 et_xmlfile==2.0.0 eval_type_backport==0.2.2 +exa-py>=2.0.0 exceptiongroup==1.3.0 fastapi==0.116.1 fastapi-cli==0.0.11