Skip to content

Commit 696f0fc

Browse files
committed
feat(server): rag_context field for adding better context
1 parent 07d8650 commit 696f0fc

File tree

3 files changed

+33
-13
lines changed

3 files changed

+33
-13
lines changed

server/src/data/models.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4514,6 +4514,7 @@ impl ApiKeyRequestParams {
45144514
) -> CreateMessageReqPayload {
45154515
CreateMessageReqPayload {
45164516
new_message_content: payload.new_message_content,
4517+
rag_context: payload.rag_context,
45174518
topic_id: payload.topic_id,
45184519
user_id: payload.user_id,
45194520
sort_options: payload.sort_options,
@@ -9417,6 +9418,7 @@ impl<'de> Deserialize<'de> for CreateMessageReqPayload {
94179418
pub only_include_docs_used: Option<bool>,
94189419
pub currency: Option<String>,
94199420
metadata: Option<serde_json::Value>,
9421+
pub rag_context: Option<String>,
94209422
#[serde(flatten)]
94219423
other: std::collections::HashMap<String, serde_json::Value>,
94229424
use_quote_negated_terms: Option<bool>,
@@ -9457,6 +9459,7 @@ impl<'de> Deserialize<'de> for CreateMessageReqPayload {
94579459
context_options,
94589460
no_result_message: helper.no_result_message,
94599461
metadata: helper.metadata,
9462+
rag_context: helper.rag_context,
94609463
only_include_docs_used: helper.only_include_docs_used,
94619464
use_quote_negated_terms: helper.use_quote_negated_terms,
94629465
remove_stop_words: helper.remove_stop_words,
@@ -9494,6 +9497,7 @@ impl<'de> Deserialize<'de> for RegenerateMessageReqPayload {
94949497
pub use_quote_negated_terms: Option<bool>,
94959498
pub remove_stop_words: Option<bool>,
94969499
pub typo_options: Option<TypoOptions>,
9500+
pub rag_context: Option<String>,
94979501
}
94989502

94999503
let mut helper = Helper::deserialize(deserializer)?;
@@ -9530,6 +9534,7 @@ impl<'de> Deserialize<'de> for RegenerateMessageReqPayload {
95309534
use_quote_negated_terms: helper.use_quote_negated_terms,
95319535
remove_stop_words: helper.remove_stop_words,
95329536
typo_options: helper.typo_options,
9537+
rag_context: helper.rag_context,
95339538
})
95349539
}
95359540
}
@@ -9567,6 +9572,7 @@ impl<'de> Deserialize<'de> for EditMessageReqPayload {
95679572
pub use_quote_negated_terms: Option<bool>,
95689573
pub remove_stop_words: Option<bool>,
95699574
pub typo_options: Option<TypoOptions>,
9575+
pub rag_context: Option<String>,
95709576
}
95719577

95729578
let mut helper = Helper::deserialize(deserializer)?;
@@ -9607,6 +9613,7 @@ impl<'de> Deserialize<'de> for EditMessageReqPayload {
96079613
use_quote_negated_terms: helper.use_quote_negated_terms,
96089614
remove_stop_words: helper.remove_stop_words,
96099615
typo_options: helper.typo_options,
9616+
rag_context: helper.rag_context,
96109617
})
96119618
}
96129619
}

server/src/handlers/message_handler.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ pub struct CreateMessageReqPayload {
129129
pub typo_options: Option<TypoOptions>,
130130
/// Metadata is any metadata you want to associate w/ the event that is created from this request
131131
pub metadata: Option<serde_json::Value>,
132+
/// Overrides what the way chunks are placed into the context window
133+
pub rag_context: Option<String>,
132134
}
133135

134136
/// Create message
@@ -414,6 +416,8 @@ pub struct RegenerateMessageReqPayload {
414416
pub typo_options: Option<TypoOptions>,
415417
/// Metadata is any metadata you want to associate w/ the event that is created from this request
416418
pub metadata: Option<serde_json::Value>,
419+
/// Overrides what the way chunks are placed into the context window
420+
pub rag_context: Option<String>,
417421
}
418422

419423
#[derive(Serialize, Debug, ToSchema)]
@@ -466,6 +470,8 @@ pub struct EditMessageReqPayload {
466470
pub typo_options: Option<TypoOptions>,
467471
/// Metadata is any metadata you want to associate w/ the event that is created from this request
468472
pub metadata: Option<serde_json::Value>,
473+
/// Overrides what the way chunks are placed into the context window
474+
pub rag_context: Option<String>,
469475
}
470476

471477
impl From<EditMessageReqPayload> for CreateMessageReqPayload {
@@ -494,6 +500,7 @@ impl From<EditMessageReqPayload> for CreateMessageReqPayload {
494500
remove_stop_words: data.remove_stop_words,
495501
typo_options: data.typo_options,
496502
metadata: data.metadata,
503+
rag_context: data.rag_context,
497504
}
498505
}
499506
}
@@ -524,6 +531,7 @@ impl From<RegenerateMessageReqPayload> for CreateMessageReqPayload {
524531
remove_stop_words: data.remove_stop_words,
525532
typo_options: data.typo_options,
526533
metadata: data.metadata,
534+
rag_context: data.rag_context,
527535
}
528536
}
529537
}

server/src/operators/message_operator.rs

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -837,20 +837,25 @@ pub async fn stream_response(
837837
.streaming(response_stream));
838838
}
839839

840-
let rag_content = score_chunks
841-
.iter()
842-
.enumerate()
843-
.map(|(idx, score_chunk)| {
844-
json!({
845-
"doc": idx + 1,
846-
"text": convert_html_to_text(&(ChunkMetadata::from(score_chunk.chunk.clone()).chunk_html.clone().unwrap_or_default())),
847-
"num_value": ChunkMetadata::from(score_chunk.chunk.clone()).num_value.map(|x| format!("{} {}", create_message_req_payload.currency.clone().unwrap_or("".to_string()), x)).unwrap_or("".to_string()),
848-
"link": ChunkMetadata::from(score_chunk.chunk.clone()).link.clone().unwrap_or_default()
840+
let rag_content = match create_message_req_payload.rag_context {
841+
Some(rag_context) => rag_context,
842+
None => {
843+
score_chunks
844+
.iter()
845+
.enumerate()
846+
.map(|(idx, score_chunk)| {
847+
json!({
848+
"doc": idx + 1,
849+
"text": convert_html_to_text(&(ChunkMetadata::from(score_chunk.chunk.clone()).chunk_html.clone().unwrap_or_default())),
850+
"num_value": ChunkMetadata::from(score_chunk.chunk.clone()).num_value.map(|x| format!("{} {}", create_message_req_payload.currency.clone().unwrap_or("".to_string()), x)).unwrap_or("".to_string()),
851+
"link": ChunkMetadata::from(score_chunk.chunk.clone()).link.clone().unwrap_or_default()
852+
})
853+
.to_string()
849854
})
850-
.to_string()
851-
})
852-
.collect::<Vec<String>>()
853-
.join("\n\n");
855+
.collect::<Vec<String>>()
856+
.join("\n\n")
857+
}
858+
};
854859

855860
let user_message = match &openai_messages
856861
.last()

0 commit comments

Comments
 (0)