@@ -11,6 +11,7 @@ use crate::{
1111 get_env,
1212 operators:: {
1313 clickhouse_operator:: EventQueue ,
14+ file_operator:: put_file_in_s3_get_signed_url,
1415 message_operator:: {
1516 create_topic_message_query, delete_message_query, get_llm_api_key,
1617 get_message_by_id_query, get_message_by_sort_for_topic_query,
@@ -1206,18 +1207,65 @@ impl From<InputImageSize> for ImageSize {
12061207 }
12071208}
12081209
1210+ #[ derive( Debug , Serialize , Deserialize , ToSchema , Clone ) ]
1211+ pub enum ImageSourceType {
1212+ #[ serde( rename = "base64" ) ]
1213+ #[ schema( title = "Base64" ) ]
1214+ /// Base64 encoded image data
1215+ Base64 ( String ) ,
1216+ #[ serde( rename = "url" ) ]
1217+ #[ schema( title = "URL" ) ]
1218+ /// URL of the image
1219+ Url ( String ) ,
1220+ }
1221+
1222+ // impl a to base64 function for the ImageDTO enum
1223+ impl ImageSourceType {
1224+ pub async fn to_image_bytes ( & self ) -> Result < Vec < u8 > , ServiceError > {
1225+ match self {
1226+ ImageSourceType :: Base64 ( base64_string) => {
1227+ let decoded_bytes = base64:: prelude:: BASE64_STANDARD
1228+ . decode ( base64_string)
1229+ . map_err ( |_| ServiceError :: BadRequest ( "Invalid base64 string" . to_string ( ) ) ) ?;
1230+ Ok ( decoded_bytes)
1231+ }
1232+ ImageSourceType :: Url ( url) => {
1233+ let response = reqwest:: get ( url)
1234+ . await
1235+ . map_err ( |_| ServiceError :: BadRequest ( "Invalid URL" . to_string ( ) ) ) ?;
1236+ let bytes = response. bytes ( ) . await . map_err ( |_| {
1237+ ServiceError :: BadRequest ( "Failed to read bytes from URL" . to_string ( ) )
1238+ } ) ?;
1239+ Ok ( bytes. to_vec ( ) )
1240+ }
1241+ }
1242+ }
1243+ }
1244+
12091245#[ derive( Debug , Serialize , Deserialize , ToSchema ) ]
12101246pub struct ImageUpload {
12111247 /// The image base64 encoded
1212- pub base64_image : String ,
1248+ /// The image data - either a base64-encoded string or a URL
1249+ pub image_src : ImageSourceType ,
12131250 /// The file name of the image
12141251 pub file_name : String ,
12151252}
12161253
1254+ impl ImageUpload {
1255+ pub async fn to_file_upload_bytes ( & self ) -> Result < FileUploadBytes , ServiceError > {
1256+ let image_bytes = self . image_src . to_image_bytes ( ) . await ?;
1257+ let file_upload_bytes = FileUploadBytes {
1258+ bytes : image_bytes. into ( ) ,
1259+ filename : self . file_name . clone ( ) ,
1260+ } ;
1261+ Ok ( file_upload_bytes)
1262+ }
1263+ }
1264+
12171265#[ derive( Debug , Serialize , Deserialize , ToSchema ) ]
12181266pub struct ImageEditResponse {
12191267 /// The URL of the generated image
1220- pub images : Vec < ImageResponseData > ,
1268+ pub image_urls : Vec < String > ,
12211269}
12221270
12231271#[ derive( Debug , Serialize , Deserialize , ToSchema ) ]
@@ -1271,22 +1319,15 @@ pub async fn edit_image(
12711319 organization : None ,
12721320 } ;
12731321
1322+ let mut file_upload_bytes_futures = Vec :: new ( ) ;
1323+ for input_image in & data. input_images {
1324+ let fut = input_image. to_file_upload_bytes ( ) ;
1325+ file_upload_bytes_futures. push ( fut) ;
1326+ }
1327+ let file_upload_bytes = futures:: future:: try_join_all ( file_upload_bytes_futures) . await ?;
1328+
12741329 let parameters = EditImageParametersBuilder :: default ( )
1275- . image ( FileUpload :: BytesArray (
1276- data. input_images
1277- . iter ( )
1278- . map ( |image| {
1279- let image_bytes = base64:: prelude:: BASE64_STANDARD
1280- . decode ( image. base64_image . clone ( ) )
1281- . unwrap ( ) ;
1282-
1283- FileUploadBytes {
1284- bytes : bytes:: Bytes :: from ( image_bytes) ,
1285- filename : image. file_name . clone ( ) ,
1286- }
1287- } )
1288- . collect ( ) ,
1289- ) )
1330+ . image ( FileUpload :: BytesArray ( file_upload_bytes) )
12901331 . model ( "gpt-image-1" )
12911332 . quality :: < ImageQuality > ( data. quality . unwrap_or_default ( ) . into ( ) )
12921333 . prompt ( data. prompt . clone ( ) )
@@ -1302,17 +1343,33 @@ pub async fn edit_image(
13021343 . await
13031344 . map_err ( |e| ServiceError :: BadRequest ( e. to_string ( ) ) ) ?;
13041345
1305- let images = result. data . iter ( ) . filter_map ( |image| {
1306- if let ImageData :: B64Json { b64_json, .. } = & image {
1307- Some ( ImageResponseData {
1308- b64_json : b64_json. clone ( ) ,
1309- } )
1310- } else {
1311- None
1312- }
1346+ let images: Vec < ImageResponseData > = result
1347+ . data
1348+ . iter ( )
1349+ . filter_map ( |image| {
1350+ if let ImageData :: B64Json { b64_json, .. } = & image {
1351+ Some ( ImageResponseData {
1352+ b64_json : b64_json. clone ( ) ,
1353+ } )
1354+ } else {
1355+ None
1356+ }
1357+ } )
1358+ . collect ( ) ;
1359+
1360+ let image_signed_url_futures = images. iter ( ) . map ( |image| {
1361+ let file_id = uuid:: Uuid :: new_v4 ( ) ;
1362+ let b64_json = image. b64_json . clone ( ) ;
1363+ let decoded_bytes = base64:: prelude:: BASE64_STANDARD
1364+ . decode ( b64_json)
1365+ . unwrap_or_default ( ) ;
1366+ put_file_in_s3_get_signed_url ( file_id, decoded_bytes)
13131367 } ) ;
13141368
1315- Ok ( HttpResponse :: Ok ( ) . json ( ImageEditResponse {
1316- images : images. collect ( ) ,
1317- } ) )
1369+ let image_urls = futures:: future:: join_all ( image_signed_url_futures)
1370+ . await
1371+ . into_iter ( )
1372+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
1373+
1374+ Ok ( HttpResponse :: Ok ( ) . json ( ImageEditResponse { image_urls } ) )
13181375}
0 commit comments