@@ -341,76 +341,95 @@ def _auto_occlusion_threshold(flow_fwd: np.ndarray, flow_bwd: np.ndarray) -> flo
341341 threshold = p95 + max ((p95 - p85 ) * 0.5 , 0.5 )
342342 return float (np .clip (threshold , 1.0 , 15.0 ))
343343
344+ def _match_histogram (source : np .ndarray , template : np .ndarray ) -> np .ndarray :
345+ """
346+ Adjust the pixel values of a source image such that its histogram
347+ matches that of a target template image.
348+ Both source and template should be 2D numpy arrays (a single channel).
349+ """
350+ oldshape = source .shape
351+ source_flat = source .ravel ()
352+ template_flat = template .ravel ()
353+
354+ # get the set of unique pixel values and their corresponding indices and counts
355+ s_values , bin_idx , s_counts = np .unique (source_flat , return_inverse = True , return_counts = True )
356+ t_values , t_counts = np .unique (template_flat , return_counts = True )
357+
358+ # take the cumsum of the counts and normalize by the number of pixels to
359+ # get the empirical cumulative distribution functions for the source and
360+ # template images (maps pixel value --> quantile)
361+ s_quantiles = np .cumsum (s_counts ).astype (np .float64 )
362+ s_quantiles /= s_quantiles [- 1 ]
363+
364+ t_quantiles = np .cumsum (t_counts ).astype (np .float64 )
365+ t_quantiles /= t_quantiles [- 1 ]
366+
367+ # interpolate linearly to find the pixel values in the template image
368+ # that correspond most closely to the quantiles in the source image
369+ interp_t_values = np .interp (s_quantiles , t_quantiles , t_values )
370+
371+ return interp_t_values [bin_idx ].reshape (oldshape )
372+
344373def _match_image_properties (
345374 original_tensor : torch .Tensor ,
346375 generated_tensor : torch .Tensor ,
347376 overall_weight : float ,
348377 color_weight : float ,
349- saturation_weight : float ,
350378 lighting_weight : float ,
379+ mask_tensor : torch .Tensor = None ,
351380) -> torch .Tensor :
352- # We will do color and lighting transfer in LAB space.
353- # Saturation transfer will be handled via blending in HSV space if needed,
354- # but LAB a/b channels inherently affect colorfulness. We can just use HSV
355- # for saturation strictly.
356381
357382 batch_size = generated_tensor .size (0 )
358383 out_tensors = []
359384
360385 orig_batch = original_tensor .size (0 )
386+ mask_batch = mask_tensor .size (0 ) if mask_tensor is not None else 0
361387
362388 for i in range (batch_size ):
363389 orig_i = i if i < orig_batch else 0
364390
365391 orig_np = np .clip (255.0 * original_tensor [orig_i ].cpu ().numpy ().squeeze (), 0 , 255 ).astype (np .uint8 )
366392 gen_np = np .clip (255.0 * generated_tensor [i ].cpu ().numpy ().squeeze (), 0 , 255 ).astype (np .uint8 )
367393
368- # LAB for Color & Lighting
369- orig_lab = cv2 .cvtColor (orig_np , cv2 .COLOR_RGB2LAB ).astype (np .float32 )
370- gen_lab = cv2 .cvtColor (gen_np , cv2 .COLOR_RGB2LAB ).astype (np .float32 )
394+ mask_np = None
395+ if mask_tensor is not None :
396+ mask_i = i if i < mask_batch else 0
397+ # Extract mask, it might be (H, W) or (1, H, W) or (C, H, W)
398+ # Typically comfy masks are (H, W)
399+ m_t = mask_tensor [mask_i ].cpu ().numpy ()
400+ if m_t .ndim > 2 :
401+ m_t = m_t .squeeze ()
402+ if m_t .shape != gen_np .shape [:2 ]:
403+ m_t = cv2 .resize (m_t , (gen_np .shape [1 ], gen_np .shape [0 ]), interpolation = cv2 .INTER_LINEAR )
404+ mask_np = m_t [:, :, np .newaxis ] # (H, W, 1)
371405
372- # Means and Std Devs for LAB
373- orig_l_mean , orig_l_std = orig_lab [:, :, 0 ].mean (), orig_lab [:, :, 0 ].std ()
374- orig_a_mean , orig_a_std = orig_lab [:, :, 1 ].mean (), orig_lab [:, :, 1 ].std ()
375- orig_b_mean , orig_b_std = orig_lab [:, :, 2 ].mean (), orig_lab [:, :, 2 ].std ()
406+ orig_lab = cv2 .cvtColor (orig_np , cv2 .COLOR_RGB2LAB )
407+ gen_lab = cv2 .cvtColor (gen_np , cv2 .COLOR_RGB2LAB )
376408
377- gen_l_mean , gen_l_std = gen_lab [:, :, 0 ].mean (), gen_lab [:, :, 0 ].std ()
378- gen_a_mean , gen_a_std = gen_lab [:, :, 1 ].mean (), gen_lab [:, :, 1 ].std ()
379- gen_b_mean , gen_b_std = gen_lab [:, :, 2 ].mean (), gen_lab [:, :, 2 ].std ()
409+ out_lab = np .copy (gen_lab ).astype (np .float32 )
410+ gen_lab_f = gen_lab .astype (np .float32 )
380411
381- out_lab = np .copy (gen_lab )
382-
383- # Calculate full transfer
384- l_trans = (gen_lab [:, :, 0 ] - gen_l_mean ) * (orig_l_std / (gen_l_std + 1e-5 )) + orig_l_mean
385- a_trans = (gen_lab [:, :, 1 ] - gen_a_mean ) * (orig_a_std / (gen_a_std + 1e-5 )) + orig_a_mean
386- b_trans = (gen_lab [:, :, 2 ] - gen_b_mean ) * (orig_b_std / (gen_b_std + 1e-5 )) + orig_b_mean
387-
388- # Blend based on weights
412+ # 1. Lighting (L channel)
413+ l_trans = _match_histogram (gen_lab [:, :, 0 ], orig_lab [:, :, 0 ])
389414 l_weight = lighting_weight * overall_weight
415+ out_lab [:, :, 0 ] = gen_lab_f [:, :, 0 ] * (1.0 - l_weight ) + l_trans * l_weight
416+
417+ # 2. Color (A and B channels)
418+ # If color_weight > 0, we match the A and B histograms
419+ a_trans = _match_histogram (gen_lab [:, :, 1 ], orig_lab [:, :, 1 ])
420+ b_trans = _match_histogram (gen_lab [:, :, 2 ], orig_lab [:, :, 2 ])
390421 c_weight = color_weight * overall_weight
391422
392- out_lab [:, :, 0 ] = gen_lab [:, :, 0 ] * (1.0 - l_weight ) + l_trans * l_weight
393- out_lab [:, :, 1 ] = gen_lab [:, :, 1 ] * (1.0 - c_weight ) + a_trans * c_weight
394- out_lab [:, :, 2 ] = gen_lab [:, :, 2 ] * (1.0 - c_weight ) + b_trans * c_weight
423+ out_lab [:, :, 1 ] = gen_lab_f [:, :, 1 ] * (1.0 - c_weight ) + a_trans * c_weight
424+ out_lab [:, :, 2 ] = gen_lab_f [:, :, 2 ] * (1.0 - c_weight ) + b_trans * c_weight
425+
426+ # Apply soft masking if provided
427+ if mask_np is not None :
428+ out_lab = gen_lab_f * (1.0 - mask_np ) + out_lab * mask_np
395429
396430 out_lab = np .clip (out_lab , 0 , 255 ).astype (np .uint8 )
397431 res_rgb = cv2 .cvtColor (out_lab , cv2 .COLOR_LAB2RGB )
398432
399- # Now handle Saturation in HSV space
400- if saturation_weight > 0.0 :
401- res_hsv = cv2 .cvtColor (res_rgb , cv2 .COLOR_RGB2HSV ).astype (np .float32 )
402- orig_hsv = cv2 .cvtColor (orig_np , cv2 .COLOR_RGB2HSV ).astype (np .float32 )
403-
404- orig_s_mean = orig_hsv [:, :, 1 ].mean ()
405- gen_s_mean = res_hsv [:, :, 1 ].mean ()
406-
407- sat_ratio = (orig_s_mean + 1e-5 ) / (gen_s_mean + 1e-5 )
408- s_weight = saturation_weight * overall_weight
409-
410- effective_sat_ratio = 1.0 + (sat_ratio - 1.0 ) * s_weight
411- res_hsv [:, :, 1 ] = np .clip (res_hsv [:, :, 1 ] * effective_sat_ratio , 0 , 255 )
412- res_rgb = cv2 .cvtColor (res_hsv .astype (np .uint8 ), cv2 .COLOR_HSV2RGB )
413-
414433 out_tensor = torch .from_numpy (res_rgb .astype (np .float32 ) / 255.0 ).unsqueeze (0 )
415434 out_tensors .append (out_tensor )
416435
@@ -2213,10 +2232,10 @@ def define_schema(cls) -> io.Schema:
22132232 inputs = [
22142233 io .Image .Input ("original_image" ),
22152234 io .Image .Input ("generated_image" ),
2216- io .Float .Input ("overall_weight" , default = 1.0 , min = 0.0 , max = 10 .0 , step = 0.01 ),
2217- io .Float .Input ("color_weight" , default = 1.0 , min = 0.0 , max = 10 .0 , step = 0.01 ),
2218- io .Float .Input ("saturation_weight " , default = 1.0 , min = 0.0 , max = 10 .0 , step = 0.01 ),
2219- io .Float .Input ("lighting_weight " , default = 1.0 , min = 0.0 , max = 10.0 , step = 0.01 ),
2235+ io .Float .Input ("overall_weight" , default = 1.0 , min = 0.0 , max = 1 .0 , step = 0.001 ),
2236+ io .Float .Input ("color_weight" , default = 1.0 , min = 0.0 , max = 1 .0 , step = 0.001 ),
2237+ io .Float .Input ("lighting_weight " , default = 1.0 , min = 0.0 , max = 1 .0 , step = 0.001 ),
2238+ io .Mask .Input ("mask " , optional = True , tooltip = "Optional mask to softly blend the color/lighting changes onto the generated image." ),
22202239 ],
22212240 outputs = [
22222241 io .Image .Output (display_name = "image" ),
@@ -2230,16 +2249,16 @@ def execute(
22302249 generated_image : torch .Tensor ,
22312250 overall_weight : float ,
22322251 color_weight : float ,
2233- saturation_weight : float ,
22342252 lighting_weight : float ,
2253+ mask : torch .Tensor = None ,
22352254 ) -> io .NodeOutput :
22362255 result = _match_image_properties (
22372256 original_image ,
22382257 generated_image ,
22392258 overall_weight ,
22402259 color_weight ,
2241- saturation_weight ,
22422260 lighting_weight ,
2261+ mask ,
22432262 )
22442263 return io .NodeOutput (result )
22452264
0 commit comments