Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ ocr = [
"ipywidgets>=8.1.0",
"pillow>=10.4.0",
]
tavily = ["tavily-python>=0.5.0"]

[build-system]
requires = ["hatchling==1.26.3"]
Expand Down
46 changes: 45 additions & 1 deletion scrapegraphai/utils/research_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ class SearchConfig(BaseModel):
None, description="Proxy configuration"
)
serper_api_key: Optional[str] = Field(None, description="API key for Serper")
tavily_api_key: Optional[str] = Field(None, description="API key for Tavily")
region: Optional[str] = Field(None, description="Country/region code")
language: str = Field("en", description="Language code")

@validator("search_engine")
def validate_search_engine(cls, v):
"""Validate search engine."""
valid_engines = {"duckduckgo", "bing", "searxng", "serper"}
valid_engines = {"duckduckgo", "bing", "searxng", "serper", "tavily"}
if v.lower() not in valid_engines:
raise ValueError(
f"Search engine must be one of: {', '.join(valid_engines)}"
Expand Down Expand Up @@ -166,6 +167,7 @@ def search_on_web(
timeout: int = 10,
proxy: Optional[Union[str, Dict, ProxyConfig]] = None,
serper_api_key: Optional[str] = None,
tavily_api_key: Optional[str] = None,
region: Optional[str] = None,
language: str = "en",
) -> List[str]:
Expand Down Expand Up @@ -204,6 +206,7 @@ def search_on_web(
timeout=timeout,
proxy=proxy,
serper_api_key=serper_api_key,
tavily_api_key=tavily_api_key,
region=region,
language=language,
)
Expand Down Expand Up @@ -237,6 +240,11 @@ def search_on_web(
config.query, config.max_results, config.serper_api_key, config.timeout
)

elif config.search_engine == "tavily":
results = _search_tavily(
config.query, config.max_results, config.tavily_api_key
)

return filter_pdf_links(results)

except requests.Timeout:
Expand Down Expand Up @@ -381,6 +389,42 @@ def _search_serper(
raise SearchRequestError(f"Serper search failed: {str(e)}")


def _search_tavily(query: str, max_results: int, api_key: str) -> List[str]:
"""
Helper function for Tavily search.

Args:
query (str): Search query
max_results (int): Maximum number of results to return
api_key (str): API key for Tavily

Returns:
List[str]: List of URLs from search results
"""
if not api_key:
raise SearchConfigError("Tavily API key is required")

try:
from tavily import TavilyClient

client = TavilyClient(api_key=api_key)
response = client.search(query=query, max_results=max_results)

results = []
for result in response.get("results", []):
if "url" in result:
results.append(result["url"])

return results
except ImportError:
raise SearchConfigError(
"tavily-python package is required for Tavily search. "
"Install it with: pip install tavily-python"
)
except Exception as e:
raise SearchRequestError(f"Tavily search failed: {str(e)}")


def format_proxy(proxy_config: Union[str, Dict, ProxyConfig]) -> str:
"""
Format proxy configuration into a string.
Expand Down
93 changes: 92 additions & 1 deletion tests/utils/research_web_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from unittest.mock import MagicMock, patch

import pytest

from scrapegraphai.utils.research_web import ( # Replace with actual path to your file
from scrapegraphai.utils.research_web import (
SearchConfig,
SearchConfigError,
SearchRequestError,
_search_tavily,
search_on_web,
)

Expand Down Expand Up @@ -30,3 +36,88 @@ def test_max_results():
results_5 = search_on_web("test query", max_results=5)
results_10 = search_on_web("test query", max_results=10)
assert len(results_5) <= len(results_10)


# --- Tavily search engine tests ---


def test_search_config_accepts_tavily():
"""Tests that SearchConfig accepts 'tavily' as a valid search engine."""
config = SearchConfig(query="test query", search_engine="tavily")
assert config.search_engine == "tavily"


def test_search_tavily_success():
"""Tests _search_tavily returns URLs from Tavily search results."""
mock_client = MagicMock()
mock_client.search.return_value = {
"results": [
{"title": "Result 1", "url": "https://example.com/1", "score": 0.9},
{"title": "Result 2", "url": "https://example.com/2", "score": 0.8},
]
}
mock_tavily_module = MagicMock()
mock_tavily_module.TavilyClient.return_value = mock_client

with patch.dict("sys.modules", {"tavily": mock_tavily_module}):
results = _search_tavily("test query", max_results=2, api_key="tvly-test-key")

assert results == ["https://example.com/1", "https://example.com/2"]
mock_tavily_module.TavilyClient.assert_called_once_with(api_key="tvly-test-key")
mock_client.search.assert_called_once_with(query="test query", max_results=2)


def test_search_tavily_missing_api_key():
"""Tests _search_tavily raises SearchConfigError when API key is missing."""
with pytest.raises(SearchConfigError, match="Tavily API key is required"):
_search_tavily("test query", max_results=5, api_key=None)


def test_search_tavily_empty_api_key():
"""Tests _search_tavily raises SearchConfigError when API key is empty string."""
with pytest.raises(SearchConfigError, match="Tavily API key is required"):
_search_tavily("test query", max_results=5, api_key="")


def test_search_tavily_api_error():
"""Tests _search_tavily raises SearchRequestError on API failure."""
mock_client = MagicMock()
mock_client.search.side_effect = Exception("API rate limit exceeded")
mock_tavily_module = MagicMock()
mock_tavily_module.TavilyClient.return_value = mock_client

with patch.dict("sys.modules", {"tavily": mock_tavily_module}):
with pytest.raises(SearchRequestError, match="Tavily search failed"):
_search_tavily("test query", max_results=5, api_key="tvly-test-key")


def test_search_tavily_empty_results():
"""Tests _search_tavily returns empty list when no results found."""
mock_client = MagicMock()
mock_client.search.return_value = {"results": []}
mock_tavily_module = MagicMock()
mock_tavily_module.TavilyClient.return_value = mock_client

with patch.dict("sys.modules", {"tavily": mock_tavily_module}):
results = _search_tavily("test query", max_results=5, api_key="tvly-test-key")

assert results == []


@patch("scrapegraphai.utils.research_web._search_tavily")
def test_search_on_web_tavily_integration(mock_search_tavily):
"""Tests search_on_web dispatches to tavily engine correctly."""
mock_search_tavily.return_value = [
"https://example.com/1",
"https://example.com/2",
]

results = search_on_web(
"test query",
search_engine="tavily",
max_results=2,
tavily_api_key="tvly-test-key",
)

assert results == ["https://example.com/1", "https://example.com/2"]
mock_search_tavily.assert_called_once()