-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathalign_nodes.py
More file actions
611 lines (514 loc) · 25.2 KB
/
Copy pathalign_nodes.py
File metadata and controls
611 lines (514 loc) · 25.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
"""LTX-Relight - Alignment / Drift-Reconciliation nodes.
LTX (and other diffusion video relighters) produce output frames that are
visually plausible but not pixel-locked to the source plate. There's typically
a fraction-of-a-pixel global affine drift plus a small amount of local
non-rigid wobble around features (eyes, mouth, hair edges).
For VFX work — especially comp passes like muzzle flashes, lightning, lamp
flicker — the relight needs to align cleanly to the source so artists can
composite the *light contribution* onto the original plate without ringing.
This file provides three small, composable nodes:
LTXRelightAlign - aligns relight frames to source (affine + flow)
LTXRelightDelta - source-subtracted "light contribution" pass
LTXRelightGuard - confidence-gated composite (relight where reliable,
source where the AI hallucinated)
All three operate on standard ComfyUI IMAGE batches ([N,H,W,3] float 0-1) and
emit additional MASK outputs ([N,H,W] float 0-1) where useful.
Dependency note: only `cv2` (already in ltx_relight's requirements.txt) plus
numpy/torch. No models, no GPU work, no extra installs.
Caveman version
---------------
The relight is a tracing of the original drawing on a different sheet of
paper, and the sheets aren't perfectly stacked. Comparing colors won't help
us line them up because the colors are *meant* to be different. So we line
them up by *outlines* (which are the same in both), then squeeze out the last
bit of slop with a flow technique that compares "is this pixel brighter or
darker than its neighbor" — a yes/no question that doesn't care if the whole
face is lit up. Then we can either output the aligned relight, or subtract
source from it to get just the new light (perfect for Nuke flash passes).
"""
from __future__ import annotations
import logging
import numpy as np
import torch
# cv2 is provided by opencv-python (already a runtime dependency of this pack).
# We import it lazily inside functions so a missing-cv2 install doesn't break
# the entire node pack at import time — the existing relight nodes still work.
try:
import cv2 # noqa: F401
_HAS_CV2 = True
except Exception: # pragma: no cover
_HAS_CV2 = False
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Tensor / numpy bridges
# ---------------------------------------------------------------------------
def _img_to_np_u8(img_t: torch.Tensor) -> np.ndarray:
"""ComfyUI IMAGE [H,W,3] float 0..1 -> np.uint8 BGR for OpenCV."""
arr = img_t.detach().cpu().numpy()
arr = np.clip(arr, 0.0, 1.0) * 255.0
arr = arr.astype(np.uint8)
# ComfyUI uses RGB; OpenCV defaults to BGR. We feed cv2 algos that don't
# care about channel order (ECC on grayscale, warpAffine), but be explicit.
return arr # RGB order — we'll keep RGB end-to-end.
def _np_to_img_t(arr_u8: np.ndarray, ref: torch.Tensor) -> torch.Tensor:
"""np.uint8 [H,W,3] (RGB) -> torch IMAGE [H,W,3] float 0..1 on ref device."""
t = torch.from_numpy(arr_u8.astype(np.float32) / 255.0)
return t.to(device=ref.device, dtype=ref.dtype)
def _to_gray_u8(img_u8: np.ndarray) -> np.ndarray:
"""[H,W,3] uint8 -> [H,W] uint8 luminance (Rec.709)."""
import cv2 # local import — see module docstring
# cv2.cvtColor expects BGR for COLOR_BGR2GRAY but the matrix is symmetric
# enough for our needs. Use explicit weights for accuracy.
return cv2.cvtColor(img_u8, cv2.COLOR_RGB2GRAY)
# ---------------------------------------------------------------------------
# Edge maps for lighting-invariant alignment
# ---------------------------------------------------------------------------
def _edge_map(gray_u8: np.ndarray, method: str) -> np.ndarray:
"""Produce a [H,W] float32 edge map roughly invariant to lighting changes.
Sobel is fast and smooth; Canny is sharper but needs threshold tuning;
`luma` skips edges entirely and uses the gray frame (only useful when the
relight is a small lighting tweak).
"""
import cv2
if method == "luma":
return gray_u8.astype(np.float32) / 255.0
# Mild blur kills sensor noise without softening structural edges
blur = cv2.GaussianBlur(gray_u8, (3, 3), 0)
if method == "canny":
# Auto thresholds via Otsu on the blurred image
otsu, _ = cv2.threshold(blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
edges = cv2.Canny(blur, otsu * 0.5, otsu)
return edges.astype(np.float32) / 255.0
# Default: Sobel magnitude, normalized
gx = cv2.Sobel(blur, cv2.CV_32F, 1, 0, ksize=3)
gy = cv2.Sobel(blur, cv2.CV_32F, 0, 1, ksize=3)
mag = cv2.magnitude(gx, gy)
m = float(mag.max()) or 1.0
return mag / m
# ---------------------------------------------------------------------------
# Layer 1 — global affine via ECC on edge maps
# ---------------------------------------------------------------------------
def _fit_affine_ecc(
src_edges: np.ndarray,
rel_edges: np.ndarray,
iters: int = 200,
eps: float = 1e-4,
multi_scale: bool = True,
warp_model: str = "euclidean",
) -> np.ndarray:
"""Return 2x3 affine M such that warpAffine(rel_edges, M) ~= src_edges.
Falls back to identity if ECC fails to converge (e.g. near-black frames).
warp_model:
'translation' - 2 DoF (tx, ty). Use when camera is locked.
'euclidean' - 4 DoF (tx, ty, rotation, uniform scale). Default.
'affine' - 6 DoF. Allows non-uniform scale + shear (can cause
lens-distortion-like widening on relight outputs).
"""
import cv2
motion_map = {
"translation": cv2.MOTION_TRANSLATION,
"euclidean": cv2.MOTION_EUCLIDEAN,
"affine": cv2.MOTION_AFFINE,
}
motion = motion_map.get(warp_model, cv2.MOTION_EUCLIDEAN)
M = np.eye(2, 3, dtype=np.float32)
# Multi-scale warm start: fit on half-res first, then refine at full-res.
# Halves run-time on stubborn frames and handles bigger drifts.
pyramid = [(src_edges, rel_edges, M)]
if multi_scale:
h, w = src_edges.shape
if min(h, w) >= 256:
ds_src = cv2.resize(src_edges, (w // 2, h // 2), interpolation=cv2.INTER_AREA)
ds_rel = cv2.resize(rel_edges, (w // 2, h // 2), interpolation=cv2.INTER_AREA)
try:
_, M_half = cv2.findTransformECC(
ds_src, ds_rel, np.eye(2, 3, dtype=np.float32),
motion,
(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, iters, eps),
None, 1,
)
# Scale translation back to full res; rotation/scale terms unchanged
M = M_half.copy()
M[0, 2] *= 2.0
M[1, 2] *= 2.0
except cv2.error:
pass
try:
_, M = cv2.findTransformECC(
src_edges, rel_edges, M,
motion,
(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, iters, eps),
None, 1,
)
except cv2.error as e:
log.warning("[RelightAlign] ECC failed (%s) — returning identity", e)
return np.eye(2, 3, dtype=np.float32)
return M.astype(np.float32)
def _clamp_affine(M: np.ndarray, max_shift_px: float, h: int, w: int) -> np.ndarray:
"""Reject affines that translate more than max_shift_px in either axis.
Big translations are almost always ECC failures, not real drift. We don't
clamp scale/rotation because those tend to be small even when broken.
"""
if max_shift_px <= 0:
return M
tx = float(M[0, 2])
ty = float(M[1, 2])
if abs(tx) > max_shift_px or abs(ty) > max_shift_px:
log.warning(
"[RelightAlign] affine shift (%.2f, %.2f) exceeds max %d — clamping to identity",
tx, ty, max_shift_px,
)
return np.eye(2, 3, dtype=np.float32)
return M
# ---------------------------------------------------------------------------
# Layer 2 — local refinement via census-transform optical flow
# ---------------------------------------------------------------------------
def _census_transform(gray_u8: np.ndarray, window: int = 7) -> np.ndarray:
"""Census transform: each pixel becomes a bit-string encoding sign of
differences with its (window*window - 1) neighbors.
Output is [H, W, K] uint8 (one byte per neighbor), where K = window**2 - 1.
Hamming distance between two census codes is invariant to monotonic
lighting changes — exactly what we need for relight alignment.
"""
h, w = gray_u8.shape
radius = window // 2
pad = np.pad(gray_u8.astype(np.int16), radius, mode="edge")
center = pad[radius:radius + h, radius:radius + w]
bits = []
for dy in range(-radius, radius + 1):
for dx in range(-radius, radius + 1):
if dy == 0 and dx == 0:
continue
shifted = pad[radius + dy:radius + dy + h, radius + dx:radius + dx + w]
bits.append((shifted < center).astype(np.uint8))
return np.stack(bits, axis=-1)
def _refine_flow(
src_gray: np.ndarray,
rel_gray_warped: np.ndarray,
radius_px: int = 2,
window: int = 7,
) -> np.ndarray:
"""Sub-pixel residual flow via census-transform local search.
Searches a small ±radius_px window around each pixel for the offset that
minimizes Hamming distance between census codes. Returns [H,W,2] float32
flow (dx, dy) that should be applied to rel_gray_warped to bring it into
register with src_gray.
This is a deliberately small local search — Layer 1 (ECC affine) has
already removed everything bigger than ~1 px. We only need to mop up
the sub-pixel residual.
"""
src_census = _census_transform(src_gray, window=window)
rel_census = _census_transform(rel_gray_warped, window=window)
h, w, k = src_census.shape
best_cost = np.full((h, w), 255 * k, dtype=np.int32)
best_dx = np.zeros((h, w), dtype=np.int32)
best_dy = np.zeros((h, w), dtype=np.int32)
for dy in range(-radius_px, radius_px + 1):
for dx in range(-radius_px, radius_px + 1):
shifted = np.pad(
rel_census,
((max(dy, 0), max(-dy, 0)),
(max(dx, 0), max(-dx, 0)),
(0, 0)),
mode="edge",
)[max(-dy, 0):max(-dy, 0) + h,
max(-dx, 0):max(-dx, 0) + w]
cost = np.sum(src_census != shifted, axis=-1)
mask = cost < best_cost
best_cost[mask] = cost[mask]
best_dx[mask] = dx
best_dy[mask] = dy
flow = np.stack([best_dx, best_dy], axis=-1).astype(np.float32)
# Smooth the flow field aggressively. Per-pixel census matches are noisy in
# flat / low-texture regions (sky, walls, skin) — without smoothing they
# warp every pixel independently and shred the image. A wide median filter
# collapses spurious matches into the dominant local motion.
import cv2
flow[..., 0] = cv2.medianBlur(flow[..., 0], 9)
flow[..., 1] = cv2.medianBlur(flow[..., 1], 9)
# Final Gaussian smooth for sub-pixel continuity
flow[..., 0] = cv2.GaussianBlur(flow[..., 0], (9, 9), 0)
flow[..., 1] = cv2.GaussianBlur(flow[..., 1], (9, 9), 0)
return flow
def _warp_image_by_flow(img_u8: np.ndarray, flow: np.ndarray) -> np.ndarray:
"""Warp img by per-pixel flow [H,W,2] (dx, dy)."""
import cv2
h, w = img_u8.shape[:2]
grid_x, grid_y = np.meshgrid(np.arange(w, dtype=np.float32),
np.arange(h, dtype=np.float32))
map_x = grid_x + flow[..., 0]
map_y = grid_y + flow[..., 1]
return cv2.remap(img_u8, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
# ---------------------------------------------------------------------------
# Confidence map: structural agreement after alignment
# ---------------------------------------------------------------------------
def _confidence_map(src_u8: np.ndarray, aligned_u8: np.ndarray) -> np.ndarray:
"""[H,W] float 0..1 where 1 = very confident the alignment + identity match,
0 = the AI probably hallucinated this region.
We compare *structure* (edges) not raw RGB, since the lighting is meant to
differ. Then convert to a per-pixel score via local SSIM on edge maps.
"""
import cv2
s_gray = _to_gray_u8(src_u8)
a_gray = _to_gray_u8(aligned_u8)
s_edges = _edge_map(s_gray, "sobel")
a_edges = _edge_map(a_gray, "sobel")
# Local mean / variance via box filter
k = 11
mu_s = cv2.boxFilter(s_edges, -1, (k, k))
mu_a = cv2.boxFilter(a_edges, -1, (k, k))
s_sq = cv2.boxFilter(s_edges * s_edges, -1, (k, k)) - mu_s * mu_s
a_sq = cv2.boxFilter(a_edges * a_edges, -1, (k, k)) - mu_a * mu_a
sa = cv2.boxFilter(s_edges * a_edges, -1, (k, k)) - mu_s * mu_a
c1, c2 = 0.01 ** 2, 0.03 ** 2
num = (2 * mu_s * mu_a + c1) * (2 * sa + c2)
den = (mu_s * mu_s + mu_a * mu_a + c1) * (s_sq + a_sq + c2)
ssim = np.clip(num / np.maximum(den, 1e-8), 0.0, 1.0)
return ssim.astype(np.float32)
# ---------------------------------------------------------------------------
# NODE 1 — RelightAlign
# ---------------------------------------------------------------------------
class LTXRelightAlign:
"""Align relight frames to source using ECC affine + optional census flow.
Outputs the warped relight plus a per-pixel confidence mask so downstream
nodes can decide where to trust the AI vs fall back to source.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"source": ("IMAGE", {
"tooltip": "Original source plate (the ground truth for alignment).",
}),
"relight": ("IMAGE", {
"tooltip": "Diffusion-relit output to be aligned to source.",
}),
"mode": (["affine only", "affine + flow", "flow only"], {
"default": "affine only",
"tooltip": "Affine = global drift fix (fast, safe). Flow = sub-pixel local fix (can over-warp on flat regions — only enable for tight identity-bound shots). Default is affine-only.",
}),
"warp_model": (["euclidean", "translation", "affine"], {
"default": "euclidean",
"tooltip": "Warp DoF. 'euclidean' = shift+rotate+uniform-scale (recommended). 'translation' = pure shift (locked-off cams). 'affine' = full 6-DoF (allows non-uniform scale & shear — can cause lens-distortion-like widening).",
}),
"edge_method": (["sobel", "canny", "luma"], {
"default": "sobel",
"tooltip": "What ECC fits on. Sobel = fast & robust. Canny = sharper. Luma = no edges (only for tiny lighting tweaks).",
}),
"max_shift_px": ("INT", {
"default": 32, "min": 0, "max": 256,
"tooltip": "Reject affine fits that translate more than this many pixels (sanity clamp). 0 = no clamp.",
}),
"flow_radius_px": ("INT", {
"default": 1, "min": 1, "max": 8,
"tooltip": "Census-flow search radius. 1 px is plenty after affine. Higher = slower AND more risk of warping artifacts in flat regions.",
}),
},
"optional": {
"anchor_first_frame": ("BOOLEAN", {
"default": False,
"tooltip": "Align every relight frame to source frame 0 instead of pair-wise. Locks output to a clean reference even if source has its own micro-jitter.",
}),
},
}
RETURN_TYPES = ("IMAGE", "MASK")
RETURN_NAMES = ("aligned", "confidence")
FUNCTION = "align"
CATEGORY = "LTX-Relight"
def align(
self,
source: torch.Tensor,
relight: torch.Tensor,
mode: str,
edge_method: str,
max_shift_px: int,
flow_radius_px: int,
warp_model: str = "euclidean",
anchor_first_frame: bool = False,
):
if not _HAS_CV2:
raise RuntimeError(
"LTXRelightAlign requires opencv-python. Install via "
"`pip install opencv-python` in the ComfyUI Python env."
)
import cv2 # noqa: F401
if source.shape[0] != relight.shape[0]:
# Broadcast the shorter one (typically source has 1 frame, relight has N)
if source.shape[0] == 1:
source = source.expand(relight.shape[0], -1, -1, -1)
elif relight.shape[0] == 1:
relight = relight.expand(source.shape[0], -1, -1, -1)
else:
raise ValueError(
f"Source and relight have different frame counts "
f"({source.shape[0]} vs {relight.shape[0]}) and neither is 1."
)
if source.shape[1:3] != relight.shape[1:3]:
raise ValueError(
f"Source/relight resolutions differ: {source.shape[1:3]} vs {relight.shape[1:3]}. "
"Resize one to match before aligning."
)
n, h, w, _ = relight.shape
out_aligned = torch.empty_like(relight)
out_conf = torch.zeros((n, h, w), dtype=relight.dtype, device=relight.device)
# Cache the anchor-mode source frame
anchor_src_u8 = _img_to_np_u8(source[0]) if anchor_first_frame else None
for i in range(n):
src_u8 = anchor_src_u8 if anchor_first_frame else _img_to_np_u8(source[i])
rel_u8 = _img_to_np_u8(relight[i])
# ---- Layer 1: ECC affine on edges ----
warped_u8 = rel_u8
if mode in ("affine + flow", "affine only"):
src_gray = _to_gray_u8(src_u8)
rel_gray = _to_gray_u8(rel_u8)
src_edges = _edge_map(src_gray, edge_method)
rel_edges = _edge_map(rel_gray, edge_method)
M = _fit_affine_ecc(src_edges, rel_edges, warp_model=warp_model)
M = _clamp_affine(M, float(max_shift_px), h, w)
warped_u8 = cv2.warpAffine(
rel_u8, M, (w, h),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REFLECT,
)
# ---- Layer 2: census-transform residual flow ----
if mode in ("affine + flow", "flow only"):
src_gray = _to_gray_u8(src_u8)
warp_gray = _to_gray_u8(warped_u8)
flow = _refine_flow(src_gray, warp_gray, radius_px=int(flow_radius_px))
warped_u8 = _warp_image_by_flow(warped_u8, flow)
# ---- Confidence map ----
conf = _confidence_map(src_u8, warped_u8)
out_aligned[i] = _np_to_img_t(warped_u8, relight)
out_conf[i] = torch.from_numpy(conf).to(device=relight.device, dtype=relight.dtype)
return (out_aligned, out_conf)
# ---------------------------------------------------------------------------
# NODE 2 — RelightDelta
# ---------------------------------------------------------------------------
class LTXRelightDelta:
"""Source-subtracted output: gives you a pure light-contribution layer
suitable for screen/plus-comp in Nuke.
Modes
-----
add - max(relight - source, 0) — "extra light" only
diff - signed (relight - source)*0.5+0.5 — for inspection / mattes
ratio - log2(relight / source) mapped 0..1 — for grading workflows
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"source": ("IMAGE",),
"aligned_relight": ("IMAGE",),
"mode": (["add", "diff", "ratio"], {
"default": "add",
"tooltip": "add = light contribution (most common for VFX flash passes). diff = signed delta. ratio = log delta.",
}),
"gain": ("FLOAT", {
"default": 1.0, "min": 0.0, "max": 8.0, "step": 0.05,
"tooltip": "Multiplier on the delta before clamping. >1 boosts subtle changes.",
}),
"matte_threshold": ("FLOAT", {
"default": 0.05, "min": 0.0, "max": 1.0, "step": 0.005,
"tooltip": "Pixels with delta luma below this become 0 in the matte (kills noise floor).",
}),
},
}
RETURN_TYPES = ("IMAGE", "MASK")
RETURN_NAMES = ("delta", "delta_matte")
FUNCTION = "delta"
CATEGORY = "LTX-Relight"
def delta(self, source, aligned_relight, mode, gain, matte_threshold):
if source.shape != aligned_relight.shape:
if source.shape[0] == 1 and aligned_relight.shape[0] > 1:
source = source.expand_as(aligned_relight)
else:
raise ValueError(
f"Shape mismatch: source {tuple(source.shape)} vs "
f"aligned_relight {tuple(aligned_relight.shape)}"
)
s = source.clamp(0.0, 1.0)
r = aligned_relight.clamp(0.0, 1.0)
if mode == "add":
d = (r - s) * gain
d = d.clamp(0.0, 1.0)
elif mode == "diff":
d = ((r - s) * gain) * 0.5 + 0.5
d = d.clamp(0.0, 1.0)
else: # ratio
num = r + 1e-3
den = s + 1e-3
d = torch.log2(num / den) * gain
# Map -3..+3 EV stops -> 0..1 for visualization
d = (d / 6.0 + 0.5).clamp(0.0, 1.0)
# Matte from delta luma (Rec.709)
luma = 0.2126 * d[..., 0] + 0.7152 * d[..., 1] + 0.0722 * d[..., 2]
matte = (luma >= matte_threshold).to(d.dtype) * luma.clamp(0.0, 1.0)
return (d, matte)
# ---------------------------------------------------------------------------
# NODE 3 — RelightGuard
# ---------------------------------------------------------------------------
class LTXRelightGuard:
"""Composite aligned relight onto source, falling back to source where the
confidence map says the alignment / identity is unreliable.
Use this when you want a finished plate (not a comp element). The output
is your "best of both worlds" — relight pixels where they make sense,
source pixels where the AI hallucinated.
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"source": ("IMAGE",),
"aligned_relight": ("IMAGE",),
"confidence": ("MASK",),
"threshold": ("FLOAT", {
"default": 0.4, "min": 0.0, "max": 1.0, "step": 0.01,
"tooltip": "Confidence below this becomes pure source. Above scales linearly to pure relight.",
}),
"feather_px": ("INT", {
"default": 8, "min": 0, "max": 64,
"tooltip": "Gaussian blur radius applied to the confidence mask. Softens hard cutoffs.",
}),
},
}
RETURN_TYPES = ("IMAGE", "MASK")
RETURN_NAMES = ("composite", "blend_mask")
FUNCTION = "guard"
CATEGORY = "LTX-Relight"
def guard(self, source, aligned_relight, confidence, threshold, feather_px):
if not _HAS_CV2 and feather_px > 0:
log.warning("[RelightGuard] cv2 unavailable — skipping feather")
feather_px = 0
# Broadcast source if needed
if source.shape[0] == 1 and aligned_relight.shape[0] > 1:
source = source.expand_as(aligned_relight)
# Remap confidence: below threshold -> 0, above -> 0..1
c = confidence.clamp(0.0, 1.0)
c = ((c - threshold) / max(1.0 - threshold, 1e-3)).clamp(0.0, 1.0)
if feather_px > 0:
import cv2
n = c.shape[0]
ksize = feather_px * 2 + 1
out = torch.empty_like(c)
for i in range(n):
arr = c[i].detach().cpu().numpy()
arr = cv2.GaussianBlur(arr, (ksize, ksize), 0)
out[i] = torch.from_numpy(arr).to(device=c.device, dtype=c.dtype)
c = out
# Per-pixel blend, broadcasting mask across RGB
c3 = c.unsqueeze(-1)
composite = aligned_relight * c3 + source * (1.0 - c3)
return (composite.clamp(0.0, 1.0), c)
# ---------------------------------------------------------------------------
# Registration helpers (imported by __init__.py / nodes.py)
# ---------------------------------------------------------------------------
ALIGN_NODE_CLASS_MAPPINGS = {
"LTXRelightAlign": LTXRelightAlign,
"LTXRelightDelta": LTXRelightDelta,
"LTXRelightGuard": LTXRelightGuard,
}
ALIGN_NODE_DISPLAY_NAME_MAPPINGS = {
"LTXRelightAlign": "LTX Relight - Align to Source",
"LTXRelightDelta": "LTX Relight - Source-Subtracted Delta",
"LTXRelightGuard": "LTX Relight - Confidence Composite",
}