diff --git a/pyproject.toml b/pyproject.toml index aaca7c4f..4a4a2012 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "simpleeval>=1.0.3", "jsonschema>=4.25.1", "duckduckgo-search>=8.1.1", + "tavily-python>=0.5.0", "pydantic>=2.12.5", "scrapegraph-py>=1.44.0", ] diff --git a/scrapegraphai/utils/research_web.py b/scrapegraphai/utils/research_web.py index d633084d..bce2ce07 100644 --- a/scrapegraphai/utils/research_web.py +++ b/scrapegraphai/utils/research_web.py @@ -57,13 +57,14 @@ class SearchConfig(BaseModel): None, description="Proxy configuration" ) serper_api_key: Optional[str] = Field(None, description="API key for Serper") + tavily_api_key: Optional[str] = Field(None, description="API key for Tavily") region: Optional[str] = Field(None, description="Country/region code") language: str = Field("en", description="Language code") @validator("search_engine") def validate_search_engine(cls, v): """Validate search engine.""" - valid_engines = {"duckduckgo", "bing", "searxng", "serper"} + valid_engines = {"duckduckgo", "bing", "searxng", "serper", "tavily"} if v.lower() not in valid_engines: raise ValueError( f"Search engine must be one of: {', '.join(valid_engines)}" @@ -166,6 +167,7 @@ def search_on_web( timeout: int = 10, proxy: Optional[Union[str, Dict, ProxyConfig]] = None, serper_api_key: Optional[str] = None, + tavily_api_key: Optional[str] = None, region: Optional[str] = None, language: str = "en", ) -> List[str]: @@ -180,6 +182,7 @@ def search_on_web( timeout (int): Request timeout in seconds proxy (str | dict | ProxyConfig): Proxy configuration serper_api_key (str): API key for Serper + tavily_api_key (str): API key for Tavily region (str): Country/region code (e.g., 'mx' for Mexico) language (str): Language code (e.g., 'es' for Spanish) @@ -204,6 +207,7 @@ def search_on_web( timeout=timeout, proxy=proxy, serper_api_key=serper_api_key, + tavily_api_key=tavily_api_key, region=region, language=language, ) @@ -237,6 +241,11 @@ def search_on_web( config.query, config.max_results, config.serper_api_key, config.timeout ) + elif config.search_engine == "tavily": + results = _search_tavily( + config.query, config.max_results, config.tavily_api_key + ) + return filter_pdf_links(results) except requests.Timeout: @@ -381,6 +390,44 @@ def _search_serper( raise SearchRequestError(f"Serper search failed: {str(e)}") +def _search_tavily( + query: str, max_results: int, api_key: Optional[str] = None +) -> List[str]: + """ + Helper function for Tavily search. + + Args: + query (str): Search query + max_results (int): Maximum number of results to return + api_key (str, optional): API key for Tavily. Falls back to TAVILY_API_KEY env var. + + Returns: + List[str]: List of URLs from search results + """ + if not api_key: + import os + + api_key = os.getenv("TAVILY_API_KEY") + if not api_key: + raise SearchConfigError( + "Tavily API key is required. Provide tavily_api_key or set TAVILY_API_KEY." + ) + + try: + from tavily import TavilyClient + + client = TavilyClient(api_key=api_key) + response = client.search(query=query, max_results=max_results) + results = [result["url"] for result in response.get("results", [])] + return results + except ImportError: + raise SearchConfigError( + "tavily-python package is required. Install it with: pip install tavily-python" + ) + except Exception as e: + raise SearchRequestError(f"Tavily search failed: {str(e)}") + + def format_proxy(proxy_config: Union[str, Dict, ProxyConfig]) -> str: """ Format proxy configuration into a string.