Skip to content

Commit 086344d

Browse files
committed
feat: Implement search result caching, agent-specific search boosts, and new CLI commands for analytics and tool testing.
1 parent e5b0058 commit 086344d

9 files changed

Lines changed: 779 additions & 179 deletions

File tree

deploy.sh

Lines changed: 59 additions & 166 deletions
Large diffs are not rendered by default.

tooldns/api.py

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from datetime import datetime, timedelta
5050
from pathlib import Path
5151

52-
from fastapi import APIRouter, Depends, HTTPException
52+
from fastapi import APIRouter, Depends, HTTPException, Response
5353
from tooldns.config import settings, TOOLDNS_HOME
5454
from tooldns.auth import require_api_key
5555
from tooldns.workflows import WorkflowEngine
@@ -66,6 +66,7 @@
6666
RegisterMCPRequest, CreateSkillRequest,
6767
CallToolRequest, CreateMacroRequest, MacroStep, MacroInfo,
6868
PreflightRequest, PreflightResponse, PreflightToolMatch,
69+
SearchSelectRequest,
6970
)
7071
from tooldns.caller import call_tool as caller_call_tool, load_skill_content, resolve_args
7172

@@ -235,6 +236,8 @@ def search_tools(req: SearchRequest, auth: dict = Depends(require_api_key)):
235236
session = _get_session(req.session_id)
236237
if session:
237238
seen_tool_ids = set(session["seen_tool_ids"])
239+
if not req.agent_id and session.get("agent_id"):
240+
req.agent_id = session["agent_id"]
238241
else:
239242
raise HTTPException(status_code=404, detail=f"Session not found: {req.session_id}")
240243

@@ -696,6 +699,14 @@ def batch_search_tools(req: BatchSearchRequest, auth: dict = Depends(require_api
696699
else:
697700
raise HTTPException(status_code=404, detail=f"Session not found: {req.session_id}")
698701

702+
# Get agent preference boosts if agent_id provided
703+
preference_boosts = {}
704+
if req.agent_id and _workflow_engine:
705+
try:
706+
preference_boosts = _workflow_engine.get_agent_boosts(req.agent_id)
707+
except Exception as e:
708+
logger.warning(f"Failed to get agent preferences: {e}")
709+
699710
batch_start = _time.time()
700711
results = []
701712
total_tokens_saved = 0
@@ -712,6 +723,7 @@ def batch_search_tools(req: BatchSearchRequest, auth: dict = Depends(require_api
712723
minimal=req.minimal,
713724
allowed_tool_ids=allowed_tool_ids,
714725
seen_tool_ids=seen_tool_ids,
726+
preference_boosts=preference_boosts if preference_boosts else None,
715727
)
716728

717729
# Update shared seen_tool_ids so next query in batch benefits from dedup
@@ -744,6 +756,24 @@ def batch_search_tools(req: BatchSearchRequest, auth: dict = Depends(require_api
744756
)
745757

746758

759+
# -----------------------------------------------------------------------
760+
# Search Select — record tool selection without calling the tool
761+
# -----------------------------------------------------------------------
762+
763+
@router.post("/search/select")
764+
def search_select(req: SearchSelectRequest):
765+
"""
766+
Record which search result an agent selected.
767+
768+
Allows agents to report tool selections without executing the tool
769+
via /v1/call. Updates agent preferences for personalized search.
770+
"""
771+
if not _workflow_engine:
772+
raise HTTPException(status_code=503, detail="Workflow engine not initialized")
773+
_workflow_engine.record_tool_selection(req.agent_id, req.tool_id, req.query, req.confidence)
774+
return {"status": "ok", "agent_id": req.agent_id, "tool_id": req.tool_id}
775+
776+
747777
# -----------------------------------------------------------------------
748778
# Agent Sessions
749779
# -----------------------------------------------------------------------
@@ -923,8 +953,11 @@ def create_profile(req: CreateProfileRequest):
923953

924954

925955
@router.get("/profiles", response_model=list[ProfileInfo])
926-
def list_profiles():
956+
def list_profiles(response: Response = None):
927957
"""List all tool profiles with their current matched tool counts."""
958+
if response:
959+
response.headers["Cache-Control"] = "public, max-age=60"
960+
928961
with _profiles_lock:
929962
profile_list = list(_profiles.values())
930963

@@ -1106,12 +1139,15 @@ def _sanitize_source(source: dict, is_admin: bool) -> dict:
11061139

11071140

11081141
@router.get("/sources")
1109-
def list_sources(key_info: dict = Depends(require_api_key)):
1142+
def list_sources(key_info: dict = Depends(require_api_key), response: Response = None):
11101143
"""
11111144
List all registered sources with their status and tool counts.
11121145
11131146
Admin keys see full config. Sub-keys see name/type/status only.
11141147
"""
1148+
if response:
1149+
response.headers["Cache-Control"] = "public, max-age=60"
1150+
11151151
sources = _database.get_all_sources()
11161152
is_admin = key_info.get("is_admin", False)
11171153
return [_sanitize_source(s, is_admin) for s in sources]
@@ -1181,7 +1217,7 @@ def list_categories():
11811217

11821218

11831219
@router.get("/tools")
1184-
def list_tools(source: str = None, category: str = None):
1220+
def list_tools(source: str = None, category: str = None, response: Response = None):
11851221
"""
11861222
List all indexed tools, optionally filtered by source or category.
11871223
@@ -1192,6 +1228,9 @@ def list_tools(source: str = None, category: str = None):
11921228
Returns:
11931229
dict: Tool list with count.
11941230
"""
1231+
if response:
1232+
response.headers["Cache-Control"] = "public, max-age=300"
1233+
11951234
if source:
11961235
tools = _database.get_tools_by_source(source)
11971236
else:
@@ -1207,7 +1246,7 @@ def list_tools(source: str = None, category: str = None):
12071246

12081247

12091248
@router.get("/tool/{tool_id:path}")
1210-
def get_tool(tool_id: str):
1249+
def get_tool(tool_id: str, response: Response = None):
12111250
"""
12121251
Get full details for a specific tool.
12131252
@@ -1227,6 +1266,7 @@ def get_tool(tool_id: str):
12271266
"""
12281267
from pathlib import Path
12291268
import json as json_mod
1269+
import hashlib
12301270

12311271
# Find the tool in the database
12321272
tool = _database.get_tool_by_id(tool_id)
@@ -1252,6 +1292,12 @@ def get_tool(tool_id: str):
12521292
if skill_content:
12531293
result["skill_content"] = skill_content
12541294

1295+
# Cache headers + ETag
1296+
if response:
1297+
response.headers["Cache-Control"] = "public, max-age=300"
1298+
etag = hashlib.md5(json_mod.dumps(result, sort_keys=True).encode()).hexdigest()
1299+
response.headers["ETag"] = f'"{etag}"'
1300+
12551301
return result
12561302

12571303

0 commit comments

Comments
 (0)