@@ -431,6 +431,60 @@ def create_coordinate_grid(size, start=None, device=None):
431431 grids = torch .meshgrid (axes , indexing = "ij" )
432432 return torch .stack (grids , dim = - 1 )
433433
434+ def _get_i2i_denoise_strength (self , input_info ):
435+ strength = getattr (input_info , "i2i_denoise_strength" , None )
436+ if strength is None :
437+ strength = self .config .get ("i2i_denoise_strength" )
438+ if strength is None :
439+ return None
440+ strength = float (strength )
441+ if strength < 0.0 or strength > 1.0 :
442+ raise ValueError (f"The value of i2i_denoise_strength should be in [0.0, 1.0] but is { strength } " )
443+ return strength
444+
445+ def _get_single_i2i_image_latents (self , input_info ):
446+ image_encoder_output = getattr (input_info , "image_encoder_output" , None )
447+ if not image_encoder_output :
448+ raise ValueError ("z-image i2i requires exactly one input image with VAE image latents." )
449+ if len (image_encoder_output ) != 1 :
450+ raise ValueError (f"z-image i2i currently supports single-image editing only, got { len (image_encoder_output )} images." )
451+ return image_encoder_output [0 ]["image_latents" ]
452+
453+ def get_timesteps (self , num_inference_steps , strength ):
454+ target_steps = round (num_inference_steps * strength )
455+ if target_steps < 1 :
456+ raise ValueError (
457+ "i2i_denoise_strength results in 0 denoising steps: "
458+ f"round(infer_steps * i2i_denoise_strength)=round({ num_inference_steps } * { strength } )={ target_steps } ; "
459+ "please increase it to run at least 1 step."
460+ )
461+ t_start = num_inference_steps - target_steps
462+ timesteps = self .timesteps [t_start * self .scheduler .order :]
463+ if hasattr (self .scheduler , "set_begin_index" ):
464+ self .scheduler .set_begin_index (t_start * self .scheduler .order )
465+ return timesteps , target_steps
466+
467+ def _resize_i2i_image_latents (self , image_latents , target_height , target_width , target_channels ):
468+ if image_latents .ndim != 4 :
469+ raise ValueError (f"Expected z-image i2i image latents with shape [B, C, H, W], got { tuple (image_latents .shape )} " )
470+ if image_latents .shape [1 ] != target_channels :
471+ raise ValueError (f"z-image i2i image latent channels { image_latents .shape [1 ]} do not match target channels { target_channels } ." )
472+ if image_latents .shape [- 2 :] != (target_height , target_width ):
473+ image_latents = F .interpolate (image_latents , size = (target_height , target_width ), mode = "bilinear" , align_corners = False )
474+ return image_latents
475+
476+ def prepare_i2i_denoise_strength_latents (self , input_info ):
477+ image_latents = self ._get_single_i2i_image_latents (input_info ).to (device = AI_DEVICE , dtype = self .dtype )
478+ if self .latents .shape [0 ] != 1 :
479+ raise ValueError (f"z-image i2i currently supports single-image single-output editing only, got output latent batch { self .latents .shape [0 ]} ." )
480+
481+ _ , target_channels , target_height , target_width = self .latents .shape
482+ image_latents = self ._resize_i2i_image_latents (image_latents , target_height , target_width , target_channels )
483+
484+ latent_timestep = self .timesteps [:1 ]
485+ noise = self .latents
486+ self .latents = self .scheduler .scale_noise (image_latents , latent_timestep , noise )
487+
434488 def prepare_latents (self , input_info ):
435489 self .input_info = input_info
436490 shape = input_info .target_shape
@@ -477,7 +531,8 @@ def generate_freqs_cis_from_position_ids(self, position_ids: torch.Tensor, devic
477531
478532 def set_timesteps (self ):
479533 sigmas = np .linspace (1.0 , 1 / self .config ["infer_steps" ], self .config ["infer_steps" ])
480- image_seq_len = self .latents .shape [1 ]
534+ _ , _ , latent_height , latent_width = self .latents .shape
535+ image_seq_len = (latent_height // 2 ) * (latent_width // 2 )
481536 mu = calculate_shift (
482537 image_seq_len ,
483538 self .scheduler_config .get ("base_image_seq_len" , 256 ),
@@ -497,6 +552,13 @@ def set_timesteps(self):
497552 self .timesteps = timesteps
498553 self .infer_steps = num_inference_steps
499554
555+ if self .config ["task" ] == "i2i" :
556+ strength = self ._get_i2i_denoise_strength (self .input_info )
557+ if strength is not None :
558+ timesteps , num_inference_steps = self .get_timesteps (num_inference_steps , strength )
559+ self .timesteps = timesteps
560+ self .infer_steps = num_inference_steps
561+
500562 num_warmup_steps = max (len (timesteps ) - num_inference_steps * self .scheduler .order , 0 )
501563 self ._num_timesteps = len (timesteps )
502564 self .num_warmup_steps = num_warmup_steps
@@ -509,6 +571,9 @@ def prepare(self, input_info):
509571 logger .info (f"Generator is not None, using existing generator for latents" )
510572 self .prepare_latents (input_info )
511573 self .set_timesteps ()
574+ strength = self ._get_i2i_denoise_strength (input_info )
575+ if self .config ["task" ] == "i2i" and strength is not None :
576+ self .prepare_i2i_denoise_strength_latents (input_info )
512577
513578 self .image_rotary_emb = self .pos_embed (self .input_info .image_shapes , input_info .txt_seq_lens [0 ], device = AI_DEVICE )
514579
0 commit comments