Skip to content

Commit ea4def4

Browse files
committed
update z-image
1 parent 2868c5c commit ea4def4

4 files changed

Lines changed: 116 additions & 5 deletions

File tree

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
{
2+
"num_channels_latents": 16,
3+
"infer_steps": 9,
4+
"attn_type": "flash_attn3",
5+
"enable_cfg": false,
6+
"sample_guide_scale": 0.0,
7+
"patch_size": 2,
8+
"i2i_denoise_strength": 1.0
9+
}

lightx2v/models/runners/z_image/z_image_runner.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ def read_image_input(self, img_path):
159159

160160
@ProfilingContext4DebugL2("Run Encoders")
161161
def _run_input_encoder_local_i2i(self):
162-
image_paths_list = self.input_info.image_path.split(",")
162+
image_paths_list = [image_path.strip() for image_path in self.input_info.image_path.split(",") if image_path.strip()]
163+
if len(image_paths_list) != 1:
164+
raise ValueError(f"z-image i2i currently supports exactly one input image, got {len(image_paths_list)}.")
165+
163166
images_list = []
164167
for image_path in image_paths_list:
165168
_, image = self.read_image_input(image_path)
@@ -299,7 +302,18 @@ def get_input_target_shape(self):
299302
logger.info(f"Z Image Runner got custom shape: {width}x{height}")
300303
return (width, height)
301304

302-
aspect_ratio = self.input_info.aspect_ratio if self.input_info.aspect_ratio else self.config.get("aspect_ratio", None)
305+
aspect_ratio = self.input_info.aspect_ratio
306+
if aspect_ratio in as_maps:
307+
logger.info(f"Z Image Runner got aspect ratio: {aspect_ratio}")
308+
width, height = as_maps[aspect_ratio]
309+
return (width, height)
310+
311+
if self.config["task"] == "i2i" and self.input_info.original_size:
312+
width, height = self.input_info.original_size[-1]
313+
logger.info(f"Z Image Runner got i2i source image shape: {width}x{height}")
314+
return (width, height)
315+
316+
aspect_ratio = self.config.get("aspect_ratio", None)
303317
if aspect_ratio in as_maps:
304318
logger.info(f"Z Image Runner got aspect ratio: {aspect_ratio}")
305319
width, height = as_maps[aspect_ratio]
@@ -309,7 +323,7 @@ def get_input_target_shape(self):
309323
raise NotImplementedError
310324

311325
def set_target_shape(self):
312-
height, width = self.get_input_target_shape()
326+
width, height = self.get_input_target_shape()
313327

314328
# VAE applies 8x compression on images but we must also account for packing which requires
315329
# latent height and width to be divisible by 2.
@@ -326,7 +340,7 @@ def set_img_shapes(self):
326340
raise ValueError(f"target_shape must be 4D [B, C, H, W], got {len(self.input_info.target_shape)}D: {self.input_info.target_shape}")
327341
_, _, latent_height, latent_width = self.input_info.target_shape
328342
else:
329-
height, width = self.get_input_target_shape()
343+
width, height = self.get_input_target_shape()
330344

331345
vae_scale_factor = self.config["vae_scale_factor"]
332346
latent_height = 2 * (int(height) // (vae_scale_factor * 2))

lightx2v/models/schedulers/z_image/scheduler.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,60 @@ def create_coordinate_grid(size, start=None, device=None):
431431
grids = torch.meshgrid(axes, indexing="ij")
432432
return torch.stack(grids, dim=-1)
433433

434+
def _get_i2i_denoise_strength(self, input_info):
435+
strength = getattr(input_info, "i2i_denoise_strength", None)
436+
if strength is None:
437+
strength = self.config.get("i2i_denoise_strength")
438+
if strength is None:
439+
return None
440+
strength = float(strength)
441+
if strength < 0.0 or strength > 1.0:
442+
raise ValueError(f"The value of i2i_denoise_strength should be in [0.0, 1.0] but is {strength}")
443+
return strength
444+
445+
def _get_single_i2i_image_latents(self, input_info):
446+
image_encoder_output = getattr(input_info, "image_encoder_output", None)
447+
if not image_encoder_output:
448+
raise ValueError("z-image i2i requires exactly one input image with VAE image latents.")
449+
if len(image_encoder_output) != 1:
450+
raise ValueError(f"z-image i2i currently supports single-image editing only, got {len(image_encoder_output)} images.")
451+
return image_encoder_output[0]["image_latents"]
452+
453+
def get_timesteps(self, num_inference_steps, strength):
454+
target_steps = round(num_inference_steps * strength)
455+
if target_steps < 1:
456+
raise ValueError(
457+
"i2i_denoise_strength results in 0 denoising steps: "
458+
f"round(infer_steps * i2i_denoise_strength)=round({num_inference_steps} * {strength})={target_steps}; "
459+
"please increase it to run at least 1 step."
460+
)
461+
t_start = num_inference_steps - target_steps
462+
timesteps = self.timesteps[t_start * self.scheduler.order :]
463+
if hasattr(self.scheduler, "set_begin_index"):
464+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
465+
return timesteps, target_steps
466+
467+
def _resize_i2i_image_latents(self, image_latents, target_height, target_width, target_channels):
468+
if image_latents.ndim != 4:
469+
raise ValueError(f"Expected z-image i2i image latents with shape [B, C, H, W], got {tuple(image_latents.shape)}")
470+
if image_latents.shape[1] != target_channels:
471+
raise ValueError(f"z-image i2i image latent channels {image_latents.shape[1]} do not match target channels {target_channels}.")
472+
if image_latents.shape[-2:] != (target_height, target_width):
473+
image_latents = F.interpolate(image_latents, size=(target_height, target_width), mode="bilinear", align_corners=False)
474+
return image_latents
475+
476+
def prepare_i2i_denoise_strength_latents(self, input_info):
477+
image_latents = self._get_single_i2i_image_latents(input_info).to(device=AI_DEVICE, dtype=self.dtype)
478+
if self.latents.shape[0] != 1:
479+
raise ValueError(f"z-image i2i currently supports single-image single-output editing only, got output latent batch {self.latents.shape[0]}.")
480+
481+
_, target_channels, target_height, target_width = self.latents.shape
482+
image_latents = self._resize_i2i_image_latents(image_latents, target_height, target_width, target_channels)
483+
484+
latent_timestep = self.timesteps[:1]
485+
noise = self.latents
486+
self.latents = self.scheduler.scale_noise(image_latents, latent_timestep, noise)
487+
434488
def prepare_latents(self, input_info):
435489
self.input_info = input_info
436490
shape = input_info.target_shape
@@ -477,7 +531,8 @@ def generate_freqs_cis_from_position_ids(self, position_ids: torch.Tensor, devic
477531

478532
def set_timesteps(self):
479533
sigmas = np.linspace(1.0, 1 / self.config["infer_steps"], self.config["infer_steps"])
480-
image_seq_len = self.latents.shape[1]
534+
_, _, latent_height, latent_width = self.latents.shape
535+
image_seq_len = (latent_height // 2) * (latent_width // 2)
481536
mu = calculate_shift(
482537
image_seq_len,
483538
self.scheduler_config.get("base_image_seq_len", 256),
@@ -497,6 +552,13 @@ def set_timesteps(self):
497552
self.timesteps = timesteps
498553
self.infer_steps = num_inference_steps
499554

555+
if self.config["task"] == "i2i":
556+
strength = self._get_i2i_denoise_strength(self.input_info)
557+
if strength is not None:
558+
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength)
559+
self.timesteps = timesteps
560+
self.infer_steps = num_inference_steps
561+
500562
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
501563
self._num_timesteps = len(timesteps)
502564
self.num_warmup_steps = num_warmup_steps
@@ -509,6 +571,9 @@ def prepare(self, input_info):
509571
logger.info(f"Generator is not None, using existing generator for latents")
510572
self.prepare_latents(input_info)
511573
self.set_timesteps()
574+
strength = self._get_i2i_denoise_strength(input_info)
575+
if self.config["task"] == "i2i" and strength is not None:
576+
self.prepare_i2i_denoise_strength_latents(input_info)
512577

513578
self.image_rotary_emb = self.pos_embed(self.input_info.image_shapes, input_info.txt_seq_lens[0], device=AI_DEVICE)
514579

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#!/bin/bash
2+
3+
# set path firstly
4+
lightx2v_path=/data/nvme1/yongyang/nb/LightX2V
5+
model_path=/data/nvme1/models/Tongyi-MAI/Z-Image-Turbo
6+
image_path=${lightx2v_path}/assets/inputs/imgs/img_0.jpg
7+
8+
export CUDA_VISIBLE_DEVICES=0
9+
10+
# set environment variables
11+
source ${lightx2v_path}/scripts/base/base.sh
12+
13+
python -m lightx2v.infer \
14+
--model_cls z_image \
15+
--task i2i \
16+
--model_path $model_path \
17+
--config_json ${lightx2v_path}/configs/z_image/z_image_turbo_i2i.json \
18+
--image_path $image_path \
19+
--prompt "Change the cat to a dog." \
20+
--negative_prompt " " \
21+
--save_result_path ${lightx2v_path}/save_results/z_image_turbo_i2i.png \
22+
--seed 42 \
23+
--i2i_denoise_strength 1.0

0 commit comments

Comments
 (0)