Skip to content

Commit e8aa5da

Browse files
committed
Add support for retrieval-augmented generation
1 parent 770753b commit e8aa5da

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

crates/chatbot/src/lib.rs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use rand::{rngs::SmallRng, Rng, SeedableRng};
2-
use std::{cell::RefCell, time::Duration};
2+
use std::{cell::RefCell, path::PathBuf, time::Duration};
33

44
thread_local! {
55
static RNG: RefCell<SmallRng> = RefCell::new(SmallRng::from_entropy());
@@ -33,17 +33,26 @@ impl Chatbot {
3333
}
3434
}
3535

36+
pub fn retrieval_documents(&self, _messages: &[String]) -> Vec<PathBuf> {
37+
vec![
38+
PathBuf::from("data/doc1.txt"),
39+
PathBuf::from("data/doc2.txt"),
40+
]
41+
}
42+
3643
/// Generates a list of possible responses given the current chat.
3744
///
3845
/// Warning: may take a few seconds!
39-
pub async fn query_chat(&mut self, messages: &[String]) -> Vec<String> {
46+
pub async fn query_chat(&mut self, messages: &[String], docs: &[String]) -> Vec<String> {
4047
std::thread::sleep(Duration::from_secs(2));
4148
let most_recent = messages.last().unwrap();
4249
let emoji = &self.emojis[self.emoji_counter];
4350
self.emoji_counter = (self.emoji_counter + 1) % self.emojis.len();
4451
vec![
4552
format!("\"{most_recent}\"? And how does that make you feel? {emoji}",),
4653
format!("\"{most_recent}\"! Interesting! Go on... {emoji}"),
54+
format!("Have you considered: {}", docs.first().unwrap()),
55+
format!("I might recommend: {}", docs.last().unwrap()),
4756
]
4857
}
4958
}

0 commit comments

Comments
 (0)