diff --git a/pyproject.toml b/pyproject.toml index aaca7c4f..8f0cec99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,6 +77,7 @@ ocr = [ "ipywidgets>=8.1.0", "pillow>=10.4.0", ] +tavily = ["tavily-python>=0.5.0"] [build-system] requires = ["hatchling==1.26.3"] diff --git a/scrapegraphai/utils/research_web.py b/scrapegraphai/utils/research_web.py index d633084d..376de720 100644 --- a/scrapegraphai/utils/research_web.py +++ b/scrapegraphai/utils/research_web.py @@ -57,13 +57,14 @@ class SearchConfig(BaseModel): None, description="Proxy configuration" ) serper_api_key: Optional[str] = Field(None, description="API key for Serper") + tavily_api_key: Optional[str] = Field(None, description="API key for Tavily") region: Optional[str] = Field(None, description="Country/region code") language: str = Field("en", description="Language code") @validator("search_engine") def validate_search_engine(cls, v): """Validate search engine.""" - valid_engines = {"duckduckgo", "bing", "searxng", "serper"} + valid_engines = {"duckduckgo", "bing", "searxng", "serper", "tavily"} if v.lower() not in valid_engines: raise ValueError( f"Search engine must be one of: {', '.join(valid_engines)}" @@ -166,6 +167,7 @@ def search_on_web( timeout: int = 10, proxy: Optional[Union[str, Dict, ProxyConfig]] = None, serper_api_key: Optional[str] = None, + tavily_api_key: Optional[str] = None, region: Optional[str] = None, language: str = "en", ) -> List[str]: @@ -204,6 +206,7 @@ def search_on_web( timeout=timeout, proxy=proxy, serper_api_key=serper_api_key, + tavily_api_key=tavily_api_key, region=region, language=language, ) @@ -237,6 +240,11 @@ def search_on_web( config.query, config.max_results, config.serper_api_key, config.timeout ) + elif config.search_engine == "tavily": + results = _search_tavily( + config.query, config.max_results, config.tavily_api_key + ) + return filter_pdf_links(results) except requests.Timeout: @@ -381,6 +389,42 @@ def _search_serper( raise SearchRequestError(f"Serper search failed: {str(e)}") +def _search_tavily(query: str, max_results: int, api_key: str) -> List[str]: + """ + Helper function for Tavily search. + + Args: + query (str): Search query + max_results (int): Maximum number of results to return + api_key (str): API key for Tavily + + Returns: + List[str]: List of URLs from search results + """ + if not api_key: + raise SearchConfigError("Tavily API key is required") + + try: + from tavily import TavilyClient + + client = TavilyClient(api_key=api_key) + response = client.search(query=query, max_results=max_results) + + results = [] + for result in response.get("results", []): + if "url" in result: + results.append(result["url"]) + + return results + except ImportError: + raise SearchConfigError( + "tavily-python package is required for Tavily search. " + "Install it with: pip install tavily-python" + ) + except Exception as e: + raise SearchRequestError(f"Tavily search failed: {str(e)}") + + def format_proxy(proxy_config: Union[str, Dict, ProxyConfig]) -> str: """ Format proxy configuration into a string. diff --git a/tests/utils/research_web_test.py b/tests/utils/research_web_test.py index a4a37191..359fa1a2 100644 --- a/tests/utils/research_web_test.py +++ b/tests/utils/research_web_test.py @@ -1,6 +1,12 @@ +from unittest.mock import MagicMock, patch + import pytest -from scrapegraphai.utils.research_web import ( # Replace with actual path to your file +from scrapegraphai.utils.research_web import ( + SearchConfig, + SearchConfigError, + SearchRequestError, + _search_tavily, search_on_web, ) @@ -30,3 +36,88 @@ def test_max_results(): results_5 = search_on_web("test query", max_results=5) results_10 = search_on_web("test query", max_results=10) assert len(results_5) <= len(results_10) + + +# --- Tavily search engine tests --- + + +def test_search_config_accepts_tavily(): + """Tests that SearchConfig accepts 'tavily' as a valid search engine.""" + config = SearchConfig(query="test query", search_engine="tavily") + assert config.search_engine == "tavily" + + +def test_search_tavily_success(): + """Tests _search_tavily returns URLs from Tavily search results.""" + mock_client = MagicMock() + mock_client.search.return_value = { + "results": [ + {"title": "Result 1", "url": "https://example.com/1", "score": 0.9}, + {"title": "Result 2", "url": "https://example.com/2", "score": 0.8}, + ] + } + mock_tavily_module = MagicMock() + mock_tavily_module.TavilyClient.return_value = mock_client + + with patch.dict("sys.modules", {"tavily": mock_tavily_module}): + results = _search_tavily("test query", max_results=2, api_key="tvly-test-key") + + assert results == ["https://example.com/1", "https://example.com/2"] + mock_tavily_module.TavilyClient.assert_called_once_with(api_key="tvly-test-key") + mock_client.search.assert_called_once_with(query="test query", max_results=2) + + +def test_search_tavily_missing_api_key(): + """Tests _search_tavily raises SearchConfigError when API key is missing.""" + with pytest.raises(SearchConfigError, match="Tavily API key is required"): + _search_tavily("test query", max_results=5, api_key=None) + + +def test_search_tavily_empty_api_key(): + """Tests _search_tavily raises SearchConfigError when API key is empty string.""" + with pytest.raises(SearchConfigError, match="Tavily API key is required"): + _search_tavily("test query", max_results=5, api_key="") + + +def test_search_tavily_api_error(): + """Tests _search_tavily raises SearchRequestError on API failure.""" + mock_client = MagicMock() + mock_client.search.side_effect = Exception("API rate limit exceeded") + mock_tavily_module = MagicMock() + mock_tavily_module.TavilyClient.return_value = mock_client + + with patch.dict("sys.modules", {"tavily": mock_tavily_module}): + with pytest.raises(SearchRequestError, match="Tavily search failed"): + _search_tavily("test query", max_results=5, api_key="tvly-test-key") + + +def test_search_tavily_empty_results(): + """Tests _search_tavily returns empty list when no results found.""" + mock_client = MagicMock() + mock_client.search.return_value = {"results": []} + mock_tavily_module = MagicMock() + mock_tavily_module.TavilyClient.return_value = mock_client + + with patch.dict("sys.modules", {"tavily": mock_tavily_module}): + results = _search_tavily("test query", max_results=5, api_key="tvly-test-key") + + assert results == [] + + +@patch("scrapegraphai.utils.research_web._search_tavily") +def test_search_on_web_tavily_integration(mock_search_tavily): + """Tests search_on_web dispatches to tavily engine correctly.""" + mock_search_tavily.return_value = [ + "https://example.com/1", + "https://example.com/2", + ] + + results = search_on_web( + "test query", + search_engine="tavily", + max_results=2, + tavily_api_key="tvly-test-key", + ) + + assert results == ["https://example.com/1", "https://example.com/2"] + mock_search_tavily.assert_called_once()