@@ -262,6 +262,144 @@ def _flux2_compute_empirical_mu(image_seq_len, num_steps):
262262 return result [0 ], duration
263263
264264
265+ def tiled_hires_sample (latent_input , patched_model , config , positive_conditioning , negative_conditioning ,
266+ hires_steps , hires_denoise , tile_width , tile_height , mask_blur , tile_padding ,
267+ force_uniform , pixel_width , pixel_height ):
268+ """
269+ Run HiRes fix sampling in tiles to prevent OOM on large images.
270+ Splits the latent into overlapping tiles, samples each, then blends them back together.
271+
272+ Args:
273+ latent_input: Dict with "samples" key — the upscaled latent to denoise
274+ patched_model: Model for sampling
275+ config: Generation config (seed, cfg, sampler, scheduler)
276+ positive_conditioning, negative_conditioning: Conditioning tensors
277+ hires_steps: Number of sampling steps
278+ hires_denoise: Denoise strength
279+ tile_width, tile_height: Tile size in pixels (will be converted to latent space /8)
280+ mask_blur: Gaussian blur radius for tile seam blending (pixels)
281+ tile_padding: Extra context padding around each tile (pixels)
282+ force_uniform: If True, force all tiles to be the same size (may crop edges)
283+ pixel_width, pixel_height: Full image pixel dimensions (for logging)
284+
285+ Returns:
286+ dict: Result latent dict with "samples" key
287+ """
288+ import torch
289+
290+ samples = latent_input ["samples" ]
291+ # Convert pixel dimensions to latent space (8x smaller)
292+ # Handle both 4D [B, C, H, W] and 5D [B, C, T, H, W] latent formats (video VAEs add temporal dim)
293+ if samples .ndim == 5 :
294+ lat_h , lat_w = samples .shape [3 ], samples .shape [4 ]
295+ else :
296+ lat_h , lat_w = samples .shape [2 ], samples .shape [3 ]
297+ tw = tile_width // 8
298+ th = tile_height // 8
299+ pad = tile_padding // 8
300+ blur = max (1 , mask_blur // 8 ) # Blur in latent space
301+
302+ # Calculate tile grid
303+ def calc_tiles (total , tile_size , padding , uniform ):
304+ """Calculate tile start positions with overlap = 2 * padding."""
305+ if total <= tile_size :
306+ return [(0 , total )]
307+ stride = tile_size - 2 * padding
308+ if stride <= 0 :
309+ stride = tile_size // 2
310+ tiles = []
311+ pos = 0
312+ while pos < total :
313+ end = min (pos + tile_size , total )
314+ if uniform and end == total and end - pos < tile_size :
315+ # Shift last tile back to maintain uniform size
316+ pos = max (0 , total - tile_size )
317+ end = total
318+ tiles .append ((pos , end ))
319+ if end == total :
320+ break
321+ pos += stride
322+ return tiles
323+
324+ x_tiles = calc_tiles (lat_w , tw , pad , force_uniform )
325+ y_tiles = calc_tiles (lat_h , th , pad , force_uniform )
326+
327+ total_tiles = len (x_tiles ) * len (y_tiles )
328+ print (f"[GridTester] 🔍 Tiled HiRes sampling: { len (x_tiles )} x{ len (y_tiles )} = { total_tiles } tiles "
329+ f"(tile={ tile_width } x{ tile_height } px, padding={ tile_padding } px, blur={ mask_blur } px)" )
330+
331+ # Output accumulator with weighted blending
332+ result_samples = torch .zeros_like (samples )
333+ is_5d = samples .ndim == 5
334+
335+ # Weight map shape matches spatial dims only
336+ if is_5d :
337+ weight_map = torch .zeros (1 , 1 , 1 , lat_h , lat_w , device = samples .device )
338+ else :
339+ weight_map = torch .zeros (1 , 1 , lat_h , lat_w , device = samples .device )
340+
341+ tile_idx = 0
342+ for yi , (y_start , y_end ) in enumerate (y_tiles ):
343+ for xi , (x_start , x_end ) in enumerate (x_tiles ):
344+ tile_idx += 1
345+ # Extract tile from latent (handle both 4D and 5D)
346+ if is_5d :
347+ tile_latent = samples [:, :, :, y_start :y_end , x_start :x_end ].clone ()
348+ else :
349+ tile_latent = samples [:, :, y_start :y_end , x_start :x_end ].clone ()
350+
351+ print (f"[GridTester] 🔍 Tile { tile_idx } /{ total_tiles } : latent region [{ y_start } :{ y_end } , { x_start } :{ x_end } ]" )
352+
353+ # Run KSampler on this tile
354+ tile_result , _ = generate_image (
355+ patched_model , config .get ("seed" , 0 ), hires_steps , config .get ("cfg" , 7 ),
356+ config .get ("sampler" , "euler" ), config .get ("scheduler" , "normal" ),
357+ positive_conditioning , negative_conditioning ,
358+ {"samples" : tile_latent }, hires_denoise ,
359+ width = (x_end - x_start ) * 8 , height = (y_end - y_start ) * 8
360+ )
361+
362+ tile_out = tile_result ["samples" ]
363+ tile_h = y_end - y_start
364+ tile_w = x_end - x_start
365+
366+ # Create feathered weight mask for this tile (higher weight in center, fading at edges)
367+ mask = torch .ones (tile_h , tile_w , device = samples .device )
368+ if blur > 0 :
369+ # Feather edges: linear ramp over blur pixels
370+ for b in range (blur ):
371+ factor = (b + 1 ) / (blur + 1 )
372+ # Top edge
373+ if y_start > 0 and b < tile_h :
374+ mask [b , :] *= factor
375+ # Bottom edge
376+ if y_end < lat_h and b < tile_h :
377+ mask [tile_h - 1 - b , :] *= factor
378+ # Left edge
379+ if x_start > 0 and b < tile_w :
380+ mask [:, b ] *= factor
381+ # Right edge
382+ if x_end < lat_w and b < tile_w :
383+ mask [:, tile_w - 1 - b ] *= factor
384+
385+ # Accumulate weighted results (broadcast mask to match tensor dims)
386+ if is_5d :
387+ mask_shaped = mask .unsqueeze (0 ).unsqueeze (0 ).unsqueeze (0 ) # [1, 1, 1, H, W]
388+ result_samples [:, :, :, y_start :y_end , x_start :x_end ] += tile_out * mask_shaped
389+ weight_map [:, :, :, y_start :y_end , x_start :x_end ] += mask_shaped
390+ else :
391+ mask_shaped = mask .unsqueeze (0 ).unsqueeze (0 ) # [1, 1, H, W]
392+ result_samples [:, :, y_start :y_end , x_start :x_end ] += tile_out * mask_shaped
393+ weight_map [:, :, y_start :y_end , x_start :x_end ] += mask_shaped
394+
395+ # Normalize by weights to blend overlapping regions
396+ weight_map = torch .clamp (weight_map , min = 1e-6 )
397+ result_samples = result_samples / weight_map
398+
399+ print (f"[GridTester] 🔍 Tiled HiRes sampling complete ({ total_tiles } tiles)" )
400+ return {"samples" : result_samples }
401+
402+
265403def upscale_image (result_latent , vae , patched_model , upscaling_config , config , positive_conditioning , negative_conditioning , width , height ):
266404 """
267405 Apply upscaling to a generated latent based on upscaling settings.
@@ -289,8 +427,18 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
289427 hires_steps = int (upscaling_config .get ("hires_steps" , 0 )) or config .get ("steps" , 20 )
290428 tiled_vae = upscaling_config .get ("tiled_vae" , False )
291429 tile_size = int (upscaling_config .get ("tile_size" , 512 ))
430+ tile_overlap = int (upscaling_config .get ("tile_overlap" , 64 ))
431+ temporal_size = int (upscaling_config .get ("temporal_size" , 512 ))
432+ temporal_overlap = int (upscaling_config .get ("temporal_overlap" , 64 ))
292433 upscale_model_name = upscaling_config .get ("upscale_model" , "" )
293434 upscale_size = float (upscaling_config .get ("upscale_size" , 2.0 ))
435+ resize_method = upscaling_config .get ("resize_method" , "bilinear" )
436+ hires_tiled_sampling = upscaling_config .get ("hires_tiled_sampling" , False )
437+ hires_tile_width = int (upscaling_config .get ("hires_tile_width" , 512 ))
438+ hires_tile_height = int (upscaling_config .get ("hires_tile_height" , 512 ))
439+ hires_mask_blur = int (upscaling_config .get ("hires_mask_blur" , 8 ))
440+ hires_tile_padding = int (upscaling_config .get ("hires_tile_padding" , 32 ))
441+ hires_force_uniform_tiles = upscaling_config .get ("hires_force_uniform_tiles" , False )
294442
295443 new_w = int (width * upscale_ratio )
296444 new_h = int (height * upscale_ratio )
@@ -304,16 +452,25 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
304452 latent_samples = result_latent ["samples" ]
305453 # Latent space is 8x smaller than pixel space
306454 upscaled_latent = comfy .utils .common_upscale (
307- latent_samples , new_w // 8 , new_h // 8 , "bilinear" , "disabled"
455+ latent_samples , new_w // 8 , new_h // 8 , resize_method , "disabled"
308456 )
309457
310- hires_latent , hires_duration = generate_image (
311- patched_model , config .get ("seed" , 0 ), hires_steps , config .get ("cfg" , 7 ),
312- config .get ("sampler" , "euler" ), config .get ("scheduler" , "normal" ),
313- positive_conditioning , negative_conditioning ,
314- {"samples" : upscaled_latent }, hires_denoise ,
315- width = new_w , height = new_h
316- )
458+ if hires_tiled_sampling :
459+ hires_latent = tiled_hires_sample (
460+ {"samples" : upscaled_latent }, patched_model , config ,
461+ positive_conditioning , negative_conditioning ,
462+ hires_steps , hires_denoise ,
463+ hires_tile_width , hires_tile_height , hires_mask_blur , hires_tile_padding ,
464+ hires_force_uniform_tiles , new_w , new_h
465+ )
466+ else :
467+ hires_latent , hires_duration = generate_image (
468+ patched_model , config .get ("seed" , 0 ), hires_steps , config .get ("cfg" , 7 ),
469+ config .get ("sampler" , "euler" ), config .get ("scheduler" , "normal" ),
470+ positive_conditioning , negative_conditioning ,
471+ {"samples" : upscaled_latent }, hires_denoise ,
472+ width = new_w , height = new_h
473+ )
317474
318475 duration = round (time .time () - t0 , 3 )
319476 print (f"[GridTester] 🔍 HiRes fix complete in { duration } s → { new_w } x{ new_h } " )
@@ -329,6 +486,13 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
329486 from comfy_extras .nodes_post_processing import ImageScaleToTotalPixels
330487 vae .first_stage_model .tile_sample_min_size = tile_size
331488 vae .first_stage_model .tile_latent_min_size = tile_size // 8
489+ vae .first_stage_model .tile_overlap_factor = tile_overlap / tile_size if tile_size > 0 else 0.125
490+ if hasattr (vae .first_stage_model , 'tile_sample_min_size_temporal' ):
491+ vae .first_stage_model .tile_sample_min_size_temporal = temporal_size
492+ if hasattr (vae .first_stage_model , 'tile_latent_min_size_temporal' ):
493+ vae .first_stage_model .tile_latent_min_size_temporal = temporal_size // 8
494+ if hasattr (vae .first_stage_model , 'tile_overlap_factor_temporal' ):
495+ vae .first_stage_model .tile_overlap_factor_temporal = temporal_overlap / temporal_size if temporal_size > 0 else 0.125
332496 pil_image = decode_latent_with_vae (vae , result_latent ["samples" ])
333497
334498 img_np = np .array (pil_image ).astype (np .float32 ) / 255.0
@@ -347,7 +511,7 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
347511 import comfy .utils
348512 # Resize from model's native output to user-specified upscale_size
349513 upscaled_tensor = upscaled_tensor .permute (0 , 3 , 1 , 2 ) # NHWC → NCHW
350- upscaled_tensor = comfy .utils .common_upscale (upscaled_tensor , target_w , target_h , "bilinear" , "disabled" )
514+ upscaled_tensor = comfy .utils .common_upscale (upscaled_tensor , target_w , target_h , resize_method , "disabled" )
351515 upscaled_tensor = upscaled_tensor .permute (0 , 2 , 3 , 1 ) # NCHW → NHWC
352516
353517 up_np = upscaled_tensor [0 ].cpu ().float ().numpy ()
@@ -368,6 +532,13 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
368532 if tiled_vae :
369533 vae .first_stage_model .tile_sample_min_size = tile_size
370534 vae .first_stage_model .tile_latent_min_size = tile_size // 8
535+ vae .first_stage_model .tile_overlap_factor = tile_overlap / tile_size if tile_size > 0 else 0.125
536+ if hasattr (vae .first_stage_model , 'tile_sample_min_size_temporal' ):
537+ vae .first_stage_model .tile_sample_min_size_temporal = temporal_size
538+ if hasattr (vae .first_stage_model , 'tile_latent_min_size_temporal' ):
539+ vae .first_stage_model .tile_latent_min_size_temporal = temporal_size // 8
540+ if hasattr (vae .first_stage_model , 'tile_overlap_factor_temporal' ):
541+ vae .first_stage_model .tile_overlap_factor_temporal = temporal_overlap / temporal_size if temporal_size > 0 else 0.125
371542 pil_image = decode_latent_with_vae (vae , result_latent ["samples" ])
372543
373544 img_np = np .array (pil_image ).astype (np .float32 ) / 255.0
@@ -385,21 +556,30 @@ def upscale_image(result_latent, vae, patched_model, upscaling_config, config, p
385556 if abs (actual_w - target_w ) > 4 or abs (actual_h - target_h ) > 4 :
386557 import comfy .utils
387558 upscaled_tensor = upscaled_tensor .permute (0 , 3 , 1 , 2 ) # NHWC → NCHW
388- upscaled_tensor = comfy .utils .common_upscale (upscaled_tensor , target_w , target_h , "bilinear" , "disabled" )
559+ upscaled_tensor = comfy .utils .common_upscale (upscaled_tensor , target_w , target_h , resize_method , "disabled" )
389560 upscaled_tensor = upscaled_tensor .permute (0 , 2 , 3 , 1 ) # NCHW → NHWC
390561
391562 up_h , up_w = upscaled_tensor .shape [1 ], upscaled_tensor .shape [2 ]
392563
393564 # Encode back to latent space for HiRes fix
394565 encoded_latent = vae .encode (upscaled_tensor [:, :, :, :3 ])
395566
396- hires_latent , hires_duration = generate_image (
397- patched_model , config .get ("seed" , 0 ), hires_steps , config .get ("cfg" , 7 ),
398- config .get ("sampler" , "euler" ), config .get ("scheduler" , "normal" ),
399- positive_conditioning , negative_conditioning ,
400- {"samples" : encoded_latent }, hires_denoise ,
401- width = up_w , height = up_h
402- )
567+ if hires_tiled_sampling :
568+ hires_latent = tiled_hires_sample (
569+ {"samples" : encoded_latent }, patched_model , config ,
570+ positive_conditioning , negative_conditioning ,
571+ hires_steps , hires_denoise ,
572+ hires_tile_width , hires_tile_height , hires_mask_blur , hires_tile_padding ,
573+ hires_force_uniform_tiles , up_w , up_h
574+ )
575+ else :
576+ hires_latent , hires_duration = generate_image (
577+ patched_model , config .get ("seed" , 0 ), hires_steps , config .get ("cfg" , 7 ),
578+ config .get ("sampler" , "euler" ), config .get ("scheduler" , "normal" ),
579+ positive_conditioning , negative_conditioning ,
580+ {"samples" : encoded_latent }, hires_denoise ,
581+ width = up_w , height = up_h
582+ )
403583
404584 duration = round (time .time () - t0 , 3 )
405585 print (f"[GridTester] 🔍 Model+HiRes upscale complete in { duration } s → { up_w } x{ up_h } " )
0 commit comments