Skip to content

Commit 275d59d

Browse files
authored
feat: support images with fs_read (#1489)
1 parent dd6ca31 commit 275d59d

15 files changed

Lines changed: 779 additions & 18 deletions

File tree

crates/chat-cli/src/api_client/clients/streaming_client.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ mod tests {
291291
.send_message(ConversationState {
292292
conversation_id: None,
293293
user_input_message: UserInputMessage {
294+
images: None,
294295
content: "Hello".into(),
295296
user_input_message_context: None,
296297
user_intent: None,
@@ -315,12 +316,14 @@ mod tests {
315316
.send_message(ConversationState {
316317
conversation_id: None,
317318
user_input_message: UserInputMessage {
319+
images: None,
318320
content: "How about rustc?".into(),
319321
user_input_message_context: None,
320322
user_intent: None,
321323
},
322324
history: Some(vec![
323325
ChatMessage::UserInputMessage(UserInputMessage {
326+
images: None,
324327
content: "What language is the linux kernel written in, and who wrote it?".into(),
325328
user_input_message_context: None,
326329
user_intent: None,

crates/chat-cli/src/api_client/model.rs

Lines changed: 106 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use aws_smithy_types::Document;
1+
use aws_smithy_types::{
2+
Blob,
3+
Document,
4+
};
25
use serde::{
36
Deserialize,
47
Serialize,
@@ -565,17 +568,113 @@ impl From<GitState> for amzn_qdeveloper_streaming_client::types::GitState {
565568
}
566569
}
567570

571+
#[derive(Debug, Clone, Serialize, Deserialize)]
572+
pub struct ImageBlock {
573+
pub format: ImageFormat,
574+
pub source: ImageSource,
575+
}
576+
577+
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
578+
pub enum ImageFormat {
579+
Gif,
580+
Jpeg,
581+
Png,
582+
Webp,
583+
}
584+
585+
impl std::str::FromStr for ImageFormat {
586+
type Err = String;
587+
588+
fn from_str(s: &str) -> Result<Self, Self::Err> {
589+
match s.trim().to_lowercase().as_str() {
590+
"gif" => Ok(ImageFormat::Gif),
591+
"jpeg" => Ok(ImageFormat::Jpeg),
592+
"jpg" => Ok(ImageFormat::Jpeg),
593+
"png" => Ok(ImageFormat::Png),
594+
"webp" => Ok(ImageFormat::Webp),
595+
_ => Err(format!("Failed to parse '{}' as ImageFormat", s)),
596+
}
597+
}
598+
}
599+
600+
impl From<ImageFormat> for amzn_codewhisperer_streaming_client::types::ImageFormat {
601+
fn from(value: ImageFormat) -> Self {
602+
match value {
603+
ImageFormat::Gif => Self::Gif,
604+
ImageFormat::Jpeg => Self::Jpeg,
605+
ImageFormat::Png => Self::Png,
606+
ImageFormat::Webp => Self::Webp,
607+
}
608+
}
609+
}
610+
impl From<ImageFormat> for amzn_qdeveloper_streaming_client::types::ImageFormat {
611+
fn from(value: ImageFormat) -> Self {
612+
match value {
613+
ImageFormat::Gif => Self::Gif,
614+
ImageFormat::Jpeg => Self::Jpeg,
615+
ImageFormat::Png => Self::Png,
616+
ImageFormat::Webp => Self::Webp,
617+
}
618+
}
619+
}
620+
621+
#[non_exhaustive]
622+
#[derive(Debug, Clone, Serialize, Deserialize)]
623+
pub enum ImageSource {
624+
Bytes(Vec<u8>),
625+
#[non_exhaustive]
626+
Unknown,
627+
}
628+
629+
impl From<ImageSource> for amzn_codewhisperer_streaming_client::types::ImageSource {
630+
fn from(value: ImageSource) -> Self {
631+
match value {
632+
ImageSource::Bytes(bytes) => Self::Bytes(Blob::new(bytes)),
633+
ImageSource::Unknown => Self::Unknown,
634+
}
635+
}
636+
}
637+
impl From<ImageSource> for amzn_qdeveloper_streaming_client::types::ImageSource {
638+
fn from(value: ImageSource) -> Self {
639+
match value {
640+
ImageSource::Bytes(bytes) => Self::Bytes(Blob::new(bytes)),
641+
ImageSource::Unknown => Self::Unknown,
642+
}
643+
}
644+
}
645+
646+
impl From<ImageBlock> for amzn_codewhisperer_streaming_client::types::ImageBlock {
647+
fn from(value: ImageBlock) -> Self {
648+
Self::builder()
649+
.format(value.format.into())
650+
.source(value.source.into())
651+
.build()
652+
.expect("Failed to build ImageBlock")
653+
}
654+
}
655+
impl From<ImageBlock> for amzn_qdeveloper_streaming_client::types::ImageBlock {
656+
fn from(value: ImageBlock) -> Self {
657+
Self::builder()
658+
.format(value.format.into())
659+
.source(value.source.into())
660+
.build()
661+
.expect("Failed to build ImageBlock")
662+
}
663+
}
664+
568665
#[derive(Debug, Clone)]
569666
pub struct UserInputMessage {
570667
pub content: String,
571668
pub user_input_message_context: Option<UserInputMessageContext>,
572669
pub user_intent: Option<UserIntent>,
670+
pub images: Option<Vec<ImageBlock>>,
573671
}
574672

575673
impl From<UserInputMessage> for amzn_codewhisperer_streaming_client::types::UserInputMessage {
576674
fn from(value: UserInputMessage) -> Self {
577675
Self::builder()
578676
.content(value.content)
677+
.set_images(value.images.map(|images| images.into_iter().map(Into::into).collect()))
579678
.set_user_input_message_context(value.user_input_message_context.map(Into::into))
580679
.set_user_intent(value.user_intent.map(Into::into))
581680
.origin(amzn_codewhisperer_streaming_client::types::Origin::Cli)
@@ -588,6 +687,7 @@ impl From<UserInputMessage> for amzn_qdeveloper_streaming_client::types::UserInp
588687
fn from(value: UserInputMessage) -> Self {
589688
Self::builder()
590689
.content(value.content)
690+
.set_images(value.images.map(|images| images.into_iter().map(Into::into).collect()))
591691
.set_user_input_message_context(value.user_input_message_context.map(Into::into))
592692
.set_user_intent(value.user_intent.map(Into::into))
593693
.origin(amzn_qdeveloper_streaming_client::types::Origin::Cli)
@@ -654,6 +754,10 @@ mod tests {
654754
#[test]
655755
fn build_user_input_message() {
656756
let user_input_message = UserInputMessage {
757+
images: Some(vec![ImageBlock {
758+
format: ImageFormat::Png,
759+
source: ImageSource::Bytes(vec![1, 2, 3]),
760+
}]),
657761
content: "test content".to_string(),
658762
user_input_message_context: Some(UserInputMessageContext {
659763
env_state: Some(EnvState {
@@ -690,6 +794,7 @@ mod tests {
690794
assert_eq!(format!("{codewhisper_input:?}"), format!("{qdeveloper_input:?}"));
691795

692796
let minimal_message = UserInputMessage {
797+
images: None,
693798
content: "test content".to_string(),
694799
user_input_message_context: None,
695800
user_intent: None,

crates/chat-cli/src/cli/chat/consts.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,8 @@ pub const MAX_USER_MESSAGE_SIZE: usize = 600_000;
1717
pub const CONTEXT_WINDOW_SIZE: usize = 200_000;
1818

1919
pub const MAX_CHARS: usize = TokenCounter::token_to_chars(CONTEXT_WINDOW_SIZE); // Character-based warning threshold
20+
21+
pub const MAX_NUMBER_OF_IMAGES_PER_REQUEST: usize = 10;
22+
23+
/// In bytes - 10 MB
24+
pub const MAX_IMAGE_SIZE: usize = 10 * 1024 * 1024;

crates/chat-cli/src/cli/chat/conversation_state.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ use crate::api_client::model::{
4242
AssistantResponseMessage,
4343
ChatMessage,
4444
ConversationState as FigConversationState,
45+
ImageBlock,
4546
Tool,
4647
ToolInputSchema,
4748
ToolResult,
@@ -294,6 +295,11 @@ impl ConversationState {
294295
self.next_message = Some(UserMessage::new_tool_use_results(tool_results));
295296
}
296297

298+
pub fn add_tool_results_with_images(&mut self, tool_results: Vec<ToolUseResult>, images: Vec<ImageBlock>) {
299+
debug_assert!(self.next_message.is_none());
300+
self.next_message = Some(UserMessage::new_tool_use_results_with_images(tool_results, images));
301+
}
302+
297303
/// Sets the next user message with "cancelled" tool results.
298304
pub fn abandon_tool_use(&mut self, tools_to_be_abandoned: Vec<QueuedTool>, deny_input: String) {
299305
self.next_message = Some(UserMessage::new_cancelled_tool_uses(
@@ -415,6 +421,7 @@ impl ConversationState {
415421
content: summary_content,
416422
user_input_message_context: None,
417423
user_intent: None,
424+
images: None,
418425
};
419426

420427
// If the last message contains tool uses, then add cancelled tool results to the summary

crates/chat-cli/src/cli/chat/message.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ use super::util::truncate_safe;
1717
use crate::api_client::model::{
1818
AssistantResponseMessage,
1919
EnvState,
20+
ImageBlock,
2021
ToolResult,
2122
ToolResultContentBlock,
2223
ToolResultStatus,
@@ -33,6 +34,7 @@ pub struct UserMessage {
3334
pub additional_context: String,
3435
pub env_context: UserEnvContext,
3536
pub content: UserMessageContent,
37+
pub images: Option<Vec<ImageBlock>>,
3638
}
3739

3840
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -56,6 +58,7 @@ impl UserMessage {
5658
/// environment [UserEnvContext].
5759
pub fn new_prompt(prompt: String) -> Self {
5860
Self {
61+
images: None,
5962
additional_context: String::new(),
6063
env_context: UserEnvContext::generate_new(),
6164
content: UserMessageContent::Prompt { prompt },
@@ -64,6 +67,7 @@ impl UserMessage {
6467

6568
pub fn new_cancelled_tool_uses<'a>(prompt: Option<String>, tool_use_ids: impl Iterator<Item = &'a str>) -> Self {
6669
Self {
70+
images: None,
6771
additional_context: String::new(),
6872
env_context: UserEnvContext::generate_new(),
6973
content: UserMessageContent::CancelledToolUses {
@@ -88,13 +92,26 @@ impl UserMessage {
8892
content: UserMessageContent::ToolUseResults {
8993
tool_use_results: results,
9094
},
95+
images: None,
96+
}
97+
}
98+
99+
pub fn new_tool_use_results_with_images(results: Vec<ToolUseResult>, images: Vec<ImageBlock>) -> Self {
100+
Self {
101+
additional_context: String::new(),
102+
env_context: UserEnvContext::generate_new(),
103+
content: UserMessageContent::ToolUseResults {
104+
tool_use_results: results,
105+
},
106+
images: Some(images),
91107
}
92108
}
93109

94110
/// Converts this message into a [UserInputMessage] to be stored in the history of
95111
/// [api_client::model::ConversationState].
96112
pub fn into_history_entry(self) -> UserInputMessage {
97113
UserInputMessage {
114+
images: None,
98115
content: self.prompt().unwrap_or_default().to_string(),
99116
user_input_message_context: Some(UserInputMessageContext {
100117
env_state: self.env_context.env_state,
@@ -122,6 +139,7 @@ impl UserMessage {
122139
_ => String::new(),
123140
};
124141
UserInputMessage {
142+
images: self.images,
125143
content: format!("{} {}", self.additional_context, formatted_prompt)
126144
.trim()
127145
.to_string(),
@@ -232,6 +250,7 @@ impl From<InvokeOutput> for ToolUseResultBlock {
232250
match value.output {
233251
OutputKind::Text(text) => Self::Text(text),
234252
OutputKind::Json(value) => Self::Json(value),
253+
OutputKind::Images(_) => Self::Text("See images data supplied".to_string()),
235254
}
236255
}
237256
}

crates/chat-cli/src/cli/chat/mod.rs

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ use tool_manager::{
156156
};
157157
use tools::gh_issue::GhIssueContext;
158158
use tools::{
159+
OutputKind,
159160
QueuedTool,
160161
Tool,
161162
ToolPermissions,
@@ -169,6 +170,7 @@ use tracing::{
169170
warn,
170171
};
171172
use unicode_width::UnicodeWidthStr;
173+
use util::images::RichImageBlock;
172174
use util::{
173175
animate_output,
174176
play_notification_bell,
@@ -1283,14 +1285,6 @@ impl ChatContext {
12831285
// Otherwise continue with normal chat on 'n' or other responses
12841286
self.tool_use_status = ToolUseStatus::Idle;
12851287

1286-
if pending_tool_index.is_some() {
1287-
self.conversation_state.abandon_tool_use(tool_uses, user_input);
1288-
} else {
1289-
self.conversation_state.set_next_user_message(user_input).await;
1290-
}
1291-
1292-
let conv_state = self.conversation_state.as_sendable_conversation_state(true).await;
1293-
12941288
if self.interactive {
12951289
queue!(self.output, style::SetForegroundColor(Color::Magenta))?;
12961290
queue!(self.output, style::SetForegroundColor(Color::Reset))?;
@@ -1299,6 +1293,13 @@ impl ChatContext {
12991293
self.spinner = Some(Spinner::new(Spinners::Dots, "Thinking...".to_owned()));
13001294
}
13011295

1296+
if pending_tool_index.is_some() {
1297+
self.conversation_state.abandon_tool_use(tool_uses, user_input);
1298+
} else {
1299+
self.conversation_state.set_next_user_message(user_input).await;
1300+
}
1301+
1302+
let conv_state = self.conversation_state.as_sendable_conversation_state(true).await;
13021303
self.send_tool_use_telemetry().await;
13031304

13041305
ChatState::HandleResponseStream(self.client.send_message(conv_state).await?)
@@ -2673,6 +2674,7 @@ impl ChatContext {
26732674

26742675
// Execute the requested tools.
26752676
let mut tool_results = vec![];
2677+
let mut image_blocks: Vec<RichImageBlock> = Vec::new();
26762678

26772679
for tool in tool_uses {
26782680
let mut tool_telemetry = self.tool_use_telemetry_events.entry(tool.id.clone());
@@ -2700,9 +2702,20 @@ impl ChatContext {
27002702
});
27012703
}
27022704
let tool_time = format!("{}.{}", tool_time.as_secs(), tool_time.subsec_millis());
2703-
27042705
match invoke_result {
27052706
Ok(result) => {
2707+
match result.output {
2708+
OutputKind::Text(ref text) => {
2709+
debug!("Output is Text: {}", text);
2710+
},
2711+
OutputKind::Json(ref json) => {
2712+
debug!("Output is JSON: {}", json);
2713+
},
2714+
OutputKind::Images(ref image) => {
2715+
image_blocks.extend(image.clone());
2716+
},
2717+
}
2718+
27062719
debug!("tool result output: {:#?}", result);
27072720
execute!(
27082721
self.output,
@@ -2762,7 +2775,13 @@ impl ChatContext {
27622775
}
27632776
}
27642777

2765-
self.conversation_state.add_tool_results(tool_results);
2778+
if !image_blocks.is_empty() {
2779+
let images = image_blocks.into_iter().map(|(block, _)| block).collect();
2780+
self.conversation_state
2781+
.add_tool_results_with_images(tool_results, images);
2782+
} else {
2783+
self.conversation_state.add_tool_results(tool_results);
2784+
}
27662785

27672786
self.send_tool_use_telemetry().await;
27682787
return Ok(ChatState::HandleResponseStream(

0 commit comments

Comments
 (0)