Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,18 @@ MAX_WORKERS=30
# API Keys and External Services
# =============================================================================

# Search backend selector. Set to "exa" to route web/scholar search through Exa,
# or leave as "serper" (default) to use Serper. Existing deployments are unaffected.
SEARCH_PROVIDER=serper

# Serper API for web search and Google Scholar
# Get your key from: https://serper.dev/
SERPER_KEY_ID=your_key

# Exa API for web search and research-paper search (used when SEARCH_PROVIDER=exa)
# Get your key from: https://dashboard.exa.ai/
EXA_API_KEY=your_exa_api_key

# Jina API for web page reading
# Get your key from: https://jina.ai/
JINA_API_KEYS=your_key
Expand Down
Empty file added inference/tests/__init__.py
Empty file.
159 changes: 159 additions & 0 deletions inference/tests/test_tool_exa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Tests for the Exa search backend (``inference/tool_exa.py``).

These tests do not hit the live Exa API. They mock the ``exa_py.Exa`` class
and assert response parsing, snippet fallback ordering, the integration header,
and behavior when ``EXA_API_KEY`` is unset.
"""

import os
import sys
import unittest
from unittest import mock

# The inference package uses sibling-relative imports (e.g. ``from tool_exa import ...``)
# so we add the inference directory to sys.path for these tests.
HERE = os.path.dirname(os.path.abspath(__file__))
INFERENCE_DIR = os.path.dirname(HERE)
if INFERENCE_DIR not in sys.path:
sys.path.insert(0, INFERENCE_DIR)


class _FakeResult:
def __init__(
self,
title="",
url="",
published_date=None,
author=None,
text=None,
highlights=None,
summary=None,
):
self.title = title
self.url = url
self.published_date = published_date
self.author = author
self.text = text
self.highlights = highlights
self.summary = summary


class _FakeResponse:
def __init__(self, results):
self.results = results


class _FakeExa:
"""Stands in for ``exa_py.Exa`` and records the kwargs it receives."""

last_kwargs: dict = {}
last_query: str = ""
last_instance: "Optional[_FakeExa]" = None

def __init__(self, api_key):
self.api_key = api_key
self.headers: dict = {}
_FakeExa.last_instance = self

def search_and_contents(self, query, **kwargs):
_FakeExa.last_query = query
_FakeExa.last_kwargs = kwargs
return _FakeResponse(_FakeExa.next_results)


_FakeExa.next_results = []


def _patch_exa(results):
"""Install a fake ``exa_py`` module so ``from exa_py import Exa`` returns _FakeExa."""
_FakeExa.next_results = results
fake_module = mock.MagicMock()
fake_module.Exa = _FakeExa
return mock.patch.dict(sys.modules, {"exa_py": fake_module})


class ExaParsingTests(unittest.TestCase):
def setUp(self):
os.environ["EXA_API_KEY"] = "test-key"

def test_search_parses_results_and_sets_integration_header(self):
results = [
_FakeResult(
title="Example",
url="https://example.com/a",
published_date="2026-04-01",
highlights=["Hello world", "More context"],
),
_FakeResult(
title="Second",
url="https://example.com/b",
summary="A short summary.",
),
]
with _patch_exa(results):
from tool_exa import search_with_exa

output = search_with_exa("python testing")

self.assertIn("python testing", output)
self.assertIn("Example", output)
self.assertIn("https://example.com/a", output)
self.assertIn("Hello world ... More context", output)
self.assertIn("A short summary.", output)
self.assertIn("Date published: 2026-04-01", output)
self.assertEqual(_FakeExa.last_query, "python testing")
self.assertEqual(_FakeExa.last_kwargs.get("type"), "auto")
self.assertIn("highlights", _FakeExa.last_kwargs)
self.assertIn("text", _FakeExa.last_kwargs)
self.assertIsNotNone(_FakeExa.last_instance)
self.assertEqual(
_FakeExa.last_instance.headers.get("x-exa-integration"),
"deepresearch",
)

def test_snippet_falls_back_through_highlights_summary_text(self):
from tool_exa import ExaResult

only_highlights = ExaResult(title="t", url="u", highlights=["h1", "h2"])
only_summary = ExaResult(title="t", url="u", summary="s")
only_text = ExaResult(title="t", url="u", text="t" * 600)
empty = ExaResult(title="t", url="u")

self.assertEqual(only_highlights.snippet(), "h1 ... h2")
self.assertEqual(only_summary.snippet(), "s")
self.assertTrue(only_text.snippet().endswith("..."))
self.assertLessEqual(len(only_text.snippet()), 504)
self.assertEqual(empty.snippet(), "")

def test_scholar_uses_research_paper_category(self):
results = [_FakeResult(title="Paper", url="https://arxiv.org/x", summary="S")]
with _patch_exa(results):
from tool_exa import scholar_with_exa

output = scholar_with_exa("transformer scaling laws")

self.assertIn("Scholar Results", output)
self.assertIn("Paper", output)
self.assertEqual(_FakeExa.last_kwargs.get("category"), "research paper")

def test_no_results_returns_friendly_message(self):
with _patch_exa([]):
from tool_exa import search_with_exa

output = search_with_exa("zzz nothing matches")

self.assertIn("No results found", output)


class ExaDisabledTests(unittest.TestCase):
def test_missing_api_key_returns_error_string(self):
os.environ.pop("EXA_API_KEY", None)
with _patch_exa([]):
from tool_exa import search_with_exa

output = search_with_exa("anything")

self.assertIn("EXA_API_KEY is not set", output)

if __name__ == "__main__":
unittest.main()
158 changes: 158 additions & 0 deletions inference/tool_exa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Exa search backend.

Provides Exa-powered web search and Exa-powered scholar search. The implementations
return strings with the same numbered-snippet format produced by the Serper backends
in ``tool_search.py`` and ``tool_scholar.py`` so the agent prompt format is unchanged.
"""

import os
from dataclasses import dataclass
from typing import Any, List, Optional


_EXA_INTEGRATION_HEADER = "deepresearch"


@dataclass
class ExaResult:
title: str
url: str
published_date: Optional[str] = None
author: Optional[str] = None
text: Optional[str] = None
highlights: Optional[List[str]] = None
summary: Optional[str] = None

@classmethod
def from_sdk(cls, item: Any) -> "ExaResult":
get = lambda key: getattr(item, key, None) if not isinstance(item, dict) else item.get(key)
highlights = get("highlights")
if highlights is not None and not isinstance(highlights, list):
highlights = list(highlights)
return cls(
title=get("title") or "",
url=get("url") or "",
published_date=get("published_date") or get("publishedDate"),
author=get("author"),
text=get("text"),
highlights=highlights,
summary=get("summary"),
)

def snippet(self) -> str:
if self.highlights:
return " ... ".join(h.strip() for h in self.highlights if h)
if self.summary:
return self.summary.strip()
if self.text:
text = self.text.strip()
if len(text) > 500:
text = text[:500].rstrip() + "..."
return text
return ""


def _build_client():
from exa_py import Exa

api_key = os.environ.get("EXA_API_KEY")
if not api_key:
raise RuntimeError("EXA_API_KEY is not set")
client = Exa(api_key=api_key)
try:
client.headers["x-exa-integration"] = _EXA_INTEGRATION_HEADER
except AttributeError:
pass
return client


def _run_search(
query: str,
num_results: int,
category: Optional[str] = None,
include_domains: Optional[List[str]] = None,
exclude_domains: Optional[List[str]] = None,
start_published_date: Optional[str] = None,
end_published_date: Optional[str] = None,
) -> List[ExaResult]:
client = _build_client()

kwargs: dict = {
"num_results": num_results,
"type": "auto",
"highlights": {"num_sentences": 3},
"text": {"max_characters": 500},
}
if category:
kwargs["category"] = category
if include_domains:
kwargs["include_domains"] = include_domains
if exclude_domains:
kwargs["exclude_domains"] = exclude_domains
if start_published_date:
kwargs["start_published_date"] = start_published_date
if end_published_date:
kwargs["end_published_date"] = end_published_date

last_err: Optional[Exception] = None
for attempt in range(5):
try:
response = client.search_and_contents(query, **kwargs)
raw_results = getattr(response, "results", None) or []
return [ExaResult.from_sdk(r) for r in raw_results]
except Exception as e:
last_err = e
continue
raise RuntimeError(f"Exa search failed after 5 attempts: {last_err}")


def _format_block(query: str, results: List[ExaResult], header: str, label: str) -> str:
if not results:
return f"No results found for '{query}'. Try with a more general query."

snippets: List[str] = []
for idx, r in enumerate(results, start=1):
date_published = f"\nDate published: {r.published_date}" if r.published_date else ""
author_line = f"\nSource: {r.author}" if r.author else ""
snippet = r.snippet()
snippet_block = f"\n{snippet}" if snippet else ""
title = r.title or r.url
snippets.append(f"{idx}. [{title}]({r.url}){date_published}{author_line}{snippet_block}")

return f"{header} for '{query}' found {len(snippets)} results:\n\n## {label}\n" + "\n\n".join(snippets)


def search_with_exa(
query: str,
num_results: int = 10,
category: Optional[str] = None,
include_domains: Optional[List[str]] = None,
exclude_domains: Optional[List[str]] = None,
start_published_date: Optional[str] = None,
end_published_date: Optional[str] = None,
) -> str:
try:
results = _run_search(
query,
num_results=num_results,
category=category,
include_domains=include_domains,
exclude_domains=exclude_domains,
start_published_date=start_published_date,
end_published_date=end_published_date,
)
except RuntimeError as e:
return f"[Exa search] {e}"
except Exception as e:
return f"[Exa search] Unexpected error: {e}"
return _format_block(query, results, header="A web search", label="Web Results")


def scholar_with_exa(query: str, num_results: int = 10) -> str:
try:
results = _run_search(query, num_results=num_results, category="research paper")
except RuntimeError as e:
return f"[Exa scholar] {e}"
except Exception as e:
return f"[Exa scholar] Unexpected error: {e}"
return _format_block(query, results, header="A research-paper search", label="Scholar Results")
14 changes: 11 additions & 3 deletions inference/tool_scholar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
from concurrent.futures import ThreadPoolExecutor
import http.client

from tool_exa import scholar_with_exa


SERPER_KEY=os.environ.get('SERPER_KEY_ID')
SEARCH_PROVIDER=os.environ.get('SEARCH_PROVIDER', 'serper').lower()


@register_tool("google_scholar", allow_overwrite=True)
Expand Down Expand Up @@ -91,20 +94,25 @@ def google_scholar_with_serp(self, query: str):
return f"No results found for '{query}'. Try with a more general query."


def _do_scholar_search(self, query: str) -> str:
if SEARCH_PROVIDER == "exa":
return scholar_with_exa(query)
return self.google_scholar_with_serp(query)

def call(self, params: Union[str, dict], **kwargs) -> str:
# assert GOOGLE_SEARCH_KEY is not None, "Please set the IDEALAB_SEARCH_KEY environment variable."
try:
params = self._verify_json_format_args(params)
query = params["query"]
except:
return "[google_scholar] Invalid request format: Input must be a JSON object containing 'query' field"

if isinstance(query, str):
response = self.google_scholar_with_serp(query)
response = self._do_scholar_search(query)
else:
assert isinstance(query, List)
with ThreadPoolExecutor(max_workers=3) as executor:

response = list(executor.map(self.google_scholar_with_serp, query))
response = list(executor.map(self._do_scholar_search, query))
response = "\n=======\n".join(response)
return response
Loading