@@ -5,7 +5,9 @@ use std::sync::{Arc, Mutex};
55
66use anyhow:: Result ;
77use image:: imageops:: { self , FilterType } ;
8- use image:: { DynamicImage , GenericImageView , GrayImage , ImageBuffer , Luma , Rgb , Rgb32FImage } ;
8+ use image:: {
9+ DynamicImage , GenericImageView , GrayImage , ImageBuffer , Luma , Rgb , Rgb32FImage , Rgba , RgbaImage ,
10+ } ;
911use ndarray:: { Array , Array4 , IxDyn } ;
1012use ort:: session:: Session ;
1113use ort:: value:: Tensor ;
@@ -46,6 +48,11 @@ const DENOISE_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/res
4648const DENOISE_FILENAME : & str = "nind_denoise_utnet_684.onnx" ;
4749const DENOISE_SHA256 : & str = "ee3586279d514df557ff3f7dec6df37fafc51ba5d3a3435b2cc9ac2d9017e7fe" ;
4850
51+ const LAMA_URL : & str =
52+ "https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/lama_fp16.onnx?download=true" ;
53+ const LAMA_FILENAME : & str = "lama_fp16.onnx" ;
54+ const LAMA_SHA256 : & str = "2d6be6277c400d6f1b91819737f7c3da935e5c63d1b521d393be1196a2bfa82c" ;
55+
4956pub struct AiModels {
5057 pub sam_encoder : Mutex < Session > ,
5158 pub sam_decoder : Mutex < Session > ,
@@ -69,6 +76,7 @@ pub struct AiState {
6976 pub models : Option < Arc < AiModels > > ,
7077 pub denoise_model : Option < Arc < Mutex < Session > > > ,
7178 pub clip_models : Option < Arc < ClipModels > > ,
79+ pub lama_model : Option < Arc < Mutex < Session > > > ,
7280 pub embeddings : Option < ImageEmbeddings > ,
7381}
7482
@@ -203,18 +211,18 @@ pub async fn get_or_init_ai_models(
203211 ai_state_mutex : & Mutex < Option < AiState > > ,
204212 ai_init_lock : & TokioMutex < ( ) > ,
205213) -> Result < Arc < AiModels > > {
206- if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( )
207- && let Some ( models) = & ai_state. models
208- {
209- return Ok ( models . clone ( ) ) ;
214+ if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( ) {
215+ if let Some ( models) = & ai_state. models {
216+ return Ok ( models . clone ( ) ) ;
217+ }
210218 }
211219
212220 let _guard = ai_init_lock. lock ( ) . await ;
213221
214- if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( )
215- && let Some ( models) = & ai_state. models
216- {
217- return Ok ( models . clone ( ) ) ;
222+ if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( ) {
223+ if let Some ( models) = & ai_state. models {
224+ return Ok ( models . clone ( ) ) ;
225+ }
218226 }
219227
220228 let models_dir = get_models_dir ( app_handle) ?;
@@ -285,6 +293,7 @@ pub async fn get_or_init_ai_models(
285293 models : Some ( models. clone ( ) ) ,
286294 denoise_model : None ,
287295 clip_models : None ,
296+ lama_model : None ,
288297 embeddings : None ,
289298 } ) ;
290299 }
@@ -297,18 +306,18 @@ pub async fn get_or_init_denoise_model(
297306 ai_state_mutex : & Mutex < Option < AiState > > ,
298307 ai_init_lock : & TokioMutex < ( ) > ,
299308) -> Result < Arc < Mutex < Session > > > {
300- if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( )
301- && let Some ( denoise_model) = & ai_state. denoise_model
302- {
303- return Ok ( denoise_model . clone ( ) ) ;
309+ if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( ) {
310+ if let Some ( denoise_model) = & ai_state. denoise_model {
311+ return Ok ( denoise_model . clone ( ) ) ;
312+ }
304313 }
305314
306315 let _guard = ai_init_lock. lock ( ) . await ;
307316
308- if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( )
309- && let Some ( denoise_model) = & ai_state. denoise_model
310- {
311- return Ok ( denoise_model . clone ( ) ) ;
317+ if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( ) {
318+ if let Some ( denoise_model) = & ai_state. denoise_model {
319+ return Ok ( denoise_model . clone ( ) ) ;
320+ }
312321 }
313322
314323 let models_dir = get_models_dir ( app_handle) ?;
@@ -318,11 +327,11 @@ pub async fn get_or_init_denoise_model(
318327 DENOISE_FILENAME ,
319328 DENOISE_URL ,
320329 DENOISE_SHA256 ,
321- "AI Denoise Model" ,
330+ "NIND Denoise Model" ,
322331 )
323332 . await ?;
324333
325- let _ = ort:: init ( ) . with_name ( "RapidRAW -Denoise" ) . commit ( ) ;
334+ let _ = ort:: init ( ) . with_name ( "AI -Denoise" ) . commit ( ) ;
326335 let model_path = models_dir. join ( DENOISE_FILENAME ) ;
327336 let session = Session :: builder ( ) ?. commit_from_file ( model_path) ?;
328337 let denoise_model = Arc :: new ( Mutex :: new ( session) ) ;
@@ -337,6 +346,7 @@ pub async fn get_or_init_denoise_model(
337346 models : None ,
338347 denoise_model : Some ( denoise_model. clone ( ) ) ,
339348 clip_models : None ,
349+ lama_model : None ,
340350 embeddings : None ,
341351 } ) ;
342352 }
@@ -349,18 +359,18 @@ pub async fn get_or_init_clip_models(
349359 ai_state_mutex : & Mutex < Option < AiState > > ,
350360 ai_init_lock : & TokioMutex < ( ) > ,
351361) -> Result < Arc < ClipModels > > {
352- if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( )
353- && let Some ( clip_models) = & ai_state. clip_models
354- {
355- return Ok ( clip_models . clone ( ) ) ;
362+ if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( ) {
363+ if let Some ( clip_models) = & ai_state. clip_models {
364+ return Ok ( clip_models . clone ( ) ) ;
365+ }
356366 }
357367
358368 let _guard = ai_init_lock. lock ( ) . await ;
359369
360- if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( )
361- && let Some ( clip_models) = & ai_state. clip_models
362- {
363- return Ok ( clip_models . clone ( ) ) ;
370+ if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( ) {
371+ if let Some ( clip_models) = & ai_state. clip_models {
372+ return Ok ( clip_models . clone ( ) ) ;
373+ }
364374 }
365375
366376 let models_dir = get_models_dir ( app_handle) ?;
@@ -400,13 +410,67 @@ pub async fn get_or_init_clip_models(
400410 models : None ,
401411 denoise_model : None ,
402412 clip_models : Some ( clip_models. clone ( ) ) ,
413+ lama_model : None ,
403414 embeddings : None ,
404415 } ) ;
405416 }
406417
407418 Ok ( clip_models)
408419}
409420
421+ pub async fn get_or_init_lama_model (
422+ app_handle : & tauri:: AppHandle ,
423+ ai_state_mutex : & Mutex < Option < AiState > > ,
424+ ai_init_lock : & TokioMutex < ( ) > ,
425+ ) -> Result < Arc < Mutex < Session > > > {
426+ if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( ) {
427+ if let Some ( lama_model) = & ai_state. lama_model {
428+ return Ok ( lama_model. clone ( ) ) ;
429+ }
430+ }
431+
432+ let _guard = ai_init_lock. lock ( ) . await ;
433+
434+ if let Some ( ai_state) = ai_state_mutex. lock ( ) . unwrap ( ) . as_ref ( ) {
435+ if let Some ( lama_model) = & ai_state. lama_model {
436+ return Ok ( lama_model. clone ( ) ) ;
437+ }
438+ }
439+
440+ let models_dir = get_models_dir ( app_handle) ?;
441+ download_and_verify_model (
442+ app_handle,
443+ & models_dir,
444+ LAMA_FILENAME ,
445+ LAMA_URL ,
446+ LAMA_SHA256 ,
447+ "Inpainting Model" ,
448+ )
449+ . await ?;
450+
451+ let _ = ort:: init ( ) . with_name ( "AI-Inpainting" ) . commit ( ) ;
452+ let model_path = models_dir. join ( LAMA_FILENAME ) ;
453+ let session = Session :: builder ( ) ?. commit_from_file ( model_path) ?;
454+ let lama_model = Arc :: new ( Mutex :: new ( session) ) ;
455+
456+ crate :: register_exit_handler ( ) ;
457+
458+ let mut ai_state_lock = ai_state_mutex. lock ( ) . unwrap ( ) ;
459+ if let Some ( state) = ai_state_lock. as_mut ( ) {
460+ state. lama_model = Some ( lama_model. clone ( ) ) ;
461+ } else {
462+ * ai_state_lock = Some ( AiState {
463+ models : None ,
464+ denoise_model : None ,
465+ clip_models : None ,
466+ lama_model : Some ( lama_model. clone ( ) ) ,
467+ embeddings : None ,
468+ } ) ;
469+ }
470+
471+ Ok ( lama_model)
472+ }
473+
410474#[ derive( Clone , Copy ) ]
411475struct TileParams {
412476 cs : usize ,
@@ -653,6 +717,143 @@ pub fn run_ai_denoise(
653717 Ok ( DynamicImage :: ImageRgb32F ( out_img_buffer) )
654718}
655719
720+ pub fn run_lama_inpainting (
721+ image : & DynamicImage ,
722+ mask : & GrayImage ,
723+ lama_session : & Mutex < Session > ,
724+ ) -> Result < RgbaImage > {
725+ let ( w, h) = image. dimensions ( ) ;
726+
727+ let ( mut min_x, mut min_y) = ( w, h) ;
728+ let ( mut max_x, mut max_y) = ( 0u32 , 0u32 ) ;
729+ let mut has_mask = false ;
730+
731+ for ( x, y, p) in mask. enumerate_pixels ( ) {
732+ if p[ 0 ] > 0 {
733+ min_x = min_x. min ( x) ;
734+ min_y = min_y. min ( y) ;
735+ max_x = max_x. max ( x) ;
736+ max_y = max_y. max ( y) ;
737+ has_mask = true ;
738+ }
739+ }
740+
741+ if !has_mask {
742+ return Ok ( image. to_rgba8 ( ) ) ;
743+ }
744+
745+ let mask_w = max_x - min_x + 1 ;
746+ let mask_h = max_y - min_y + 1 ;
747+
748+ let pad_x = 64 . max ( ( mask_w as f32 * 0.5 ) as u32 ) ;
749+ let pad_y = 64 . max ( ( mask_h as f32 * 0.5 ) as u32 ) ;
750+
751+ let x0 = min_x. saturating_sub ( pad_x) ;
752+ let y0 = min_y. saturating_sub ( pad_y) ;
753+ let x1 = ( max_x + pad_x) . min ( w. saturating_sub ( 1 ) ) ;
754+ let y1 = ( max_y + pad_y) . min ( h. saturating_sub ( 1 ) ) ;
755+
756+ let crop_w = x1 - x0 + 1 ;
757+ let crop_h = y1 - y0 + 1 ;
758+
759+ let rgba = image. to_rgba8 ( ) ;
760+
761+ let cropped_img = imageops:: crop_imm ( & rgba, x0, y0, crop_w, crop_h) . to_image ( ) ;
762+ let cropped_mask = imageops:: crop_imm ( mask, x0, y0, crop_w, crop_h) . to_image ( ) ;
763+
764+ let max_dim_limit: u32 = 1024 ;
765+ let needs_downscale = crop_w > max_dim_limit || crop_h > max_dim_limit;
766+
767+ let ( fw, fh, inf_img, inf_mask) = if needs_downscale {
768+ let scale = max_dim_limit as f32 / crop_w. max ( crop_h) as f32 ;
769+
770+ let scaled_w = ( crop_w as f32 * scale) . round ( ) . max ( 1.0 ) as u32 ;
771+ let scaled_h = ( crop_h as f32 * scale) . round ( ) . max ( 1.0 ) as u32 ;
772+
773+ (
774+ scaled_w,
775+ scaled_h,
776+ imageops:: resize ( & cropped_img, scaled_w, scaled_h, FilterType :: Lanczos3 ) ,
777+ imageops:: resize ( & cropped_mask, scaled_w, scaled_h, FilterType :: Triangle ) ,
778+ )
779+ } else {
780+ ( crop_w, crop_h, cropped_img. clone ( ) , cropped_mask. clone ( ) )
781+ } ;
782+
783+ let align = 64u32 ;
784+ let mut tensor_dim = fw. max ( fh) ;
785+ if tensor_dim % align != 0 {
786+ tensor_dim += align - ( tensor_dim % align) ;
787+ }
788+ let tensor_dim = tensor_dim. max ( align) as usize ;
789+
790+ let mut img_tensor = Array :: < f32 , _ > :: zeros ( ( 1 , 3 , tensor_dim, tensor_dim) ) ;
791+ let mut msk_tensor = Array :: < f32 , _ > :: zeros ( ( 1 , 1 , tensor_dim, tensor_dim) ) ;
792+
793+ for y in 0 ..tensor_dim {
794+ for x in 0 ..tensor_dim {
795+ let sx = ( x as u32 ) . min ( fw. saturating_sub ( 1 ) ) ;
796+ let sy = ( y as u32 ) . min ( fh. saturating_sub ( 1 ) ) ;
797+
798+ let p = inf_img. get_pixel ( sx, sy) ;
799+ let m = inf_mask. get_pixel ( sx, sy) [ 0 ] ;
800+
801+ img_tensor[ [ 0 , 0 , y, x] ] = p[ 0 ] as f32 / 255.0 ;
802+ img_tensor[ [ 0 , 1 , y, x] ] = p[ 1 ] as f32 / 255.0 ;
803+ img_tensor[ [ 0 , 2 , y, x] ] = p[ 2 ] as f32 / 255.0 ;
804+ msk_tensor[ [ 0 , 0 , y, x] ] = if m > 0 { 1.0 } else { 0.0 } ;
805+ }
806+ }
807+
808+ let t_img = Tensor :: from_array ( img_tensor. into_dyn ( ) . as_standard_layout ( ) . into_owned ( ) ) ?;
809+ let t_msk = Tensor :: from_array ( msk_tensor. into_dyn ( ) . as_standard_layout ( ) . into_owned ( ) ) ?;
810+
811+ let output_tensor = {
812+ let mut session = lama_session. lock ( ) . unwrap ( ) ;
813+ let outputs = session. run ( ort:: inputs![ "image" => t_img, "mask" => t_msk] ) ?;
814+ outputs[ 0 ] . try_extract_array :: < f32 > ( ) ?. to_owned ( )
815+ } ;
816+
817+ let mut result_inf = RgbaImage :: new ( fw, fh) ;
818+ for y in 0 ..fh {
819+ for x in 0 ..fw {
820+ let r = output_tensor[ [ 0 , 0 , y as usize , x as usize ] ] . clamp ( 0.0 , 255.0 ) as u8 ;
821+ let g = output_tensor[ [ 0 , 1 , y as usize , x as usize ] ] . clamp ( 0.0 , 255.0 ) as u8 ;
822+ let b = output_tensor[ [ 0 , 2 , y as usize , x as usize ] ] . clamp ( 0.0 , 255.0 ) as u8 ;
823+ result_inf. put_pixel ( x, y, Rgba ( [ r, g, b, 255 ] ) ) ;
824+ }
825+ }
826+
827+ let result_crop = if needs_downscale {
828+ imageops:: resize ( & result_inf, crop_w, crop_h, FilterType :: Lanczos3 )
829+ } else {
830+ result_inf
831+ } ;
832+
833+ let mut final_image = image. to_rgba8 ( ) ;
834+
835+ for y in 0 ..crop_h {
836+ for x in 0 ..crop_w {
837+ let m = cropped_mask. get_pixel ( x, y) [ 0 ] ;
838+ if m > 0 {
839+ let alpha = m as f32 / 255.0 ;
840+ let p = result_crop. get_pixel ( x, y) ;
841+ let gx = x0 + x;
842+ let gy = y0 + y;
843+ let orig = final_image. get_pixel ( gx, gy) ;
844+
845+ let r = ( p[ 0 ] as f32 * alpha + orig[ 0 ] as f32 * ( 1.0 - alpha) ) as u8 ;
846+ let g = ( p[ 1 ] as f32 * alpha + orig[ 1 ] as f32 * ( 1.0 - alpha) ) as u8 ;
847+ let b = ( p[ 2 ] as f32 * alpha + orig[ 2 ] as f32 * ( 1.0 - alpha) ) as u8 ;
848+
849+ final_image. put_pixel ( gx, gy, Rgba ( [ r, g, b, 255 ] ) ) ;
850+ }
851+ }
852+ }
853+
854+ Ok ( final_image)
855+ }
856+
656857pub fn generate_image_embeddings (
657858 image : & DynamicImage ,
658859 encoder : & Mutex < Session > ,
0 commit comments