|
50 | 50 | title="Denoise - Z-Image", |
51 | 51 | tags=["image", "z-image"], |
52 | 52 | category="image", |
53 | | - version="1.4.0", |
| 53 | + version="1.5.0", |
54 | 54 | classification=Classification.Prototype, |
55 | 55 | ) |
56 | 56 | class ZImageDenoiseInvocation(BaseInvocation): |
@@ -104,6 +104,15 @@ class ZImageDenoiseInvocation(BaseInvocation): |
104 | 104 | description=FieldDescriptions.vae + " Required for control conditioning.", |
105 | 105 | input=Input.Connection, |
106 | 106 | ) |
| 107 | + # Shift override for the sigma schedule. If None, shift is auto-calculated from image dimensions. |
| 108 | + shift: Optional[float] = InputField( |
| 109 | + default=None, |
| 110 | + ge=0.0, |
| 111 | + description="Override the timestep shift (mu) for the sigma schedule. " |
| 112 | + "Leave blank to auto-calculate based on image dimensions (recommended). " |
| 113 | + "Lower values (~0.5) produce less noise shifting, higher values (~1.15) produce more.", |
| 114 | + title="Shift", |
| 115 | + ) |
107 | 116 | # Scheduler selection for the denoising process |
108 | 117 | scheduler: ZIMAGE_SCHEDULER_NAME_VALUES = InputField( |
109 | 118 | default="euler", |
@@ -225,34 +234,36 @@ def _calculate_shift( |
225 | 234 | """Calculate timestep shift based on image sequence length. |
226 | 235 |
|
227 | 236 | Based on diffusers ZImagePipeline.calculate_shift method. |
| 237 | + Returns a linear shift value (exp(mu) from the original formula). |
228 | 238 | """ |
| 239 | + import math |
| 240 | + |
229 | 241 | m = (max_shift - base_shift) / (max_image_seq_len - base_image_seq_len) |
230 | 242 | b = base_shift - m * base_image_seq_len |
231 | 243 | mu = image_seq_len * m + b |
232 | | - return mu |
| 244 | + # Convert from exponential mu to linear shift value |
| 245 | + return math.exp(mu) |
233 | 246 |
|
234 | | - def _get_sigmas(self, mu: float, num_steps: int) -> list[float]: |
235 | | - """Generate sigma schedule with time shift. |
| 247 | + def _get_sigmas(self, shift: float, num_steps: int) -> list[float]: |
| 248 | + """Generate sigma schedule with linear time shift. |
236 | 249 |
|
237 | | - Based on FlowMatchEulerDiscreteScheduler with shift. |
| 250 | + Uses linear time shift: shift / (shift + (1/t - 1)). |
| 251 | + The shift value is used directly as a multiplier. |
238 | 252 | Generates num_steps + 1 sigma values (including terminal 0.0). |
239 | 253 | """ |
240 | | - import math |
241 | 254 |
|
242 | | - def time_shift(mu: float, sigma: float, t: float) -> float: |
243 | | - """Apply time shift to a single timestep value.""" |
| 255 | + def time_shift(shift: float, t: float) -> float: |
| 256 | + """Apply linear time shift to a single timestep value.""" |
244 | 257 | if t <= 0: |
245 | 258 | return 0.0 |
246 | 259 | if t >= 1: |
247 | 260 | return 1.0 |
248 | | - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) |
| 261 | + return shift / (shift + (1 / t - 1)) |
249 | 262 |
|
250 | | - # Generate linearly spaced values from 1 to 0 (excluding endpoints for safety) |
251 | | - # then apply time shift |
252 | 263 | sigmas = [] |
253 | 264 | for i in range(num_steps + 1): |
254 | 265 | t = 1.0 - i / num_steps # Goes from 1.0 to 0.0 |
255 | | - sigma = time_shift(mu, 1.0, t) |
| 266 | + sigma = time_shift(shift, t) |
256 | 267 | sigmas.append(sigma) |
257 | 268 |
|
258 | 269 | return sigmas |
@@ -313,11 +324,14 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor: |
313 | 324 | # Concatenate all negative embeddings |
314 | 325 | neg_prompt_embeds = torch.cat([tc.prompt_embeds for tc in neg_text_conditionings], dim=0) |
315 | 326 |
|
316 | | - # Calculate shift based on image sequence length |
317 | | - mu = self._calculate_shift(img_seq_len) |
| 327 | + # Calculate shift based on image sequence length, or use override |
| 328 | + if self.shift is not None: |
| 329 | + shift = self.shift |
| 330 | + else: |
| 331 | + shift = self._calculate_shift(img_seq_len) |
318 | 332 |
|
319 | 333 | # Generate sigma schedule with time shift |
320 | | - sigmas = self._get_sigmas(mu, self.steps) |
| 334 | + sigmas = self._get_sigmas(shift, self.steps) |
321 | 335 |
|
322 | 336 | # Apply denoising_start and denoising_end clipping |
323 | 337 | if self.denoising_start > 0 or self.denoising_end < 1: |
|
0 commit comments