Skip to content

Commit c351c72

Browse files
committed
cleanup: add back only_used_docs_used for now
1 parent d3d4e7d commit c351c72

4 files changed

Lines changed: 196 additions & 20 deletions

File tree

server/src/data/models.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4547,6 +4547,7 @@ impl ApiKeyRequestParams {
45474547
typo_options: self.typo_options.or(payload.typo_options),
45484548
metadata: payload.metadata,
45494549
use_agentic_search: payload.use_agentic_search,
4550+
only_include_docs_used: payload.only_include_docs_used,
45504551
}
45514552
}
45524553

@@ -9432,6 +9433,7 @@ impl<'de> Deserialize<'de> for CreateMessageReqPayload {
94329433
use_quote_negated_terms: Option<bool>,
94339434
remove_stop_words: Option<bool>,
94349435
typo_options: Option<TypoOptions>,
9436+
pub only_include_docs_used: Option<bool>,
94359437
}
94369438

94379439
let mut helper = Helper::deserialize(deserializer)?;
@@ -9472,6 +9474,7 @@ impl<'de> Deserialize<'de> for CreateMessageReqPayload {
94729474
remove_stop_words: helper.remove_stop_words,
94739475
typo_options: helper.typo_options,
94749476
use_agentic_search: helper.use_agentic_search,
9477+
only_include_docs_used: helper.only_include_docs_used,
94759478
})
94769479
}
94779480
}
@@ -9506,6 +9509,7 @@ impl<'de> Deserialize<'de> for RegenerateMessageReqPayload {
95069509
pub typo_options: Option<TypoOptions>,
95079510
pub rag_context: Option<String>,
95089511
pub use_agentic_search: Option<bool>,
9512+
pub only_include_docs_used: Option<bool>,
95099513
}
95109514

95119515
let mut helper = Helper::deserialize(deserializer)?;
@@ -9543,6 +9547,7 @@ impl<'de> Deserialize<'de> for RegenerateMessageReqPayload {
95439547
typo_options: helper.typo_options,
95449548
rag_context: helper.rag_context,
95459549
use_agentic_search: helper.use_agentic_search,
9550+
only_include_docs_used: helper.only_include_docs_used,
95469551
})
95479552
}
95489553
}
@@ -9581,6 +9586,7 @@ impl<'de> Deserialize<'de> for EditMessageReqPayload {
95819586
pub typo_options: Option<TypoOptions>,
95829587
pub rag_context: Option<String>,
95839588
pub use_agentic_search: Option<bool>,
9589+
pub only_include_docs_used: Option<bool>,
95849590
}
95859591

95869592
let mut helper = Helper::deserialize(deserializer)?;
@@ -9622,6 +9628,7 @@ impl<'de> Deserialize<'de> for EditMessageReqPayload {
96229628
typo_options: helper.typo_options,
96239629
rag_context: helper.rag_context,
96249630
use_agentic_search: helper.use_agentic_search,
9631+
only_include_docs_used: helper.only_include_docs_used,
96259632
})
96269633
}
96279634
}

server/src/handlers/message_handler.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ pub struct CreateMessageReqPayload {
106106
pub context_options: Option<ContextOptions>,
107107
/// No result message for when there are no chunks found above the score threshold.
108108
pub no_result_message: Option<String>,
109+
/// Only include docs used is a boolean that indicates whether or not to only include the docs that were used in the completion. If true, the completion will only include the docs that were used in the completion. If false, the completion will include all of the docs.
110+
pub only_include_docs_used: Option<bool>,
109111
/// The currency to use for the completion. If not specified, this defaults to "USD".
110112
pub currency: Option<String>,
111113
/// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. Default is "hybrid".
@@ -271,6 +273,7 @@ pub async fn create_message(
271273
&dataset_config,
272274
previous_messages,
273275
create_message_data.use_agentic_search.unwrap_or(false),
276+
create_message_data.only_include_docs_used.unwrap_or(false),
274277
new_message,
275278
dataset_org_plan_sub.dataset.id,
276279
&create_message_pool,
@@ -406,6 +409,8 @@ pub struct RegenerateMessageReqPayload {
406409
pub context_options: Option<ContextOptions>,
407410
/// No result message for when there are no chunks found above the score threshold.
408411
pub no_result_message: Option<String>,
412+
/// Only include docs used is a boolean that indicates whether or not to only include the docs that were used in the completion. If true, the completion will only include the docs that were used in the completion. If false, the completion will include all of the docs.
413+
pub only_include_docs_used: Option<bool>,
409414
/// The currency symbol to use for the completion. If not specified, this defaults to "$".
410415
pub currency: Option<String>,
411416
/// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. Default is "hybrid".
@@ -460,6 +465,8 @@ pub struct EditMessageReqPayload {
460465
pub context_options: Option<ContextOptions>,
461466
/// No result message for when there are no chunks found above the score threshold.
462467
pub no_result_message: Option<String>,
468+
/// Only include docs used is a boolean that indicates whether or not to only include the docs that were used in the completion. If true, the completion will only include the docs that were used in the completion. If false, the completion will include all of the docs.
469+
pub only_include_docs_used: Option<bool>,
463470
/// The currency symbol to use for the completion. If not specified, this defaults to "$".
464471
pub currency: Option<String>,
465472
/// Search_type can be either "semantic", "fulltext", or "hybrid". "hybrid" will pull in one page (10 chunks) of both semantic and full-text results then re-rank them using scores from a cross encoder model. "semantic" will pull in one page (10 chunks) of the nearest cosine distant vectors. "fulltext" will pull in one page (10 chunks) of full-text results based on SPLADE. Default is "hybrid".
@@ -517,6 +524,7 @@ impl From<EditMessageReqPayload> for CreateMessageReqPayload {
517524
metadata: data.metadata,
518525
rag_context: data.rag_context,
519526
use_agentic_search: data.use_agentic_search,
527+
only_include_docs_used: data.only_include_docs_used,
520528
}
521529
}
522530
}
@@ -548,6 +556,7 @@ impl From<RegenerateMessageReqPayload> for CreateMessageReqPayload {
548556
metadata: data.metadata,
549557
rag_context: data.rag_context,
550558
use_agentic_search: data.use_agentic_search,
559+
only_include_docs_used: data.only_include_docs_used,
551560
}
552561
}
553562
}

server/src/operators/message_operator.rs

Lines changed: 134 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,25 @@ use itertools::Itertools;
44
use openai_dive::v1::models::WhisperModel;
55
use simple_server_timing_header::Timer;
66
use simsearch::SimSearch;
7+
use std::sync::atomic::{AtomicBool, AtomicU16, Ordering};
78
use std::sync::{Arc, Mutex};
89

910
#[cfg(not(feature = "hallucination-detection"))]
1011
use crate::data::models::DummyHallucinationScore;
1112
use crate::data::models::{
12-
self, escape_quotes, ChunkMetadata, ChunkMetadataStringTagSet, ChunkMetadataStringTagSetWithHighlightsScore, ChunkMetadataTypes, ConditionType, Dataset, DatasetConfiguration, FieldCondition, LLMOptions, MultiQuery, QdrantChunkMetadata, QueryTypes, RagQueryEventClickhouse, Range, RangeCondition, RedisPool, ScoreChunk, SearchMethod, SearchModalities, SuggestType
13+
self, escape_quotes, ChunkMetadata, ChunkMetadataStringTagSet,
14+
ChunkMetadataStringTagSetWithHighlightsScore, ChunkMetadataTypes, ConditionType, Dataset,
15+
DatasetConfiguration, FieldCondition, LLMOptions, MultiQuery, QdrantChunkMetadata, QueryTypes,
16+
RagQueryEventClickhouse, Range, RangeCondition, RedisPool, ScoreChunk, SearchMethod,
17+
SearchModalities, SuggestType,
1318
};
1419
use crate::diesel::prelude::*;
1520
use crate::get_env;
1621
use crate::handlers::chunk_handler::SearchChunksReqPayload;
1722
use crate::handlers::group_handler::SearchOverGroupsReqPayload;
1823
use crate::handlers::message_handler::{CreateMessageReqPayload, SuggestedQueriesReqPayload};
1924
use crate::operators::clickhouse_operator::ClickHouseEvent;
20-
use crate::operators::parse_operator::convert_html_to_text;
25+
use crate::operators::parse_operator::{convert_html_to_text, parse_streaming_completetion};
2126
use crate::operators::qdrant_operator::scroll_dataset_points;
2227
use crate::{
2328
data::models::{Message, Pool, SearchQueryEventClickhouse},
@@ -58,6 +63,42 @@ use super::search_operator::{
5863
assemble_qdrant_filter, hybrid_search_over_groups, search_chunks_query, search_hybrid_chunks,
5964
search_over_groups_query, ParsedQuery, ParsedQueryTypes,
6065
};
66+
67+
pub fn parse_text_into_docs_message(
68+
text: &str,
69+
score_chunks: Vec<ScoreChunk>,
70+
) -> Result<(String, Vec<ScoreChunk>), ServiceError> {
71+
let parsed: serde_json::Value = serde_json::from_str(text).map_err(|_| {
72+
log::error!("Invalid JSON response when trying to fetch used documents array");
73+
ServiceError::BadRequest(
74+
"Invalid JSON response when trying to fetch used documents array".to_string(),
75+
)
76+
})?;
77+
78+
let used_docs = parsed["documents"].as_array().ok_or_else(|| {
79+
log::error!("Missing documents array");
80+
ServiceError::BadRequest("Missing documents array".to_string())
81+
})?;
82+
83+
let rag_message = parsed["message"].as_str().ok_or_else(|| {
84+
log::error!("Missing message");
85+
ServiceError::BadRequest("Missing message".to_string())
86+
})?;
87+
88+
// Filter chunk_metadatas to only include used documents
89+
let filtered_chunks: Vec<_> = used_docs
90+
.iter()
91+
.filter_map(|doc_idx| {
92+
doc_idx
93+
.as_u64()
94+
.and_then(|idx| score_chunks.get(idx as usize - 1))
95+
.cloned()
96+
})
97+
.collect();
98+
99+
Ok((rag_message.to_string(), filtered_chunks))
100+
}
101+
61102
#[derive(Debug, Serialize, Deserialize)]
62103
pub struct ChatCompletionDTO {
63104
pub completion_message: Message,
@@ -151,9 +192,25 @@ If you use the search tool, you MUST use the chunks_used tool to respond with th
151192
respond with the chunks you MUST include in your response to the user's question.
152193
"#;
153194

195+
const STRUCTURE_SYSTEM_PROMPT: &str = r#"
196+
Before you start generating respond with the documents that you plan to use to generate your response, YOU MUST INCLUDE AT LEAST 1.
197+
YOU MUST DO THIS BEFORE YOU CONTINUE TO GENERATE A RESPONSE.
198+
After responding with the documents, YOU MUST RESPOND TO THE USERS PROMPT.
199+
```
200+
Example:
201+
User:
202+
Here's my prompt: what about for spreadsheets \n\n Use the following retrieved documents to respond briefly and accurately: {"doc": 1, "text": "chunk text..", "link": "chunk link.." }\n\n{"doc": 2, "text": "chunk text..", "link": "chunk link.." }... etc
203+
Assistant:
204+
documents: [1,2]
205+
...continue with model response
206+
```
207+
After you have done these things now follow:
208+
"#;
209+
154210
pub async fn create_generic_system_message(
155211
system_prompt: String,
156212
use_agentic_search: bool,
213+
only_include_docs_used: bool,
157214
messages_topic_id: uuid::Uuid,
158215
dataset_id: uuid::Uuid,
159216
pool: &web::Data<Pool>,
@@ -162,12 +219,16 @@ pub async fn create_generic_system_message(
162219
crate::operators::topic_operator::get_topic_query(messages_topic_id, dataset_id, pool)
163220
.await?;
164221

165-
let system_prompt = if use_agentic_search {
222+
let mut system_prompt = if use_agentic_search {
166223
format!("{}\n\n{}", AGENTIC_SEARCH_SYSTEM_PROMPT, system_prompt)
167224
} else {
168225
system_prompt
169226
};
170227

228+
if only_include_docs_used {
229+
system_prompt = format!("{}\n\n{}", STRUCTURE_SYSTEM_PROMPT, system_prompt);
230+
}
231+
171232
let system_message = Message::from_details(
172233
system_prompt,
173234
topic.id,
@@ -186,6 +247,7 @@ pub async fn create_topic_message_query(
186247
config: &DatasetConfiguration,
187248
previous_messages: Vec<Message>,
188249
use_agentic_search: bool,
250+
only_include_docs_used: bool,
189251
new_message: Message,
190252
dataset_id: uuid::Uuid,
191253
pool: &web::Data<Pool>,
@@ -198,7 +260,8 @@ pub async fn create_topic_message_query(
198260
let system_message = create_generic_system_message(
199261
config.SYSTEM_PROMPT.clone(),
200262
use_agentic_search,
201-
new_message.topic_id,
263+
only_include_docs_used,
264+
new_message.topic_id,
202265
dataset_id,
203266
pool,
204267
)
@@ -1005,7 +1068,17 @@ pub async fn stream_response(
10051068
content: Some(ChatMessageContent::Text(text)),
10061069
..
10071070
}) => {
1008-
(text.clone(), score_chunks.clone())
1071+
if create_message_req_payload
1072+
.only_include_docs_used
1073+
.unwrap_or(false)
1074+
{
1075+
match parse_text_into_docs_message(text, score_chunks.clone()) {
1076+
Ok((response_text, filtered_chunks)) => (response_text, filtered_chunks),
1077+
Err(_) => (text.clone(), vec![]),
1078+
}
1079+
} else {
1080+
(text.clone(), score_chunks.clone())
1081+
}
10091082
}
10101083
_ => {
10111084
return Err(ServiceError::BadRequest("Invalid response format, did not receive text on the assistant message from the LLM provider".to_string()).into())
@@ -1254,9 +1327,17 @@ pub async fn stream_response(
12541327
.parse::<u64>()
12551328
.unwrap_or(60);
12561329

1257-
let documents = Arc::new(Mutex::new(
1258-
(0..score_chunks.len() as u32).collect::<Vec<u32>>(),
1259-
));
1330+
let state = Arc::new(AtomicU16::new(0));
1331+
let documents = if create_message_req_payload
1332+
.only_include_docs_used
1333+
.unwrap_or(false)
1334+
{
1335+
Arc::new(Mutex::new(vec![]))
1336+
} else {
1337+
Arc::new(Mutex::new((0..score_chunks.len() as u32).collect()))
1338+
};
1339+
let started_parsing_completion = AtomicBool::new(false);
1340+
let mut bail_on_parsing = AtomicBool::new(false);
12601341

12611342
let completion_stream = stream
12621343
.take_until(tokio::time::sleep(std::time::Duration::from_secs(chat_completion_timeout)))
@@ -1292,11 +1373,45 @@ pub async fn stream_response(
12921373
content: Some(ChatMessageContent::Text(text)),
12931374
..
12941375
} => {
1295-
if !completion_first {
1376+
if create_message_req_payload
1377+
.only_include_docs_used
1378+
.unwrap_or(false) {
1379+
let bailed_on_iter = bail_on_parsing.get_mut();
1380+
let (text, docs) = if !*bailed_on_iter {
1381+
let (parsed_text, docs, bail) = parse_streaming_completetion(text, state.clone(), documents.clone());
1382+
if bail {
1383+
*bailed_on_iter = true;
1384+
documents.lock().unwrap().extend(0..score_chunks.len() as u32);
1385+
(Some(text.clone()), Some((0..score_chunks.len() as u32).collect()))
1386+
} else {
1387+
(parsed_text, docs)
1388+
}
1389+
} else {
1390+
(Some(text.clone()), None)
1391+
};
1392+
1393+
if let Some(docs) = docs {
1394+
if !completion_first {
1395+
let filtered_chunks = score_chunks.iter().enumerate().filter_map(|(idx, score_chunk)| {
1396+
if docs.contains(&(idx as u32)) {
1397+
Some(ChunkMetadataStringTagSetWithHighlightsScore::from(score_chunk.clone()))
1398+
} else {
1399+
None
1400+
}
1401+
}).collect::<Vec<ChunkMetadataStringTagSetWithHighlightsScore>>();
1402+
if *bailed_on_iter {
1403+
return Some(format!("{}||{}", serde_json::to_string(&filtered_chunks).unwrap_or_default().replace("||", ""), text.unwrap_or("".to_string())));
1404+
} else {
1405+
return Some(format!("{}||", serde_json::to_string(&filtered_chunks).unwrap_or_default().replace("||", "")));
1406+
}
1407+
}
1408+
}
1409+
text.clone()
1410+
} else if !completion_first && !started_parsing_completion.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |_| Some(true)).unwrap_or(true) {
12961411
let returned_chunks = score_chunks.iter().map(|score_chunk| {
12971412
ChunkMetadataStringTagSetWithHighlightsScore::from(score_chunk.clone())
12981413
}).collect::<Vec<ChunkMetadataStringTagSetWithHighlightsScore>>();
1299-
Some(format!("{}||", serde_json::to_string(&returned_chunks).unwrap_or_default().replace("||", "")))
1414+
return Some(format!("{}||", serde_json::to_string(&returned_chunks).unwrap_or_default().replace("||", "")));
13001415
} else {
13011416
Some(text.clone())
13021417
}
@@ -1339,7 +1454,6 @@ pub async fn stream_response(
13391454
}
13401455
}
13411456

1342-
13431457
#[derive(Deserialize, Debug)]
13441458
struct PriceFilter {
13451459
min: Option<f32>,
@@ -1367,15 +1481,16 @@ async fn search_chunks(
13671481
filters.map(|filter| {
13681482
filter.must.map(|mut must| {
13691483
must.push(ConditionType::Field(FieldCondition {
1370-
field: "num_value".to_string(),
1371-
range: Some(Range {
1372-
gte: price_filter.min.map(|x| RangeCondition::Float(x as f64)),
1373-
lte: price_filter.max.map(|x| RangeCondition::Float(x as f64)),
1484+
field: "num_value".to_string(),
1485+
range: Some(Range {
1486+
gte: price_filter.min.map(|x| RangeCondition::Float(x as f64)),
1487+
lte: price_filter.max.map(|x| RangeCondition::Float(x as f64)),
1488+
..Default::default()
1489+
}),
13741490
..Default::default()
1375-
}),
1376-
..Default::default()
1377-
}))
1378-
})});
1491+
}))
1492+
})
1493+
});
13791494
}
13801495

13811496
let search_type = create_message_req_payload

0 commit comments

Comments
 (0)