Skip to content

Commit 77ec71d

Browse files
densumeshskeptrunedev
authored andcommitted
feature: added search query to response from agentic search
1 parent e21e5a8 commit 77ec71d

1 file changed

Lines changed: 11 additions & 8 deletions

File tree

server/src/operators/message_operator.rs

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,7 +1486,7 @@ struct PriceFilter {
14861486
min: Option<f32>,
14871487
max: Option<f32>,
14881488
}
1489-
#[derive(Deserialize, Debug)]
1489+
#[derive(Deserialize, Debug, Default)]
14901490
struct SearchParams {
14911491
query: String,
14921492
price_filter: Option<PriceFilter>,
@@ -1763,17 +1763,14 @@ async fn search_chunks(
17631763
}
17641764

17651765
async fn handle_search_tool_call(
1766-
tool_call: ToolCall,
1766+
search_params: SearchParams,
17671767
create_message_req_payload: CreateMessageReqPayload,
17681768
dataset: Dataset,
17691769
pool: web::Data<Pool>,
17701770
redis_pool: web::Data<RedisPool>,
17711771
dataset_config: DatasetConfiguration,
17721772
event_queue: web::Data<EventQueue>,
17731773
) -> Result<(Vec<ScoreChunk>, String, SearchQueryEventClickhouse), ServiceError> {
1774-
let search_params = serde_json::from_str::<SearchParams>(&tool_call.function.arguments)
1775-
.map_err(|e| ServiceError::BadRequest(e.to_string()))?;
1776-
17771774
let (clickhouse_search_event, results) = search_chunks(
17781775
create_message_req_payload.clone(),
17791776
search_params,
@@ -2071,9 +2068,13 @@ pub async fn stream_response_with_agentic_search(
20712068
// Handle tool calls
20722069
for tool_call in tool_calls {
20732070
if tool_call.function.name == "search" {
2071+
let search_params = serde_json::from_str::<SearchParams>(
2072+
&tool_call.function.arguments,
2073+
)
2074+
.map_err(|e| ServiceError::BadRequest(e.to_string()))?;
20742075
let (results, formatted_results, clickhouse_search_event) =
20752076
handle_search_tool_call(
2076-
tool_call.clone(),
2077+
search_params,
20772078
create_message_req_payload.clone(),
20782079
dataset.clone(),
20792080
pool.clone(),
@@ -2430,15 +2431,17 @@ pub async fn stream_response_with_agentic_search(
24302431
// Process all tool calls
24312432
for tool_call in &tool_calls_vec {
24322433
if tool_call.function.name == "search" {
2434+
let search_params = serde_json::from_str::<SearchParams>(&tool_call.function.arguments)
2435+
.map_err(|e| ServiceError::BadRequest(e.to_string())).unwrap_or_default();
24332436
// Send search indicator
24342437
let _ = tx
24352438
.send(web::Bytes::from(
2436-
"\n\n[Searching...]\n\n",
2439+
format!("\n\n[Searching for {}...]\n\n", search_params.query),
24372440
))
24382441
.await;
24392442

24402443
match handle_search_tool_call(
2441-
tool_call.clone(),
2444+
search_params,
24422445
create_message_req_payload_clone.clone(),
24432446
dataset.clone(),
24442447
pool.clone(),

0 commit comments

Comments
 (0)