diff --git a/examples/basic_modules/textual_memory_internet_search_example.py b/examples/basic_modules/textual_memory_internet_search_example.py index 9007d7e67..46dd01a93 100644 --- a/examples/basic_modules/textual_memory_internet_search_example.py +++ b/examples/basic_modules/textual_memory_internet_search_example.py @@ -288,6 +288,83 @@ print("\n Get your credentials from:") print(" https://developers.google.com/custom-search/v1/overview") +# ============================================================================ +# Step 7: Test Tavily Search API (Optional) +# ============================================================================ +print("\n" + "=" * 80) +print("TAVILY SEARCH API TEST") +print("=" * 80) + +tavily_api_key = os.environ.get("TAVILY_API_KEY", "") + +if tavily_api_key: + print("\n[Step 7.1] Configuring Tavily Search retriever...") + + tavily_retriever_config = InternetRetrieverConfigFactory.model_validate( + { + "backend": "tavily", + "config": { + "api_key": tavily_api_key, + "max_results": 5, + }, + } + ) + + print("✓ Tavily retriever configured") + print(f" Max results: {tavily_retriever_config.config.max_results}") + + print("\n[Step 7.2] Creating Tavily retriever instance...") + tavily_retriever = InternetRetrieverFactory.from_config(tavily_retriever_config, embedder) + print("✓ Tavily retriever initialized") + + print("\n[Step 7.3] Performing Tavily web search...") + tavily_query = "latest AI research breakthroughs 2024" + print(f" Query: '{tavily_query}'") + print(" Searching via Tavily Search API...\n") + + tavily_results = tavily_retriever.retrieve_from_internet(tavily_query) + + print("✓ Tavily search completed!") + print(f"✓ Retrieved {len(tavily_results)} memory items from Tavily search\n") + + print("=" * 80) + print("TAVILY SEARCH RESULTS") + print("=" * 80) + + if not tavily_results: + print("\nNo results found from Tavily.") + print(" This might indicate:") + print(" - Invalid Tavily API key") + print(" - API quota exceeded") + print(" - Network connectivity issues") + else: + for idx, item in enumerate(tavily_results, 1): + print(f"\n[Tavily Result #{idx}]") + print("-" * 80) + + content = item.memory + if len(content) > 300: + print(f"Content: {content[:300]}...") + print(f" (... {len(content) - 300} more characters)") + else: + print(f"Content: {content}") + + if hasattr(item, "metadata") and item.metadata: + metadata = item.metadata + if hasattr(metadata, "sources") and metadata.sources: + print(f"Source: {metadata.sources[0] if metadata.sources else 'N/A'}") + + print() + + print("=" * 80) + print("Tavily Search Test completed!") + print("=" * 80) +else: + print("\n Skipping Tavily Search API test") + print(" To enable this test, set the following environment variable:") + print(" - TAVILY_API_KEY: Your Tavily API key") + print("\n Get your API key from: https://app.tavily.com") + print("\n" + "=" * 80) print("ALL TESTS COMPLETED") print("=" * 80) @@ -297,6 +374,10 @@ print(" ✓ Tested Google Custom Search API") else: print(" ⏭️ Skipped Google Custom Search API (credentials not set)") +if tavily_api_key: + print(" Tested Tavily Search API") +else: + print(" Skipped Tavily Search API (credentials not set)") print("\n💡 Quick Start:") print(" # Set BochaAI API key") print(" export BOCHA_API_KEY='sk-your-bocha-api-key'") @@ -305,5 +386,8 @@ print(" export GOOGLE_API_KEY='your-google-api-key'") print(" export GOOGLE_SEARCH_ENGINE_ID='your-search-engine-id'") print(" ") +print(" # Set Tavily API key (optional)") +print(" export TAVILY_API_KEY='tvly-your-tavily-api-key'") +print(" ") print(" # Run the example") print(" python examples/basic_modules/textual_memory_internet_search_example.py\n") diff --git a/pyproject.toml b/pyproject.toml index de8e66ad1..4d20c91a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,11 @@ skill-mem = [ "alibabacloud-oss-v2 (>=1.2.2,<1.2.3)", ] +# Tavily Search +tavily = [ + "tavily-python (>=0.5.0)", +] + # All optional dependencies # Allow users to install with `pip install MemoryOS[all]` all = [ @@ -129,6 +134,7 @@ all = [ "nltk (>=3.9.1,<4.0.0)", "rake-nltk (>=1.0.6,<1.1.0)", "alibabacloud-oss-v2 (>=1.2.2,<1.2.3)", + "tavily-python (>=0.5.0)", # Uncategorized dependencies ] diff --git a/src/memos/api/config.py b/src/memos/api/config.py index 87f1efd8e..c93aaf37c 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -626,7 +626,20 @@ def get_oss_config() -> dict[str, Any] | None: return config def get_internet_config() -> dict[str, Any]: - """Get embedder configuration.""" + """Get internet retriever configuration.""" + tavily_api_key = os.getenv("TAVILY_API_KEY", "") + bocha_api_key = os.getenv("BOCHA_API_KEY", "") + + # Use Tavily if TAVILY_API_KEY is set and BOCHA_API_KEY is absent + if tavily_api_key and not bocha_api_key: + return { + "backend": "tavily", + "config": { + "api_key": tavily_api_key, + "max_results": 15, + }, + } + reader_config = APIConfig.get_reader_config() return { "backend": "bocha", diff --git a/src/memos/configs/internet_retriever.py b/src/memos/configs/internet_retriever.py index 1c5e2b8ad..2ce8e70fc 100644 --- a/src/memos/configs/internet_retriever.py +++ b/src/memos/configs/internet_retriever.py @@ -67,6 +67,12 @@ class BochaSearchConfig(BaseInternetRetrieverConfig): ) +class TavilySearchConfig(BaseInternetRetrieverConfig): + """Configuration class for Tavily Search API.""" + + max_results: int = Field(default=20, description="Maximum number of results to retrieve") + + class InternetRetrieverConfigFactory(BaseConfig): """Factory class for creating internet retriever configurations.""" @@ -82,6 +88,7 @@ class InternetRetrieverConfigFactory(BaseConfig): "bing": BingSearchConfig, "xinyu": XinyuSearchConfig, "bocha": BochaSearchConfig, + "tavily": TavilySearchConfig, } @field_validator("backend") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py b/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py index 3498f596a..b947298f1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/internet_retriever_factory.py @@ -9,6 +9,7 @@ from memos.memories.textual.tree_text_memory.retrieve.internet_retriever import ( InternetGoogleRetriever, ) +from memos.memories.textual.tree_text_memory.retrieve.tavilysearch import TavilySearchRetriever from memos.memories.textual.tree_text_memory.retrieve.xinyusearch import XinyuSearchRetriever from memos.memos_tools.singleton import singleton_factory @@ -21,6 +22,7 @@ class InternetRetrieverFactory: "bing": InternetGoogleRetriever, # TODO: Implement BingRetriever "xinyu": XinyuSearchRetriever, "bocha": BochaAISearchRetriever, + "tavily": TavilySearchRetriever, } @classmethod @@ -81,6 +83,12 @@ def from_config( reader=MemReaderFactory.from_config(config.reader), max_results=config.max_results, ) + elif backend == "tavily": + return retriever_class( + api_key=config.api_key, + embedder=embedder, + max_results=config.max_results, + ) else: raise ValueError(f"Unsupported backend: {backend}") diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/tavilysearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/tavilysearch.py new file mode 100644 index 000000000..26091b041 --- /dev/null +++ b/src/memos/memories/textual/tree_text_memory/retrieve/tavilysearch.py @@ -0,0 +1,236 @@ +"""Tavily Search API retriever for tree text memory.""" + +from concurrent.futures import as_completed +from datetime import datetime +from typing import Any + +from memos.context.context import ContextThreadPoolExecutor +from memos.dependency import require_python_package +from memos.embedders.factory import OllamaEmbedder +from memos.log import get_logger +from memos.mem_reader.read_multi_modal import detect_lang +from memos.memories.textual.item import ( + SearchedTreeNodeTextualMemoryMetadata, + SourceMessage, + TextualMemoryItem, +) + + +logger = get_logger(__name__) + + +class TavilySearchRetriever: + """Tavily retriever that converts search results into TextualMemoryItem objects.""" + + @require_python_package( + import_name="tavily", + install_command="pip install tavily-python", + install_link="https://github.com/tavily-ai/tavily-python", + ) + @require_python_package( + import_name="jieba", + install_command="pip install jieba", + install_link="https://github.com/fxsjy/jieba", + ) + def __init__( + self, + api_key: str, + embedder: OllamaEmbedder, + max_results: int = 20, + ): + """ + Initialize Tavily Search retriever. + + Args: + api_key: Tavily API key + embedder: Embedder instance for generating embeddings + max_results: Maximum number of search results to retrieve + """ + from jieba.analyse import TextRank + from tavily import TavilyClient + + self.client = TavilyClient(api_key=api_key) + self.embedder = embedder + self.max_results = max_results + self.zh_fast_keywords_extractor = TextRank() + + def _extract_tags(self, title: str, content: str, summary: str, parsed_goal=None) -> list[str]: + """ + Extract tags from title, content and summary. + + Args: + title: Article title + content: Article content + summary: Article summary + parsed_goal: Parsed task goal (optional) + + Returns: + List of extracted tags + """ + tags = ["tavily_search", "news"] + + text = f"{title} {content} {summary}".lower() + + keywords = { + "economy": [ + "economy", "GDP", "growth", "production", "industry", + "investment", "consumption", "market", "trade", "finance", + ], + "politics": [ + "politics", "government", "policy", "meeting", "leader", + "election", "parliament", "ministry", + ], + "technology": [ + "technology", "tech", "innovation", "digital", "internet", + "AI", "artificial intelligence", "software", "hardware", + ], + "sports": [ + "sports", "game", "athlete", "olympic", "championship", + "tournament", "team", "player", + ], + "culture": [ + "culture", "education", "art", "history", "literature", + "music", "film", "museum", + ], + "health": [ + "health", "medical", "pandemic", "hospital", "doctor", + "medicine", "disease", "treatment", + ], + "environment": [ + "environment", "ecology", "pollution", "green", "climate", + "sustainability", "renewable", + ], + } + + for category, words in keywords.items(): + if any(word in text for word in words): + tags.append(category) + + if parsed_goal and hasattr(parsed_goal, "tags"): + tags.extend(parsed_goal.tags) + + return list(set(tags))[:15] + + def retrieve_from_internet( + self, query: str, top_k: int = 10, parsed_goal=None, info=None, mode="fast" + ) -> list[TextualMemoryItem]: + """ + Retrieve information from the internet using Tavily Search API. + + Args: + query: Search query + top_k: Number of results to retrieve + parsed_goal: Parsed task goal (optional) + info (dict): Metadata for memory consumption tracking + mode: Retrieval mode ("fast" for summaries only) + + Returns: + List of TextualMemoryItem + """ + try: + response = self.client.search( + query=query, + max_results=min(top_k, self.max_results), + search_depth="basic", + topic="general", + ) + search_results = response.get("results", []) + except Exception: + import traceback + + logger.error(f"Tavily search error: {traceback.format_exc()}") + search_results = [] + + return self._convert_to_mem_items(search_results, query, parsed_goal, info, mode=mode) + + def _convert_to_mem_items( + self, search_results: list[dict], query: str, parsed_goal=None, info=None, mode="fast" + ): + """Convert Tavily search results into TextualMemoryItem objects.""" + memory_items = [] + if not info: + info = {"user_id": "", "session_id": ""} + + with ContextThreadPoolExecutor(max_workers=8) as executor: + futures = [ + executor.submit(self._process_result, r, query, parsed_goal, info, mode=mode) + for r in search_results + ] + for future in as_completed(futures): + try: + memory_items.extend(future.result()) + except Exception as e: + logger.error(f"Error processing Tavily search result: {e}") + + # Deduplicate items by memory text + unique_memory_items = {item.memory: item for item in memory_items} + return list(unique_memory_items.values()) + + def _process_result( + self, result: dict, query: str, parsed_goal: str, info: dict[str, Any], mode="fast" + ) -> list[TextualMemoryItem]: + """Process one Tavily search result into TextualMemoryItem.""" + if mode != "fast": + logger.warning( + "TavilySearchRetriever only supports mode=\"fast\"; ignoring mode=%r", + mode, + ) + title = result.get("title", "") + content = result.get("content", "") + summary = content # Tavily returns content as the snippet/summary + url = result.get("url", "") + publish_time = result.get("published_date", "") + + if publish_time: + try: + publish_time = datetime.fromisoformat( + publish_time.replace("Z", "+00:00") + ).strftime("%Y-%m-%d") + except Exception: + publish_time = datetime.now().strftime("%Y-%m-%d") + else: + publish_time = datetime.now().strftime("%Y-%m-%d") + + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + lang = detect_lang(summary) + tags = ( + self.zh_fast_keywords_extractor.textrank(summary, topK=3)[:3] + if lang == "zh" + else self._extract_tags(title, content, summary)[:3] + ) + + return [ + TextualMemoryItem( + memory=( + f"[Outer internet view] Title: {title}\nNewsTime:" + f" {publish_time}\nSummary:" + f" {summary}\n" + ), + metadata=SearchedTreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type="OuterMemory", + status="activated", + type="fact", + source="web", + sources=[SourceMessage(type="web", url=url)] if url else [], + visibility="public", + info=info_, + background="", + confidence=0.99, + usage=[], + tags=tags, + key=title, + embedding=self.embedder.embed([content])[0], + internet_info={ + "title": title, + "url": url, + "site_name": "", + "site_icon": None, + "summary": summary, + }, + ), + ) + ]