Skip to content

Commit 0938e98

Browse files
Merge branch 'CyberTimon:main' into feat-brush-resizing-shortcuts
2 parents 1617b70 + a683964 commit 0938e98

10 files changed

Lines changed: 425 additions & 731 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,7 @@ A huge thank you to the following projects and tools that were very important in
551551
- **[Google AI Studio](https://aistudio.google.com):** For providing amazing assistance in researching, implementing image processing algorithms and giving an overall speed boost.
552552
- **[rawler](https://github.com/dnglab/dnglab/tree/main/rawler):** For the excellent Rust crate that provides the foundation for RAW file processing in this project.
553553
- **[lensfun](https://lensfun.github.io/):** For its invaluable open-source library and comprehensive database for automatic lens correction.
554+
- **[LaMa](https://github.com/advimman/lama):** For the powerful & simple image inpainting model, which enables content-aware fill and object removal.
554555
- **[NegPy](https://github.com/marcinz606/NegPy):** For the inspiration behind the negative conversion logic, particularly the mathematical approach to film inversion using characteristic curves.
555556
- **[pixls.us](https://discuss.pixls.us/):** For being an incredible community full of knowledgeable people who offered inspiration, advice, and ideas.
556557
- **[darktable & co.](https://github.com/darktable-org/darktable):** For some reference implementations that guided parts of this work.

src-tauri/src/ai_processing.rs

Lines changed: 228 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ use std::sync::{Arc, Mutex};
55

66
use anyhow::Result;
77
use image::imageops::{self, FilterType};
8-
use image::{DynamicImage, GenericImageView, GrayImage, ImageBuffer, Luma, Rgb, Rgb32FImage};
8+
use image::{
9+
DynamicImage, GenericImageView, GrayImage, ImageBuffer, Luma, Rgb, Rgb32FImage, Rgba, RgbaImage,
10+
};
911
use ndarray::{Array, Array4, IxDyn};
1012
use ort::session::Session;
1113
use ort::value::Tensor;
@@ -46,6 +48,11 @@ const DENOISE_URL: &str = "https://huggingface.co/CyberTimon/RapidRAW-Models/res
4648
const DENOISE_FILENAME: &str = "nind_denoise_utnet_684.onnx";
4749
const DENOISE_SHA256: &str = "ee3586279d514df557ff3f7dec6df37fafc51ba5d3a3435b2cc9ac2d9017e7fe";
4850

51+
const LAMA_URL: &str =
52+
"https://huggingface.co/CyberTimon/RapidRAW-Models/resolve/main/lama_fp16.onnx?download=true";
53+
const LAMA_FILENAME: &str = "lama_fp16.onnx";
54+
const LAMA_SHA256: &str = "2d6be6277c400d6f1b91819737f7c3da935e5c63d1b521d393be1196a2bfa82c";
55+
4956
pub struct AiModels {
5057
pub sam_encoder: Mutex<Session>,
5158
pub sam_decoder: Mutex<Session>,
@@ -69,6 +76,7 @@ pub struct AiState {
6976
pub models: Option<Arc<AiModels>>,
7077
pub denoise_model: Option<Arc<Mutex<Session>>>,
7178
pub clip_models: Option<Arc<ClipModels>>,
79+
pub lama_model: Option<Arc<Mutex<Session>>>,
7280
pub embeddings: Option<ImageEmbeddings>,
7381
}
7482

@@ -203,18 +211,18 @@ pub async fn get_or_init_ai_models(
203211
ai_state_mutex: &Mutex<Option<AiState>>,
204212
ai_init_lock: &TokioMutex<()>,
205213
) -> Result<Arc<AiModels>> {
206-
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref()
207-
&& let Some(models) = &ai_state.models
208-
{
209-
return Ok(models.clone());
214+
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref() {
215+
if let Some(models) = &ai_state.models {
216+
return Ok(models.clone());
217+
}
210218
}
211219

212220
let _guard = ai_init_lock.lock().await;
213221

214-
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref()
215-
&& let Some(models) = &ai_state.models
216-
{
217-
return Ok(models.clone());
222+
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref() {
223+
if let Some(models) = &ai_state.models {
224+
return Ok(models.clone());
225+
}
218226
}
219227

220228
let models_dir = get_models_dir(app_handle)?;
@@ -285,6 +293,7 @@ pub async fn get_or_init_ai_models(
285293
models: Some(models.clone()),
286294
denoise_model: None,
287295
clip_models: None,
296+
lama_model: None,
288297
embeddings: None,
289298
});
290299
}
@@ -297,18 +306,18 @@ pub async fn get_or_init_denoise_model(
297306
ai_state_mutex: &Mutex<Option<AiState>>,
298307
ai_init_lock: &TokioMutex<()>,
299308
) -> Result<Arc<Mutex<Session>>> {
300-
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref()
301-
&& let Some(denoise_model) = &ai_state.denoise_model
302-
{
303-
return Ok(denoise_model.clone());
309+
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref() {
310+
if let Some(denoise_model) = &ai_state.denoise_model {
311+
return Ok(denoise_model.clone());
312+
}
304313
}
305314

306315
let _guard = ai_init_lock.lock().await;
307316

308-
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref()
309-
&& let Some(denoise_model) = &ai_state.denoise_model
310-
{
311-
return Ok(denoise_model.clone());
317+
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref() {
318+
if let Some(denoise_model) = &ai_state.denoise_model {
319+
return Ok(denoise_model.clone());
320+
}
312321
}
313322

314323
let models_dir = get_models_dir(app_handle)?;
@@ -318,11 +327,11 @@ pub async fn get_or_init_denoise_model(
318327
DENOISE_FILENAME,
319328
DENOISE_URL,
320329
DENOISE_SHA256,
321-
"AI Denoise Model",
330+
"NIND Denoise Model",
322331
)
323332
.await?;
324333

325-
let _ = ort::init().with_name("RapidRAW-Denoise").commit();
334+
let _ = ort::init().with_name("AI-Denoise").commit();
326335
let model_path = models_dir.join(DENOISE_FILENAME);
327336
let session = Session::builder()?.commit_from_file(model_path)?;
328337
let denoise_model = Arc::new(Mutex::new(session));
@@ -337,6 +346,7 @@ pub async fn get_or_init_denoise_model(
337346
models: None,
338347
denoise_model: Some(denoise_model.clone()),
339348
clip_models: None,
349+
lama_model: None,
340350
embeddings: None,
341351
});
342352
}
@@ -349,18 +359,18 @@ pub async fn get_or_init_clip_models(
349359
ai_state_mutex: &Mutex<Option<AiState>>,
350360
ai_init_lock: &TokioMutex<()>,
351361
) -> Result<Arc<ClipModels>> {
352-
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref()
353-
&& let Some(clip_models) = &ai_state.clip_models
354-
{
355-
return Ok(clip_models.clone());
362+
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref() {
363+
if let Some(clip_models) = &ai_state.clip_models {
364+
return Ok(clip_models.clone());
365+
}
356366
}
357367

358368
let _guard = ai_init_lock.lock().await;
359369

360-
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref()
361-
&& let Some(clip_models) = &ai_state.clip_models
362-
{
363-
return Ok(clip_models.clone());
370+
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref() {
371+
if let Some(clip_models) = &ai_state.clip_models {
372+
return Ok(clip_models.clone());
373+
}
364374
}
365375

366376
let models_dir = get_models_dir(app_handle)?;
@@ -400,13 +410,67 @@ pub async fn get_or_init_clip_models(
400410
models: None,
401411
denoise_model: None,
402412
clip_models: Some(clip_models.clone()),
413+
lama_model: None,
403414
embeddings: None,
404415
});
405416
}
406417

407418
Ok(clip_models)
408419
}
409420

421+
pub async fn get_or_init_lama_model(
422+
app_handle: &tauri::AppHandle,
423+
ai_state_mutex: &Mutex<Option<AiState>>,
424+
ai_init_lock: &TokioMutex<()>,
425+
) -> Result<Arc<Mutex<Session>>> {
426+
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref() {
427+
if let Some(lama_model) = &ai_state.lama_model {
428+
return Ok(lama_model.clone());
429+
}
430+
}
431+
432+
let _guard = ai_init_lock.lock().await;
433+
434+
if let Some(ai_state) = ai_state_mutex.lock().unwrap().as_ref() {
435+
if let Some(lama_model) = &ai_state.lama_model {
436+
return Ok(lama_model.clone());
437+
}
438+
}
439+
440+
let models_dir = get_models_dir(app_handle)?;
441+
download_and_verify_model(
442+
app_handle,
443+
&models_dir,
444+
LAMA_FILENAME,
445+
LAMA_URL,
446+
LAMA_SHA256,
447+
"Inpainting Model",
448+
)
449+
.await?;
450+
451+
let _ = ort::init().with_name("AI-Inpainting").commit();
452+
let model_path = models_dir.join(LAMA_FILENAME);
453+
let session = Session::builder()?.commit_from_file(model_path)?;
454+
let lama_model = Arc::new(Mutex::new(session));
455+
456+
crate::register_exit_handler();
457+
458+
let mut ai_state_lock = ai_state_mutex.lock().unwrap();
459+
if let Some(state) = ai_state_lock.as_mut() {
460+
state.lama_model = Some(lama_model.clone());
461+
} else {
462+
*ai_state_lock = Some(AiState {
463+
models: None,
464+
denoise_model: None,
465+
clip_models: None,
466+
lama_model: Some(lama_model.clone()),
467+
embeddings: None,
468+
});
469+
}
470+
471+
Ok(lama_model)
472+
}
473+
410474
#[derive(Clone, Copy)]
411475
struct TileParams {
412476
cs: usize,
@@ -653,6 +717,143 @@ pub fn run_ai_denoise(
653717
Ok(DynamicImage::ImageRgb32F(out_img_buffer))
654718
}
655719

720+
pub fn run_lama_inpainting(
721+
image: &DynamicImage,
722+
mask: &GrayImage,
723+
lama_session: &Mutex<Session>,
724+
) -> Result<RgbaImage> {
725+
let (w, h) = image.dimensions();
726+
727+
let (mut min_x, mut min_y) = (w, h);
728+
let (mut max_x, mut max_y) = (0u32, 0u32);
729+
let mut has_mask = false;
730+
731+
for (x, y, p) in mask.enumerate_pixels() {
732+
if p[0] > 0 {
733+
min_x = min_x.min(x);
734+
min_y = min_y.min(y);
735+
max_x = max_x.max(x);
736+
max_y = max_y.max(y);
737+
has_mask = true;
738+
}
739+
}
740+
741+
if !has_mask {
742+
return Ok(image.to_rgba8());
743+
}
744+
745+
let mask_w = max_x - min_x + 1;
746+
let mask_h = max_y - min_y + 1;
747+
748+
let pad_x = 64.max((mask_w as f32 * 0.5) as u32);
749+
let pad_y = 64.max((mask_h as f32 * 0.5) as u32);
750+
751+
let x0 = min_x.saturating_sub(pad_x);
752+
let y0 = min_y.saturating_sub(pad_y);
753+
let x1 = (max_x + pad_x).min(w.saturating_sub(1));
754+
let y1 = (max_y + pad_y).min(h.saturating_sub(1));
755+
756+
let crop_w = x1 - x0 + 1;
757+
let crop_h = y1 - y0 + 1;
758+
759+
let rgba = image.to_rgba8();
760+
761+
let cropped_img = imageops::crop_imm(&rgba, x0, y0, crop_w, crop_h).to_image();
762+
let cropped_mask = imageops::crop_imm(mask, x0, y0, crop_w, crop_h).to_image();
763+
764+
let max_dim_limit: u32 = 1024;
765+
let needs_downscale = crop_w > max_dim_limit || crop_h > max_dim_limit;
766+
767+
let (fw, fh, inf_img, inf_mask) = if needs_downscale {
768+
let scale = max_dim_limit as f32 / crop_w.max(crop_h) as f32;
769+
770+
let scaled_w = (crop_w as f32 * scale).round().max(1.0) as u32;
771+
let scaled_h = (crop_h as f32 * scale).round().max(1.0) as u32;
772+
773+
(
774+
scaled_w,
775+
scaled_h,
776+
imageops::resize(&cropped_img, scaled_w, scaled_h, FilterType::Lanczos3),
777+
imageops::resize(&cropped_mask, scaled_w, scaled_h, FilterType::Triangle),
778+
)
779+
} else {
780+
(crop_w, crop_h, cropped_img.clone(), cropped_mask.clone())
781+
};
782+
783+
let align = 64u32;
784+
let mut tensor_dim = fw.max(fh);
785+
if tensor_dim % align != 0 {
786+
tensor_dim += align - (tensor_dim % align);
787+
}
788+
let tensor_dim = tensor_dim.max(align) as usize;
789+
790+
let mut img_tensor = Array::<f32, _>::zeros((1, 3, tensor_dim, tensor_dim));
791+
let mut msk_tensor = Array::<f32, _>::zeros((1, 1, tensor_dim, tensor_dim));
792+
793+
for y in 0..tensor_dim {
794+
for x in 0..tensor_dim {
795+
let sx = (x as u32).min(fw.saturating_sub(1));
796+
let sy = (y as u32).min(fh.saturating_sub(1));
797+
798+
let p = inf_img.get_pixel(sx, sy);
799+
let m = inf_mask.get_pixel(sx, sy)[0];
800+
801+
img_tensor[[0, 0, y, x]] = p[0] as f32 / 255.0;
802+
img_tensor[[0, 1, y, x]] = p[1] as f32 / 255.0;
803+
img_tensor[[0, 2, y, x]] = p[2] as f32 / 255.0;
804+
msk_tensor[[0, 0, y, x]] = if m > 0 { 1.0 } else { 0.0 };
805+
}
806+
}
807+
808+
let t_img = Tensor::from_array(img_tensor.into_dyn().as_standard_layout().into_owned())?;
809+
let t_msk = Tensor::from_array(msk_tensor.into_dyn().as_standard_layout().into_owned())?;
810+
811+
let output_tensor = {
812+
let mut session = lama_session.lock().unwrap();
813+
let outputs = session.run(ort::inputs!["image" => t_img, "mask" => t_msk])?;
814+
outputs[0].try_extract_array::<f32>()?.to_owned()
815+
};
816+
817+
let mut result_inf = RgbaImage::new(fw, fh);
818+
for y in 0..fh {
819+
for x in 0..fw {
820+
let r = output_tensor[[0, 0, y as usize, x as usize]].clamp(0.0, 255.0) as u8;
821+
let g = output_tensor[[0, 1, y as usize, x as usize]].clamp(0.0, 255.0) as u8;
822+
let b = output_tensor[[0, 2, y as usize, x as usize]].clamp(0.0, 255.0) as u8;
823+
result_inf.put_pixel(x, y, Rgba([r, g, b, 255]));
824+
}
825+
}
826+
827+
let result_crop = if needs_downscale {
828+
imageops::resize(&result_inf, crop_w, crop_h, FilterType::Lanczos3)
829+
} else {
830+
result_inf
831+
};
832+
833+
let mut final_image = image.to_rgba8();
834+
835+
for y in 0..crop_h {
836+
for x in 0..crop_w {
837+
let m = cropped_mask.get_pixel(x, y)[0];
838+
if m > 0 {
839+
let alpha = m as f32 / 255.0;
840+
let p = result_crop.get_pixel(x, y);
841+
let gx = x0 + x;
842+
let gy = y0 + y;
843+
let orig = final_image.get_pixel(gx, gy);
844+
845+
let r = (p[0] as f32 * alpha + orig[0] as f32 * (1.0 - alpha)) as u8;
846+
let g = (p[1] as f32 * alpha + orig[1] as f32 * (1.0 - alpha)) as u8;
847+
let b = (p[2] as f32 * alpha + orig[2] as f32 * (1.0 - alpha)) as u8;
848+
849+
final_image.put_pixel(gx, gy, Rgba([r, g, b, 255]));
850+
}
851+
}
852+
}
853+
854+
Ok(final_image)
855+
}
856+
656857
pub fn generate_image_embeddings(
657858
image: &DynamicImage,
658859
encoder: &Mutex<Session>,

0 commit comments

Comments
 (0)