Skip to content

Commit 7cc8a77

Browse files
skeptrunedevcdxker
authored andcommitted
feat(server): add support for image source types and S3 signed URL generation
1 parent dfc17f7 commit 7cc8a77

2 files changed

Lines changed: 115 additions & 28 deletions

File tree

server/src/handlers/message_handler.rs

Lines changed: 85 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use crate::{
1111
get_env,
1212
operators::{
1313
clickhouse_operator::EventQueue,
14+
file_operator::put_file_in_s3_get_signed_url,
1415
message_operator::{
1516
create_topic_message_query, delete_message_query, get_llm_api_key,
1617
get_message_by_id_query, get_message_by_sort_for_topic_query,
@@ -1206,18 +1207,65 @@ impl From<InputImageSize> for ImageSize {
12061207
}
12071208
}
12081209

1210+
#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)]
1211+
pub enum ImageSourceType {
1212+
#[serde(rename = "base64")]
1213+
#[schema(title = "Base64")]
1214+
/// Base64 encoded image data
1215+
Base64(String),
1216+
#[serde(rename = "url")]
1217+
#[schema(title = "URL")]
1218+
/// URL of the image
1219+
Url(String),
1220+
}
1221+
1222+
// impl a to base64 function for the ImageDTO enum
1223+
impl ImageSourceType {
1224+
pub async fn to_image_bytes(&self) -> Result<Vec<u8>, ServiceError> {
1225+
match self {
1226+
ImageSourceType::Base64(base64_string) => {
1227+
let decoded_bytes = base64::prelude::BASE64_STANDARD
1228+
.decode(base64_string)
1229+
.map_err(|_| ServiceError::BadRequest("Invalid base64 string".to_string()))?;
1230+
Ok(decoded_bytes)
1231+
}
1232+
ImageSourceType::Url(url) => {
1233+
let response = reqwest::get(url)
1234+
.await
1235+
.map_err(|_| ServiceError::BadRequest("Invalid URL".to_string()))?;
1236+
let bytes = response.bytes().await.map_err(|_| {
1237+
ServiceError::BadRequest("Failed to read bytes from URL".to_string())
1238+
})?;
1239+
Ok(bytes.to_vec())
1240+
}
1241+
}
1242+
}
1243+
}
1244+
12091245
#[derive(Debug, Serialize, Deserialize, ToSchema)]
12101246
pub struct ImageUpload {
12111247
/// The image base64 encoded
1212-
pub base64_image: String,
1248+
/// The image data - either a base64-encoded string or a URL
1249+
pub image_src: ImageSourceType,
12131250
/// The file name of the image
12141251
pub file_name: String,
12151252
}
12161253

1254+
impl ImageUpload {
1255+
pub async fn to_file_upload_bytes(&self) -> Result<FileUploadBytes, ServiceError> {
1256+
let image_bytes = self.image_src.to_image_bytes().await?;
1257+
let file_upload_bytes = FileUploadBytes {
1258+
bytes: image_bytes.into(),
1259+
filename: self.file_name.clone(),
1260+
};
1261+
Ok(file_upload_bytes)
1262+
}
1263+
}
1264+
12171265
#[derive(Debug, Serialize, Deserialize, ToSchema)]
12181266
pub struct ImageEditResponse {
12191267
/// The URL of the generated image
1220-
pub images: Vec<ImageResponseData>,
1268+
pub image_urls: Vec<String>,
12211269
}
12221270

12231271
#[derive(Debug, Serialize, Deserialize, ToSchema)]
@@ -1271,22 +1319,15 @@ pub async fn edit_image(
12711319
organization: None,
12721320
};
12731321

1322+
let mut file_upload_bytes_futures = Vec::new();
1323+
for input_image in &data.input_images {
1324+
let fut = input_image.to_file_upload_bytes();
1325+
file_upload_bytes_futures.push(fut);
1326+
}
1327+
let file_upload_bytes = futures::future::try_join_all(file_upload_bytes_futures).await?;
1328+
12741329
let parameters = EditImageParametersBuilder::default()
1275-
.image(FileUpload::BytesArray(
1276-
data.input_images
1277-
.iter()
1278-
.map(|image| {
1279-
let image_bytes = base64::prelude::BASE64_STANDARD
1280-
.decode(image.base64_image.clone())
1281-
.unwrap();
1282-
1283-
FileUploadBytes {
1284-
bytes: bytes::Bytes::from(image_bytes),
1285-
filename: image.file_name.clone(),
1286-
}
1287-
})
1288-
.collect(),
1289-
))
1330+
.image(FileUpload::BytesArray(file_upload_bytes))
12901331
.model("gpt-image-1")
12911332
.quality::<ImageQuality>(data.quality.unwrap_or_default().into())
12921333
.prompt(data.prompt.clone())
@@ -1302,17 +1343,33 @@ pub async fn edit_image(
13021343
.await
13031344
.map_err(|e| ServiceError::BadRequest(e.to_string()))?;
13041345

1305-
let images = result.data.iter().filter_map(|image| {
1306-
if let ImageData::B64Json { b64_json, .. } = &image {
1307-
Some(ImageResponseData {
1308-
b64_json: b64_json.clone(),
1309-
})
1310-
} else {
1311-
None
1312-
}
1346+
let images: Vec<ImageResponseData> = result
1347+
.data
1348+
.iter()
1349+
.filter_map(|image| {
1350+
if let ImageData::B64Json { b64_json, .. } = &image {
1351+
Some(ImageResponseData {
1352+
b64_json: b64_json.clone(),
1353+
})
1354+
} else {
1355+
None
1356+
}
1357+
})
1358+
.collect();
1359+
1360+
let image_signed_url_futures = images.iter().map(|image| {
1361+
let file_id = uuid::Uuid::new_v4();
1362+
let b64_json = image.b64_json.clone();
1363+
let decoded_bytes = base64::prelude::BASE64_STANDARD
1364+
.decode(b64_json)
1365+
.unwrap_or_default();
1366+
put_file_in_s3_get_signed_url(file_id, decoded_bytes)
13131367
});
13141368

1315-
Ok(HttpResponse::Ok().json(ImageEditResponse {
1316-
images: images.collect(),
1317-
}))
1369+
let image_urls = futures::future::join_all(image_signed_url_futures)
1370+
.await
1371+
.into_iter()
1372+
.collect::<Result<Vec<_>, _>>()?;
1373+
1374+
Ok(HttpResponse::Ok().json(ImageEditResponse { image_urls }))
13181375
}

server/src/operators/file_operator.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,3 +642,33 @@ pub async fn delete_file_query(
642642

643643
Ok(())
644644
}
645+
646+
pub async fn put_file_in_s3_get_signed_url(
647+
file_id: uuid::Uuid,
648+
file_data: Vec<u8>,
649+
) -> Result<String, ServiceError> {
650+
let bucket = get_aws_bucket().unwrap();
651+
652+
bucket
653+
.put_object(file_id.to_string(), &file_data)
654+
.await
655+
.map_err(|e| {
656+
log::error!(
657+
"Could not upload file to S3 before getting signed URL {:?}",
658+
e
659+
);
660+
ServiceError::BadRequest(
661+
"Could not upload file to S3 before getting signed URL".to_string(),
662+
)
663+
})?;
664+
665+
let signed_url = bucket
666+
.presign_get(file_id.to_string(), 86400, None)
667+
.await
668+
.map_err(|e| {
669+
log::error!("Could not get presigned url after putting object {:?}", e);
670+
ServiceError::BadRequest("Could not get presigned url after putting object".to_string())
671+
})?;
672+
673+
Ok(signed_url)
674+
}

0 commit comments

Comments
 (0)