Skip to content

Commit 155b7f3

Browse files
committed
improve the new node
1 parent 6852302 commit 155b7f3

1 file changed

Lines changed: 66 additions & 47 deletions

File tree

__init__.py

Lines changed: 66 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -341,76 +341,95 @@ def _auto_occlusion_threshold(flow_fwd: np.ndarray, flow_bwd: np.ndarray) -> flo
341341
threshold = p95 + max((p95 - p85) * 0.5, 0.5)
342342
return float(np.clip(threshold, 1.0, 15.0))
343343

344+
def _match_histogram(source: np.ndarray, template: np.ndarray) -> np.ndarray:
345+
"""
346+
Adjust the pixel values of a source image such that its histogram
347+
matches that of a target template image.
348+
Both source and template should be 2D numpy arrays (a single channel).
349+
"""
350+
oldshape = source.shape
351+
source_flat = source.ravel()
352+
template_flat = template.ravel()
353+
354+
# get the set of unique pixel values and their corresponding indices and counts
355+
s_values, bin_idx, s_counts = np.unique(source_flat, return_inverse=True, return_counts=True)
356+
t_values, t_counts = np.unique(template_flat, return_counts=True)
357+
358+
# take the cumsum of the counts and normalize by the number of pixels to
359+
# get the empirical cumulative distribution functions for the source and
360+
# template images (maps pixel value --> quantile)
361+
s_quantiles = np.cumsum(s_counts).astype(np.float64)
362+
s_quantiles /= s_quantiles[-1]
363+
364+
t_quantiles = np.cumsum(t_counts).astype(np.float64)
365+
t_quantiles /= t_quantiles[-1]
366+
367+
# interpolate linearly to find the pixel values in the template image
368+
# that correspond most closely to the quantiles in the source image
369+
interp_t_values = np.interp(s_quantiles, t_quantiles, t_values)
370+
371+
return interp_t_values[bin_idx].reshape(oldshape)
372+
344373
def _match_image_properties(
345374
original_tensor: torch.Tensor,
346375
generated_tensor: torch.Tensor,
347376
overall_weight: float,
348377
color_weight: float,
349-
saturation_weight: float,
350378
lighting_weight: float,
379+
mask_tensor: torch.Tensor = None,
351380
) -> torch.Tensor:
352-
# We will do color and lighting transfer in LAB space.
353-
# Saturation transfer will be handled via blending in HSV space if needed,
354-
# but LAB a/b channels inherently affect colorfulness. We can just use HSV
355-
# for saturation strictly.
356381

357382
batch_size = generated_tensor.size(0)
358383
out_tensors = []
359384

360385
orig_batch = original_tensor.size(0)
386+
mask_batch = mask_tensor.size(0) if mask_tensor is not None else 0
361387

362388
for i in range(batch_size):
363389
orig_i = i if i < orig_batch else 0
364390

365391
orig_np = np.clip(255.0 * original_tensor[orig_i].cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
366392
gen_np = np.clip(255.0 * generated_tensor[i].cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
367393

368-
# LAB for Color & Lighting
369-
orig_lab = cv2.cvtColor(orig_np, cv2.COLOR_RGB2LAB).astype(np.float32)
370-
gen_lab = cv2.cvtColor(gen_np, cv2.COLOR_RGB2LAB).astype(np.float32)
394+
mask_np = None
395+
if mask_tensor is not None:
396+
mask_i = i if i < mask_batch else 0
397+
# Extract mask, it might be (H, W) or (1, H, W) or (C, H, W)
398+
# Typically comfy masks are (H, W)
399+
m_t = mask_tensor[mask_i].cpu().numpy()
400+
if m_t.ndim > 2:
401+
m_t = m_t.squeeze()
402+
if m_t.shape != gen_np.shape[:2]:
403+
m_t = cv2.resize(m_t, (gen_np.shape[1], gen_np.shape[0]), interpolation=cv2.INTER_LINEAR)
404+
mask_np = m_t[:, :, np.newaxis] # (H, W, 1)
371405

372-
# Means and Std Devs for LAB
373-
orig_l_mean, orig_l_std = orig_lab[:, :, 0].mean(), orig_lab[:, :, 0].std()
374-
orig_a_mean, orig_a_std = orig_lab[:, :, 1].mean(), orig_lab[:, :, 1].std()
375-
orig_b_mean, orig_b_std = orig_lab[:, :, 2].mean(), orig_lab[:, :, 2].std()
406+
orig_lab = cv2.cvtColor(orig_np, cv2.COLOR_RGB2LAB)
407+
gen_lab = cv2.cvtColor(gen_np, cv2.COLOR_RGB2LAB)
376408

377-
gen_l_mean, gen_l_std = gen_lab[:, :, 0].mean(), gen_lab[:, :, 0].std()
378-
gen_a_mean, gen_a_std = gen_lab[:, :, 1].mean(), gen_lab[:, :, 1].std()
379-
gen_b_mean, gen_b_std = gen_lab[:, :, 2].mean(), gen_lab[:, :, 2].std()
409+
out_lab = np.copy(gen_lab).astype(np.float32)
410+
gen_lab_f = gen_lab.astype(np.float32)
380411

381-
out_lab = np.copy(gen_lab)
382-
383-
# Calculate full transfer
384-
l_trans = (gen_lab[:, :, 0] - gen_l_mean) * (orig_l_std / (gen_l_std + 1e-5)) + orig_l_mean
385-
a_trans = (gen_lab[:, :, 1] - gen_a_mean) * (orig_a_std / (gen_a_std + 1e-5)) + orig_a_mean
386-
b_trans = (gen_lab[:, :, 2] - gen_b_mean) * (orig_b_std / (gen_b_std + 1e-5)) + orig_b_mean
387-
388-
# Blend based on weights
412+
# 1. Lighting (L channel)
413+
l_trans = _match_histogram(gen_lab[:, :, 0], orig_lab[:, :, 0])
389414
l_weight = lighting_weight * overall_weight
415+
out_lab[:, :, 0] = gen_lab_f[:, :, 0] * (1.0 - l_weight) + l_trans * l_weight
416+
417+
# 2. Color (A and B channels)
418+
# If color_weight > 0, we match the A and B histograms
419+
a_trans = _match_histogram(gen_lab[:, :, 1], orig_lab[:, :, 1])
420+
b_trans = _match_histogram(gen_lab[:, :, 2], orig_lab[:, :, 2])
390421
c_weight = color_weight * overall_weight
391422

392-
out_lab[:, :, 0] = gen_lab[:, :, 0] * (1.0 - l_weight) + l_trans * l_weight
393-
out_lab[:, :, 1] = gen_lab[:, :, 1] * (1.0 - c_weight) + a_trans * c_weight
394-
out_lab[:, :, 2] = gen_lab[:, :, 2] * (1.0 - c_weight) + b_trans * c_weight
423+
out_lab[:, :, 1] = gen_lab_f[:, :, 1] * (1.0 - c_weight) + a_trans * c_weight
424+
out_lab[:, :, 2] = gen_lab_f[:, :, 2] * (1.0 - c_weight) + b_trans * c_weight
425+
426+
# Apply soft masking if provided
427+
if mask_np is not None:
428+
out_lab = gen_lab_f * (1.0 - mask_np) + out_lab * mask_np
395429

396430
out_lab = np.clip(out_lab, 0, 255).astype(np.uint8)
397431
res_rgb = cv2.cvtColor(out_lab, cv2.COLOR_LAB2RGB)
398432

399-
# Now handle Saturation in HSV space
400-
if saturation_weight > 0.0:
401-
res_hsv = cv2.cvtColor(res_rgb, cv2.COLOR_RGB2HSV).astype(np.float32)
402-
orig_hsv = cv2.cvtColor(orig_np, cv2.COLOR_RGB2HSV).astype(np.float32)
403-
404-
orig_s_mean = orig_hsv[:, :, 1].mean()
405-
gen_s_mean = res_hsv[:, :, 1].mean()
406-
407-
sat_ratio = (orig_s_mean + 1e-5) / (gen_s_mean + 1e-5)
408-
s_weight = saturation_weight * overall_weight
409-
410-
effective_sat_ratio = 1.0 + (sat_ratio - 1.0) * s_weight
411-
res_hsv[:, :, 1] = np.clip(res_hsv[:, :, 1] * effective_sat_ratio, 0, 255)
412-
res_rgb = cv2.cvtColor(res_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
413-
414433
out_tensor = torch.from_numpy(res_rgb.astype(np.float32) / 255.0).unsqueeze(0)
415434
out_tensors.append(out_tensor)
416435

@@ -2213,10 +2232,10 @@ def define_schema(cls) -> io.Schema:
22132232
inputs=[
22142233
io.Image.Input("original_image"),
22152234
io.Image.Input("generated_image"),
2216-
io.Float.Input("overall_weight", default=1.0, min=0.0, max=10.0, step=0.01),
2217-
io.Float.Input("color_weight", default=1.0, min=0.0, max=10.0, step=0.01),
2218-
io.Float.Input("saturation_weight", default=1.0, min=0.0, max=10.0, step=0.01),
2219-
io.Float.Input("lighting_weight", default=1.0, min=0.0, max=10.0, step=0.01),
2235+
io.Float.Input("overall_weight", default=1.0, min=0.0, max=1.0, step=0.001),
2236+
io.Float.Input("color_weight", default=1.0, min=0.0, max=1.0, step=0.001),
2237+
io.Float.Input("lighting_weight", default=1.0, min=0.0, max=1.0, step=0.001),
2238+
io.Mask.Input("mask", optional=True, tooltip="Optional mask to softly blend the color/lighting changes onto the generated image."),
22202239
],
22212240
outputs=[
22222241
io.Image.Output(display_name="image"),
@@ -2230,16 +2249,16 @@ def execute(
22302249
generated_image: torch.Tensor,
22312250
overall_weight: float,
22322251
color_weight: float,
2233-
saturation_weight: float,
22342252
lighting_weight: float,
2253+
mask: torch.Tensor = None,
22352254
) -> io.NodeOutput:
22362255
result = _match_image_properties(
22372256
original_image,
22382257
generated_image,
22392258
overall_weight,
22402259
color_weight,
2241-
saturation_weight,
22422260
lighting_weight,
2261+
mask,
22432262
)
22442263
return io.NodeOutput(result)
22452264

0 commit comments

Comments
 (0)