Skip to content

Commit 89b6ce4

Browse files
JasonHokuclaude
andcommitted
feat: upscale pipelines system + save pre-upscale + HiRes prompt adjust
Replace flat upscale configs with pipeline/steps/repeat architecture: - Pipelines run independently from base image - Steps chain sequentially within a pipeline - Steps support repeat count for iterative refinement Add "Save Pre-Upscaled Output" option to also keep the base image. Add "Adjust Prompt During HiRes Fix" with prepend/append/replace behaviors. HiRes-adjusted prompts are batch pre-encoded for efficiency. Results written to manifest metadata and exposed as label overlays. Fix fmtVal scoping error in buildLabelOverlay (was referenced outside the upscale block where it was defined). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9ede82f commit 89b6ce4

12 files changed

Lines changed: 1377 additions & 324 deletions

config_builder_node.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -707,7 +707,24 @@ def generate_config(
707707
session_settings = {}
708708
upscaling_data = state.get("upscaling", {})
709709
if upscaling_data and upscaling_data.get("enabled", False):
710-
session_settings["upscaling"] = upscaling_data
710+
# Filter out inactive pipelines and inactive steps within pipelines
711+
pipelines = upscaling_data.get("pipelines", [])
712+
active_pipelines = []
713+
for p in pipelines:
714+
if p.get("active", True) is False:
715+
continue
716+
active_steps = [s for s in p.get("steps", []) if s.get("active", True) is not False]
717+
if active_steps:
718+
active_pipelines.append({**p, "steps": active_steps})
719+
if active_pipelines:
720+
session_settings["upscaling"] = {
721+
"enabled": True,
722+
"save_pre_upscale": upscaling_data.get("save_pre_upscale", False),
723+
"hires_prompt_adjust": upscaling_data.get("hires_prompt_adjust", False),
724+
"hires_prompt_behavior": upscaling_data.get("hires_prompt_behavior", "append_end"),
725+
"hires_prompt_text": upscaling_data.get("hires_prompt_text", ""),
726+
"pipelines": active_pipelines
727+
}
711728
cooldown_data = state.get("cooldown", {})
712729
if cooldown_data and cooldown_data.get("enabled", False):
713730
session_settings["cooldown"] = cooldown_data

config_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,12 @@ def to_list(x):
613613
"model_prompt_suffix": model_prompt_suffix.strip()
614614
})
615615

616+
# Apply full_run_seed override (per-config seed that overrides node seed)
617+
full_run_seed = entry.get("full_run_seed", 0)
618+
if full_run_seed and int(full_run_seed) > 0:
619+
for c in base_combos:
620+
c["seed"] = int(full_run_seed)
621+
616622
# Apply base seed and extra seeds
617623
for c in base_combos:
618624
expanded.append(c)

generation_orchestrator.py

Lines changed: 282 additions & 91 deletions
Large diffs are not rendered by default.

image_generation.py

Lines changed: 197 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,144 @@ def _flux2_compute_empirical_mu(image_seq_len, num_steps):
262262
return result[0], duration
263263

264264

265+
def tiled_hires_sample(latent_input, patched_model, config, positive_conditioning, negative_conditioning,
266+
hires_steps, hires_denoise, tile_width, tile_height, mask_blur, tile_padding,
267+
force_uniform, pixel_width, pixel_height):
268+
"""
269+
Run HiRes fix sampling in tiles to prevent OOM on large images.
270+
Splits the latent into overlapping tiles, samples each, then blends them back together.
271+
272+
Args:
273+
latent_input: Dict with "samples" key — the upscaled latent to denoise
274+
patched_model: Model for sampling
275+
config: Generation config (seed, cfg, sampler, scheduler)
276+
positive_conditioning, negative_conditioning: Conditioning tensors
277+
hires_steps: Number of sampling steps
278+
hires_denoise: Denoise strength
279+
tile_width, tile_height: Tile size in pixels (will be converted to latent space /8)
280+
mask_blur: Gaussian blur radius for tile seam blending (pixels)
281+
tile_padding: Extra context padding around each tile (pixels)
282+
force_uniform: If True, force all tiles to be the same size (may crop edges)
283+
pixel_width, pixel_height: Full image pixel dimensions (for logging)
284+
285+
Returns:
286+
dict: Result latent dict with "samples" key
287+
"""
288+
import torch
289+
290+
samples = latent_input["samples"]
291+
# Convert pixel dimensions to latent space (8x smaller)
292+
# Handle both 4D [B, C, H, W] and 5D [B, C, T, H, W] latent formats (video VAEs add temporal dim)
293+
if samples.ndim == 5:
294+
lat_h, lat_w = samples.shape[3], samples.shape[4]
295+
else:
296+
lat_h, lat_w = samples.shape[2], samples.shape[3]
297+
tw = tile_width // 8
298+
th = tile_height // 8
299+
pad = tile_padding // 8
300+
blur = max(1, mask_blur // 8) # Blur in latent space
301+
302+
# Calculate tile grid
303+
def calc_tiles(total, tile_size, padding, uniform):
304+
"""Calculate tile start positions with overlap = 2 * padding."""
305+
if total <= tile_size:
306+
return [(0, total)]
307+
stride = tile_size - 2 * padding
308+
if stride <= 0:
309+
stride = tile_size // 2
310+
tiles = []
311+
pos = 0
312+
while pos < total:
313+
end = min(pos + tile_size, total)
314+
if uniform and end == total and end - pos < tile_size:
315+
# Shift last tile back to maintain uniform size
316+
pos = max(0, total - tile_size)
317+
end = total
318+
tiles.append((pos, end))
319+
if end == total:
320+
break
321+
pos += stride
322+
return tiles
323+
324+
x_tiles = calc_tiles(lat_w, tw, pad, force_uniform)
325+
y_tiles = calc_tiles(lat_h, th, pad, force_uniform)
326+
327+
total_tiles = len(x_tiles) * len(y_tiles)
328+
print(f"[GridTester] 🔍 Tiled HiRes sampling: {len(x_tiles)}x{len(y_tiles)} = {total_tiles} tiles "
329+
f"(tile={tile_width}x{tile_height}px, padding={tile_padding}px, blur={mask_blur}px)")
330+
331+
# Output accumulator with weighted blending
332+
result_samples = torch.zeros_like(samples)
333+
is_5d = samples.ndim == 5
334+
335+
# Weight map shape matches spatial dims only
336+
if is_5d:
337+
weight_map = torch.zeros(1, 1, 1, lat_h, lat_w, device=samples.device)
338+
else:
339+
weight_map = torch.zeros(1, 1, lat_h, lat_w, device=samples.device)
340+
341+
tile_idx = 0
342+
for yi, (y_start, y_end) in enumerate(y_tiles):
343+
for xi, (x_start, x_end) in enumerate(x_tiles):
344+
tile_idx += 1
345+
# Extract tile from latent (handle both 4D and 5D)
346+
if is_5d:
347+
tile_latent = samples[:, :, :, y_start:y_end, x_start:x_end].clone()
348+
else:
349+
tile_latent = samples[:, :, y_start:y_end, x_start:x_end].clone()
350+
351+
print(f"[GridTester] 🔍 Tile {tile_idx}/{total_tiles}: latent region [{y_start}:{y_end}, {x_start}:{x_end}]")
352+
353+
# Run KSampler on this tile
354+
tile_result, _ = generate_image(
355+
patched_model, config.get("seed", 0), hires_steps, config.get("cfg", 7),
356+
config.get("sampler", "euler"), config.get("scheduler", "normal"),
357+
positive_conditioning, negative_conditioning,
358+
{"samples": tile_latent}, hires_denoise,
359+
width=(x_end - x_start) * 8, height=(y_end - y_start) * 8
360+
)
361+
362+
tile_out = tile_result["samples"]
363+
tile_h = y_end - y_start
364+
tile_w = x_end - x_start
365+
366+
# Create feathered weight mask for this tile (higher weight in center, fading at edges)
367+
mask = torch.ones(tile_h, tile_w, device=samples.device)
368+
if blur > 0:
369+
# Feather edges: linear ramp over blur pixels
370+
for b in range(blur):
371+
factor = (b + 1) / (blur + 1)
372+
# Top edge
373+
if y_start > 0 and b < tile_h:
374+
mask[b, :] *= factor
375+
# Bottom edge
376+
if y_end < lat_h and b < tile_h:
377+
mask[tile_h - 1 - b, :] *= factor
378+
# Left edge
379+
if x_start > 0 and b < tile_w:
380+
mask[:, b] *= factor
381+
# Right edge
382+
if x_end < lat_w and b < tile_w:
383+
mask[:, tile_w - 1 - b] *= factor
384+
385+
# Accumulate weighted results (broadcast mask to match tensor dims)
386+
if is_5d:
387+
mask_shaped = mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, H, W]
388+
result_samples[:, :, :, y_start:y_end, x_start:x_end] += tile_out * mask_shaped
389+
weight_map[:, :, :, y_start:y_end, x_start:x_end] += mask_shaped
390+
else:
391+
mask_shaped = mask.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
392+
result_samples[:, :, y_start:y_end, x_start:x_end] += tile_out * mask_shaped
393+
weight_map[:, :, y_start:y_end, x_start:x_end] += mask_shaped
394+
395+
# Normalize by weights to blend overlapping regions
396+
weight_map = torch.clamp(weight_map, min=1e-6)
397+
result_samples = result_samples / weight_map
398+
399+
print(f"[GridTester] 🔍 Tiled HiRes sampling complete ({total_tiles} tiles)")
400+
return {"samples": result_samples}
401+
402+
265403
def upscale_image(result_latent, vae, patched_model, upscaling_config, config, positive_conditioning, negative_conditioning, width, height):
266404
"""
267405
Apply upscaling to a generated latent based on upscaling settings.
@@ -289,8 +427,18 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
289427
hires_steps = int(upscaling_config.get("hires_steps", 0)) or config.get("steps", 20)
290428
tiled_vae = upscaling_config.get("tiled_vae", False)
291429
tile_size = int(upscaling_config.get("tile_size", 512))
430+
tile_overlap = int(upscaling_config.get("tile_overlap", 64))
431+
temporal_size = int(upscaling_config.get("temporal_size", 512))
432+
temporal_overlap = int(upscaling_config.get("temporal_overlap", 64))
292433
upscale_model_name = upscaling_config.get("upscale_model", "")
293434
upscale_size = float(upscaling_config.get("upscale_size", 2.0))
435+
resize_method = upscaling_config.get("resize_method", "bilinear")
436+
hires_tiled_sampling = upscaling_config.get("hires_tiled_sampling", False)
437+
hires_tile_width = int(upscaling_config.get("hires_tile_width", 512))
438+
hires_tile_height = int(upscaling_config.get("hires_tile_height", 512))
439+
hires_mask_blur = int(upscaling_config.get("hires_mask_blur", 8))
440+
hires_tile_padding = int(upscaling_config.get("hires_tile_padding", 32))
441+
hires_force_uniform_tiles = upscaling_config.get("hires_force_uniform_tiles", False)
294442

295443
new_w = int(width * upscale_ratio)
296444
new_h = int(height * upscale_ratio)
@@ -304,16 +452,25 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
304452
latent_samples = result_latent["samples"]
305453
# Latent space is 8x smaller than pixel space
306454
upscaled_latent = comfy.utils.common_upscale(
307-
latent_samples, new_w // 8, new_h // 8, "bilinear", "disabled"
455+
latent_samples, new_w // 8, new_h // 8, resize_method, "disabled"
308456
)
309457

310-
hires_latent, hires_duration = generate_image(
311-
patched_model, config.get("seed", 0), hires_steps, config.get("cfg", 7),
312-
config.get("sampler", "euler"), config.get("scheduler", "normal"),
313-
positive_conditioning, negative_conditioning,
314-
{"samples": upscaled_latent}, hires_denoise,
315-
width=new_w, height=new_h
316-
)
458+
if hires_tiled_sampling:
459+
hires_latent = tiled_hires_sample(
460+
{"samples": upscaled_latent}, patched_model, config,
461+
positive_conditioning, negative_conditioning,
462+
hires_steps, hires_denoise,
463+
hires_tile_width, hires_tile_height, hires_mask_blur, hires_tile_padding,
464+
hires_force_uniform_tiles, new_w, new_h
465+
)
466+
else:
467+
hires_latent, hires_duration = generate_image(
468+
patched_model, config.get("seed", 0), hires_steps, config.get("cfg", 7),
469+
config.get("sampler", "euler"), config.get("scheduler", "normal"),
470+
positive_conditioning, negative_conditioning,
471+
{"samples": upscaled_latent}, hires_denoise,
472+
width=new_w, height=new_h
473+
)
317474

318475
duration = round(time.time() - t0, 3)
319476
print(f"[GridTester] 🔍 HiRes fix complete in {duration}s → {new_w}x{new_h}")
@@ -329,6 +486,13 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
329486
from comfy_extras.nodes_post_processing import ImageScaleToTotalPixels
330487
vae.first_stage_model.tile_sample_min_size = tile_size
331488
vae.first_stage_model.tile_latent_min_size = tile_size // 8
489+
vae.first_stage_model.tile_overlap_factor = tile_overlap / tile_size if tile_size > 0 else 0.125
490+
if hasattr(vae.first_stage_model, 'tile_sample_min_size_temporal'):
491+
vae.first_stage_model.tile_sample_min_size_temporal = temporal_size
492+
if hasattr(vae.first_stage_model, 'tile_latent_min_size_temporal'):
493+
vae.first_stage_model.tile_latent_min_size_temporal = temporal_size // 8
494+
if hasattr(vae.first_stage_model, 'tile_overlap_factor_temporal'):
495+
vae.first_stage_model.tile_overlap_factor_temporal = temporal_overlap / temporal_size if temporal_size > 0 else 0.125
332496
pil_image = decode_latent_with_vae(vae, result_latent["samples"])
333497

334498
img_np = np.array(pil_image).astype(np.float32) / 255.0
@@ -347,7 +511,7 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
347511
import comfy.utils
348512
# Resize from model's native output to user-specified upscale_size
349513
upscaled_tensor = upscaled_tensor.permute(0, 3, 1, 2) # NHWC → NCHW
350-
upscaled_tensor = comfy.utils.common_upscale(upscaled_tensor, target_w, target_h, "bilinear", "disabled")
514+
upscaled_tensor = comfy.utils.common_upscale(upscaled_tensor, target_w, target_h, resize_method, "disabled")
351515
upscaled_tensor = upscaled_tensor.permute(0, 2, 3, 1) # NCHW → NHWC
352516

353517
up_np = upscaled_tensor[0].cpu().float().numpy()
@@ -368,6 +532,13 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
368532
if tiled_vae:
369533
vae.first_stage_model.tile_sample_min_size = tile_size
370534
vae.first_stage_model.tile_latent_min_size = tile_size // 8
535+
vae.first_stage_model.tile_overlap_factor = tile_overlap / tile_size if tile_size > 0 else 0.125
536+
if hasattr(vae.first_stage_model, 'tile_sample_min_size_temporal'):
537+
vae.first_stage_model.tile_sample_min_size_temporal = temporal_size
538+
if hasattr(vae.first_stage_model, 'tile_latent_min_size_temporal'):
539+
vae.first_stage_model.tile_latent_min_size_temporal = temporal_size // 8
540+
if hasattr(vae.first_stage_model, 'tile_overlap_factor_temporal'):
541+
vae.first_stage_model.tile_overlap_factor_temporal = temporal_overlap / temporal_size if temporal_size > 0 else 0.125
371542
pil_image = decode_latent_with_vae(vae, result_latent["samples"])
372543

373544
img_np = np.array(pil_image).astype(np.float32) / 255.0
@@ -385,21 +556,30 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
385556
if abs(actual_w - target_w) > 4 or abs(actual_h - target_h) > 4:
386557
import comfy.utils
387558
upscaled_tensor = upscaled_tensor.permute(0, 3, 1, 2) # NHWC → NCHW
388-
upscaled_tensor = comfy.utils.common_upscale(upscaled_tensor, target_w, target_h, "bilinear", "disabled")
559+
upscaled_tensor = comfy.utils.common_upscale(upscaled_tensor, target_w, target_h, resize_method, "disabled")
389560
upscaled_tensor = upscaled_tensor.permute(0, 2, 3, 1) # NCHW → NHWC
390561

391562
up_h, up_w = upscaled_tensor.shape[1], upscaled_tensor.shape[2]
392563

393564
# Encode back to latent space for HiRes fix
394565
encoded_latent = vae.encode(upscaled_tensor[:, :, :, :3])
395566

396-
hires_latent, hires_duration = generate_image(
397-
patched_model, config.get("seed", 0), hires_steps, config.get("cfg", 7),
398-
config.get("sampler", "euler"), config.get("scheduler", "normal"),
399-
positive_conditioning, negative_conditioning,
400-
{"samples": encoded_latent}, hires_denoise,
401-
width=up_w, height=up_h
402-
)
567+
if hires_tiled_sampling:
568+
hires_latent = tiled_hires_sample(
569+
{"samples": encoded_latent}, patched_model, config,
570+
positive_conditioning, negative_conditioning,
571+
hires_steps, hires_denoise,
572+
hires_tile_width, hires_tile_height, hires_mask_blur, hires_tile_padding,
573+
hires_force_uniform_tiles, up_w, up_h
574+
)
575+
else:
576+
hires_latent, hires_duration = generate_image(
577+
patched_model, config.get("seed", 0), hires_steps, config.get("cfg", 7),
578+
config.get("sampler", "euler"), config.get("scheduler", "normal"),
579+
positive_conditioning, negative_conditioning,
580+
{"samples": encoded_latent}, hires_denoise,
581+
width=up_w, height=up_h
582+
)
403583

404584
duration = round(time.time() - t0, 3)
405585
print(f"[GridTester] 🔍 Model+HiRes upscale complete in {duration}s → {up_w}x{up_h}")

resources/logic_state.js

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,17 @@ var labelMode = labelMode || {
9797
cfg: false,
9898
steps: false,
9999
seed: false,
100-
denoise: false
100+
denoise: false,
101+
upscale: false,
102+
upscaleMode: false,
103+
upscaleModel: false,
104+
upscaleRatio: false,
105+
upscaleDenoise: false,
106+
upscaleResizeMethod: false,
107+
upscaleHiresSteps: false,
108+
upscaleTiling: false,
109+
hiresPromptBehavior: false,
110+
hiresPromptText: false
101111
}
102112
};
103113

0 commit comments

Comments
 (0)